chore(cutover): remove services/mana-llm/ — moved to mana-platform

Live containers on the Mac Mini build out of `../mana/services/mana-llm/`
since the 8-Doppel-Cutover commit (774852ba2). Smoke test green
2026-05-08 — health endpoints, JWKS, login flow, Stripe-webhook all
reachable from the new build path. Removing the now-stale duplicate.

Was  90M in this repo, gone now. Active code lives in
`Code/mana/services/mana-llm/` (siehe ../mana/CLAUDE.md).

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
Till JS 2026-05-08 18:53:54 +02:00
parent 6103d4d2d9
commit 2b07f6ef89
42 changed files with 0 additions and 6371 deletions

View file

@ -1,28 +0,0 @@
# Service
PORT=3025
LOG_LEVEL=info
# Ollama (Primary)
OLLAMA_URL=http://localhost:11434
OLLAMA_DEFAULT_MODEL=gemma3:4b
OLLAMA_TIMEOUT=120
# OpenRouter (Cloud Fallback)
OPENROUTER_API_KEY=sk-or-v1-xxx
OPENROUTER_DEFAULT_MODEL=meta-llama/llama-3.1-8b-instruct
# Groq (Optional)
GROQ_API_KEY=gsk_xxx
# Together (Optional)
TOGETHER_API_KEY=xxx
# Caching (Optional)
REDIS_URL=redis://localhost:6379
CACHE_TTL=3600
# CORS
CORS_ORIGINS=http://localhost:5173,https://mana.how
# API key for cross-service auth (validated in src/api_auth.py)
GPU_API_KEY=

View file

@ -1,31 +0,0 @@
# Python
__pycache__/
*.py[cod]
*$py.class
*.so
.Python
venv/
.venv/
ENV/
.env
# IDE
.idea/
.vscode/
*.swp
*.swo
# Testing
.pytest_cache/
.coverage
htmlcov/
.tox/
# Build
dist/
build/
*.egg-info/
# Local
.env.local
*.log

View file

@ -1,376 +0,0 @@
# mana-llm
Central LLM abstraction service providing a unified OpenAI-compatible API for Ollama and cloud LLM providers.
## Overview
mana-llm acts as a central gateway for all LLM requests in the monorepo, providing:
- Unified OpenAI-compatible API
- Provider routing (Ollama, OpenRouter, Groq, Together)
- Streaming via Server-Sent Events (SSE)
- Vision/multimodal support
- Embeddings generation
- Prometheus metrics
## Architecture
```
┌─────────────────────────────────────────────────────────────────────┐
│ Consumer Apps │
│ chat-backend │ mana web │ todo (LLM enrich) │ etc. │
└────────────────────────────────┬────────────────────────────────────┘
│ HTTP/SSE
┌─────────────────────────────────────────────────────────────────────┐
│ mana-llm (Port 3025) │
│ ┌─────────────┐ ┌─────────────┐ ┌─────────────┐ │
│ │ Router │ │ Cache │ │ Metrics │ │
│ │ (Provider) │ │ (Redis) │ │ (Prometheus)│ │
│ └──────┬──────┘ └─────────────┘ └─────────────┘ │
│ │ │
│ ┌──────┴──────────────────────────────────────────┐ │
│ │ Provider Adapters │ │
│ │ ┌──────────┐ ┌──────────┐ ┌──────────────┐ │ │
│ │ │ Ollama │ │ OpenAI │ │ OpenRouter │ │ │
│ │ │ Adapter │ │ Adapter │ │ Adapter │ │ │
│ │ └──────────┘ └──────────┘ └──────────────┘ │ │
│ └─────────────────────────────────────────────────┘ │
└─────────────────────────────────────────────────────────────────────┘
```
## Quick Start
### Prerequisites
- Python 3.11+
- Ollama running locally (http://localhost:11434)
- Redis (optional, for caching)
### Development
```bash
cd services/mana-llm
# Create virtual environment
python -m venv venv
source venv/bin/activate # or venv\Scripts\activate on Windows
# Install dependencies
pip install -r requirements.txt
# Copy environment file
cp .env.example .env
# Start Redis (optional)
docker-compose -f docker-compose.dev.yml up -d
# Run service
python -m uvicorn src.main:app --port 3025 --reload
```
### Docker
```bash
# Full stack (mana-llm + Redis)
docker-compose up -d
# View logs
docker-compose logs -f mana-llm
```
## API Endpoints
### Chat Completions
```bash
# Non-streaming
curl -X POST http://localhost:3025/v1/chat/completions \
-H "Content-Type: application/json" \
-d '{
"model": "ollama/gemma3:4b",
"messages": [{"role": "user", "content": "Hello!"}],
"stream": false
}'
# Streaming (SSE)
curl -X POST http://localhost:3025/v1/chat/completions \
-H "Content-Type: application/json" \
-d '{
"model": "ollama/gemma3:4b",
"messages": [{"role": "user", "content": "Hello!"}],
"stream": true
}'
```
### Vision/Multimodal
```bash
curl -X POST http://localhost:3025/v1/chat/completions \
-H "Content-Type: application/json" \
-d '{
"model": "ollama/llava:7b",
"messages": [{
"role": "user",
"content": [
{"type": "text", "text": "What is in this image?"},
{"type": "image_url", "image_url": {"url": "data:image/png;base64,..."}}
]
}]
}'
```
### Models
```bash
# List all models
curl http://localhost:3025/v1/models
# Get specific model
curl http://localhost:3025/v1/models/ollama/gemma3:4b
```
### Embeddings
```bash
curl -X POST http://localhost:3025/v1/embeddings \
-H "Content-Type: application/json" \
-d '{
"model": "ollama/nomic-embed-text",
"input": "Text to embed"
}'
```
### Health & Metrics
```bash
# Liveness summary (legacy, terse shape — only status + per-provider status string)
curl http://localhost:3025/health
# Detailed per-provider liveness snapshot (M4)
curl http://localhost:3025/v1/health
# Prometheus metrics
curl http://localhost:3025/metrics
```
### Aliases (M4 — see `aliases.yaml` and the fallback section below)
```bash
# What does each `mana/<class>` resolve to?
curl http://localhost:3025/v1/aliases
```
## Aliases & Fallback
> Background: [`docs/plans/llm-fallback-aliases.md`](../../docs/plans/llm-fallback-aliases.md)
### What callers send
Two acceptable shapes for the `model` field of `/v1/chat/completions`:
1. **Aliases** in the reserved `mana/` namespace — recommended for product code.
The router resolves them via `aliases.yaml` to a chain of concrete
`provider/model` strings and tries them in order.
2. **Direct `provider/model`** — bypasses the alias layer, no fallback.
Useful for tests, debugging, and one-off integrations.
| Alias | Class |
|---|---|
| `mana/fast-text` | Short answers, classification, single-shot Q&A |
| `mana/long-form` | Writing, essays, stories, longer prose |
| `mana/structured` | JSON output (comic storyboards, research subqueries, tag suggestions) |
| `mana/reasoning` | Agent missions, tool calls, multi-step plans |
| `mana/vision` | Multimodal (image + text) |
The chain for each alias lives in `services/mana-llm/aliases.yaml`. Edit
the file and `kill -HUP <pid>` to reload — no restart needed. Reload
errors keep the previous good state; check the service logs.
### Fallback semantics
Every chain is tried in order. The router skips an entry if the provider
isn't configured at this deployment (no API key) or is currently marked
unhealthy by the health-cache. For each remaining entry the request is
attempted; on a **retryable** error (connection failure, timeout, 5xx,
rate-limit, RemoteProtocolError) the provider is marked unhealthy and
the next entry is tried. **Non-retryable** errors (auth, capability,
content-blocked, 4xx, unknown exception types) propagate immediately —
no fallback, the cache is not poisoned.
Streaming follows the same logic up to the **first byte**. Once a chunk
has been yielded the provider is committed; mid-stream errors surface
as-is so we never splice two providers' voices into one output.
If every entry was skipped or failed, the response is `503` carrying a
structured `attempts: list[(model, reason)]` log so the cause is
visible to the caller, not only in service logs.
### Resolved-model header
Non-streaming responses carry `X-Mana-LLM-Resolved: <provider>/<model>`
(e.g. `groq/llama-3.3-70b-versatile`) — the concrete model that
actually answered. Use this for token-cost attribution when the request
used an alias. For streaming, each chunk's `model` field carries the
same info (headers go out before the chain is walked).
### Health-cache + probe
`ProviderHealthCache` keeps a per-provider circuit-breaker:
* 1 failure: still healthy (transient blip, don't bounce).
* 2 consecutive failures: `is_healthy → False` for 60 s; the router
fail-fasts straight to the next chain entry.
* After 60 s: half-open. Next call exercises the provider; success
fully resets, failure re-arms the backoff.
A background `HealthProbe` task runs every 30 s with a 3 s timeout per
provider, calling cheap endpoints (`/api/tags` for Ollama, `/v1/models`
for OpenAI-compat). One bad probe can't sink the loop; results feed
into the same cache as the call-site fallback.
### Prometheus metrics added in M4
| Metric | Labels | Purpose |
|---|---|---|
| `mana_llm_alias_resolved_total` | `alias`, `target` | How often an alias resolved to which concrete model — useful for spotting cases where the primary always falls through. |
| `mana_llm_fallback_total` | `from_model`, `to_model`, `reason` | Each fallback transition. `reason` is the exception class name or `cache-unhealthy` / `unconfigured`. |
| `mana_llm_provider_healthy` | `provider` | Gauge: 1 healthy, 0 in backoff. Mirrors the circuit-breaker. |
## Provider Routing
Models use the format `provider/model`:
| Model | Provider | Target |
|-------|----------|--------|
| `ollama/gemma3:4b` | Ollama | localhost:11434 |
| `ollama/llava:7b` | Ollama | localhost:11434 |
| `openrouter/meta-llama/llama-3.1-8b-instruct` | OpenRouter | api.openrouter.ai |
| `groq/llama-3.1-8b-instant` | Groq | api.groq.com |
| `together/meta-llama/Meta-Llama-3.1-8B-Instruct-Turbo` | Together | api.together.xyz |
**Default:** If no provider prefix is given (e.g., `gemma3:4b`), Ollama is used.
## Configuration
Environment variables (see `.env.example`):
| Variable | Default | Description |
|----------|---------|-------------|
| `PORT` | 3025 | Service port |
| `LOG_LEVEL` | info | Logging level |
| `OLLAMA_URL` | http://localhost:11434 | Ollama server URL |
| `OLLAMA_DEFAULT_MODEL` | gemma3:4b | Default Ollama model |
| `OLLAMA_TIMEOUT` | 120 | Ollama request timeout (seconds) |
| `OPENROUTER_API_KEY` | - | OpenRouter API key |
| `GROQ_API_KEY` | - | Groq API key |
| `TOGETHER_API_KEY` | - | Together API key |
| `REDIS_URL` | - | Redis URL for caching |
| `CACHE_TTL` | 3600 | Cache TTL in seconds |
| `CORS_ORIGINS` | localhost | Allowed CORS origins |
## Project Structure
```
services/mana-llm/
├── src/
│ ├── main.py # FastAPI app entry point
│ ├── config.py # Settings via pydantic-settings
│ ├── providers/
│ │ ├── base.py # Abstract provider interface
│ │ ├── ollama.py # Ollama provider
│ │ ├── openai_compat.py # OpenAI-compatible provider
│ │ └── router.py # Provider routing logic
│ ├── models/
│ │ ├── requests.py # Request Pydantic models
│ │ └── responses.py # Response Pydantic models
│ ├── streaming/
│ │ └── sse.py # SSE response handling
│ └── utils/
│ ├── cache.py # Redis caching
│ └── metrics.py # Prometheus metrics
├── tests/
│ ├── test_api.py # API endpoint tests
│ ├── test_providers.py # Provider tests
│ └── test_streaming.py # Streaming tests
├── Dockerfile
├── docker-compose.yml
├── docker-compose.dev.yml
├── requirements.txt
├── pyproject.toml
└── .env.example
```
## Testing
```bash
# Run tests
pytest
# Run with coverage
pytest --cov=src
# Run specific test file
pytest tests/test_providers.py -v
```
## Integration Example
### TypeScript/Node.js Client
```typescript
// Using fetch
const response = await fetch('http://localhost:3025/v1/chat/completions', {
method: 'POST',
headers: { 'Content-Type': 'application/json' },
body: JSON.stringify({
model: 'ollama/gemma3:4b',
messages: [{ role: 'user', content: 'Hello!' }],
stream: false,
}),
});
const data = await response.json();
console.log(data.choices[0].message.content);
```
### Streaming with EventSource
```typescript
const response = await fetch('http://localhost:3025/v1/chat/completions', {
method: 'POST',
headers: { 'Content-Type': 'application/json' },
body: JSON.stringify({
model: 'ollama/gemma3:4b',
messages: [{ role: 'user', content: 'Hello!' }],
stream: true,
}),
});
const reader = response.body?.getReader();
const decoder = new TextDecoder();
while (true) {
const { done, value } = await reader!.read();
if (done) break;
const chunk = decoder.decode(value);
const lines = chunk.split('\n').filter(line => line.startsWith('data: '));
for (const line of lines) {
const data = line.slice(6);
if (data === '[DONE]') break;
const parsed = JSON.parse(data);
const content = parsed.choices[0]?.delta?.content;
if (content) process.stdout.write(content);
}
}
```
## Related Services
| Service | Port | Description |
|---------|------|-------------|
| mana-tts | 3022 | Text-to-speech service |
| mana-stt | 3023 | Speech-to-text service |
| mana-search | 3021 | Web search & extraction |

View file

@ -1,49 +0,0 @@
# Build stage
FROM python:3.12-slim AS builder
WORKDIR /app
# Install build dependencies
RUN apt-get update && apt-get install -y --no-install-recommends \
build-essential \
&& rm -rf /var/lib/apt/lists/*
# Copy requirements first for caching
COPY requirements.txt .
RUN pip install --no-cache-dir --user -r requirements.txt
# Production stage
FROM python:3.12-slim
WORKDIR /app
# Create non-root user
RUN useradd -m -u 1000 appuser
# Copy installed packages from builder
COPY --from=builder /root/.local /home/appuser/.local
# Copy application code + alias config. `aliases.yaml` is loaded by
# main.py's lifespan handler from `Path(__file__).parent.parent /
# 'aliases.yaml'` = /app/aliases.yaml. Without it the FastAPI app
# raises AliasConfigError on startup and the container crashloops.
COPY --chown=appuser:appuser src/ ./src/
COPY --chown=appuser:appuser aliases.yaml ./aliases.yaml
# Set environment
ENV PATH=/home/appuser/.local/bin:$PATH
ENV PYTHONUNBUFFERED=1
ENV PYTHONDONTWRITEBYTECODE=1
# Switch to non-root user
USER appuser
# Expose port
EXPOSE 3025
# Health check
HEALTHCHECK --interval=30s --timeout=10s --start-period=5s --retries=3 \
CMD python -c "import httpx; httpx.get('http://localhost:3025/health').raise_for_status()"
# Run application
CMD ["python", "-m", "uvicorn", "src.main:app", "--host", "0.0.0.0", "--port", "3025"]

View file

@ -1,54 +0,0 @@
# mana-llm Model Aliases — single source of truth for which class of
# model each backend feature uses.
#
# Consumers (mana-api, mana-ai, …) send `"model": "mana/<class>"` in
# their /v1/chat/completions requests; mana-llm resolves the alias to
# the chain below and tries entries in order, skipping providers that
# the health-cache has marked unhealthy.
#
# Order in `chain` = preference. First healthy entry wins. Each chain
# should end with a cloud provider so the system stays functional even
# when the local GPU server (mana-gpu, RTX 3090) is offline.
#
# Reload at runtime: `kill -HUP <pid>` after editing — no restart needed.
# Reference: docs/plans/llm-fallback-aliases.md.
aliases:
mana/fast-text:
description: "Short answers, classification, single-shot Q&A"
chain:
- ollama/qwen2.5:7b
- groq/llama-3.1-8b-instant
- openrouter/anthropic/claude-3-haiku
mana/long-form:
description: "Writing, essays, stories, longer prose"
chain:
- ollama/gemma3:12b
- groq/llama-3.3-70b-versatile
- openrouter/anthropic/claude-3.5-haiku
mana/structured:
description: "JSON output (comic storyboards, research subqueries, tag suggestions)"
chain:
- ollama/qwen2.5:7b
- groq/llama-3.1-8b-instant
- openrouter/openai/gpt-4o-mini
mana/reasoning:
description: "Agent missions, tool calls, multi-step plans"
# Cloud first by design — local 4-7B models are unreliable for tool calls
chain:
- openrouter/anthropic/claude-3.5-sonnet
- groq/llama-3.3-70b-versatile
mana/vision:
description: "Multimodal (image + text)"
chain:
- ollama/llava:7b
- google/gemini-2.0-flash-exp
- openrouter/openai/gpt-4o
# Default alias used when a request omits `model` or sends an unknown
# value with no provider prefix. Keep this conservative (cheap class).
default: mana/fast-text

View file

@ -1,15 +0,0 @@
# Development compose - only Redis (run Python locally)
version: "3.8"
services:
redis:
image: redis:7-alpine
container_name: mana-llm-redis-dev
ports:
- "6380:6379"
volumes:
- redis-data-dev:/data
restart: unless-stopped
volumes:
redis-data-dev:

View file

@ -1,52 +0,0 @@
version: "3.8"
services:
mana-llm:
build:
context: .
dockerfile: Dockerfile
container_name: mana-llm
ports:
- "3025:3025"
environment:
- PORT=3025
- LOG_LEVEL=info
- OLLAMA_URL=http://host.docker.internal:11434
- OLLAMA_DEFAULT_MODEL=gemma3:4b
- OLLAMA_TIMEOUT=120
- REDIS_URL=redis://redis:6379
# Add API keys via .env file
- GOOGLE_API_KEY=${GOOGLE_API_KEY:-}
- GOOGLE_DEFAULT_MODEL=${GOOGLE_DEFAULT_MODEL:-gemini-2.5-flash}
- OPENROUTER_API_KEY=${OPENROUTER_API_KEY:-}
- GROQ_API_KEY=${GROQ_API_KEY:-}
- TOGETHER_API_KEY=${TOGETHER_API_KEY:-}
- CORS_ORIGINS=http://localhost:5173,http://localhost:3000,https://mana.how
depends_on:
- redis
restart: unless-stopped
healthcheck:
test: ["CMD", "python", "-c", "import httpx; httpx.get('http://localhost:3025/health').raise_for_status()"]
interval: 30s
timeout: 10s
retries: 3
start_period: 10s
extra_hosts:
- "host.docker.internal:host-gateway"
redis:
image: redis:7-alpine
container_name: mana-llm-redis
ports:
- "6380:6379"
volumes:
- redis-data:/data
restart: unless-stopped
healthcheck:
test: ["CMD", "redis-cli", "ping"]
interval: 10s
timeout: 5s
retries: 5
volumes:
redis-data:

View file

@ -1,40 +0,0 @@
[project]
name = "mana-llm"
version = "0.1.0"
description = "Central LLM abstraction service for Ollama and OpenAI-compatible APIs"
requires-python = ">=3.11"
dependencies = [
"fastapi>=0.115.0",
"uvicorn[standard]>=0.32.0",
"pydantic>=2.10.0",
"pydantic-settings>=2.6.0",
"httpx>=0.28.0",
"sse-starlette>=2.2.0",
"redis>=5.2.0",
"prometheus-client>=0.21.0",
"google-genai>=1.0.0",
"pyyaml>=6.0.2",
]
[project.optional-dependencies]
dev = [
"pytest>=8.3.0",
"pytest-asyncio>=0.24.0",
"pytest-httpx>=0.35.0",
"ruff>=0.8.0",
]
[build-system]
requires = ["hatchling"]
build-backend = "hatchling.build"
[tool.ruff]
line-length = 100
target-version = "py311"
[tool.ruff.lint]
select = ["E", "F", "I", "W"]
[tool.pytest.ini_options]
asyncio_mode = "auto"
testpaths = ["tests"]

View file

@ -1,29 +0,0 @@
# Core
fastapi>=0.115.0
uvicorn[standard]>=0.32.0
pydantic>=2.10.0
pydantic-settings>=2.6.0
# HTTP Client
httpx>=0.28.0
# Streaming
sse-starlette>=2.2.0
# Caching (optional)
redis>=5.2.0
# Google Gemini
google-genai>=1.0.0
# Metrics
prometheus-client>=0.21.0
# Config (alias registry)
pyyaml>=6.0.2
# Dev
pytest>=8.3.0
pytest-asyncio>=0.24.0
pytest-httpx>=0.35.0
ruff>=0.8.0

View file

@ -1,17 +0,0 @@
"""mana-llm service runner (run with pythonw.exe to run headless)."""
import os
import sys
os.chdir(r"C:\mana\services\mana-llm")
sys.path.insert(0, r"C:\mana\services\mana-llm")
# Load .env file
from dotenv import load_dotenv
load_dotenv(r"C:\mana\services\mana-llm\.env")
# Redirect stdout/stderr to log file
log = open(r"C:\mana\services\mana-llm\service.log", "w", buffering=1)
sys.stdout = log
sys.stderr = log
import uvicorn
uvicorn.run("src.main:app", host="0.0.0.0", port=3025, log_level="info")

View file

@ -1,3 +0,0 @@
"""mana-llm - Central LLM abstraction service."""
__version__ = "0.1.0"

View file

@ -1,223 +0,0 @@
"""Model-alias registry.
Loads `aliases.yaml` and exposes a small API the router uses to resolve
semantic model names like ``mana/long-form`` to an ordered list of
concrete provider-prefixed model strings (``ollama/gemma3:12b``
``groq/llama-3.3-70b-versatile`` ).
The registry is hot-reloadable: ``reload()`` rebuilds the in-memory
mapping atomically. Reload errors leave the previous good state intact
so a typo in the yaml file doesn't take the service down — caller logs
the error and keeps serving.
See docs/plans/llm-fallback-aliases.md for the full design.
"""
from __future__ import annotations
import logging
import threading
from dataclasses import dataclass
from pathlib import Path
import yaml
logger = logging.getLogger(__name__)
# Aliases live in this namespace. Anything else passed as `model` is
# treated as a direct provider/model string (preserves the legal
# bypass-the-alias-layer escape hatch for tests/debugging).
ALIAS_PREFIX = "mana/"
@dataclass(frozen=True)
class Alias:
"""A resolved alias entry."""
name: str
description: str
chain: tuple[str, ...]
class AliasConfigError(ValueError):
"""Raised when the YAML file is malformed or violates schema constraints."""
class UnknownAliasError(KeyError):
"""Raised when a caller asks for an alias that isn't defined."""
def _validate_chain(name: str, chain: object) -> tuple[str, ...]:
"""Schema-check a single alias chain. Returns the validated tuple."""
if not isinstance(chain, list):
raise AliasConfigError(f"alias '{name}': chain must be a list, got {type(chain).__name__}")
if not chain:
raise AliasConfigError(f"alias '{name}': chain must not be empty")
out: list[str] = []
for i, entry in enumerate(chain):
if not isinstance(entry, str) or not entry.strip():
raise AliasConfigError(
f"alias '{name}': chain[{i}] must be a non-empty string, got {entry!r}"
)
if "/" not in entry:
raise AliasConfigError(
f"alias '{name}': chain[{i}] = {entry!r} must include a provider prefix "
f"(e.g. 'ollama/...', 'groq/...')"
)
out.append(entry.strip())
return tuple(out)
def _validate_name(name: object) -> str:
"""Aliases must live in the reserved `mana/` namespace."""
if not isinstance(name, str) or not name.startswith(ALIAS_PREFIX):
raise AliasConfigError(
f"alias name {name!r} must start with {ALIAS_PREFIX!r} (the reserved namespace)"
)
suffix = name[len(ALIAS_PREFIX) :]
if not suffix or "/" in suffix:
raise AliasConfigError(
f"alias name {name!r} must have exactly one segment after {ALIAS_PREFIX!r}"
)
return name
def _parse_document(doc: object) -> tuple[dict[str, Alias], str | None]:
"""Parse a loaded YAML document into a normalized (aliases, default) pair."""
if not isinstance(doc, dict):
raise AliasConfigError(f"yaml root must be a mapping, got {type(doc).__name__}")
raw_aliases = doc.get("aliases", {})
if not isinstance(raw_aliases, dict):
raise AliasConfigError(
f"`aliases` must be a mapping, got {type(raw_aliases).__name__}"
)
if not raw_aliases:
raise AliasConfigError("`aliases` is empty — at least one alias is required")
parsed: dict[str, Alias] = {}
for name, body in raw_aliases.items():
validated_name = _validate_name(name)
if not isinstance(body, dict):
raise AliasConfigError(
f"alias '{validated_name}': body must be a mapping, got {type(body).__name__}"
)
description = body.get("description", "")
if not isinstance(description, str):
raise AliasConfigError(
f"alias '{validated_name}': description must be a string"
)
chain = _validate_chain(validated_name, body.get("chain"))
parsed[validated_name] = Alias(
name=validated_name,
description=description.strip(),
chain=chain,
)
default = doc.get("default")
if default is not None:
if not isinstance(default, str):
raise AliasConfigError(f"`default` must be a string, got {type(default).__name__}")
if default not in parsed:
raise AliasConfigError(
f"`default` references unknown alias {default!r} "
f"(known: {sorted(parsed)})"
)
return parsed, default
class AliasRegistry:
"""Thread-safe in-memory registry of model aliases.
Construct once at startup with the path to the yaml file. Call
:meth:`reload` to re-read after a SIGHUP. Reads (``resolve``,
``is_alias``, ) are cheap and lock-free during steady state they
snapshot the current mapping reference; only the swap on reload is
serialized.
"""
def __init__(self, path: Path | str):
self._path = Path(path)
self._lock = threading.Lock()
self._aliases: dict[str, Alias] = {}
self._default: str | None = None
self._load()
@property
def path(self) -> Path:
return self._path
def _load(self) -> None:
"""Initial load — propagates errors so a bad config fails fast at startup."""
if not self._path.exists():
raise AliasConfigError(f"alias config not found at {self._path}")
with self._path.open("r", encoding="utf-8") as f:
try:
doc = yaml.safe_load(f)
except yaml.YAMLError as e:
raise AliasConfigError(f"failed to parse {self._path}: {e}") from e
aliases, default = _parse_document(doc)
# No lock needed during __init__ — nothing else can read yet.
self._aliases = aliases
self._default = default
logger.info(
"AliasRegistry loaded %d alias(es) from %s (default=%s)",
len(aliases),
self._path,
default,
)
def reload(self) -> None:
"""Re-read the yaml file. On parse error, keep the previous state and raise.
Designed for SIGHUP: callers should ``try/except AliasConfigError``
and log; do not crash the service on a typo.
"""
with self._path.open("r", encoding="utf-8") as f:
try:
doc = yaml.safe_load(f)
except yaml.YAMLError as e:
raise AliasConfigError(f"failed to parse {self._path}: {e}") from e
aliases, default = _parse_document(doc)
with self._lock:
self._aliases = aliases
self._default = default
logger.info(
"AliasRegistry reloaded %d alias(es) from %s (default=%s)",
len(aliases),
self._path,
default,
)
@staticmethod
def is_alias(name: str) -> bool:
"""Cheap syntactic check — does this name live in the alias namespace?
Static; doesn't require a registry instance. Used by the router to
decide whether to dispatch to the alias layer or pass through to
provider-direct routing.
"""
return isinstance(name, str) and name.startswith(ALIAS_PREFIX)
def resolve(self, name: str) -> Alias:
"""Look up the named alias. Raises :class:`UnknownAliasError` if absent."""
try:
return self._aliases[name]
except KeyError as e:
raise UnknownAliasError(
f"unknown alias {name!r} (known: {sorted(self._aliases)})"
) from e
def resolve_chain(self, name: str) -> tuple[str, ...]:
"""Sugar for ``resolve(name).chain`` — the form the router actually wants."""
return self.resolve(name).chain
@property
def default_alias(self) -> str | None:
"""The alias used when a request arrives with no recognizable model."""
return self._default
def list_aliases(self) -> list[Alias]:
"""All aliases as a snapshot list — for the GET /v1/aliases debug endpoint."""
return [self._aliases[k] for k in sorted(self._aliases)]

View file

@ -1,53 +0,0 @@
"""
Simple API Key Authentication Middleware for GPU Services.
Checks X-API-Key header or ?api_key query parameter.
Skips auth for /health, /docs, /openapi.json, /redoc endpoints.
Environment variables:
GPU_API_KEY: Required API key (if empty, auth is disabled)
GPU_REQUIRE_AUTH: Enable/disable auth (default: true if GPU_API_KEY is set)
"""
import os
import logging
from fastapi import Request
from fastapi.responses import JSONResponse
from starlette.middleware.base import BaseHTTPMiddleware
logger = logging.getLogger(__name__)
GPU_API_KEY = os.getenv("GPU_API_KEY", "")
GPU_REQUIRE_AUTH = os.getenv("GPU_REQUIRE_AUTH", "true" if GPU_API_KEY else "false").lower() == "true"
# Endpoints that don't require auth
PUBLIC_PATHS = {"/health", "/docs", "/openapi.json", "/redoc", "/metrics"}
class ApiKeyMiddleware(BaseHTTPMiddleware):
async def dispatch(self, request: Request, call_next):
# Skip auth if disabled
if not GPU_REQUIRE_AUTH or not GPU_API_KEY:
return await call_next(request)
# Skip auth for public endpoints
if request.url.path in PUBLIC_PATHS:
return await call_next(request)
# Check API key from header or query param
api_key = request.headers.get("X-API-Key") or request.query_params.get("api_key")
if not api_key:
return JSONResponse(
status_code=401,
content={"detail": "Missing API key. Provide X-API-Key header."},
)
if api_key != GPU_API_KEY:
logger.warning(f"Invalid API key attempt from {request.client.host if request.client else 'unknown'}")
return JSONResponse(
status_code=401,
content={"detail": "Invalid API key."},
)
return await call_next(request)

View file

@ -1,57 +0,0 @@
"""Configuration settings for mana-llm service."""
from pydantic_settings import BaseSettings
class Settings(BaseSettings):
"""Application settings loaded from environment variables."""
# Service
port: int = 3025
log_level: str = "info"
# Ollama (Primary provider)
ollama_url: str = "http://localhost:11434"
ollama_default_model: str = "gemma3:4b"
ollama_timeout: int = 120
# OpenRouter (Cloud fallback)
openrouter_api_key: str | None = None
openrouter_base_url: str = "https://openrouter.ai/api/v1"
openrouter_default_model: str = "meta-llama/llama-3.1-8b-instruct"
# Groq (Optional)
groq_api_key: str | None = None
groq_base_url: str = "https://api.groq.com/openai/v1"
# Together (Optional)
together_api_key: str | None = None
together_base_url: str = "https://api.together.xyz/v1"
# Google Gemini (Fallback provider)
google_api_key: str | None = None
google_default_model: str = "gemini-2.5-flash"
# Auto-fallback: Ollama → Google when Ollama is overloaded/down
auto_fallback_enabled: bool = True
ollama_max_concurrent: int = 3
# Caching (Optional)
redis_url: str | None = None
cache_ttl: int = 3600
# CORS
cors_origins: str = "http://localhost:5173,http://localhost:5190,https://mana.how,https://playground.mana.how"
@property
def cors_origins_list(self) -> list[str]:
"""Parse CORS origins from comma-separated string."""
return [origin.strip() for origin in self.cors_origins.split(",")]
class Config:
env_file = ".env"
env_file_encoding = "utf-8"
extra = "ignore"
settings = Settings()

View file

@ -1,203 +0,0 @@
"""Provider health cache.
Tracks per-provider liveness for the LLM router. The router reads
:meth:`is_healthy` to decide whether to even try a provider in a chain;
the probe loop and the call-site fallback handler write state via
:meth:`mark_healthy` / :meth:`mark_unhealthy`.
Implements a simple circuit-breaker:
* The first failure flips no switch providers occasionally have
transient blips, we don't want to bounce off after a single 502.
* After ``failure_threshold`` consecutive failures the provider is
marked unhealthy for ``unhealthy_backoff`` seconds. During that
window :meth:`is_healthy` returns ``False`` so the router fails
fast straight to the next chain entry.
* When the backoff expires :meth:`is_healthy` returns ``True`` again
(half-open). The next call exercises the provider; success calls
:meth:`mark_healthy` and fully resets state, failure re-arms the
backoff window.
State is kept in a plain dict guarded by a ``threading.Lock``. All
operations are short, lock-free reads of dict references aren't safe
because we mutate state in-place the lock keeps it boring. Probe
loop runs in the asyncio loop alongside the router, but the lock
costs are negligible at ~1 update/30s/provider.
"""
from __future__ import annotations
import logging
import threading
import time
from dataclasses import dataclass, field
from typing import Callable, Iterable
logger = logging.getLogger(__name__)
#: Notification fired whenever a provider transitions between healthy and
#: unhealthy. ``main.py`` wires this to the Prometheus gauge — but the
#: cache itself stays metrics-agnostic so tests don't need to mock it.
HealthChangeListener = Callable[[str, bool], None]
DEFAULT_FAILURE_THRESHOLD = 2
DEFAULT_UNHEALTHY_BACKOFF_SEC = 60.0
@dataclass
class ProviderState:
"""Per-provider liveness snapshot. All times are unix seconds."""
healthy: bool = True
consecutive_failures: int = 0
last_check: float = 0.0
last_error: str | None = None
unhealthy_until: float = 0.0
"""When > now, the provider is currently in backoff (`is_healthy → False`)."""
class ProviderHealthCache:
"""Thread-safe per-provider liveness with circuit-breaker semantics.
Provider IDs are arbitrary strings by convention we use the same
short name as the provider router (``ollama``, ``groq``, ``openrouter``,
``together``, ``google``). The cache is provider-list agnostic; states
are created lazily on first ``mark_*`` or queried-but-absent ``is_healthy``
call (returning ``True`` by default no state means no reason to skip).
"""
def __init__(
self,
*,
failure_threshold: int = DEFAULT_FAILURE_THRESHOLD,
unhealthy_backoff_sec: float = DEFAULT_UNHEALTHY_BACKOFF_SEC,
clock: callable = time.time,
) -> None:
if failure_threshold < 1:
raise ValueError("failure_threshold must be >= 1")
if unhealthy_backoff_sec < 0:
raise ValueError("unhealthy_backoff_sec must be >= 0")
self._failure_threshold = failure_threshold
self._unhealthy_backoff = unhealthy_backoff_sec
self._clock = clock
self._lock = threading.Lock()
self._states: dict[str, ProviderState] = {}
self._listeners: list[HealthChangeListener] = []
def add_listener(self, listener: HealthChangeListener) -> None:
"""Register a callback fired with ``(provider_id, healthy: bool)``
whenever a provider's healthy-flag transitions. Listeners run
outside the cache's lock; exceptions are swallowed and logged so a
bad listener can't break the underlying state machine.
"""
self._listeners.append(listener)
def _notify(self, provider_id: str, healthy: bool) -> None:
for listener in self._listeners:
try:
listener(provider_id, healthy)
except Exception as e: # noqa: BLE001
logger.error("health-change listener raised: %s", e)
@property
def failure_threshold(self) -> int:
return self._failure_threshold
@property
def unhealthy_backoff_sec(self) -> float:
return self._unhealthy_backoff
# ------------------------------------------------------------------
# Reads
# ------------------------------------------------------------------
def is_healthy(self, provider_id: str) -> bool:
"""Should the router try this provider right now?
Returns ``True`` by default for unknown providers the cache is
observation-only, not a registry.
"""
with self._lock:
state = self._states.get(provider_id)
if state is None:
return True
if state.unhealthy_until > self._clock():
return False
# Backoff expired: caller is allowed to try again (half-open).
return True
def get_state(self, provider_id: str) -> ProviderState | None:
"""Snapshot of one provider's state (for debugging / tests)."""
with self._lock:
state = self._states.get(provider_id)
return None if state is None else _copy(state)
def snapshot(self, expected: Iterable[str] | None = None) -> dict[str, ProviderState]:
"""All known states, plus zero-state placeholders for any names in
``expected`` that haven't been touched yet. Used by ``GET /v1/health``
so the response shape is stable regardless of probe order.
"""
with self._lock:
out = {pid: _copy(s) for pid, s in self._states.items()}
if expected:
for pid in expected:
out.setdefault(pid, ProviderState())
return out
# ------------------------------------------------------------------
# Writes
# ------------------------------------------------------------------
def mark_healthy(self, provider_id: str) -> None:
"""Provider answered correctly — clear any failure state."""
transitioned = False
with self._lock:
state = self._states.setdefault(provider_id, ProviderState())
transitioned = not state.healthy
state.healthy = True
state.consecutive_failures = 0
state.last_check = self._clock()
state.last_error = None
state.unhealthy_until = 0.0
if transitioned:
logger.info("provider %s recovered", provider_id)
self._notify(provider_id, True)
def mark_unhealthy(self, provider_id: str, reason: str) -> None:
"""Record a failure. Trips the breaker after the threshold."""
transitioned = False
with self._lock:
state = self._states.setdefault(provider_id, ProviderState())
state.consecutive_failures += 1
state.last_check = self._clock()
state.last_error = reason
tripped = state.consecutive_failures >= self._failure_threshold
if tripped and state.healthy:
state.healthy = False
state.unhealthy_until = self._clock() + self._unhealthy_backoff
transitioned = True
logger.warning(
"provider %s marked unhealthy after %d consecutive failures (%s); "
"backoff %.0fs",
provider_id,
state.consecutive_failures,
reason,
self._unhealthy_backoff,
)
elif not state.healthy:
# Still in unhealthy window; refresh the backoff so a flapping
# provider doesn't get re-tried every probe tick.
state.unhealthy_until = self._clock() + self._unhealthy_backoff
if transitioned:
self._notify(provider_id, False)
def _copy(state: ProviderState) -> ProviderState:
"""Return a shallow copy so callers can read without holding the lock."""
return ProviderState(
healthy=state.healthy,
consecutive_failures=state.consecutive_failures,
last_check=state.last_check,
last_error=state.last_error,
unhealthy_until=state.unhealthy_until,
)

View file

@ -1,210 +0,0 @@
"""Background probe loop that keeps :class:`ProviderHealthCache` fresh.
The router's circuit-breaker is reactive — it only learns a provider is
sick after the next live request fails. A reactive-only design means:
* every cold start re-discovers Ollama is down by paying one 75-second
``ConnectError``, and
* a provider that quietly recovers stays marked unhealthy until its
backoff expires and someone tries it.
The probe loop closes both gaps. Every ``interval`` seconds it pings
each registered provider with a small known-cheap request (Ollama:
``GET /api/tags``, OpenAI-compat: ``GET /v1/models``) and updates the
cache. Probes run concurrently per tick and respect a hard
``probe_timeout`` so a hanging provider can't stall the loop.
Probe functions are injected from outside (``probes={name: async-fn}``)
so this module stays decoupled from the provider classes wiring lives
in ``main.py`` where we already know which providers are configured.
"""
from __future__ import annotations
import asyncio
import logging
from typing import Awaitable, Callable
from .health import ProviderHealthCache
logger = logging.getLogger(__name__)
#: Probe function: returns ``True`` for healthy. Raising or returning
#: ``False`` both count as a failure (the loop just calls
#: ``mark_unhealthy``); an exception's string form becomes the
#: ``last_error`` for the snapshot endpoint.
ProbeFn = Callable[[], Awaitable[bool]]
DEFAULT_PROBE_INTERVAL_SEC = 30.0
DEFAULT_PROBE_TIMEOUT_SEC = 3.0
class HealthProbe:
"""Periodically probes every registered provider and updates the cache.
Lifecycle::
probe = HealthProbe(cache, {"ollama": probe_ollama, ...})
await probe.start() # spawns the background task
...
await probe.stop() # cancels and awaits cleanup
Tests typically call :meth:`tick_once` directly to exercise one cycle
without driving the asyncio scheduler through real ``asyncio.sleep``.
"""
def __init__(
self,
cache: ProviderHealthCache,
probes: dict[str, ProbeFn],
*,
interval: float = DEFAULT_PROBE_INTERVAL_SEC,
timeout: float = DEFAULT_PROBE_TIMEOUT_SEC,
) -> None:
if interval <= 0:
raise ValueError("interval must be > 0")
if timeout <= 0:
raise ValueError("timeout must be > 0")
self._cache = cache
self._probes = dict(probes)
self._interval = interval
self._timeout = timeout
self._task: asyncio.Task | None = None
self._stop = asyncio.Event()
@property
def interval(self) -> float:
return self._interval
@property
def timeout(self) -> float:
return self._timeout
@property
def provider_ids(self) -> list[str]:
return list(self._probes)
@property
def running(self) -> bool:
return self._task is not None and not self._task.done()
# ------------------------------------------------------------------
# Per-tick logic — exercised directly by tests
# ------------------------------------------------------------------
async def tick_once(self) -> None:
"""Probe every provider once, in parallel, updating the cache.
Errors in any one probe (including ``asyncio.TimeoutError``) are
captured per-provider one bad probe never sinks the loop.
"""
if not self._probes:
return
results = await asyncio.gather(
*(self._probe_one(name, fn) for name, fn in self._probes.items()),
return_exceptions=True,
)
# gather(return_exceptions=True) caught everything per-probe; this
# branch should never fire, but guard in case _probe_one ever grows
# a code path that bypasses its try/except.
# asyncio.gather() returns results in input order and same length,
# so the zip is a 1:1 mapping back to provider names.
for name, result in zip(self._probes, results):
if isinstance(result, BaseException):
logger.error("probe %s leaked exception: %s", name, result)
async def _probe_one(self, name: str, fn: ProbeFn) -> None:
try:
healthy = await asyncio.wait_for(fn(), timeout=self._timeout)
except asyncio.TimeoutError:
self._cache.mark_unhealthy(name, f"probe-timeout (>{self._timeout:.0f}s)")
return
except asyncio.CancelledError:
raise
except Exception as e: # noqa: BLE001 — probe SHOULD be permissive
self._cache.mark_unhealthy(name, f"probe-exception: {type(e).__name__}: {e}")
return
if healthy:
self._cache.mark_healthy(name)
else:
self._cache.mark_unhealthy(name, "probe-returned-false")
# ------------------------------------------------------------------
# Long-running task management
# ------------------------------------------------------------------
async def start(self) -> None:
"""Spawn the periodic probe task. Idempotent."""
if self.running:
return
self._stop.clear()
self._task = asyncio.create_task(self._run_forever(), name="mana-llm-health-probe")
logger.info(
"HealthProbe started (interval=%.0fs, timeout=%.0fs, providers=%s)",
self._interval,
self._timeout,
", ".join(self._probes) or "<none>",
)
async def stop(self) -> None:
"""Cancel the background task and wait for it to finish."""
if not self.running:
return
self._stop.set()
assert self._task is not None
self._task.cancel()
try:
await self._task
except asyncio.CancelledError:
pass
finally:
self._task = None
logger.info("HealthProbe stopped")
async def _run_forever(self) -> None:
# Probe immediately at boot so we don't serve traffic for `interval`
# seconds based on optimistic-default assumptions.
try:
await self.tick_once()
except Exception as e: # noqa: BLE001
logger.error("HealthProbe initial tick failed: %s", e)
while not self._stop.is_set():
try:
await asyncio.wait_for(self._stop.wait(), timeout=self._interval)
except asyncio.TimeoutError:
pass
else:
# _stop.wait() succeeded → stop signalled, exit.
return
try:
await self.tick_once()
except Exception as e: # noqa: BLE001
logger.error("HealthProbe tick failed: %s", e)
# ---------------------------------------------------------------------------
# Probe-function helpers
# ---------------------------------------------------------------------------
def make_http_probe(
url: str,
*,
headers: dict[str, str] | None = None,
expected_status_lt: int = 500,
) -> ProbeFn:
"""Return a probe function that does ``GET <url>`` and considers the
provider healthy iff the response status is below
``expected_status_lt`` (default: any non-5xx counts).
A 401/403/404 still counts as healthy because the *server* answered
auth or path mistakes are misconfiguration, not provider liveness.
"""
import httpx
async def probe() -> bool:
async with httpx.AsyncClient(timeout=httpx.Timeout(5.0)) as client:
resp = await client.get(url, headers=headers or None)
return resp.status_code < expected_status_lt
return probe

View file

@ -1,405 +0,0 @@
"""Main FastAPI application for mana-llm service."""
import asyncio
import logging
import signal
import time
from contextlib import asynccontextmanager
from pathlib import Path
from typing import Any
from fastapi import FastAPI, HTTPException, Request
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse, Response
from sse_starlette.sse import EventSourceResponse
from src.aliases import AliasConfigError, AliasRegistry
from src.api_auth import ApiKeyMiddleware
from src.config import settings
from src.health import ProviderHealthCache
from src.health_probe import HealthProbe, make_http_probe
from src.models import (
ChatCompletionRequest,
ChatCompletionResponse,
EmbeddingRequest,
EmbeddingResponse,
ModelInfo,
ModelsResponse,
)
from src.providers import ProviderRouter
from src.providers.errors import ProviderError
from src.streaming import stream_chat_completion
from src.utils.cache import close_redis
from src.utils.metrics import (
get_metrics,
record_llm_error,
record_llm_request,
set_provider_healthy,
)
#: Header carrying the concrete provider/model that actually served a
#: non-streaming response — useful for token-cost accounting on the
#: caller side, since `mana/long-form` could resolve to ollama, groq,
#: or claude depending on which providers were healthy at request time.
RESOLVED_MODEL_HEADER = "X-Mana-LLM-Resolved"
# Configure logging
logging.basicConfig(
level=getattr(logging, settings.log_level.upper()),
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
)
logger = logging.getLogger(__name__)
# Global service singletons
router: ProviderRouter | None = None
health_cache: ProviderHealthCache | None = None
health_probe: HealthProbe | None = None
alias_registry: AliasRegistry | None = None
def _build_provider_probes(
providers: dict[str, Any],
) -> dict[str, Any]:
"""Wire each configured provider to a cheap HTTP probe."""
probes: dict[str, Any] = {}
if "ollama" in providers:
probes["ollama"] = make_http_probe(f"{settings.ollama_url}/api/tags")
if "openrouter" in providers:
probes["openrouter"] = make_http_probe(
f"{settings.openrouter_base_url}/models",
headers={"Authorization": f"Bearer {settings.openrouter_api_key}"},
)
if "groq" in providers:
probes["groq"] = make_http_probe(
f"{settings.groq_base_url}/models",
headers={"Authorization": f"Bearer {settings.groq_api_key}"},
)
if "together" in providers:
probes["together"] = make_http_probe(
f"{settings.together_base_url}/models",
headers={"Authorization": f"Bearer {settings.together_api_key}"},
)
# Google: skipped — google-genai SDK is opaque enough that a probe
# would amount to a real API call. Treat as healthy by default; the
# router's call-site fallback will mark it unhealthy on real errors.
return probes
def _on_health_change(provider: str, healthy: bool) -> None:
"""Mirror health-cache transitions into the Prometheus gauge."""
set_provider_healthy(provider, healthy)
def _install_sighup_reload(loop: asyncio.AbstractEventLoop) -> None:
"""Reload ``aliases.yaml`` when the process receives SIGHUP.
Reload errors keep the previous good state in memory (see
AliasRegistry.reload). SIGHUP isn't available on Windows; we just
log and skip in that case.
"""
def handler() -> None:
if alias_registry is None:
return
try:
alias_registry.reload()
except AliasConfigError as e:
logger.error("alias reload rejected, keeping previous state: %s", e)
try:
loop.add_signal_handler(signal.SIGHUP, handler)
except (NotImplementedError, AttributeError, RuntimeError):
# NotImplementedError on Windows; RuntimeError when the loop is
# not running in the main thread (TestClient does this).
logger.info("SIGHUP reload not available in this context — skipping")
@asynccontextmanager
async def lifespan(app: FastAPI):
"""Application lifespan: load aliases, spin up router + health probe."""
global router, health_cache, health_probe, alias_registry
logger.info("Starting mana-llm service...")
aliases_path = Path(__file__).resolve().parent.parent / "aliases.yaml"
alias_registry = AliasRegistry(aliases_path)
logger.info("Loaded %d aliases from %s", len(alias_registry.list_aliases()), aliases_path)
health_cache = ProviderHealthCache()
health_cache.add_listener(_on_health_change)
router = ProviderRouter(aliases=alias_registry, health_cache=health_cache)
logger.info("Initialized providers: %s", list(router.providers))
# Initial gauge values so dashboards render before the first probe
# transition fires the listener.
for provider_name in router.providers:
set_provider_healthy(provider_name, True)
health_probe = HealthProbe(health_cache, _build_provider_probes(router.providers))
await health_probe.start()
_install_sighup_reload(asyncio.get_running_loop())
yield
logger.info("Shutting down mana-llm service...")
if health_probe is not None:
await health_probe.stop()
if router is not None:
await router.close()
await close_redis()
# Create FastAPI app
app = FastAPI(
title="mana-llm",
description="Central LLM abstraction service for Ollama and OpenAI-compatible APIs",
version="0.1.0",
lifespan=lifespan,
)
# Add CORS middleware
app.add_middleware(
CORSMiddleware,
allow_origins=settings.cors_origins_list,
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
app.add_middleware(ApiKeyMiddleware)
# Health endpoint
@app.get("/health")
async def health_check() -> dict[str, Any]:
"""Check service health and provider status."""
if router is None:
return {"status": "unhealthy", "error": "Router not initialized"}
provider_health = await router.health_check()
return {
"status": provider_health["status"],
"service": "mana-llm",
"version": "0.1.0",
"providers": provider_health["providers"],
}
# Metrics endpoint
@app.get("/metrics")
async def metrics() -> Response:
"""Prometheus metrics endpoint."""
return Response(content=get_metrics(), media_type="text/plain")
# ----- Alias / health debug endpoints (M4) ---------------------------------
@app.get("/v1/aliases")
async def list_aliases() -> dict[str, Any]:
"""Inspect the alias registry — what each ``mana/<class>`` resolves to.
Useful for debugging "which model actually answered my request" and
for confirming SIGHUP reloads picked up edits to ``aliases.yaml``.
"""
if alias_registry is None:
raise HTTPException(status_code=503, detail="Service not ready")
return {
"default": alias_registry.default_alias,
"aliases": [
{
"name": a.name,
"description": a.description,
"chain": list(a.chain),
}
for a in alias_registry.list_aliases()
],
}
@app.get("/v1/health")
async def detailed_health() -> dict[str, Any]:
"""Full per-provider liveness snapshot.
Includes the failure counter, last error, and the unhealthy-until
backoff timestamp info the original ``/health`` endpoint hides.
"""
if router is None:
raise HTTPException(status_code=503, detail="Service not ready")
return await router.health_check()
# Models endpoints
@app.get("/v1/models", response_model=ModelsResponse)
async def list_models() -> ModelsResponse:
"""List all available models from all providers."""
if router is None:
raise HTTPException(status_code=503, detail="Service not ready")
models = await router.list_models()
return ModelsResponse(data=models)
@app.get("/v1/models/{model_id:path}")
async def get_model(model_id: str) -> ModelInfo:
"""Get specific model information."""
if router is None:
raise HTTPException(status_code=503, detail="Service not ready")
model = await router.get_model(model_id)
if model is None:
raise HTTPException(status_code=404, detail=f"Model '{model_id}' not found")
return model
# Chat completions endpoint
@app.post("/v1/chat/completions", response_model=None)
async def chat_completions(
request: ChatCompletionRequest,
http_request: Request,
) -> ChatCompletionResponse | EventSourceResponse:
"""
Create a chat completion.
Supports both streaming (SSE) and non-streaming responses based on the
`stream` parameter in the request body.
"""
if router is None:
raise HTTPException(status_code=503, detail="Service not ready")
# The request's `model` field is what the caller asked for — could be
# `mana/long-form`, `ollama/gemma3:4b`, or even bare `gemma3:4b`. For
# error-path metrics we use that value (it's what the caller will
# search for); for success-path metrics we use the resolved provider
# so token-cost / latency attribute to the model that actually ran.
requested_provider, requested_model = _split_model(request.model)
start_time = time.time()
try:
if request.stream:
# Streaming response via SSE
logger.info(f"Streaming chat completion: {request.model}")
async def generate():
async for chunk in stream_chat_completion(router, request):
yield chunk
# Streaming metrics: we don't yet know which provider answered
# at request-record time. Each chunk's `model` field carries
# the resolved name; per-token latency is harder to attribute
# cleanly so we skip it for streams.
record_llm_request(requested_provider, requested_model, streaming=True)
return EventSourceResponse(
generate(),
media_type="text/event-stream",
)
else:
# Non-streaming response
logger.info(f"Chat completion: {request.model}")
response = await router.chat_completion(request)
resolved_provider, resolved_model = _split_model(response.model)
latency = time.time() - start_time
record_llm_request(
provider=resolved_provider,
model=resolved_model,
streaming=False,
prompt_tokens=response.usage.prompt_tokens,
completion_tokens=response.usage.completion_tokens,
latency=latency,
)
# `response.model` is the concrete provider/model the chain
# actually resolved to. Surface it via header so the caller
# can attribute token cost to the right model even when the
# request used an alias.
return JSONResponse(
content=response.model_dump(),
headers={RESOLVED_MODEL_HEADER: response.model},
)
except ValueError as e:
logger.error(f"Invalid request: {e}")
record_llm_error(requested_provider, requested_model, "invalid_request")
raise HTTPException(status_code=400, detail=str(e))
except ProviderError as e:
logger.warning(
f"Provider error on {requested_provider}/{requested_model}: "
f"kind={e.kind} detail={e}"
)
record_llm_error(requested_provider, requested_model, e.kind)
raise HTTPException(
status_code=e.http_status,
detail={"kind": e.kind, "message": str(e)},
)
except Exception as e:
logger.error(f"Chat completion failed: {e}")
record_llm_error(requested_provider, requested_model, "server_error")
raise HTTPException(status_code=500, detail=str(e))
# Embeddings endpoint
@app.post("/v1/embeddings", response_model=EmbeddingResponse)
async def create_embeddings(request: EmbeddingRequest) -> EmbeddingResponse:
"""Create embeddings for the input text."""
if router is None:
raise HTTPException(status_code=503, detail="Service not ready")
provider, model = _split_model(request.model)
start_time = time.time()
try:
logger.info(f"Creating embeddings: {request.model}")
response = await router.embeddings(request)
latency = time.time() - start_time
record_llm_request(
provider=provider,
model=model,
streaming=False,
prompt_tokens=response.usage.prompt_tokens,
latency=latency,
)
return response
except ValueError as e:
logger.error(f"Invalid embedding request: {e}")
record_llm_error(provider, model, "invalid_request")
raise HTTPException(status_code=400, detail=str(e))
except Exception as e:
logger.error(f"Embeddings failed: {e}")
record_llm_error(provider, model, "server_error")
raise HTTPException(status_code=500, detail=str(e))
def _split_model(model: str) -> tuple[str, str]:
"""Split a ``provider/model`` string for metric labelling.
Bare names with no slash default to ``ollama`` to match the legacy
OpenAI-style behaviour. Aliases (``mana/...``) keep their namespace
in the metrics that's intentional, so request-side counters tell
you what callers ASKED for, while the resolved-side counters
(``mana_llm_alias_resolved_total``) tell you what they GOT.
"""
if "/" in model:
provider, _, name = model.partition("/")
return provider.lower(), name
return "ollama", model
if __name__ == "__main__":
import uvicorn
uvicorn.run(
"src.main:app",
host="0.0.0.0",
port=settings.port,
reload=True,
log_level=settings.log_level,
)

View file

@ -1,47 +0,0 @@
"""Pydantic models for OpenAI-compatible API."""
from .requests import (
ChatCompletionRequest,
EmbeddingRequest,
FunctionSpec,
Message,
ToolChoice,
ToolSpec,
)
from .responses import (
ChatCompletionResponse,
ChatCompletionStreamResponse,
Choice,
DeltaContent,
EmbeddingData,
EmbeddingResponse,
MessageResponse,
ModelInfo,
ModelsResponse,
StreamChoice,
ToolCall,
ToolCallFunction,
Usage,
)
__all__ = [
"ChatCompletionRequest",
"ChatCompletionResponse",
"ChatCompletionStreamResponse",
"Choice",
"DeltaContent",
"EmbeddingData",
"EmbeddingRequest",
"EmbeddingResponse",
"FunctionSpec",
"Message",
"MessageResponse",
"ModelInfo",
"ModelsResponse",
"StreamChoice",
"ToolCall",
"ToolCallFunction",
"ToolChoice",
"ToolSpec",
"Usage",
]

View file

@ -1,128 +0,0 @@
"""Request models for OpenAI-compatible API."""
from typing import Any, Literal
from pydantic import BaseModel, Field
class TextContent(BaseModel):
"""Text content in a message."""
type: Literal["text"] = "text"
text: str
class ImageUrl(BaseModel):
"""Image URL reference."""
url: str # Can be http(s):// or data:image/...;base64,...
class ImageContent(BaseModel):
"""Image content in a message."""
type: Literal["image_url"] = "image_url"
image_url: ImageUrl
MessageContent = str | list[TextContent | ImageContent]
class ToolCallFunction(BaseModel):
"""The function portion of a tool_call on an assistant message."""
name: str
# Arguments are passed as a JSON string (OpenAI spec). Providers may
# emit structured args natively; the adapter serialises them here.
arguments: str
class ToolCall(BaseModel):
"""A tool invocation the assistant decided to make."""
id: str
type: Literal["function"] = "function"
function: ToolCallFunction
class Message(BaseModel):
"""A single message in the conversation.
`tool` messages carry the result of a previously-requested tool call
back into the context; they must reference the originating call via
``tool_call_id``. Assistant messages may contain either plain
``content`` or ``tool_calls`` (or both, though providers typically
only emit one at a time).
"""
role: Literal["system", "user", "assistant", "tool"]
content: MessageContent | None = None
tool_call_id: str | None = None
tool_calls: list[ToolCall] | None = None
class FunctionSpec(BaseModel):
"""OpenAI-style function declaration."""
name: str
description: str
# JSON Schema for the function's parameters. Kept as a dict here to
# stay loose; providers translate it to their native shape.
parameters: dict[str, Any] = Field(default_factory=dict)
class ToolSpec(BaseModel):
"""A tool the model may call. Only ``function`` tools are supported."""
type: Literal["function"] = "function"
function: FunctionSpec
class ToolChoiceFunction(BaseModel):
"""Force-pick a specific function for ``tool_choice``."""
type: Literal["function"] = "function"
function: dict[str, str] # {"name": "tool_name"}
ToolChoice = Literal["auto", "none", "required"] | ToolChoiceFunction
class ResponseFormat(BaseModel):
"""OpenAI structured-output response_format hint.
Two shapes are accepted:
- {"type": "json_object"} free-form JSON
- {"type": "json_schema",
"json_schema": {"name": "...", "schema": {...}, "strict": bool}}
schema-constrained JSON; passed through to providers that
support it (e.g. Ollama 0.5+ via its native `format` field).
"""
type: Literal["json_object", "json_schema"]
json_schema: dict[str, Any] | None = None
class ChatCompletionRequest(BaseModel):
"""Request body for chat completions endpoint."""
model: str = Field(..., description="Model identifier in format 'provider/model' or just 'model'")
messages: list[Message] = Field(..., min_length=1)
stream: bool = False
temperature: float | None = Field(default=None, ge=0.0, le=2.0)
max_tokens: int | None = Field(default=None, gt=0)
top_p: float | None = Field(default=None, ge=0.0, le=1.0)
frequency_penalty: float | None = Field(default=None, ge=-2.0, le=2.0)
presence_penalty: float | None = Field(default=None, ge=-2.0, le=2.0)
stop: str | list[str] | None = None
response_format: ResponseFormat | None = None
tools: list[ToolSpec] | None = None
tool_choice: ToolChoice | None = None
class EmbeddingRequest(BaseModel):
"""Request body for embeddings endpoint."""
model: str = Field(..., description="Model identifier")
input: str | list[str] = Field(..., description="Text(s) to embed")
encoding_format: Literal["float", "base64"] = "float"

View file

@ -1,120 +0,0 @@
"""Response models for OpenAI-compatible API."""
import time
import uuid
from typing import Literal
from pydantic import BaseModel, Field
class Usage(BaseModel):
"""Token usage information."""
prompt_tokens: int = 0
completion_tokens: int = 0
total_tokens: int = 0
class ToolCallFunction(BaseModel):
"""Function portion of a tool_call the assistant produced."""
name: str
arguments: str # JSON-encoded (OpenAI spec)
class ToolCall(BaseModel):
"""A tool invocation the assistant decided to make."""
id: str
type: Literal["function"] = "function"
function: ToolCallFunction
class MessageResponse(BaseModel):
"""Response message from the model."""
role: Literal["assistant"] = "assistant"
content: str | None = None
tool_calls: list[ToolCall] | None = None
class Choice(BaseModel):
"""A single completion choice."""
index: int = 0
message: MessageResponse
finish_reason: (
Literal["stop", "length", "content_filter", "tool_calls"] | None
) = "stop"
class ChatCompletionResponse(BaseModel):
"""Response from chat completions endpoint (non-streaming)."""
id: str = Field(default_factory=lambda: f"chatcmpl-{uuid.uuid4().hex[:12]}")
object: Literal["chat.completion"] = "chat.completion"
created: int = Field(default_factory=lambda: int(time.time()))
model: str
choices: list[Choice]
usage: Usage = Field(default_factory=Usage)
class DeltaContent(BaseModel):
"""Delta content for streaming responses."""
role: Literal["assistant"] | None = None
content: str | None = None
tool_calls: list[ToolCall] | None = None
class StreamChoice(BaseModel):
"""A single streaming choice."""
index: int = 0
delta: DeltaContent
finish_reason: (
Literal["stop", "length", "content_filter", "tool_calls"] | None
) = None
class ChatCompletionStreamResponse(BaseModel):
"""Response chunk from chat completions endpoint (streaming)."""
id: str = Field(default_factory=lambda: f"chatcmpl-{uuid.uuid4().hex[:12]}")
object: Literal["chat.completion.chunk"] = "chat.completion.chunk"
created: int = Field(default_factory=lambda: int(time.time()))
model: str
choices: list[StreamChoice]
class ModelInfo(BaseModel):
"""Information about a model."""
id: str
object: Literal["model"] = "model"
created: int = Field(default_factory=lambda: int(time.time()))
owned_by: str = "mana-llm"
class ModelsResponse(BaseModel):
"""Response from models endpoint."""
object: Literal["list"] = "list"
data: list[ModelInfo]
class EmbeddingData(BaseModel):
"""A single embedding result."""
object: Literal["embedding"] = "embedding"
index: int = 0
embedding: list[float]
class EmbeddingResponse(BaseModel):
"""Response from embeddings endpoint."""
object: Literal["list"] = "list"
data: list[EmbeddingData]
model: str
usage: Usage = Field(default_factory=Usage)

View file

@ -1,13 +0,0 @@
"""LLM Provider implementations."""
from .base import LLMProvider
from .ollama import OllamaProvider
from .openai_compat import OpenAICompatProvider
from .router import ProviderRouter
__all__ = [
"LLMProvider",
"OllamaProvider",
"OpenAICompatProvider",
"ProviderRouter",
]

View file

@ -1,76 +0,0 @@
"""Abstract base class for LLM providers."""
from abc import ABC, abstractmethod
from collections.abc import AsyncIterator
from typing import Any
from src.models import (
ChatCompletionRequest,
ChatCompletionResponse,
ChatCompletionStreamResponse,
EmbeddingRequest,
EmbeddingResponse,
ModelInfo,
)
class LLMProvider(ABC):
"""Abstract base class for LLM providers."""
name: str = "base"
# Set to True if the provider supports OpenAI-style `tools` + `tool_calls`
# for chat completions. The router rejects tool-bearing requests routed
# to providers without support rather than silently dropping the tools.
# Provider adapters may further narrow this per-model if needed.
supports_tools: bool = False
def model_supports_tools(self, model: str) -> bool:
"""Check if a specific model within this provider supports tools.
Default: falls back to the provider-wide flag. Providers with a
mixed capability surface (e.g. Ollama depends on the local
model) override this.
"""
return self.supports_tools
@abstractmethod
async def chat_completion(
self,
request: ChatCompletionRequest,
model: str,
) -> ChatCompletionResponse:
"""Generate a chat completion (non-streaming)."""
...
@abstractmethod
async def chat_completion_stream(
self,
request: ChatCompletionRequest,
model: str,
) -> AsyncIterator[ChatCompletionStreamResponse]:
"""Generate a chat completion (streaming)."""
...
@abstractmethod
async def list_models(self) -> list[ModelInfo]:
"""List available models."""
...
@abstractmethod
async def embeddings(
self,
request: EmbeddingRequest,
model: str,
) -> EmbeddingResponse:
"""Generate embeddings for input text."""
...
@abstractmethod
async def health_check(self) -> dict[str, Any]:
"""Check provider health status."""
...
async def close(self) -> None:
"""Clean up resources."""
pass

View file

@ -1,111 +0,0 @@
"""Structured provider errors.
These map to distinct HTTP status codes and metric labels so callers can
distinguish between "model refused to answer" (blocked), "hit the token
budget" (truncated), "auth is broken", "we were rate-limited", and
"model doesn't support what we asked" (capability). The old behaviour of
returning an empty-string response on content-filter hits silently
corrupts downstream pipelines (e.g. the planner sees "" and reports
"no JSON block found" misleading).
"""
class ProviderError(Exception):
"""Base class for structured provider errors."""
# Short stable identifier used in metrics and API responses.
kind: str = "unknown"
# Suggested HTTP status when this error bubbles out of an endpoint.
http_status: int = 502
class ProviderBlockedError(ProviderError):
"""The provider refused to return content (safety, recitation, …).
The model produced nothing usable because a content filter or policy
guardrail fired. Retry with the same inputs will fail again the
caller needs to adjust the prompt or give the user a clear message.
"""
kind = "blocked"
http_status = 422
def __init__(self, reason: str, detail: str | None = None):
self.reason = reason # e.g. "SAFETY", "RECITATION"
self.detail = detail
msg = f"Provider blocked the response ({reason})"
if detail:
msg += f": {detail}"
super().__init__(msg)
class ProviderTruncatedError(ProviderError):
"""The response was cut off before completion (hit max_tokens)."""
kind = "truncated"
http_status = 502
def __init__(self, partial_text: str | None = None):
self.partial_text = partial_text
super().__init__(
"Provider truncated the response (max_tokens reached) — "
"re-run with a higher max_tokens or smaller input"
)
class ProviderAuthError(ProviderError):
"""Authentication against the upstream provider failed."""
kind = "auth"
http_status = 502 # not 401 — the client's auth is fine, *ours* isn't
class ProviderRateLimitError(ProviderError):
"""The upstream provider rate-limited us."""
kind = "rate_limit"
http_status = 429
class ProviderCapabilityError(ProviderError):
"""The requested feature is not supported by the chosen model.
Typically raised when a request asks for native tool-calling against
a model that does not support it. We refuse to silently degrade.
"""
kind = "capability"
http_status = 400
class NoHealthyProviderError(ProviderError):
"""Every entry in the resolved chain has been tried and either was
unconfigured, marked unhealthy by the cache, or failed in flight.
Carries the full attempt log so the caller can see which providers
were tried and why each one was skipped or failed invaluable when
a real outage hits and the API returns 503 instead of the usual 200.
"""
kind = "no_healthy_provider"
http_status = 503
def __init__(
self,
model_or_alias: str,
attempts: list[tuple[str, str]],
last_exception: Exception | None = None,
) -> None:
self.model_or_alias = model_or_alias
self.attempts = list(attempts)
self.last_exception = last_exception
if attempts:
attempt_log = ", ".join(f"{model}={reason}" for model, reason in attempts)
else:
attempt_log = "(no providers in resolved chain were configured)"
msg = (
f"no healthy provider could serve {model_or_alias!r}. Attempts: {attempt_log}"
)
if last_exception is not None:
msg += f". Last error: {type(last_exception).__name__}: {last_exception}"
super().__init__(msg)

View file

@ -1,521 +0,0 @@
"""Google Gemini provider for mana-llm (fallback when Ollama is unavailable)."""
import json
import logging
import uuid
from collections.abc import AsyncIterator
from typing import Any
from google import genai
from google.genai import types
from src.config import settings
from src.models import (
ChatCompletionRequest,
ChatCompletionResponse,
ChatCompletionStreamResponse,
Choice,
DeltaContent,
EmbeddingData,
EmbeddingRequest,
EmbeddingResponse,
MessageResponse,
ModelInfo,
StreamChoice,
ToolCall,
ToolCallFunction,
ToolSpec,
Usage,
)
from .base import LLMProvider
from .errors import (
ProviderAuthError,
ProviderBlockedError,
ProviderError,
ProviderRateLimitError,
ProviderTruncatedError,
)
logger = logging.getLogger(__name__)
def _build_gemini_tools(tools: list[ToolSpec] | None) -> list[types.Tool] | None:
"""Translate our ToolSpec list into Gemini ``types.Tool`` declarations."""
if not tools:
return None
declarations: list[types.FunctionDeclaration] = []
for t in tools:
declarations.append(
types.FunctionDeclaration(
name=t.function.name,
description=t.function.description,
parameters=t.function.parameters or None,
)
)
return [types.Tool(function_declarations=declarations)]
def _extract_tool_calls(candidate: Any) -> list[ToolCall] | None:
"""Pull any ``function_call`` parts off a candidate into ToolCalls.
Gemini emits tool calls as ``content.parts[i].function_call`` with a
``FunctionCall(name, args)`` where ``args`` is a dict (not a JSON
string). We normalise to OpenAI shape: arguments are JSON-encoded
strings so downstream handlers can treat all providers the same.
"""
if candidate is None:
return None
content = getattr(candidate, "content", None)
parts = getattr(content, "parts", None) or []
calls: list[ToolCall] = []
for part in parts:
fc = getattr(part, "function_call", None)
if fc is None:
continue
name = getattr(fc, "name", None)
if not name:
continue
args = getattr(fc, "args", None) or {}
# Gemini's args are already dict-shaped; serialise to JSON string.
try:
arguments_json = json.dumps(dict(args), ensure_ascii=False)
except (TypeError, ValueError):
arguments_json = json.dumps({}, ensure_ascii=False)
calls.append(
ToolCall(
id=f"call_{uuid.uuid4().hex[:12]}",
function=ToolCallFunction(name=name, arguments=arguments_json),
)
)
return calls or None
def _unwrap_gemini_response(
response: Any, gemini_model: str
) -> tuple[str, list[ToolCall] | None, str]:
"""Validate a non-streaming Gemini response.
Returns ``(text, tool_calls, finish_reason)``. Raises a structured
``ProviderError`` if the response was blocked, truncated, or
otherwise produced no usable payload. ``response.text`` silently
returns ``""`` on blocked responses we refuse to propagate that.
"""
candidates = getattr(response, "candidates", None) or []
candidate = candidates[0] if candidates else None
finish_reason = getattr(candidate, "finish_reason", None)
# SDK sometimes exposes the enum name on `.name`, sometimes it's a string.
finish_name = getattr(finish_reason, "name", None) or (
str(finish_reason) if finish_reason is not None else None
)
# Strip the leading enum prefix if present (e.g. "FinishReason.SAFETY").
if finish_name and "." in finish_name:
finish_name = finish_name.rsplit(".", 1)[-1]
text = response.text or ""
tool_calls = _extract_tool_calls(candidate)
if finish_name in {"SAFETY", "RECITATION", "PROHIBITED_CONTENT", "SPII", "BLOCKLIST"}:
# Pull the first safety rating that actually blocked if present.
ratings = getattr(candidate, "safety_ratings", None) or []
blocked = [
getattr(r, "category", None)
for r in ratings
if getattr(r, "blocked", False)
]
detail = ", ".join(str(c) for c in blocked if c) or None
logger.warning(
"Gemini response blocked (model=%s, reason=%s, detail=%s)",
gemini_model,
finish_name,
detail,
)
raise ProviderBlockedError(reason=finish_name, detail=detail)
if finish_name == "MAX_TOKENS":
logger.warning(
"Gemini response truncated at max_tokens (model=%s)", gemini_model
)
raise ProviderTruncatedError(partial_text=text or None)
if not text and not tool_calls and finish_name not in (None, "STOP"):
# Unknown finish reason, nothing to return — surface instead of "".
raise ProviderError(
f"Gemini returned no content (finish_reason={finish_name})"
)
# Normalise the finish_reason to our OpenAI-compatible vocabulary.
openai_finish = "stop"
if tool_calls:
openai_finish = "tool_calls"
elif finish_name == "MAX_TOKENS":
openai_finish = "length"
return text, tool_calls, openai_finish
def _lookup_tool_name(messages: list[Any], tool_call_id: str | None) -> str | None:
"""Find the tool name for a given ``tool_call_id``.
OpenAI's tool-message carries only the id; the matching name lives
on the preceding assistant message's ``tool_calls[]``. We scan from
the end (most recent assistant turn) backwards.
"""
if not tool_call_id:
return None
for m in reversed(messages):
if m.role != "assistant" or not m.tool_calls:
continue
for call in m.tool_calls:
if call.id == tool_call_id:
return call.function.name
return None
def _wrap_gemini_call_error(err: Exception, gemini_model: str) -> ProviderError:
"""Translate a raw Google SDK exception into a structured ProviderError.
The SDK uses google.genai.errors.* but we avoid importing them at
top level to keep the provider optional. String-match the class
name instead.
"""
cls_name = type(err).__name__
msg = str(err) or cls_name
if "Auth" in cls_name or "PermissionDenied" in cls_name or "Unauthenticated" in cls_name:
return ProviderAuthError(f"Gemini auth failed for {gemini_model}: {msg}")
if "ResourceExhausted" in cls_name or "RateLimit" in cls_name or "429" in msg:
return ProviderRateLimitError(f"Gemini rate-limited for {gemini_model}: {msg}")
return ProviderError(f"Gemini call failed for {gemini_model}: {msg}")
# Model mapping: Ollama model → Google Gemini equivalent
OLLAMA_TO_GEMINI: dict[str, str] = {
"gemma3:4b": "gemini-2.5-flash",
"gemma3:12b": "gemini-2.5-flash",
"gemma3:27b": "gemini-2.5-pro",
"llava:7b": "gemini-2.5-flash", # Gemini has native vision
"qwen3-vl:4b": "gemini-2.5-flash", # vision fallback
"qwen2.5-coder:7b": "gemini-2.5-flash",
"qwen2.5-coder:14b": "gemini-2.5-pro",
"phi3.5:latest": "gemini-2.5-flash",
"ministral-3:3b": "gemini-2.5-flash",
"deepseek-ocr:latest": "gemini-2.5-flash",
}
class GoogleProvider(LLMProvider):
"""Google Gemini API provider."""
name = "google"
# Gemini 2.x supports OpenAI-style function calling across all listed models.
supports_tools = True
def __init__(self, api_key: str, default_model: str = "gemini-2.5-flash"):
self.api_key = api_key
self.default_model = default_model
self.client = genai.Client(api_key=api_key)
def map_model(self, ollama_model: str) -> str:
"""Map an Ollama model name to a Google Gemini equivalent."""
return OLLAMA_TO_GEMINI.get(ollama_model, self.default_model)
def _convert_messages(
self, request: ChatCompletionRequest
) -> tuple[str | None, list[types.Content]]:
"""Convert OpenAI-format messages to Google Gemini format.
Returns (system_instruction, contents). Handles text, multimodal
image content, assistant messages carrying ``tool_calls``, and
``tool`` result messages (mapped to Gemini ``function_response``
parts Gemini has no ``tool`` role, function responses ride on
a ``user`` turn).
"""
system_instruction: str | None = None
contents: list[types.Content] = []
for msg in request.messages:
if msg.role == "system":
if isinstance(msg.content, str):
system_instruction = msg.content
continue
# Tool result message → function_response Part on a user turn.
if msg.role == "tool":
# The content is the stringified tool result. We also need
# a tool name — the OpenAI spec carries it on the matching
# assistant tool_call, keyed by tool_call_id. We don't
# track that back-reference here, so we pull the name
# from the preceding assistant message's tool_calls.
name = _lookup_tool_name(request.messages, msg.tool_call_id)
if not name:
continue # orphan tool message — skip silently
payload: Any
if isinstance(msg.content, str):
try:
payload = json.loads(msg.content)
if not isinstance(payload, dict):
payload = {"result": payload}
except (TypeError, ValueError):
payload = {"result": msg.content}
else:
payload = {"result": ""}
contents.append(
types.Content(
role="user",
parts=[
types.Part.from_function_response(
name=name, response=payload
)
],
)
)
continue
role = "user" if msg.role == "user" else "model"
parts: list[types.Part] = []
if msg.role == "assistant" and msg.tool_calls:
for call in msg.tool_calls:
try:
args_obj = json.loads(call.function.arguments or "{}")
except (TypeError, ValueError):
args_obj = {}
parts.append(
types.Part(
function_call=types.FunctionCall(
name=call.function.name, args=args_obj
)
)
)
if msg.content is not None:
if isinstance(msg.content, str):
parts.append(types.Part.from_text(text=msg.content))
else:
for part in msg.content:
if part.type == "text":
parts.append(types.Part.from_text(text=part.text))
elif part.type == "image_url" and part.image_url:
url = part.image_url.url
if url.startswith("data:"):
header, b64_data = url.split(",", 1)
mime_type = header.split(":")[1].split(";")[0]
import base64
image_bytes = base64.b64decode(b64_data)
parts.append(
types.Part.from_bytes(
data=image_bytes, mime_type=mime_type
)
)
else:
parts.append(
types.Part.from_uri(
file_uri=url, mime_type="image/jpeg"
)
)
if parts:
contents.append(types.Content(role=role, parts=parts))
return system_instruction, contents
async def chat_completion(
self,
request: ChatCompletionRequest,
model: str,
) -> ChatCompletionResponse:
"""Generate a chat completion via Google Gemini."""
gemini_model = self.map_model(model) if model in OLLAMA_TO_GEMINI else model
system_instruction, contents = self._convert_messages(request)
config: dict[str, Any] = {}
if request.temperature is not None:
config["temperature"] = request.temperature
if request.max_tokens is not None:
config["max_output_tokens"] = request.max_tokens
if request.top_p is not None:
config["top_p"] = request.top_p
if request.stop:
stop_seqs = request.stop if isinstance(request.stop, list) else [request.stop]
config["stop_sequences"] = stop_seqs
gemini_tools = _build_gemini_tools(request.tools)
if gemini_tools:
config["tools"] = gemini_tools
gen_config = types.GenerateContentConfig(
system_instruction=system_instruction,
**config,
)
logger.debug(f"Google Gemini request: {gemini_model}, messages: {len(contents)}")
try:
response = await self.client.aio.models.generate_content(
model=gemini_model,
contents=contents,
config=gen_config,
)
except ProviderError:
raise
except Exception as err:
raise _wrap_gemini_call_error(err, gemini_model) from err
content, tool_calls, finish_reason = _unwrap_gemini_response(
response, gemini_model
)
usage_meta = response.usage_metadata
return ChatCompletionResponse(
model=f"google/{gemini_model}",
choices=[
Choice(
index=0,
message=MessageResponse(
content=content or None,
tool_calls=tool_calls,
),
finish_reason=finish_reason,
)
],
usage=Usage(
prompt_tokens=usage_meta.prompt_token_count if usage_meta else 0,
completion_tokens=usage_meta.candidates_token_count if usage_meta else 0,
total_tokens=usage_meta.total_token_count if usage_meta else 0,
),
)
async def chat_completion_stream(
self,
request: ChatCompletionRequest,
model: str,
) -> AsyncIterator[ChatCompletionStreamResponse]:
"""Generate a streaming chat completion via Google Gemini."""
gemini_model = self.map_model(model) if model in OLLAMA_TO_GEMINI else model
system_instruction, contents = self._convert_messages(request)
config: dict[str, Any] = {}
if request.temperature is not None:
config["temperature"] = request.temperature
if request.max_tokens is not None:
config["max_output_tokens"] = request.max_tokens
if request.top_p is not None:
config["top_p"] = request.top_p
gemini_tools = _build_gemini_tools(request.tools)
if gemini_tools:
config["tools"] = gemini_tools
gen_config = types.GenerateContentConfig(
system_instruction=system_instruction,
**config,
)
# First chunk with role
yield ChatCompletionStreamResponse(
model=f"google/{gemini_model}",
choices=[
StreamChoice(
delta=DeltaContent(role="assistant"),
finish_reason=None,
)
],
)
last_chunk: Any = None
emitted_any_text = False
try:
stream = await self.client.aio.models.generate_content_stream(
model=gemini_model,
contents=contents,
config=gen_config,
)
except Exception as err:
raise _wrap_gemini_call_error(err, gemini_model) from err
async for chunk in stream:
last_chunk = chunk
text = chunk.text
if text:
emitted_any_text = True
yield ChatCompletionStreamResponse(
model=f"google/{gemini_model}",
choices=[
StreamChoice(
delta=DeltaContent(content=text),
finish_reason=None,
)
],
)
# Post-stream check: if the stream ended without emitting any text,
# surface the structured reason instead of quietly closing with an
# empty "stop". Matches _unwrap_gemini_response semantics. We
# discard the return value here — the streaming path produces its
# content chunk-by-chunk above, so we only need this call for its
# side-effect of raising on SAFETY / RECITATION / MAX_TOKENS.
if not emitted_any_text and last_chunk is not None:
_unwrap_gemini_response(last_chunk, gemini_model)
# Final chunk
yield ChatCompletionStreamResponse(
model=f"google/{gemini_model}",
choices=[
StreamChoice(
delta=DeltaContent(),
finish_reason="stop",
)
],
)
async def list_models(self) -> list[ModelInfo]:
"""List available Google Gemini models."""
# Return a static list of commonly used models
return [
ModelInfo(id="google/gemini-2.5-flash", owned_by="google"),
ModelInfo(id="google/gemini-2.5-pro", owned_by="google"),
]
async def embeddings(
self,
request: EmbeddingRequest,
model: str,
) -> EmbeddingResponse:
"""Generate embeddings via Google Gemini."""
inputs = request.input if isinstance(request.input, list) else [request.input]
result = await self.client.aio.models.embed_content(
model="text-embedding-004",
contents=inputs,
)
return EmbeddingResponse(
data=[
EmbeddingData(index=i, embedding=emb.values)
for i, emb in enumerate(result.embeddings)
],
model="google/text-embedding-004",
usage=Usage(
prompt_tokens=sum(len(t.split()) for t in inputs),
total_tokens=sum(len(t.split()) for t in inputs),
),
)
async def health_check(self) -> dict[str, Any]:
"""Check Google API health."""
try:
# Quick test: list models
response = await self.client.aio.models.list(config={"page_size": 1})
return {
"status": "healthy",
"provider": "google",
}
except Exception as e:
return {
"status": "unhealthy",
"provider": "google",
"error": str(e),
}
async def close(self) -> None:
"""No cleanup needed for Google client."""
pass

View file

@ -1,462 +0,0 @@
"""Ollama provider implementation."""
import json
import logging
from collections.abc import AsyncIterator
from typing import Any
import httpx
from src.config import settings
from src.models import (
ChatCompletionRequest,
ChatCompletionResponse,
ChatCompletionStreamResponse,
Choice,
DeltaContent,
EmbeddingData,
EmbeddingRequest,
EmbeddingResponse,
MessageResponse,
ModelInfo,
StreamChoice,
ToolCall,
ToolCallFunction,
Usage,
)
from .base import LLMProvider
# Ollama emits tool_calls under /api/chat as:
# {"message":{"content":"", "tool_calls":[{"function":{"name":"x","arguments":{...}}}]}}
# Arguments arrive as a dict — we normalise to a JSON string to match
# the OpenAI-spec shape our downstream code expects.
# Ollama models known to support tool-calling reliably (as of 0.3+).
# Everything else: we still pass `tools` to the API (it ignores them on
# incompatible models), but the assistant response will be plain text.
# A shared-ai-level capability check rejects these cases before the call.
TOOL_CAPABLE_OLLAMA_PATTERNS: tuple[str, ...] = (
"llama3.1",
"llama3.2",
"llama3.3",
"qwen2.5",
"qwen3",
"mistral",
"mixtral",
"command-r",
"firefunction",
)
logger = logging.getLogger(__name__)
def _safe_parse_args(raw: str | None) -> dict[str, Any]:
"""Best-effort parse of a JSON-encoded arguments string."""
if not raw:
return {}
try:
parsed = json.loads(raw)
except (TypeError, ValueError):
return {}
return parsed if isinstance(parsed, dict) else {}
def _tool_calls_from_ollama(raw: list[dict[str, Any]] | None) -> list[ToolCall] | None:
"""Normalise Ollama's tool_calls into our ToolCall model."""
if not raw:
return None
import uuid
calls: list[ToolCall] = []
for idx, c in enumerate(raw):
fn = c.get("function") or {}
name = fn.get("name")
if not name:
continue
args = fn.get("arguments")
if isinstance(args, dict):
arguments_json = json.dumps(args, ensure_ascii=False)
elif isinstance(args, str):
arguments_json = args
else:
arguments_json = "{}"
calls.append(
ToolCall(
id=c.get("id") or f"call_{uuid.uuid4().hex[:12]}",
function=ToolCallFunction(name=name, arguments=arguments_json),
)
)
return calls or None
def _strip_json_fences(content: str) -> str:
"""Strip ```json ... ``` markdown fences from a string if present.
Some Ollama vision models still wrap structured-output responses in
a markdown code block even when `format` is set. Downstream parsers
(Vercel AI SDK generateObject, manual JSON.parse) expect clean JSON,
so we normalize the response here at the proxy boundary.
"""
s = content.strip()
if s.startswith("```"):
# Drop the opening fence (```json or ``` plus any language tag)
first_newline = s.find("\n")
if first_newline != -1:
s = s[first_newline + 1 :]
# Drop the closing fence
if s.endswith("```"):
s = s[:-3]
s = s.strip()
return s
class OllamaProvider(LLMProvider):
"""Ollama LLM provider."""
name = "ollama"
supports_tools = True
def model_supports_tools(self, model: str) -> bool:
"""Narrow tool capability to models trained for function calling.
Ollama accepts the ``tools`` payload even for incompatible models
but silently drops it the assistant just replies in prose. We
reject those requests upfront so callers get a clear error.
"""
lower = model.lower()
return any(pattern in lower for pattern in TOOL_CAPABLE_OLLAMA_PATTERNS)
def __init__(self, base_url: str | None = None, timeout: int | None = None):
self.base_url = (base_url or settings.ollama_url).rstrip("/")
self.timeout = timeout or settings.ollama_timeout
self._client: httpx.AsyncClient | None = None
@property
def client(self) -> httpx.AsyncClient:
"""Get or create HTTP client."""
if self._client is None or self._client.is_closed:
self._client = httpx.AsyncClient(
base_url=self.base_url,
timeout=httpx.Timeout(self.timeout),
)
return self._client
def _convert_messages(self, request: ChatCompletionRequest) -> list[dict[str, Any]]:
"""Convert OpenAI message format to Ollama format.
Ollama's ``/api/chat`` uses the OpenAI shape for assistant
``tool_calls`` and ``tool`` result messages (with args as objects
rather than JSON strings on output input still accepts JSON
strings), so we largely pass them through.
"""
messages: list[dict[str, Any]] = []
for msg in request.messages:
out: dict[str, Any] = {"role": msg.role}
if msg.tool_calls:
out["tool_calls"] = [
{
"function": {
"name": c.function.name,
# Ollama accepts either stringified JSON or
# already-parsed objects; send parsed for
# better compatibility with older models.
"arguments": _safe_parse_args(c.function.arguments),
}
}
for c in msg.tool_calls
]
if msg.content is None:
out["content"] = ""
elif isinstance(msg.content, str):
out["content"] = msg.content
else:
text_parts: list[str] = []
images: list[str] = []
for part in msg.content:
if part.type == "text":
text_parts.append(part.text)
elif part.type == "image_url":
url = part.image_url.url
if url.startswith("data:"):
base64_data = url.split(",", 1)[1] if "," in url else url
images.append(base64_data)
else:
logger.warning(
"HTTP image URLs not supported, skipping: %s...",
url[:50],
)
out["content"] = " ".join(text_parts)
if images:
out["images"] = images
messages.append(out)
return messages
async def chat_completion(
self,
request: ChatCompletionRequest,
model: str,
) -> ChatCompletionResponse:
"""Generate a chat completion (non-streaming)."""
payload: dict[str, Any] = {
"model": model,
"messages": self._convert_messages(request),
"stream": False,
}
# Pass through structured-output requests to Ollama's native
# `format` field. Ollama supports either `"json"` (free-form
# JSON object) or a full JSON schema dict. The OpenAI-style
# response_format the consumer sends maps as follows:
# - {"type": "json_object"} → "json"
# - {"type": "json_schema", "json_schema": {"schema": {...}}}
# → the schema dict (Ollama 0.5+ supports full schemas)
# Without this, Ollama wraps JSON in ```json ... ``` markdown
# fences, which breaks downstream strict parsers like the AI SDK
# generateObject() helper.
if request.response_format is not None:
rf = request.response_format
if rf.type == "json_object":
payload["format"] = "json"
elif rf.type == "json_schema" and rf.json_schema is not None:
# rf.json_schema is the OpenAI envelope:
# {"name": "...", "schema": {...}, "strict": true}
# Ollama wants just the inner schema dict.
inner = rf.json_schema.get("schema")
payload["format"] = inner if inner is not None else "json"
else:
payload["format"] = "json"
# Add optional parameters
options: dict[str, Any] = {}
if request.temperature is not None:
options["temperature"] = request.temperature
if request.top_p is not None:
options["top_p"] = request.top_p
if request.max_tokens is not None:
options["num_predict"] = request.max_tokens
if request.stop:
options["stop"] = request.stop if isinstance(request.stop, list) else [request.stop]
if options:
payload["options"] = options
if request.tools:
payload["tools"] = [
{"type": "function", "function": t.function.model_dump()}
for t in request.tools
]
logger.debug(f"Ollama request: {model}, messages: {len(request.messages)}")
response = await self.client.post("/api/chat", json=payload)
response.raise_for_status()
data = response.json()
message = data.get("message", {})
tool_calls = _tool_calls_from_ollama(message.get("tool_calls"))
# Defensive fence-stripping: even with `format` set, some older
# Ollama versions still emit ```json ... ``` wrappers for vision
# models. Strip them so strict downstream parsers see clean JSON.
raw_content = message.get("content", "") or ""
content = _strip_json_fences(raw_content) if raw_content else ""
finish_reason = "tool_calls" if tool_calls else (
"stop" if data.get("done") else None
)
return ChatCompletionResponse(
model=f"ollama/{model}",
choices=[
Choice(
message=MessageResponse(
content=content or None,
tool_calls=tool_calls,
),
finish_reason=finish_reason,
)
],
usage=Usage(
prompt_tokens=data.get("prompt_eval_count", 0),
completion_tokens=data.get("eval_count", 0),
total_tokens=data.get("prompt_eval_count", 0) + data.get("eval_count", 0),
),
)
async def chat_completion_stream(
self,
request: ChatCompletionRequest,
model: str,
) -> AsyncIterator[ChatCompletionStreamResponse]:
"""Generate a chat completion (streaming)."""
payload: dict[str, Any] = {
"model": model,
"messages": self._convert_messages(request),
"stream": True,
}
# Add optional parameters
options: dict[str, Any] = {}
if request.temperature is not None:
options["temperature"] = request.temperature
if request.top_p is not None:
options["top_p"] = request.top_p
if request.max_tokens is not None:
options["num_predict"] = request.max_tokens
if request.stop:
options["stop"] = request.stop if isinstance(request.stop, list) else [request.stop]
if options:
payload["options"] = options
if request.tools:
payload["tools"] = [
{"type": "function", "function": t.function.model_dump()}
for t in request.tools
]
logger.debug(f"Ollama streaming request: {model}")
response_id = f"chatcmpl-{model[:8]}"
first_chunk = True
async with self.client.stream("POST", "/api/chat", json=payload) as response:
response.raise_for_status()
async for line in response.aiter_lines():
if not line:
continue
try:
data = json.loads(line)
except json.JSONDecodeError:
logger.warning(f"Failed to parse Ollama response line: {line}")
continue
# First chunk includes role
if first_chunk:
yield ChatCompletionStreamResponse(
id=response_id,
model=f"ollama/{model}",
choices=[
StreamChoice(
delta=DeltaContent(role="assistant"),
)
],
)
first_chunk = False
# Content chunks
content = data.get("message", {}).get("content", "")
if content:
yield ChatCompletionStreamResponse(
id=response_id,
model=f"ollama/{model}",
choices=[
StreamChoice(
delta=DeltaContent(content=content),
)
],
)
# Final chunk with finish_reason
if data.get("done"):
yield ChatCompletionStreamResponse(
id=response_id,
model=f"ollama/{model}",
choices=[
StreamChoice(
delta=DeltaContent(),
finish_reason="stop",
)
],
)
async def list_models(self) -> list[ModelInfo]:
"""List available Ollama models."""
response = await self.client.get("/api/tags")
response.raise_for_status()
data = response.json()
models = []
for model_data in data.get("models", []):
name = model_data.get("name", "")
# Parse modified_at datetime string to Unix timestamp
created = None
if modified_at := model_data.get("modified_at"):
try:
from datetime import datetime
# Handle ISO format with timezone
dt = datetime.fromisoformat(modified_at.replace("Z", "+00:00"))
created = int(dt.timestamp())
except (ValueError, TypeError):
pass
models.append(
ModelInfo(
id=f"ollama/{name}",
owned_by="ollama",
created=created,
)
)
return models
async def embeddings(
self,
request: EmbeddingRequest,
model: str,
) -> EmbeddingResponse:
"""Generate embeddings for input text."""
inputs = request.input if isinstance(request.input, list) else [request.input]
embeddings_data = []
for i, text in enumerate(inputs):
response = await self.client.post(
"/api/embeddings",
json={"model": model, "prompt": text},
)
response.raise_for_status()
data = response.json()
embeddings_data.append(
EmbeddingData(
index=i,
embedding=data.get("embedding", []),
)
)
return EmbeddingResponse(
data=embeddings_data,
model=f"ollama/{model}",
usage=Usage(
prompt_tokens=sum(len(text.split()) for text in inputs), # Approximate
total_tokens=sum(len(text.split()) for text in inputs),
),
)
async def health_check(self) -> dict[str, Any]:
"""Check Ollama health status."""
try:
response = await self.client.get("/api/tags")
response.raise_for_status()
data = response.json()
model_count = len(data.get("models", []))
return {
"status": "healthy",
"provider": self.name,
"url": self.base_url,
"models_available": model_count,
}
except Exception as e:
return {
"status": "unhealthy",
"provider": self.name,
"url": self.base_url,
"error": str(e),
}
async def close(self) -> None:
"""Close HTTP client."""
if self._client and not self._client.is_closed:
await self._client.aclose()

View file

@ -1,364 +0,0 @@
"""OpenAI-compatible provider for OpenRouter, Groq, Together, etc."""
import json
import logging
from collections.abc import AsyncIterator
from typing import Any
import httpx
from src.models import (
ChatCompletionRequest,
ChatCompletionResponse,
ChatCompletionStreamResponse,
Choice,
DeltaContent,
EmbeddingData,
EmbeddingRequest,
EmbeddingResponse,
MessageResponse,
ModelInfo,
StreamChoice,
ToolCall,
ToolCallFunction,
Usage,
)
from .base import LLMProvider
logger = logging.getLogger(__name__)
def _tool_calls_from_openai(raw: list[dict[str, Any]] | None) -> list[ToolCall] | None:
"""Normalise an OpenAI-spec ``tool_calls`` array into our model shape."""
if not raw:
return None
calls: list[ToolCall] = []
for c in raw:
fn = c.get("function") or {}
name = fn.get("name")
if not name:
continue
calls.append(
ToolCall(
id=c.get("id") or "",
function=ToolCallFunction(
name=name,
arguments=fn.get("arguments") or "{}",
),
)
)
return calls or None
class OpenAICompatProvider(LLMProvider):
"""OpenAI-compatible API provider (OpenRouter, Groq, Together, etc.)."""
# OpenRouter/Groq/Together all expose tool_calls per the OpenAI spec;
# individual models within those services may or may not support it,
# but the request shape is uniform. The upstream returns a proper
# error for unsupported models — no silent downgrade here.
supports_tools = True
def __init__(
self,
name: str,
base_url: str,
api_key: str,
default_model: str | None = None,
timeout: int = 120,
):
self.name = name
self.base_url = base_url.rstrip("/")
self.api_key = api_key
self.default_model = default_model
self.timeout = timeout
self._client: httpx.AsyncClient | None = None
@property
def client(self) -> httpx.AsyncClient:
"""Get or create HTTP client."""
if self._client is None or self._client.is_closed:
self._client = httpx.AsyncClient(
base_url=self.base_url,
timeout=httpx.Timeout(self.timeout),
headers={
"Authorization": f"Bearer {self.api_key}",
"Content-Type": "application/json",
},
)
return self._client
def _convert_messages(self, request: ChatCompletionRequest) -> list[dict[str, Any]]:
"""Convert internal message format to OpenAI format.
The OpenAI chat-completions endpoint is the source of truth for
this shape, so most fields pass through verbatim including
``role='tool'`` messages with ``tool_call_id`` + ``content`` and
assistant messages carrying ``tool_calls[]``.
"""
messages: list[dict[str, Any]] = []
for msg in request.messages:
out: dict[str, Any] = {"role": msg.role}
if msg.tool_call_id:
out["tool_call_id"] = msg.tool_call_id
if msg.tool_calls:
out["tool_calls"] = [
{
"id": c.id,
"type": c.type,
"function": {
"name": c.function.name,
"arguments": c.function.arguments,
},
}
for c in msg.tool_calls
]
if msg.content is None:
# Assistant tool-call messages have null content per spec.
out["content"] = None
elif isinstance(msg.content, str):
out["content"] = msg.content
else:
content_parts = []
for part in msg.content:
if part.type == "text":
content_parts.append({"type": "text", "text": part.text})
elif part.type == "image_url":
content_parts.append(
{
"type": "image_url",
"image_url": {"url": part.image_url.url},
}
)
out["content"] = content_parts
messages.append(out)
return messages
async def chat_completion(
self,
request: ChatCompletionRequest,
model: str,
) -> ChatCompletionResponse:
"""Generate a chat completion (non-streaming)."""
payload: dict[str, Any] = {
"model": model,
"messages": self._convert_messages(request),
"stream": False,
}
# Add optional parameters
if request.temperature is not None:
payload["temperature"] = request.temperature
if request.max_tokens is not None:
payload["max_tokens"] = request.max_tokens
if request.top_p is not None:
payload["top_p"] = request.top_p
if request.frequency_penalty is not None:
payload["frequency_penalty"] = request.frequency_penalty
if request.presence_penalty is not None:
payload["presence_penalty"] = request.presence_penalty
if request.stop:
payload["stop"] = request.stop
if request.tools:
payload["tools"] = [
{"type": "function", "function": t.function.model_dump()}
for t in request.tools
]
if request.tool_choice is not None:
payload["tool_choice"] = (
request.tool_choice
if isinstance(request.tool_choice, str)
else request.tool_choice.model_dump()
)
logger.debug(f"{self.name} request: {model}, messages: {len(request.messages)}")
response = await self.client.post("/chat/completions", json=payload)
response.raise_for_status()
data = response.json()
return ChatCompletionResponse(
id=data.get("id", ""),
model=f"{self.name}/{model}",
choices=[
Choice(
index=choice.get("index", 0),
message=MessageResponse(
content=choice["message"].get("content"),
tool_calls=_tool_calls_from_openai(
choice["message"].get("tool_calls")
),
),
finish_reason=choice.get("finish_reason", "stop"),
)
for choice in data.get("choices", [])
],
usage=Usage(
prompt_tokens=data.get("usage", {}).get("prompt_tokens", 0),
completion_tokens=data.get("usage", {}).get("completion_tokens", 0),
total_tokens=data.get("usage", {}).get("total_tokens", 0),
),
)
async def chat_completion_stream(
self,
request: ChatCompletionRequest,
model: str,
) -> AsyncIterator[ChatCompletionStreamResponse]:
"""Generate a chat completion (streaming)."""
payload: dict[str, Any] = {
"model": model,
"messages": self._convert_messages(request),
"stream": True,
}
# Add optional parameters
if request.temperature is not None:
payload["temperature"] = request.temperature
if request.max_tokens is not None:
payload["max_tokens"] = request.max_tokens
if request.top_p is not None:
payload["top_p"] = request.top_p
if request.frequency_penalty is not None:
payload["frequency_penalty"] = request.frequency_penalty
if request.presence_penalty is not None:
payload["presence_penalty"] = request.presence_penalty
if request.stop:
payload["stop"] = request.stop
if request.tools:
payload["tools"] = [
{"type": "function", "function": t.function.model_dump()}
for t in request.tools
]
if request.tool_choice is not None:
payload["tool_choice"] = (
request.tool_choice
if isinstance(request.tool_choice, str)
else request.tool_choice.model_dump()
)
logger.debug(f"{self.name} streaming request: {model}")
async with self.client.stream("POST", "/chat/completions", json=payload) as response:
response.raise_for_status()
async for line in response.aiter_lines():
if not line or not line.startswith("data: "):
continue
data_str = line[6:] # Remove "data: " prefix
if data_str == "[DONE]":
break
try:
data = json.loads(data_str)
except json.JSONDecodeError:
logger.warning(f"Failed to parse stream line: {data_str}")
continue
choices = data.get("choices", [])
if not choices:
continue
choice = choices[0]
delta = choice.get("delta", {})
yield ChatCompletionStreamResponse(
id=data.get("id", ""),
model=f"{self.name}/{model}",
choices=[
StreamChoice(
index=choice.get("index", 0),
delta=DeltaContent(
role=delta.get("role"),
content=delta.get("content"),
tool_calls=_tool_calls_from_openai(
delta.get("tool_calls")
),
),
finish_reason=choice.get("finish_reason"),
)
],
)
async def list_models(self) -> list[ModelInfo]:
"""List available models."""
try:
response = await self.client.get("/models")
response.raise_for_status()
data = response.json()
models = []
for model_data in data.get("data", []):
model_id = model_data.get("id", "")
models.append(
ModelInfo(
id=f"{self.name}/{model_id}",
owned_by=model_data.get("owned_by", self.name),
)
)
return models
except httpx.HTTPError as e:
logger.warning(f"Failed to list models from {self.name}: {e}")
return []
async def embeddings(
self,
request: EmbeddingRequest,
model: str,
) -> EmbeddingResponse:
"""Generate embeddings for input text."""
payload = {
"model": model,
"input": request.input,
}
response = await self.client.post("/embeddings", json=payload)
response.raise_for_status()
data = response.json()
return EmbeddingResponse(
data=[
EmbeddingData(
index=item.get("index", i),
embedding=item.get("embedding", []),
)
for i, item in enumerate(data.get("data", []))
],
model=f"{self.name}/{model}",
usage=Usage(
prompt_tokens=data.get("usage", {}).get("prompt_tokens", 0),
total_tokens=data.get("usage", {}).get("total_tokens", 0),
),
)
async def health_check(self) -> dict[str, Any]:
"""Check provider health status."""
try:
response = await self.client.get("/models")
response.raise_for_status()
data = response.json()
model_count = len(data.get("data", []))
return {
"status": "healthy",
"provider": self.name,
"url": self.base_url,
"models_available": model_count,
}
except Exception as e:
return {
"status": "unhealthy",
"provider": self.name,
"url": self.base_url,
"error": str(e),
}
async def close(self) -> None:
"""Close HTTP client."""
if self._client and not self._client.is_closed:
await self._client.aclose()

View file

@ -1,463 +0,0 @@
"""Provider routing with alias resolution and health-aware fallback.
The router is the single entry point that the FastAPI handlers use. Its
job is:
1. Resolve the request's ``model`` field. If it lives in the ``mana/``
namespace the :class:`AliasRegistry` returns an ordered chain of
concrete provider/model strings; everything else is treated as a
single-entry chain (caller passed a direct provider/model).
2. Walk the chain, skipping entries whose provider is either
unconfigured at this deployment (no API key) or currently marked
unhealthy in the :class:`ProviderHealthCache`.
3. Try each remaining entry. Connection errors, timeouts, 5xx, and rate
limits are retryable record them in the cache and move to the next
entry. Capability/auth/blocked errors are caller-fixable and
propagate immediately without touching the health cache.
4. Return the first successful response. If every entry was skipped or
failed, raise :class:`NoHealthyProviderError` (HTTP 503) carrying
the full attempt log so debugging is straightforward.
The full design lives in ``docs/plans/llm-fallback-aliases.md``. This is
the M3 milestone.
"""
from __future__ import annotations
import logging
from collections.abc import AsyncIterator, Awaitable, Callable
from typing import Any, TypeVar
import httpx
from src.aliases import AliasRegistry
from src.config import settings
from src.health import ProviderHealthCache
from src.models import (
ChatCompletionRequest,
ChatCompletionResponse,
ChatCompletionStreamResponse,
EmbeddingRequest,
EmbeddingResponse,
ModelInfo,
)
from src.utils.metrics import (
record_alias_resolved,
record_fallback,
set_provider_healthy,
)
from .base import LLMProvider
from .errors import (
NoHealthyProviderError,
ProviderAuthError,
ProviderBlockedError,
ProviderCapabilityError,
ProviderError,
ProviderRateLimitError,
)
from .ollama import OllamaProvider
from .openai_compat import OpenAICompatProvider
logger = logging.getLogger(__name__)
T = TypeVar("T")
class ProviderRouter:
"""Health-aware provider router with alias resolution.
Construct with the AliasRegistry and ProviderHealthCache from
application startup; both are external dependencies so tests can
inject mocks without going through global state.
"""
def __init__(
self,
aliases: AliasRegistry,
health_cache: ProviderHealthCache,
) -> None:
self.aliases = aliases
self.health_cache = health_cache
self.providers: dict[str, LLMProvider] = {}
self._initialize_providers()
# ------------------------------------------------------------------
# Provider initialisation
# ------------------------------------------------------------------
def _initialize_providers(self) -> None:
"""Spin up provider adapters based on what's configured."""
# Ollama: always present (talks to a local/proxied server). Whether
# it's actually reachable is the cache's job to figure out.
self.providers["ollama"] = OllamaProvider()
logger.info("Initialized Ollama provider at %s", settings.ollama_url)
if settings.google_api_key:
from .google import GoogleProvider
self.providers["google"] = GoogleProvider(
api_key=settings.google_api_key,
default_model=settings.google_default_model,
)
logger.info("Initialized Google Gemini provider")
if settings.openrouter_api_key:
self.providers["openrouter"] = OpenAICompatProvider(
name="openrouter",
base_url=settings.openrouter_base_url,
api_key=settings.openrouter_api_key,
default_model=settings.openrouter_default_model,
)
logger.info("Initialized OpenRouter provider")
if settings.groq_api_key:
self.providers["groq"] = OpenAICompatProvider(
name="groq",
base_url=settings.groq_base_url,
api_key=settings.groq_api_key,
)
logger.info("Initialized Groq provider")
if settings.together_api_key:
self.providers["together"] = OpenAICompatProvider(
name="together",
base_url=settings.together_base_url,
api_key=settings.together_api_key,
)
logger.info("Initialized Together provider")
# ------------------------------------------------------------------
# Helpers
# ------------------------------------------------------------------
def _parse_model(self, model: str) -> tuple[str, str]:
"""Split ``provider/model`` into its parts.
Bare names (no prefix) default to Ollama for compatibility with
plain OpenAI-style requests. Aliases (``mana/...``) are resolved
before this is ever called.
"""
if "/" in model:
provider, _, model_name = model.partition("/")
return provider.lower(), model_name
return "ollama", model
def _resolve_chain(self, model_or_alias: str) -> list[str]:
"""Expand aliases to chains; pass everything else through unchanged."""
if AliasRegistry.is_alias(model_or_alias):
return list(self.aliases.resolve_chain(model_or_alias))
return [model_or_alias]
@staticmethod
def _is_retryable(exc: BaseException) -> bool:
"""Should we treat this exception as "try the next chain entry"?
ConnectError / timeouts / 5xx / rate-limits = yes (provider blip).
Auth / capability / blocked / 4xx = no (caller has to fix
something; retrying with a different provider only hides the bug).
"""
if isinstance(exc, ProviderCapabilityError):
return False
if isinstance(exc, ProviderBlockedError):
return False
if isinstance(exc, ProviderAuthError):
return False
if isinstance(exc, ProviderRateLimitError):
return True
if isinstance(
exc,
(
httpx.ConnectError,
httpx.ConnectTimeout,
httpx.ReadError,
httpx.ReadTimeout,
httpx.RemoteProtocolError,
httpx.WriteError,
httpx.WriteTimeout,
httpx.PoolTimeout,
),
):
return True
if isinstance(exc, httpx.HTTPStatusError):
return exc.response.status_code >= 500
if isinstance(exc, ProviderError):
# Any other provider-side error — treat as retryable.
# Subclasses with explicit non-retry semantics are caught above.
return True
# Unknown exception types: do NOT silently retry. Better to
# surface a strange error than hide a real bug behind a fallback.
return False
@staticmethod
def _exception_summary(exc: BaseException) -> str:
"""Compact one-liner for cache.last_error and log entries."""
return f"{type(exc).__name__}: {exc}"
def _check_tool_capability(
self,
provider: LLMProvider,
model_name: str,
request: ChatCompletionRequest,
) -> None:
"""Refuse tool-bearing requests for models that don't support tools.
Silent downgrade (dropping the ``tools`` payload) is more dangerous
than an explicit error the caller would get plain text back and
have no way to tell the tools never reached the model.
"""
if not request.tools:
return
if not provider.model_supports_tools(model_name):
raise ProviderCapabilityError(
f"{provider.name}/{model_name} does not support tool calling. "
"Choose a tool-capable model (e.g. groq/llama-3.3-70b-versatile)"
)
# ------------------------------------------------------------------
# Core fallback executor (non-streaming)
# ------------------------------------------------------------------
async def _execute_with_fallback(
self,
model_or_alias: str,
request: ChatCompletionRequest,
call: Callable[[LLMProvider, str, ChatCompletionRequest], Awaitable[T]],
) -> T:
"""Walk the resolved chain, returning the first successful result.
``call`` is the operation to run against each chain entry, e.g.
``lambda p, m, req: p.chat_completion(req, m)``. The function
receives the provider instance, the model name (without the
provider prefix), and the original request.
"""
chain = self._resolve_chain(model_or_alias)
attempts: list[tuple[str, str]] = []
last_exc: Exception | None = None
is_alias = AliasRegistry.is_alias(model_or_alias)
for i, entry in enumerate(chain):
next_entry = chain[i + 1] if i + 1 < len(chain) else ""
provider_name, model_name = self._parse_model(entry)
if provider_name not in self.providers:
logger.debug(
"skip chain entry %s — provider %s not configured here",
entry,
provider_name,
)
attempts.append((entry, "unconfigured"))
record_fallback(entry, next_entry, "unconfigured")
continue
if not self.health_cache.is_healthy(provider_name):
logger.debug("skip chain entry %s — cache says unhealthy", entry)
attempts.append((entry, "cache-unhealthy"))
record_fallback(entry, next_entry, "cache-unhealthy")
continue
provider = self.providers[provider_name]
self._check_tool_capability(provider, model_name, request)
try:
logger.info(
"execute → %s (alias=%s)",
entry,
model_or_alias if is_alias else "<direct>",
)
result = await call(provider, model_name, request)
self.health_cache.mark_healthy(provider_name)
set_provider_healthy(provider_name, True)
if is_alias:
record_alias_resolved(model_or_alias, entry)
return result
except Exception as e:
if not self._is_retryable(e):
# Caller error / non-retryable provider error — propagate
# without touching the health cache. The cache is for
# liveness, not for recording what the user asked for
# being wrong.
raise
self.health_cache.mark_unhealthy(provider_name, self._exception_summary(e))
set_provider_healthy(
provider_name, self.health_cache.is_healthy(provider_name)
)
attempts.append((entry, type(e).__name__))
record_fallback(entry, next_entry, type(e).__name__)
last_exc = e
logger.warning(
"execute %s failed (retryable, will try next): %s",
entry,
e,
)
raise NoHealthyProviderError(model_or_alias, attempts, last_exc)
# ------------------------------------------------------------------
# Public API — non-streaming
# ------------------------------------------------------------------
async def chat_completion(
self,
request: ChatCompletionRequest,
) -> ChatCompletionResponse:
"""Chat completion with alias resolution + health-aware fallback."""
async def call(provider: LLMProvider, model: str, req: ChatCompletionRequest):
return await provider.chat_completion(req, model)
return await self._execute_with_fallback(request.model, request, call)
# ------------------------------------------------------------------
# Public API — streaming (pre-first-byte fallback)
# ------------------------------------------------------------------
async def chat_completion_stream(
self,
request: ChatCompletionRequest,
) -> AsyncIterator[ChatCompletionStreamResponse]:
"""Streaming variant. Falls back BEFORE the first chunk arrives;
once the first chunk has been yielded the provider is committed
and any further error propagates.
Why pre-first-byte only: stitching half-streams from two different
providers would mix two voices in the output and is impossible to
sanity-check after the fact.
"""
chain = self._resolve_chain(request.model)
attempts: list[tuple[str, str]] = []
last_exc: Exception | None = None
is_alias = AliasRegistry.is_alias(request.model)
for i, entry in enumerate(chain):
next_entry = chain[i + 1] if i + 1 < len(chain) else ""
provider_name, model_name = self._parse_model(entry)
if provider_name not in self.providers:
attempts.append((entry, "unconfigured"))
record_fallback(entry, next_entry, "unconfigured")
continue
if not self.health_cache.is_healthy(provider_name):
attempts.append((entry, "cache-unhealthy"))
record_fallback(entry, next_entry, "cache-unhealthy")
continue
provider = self.providers[provider_name]
self._check_tool_capability(provider, model_name, request)
stream = provider.chat_completion_stream(request, model_name)
try:
first_chunk = await stream.__anext__()
except StopAsyncIteration:
# Empty stream is a successful but content-free response.
# Commit and exit cleanly.
self.health_cache.mark_healthy(provider_name)
set_provider_healthy(provider_name, True)
if is_alias:
record_alias_resolved(request.model, entry)
logger.info("stream %s yielded empty response", entry)
return
except Exception as e:
if not self._is_retryable(e):
raise
self.health_cache.mark_unhealthy(provider_name, self._exception_summary(e))
set_provider_healthy(
provider_name, self.health_cache.is_healthy(provider_name)
)
attempts.append((entry, type(e).__name__))
record_fallback(entry, next_entry, type(e).__name__)
last_exc = e
logger.warning(
"stream %s failed before first byte (retryable, trying next): %s",
entry,
e,
)
continue
# First byte landed — commit the provider, mark healthy, drain
# the rest of the stream. Any error from here on propagates;
# it is NOT safe to splice another provider's output in.
self.health_cache.mark_healthy(provider_name)
set_provider_healthy(provider_name, True)
if is_alias:
record_alias_resolved(request.model, entry)
logger.info("stream → %s (committed after first chunk)", entry)
yield first_chunk
async for chunk in stream:
yield chunk
return
raise NoHealthyProviderError(request.model, attempts, last_exc)
# ------------------------------------------------------------------
# Embeddings — no fallback (out of scope for M3, separate concerns)
# ------------------------------------------------------------------
async def embeddings(self, request: EmbeddingRequest) -> EmbeddingResponse:
"""Route an embeddings request directly. No alias / fallback."""
provider_name, model_name = self._parse_model(request.model)
if provider_name not in self.providers:
available = list(self.providers)
raise ValueError(
f"Provider '{provider_name}' not available. Available: {available}"
)
provider = self.providers[provider_name]
logger.info("embeddings → %s/%s", provider_name, model_name)
return await provider.embeddings(request, model_name)
# ------------------------------------------------------------------
# Discovery / introspection
# ------------------------------------------------------------------
async def list_models(self) -> list[ModelInfo]:
"""List all available models from all configured providers.
Best-effort: providers that error are skipped with a warning so a
single broken provider can't take down ``GET /v1/models``.
"""
all_models: list[ModelInfo] = []
for provider in self.providers.values():
try:
all_models.extend(await provider.list_models())
except Exception as e: # noqa: BLE001
logger.warning("Failed to list models from %s: %s", provider.name, e)
return all_models
async def get_model(self, model_id: str) -> ModelInfo | None:
"""Look up a single model by id, dispatching on the prefix."""
provider_name, model_name = self._parse_model(model_id)
if provider_name not in self.providers:
return None
models = await self.providers[provider_name].list_models()
for m in models:
if m.id == model_id or m.id.endswith(f"/{model_name}"):
return m
return None
async def health_check(self) -> dict[str, Any]:
"""Snapshot of the per-provider liveness cache.
Returns the same shape as before for backwards-compat with
``GET /health`` (deprecated structure M4 will swap to a
cleaner ``/v1/health`` endpoint).
"""
snapshot = self.health_cache.snapshot(expected=list(self.providers))
providers_out: dict[str, Any] = {}
for name, state in snapshot.items():
providers_out[name] = {
"status": "healthy" if state.healthy else "unhealthy",
"consecutive_failures": state.consecutive_failures,
"last_error": state.last_error,
"last_check_unix": state.last_check or None,
"unhealthy_until_unix": state.unhealthy_until or None,
}
all_healthy = all(state.healthy for state in snapshot.values())
any_healthy = any(state.healthy for state in snapshot.values())
return {
"status": "healthy" if all_healthy else ("degraded" if any_healthy else "unhealthy"),
"providers": providers_out,
}
async def close(self) -> None:
"""Close all provider clients."""
for provider in self.providers.values():
await provider.close()

View file

@ -1,5 +0,0 @@
"""Streaming utilities for SSE responses."""
from .sse import stream_chat_completion
__all__ = ["stream_chat_completion"]

View file

@ -1,52 +0,0 @@
"""Server-Sent Events (SSE) response handling."""
import json
import logging
from collections.abc import AsyncIterator
from src.models import ChatCompletionRequest, ChatCompletionStreamResponse
from src.providers import ProviderRouter
from src.providers.errors import ProviderError
logger = logging.getLogger(__name__)
async def stream_chat_completion(
router: ProviderRouter,
request: ChatCompletionRequest,
) -> AsyncIterator[dict]:
"""
Stream chat completion responses for SSE.
Yields dicts that EventSourceResponse will serialize as:
data: {"choices":[{"delta":{"content":"Hello"}}]}
data: [DONE]
"""
try:
async for chunk in router.chat_completion_stream(request):
# Yield dict for EventSourceResponse to serialize
yield {"data": json.dumps(chunk.model_dump(exclude_none=True))}
# Send final [DONE] marker
yield {"data": "[DONE]"}
except ProviderError as e:
logger.warning(f"Streaming provider error: kind={e.kind} detail={e}")
error_data = {
"error": {
"message": str(e),
"type": e.kind,
}
}
yield {"data": json.dumps(error_data)}
yield {"data": "[DONE]"}
except Exception as e:
logger.error(f"Streaming error: {e}")
error_data = {
"error": {
"message": str(e),
"type": "server_error",
}
}
yield {"data": json.dumps(error_data)}
yield {"data": "[DONE]"}

View file

@ -1,5 +0,0 @@
"""Utility modules."""
from .metrics import get_metrics, metrics_middleware
__all__ = ["get_metrics", "metrics_middleware"]

View file

@ -1,85 +0,0 @@
"""Redis caching utilities (optional)."""
import hashlib
import json
import logging
from typing import Any
from src.config import settings
logger = logging.getLogger(__name__)
# Redis client (lazy initialized)
_redis_client = None
async def get_redis_client():
"""Get or create Redis client."""
global _redis_client
if _redis_client is not None:
return _redis_client
if not settings.redis_url:
return None
try:
import redis.asyncio as redis
_redis_client = redis.from_url(settings.redis_url)
# Test connection
await _redis_client.ping()
logger.info(f"Connected to Redis at {settings.redis_url}")
return _redis_client
except Exception as e:
logger.warning(f"Failed to connect to Redis: {e}")
return None
def generate_cache_key(prefix: str, data: dict[str, Any]) -> str:
"""Generate a cache key from request data."""
# Serialize and hash the data for consistent key
serialized = json.dumps(data, sort_keys=True)
hash_value = hashlib.sha256(serialized.encode()).hexdigest()[:16]
return f"mana-llm:{prefix}:{hash_value}"
async def get_cached(key: str) -> dict[str, Any] | None:
"""Get cached value by key."""
client = await get_redis_client()
if client is None:
return None
try:
value = await client.get(key)
if value:
return json.loads(value)
except Exception as e:
logger.warning(f"Cache get failed: {e}")
return None
async def set_cached(key: str, value: dict[str, Any], ttl: int | None = None) -> bool:
"""Set cached value with optional TTL."""
client = await get_redis_client()
if client is None:
return False
try:
ttl = ttl or settings.cache_ttl
serialized = json.dumps(value)
await client.setex(key, ttl, serialized)
return True
except Exception as e:
logger.warning(f"Cache set failed: {e}")
return False
async def close_redis() -> None:
"""Close Redis connection."""
global _redis_client
if _redis_client is not None:
await _redis_client.aclose()
_redis_client = None

View file

@ -1,158 +0,0 @@
"""Prometheus metrics for mana-llm."""
import time
from collections.abc import Callable
from fastapi import Request, Response
from prometheus_client import Counter, Gauge, Histogram, generate_latest
from starlette.middleware.base import BaseHTTPMiddleware
# Request metrics
REQUEST_COUNT = Counter(
"mana_llm_requests_total",
"Total number of requests",
["method", "endpoint", "status"],
)
REQUEST_LATENCY = Histogram(
"mana_llm_request_latency_seconds",
"Request latency in seconds",
["method", "endpoint"],
buckets=[0.1, 0.5, 1.0, 2.0, 5.0, 10.0, 30.0, 60.0, 120.0],
)
# LLM-specific metrics
LLM_REQUEST_COUNT = Counter(
"mana_llm_llm_requests_total",
"Total number of LLM requests",
["provider", "model", "streaming"],
)
LLM_TOKEN_COUNT = Counter(
"mana_llm_tokens_total",
"Total tokens processed",
["provider", "model", "type"], # type: prompt, completion
)
LLM_LATENCY = Histogram(
"mana_llm_llm_latency_seconds",
"LLM request latency in seconds",
["provider", "model"],
buckets=[0.5, 1.0, 2.0, 5.0, 10.0, 30.0, 60.0, 120.0],
)
LLM_ERRORS = Counter(
"mana_llm_llm_errors_total",
"Total LLM errors",
["provider", "model", "error_type"],
)
# ---------------------------------------------------------------------------
# Alias / fallback / health metrics — added in M4 of llm-fallback-aliases.md.
# ---------------------------------------------------------------------------
ALIAS_RESOLVED = Counter(
"mana_llm_alias_resolved_total",
"How often an alias resolved to a concrete provider/model. The `target` "
"label is the chain entry that actually served the request — useful for "
"spotting cases where the primary always falls through to a cloud entry.",
["alias", "target"],
)
FALLBACK_TRIGGERED = Counter(
"mana_llm_fallback_total",
"Fallback transitions: a chain entry failed (or was skipped via cache) "
"and the router moved to the next entry. `reason` is the exception class "
"name or `cache-unhealthy` / `unconfigured`. `from_model` is the entry "
"that didn't serve, `to_model` is empty when no further entries existed.",
["from_model", "to_model", "reason"],
)
PROVIDER_HEALTHY = Gauge(
"mana_llm_provider_healthy",
"1 when the provider is currently considered healthy by the cache, "
"0 when in backoff. Refreshed on every probe tick and on every router "
"call-site state transition.",
["provider"],
)
def get_metrics() -> bytes:
"""Generate Prometheus metrics output."""
return generate_latest()
class MetricsMiddleware(BaseHTTPMiddleware):
"""Middleware for collecting HTTP metrics."""
async def dispatch(self, request: Request, call_next: Callable) -> Response:
start_time = time.time()
response = await call_next(request)
# Record metrics
duration = time.time() - start_time
endpoint = request.url.path
method = request.method
status = str(response.status_code)
REQUEST_COUNT.labels(method=method, endpoint=endpoint, status=status).inc()
REQUEST_LATENCY.labels(method=method, endpoint=endpoint).observe(duration)
return response
# Export middleware instance
metrics_middleware = MetricsMiddleware
def record_llm_request(
provider: str,
model: str,
streaming: bool,
prompt_tokens: int = 0,
completion_tokens: int = 0,
latency: float | None = None,
) -> None:
"""Record LLM request metrics."""
LLM_REQUEST_COUNT.labels(
provider=provider,
model=model,
streaming=str(streaming).lower(),
).inc()
if prompt_tokens > 0:
LLM_TOKEN_COUNT.labels(provider=provider, model=model, type="prompt").inc(prompt_tokens)
if completion_tokens > 0:
LLM_TOKEN_COUNT.labels(provider=provider, model=model, type="completion").inc(
completion_tokens
)
if latency is not None:
LLM_LATENCY.labels(provider=provider, model=model).observe(latency)
def record_llm_error(provider: str, model: str, error_type: str) -> None:
"""Record LLM error metrics."""
LLM_ERRORS.labels(provider=provider, model=model, error_type=error_type).inc()
def record_alias_resolved(alias: str, target: str) -> None:
"""Record which concrete model an alias resolved to for this request."""
ALIAS_RESOLVED.labels(alias=alias, target=target).inc()
def record_fallback(from_model: str, to_model: str, reason: str) -> None:
"""Record a fallback transition. ``to_model`` is empty when the chain
ran out (i.e. NoHealthyProviderError)."""
FALLBACK_TRIGGERED.labels(
from_model=from_model,
to_model=to_model,
reason=reason,
).inc()
def set_provider_healthy(provider: str, healthy: bool) -> None:
"""Mirror ``ProviderHealthCache`` state into a Prometheus gauge."""
PROVIDER_HEALTHY.labels(provider=provider).set(1.0 if healthy else 0.0)

View file

@ -1,28 +0,0 @@
#!/bin/bash
# Start mana-llm service
# Automatically creates venv and installs dependencies if needed
cd "$(dirname "$0")"
# Check if venv exists, create if not
if [ ! -d "venv" ]; then
echo "Creating virtual environment..."
python3 -m venv venv
fi
# Activate venv
source venv/bin/activate
# Install/update dependencies
pip install -q -r requirements.txt
# Copy .env if not exists
if [ ! -f ".env" ] && [ -f ".env.example" ]; then
cp .env.example .env
echo "Created .env from .env.example"
fi
# Start the service
echo "Starting mana-llm on port 3025..."
exec python -m uvicorn src.main:app --port 3025 --reload

View file

@ -1 +0,0 @@
"""Tests for mana-llm service."""

View file

@ -1,300 +0,0 @@
"""Tests for the model-alias registry."""
from __future__ import annotations
from pathlib import Path
import pytest
from src.aliases import (
ALIAS_PREFIX,
Alias,
AliasConfigError,
AliasRegistry,
UnknownAliasError,
)
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def write_yaml(tmp_path: Path, body: str, name: str = "aliases.yaml") -> Path:
p = tmp_path / name
p.write_text(body, encoding="utf-8")
return p
VALID_CONFIG = """\
aliases:
mana/fast-text:
description: "fast"
chain:
- ollama/qwen2.5:7b
- groq/llama-3.1-8b-instant
mana/long-form:
description: "long"
chain:
- ollama/gemma3:12b
- groq/llama-3.3-70b-versatile
default: mana/fast-text
"""
# ---------------------------------------------------------------------------
# Construction & happy-path resolution
# ---------------------------------------------------------------------------
class TestRegistryHappyPath:
def test_loads_valid_yaml(self, tmp_path: Path) -> None:
path = write_yaml(tmp_path, VALID_CONFIG)
reg = AliasRegistry(path)
assert reg.path == path
assert reg.default_alias == "mana/fast-text"
def test_resolve_returns_alias_dataclass(self, tmp_path: Path) -> None:
reg = AliasRegistry(write_yaml(tmp_path, VALID_CONFIG))
alias = reg.resolve("mana/long-form")
assert isinstance(alias, Alias)
assert alias.name == "mana/long-form"
assert alias.description == "long"
assert alias.chain == ("ollama/gemma3:12b", "groq/llama-3.3-70b-versatile")
def test_resolve_chain_returns_tuple(self, tmp_path: Path) -> None:
reg = AliasRegistry(write_yaml(tmp_path, VALID_CONFIG))
chain = reg.resolve_chain("mana/fast-text")
assert chain == ("ollama/qwen2.5:7b", "groq/llama-3.1-8b-instant")
# Tuples ensure callers can't mutate the registry's internal state.
assert isinstance(chain, tuple)
def test_list_aliases_sorted(self, tmp_path: Path) -> None:
reg = AliasRegistry(write_yaml(tmp_path, VALID_CONFIG))
names = [a.name for a in reg.list_aliases()]
assert names == sorted(names)
assert names == ["mana/fast-text", "mana/long-form"]
def test_unknown_alias_raises(self, tmp_path: Path) -> None:
reg = AliasRegistry(write_yaml(tmp_path, VALID_CONFIG))
with pytest.raises(UnknownAliasError, match="mana/nope"):
reg.resolve("mana/nope")
def test_default_optional(self, tmp_path: Path) -> None:
body = (
"aliases:\n"
" mana/x:\n"
' description: "x"\n'
" chain:\n"
" - ollama/foo:1b\n"
)
reg = AliasRegistry(write_yaml(tmp_path, body))
assert reg.default_alias is None
class TestIsAlias:
"""``is_alias`` is a cheap static syntactic check used by the router."""
@pytest.mark.parametrize(
"name",
["mana/fast-text", "mana/anything", f"{ALIAS_PREFIX}foo"],
)
def test_recognises_alias_namespace(self, name: str) -> None:
assert AliasRegistry.is_alias(name) is True
@pytest.mark.parametrize(
"name",
["ollama/gemma3:4b", "groq/llama", "gemma3:4b", "", "mana", "manaX/foo"],
)
def test_rejects_non_alias(self, name: str) -> None:
assert AliasRegistry.is_alias(name) is False
def test_static_no_instance_needed(self) -> None:
# Important: callers can hit this without instantiating, so it must
# be a free function or @staticmethod.
assert AliasRegistry.is_alias("mana/x") is True
# ---------------------------------------------------------------------------
# Schema validation — the YAML is user-edited, must fail loudly on typos
# ---------------------------------------------------------------------------
class TestSchemaValidation:
def test_missing_file_raises(self, tmp_path: Path) -> None:
with pytest.raises(AliasConfigError, match="not found"):
AliasRegistry(tmp_path / "absent.yaml")
def test_invalid_yaml_raises(self, tmp_path: Path) -> None:
path = write_yaml(tmp_path, "aliases: [\n unclosed")
with pytest.raises(AliasConfigError, match="failed to parse"):
AliasRegistry(path)
def test_root_not_a_mapping(self, tmp_path: Path) -> None:
path = write_yaml(tmp_path, "- just-a-list\n")
with pytest.raises(AliasConfigError, match="root must be a mapping"):
AliasRegistry(path)
def test_aliases_must_be_mapping(self, tmp_path: Path) -> None:
path = write_yaml(tmp_path, "aliases: just-a-string\n")
with pytest.raises(AliasConfigError, match="`aliases` must be a mapping"):
AliasRegistry(path)
def test_empty_aliases_rejected(self, tmp_path: Path) -> None:
path = write_yaml(tmp_path, "aliases: {}\n")
with pytest.raises(AliasConfigError, match="empty"):
AliasRegistry(path)
def test_alias_name_must_use_mana_namespace(self, tmp_path: Path) -> None:
body = (
"aliases:\n"
" fast-text:\n"
' description: "x"\n'
" chain:\n"
" - ollama/foo:1b\n"
)
path = write_yaml(tmp_path, body)
with pytest.raises(AliasConfigError, match="mana/"):
AliasRegistry(path)
def test_alias_name_must_have_one_segment(self, tmp_path: Path) -> None:
body = (
"aliases:\n"
" mana/foo/bar:\n"
' description: "x"\n'
" chain:\n"
" - ollama/foo:1b\n"
)
path = write_yaml(tmp_path, body)
with pytest.raises(AliasConfigError, match="exactly one segment"):
AliasRegistry(path)
def test_chain_must_be_list(self, tmp_path: Path) -> None:
body = (
"aliases:\n"
" mana/x:\n"
' description: "x"\n'
' chain: "ollama/gemma3:4b"\n'
)
path = write_yaml(tmp_path, body)
with pytest.raises(AliasConfigError, match="chain must be a list"):
AliasRegistry(path)
def test_empty_chain_rejected(self, tmp_path: Path) -> None:
body = "aliases:\n mana/x:\n chain: []\n"
path = write_yaml(tmp_path, body)
with pytest.raises(AliasConfigError, match="must not be empty"):
AliasRegistry(path)
def test_chain_entry_without_provider_prefix_rejected(self, tmp_path: Path) -> None:
# "gemma3:4b" without a provider/ prefix would silently default to
# ollama and confuse the health-cache; reject loudly at config-load.
body = "aliases:\n mana/x:\n chain:\n - gemma3:4b\n"
path = write_yaml(tmp_path, body)
with pytest.raises(AliasConfigError, match="provider prefix"):
AliasRegistry(path)
def test_chain_entry_must_be_string(self, tmp_path: Path) -> None:
body = "aliases:\n mana/x:\n chain:\n - 42\n"
path = write_yaml(tmp_path, body)
with pytest.raises(AliasConfigError):
AliasRegistry(path)
def test_default_must_reference_known_alias(self, tmp_path: Path) -> None:
body = (
"aliases:\n"
" mana/x:\n"
' description: "x"\n'
" chain:\n"
" - ollama/foo:1b\n"
"default: mana/missing\n"
)
path = write_yaml(tmp_path, body)
with pytest.raises(AliasConfigError, match="references unknown alias"):
AliasRegistry(path)
# ---------------------------------------------------------------------------
# Reload semantics — SIGHUP should be safe even with typos
# ---------------------------------------------------------------------------
class TestReload:
def test_reload_picks_up_edits(self, tmp_path: Path) -> None:
path = write_yaml(tmp_path, VALID_CONFIG)
reg = AliasRegistry(path)
assert reg.resolve_chain("mana/long-form") == (
"ollama/gemma3:12b",
"groq/llama-3.3-70b-versatile",
)
# Edit on disk: shrink the long-form chain.
new_body = (
"aliases:\n"
" mana/long-form:\n"
' description: "shorter"\n'
" chain:\n"
" - groq/llama-3.3-70b-versatile\n"
"default: mana/long-form\n"
)
path.write_text(new_body, encoding="utf-8")
reg.reload()
assert reg.resolve_chain("mana/long-form") == ("groq/llama-3.3-70b-versatile",)
assert reg.default_alias == "mana/long-form"
# Aliases that disappeared from the new file are gone.
with pytest.raises(UnknownAliasError):
reg.resolve("mana/fast-text")
def test_reload_keeps_old_state_on_parse_error(self, tmp_path: Path) -> None:
path = write_yaml(tmp_path, VALID_CONFIG)
reg = AliasRegistry(path)
# First reload fine — establish a baseline.
reg.reload()
# Now break the file with an obviously invalid yaml.
path.write_text("aliases: [unclosed\n", encoding="utf-8")
with pytest.raises(AliasConfigError):
reg.reload()
# The previous good state must still be queryable — service stays up.
assert reg.resolve_chain("mana/fast-text") == (
"ollama/qwen2.5:7b",
"groq/llama-3.1-8b-instant",
)
assert reg.default_alias == "mana/fast-text"
def test_reload_keeps_old_state_on_schema_error(self, tmp_path: Path) -> None:
path = write_yaml(tmp_path, VALID_CONFIG)
reg = AliasRegistry(path)
# Empty aliases — would be rejected on first load, must also be
# rejected here without nuking the in-memory state.
path.write_text("aliases: {}\n", encoding="utf-8")
with pytest.raises(AliasConfigError):
reg.reload()
assert "mana/fast-text" in [a.name for a in reg.list_aliases()]
# ---------------------------------------------------------------------------
# Repo-shipped aliases.yaml is itself valid
# ---------------------------------------------------------------------------
class TestShippedConfig:
def test_repo_aliases_yaml_loads(self) -> None:
# The yaml file checked into services/mana-llm/aliases.yaml is the
# one that runs in production. It must always parse cleanly — this
# test catches editor accidents before they ship.
repo_yaml = Path(__file__).resolve().parents[1] / "aliases.yaml"
assert repo_yaml.exists(), f"shipped config missing at {repo_yaml}"
reg = AliasRegistry(repo_yaml)
# Sanity: the five classes the plan calls out must exist.
for expected in (
"mana/fast-text",
"mana/long-form",
"mana/structured",
"mana/reasoning",
"mana/vision",
):
reg.resolve(expected)

View file

@ -1,41 +0,0 @@
"""API endpoint tests."""
import pytest
from fastapi.testclient import TestClient
@pytest.fixture
def client():
"""Create test client."""
from src.main import app
with TestClient(app) as c:
yield c
def test_health_endpoint(client):
"""Test health check endpoint."""
response = client.get("/health")
assert response.status_code == 200
data = response.json()
assert "status" in data
assert "service" in data
assert data["service"] == "mana-llm"
def test_metrics_endpoint(client):
"""Test metrics endpoint."""
response = client.get("/metrics")
assert response.status_code == 200
assert "mana_llm" in response.text
def test_list_models_endpoint(client):
"""Test list models endpoint."""
response = client.get("/v1/models")
# May fail if Ollama is not running, but should return valid response structure
if response.status_code == 200:
data = response.json()
assert "data" in data
assert "object" in data
assert data["object"] == "list"

View file

@ -1,203 +0,0 @@
"""Tests for the provider health cache."""
from __future__ import annotations
import pytest
from src.health import (
DEFAULT_FAILURE_THRESHOLD,
DEFAULT_UNHEALTHY_BACKOFF_SEC,
ProviderHealthCache,
ProviderState,
)
class FakeClock:
"""Deterministic clock for circuit-breaker timing tests."""
def __init__(self, start: float = 1_000_000.0) -> None:
self.now = start
def __call__(self) -> float:
return self.now
def advance(self, seconds: float) -> None:
self.now += seconds
# ---------------------------------------------------------------------------
# Construction
# ---------------------------------------------------------------------------
class TestConstruction:
def test_defaults(self) -> None:
c = ProviderHealthCache()
assert c.failure_threshold == DEFAULT_FAILURE_THRESHOLD
assert c.unhealthy_backoff_sec == DEFAULT_UNHEALTHY_BACKOFF_SEC
def test_invalid_threshold_rejected(self) -> None:
with pytest.raises(ValueError, match="failure_threshold"):
ProviderHealthCache(failure_threshold=0)
def test_invalid_backoff_rejected(self) -> None:
with pytest.raises(ValueError, match="backoff"):
ProviderHealthCache(unhealthy_backoff_sec=-1.0)
# ---------------------------------------------------------------------------
# Default-healthy behaviour
# ---------------------------------------------------------------------------
class TestDefaults:
def test_unknown_provider_is_healthy(self) -> None:
# The cache is observation-only — no entry means "no reason to skip".
c = ProviderHealthCache()
assert c.is_healthy("ollama") is True
def test_unknown_provider_has_no_state(self) -> None:
c = ProviderHealthCache()
assert c.get_state("ollama") is None
def test_snapshot_includes_expected_zero_state(self) -> None:
c = ProviderHealthCache()
snap = c.snapshot(expected=["ollama", "groq"])
assert set(snap.keys()) == {"ollama", "groq"}
assert all(s.healthy for s in snap.values())
assert all(s.consecutive_failures == 0 for s in snap.values())
# ---------------------------------------------------------------------------
# Failure → unhealthy state machine
# ---------------------------------------------------------------------------
class TestFailureSemantics:
def test_first_failure_does_not_trip(self) -> None:
# Single transient blips shouldn't bounce a provider — wait for the
# next consecutive failure to confirm.
c = ProviderHealthCache(failure_threshold=2)
c.mark_unhealthy("ollama", "boom")
assert c.is_healthy("ollama") is True
state = c.get_state("ollama")
assert state is not None
assert state.consecutive_failures == 1
assert state.healthy is True
def test_threshold_reached_trips_breaker(self) -> None:
clock = FakeClock()
c = ProviderHealthCache(
failure_threshold=2,
unhealthy_backoff_sec=60.0,
clock=clock,
)
c.mark_unhealthy("ollama", "fail-1")
c.mark_unhealthy("ollama", "fail-2")
assert c.is_healthy("ollama") is False
state = c.get_state("ollama")
assert state is not None
assert state.healthy is False
assert state.last_error == "fail-2"
assert state.unhealthy_until == clock.now + 60.0
def test_threshold_one_trips_immediately(self) -> None:
c = ProviderHealthCache(failure_threshold=1)
c.mark_unhealthy("ollama", "boom")
assert c.is_healthy("ollama") is False
def test_mark_healthy_clears_state(self) -> None:
c = ProviderHealthCache(failure_threshold=2)
c.mark_unhealthy("ollama", "x")
c.mark_unhealthy("ollama", "x")
assert c.is_healthy("ollama") is False
c.mark_healthy("ollama")
assert c.is_healthy("ollama") is True
state = c.get_state("ollama")
assert state is not None
assert state.healthy is True
assert state.consecutive_failures == 0
assert state.unhealthy_until == 0.0
assert state.last_error is None
class TestBackoffWindow:
def test_is_healthy_false_during_backoff(self) -> None:
clock = FakeClock()
c = ProviderHealthCache(failure_threshold=1, unhealthy_backoff_sec=60.0, clock=clock)
c.mark_unhealthy("ollama", "boom")
assert c.is_healthy("ollama") is False
clock.advance(30.0)
assert c.is_healthy("ollama") is False
def test_is_healthy_true_after_backoff_expires(self) -> None:
# After backoff: half-open. Router gets one more attempt; success
# mark_healthy resets, failure mark_unhealthy re-arms backoff.
clock = FakeClock()
c = ProviderHealthCache(failure_threshold=1, unhealthy_backoff_sec=60.0, clock=clock)
c.mark_unhealthy("ollama", "boom")
clock.advance(61.0)
assert c.is_healthy("ollama") is True
def test_failure_during_backoff_extends_window(self) -> None:
clock = FakeClock()
c = ProviderHealthCache(failure_threshold=1, unhealthy_backoff_sec=60.0, clock=clock)
c.mark_unhealthy("ollama", "first")
original_until = c.get_state("ollama").unhealthy_until
clock.advance(20.0)
c.mark_unhealthy("ollama", "second")
new_until = c.get_state("ollama").unhealthy_until
assert new_until > original_until
assert new_until == clock.now + 60.0
def test_recovery_logged_only_once(self, caplog: pytest.LogCaptureFixture) -> None:
clock = FakeClock()
c = ProviderHealthCache(failure_threshold=1, unhealthy_backoff_sec=60.0, clock=clock)
c.mark_unhealthy("ollama", "boom")
with caplog.at_level("INFO"):
c.mark_healthy("ollama")
# Calling mark_healthy on an already-healthy provider must not
# spam the recovery log line.
c.mark_healthy("ollama")
c.mark_healthy("ollama")
recovery_lines = [r for r in caplog.records if "recovered" in r.message]
assert len(recovery_lines) == 1
# ---------------------------------------------------------------------------
# Snapshot shape
# ---------------------------------------------------------------------------
class TestSnapshot:
def test_snapshot_returns_copies(self) -> None:
# Caller shouldn't be able to poke through to the cache's internal
# state via a snapshot reference.
c = ProviderHealthCache(failure_threshold=1)
c.mark_unhealthy("ollama", "x")
snap = c.snapshot()
snap["ollama"].healthy = True
# Original state untouched:
assert c.is_healthy("ollama") is False
def test_snapshot_matches_recorded_state(self) -> None:
clock = FakeClock()
c = ProviderHealthCache(
failure_threshold=2,
unhealthy_backoff_sec=60.0,
clock=clock,
)
c.mark_unhealthy("groq", "rate-limit")
snap = c.snapshot()
assert isinstance(snap["groq"], ProviderState)
assert snap["groq"].consecutive_failures == 1
assert snap["groq"].last_error == "rate-limit"
assert snap["groq"].last_check == clock.now
def test_snapshot_expected_does_not_overwrite_real_state(self) -> None:
c = ProviderHealthCache(failure_threshold=1)
c.mark_unhealthy("ollama", "real-boom")
snap = c.snapshot(expected=["ollama", "groq"])
# ollama keeps its real (unhealthy) state, groq gets the zero-default.
assert snap["ollama"].consecutive_failures == 1
assert snap["groq"].consecutive_failures == 0

View file

@ -1,222 +0,0 @@
"""Tests for the background health-probe loop."""
from __future__ import annotations
import asyncio
from typing import Awaitable, Callable
import pytest
from src.health import ProviderHealthCache
from src.health_probe import HealthProbe, ProbeFn
def make_probe(*, returns: bool = True, raises: type[BaseException] | None = None) -> ProbeFn:
"""Synthesise a probe function that always returns / raises the same."""
async def probe() -> bool:
if raises is not None:
raise raises("boom")
return returns
return probe
def make_slow_probe(delay: float, returns: bool = True) -> ProbeFn:
async def probe() -> bool:
await asyncio.sleep(delay)
return returns
return probe
def make_call_counter() -> tuple[ProbeFn, Callable[[], int]]:
"""Probe that counts how many times it was awaited."""
count = 0
async def probe() -> bool:
nonlocal count
count += 1
return True
return probe, lambda: count
# ---------------------------------------------------------------------------
# Construction
# ---------------------------------------------------------------------------
class TestConstruction:
def test_invalid_interval(self) -> None:
with pytest.raises(ValueError, match="interval"):
HealthProbe(ProviderHealthCache(), {}, interval=0.0)
def test_invalid_timeout(self) -> None:
with pytest.raises(ValueError, match="timeout"):
HealthProbe(ProviderHealthCache(), {}, timeout=0.0)
def test_provider_ids_exposes_keys(self) -> None:
cache = ProviderHealthCache()
probe = HealthProbe(
cache,
{"ollama": make_probe(), "groq": make_probe()},
)
assert sorted(probe.provider_ids) == ["groq", "ollama"]
# ---------------------------------------------------------------------------
# tick_once — the per-cycle behaviour
# ---------------------------------------------------------------------------
class TestTickOnce:
@pytest.mark.asyncio
async def test_healthy_probe_marks_healthy(self) -> None:
cache = ProviderHealthCache(failure_threshold=1)
# Pre-mark unhealthy so we can verify the probe recovers it.
cache.mark_unhealthy("ollama", "stale")
probe = HealthProbe(cache, {"ollama": make_probe(returns=True)})
await probe.tick_once()
assert cache.is_healthy("ollama") is True
@pytest.mark.asyncio
async def test_returning_false_marks_unhealthy(self) -> None:
cache = ProviderHealthCache(failure_threshold=1)
probe = HealthProbe(cache, {"ollama": make_probe(returns=False)})
await probe.tick_once()
assert cache.is_healthy("ollama") is False
state = cache.get_state("ollama")
assert state is not None
assert "false" in (state.last_error or "")
@pytest.mark.asyncio
async def test_raising_marks_unhealthy_with_exc_info(self) -> None:
cache = ProviderHealthCache(failure_threshold=1)
probe = HealthProbe(
cache, {"ollama": make_probe(raises=ConnectionError)}
)
await probe.tick_once()
assert cache.is_healthy("ollama") is False
state = cache.get_state("ollama")
assert state is not None
assert "ConnectionError" in (state.last_error or "")
@pytest.mark.asyncio
async def test_timeout_marks_unhealthy(self) -> None:
cache = ProviderHealthCache(failure_threshold=1)
probe = HealthProbe(
cache,
{"ollama": make_slow_probe(delay=1.0)},
timeout=0.05,
)
await probe.tick_once()
assert cache.is_healthy("ollama") is False
state = cache.get_state("ollama")
assert state is not None
assert "timeout" in (state.last_error or "").lower()
@pytest.mark.asyncio
async def test_one_bad_probe_does_not_sink_others(self) -> None:
# Probe 'ollama' raises — 'groq' must still be evaluated and marked
# healthy. Bug shape: an unhandled exception in gather() sinks the
# whole loop.
cache = ProviderHealthCache(failure_threshold=1)
probe = HealthProbe(
cache,
{
"ollama": make_probe(raises=RuntimeError),
"groq": make_probe(returns=True),
},
)
await probe.tick_once()
assert cache.is_healthy("ollama") is False
assert cache.is_healthy("groq") is True
@pytest.mark.asyncio
async def test_concurrent_probes(self) -> None:
# All probes should run in parallel — total elapsed wall-clock for
# N x 100ms probes should be well under N*100ms.
import time
cache = ProviderHealthCache()
probes = {f"p{i}": make_slow_probe(delay=0.1) for i in range(5)}
probe = HealthProbe(cache, probes, timeout=1.0)
t0 = time.perf_counter()
await probe.tick_once()
elapsed = time.perf_counter() - t0
assert elapsed < 0.3, f"probes ran serially? elapsed={elapsed:.3f}s"
@pytest.mark.asyncio
async def test_empty_probes_is_noop(self) -> None:
cache = ProviderHealthCache()
probe = HealthProbe(cache, {})
# No exception, no state mutation.
await probe.tick_once()
assert cache.snapshot() == {}
# ---------------------------------------------------------------------------
# start / stop lifecycle
# ---------------------------------------------------------------------------
class TestLifecycle:
@pytest.mark.asyncio
async def test_start_runs_initial_tick_immediately(self) -> None:
cache = ProviderHealthCache(failure_threshold=1)
cache.mark_unhealthy("ollama", "stale")
probe = HealthProbe(cache, {"ollama": make_probe(returns=True)}, interval=10.0)
await probe.start()
# Give the loop one event-loop turn to run the initial tick before
# blocking on the long sleep.
await asyncio.sleep(0.01)
assert cache.is_healthy("ollama") is True
await probe.stop()
@pytest.mark.asyncio
async def test_stop_cancels_cleanly(self) -> None:
cache = ProviderHealthCache()
probe = HealthProbe(
cache, {"ollama": make_probe()}, interval=10.0, timeout=1.0
)
await probe.start()
assert probe.running is True
await probe.stop()
assert probe.running is False
@pytest.mark.asyncio
async def test_start_is_idempotent(self) -> None:
cache = ProviderHealthCache()
probe = HealthProbe(cache, {"ollama": make_probe()}, interval=10.0)
await probe.start()
await probe.start() # must not spawn a second task
assert probe.running is True
await probe.stop()
@pytest.mark.asyncio
async def test_stop_without_start_is_safe(self) -> None:
cache = ProviderHealthCache()
probe = HealthProbe(cache, {})
await probe.stop() # idempotent / safe pre-start
@pytest.mark.asyncio
async def test_loop_keeps_running_after_tick_error(self) -> None:
# Even if every probe explodes, the loop must keep ticking.
cache = ProviderHealthCache(failure_threshold=1)
fn, count = make_call_counter()
# Wrap with one that raises — but tick_once internally catches
# per-probe via gather(return_exceptions=True). Force an outer error
# via an evil probe key that the dict can't handle? Easier: use the
# call-counter to verify multiple ticks happened.
probe = HealthProbe(
cache,
{"counter": fn},
interval=0.05, # short interval for test
timeout=1.0,
)
await probe.start()
await asyncio.sleep(0.18) # ~3-4 ticks
await probe.stop()
# Initial tick + at least 2 interval ticks.
assert count() >= 3

View file

@ -1,415 +0,0 @@
"""Tests for M4: observability + debug endpoints + reload."""
from __future__ import annotations
import asyncio
import os
import signal
from pathlib import Path
import httpx
import pytest
from fastapi.testclient import TestClient
from prometheus_client import REGISTRY
from src.aliases import AliasRegistry
from src.health import ProviderHealthCache
from src.models import (
ChatCompletionRequest,
ChatCompletionResponse,
Choice,
Message,
MessageResponse,
)
from src.providers import ProviderRouter
from src.utils.metrics import (
record_alias_resolved,
record_fallback,
set_provider_healthy,
)
# ---------------------------------------------------------------------------
# Cache → listener → metric gauge
# ---------------------------------------------------------------------------
class TestHealthChangeListener:
def test_listener_fires_on_unhealthy_transition(self) -> None:
cache = ProviderHealthCache(failure_threshold=2)
events: list[tuple[str, bool]] = []
cache.add_listener(lambda p, h: events.append((p, h)))
# First failure: still healthy → no transition.
cache.mark_unhealthy("ollama", "blip")
assert events == []
# Second failure: transition healthy→unhealthy → fires.
cache.mark_unhealthy("ollama", "boom")
assert events == [("ollama", False)]
def test_listener_fires_on_recovery(self) -> None:
cache = ProviderHealthCache(failure_threshold=1)
events: list[tuple[str, bool]] = []
cache.add_listener(lambda p, h: events.append((p, h)))
cache.mark_unhealthy("ollama", "boom")
assert events == [("ollama", False)]
cache.mark_healthy("ollama")
assert events == [("ollama", False), ("ollama", True)]
def test_steady_state_does_not_fire(self) -> None:
cache = ProviderHealthCache(failure_threshold=1)
events: list[tuple[str, bool]] = []
cache.add_listener(lambda p, h: events.append((p, h)))
# Three healthy ops in a row — no transitions, no events.
for _ in range(3):
cache.mark_healthy("ollama")
assert events == []
def test_listener_exception_does_not_break_cache(self) -> None:
cache = ProviderHealthCache(failure_threshold=1)
def bad(_provider: str, _healthy: bool) -> None:
raise RuntimeError("listener boom")
cache.add_listener(bad)
# Should NOT raise — the cache must keep working with a broken
# listener, otherwise one bad metric callback would brick the
# whole router.
cache.mark_unhealthy("ollama", "x")
assert cache.is_healthy("ollama") is False
def test_multiple_listeners(self) -> None:
cache = ProviderHealthCache(failure_threshold=1)
a: list = []
b: list = []
cache.add_listener(lambda p, h: a.append((p, h)))
cache.add_listener(lambda p, h: b.append((p, h)))
cache.mark_unhealthy("ollama", "x")
assert a == [("ollama", False)]
assert b == [("ollama", False)]
# ---------------------------------------------------------------------------
# Prometheus metrics — counters/gauges actually move
# ---------------------------------------------------------------------------
def _counter_value(name: str, labels: dict[str, str]) -> float:
"""Helper: read the current value of a labeled Prometheus metric."""
samples = REGISTRY.get_sample_value(name, labels=labels)
return samples or 0.0
class TestMetricsRecording:
def test_record_alias_resolved_increments(self) -> None:
before = _counter_value(
"mana_llm_alias_resolved_total",
{"alias": "mana/test-class", "target": "ollama/x:1b"},
)
record_alias_resolved("mana/test-class", "ollama/x:1b")
after = _counter_value(
"mana_llm_alias_resolved_total",
{"alias": "mana/test-class", "target": "ollama/x:1b"},
)
assert after - before == pytest.approx(1.0)
def test_record_fallback_increments(self) -> None:
before = _counter_value(
"mana_llm_fallback_total",
{"from_model": "ollama/x", "to_model": "groq/y", "reason": "ConnectError"},
)
record_fallback("ollama/x", "groq/y", "ConnectError")
after = _counter_value(
"mana_llm_fallback_total",
{"from_model": "ollama/x", "to_model": "groq/y", "reason": "ConnectError"},
)
assert after - before == pytest.approx(1.0)
def test_set_provider_healthy_writes_gauge(self) -> None:
set_provider_healthy("test_provider_xyz", True)
v = REGISTRY.get_sample_value(
"mana_llm_provider_healthy", labels={"provider": "test_provider_xyz"}
)
assert v == 1.0
set_provider_healthy("test_provider_xyz", False)
v = REGISTRY.get_sample_value(
"mana_llm_provider_healthy", labels={"provider": "test_provider_xyz"}
)
assert v == 0.0
# ---------------------------------------------------------------------------
# Router → metrics: end-to-end through a fallback
# ---------------------------------------------------------------------------
class _OkProvider:
"""Minimal provider double — only what the router uses for chat."""
name = "ok-provider"
supports_tools = True
def __init__(self, name: str, fail_with: BaseException | None = None) -> None:
self.name = name
self.fail_with = fail_with
self.calls = 0
def model_supports_tools(self, model: str) -> bool:
return True
async def chat_completion(self, request, model):
self.calls += 1
if self.fail_with is not None:
raise self.fail_with
return ChatCompletionResponse(
model=f"{self.name}/{model}",
choices=[Choice(message=MessageResponse(content="ok"))],
)
async def chat_completion_stream(self, request, model): # pragma: no cover
if False: # pragma: no cover
yield None
async def list_models(self):
return []
async def embeddings(self, request, model):
raise NotImplementedError
async def health_check(self):
return {"status": "healthy"}
async def close(self):
pass
def _aliases(tmp_path: Path) -> AliasRegistry:
cfg = (
"aliases:\n"
" mana/two-step:\n"
' description: "x"\n'
" chain:\n"
" - alpha/m1\n"
" - beta/m2\n"
)
p = tmp_path / "aliases.yaml"
p.write_text(cfg)
return AliasRegistry(p)
class TestRouterMetricsIntegration:
@pytest.mark.asyncio
async def test_alias_resolved_metric_records_target(self, tmp_path: Path) -> None:
aliases = _aliases(tmp_path)
cache = ProviderHealthCache()
router = ProviderRouter(aliases=aliases, health_cache=cache)
router.providers = {"alpha": _OkProvider("alpha")} # beta not configured
before = _counter_value(
"mana_llm_alias_resolved_total",
{"alias": "mana/two-step", "target": "alpha/m1"},
)
await router.chat_completion(
ChatCompletionRequest(
model="mana/two-step",
messages=[Message(role="user", content="hi")],
)
)
after = _counter_value(
"mana_llm_alias_resolved_total",
{"alias": "mana/two-step", "target": "alpha/m1"},
)
assert after - before == pytest.approx(1.0)
@pytest.mark.asyncio
async def test_fallback_metric_records_transition(self, tmp_path: Path) -> None:
aliases = _aliases(tmp_path)
cache = ProviderHealthCache()
router = ProviderRouter(aliases=aliases, health_cache=cache)
router.providers = {
"alpha": _OkProvider("alpha", fail_with=httpx.ConnectError("dead")),
"beta": _OkProvider("beta"),
}
before = _counter_value(
"mana_llm_fallback_total",
{"from_model": "alpha/m1", "to_model": "beta/m2", "reason": "ConnectError"},
)
await router.chat_completion(
ChatCompletionRequest(
model="mana/two-step",
messages=[Message(role="user", content="hi")],
)
)
after = _counter_value(
"mana_llm_fallback_total",
{"from_model": "alpha/m1", "to_model": "beta/m2", "reason": "ConnectError"},
)
assert after - before == pytest.approx(1.0)
@pytest.mark.asyncio
async def test_direct_model_does_not_record_alias_metric(
self, tmp_path: Path
) -> None:
# Direct provider/model is not an alias — ALIAS_RESOLVED counter
# must stay flat for those calls.
aliases = _aliases(tmp_path)
cache = ProviderHealthCache()
router = ProviderRouter(aliases=aliases, health_cache=cache)
router.providers = {"alpha": _OkProvider("alpha")}
before = _counter_value(
"mana_llm_alias_resolved_total",
{"alias": "alpha/anything", "target": "alpha/anything"},
)
await router.chat_completion(
ChatCompletionRequest(
model="alpha/anything",
messages=[Message(role="user", content="hi")],
)
)
after = _counter_value(
"mana_llm_alias_resolved_total",
{"alias": "alpha/anything", "target": "alpha/anything"},
)
# Counter must have NOT increased — direct calls aren't aliases.
assert after == before
# ---------------------------------------------------------------------------
# Debug endpoints: GET /v1/aliases, GET /v1/health
# ---------------------------------------------------------------------------
@pytest.fixture
def client():
from src.main import app
with TestClient(app) as c:
yield c
class TestDebugEndpoints:
def test_v1_aliases_returns_shipped_config(self, client: TestClient) -> None:
resp = client.get("/v1/aliases")
assert resp.status_code == 200
data = resp.json()
names = [a["name"] for a in data["aliases"]]
# The five canonical classes must always be present.
for expected in (
"mana/fast-text",
"mana/long-form",
"mana/structured",
"mana/reasoning",
"mana/vision",
):
assert expected in names
# Default is set in the shipped config.
assert data["default"] == "mana/fast-text"
def test_v1_aliases_chain_format(self, client: TestClient) -> None:
resp = client.get("/v1/aliases")
data = resp.json()
long_form = next(a for a in data["aliases"] if a["name"] == "mana/long-form")
# Each chain entry is a `provider/model` string.
assert all("/" in entry for entry in long_form["chain"])
assert len(long_form["chain"]) >= 2 # plan requires at least one cloud fallback
def test_v1_health_includes_all_providers(self, client: TestClient) -> None:
resp = client.get("/v1/health")
assert resp.status_code == 200
data = resp.json()
assert "status" in data
assert "providers" in data
# ollama is always configured (provider list is non-empty).
assert "ollama" in data["providers"]
for name, info in data["providers"].items():
assert "status" in info
assert "consecutive_failures" in info
# ---------------------------------------------------------------------------
# X-Mana-LLM-Resolved header on non-streaming responses
# ---------------------------------------------------------------------------
class TestResolvedHeader:
"""The header is the consumer's hook for token-cost attribution.
Tested at the router level wiring through main.py would need a
real provider connection, which isn't available in unit tests.
"""
@pytest.mark.asyncio
async def test_response_model_field_carries_resolved_target(
self, tmp_path: Path
) -> None:
# The header value is `response.model`; verify that field reflects
# the actual chain entry that served, not the requested alias.
aliases = _aliases(tmp_path)
cache = ProviderHealthCache()
router = ProviderRouter(aliases=aliases, health_cache=cache)
# Force fallback to beta.
router.providers = {
"alpha": _OkProvider("alpha", fail_with=httpx.ConnectError("d")),
"beta": _OkProvider("beta"),
}
resp = await router.chat_completion(
ChatCompletionRequest(
model="mana/two-step",
messages=[Message(role="user", content="hi")],
)
)
# Even though the caller asked for `mana/two-step`, the resolved
# field shows the entry that actually answered.
assert resp.model == "beta/m2"
# ---------------------------------------------------------------------------
# SIGHUP reload — only meaningful on Unix; tested by signalling the proc
# ---------------------------------------------------------------------------
class TestSighupReload:
"""SIGHUP triggers ``alias_registry.reload()``; reload-error keeps state.
The signal-handler wiring lives in main.py and only installs when
the loop is running in the main thread. We exercise the reload
semantics here directly on the registry instead the signal-handler
code path itself is a 4-line wrapper around ``reload()``.
"""
def test_reload_picks_up_yaml_edits(self, tmp_path: Path) -> None:
path = tmp_path / "aliases.yaml"
path.write_text(
"aliases:\n"
" mana/x:\n"
' description: "x"\n'
" chain:\n"
" - ollama/foo:1b\n"
)
reg = AliasRegistry(path)
assert reg.resolve_chain("mana/x") == ("ollama/foo:1b",)
# Edit on disk, reload (this is exactly what the SIGHUP handler
# does — minus the signal plumbing).
path.write_text(
"aliases:\n"
" mana/x:\n"
' description: "x"\n'
" chain:\n"
" - ollama/bar:1b\n"
" - groq/llama-3.1-8b-instant\n"
)
reg.reload()
assert reg.resolve_chain("mana/x") == (
"ollama/bar:1b",
"groq/llama-3.1-8b-instant",
)

View file

@ -1,123 +0,0 @@
"""Provider tests."""
from pathlib import Path
import pytest
from src.aliases import AliasRegistry
from src.health import ProviderHealthCache
from src.models import ChatCompletionRequest, EmbeddingRequest, Message
from src.providers import OllamaProvider, OpenAICompatProvider, ProviderRouter
@pytest.fixture
def shipped_aliases() -> AliasRegistry:
"""The repo's real aliases.yaml — same one production uses."""
return AliasRegistry(Path(__file__).resolve().parents[1] / "aliases.yaml")
@pytest.fixture
def router(shipped_aliases: AliasRegistry) -> ProviderRouter:
return ProviderRouter(aliases=shipped_aliases, health_cache=ProviderHealthCache())
class TestProviderRouter:
"""Tests for the helpers exposed by the router."""
def test_parse_model_with_provider(self, router: ProviderRouter) -> None:
provider, model = router._parse_model("ollama/gemma3:4b")
assert provider == "ollama"
assert model == "gemma3:4b"
def test_parse_model_without_provider(self, router: ProviderRouter) -> None:
# Bare names default to Ollama for OpenAI-style compat.
provider, model = router._parse_model("gemma3:4b")
assert provider == "ollama"
assert model == "gemma3:4b"
def test_parse_model_openrouter(self, router: ProviderRouter) -> None:
provider, model = router._parse_model("openrouter/meta-llama/llama-3.1-8b-instruct")
assert provider == "openrouter"
assert model == "meta-llama/llama-3.1-8b-instruct"
@pytest.mark.asyncio
async def test_embeddings_unknown_provider_raises(self, router: ProviderRouter) -> None:
# Embeddings don't go through the alias/fallback pipeline — they
# hit the requested provider directly. Asking for an unconfigured
# one is a config error and must raise loudly.
with pytest.raises(ValueError, match="not available"):
await router.embeddings(
EmbeddingRequest(model="bogus_provider/x", input="hi")
)
class TestOllamaProvider:
"""Test Ollama provider."""
def test_convert_simple_messages(self):
"""Test converting simple text messages."""
provider = OllamaProvider()
request = ChatCompletionRequest(
model="gemma3:4b",
messages=[
Message(role="user", content="Hello"),
],
)
messages = provider._convert_messages(request)
assert len(messages) == 1
assert messages[0]["role"] == "user"
assert messages[0]["content"] == "Hello"
def test_convert_multimodal_messages(self):
"""Test converting multimodal messages."""
provider = OllamaProvider()
request = ChatCompletionRequest(
model="llava:7b",
messages=[
Message(
role="user",
content=[
{"type": "text", "text": "What's in this image?"},
{
"type": "image_url",
"image_url": {"url": "data:image/png;base64,iVBORw0KGgo="},
},
],
),
],
)
messages = provider._convert_messages(request)
assert len(messages) == 1
assert messages[0]["role"] == "user"
assert messages[0]["content"] == "What's in this image?"
assert "images" in messages[0]
assert len(messages[0]["images"]) == 1
class TestOpenAICompatProvider:
"""Test OpenAI-compatible provider."""
def test_convert_simple_messages(self):
"""Test converting simple text messages."""
provider = OpenAICompatProvider(
name="test",
base_url="http://localhost",
api_key="test-key",
)
request = ChatCompletionRequest(
model="test-model",
messages=[
Message(role="system", content="You are helpful."),
Message(role="user", content="Hello"),
],
)
messages = provider._convert_messages(request)
assert len(messages) == 2
assert messages[0]["role"] == "system"
assert messages[0]["content"] == "You are helpful."
assert messages[1]["role"] == "user"
assert messages[1]["content"] == "Hello"

View file

@ -1,526 +0,0 @@
"""Tests for ProviderRouter fallback / alias execution (M3)."""
from __future__ import annotations
from collections.abc import AsyncIterator
from typing import Any
import httpx
import pytest
from src.aliases import AliasRegistry
from src.health import ProviderHealthCache
from src.models import (
ChatCompletionRequest,
ChatCompletionResponse,
ChatCompletionStreamResponse,
Choice,
DeltaContent,
EmbeddingRequest,
EmbeddingResponse,
Message,
MessageResponse,
ModelInfo,
StreamChoice,
)
from src.providers import ProviderRouter
from src.providers.base import LLMProvider
from src.providers.errors import (
NoHealthyProviderError,
ProviderAuthError,
ProviderCapabilityError,
ProviderRateLimitError,
)
# ---------------------------------------------------------------------------
# Test doubles
# ---------------------------------------------------------------------------
class MockProvider(LLMProvider):
"""Provider that lets tests inject a sequence of behaviours.
Each call pops one entry from ``behaviors``. Strings ``"ok"`` and
``"empty"`` are sentinels for normal returns; everything else (a
BaseException instance / class) is raised.
"""
supports_tools = True
def __init__(self, name: str, behaviors: list[Any] | None = None) -> None:
self.name = name
self._behaviors: list[Any] = list(behaviors or [])
self.calls: list[str] = []
def push(self, *behaviors: Any) -> None:
self._behaviors.extend(behaviors)
def _next(self) -> Any:
return self._behaviors.pop(0) if self._behaviors else "ok"
async def chat_completion(
self, request: ChatCompletionRequest, model: str
) -> ChatCompletionResponse:
self.calls.append(model)
b = self._next()
if isinstance(b, type) and issubclass(b, BaseException):
raise b("simulated")
if isinstance(b, BaseException):
raise b
return _ok_response(self.name, model)
async def chat_completion_stream(
self, request: ChatCompletionRequest, model: str
) -> AsyncIterator[ChatCompletionStreamResponse]:
self.calls.append(model)
b = self._next()
if isinstance(b, type) and issubclass(b, BaseException):
raise b("simulated")
if isinstance(b, BaseException):
raise b
if b == "empty":
return
for content in ("Hello", " ", "world"):
yield ChatCompletionStreamResponse(
model=f"{self.name}/{model}",
choices=[StreamChoice(delta=DeltaContent(content=content))],
)
async def list_models(self) -> list[ModelInfo]:
return [ModelInfo(id=f"{self.name}/{m}") for m in ("modelA", "modelB")]
async def embeddings(
self, request: EmbeddingRequest, model: str
) -> EmbeddingResponse:
raise NotImplementedError
async def health_check(self) -> dict[str, Any]:
return {"status": "healthy"}
class FailFirstChunkProvider(MockProvider):
"""Streaming provider that raises BEFORE the first chunk every time.
Kept separate from MockProvider's behaviour list so the per-call
semantics stay simple this one models a permanently-broken streamer.
"""
def __init__(self, name: str, exc: BaseException) -> None:
super().__init__(name)
self._exc = exc
async def chat_completion_stream(self, request, model): # type: ignore[override]
self.calls.append(model)
raise self._exc
# the yield is unreachable but keeps the function an async generator
yield # pragma: no cover
def _ok_response(provider: str, model: str) -> ChatCompletionResponse:
return ChatCompletionResponse(
model=f"{provider}/{model}",
choices=[
Choice(
message=MessageResponse(content="ok"),
finish_reason="stop",
)
],
)
def _request(model: str) -> ChatCompletionRequest:
return ChatCompletionRequest(
model=model,
messages=[Message(role="user", content="hi")],
)
def _aliases_yaml(tmp_path) -> AliasRegistry:
"""A two-alias config used across most tests."""
cfg = (
"aliases:\n"
" mana/long-form:\n"
' description: "long"\n'
" chain:\n"
" - alpha/m1\n"
" - beta/m2\n"
" - gamma/m3\n"
" mana/single:\n"
' description: "single-entry"\n'
" chain:\n"
" - alpha/solo\n"
)
p = tmp_path / "aliases.yaml"
p.write_text(cfg)
return AliasRegistry(p)
def _make_router(
tmp_path,
*,
providers: dict[str, MockProvider],
cache: ProviderHealthCache | None = None,
) -> ProviderRouter:
aliases = _aliases_yaml(tmp_path)
router = ProviderRouter(aliases=aliases, health_cache=cache or ProviderHealthCache())
# Replace the auto-initialised live providers with the test doubles.
router.providers = dict(providers)
return router
# ---------------------------------------------------------------------------
# Non-streaming chain walking
# ---------------------------------------------------------------------------
class TestChatCompletionChain:
@pytest.mark.asyncio
async def test_first_provider_ok_returns_immediately(self, tmp_path) -> None:
alpha = MockProvider("alpha", ["ok"])
beta = MockProvider("beta")
router = _make_router(tmp_path, providers={"alpha": alpha, "beta": beta})
resp = await router.chat_completion(_request("mana/long-form"))
assert resp.model == "alpha/m1"
assert alpha.calls == ["m1"]
assert beta.calls == [] # never reached
@pytest.mark.asyncio
async def test_falls_through_on_connect_error(self, tmp_path) -> None:
alpha = MockProvider("alpha", [httpx.ConnectError("dead")])
beta = MockProvider("beta", ["ok"])
router = _make_router(tmp_path, providers={"alpha": alpha, "beta": beta})
resp = await router.chat_completion(_request("mana/long-form"))
assert resp.model == "beta/m2"
assert alpha.calls == ["m1"]
assert beta.calls == ["m2"]
@pytest.mark.asyncio
async def test_skips_unconfigured_chain_entries(self, tmp_path) -> None:
# gamma isn't configured at all → chain should silently skip it
# rather than raise.
alpha = MockProvider("alpha", [httpx.ConnectError("dead")])
beta = MockProvider("beta", [httpx.ConnectError("dead too")])
router = _make_router(tmp_path, providers={"alpha": alpha, "beta": beta})
with pytest.raises(NoHealthyProviderError) as exc_info:
await router.chat_completion(_request("mana/long-form"))
# All three entries appear in attempts: two as ConnectError, one
# as unconfigured (not a fatal error, just skipped).
attempts = exc_info.value.attempts
assert ("alpha/m1", "ConnectError") in attempts
assert ("beta/m2", "ConnectError") in attempts
assert ("gamma/m3", "unconfigured") in attempts
@pytest.mark.asyncio
async def test_skips_cache_unhealthy(self, tmp_path) -> None:
cache = ProviderHealthCache(failure_threshold=1)
cache.mark_unhealthy("alpha", "stale")
alpha = MockProvider("alpha", ["ok"])
beta = MockProvider("beta", ["ok"])
router = _make_router(
tmp_path, providers={"alpha": alpha, "beta": beta}, cache=cache
)
resp = await router.chat_completion(_request("mana/long-form"))
assert alpha.calls == [] # router skipped per cache
assert beta.calls == ["m2"]
assert resp.model == "beta/m2"
@pytest.mark.asyncio
async def test_5xx_treated_as_retryable(self, tmp_path) -> None:
five_hundred = httpx.HTTPStatusError(
"boom",
request=httpx.Request("POST", "http://x"),
response=httpx.Response(503),
)
alpha = MockProvider("alpha", [five_hundred])
beta = MockProvider("beta", ["ok"])
router = _make_router(tmp_path, providers={"alpha": alpha, "beta": beta})
resp = await router.chat_completion(_request("mana/long-form"))
assert resp.model == "beta/m2"
@pytest.mark.asyncio
async def test_4xx_propagates(self, tmp_path) -> None:
four_hundred = httpx.HTTPStatusError(
"bad request",
request=httpx.Request("POST", "http://x"),
response=httpx.Response(422),
)
alpha = MockProvider("alpha", [four_hundred])
beta = MockProvider("beta", ["ok"])
router = _make_router(tmp_path, providers={"alpha": alpha, "beta": beta})
with pytest.raises(httpx.HTTPStatusError):
await router.chat_completion(_request("mana/long-form"))
# Beta never tried — caller's request needs fixing, retrying
# against another model would just hide the bug.
assert beta.calls == []
@pytest.mark.asyncio
async def test_capability_error_propagates(self, tmp_path) -> None:
alpha = MockProvider("alpha", [ProviderCapabilityError("no tools")])
beta = MockProvider("beta", ["ok"])
router = _make_router(tmp_path, providers={"alpha": alpha, "beta": beta})
with pytest.raises(ProviderCapabilityError):
await router.chat_completion(_request("mana/long-form"))
assert beta.calls == []
@pytest.mark.asyncio
async def test_auth_error_propagates(self, tmp_path) -> None:
# Auth errors mean OUR setup is broken (wrong key); falling back
# to the next provider hides the misconfiguration.
alpha = MockProvider("alpha", [ProviderAuthError("bad key")])
beta = MockProvider("beta", ["ok"])
router = _make_router(tmp_path, providers={"alpha": alpha, "beta": beta})
with pytest.raises(ProviderAuthError):
await router.chat_completion(_request("mana/long-form"))
assert beta.calls == []
@pytest.mark.asyncio
async def test_rate_limit_is_retryable(self, tmp_path) -> None:
alpha = MockProvider("alpha", [ProviderRateLimitError("slow down")])
beta = MockProvider("beta", ["ok"])
router = _make_router(tmp_path, providers={"alpha": alpha, "beta": beta})
resp = await router.chat_completion(_request("mana/long-form"))
assert resp.model == "beta/m2"
@pytest.mark.asyncio
async def test_all_fail_raises_no_healthy_provider(self, tmp_path) -> None:
alpha = MockProvider("alpha", [httpx.ConnectError("a")])
beta = MockProvider("beta", [httpx.ConnectError("b")])
gamma = MockProvider("gamma", [httpx.ConnectError("c")])
router = _make_router(
tmp_path, providers={"alpha": alpha, "beta": beta, "gamma": gamma}
)
with pytest.raises(NoHealthyProviderError) as exc_info:
await router.chat_completion(_request("mana/long-form"))
assert exc_info.value.model_or_alias == "mana/long-form"
assert isinstance(exc_info.value.last_exception, httpx.ConnectError)
# 503 status so calling code (mana-api etc.) can decide to retry
# later vs surface a clean error to the user.
assert exc_info.value.http_status == 503
@pytest.mark.asyncio
async def test_direct_provider_string_no_alias_resolution(self, tmp_path) -> None:
# Caller bypasses aliases by passing a direct provider/model.
# No fallback chain — fail = fail.
alpha = MockProvider("alpha", [httpx.ConnectError("dead")])
beta = MockProvider("beta", ["ok"])
router = _make_router(tmp_path, providers={"alpha": alpha, "beta": beta})
with pytest.raises(NoHealthyProviderError):
await router.chat_completion(_request("alpha/anything"))
# Beta would have served if this had been an alias — but it
# wasn't, so beta never gets touched.
assert beta.calls == []
# ---------------------------------------------------------------------------
# Health-cache feedback: success clears, failure marks
# ---------------------------------------------------------------------------
class TestHealthCacheFeedback:
@pytest.mark.asyncio
async def test_success_marks_provider_healthy(self, tmp_path) -> None:
cache = ProviderHealthCache(failure_threshold=1)
cache.mark_unhealthy("alpha", "stale-from-probe")
# After the cache TTL the cache thinks alpha might be OK again,
# so the router will try it; success must fully clear the state.
# (Force half-open by zeroing backoff.)
alpha = MockProvider("alpha", ["ok"])
router = _make_router(
tmp_path,
providers={"alpha": alpha},
cache=ProviderHealthCache(), # fresh cache, alpha optimistic
)
await router.chat_completion(_request("mana/single"))
assert router.health_cache.get_state("alpha").healthy is True
assert router.health_cache.get_state("alpha").consecutive_failures == 0
@pytest.mark.asyncio
async def test_failure_marks_provider_unhealthy(self, tmp_path) -> None:
# threshold=1 so a single fail is enough to flip the breaker.
cache = ProviderHealthCache(failure_threshold=1)
alpha = MockProvider("alpha", [httpx.ConnectError("boom")])
beta = MockProvider("beta", ["ok"])
router = _make_router(
tmp_path, providers={"alpha": alpha, "beta": beta}, cache=cache
)
await router.chat_completion(_request("mana/long-form"))
assert cache.get_state("alpha").healthy is False
assert cache.get_state("alpha").last_error is not None
assert "ConnectError" in cache.get_state("alpha").last_error
@pytest.mark.asyncio
async def test_propagating_error_does_not_touch_cache(self, tmp_path) -> None:
# Auth/Capability errors are about CALLER state, not provider
# health — the cache must stay clean so a real outage later
# isn't masked by stale "marked unhealthy because of bad key".
cache = ProviderHealthCache(failure_threshold=1)
alpha = MockProvider("alpha", [ProviderAuthError("bad key")])
router = _make_router(tmp_path, providers={"alpha": alpha}, cache=cache)
with pytest.raises(ProviderAuthError):
await router.chat_completion(_request("mana/single"))
# No state recorded.
assert cache.get_state("alpha") is None
# ---------------------------------------------------------------------------
# Streaming pre-first-byte fallback
# ---------------------------------------------------------------------------
class TestChatCompletionStream:
@pytest.mark.asyncio
async def test_first_provider_streams_normally(self, tmp_path) -> None:
alpha = MockProvider("alpha", ["ok"])
beta = MockProvider("beta")
router = _make_router(tmp_path, providers={"alpha": alpha, "beta": beta})
chunks = [
c async for c in router.chat_completion_stream(_request("mana/long-form"))
]
assert beta.calls == []
assert len(chunks) == 3
assert "".join(c.choices[0].delta.content or "" for c in chunks) == "Hello world"
@pytest.mark.asyncio
async def test_pre_first_byte_failure_falls_back(self, tmp_path) -> None:
alpha = FailFirstChunkProvider("alpha", httpx.ConnectError("dead"))
beta = MockProvider("beta", ["ok"])
router = _make_router(tmp_path, providers={"alpha": alpha, "beta": beta})
chunks = [
c async for c in router.chat_completion_stream(_request("mana/long-form"))
]
assert alpha.calls == ["m1"]
assert beta.calls == ["m2"]
assert len(chunks) == 3
assert all(c.model == "beta/m2" for c in chunks)
@pytest.mark.asyncio
async def test_pre_first_byte_4xx_propagates_no_fallback(self, tmp_path) -> None:
alpha = FailFirstChunkProvider("alpha", ProviderCapabilityError("no tools"))
beta = MockProvider("beta", ["ok"])
router = _make_router(tmp_path, providers={"alpha": alpha, "beta": beta})
with pytest.raises(ProviderCapabilityError):
async for _ in router.chat_completion_stream(_request("mana/long-form")):
pass
assert beta.calls == []
@pytest.mark.asyncio
async def test_empty_stream_commits_without_fallback(self, tmp_path) -> None:
# Empty-but-successful stream is a valid response, not a failure
# we should retry — committing avoids accidentally calling two
# providers and double-billing.
alpha = MockProvider("alpha", ["empty"])
beta = MockProvider("beta", ["ok"])
router = _make_router(tmp_path, providers={"alpha": alpha, "beta": beta})
chunks = [
c async for c in router.chat_completion_stream(_request("mana/long-form"))
]
assert chunks == []
assert beta.calls == [] # didn't fall through
@pytest.mark.asyncio
async def test_mid_stream_failure_does_not_fall_back(self, tmp_path) -> None:
# Custom provider that yields once then raises mid-stream — the
# router has already committed and must let the error propagate
# rather than splice in another provider's voice.
class MidStreamFailProvider(MockProvider):
async def chat_completion_stream(self, request, model): # type: ignore[override]
self.calls.append(model)
yield ChatCompletionStreamResponse(
model=f"{self.name}/{model}",
choices=[StreamChoice(delta=DeltaContent(content="halb"))],
)
raise httpx.RemoteProtocolError("connection dropped")
alpha = MidStreamFailProvider("alpha")
beta = MockProvider("beta", ["ok"])
router = _make_router(tmp_path, providers={"alpha": alpha, "beta": beta})
collected: list[str] = []
with pytest.raises(httpx.RemoteProtocolError):
async for chunk in router.chat_completion_stream(_request("mana/long-form")):
collected.append(chunk.choices[0].delta.content or "")
# We got the half-chunk that landed before the break; beta was
# NOT called as fallback.
assert collected == ["halb"]
assert beta.calls == []
@pytest.mark.asyncio
async def test_all_fail_streaming_raises_no_healthy_provider(self, tmp_path) -> None:
alpha = FailFirstChunkProvider("alpha", httpx.ConnectError("a"))
beta = FailFirstChunkProvider("beta", httpx.ConnectError("b"))
gamma = FailFirstChunkProvider("gamma", httpx.ConnectError("c"))
router = _make_router(
tmp_path, providers={"alpha": alpha, "beta": beta, "gamma": gamma}
)
with pytest.raises(NoHealthyProviderError):
async for _ in router.chat_completion_stream(_request("mana/long-form")):
pass
# ---------------------------------------------------------------------------
# Health-check shape (still using the cache snapshot)
# ---------------------------------------------------------------------------
class TestHealthCheck:
@pytest.mark.asyncio
async def test_health_check_lists_known_providers(self, tmp_path) -> None:
# Even if no probe has run yet, every configured provider should
# appear in the snapshot (zero-defaults) so /health has a stable
# shape for monitors.
alpha = MockProvider("alpha")
beta = MockProvider("beta")
router = _make_router(tmp_path, providers={"alpha": alpha, "beta": beta})
out = await router.health_check()
assert set(out["providers"].keys()) == {"alpha", "beta"}
assert out["status"] == "healthy"
assert all(p["status"] == "healthy" for p in out["providers"].values())
@pytest.mark.asyncio
async def test_health_check_degraded_when_one_unhealthy(self, tmp_path) -> None:
cache = ProviderHealthCache(failure_threshold=1)
cache.mark_unhealthy("alpha", "boom")
alpha = MockProvider("alpha")
beta = MockProvider("beta")
router = _make_router(
tmp_path, providers={"alpha": alpha, "beta": beta}, cache=cache
)
out = await router.health_check()
assert out["status"] == "degraded"
assert out["providers"]["alpha"]["status"] == "unhealthy"
assert out["providers"]["beta"]["status"] == "healthy"

View file

@ -1,57 +0,0 @@
"""Streaming tests."""
import pytest
from src.models import ChatCompletionStreamResponse, DeltaContent, StreamChoice
class TestStreamingModels:
"""Test streaming response models."""
def test_stream_response_serialization(self):
"""Test streaming response serializes correctly."""
response = ChatCompletionStreamResponse(
id="test-id",
model="ollama/gemma3:4b",
choices=[
StreamChoice(
delta=DeltaContent(content="Hello"),
)
],
)
data = response.model_dump(exclude_none=True)
assert data["id"] == "test-id"
assert data["model"] == "ollama/gemma3:4b"
assert data["choices"][0]["delta"]["content"] == "Hello"
def test_stream_response_with_role(self):
"""Test first chunk with role."""
response = ChatCompletionStreamResponse(
id="test-id",
model="ollama/gemma3:4b",
choices=[
StreamChoice(
delta=DeltaContent(role="assistant"),
)
],
)
data = response.model_dump(exclude_none=True)
assert data["choices"][0]["delta"]["role"] == "assistant"
def test_stream_response_with_finish_reason(self):
"""Test final chunk with finish_reason."""
response = ChatCompletionStreamResponse(
id="test-id",
model="ollama/gemma3:4b",
choices=[
StreamChoice(
delta=DeltaContent(),
finish_reason="stop",
)
],
)
data = response.model_dump(exclude_none=True)
assert data["choices"][0]["finish_reason"] == "stop"