mirror of
https://github.com/Memo-2023/mana-monorepo.git
synced 2026-05-14 16:41:08 +02:00
chore(cutover): remove services/mana-llm/ — moved to mana-platform
Live containers on the Mac Mini build out of `../mana/services/mana-llm/`
since the 8-Doppel-Cutover commit (774852ba2). Smoke test green
2026-05-08 — health endpoints, JWKS, login flow, Stripe-webhook all
reachable from the new build path. Removing the now-stale duplicate.
Was 90M in this repo, gone now. Active code lives in
`Code/mana/services/mana-llm/` (siehe ../mana/CLAUDE.md).
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
parent
6103d4d2d9
commit
2b07f6ef89
42 changed files with 0 additions and 6371 deletions
|
|
@ -1,28 +0,0 @@
|
|||
# Service
|
||||
PORT=3025
|
||||
LOG_LEVEL=info
|
||||
|
||||
# Ollama (Primary)
|
||||
OLLAMA_URL=http://localhost:11434
|
||||
OLLAMA_DEFAULT_MODEL=gemma3:4b
|
||||
OLLAMA_TIMEOUT=120
|
||||
|
||||
# OpenRouter (Cloud Fallback)
|
||||
OPENROUTER_API_KEY=sk-or-v1-xxx
|
||||
OPENROUTER_DEFAULT_MODEL=meta-llama/llama-3.1-8b-instruct
|
||||
|
||||
# Groq (Optional)
|
||||
GROQ_API_KEY=gsk_xxx
|
||||
|
||||
# Together (Optional)
|
||||
TOGETHER_API_KEY=xxx
|
||||
|
||||
# Caching (Optional)
|
||||
REDIS_URL=redis://localhost:6379
|
||||
CACHE_TTL=3600
|
||||
|
||||
# CORS
|
||||
CORS_ORIGINS=http://localhost:5173,https://mana.how
|
||||
|
||||
# API key for cross-service auth (validated in src/api_auth.py)
|
||||
GPU_API_KEY=
|
||||
31
services/mana-llm/.gitignore
vendored
31
services/mana-llm/.gitignore
vendored
|
|
@ -1,31 +0,0 @@
|
|||
# Python
|
||||
__pycache__/
|
||||
*.py[cod]
|
||||
*$py.class
|
||||
*.so
|
||||
.Python
|
||||
venv/
|
||||
.venv/
|
||||
ENV/
|
||||
.env
|
||||
|
||||
# IDE
|
||||
.idea/
|
||||
.vscode/
|
||||
*.swp
|
||||
*.swo
|
||||
|
||||
# Testing
|
||||
.pytest_cache/
|
||||
.coverage
|
||||
htmlcov/
|
||||
.tox/
|
||||
|
||||
# Build
|
||||
dist/
|
||||
build/
|
||||
*.egg-info/
|
||||
|
||||
# Local
|
||||
.env.local
|
||||
*.log
|
||||
|
|
@ -1,376 +0,0 @@
|
|||
# mana-llm
|
||||
|
||||
Central LLM abstraction service providing a unified OpenAI-compatible API for Ollama and cloud LLM providers.
|
||||
|
||||
## Overview
|
||||
|
||||
mana-llm acts as a central gateway for all LLM requests in the monorepo, providing:
|
||||
- Unified OpenAI-compatible API
|
||||
- Provider routing (Ollama, OpenRouter, Groq, Together)
|
||||
- Streaming via Server-Sent Events (SSE)
|
||||
- Vision/multimodal support
|
||||
- Embeddings generation
|
||||
- Prometheus metrics
|
||||
|
||||
## Architecture
|
||||
|
||||
```
|
||||
┌─────────────────────────────────────────────────────────────────────┐
|
||||
│ Consumer Apps │
|
||||
│ chat-backend │ mana web │ todo (LLM enrich) │ etc. │
|
||||
└────────────────────────────────┬────────────────────────────────────┘
|
||||
│ HTTP/SSE
|
||||
▼
|
||||
┌─────────────────────────────────────────────────────────────────────┐
|
||||
│ mana-llm (Port 3025) │
|
||||
│ ┌─────────────┐ ┌─────────────┐ ┌─────────────┐ │
|
||||
│ │ Router │ │ Cache │ │ Metrics │ │
|
||||
│ │ (Provider) │ │ (Redis) │ │ (Prometheus)│ │
|
||||
│ └──────┬──────┘ └─────────────┘ └─────────────┘ │
|
||||
│ │ │
|
||||
│ ┌──────┴──────────────────────────────────────────┐ │
|
||||
│ │ Provider Adapters │ │
|
||||
│ │ ┌──────────┐ ┌──────────┐ ┌──────────────┐ │ │
|
||||
│ │ │ Ollama │ │ OpenAI │ │ OpenRouter │ │ │
|
||||
│ │ │ Adapter │ │ Adapter │ │ Adapter │ │ │
|
||||
│ │ └──────────┘ └──────────┘ └──────────────┘ │ │
|
||||
│ └─────────────────────────────────────────────────┘ │
|
||||
└─────────────────────────────────────────────────────────────────────┘
|
||||
```
|
||||
|
||||
## Quick Start
|
||||
|
||||
### Prerequisites
|
||||
|
||||
- Python 3.11+
|
||||
- Ollama running locally (http://localhost:11434)
|
||||
- Redis (optional, for caching)
|
||||
|
||||
### Development
|
||||
|
||||
```bash
|
||||
cd services/mana-llm
|
||||
|
||||
# Create virtual environment
|
||||
python -m venv venv
|
||||
source venv/bin/activate # or venv\Scripts\activate on Windows
|
||||
|
||||
# Install dependencies
|
||||
pip install -r requirements.txt
|
||||
|
||||
# Copy environment file
|
||||
cp .env.example .env
|
||||
|
||||
# Start Redis (optional)
|
||||
docker-compose -f docker-compose.dev.yml up -d
|
||||
|
||||
# Run service
|
||||
python -m uvicorn src.main:app --port 3025 --reload
|
||||
```
|
||||
|
||||
### Docker
|
||||
|
||||
```bash
|
||||
# Full stack (mana-llm + Redis)
|
||||
docker-compose up -d
|
||||
|
||||
# View logs
|
||||
docker-compose logs -f mana-llm
|
||||
```
|
||||
|
||||
## API Endpoints
|
||||
|
||||
### Chat Completions
|
||||
|
||||
```bash
|
||||
# Non-streaming
|
||||
curl -X POST http://localhost:3025/v1/chat/completions \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{
|
||||
"model": "ollama/gemma3:4b",
|
||||
"messages": [{"role": "user", "content": "Hello!"}],
|
||||
"stream": false
|
||||
}'
|
||||
|
||||
# Streaming (SSE)
|
||||
curl -X POST http://localhost:3025/v1/chat/completions \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{
|
||||
"model": "ollama/gemma3:4b",
|
||||
"messages": [{"role": "user", "content": "Hello!"}],
|
||||
"stream": true
|
||||
}'
|
||||
```
|
||||
|
||||
### Vision/Multimodal
|
||||
|
||||
```bash
|
||||
curl -X POST http://localhost:3025/v1/chat/completions \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{
|
||||
"model": "ollama/llava:7b",
|
||||
"messages": [{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "text", "text": "What is in this image?"},
|
||||
{"type": "image_url", "image_url": {"url": "data:image/png;base64,..."}}
|
||||
]
|
||||
}]
|
||||
}'
|
||||
```
|
||||
|
||||
### Models
|
||||
|
||||
```bash
|
||||
# List all models
|
||||
curl http://localhost:3025/v1/models
|
||||
|
||||
# Get specific model
|
||||
curl http://localhost:3025/v1/models/ollama/gemma3:4b
|
||||
```
|
||||
|
||||
### Embeddings
|
||||
|
||||
```bash
|
||||
curl -X POST http://localhost:3025/v1/embeddings \
|
||||
-H "Content-Type: application/json" \
|
||||
-d '{
|
||||
"model": "ollama/nomic-embed-text",
|
||||
"input": "Text to embed"
|
||||
}'
|
||||
```
|
||||
|
||||
### Health & Metrics
|
||||
|
||||
```bash
|
||||
# Liveness summary (legacy, terse shape — only status + per-provider status string)
|
||||
curl http://localhost:3025/health
|
||||
|
||||
# Detailed per-provider liveness snapshot (M4)
|
||||
curl http://localhost:3025/v1/health
|
||||
|
||||
# Prometheus metrics
|
||||
curl http://localhost:3025/metrics
|
||||
```
|
||||
|
||||
### Aliases (M4 — see `aliases.yaml` and the fallback section below)
|
||||
|
||||
```bash
|
||||
# What does each `mana/<class>` resolve to?
|
||||
curl http://localhost:3025/v1/aliases
|
||||
```
|
||||
|
||||
## Aliases & Fallback
|
||||
|
||||
> Background: [`docs/plans/llm-fallback-aliases.md`](../../docs/plans/llm-fallback-aliases.md)
|
||||
|
||||
### What callers send
|
||||
|
||||
Two acceptable shapes for the `model` field of `/v1/chat/completions`:
|
||||
|
||||
1. **Aliases** in the reserved `mana/` namespace — recommended for product code.
|
||||
The router resolves them via `aliases.yaml` to a chain of concrete
|
||||
`provider/model` strings and tries them in order.
|
||||
2. **Direct `provider/model`** — bypasses the alias layer, no fallback.
|
||||
Useful for tests, debugging, and one-off integrations.
|
||||
|
||||
| Alias | Class |
|
||||
|---|---|
|
||||
| `mana/fast-text` | Short answers, classification, single-shot Q&A |
|
||||
| `mana/long-form` | Writing, essays, stories, longer prose |
|
||||
| `mana/structured` | JSON output (comic storyboards, research subqueries, tag suggestions) |
|
||||
| `mana/reasoning` | Agent missions, tool calls, multi-step plans |
|
||||
| `mana/vision` | Multimodal (image + text) |
|
||||
|
||||
The chain for each alias lives in `services/mana-llm/aliases.yaml`. Edit
|
||||
the file and `kill -HUP <pid>` to reload — no restart needed. Reload
|
||||
errors keep the previous good state; check the service logs.
|
||||
|
||||
### Fallback semantics
|
||||
|
||||
Every chain is tried in order. The router skips an entry if the provider
|
||||
isn't configured at this deployment (no API key) or is currently marked
|
||||
unhealthy by the health-cache. For each remaining entry the request is
|
||||
attempted; on a **retryable** error (connection failure, timeout, 5xx,
|
||||
rate-limit, RemoteProtocolError) the provider is marked unhealthy and
|
||||
the next entry is tried. **Non-retryable** errors (auth, capability,
|
||||
content-blocked, 4xx, unknown exception types) propagate immediately —
|
||||
no fallback, the cache is not poisoned.
|
||||
|
||||
Streaming follows the same logic up to the **first byte**. Once a chunk
|
||||
has been yielded the provider is committed; mid-stream errors surface
|
||||
as-is so we never splice two providers' voices into one output.
|
||||
|
||||
If every entry was skipped or failed, the response is `503` carrying a
|
||||
structured `attempts: list[(model, reason)]` log so the cause is
|
||||
visible to the caller, not only in service logs.
|
||||
|
||||
### Resolved-model header
|
||||
|
||||
Non-streaming responses carry `X-Mana-LLM-Resolved: <provider>/<model>`
|
||||
(e.g. `groq/llama-3.3-70b-versatile`) — the concrete model that
|
||||
actually answered. Use this for token-cost attribution when the request
|
||||
used an alias. For streaming, each chunk's `model` field carries the
|
||||
same info (headers go out before the chain is walked).
|
||||
|
||||
### Health-cache + probe
|
||||
|
||||
`ProviderHealthCache` keeps a per-provider circuit-breaker:
|
||||
|
||||
* 1 failure: still healthy (transient blip, don't bounce).
|
||||
* 2 consecutive failures: `is_healthy → False` for 60 s; the router
|
||||
fail-fasts straight to the next chain entry.
|
||||
* After 60 s: half-open. Next call exercises the provider; success
|
||||
fully resets, failure re-arms the backoff.
|
||||
|
||||
A background `HealthProbe` task runs every 30 s with a 3 s timeout per
|
||||
provider, calling cheap endpoints (`/api/tags` for Ollama, `/v1/models`
|
||||
for OpenAI-compat). One bad probe can't sink the loop; results feed
|
||||
into the same cache as the call-site fallback.
|
||||
|
||||
### Prometheus metrics added in M4
|
||||
|
||||
| Metric | Labels | Purpose |
|
||||
|---|---|---|
|
||||
| `mana_llm_alias_resolved_total` | `alias`, `target` | How often an alias resolved to which concrete model — useful for spotting cases where the primary always falls through. |
|
||||
| `mana_llm_fallback_total` | `from_model`, `to_model`, `reason` | Each fallback transition. `reason` is the exception class name or `cache-unhealthy` / `unconfigured`. |
|
||||
| `mana_llm_provider_healthy` | `provider` | Gauge: 1 healthy, 0 in backoff. Mirrors the circuit-breaker. |
|
||||
|
||||
## Provider Routing
|
||||
|
||||
Models use the format `provider/model`:
|
||||
|
||||
| Model | Provider | Target |
|
||||
|-------|----------|--------|
|
||||
| `ollama/gemma3:4b` | Ollama | localhost:11434 |
|
||||
| `ollama/llava:7b` | Ollama | localhost:11434 |
|
||||
| `openrouter/meta-llama/llama-3.1-8b-instruct` | OpenRouter | api.openrouter.ai |
|
||||
| `groq/llama-3.1-8b-instant` | Groq | api.groq.com |
|
||||
| `together/meta-llama/Meta-Llama-3.1-8B-Instruct-Turbo` | Together | api.together.xyz |
|
||||
|
||||
**Default:** If no provider prefix is given (e.g., `gemma3:4b`), Ollama is used.
|
||||
|
||||
## Configuration
|
||||
|
||||
Environment variables (see `.env.example`):
|
||||
|
||||
| Variable | Default | Description |
|
||||
|----------|---------|-------------|
|
||||
| `PORT` | 3025 | Service port |
|
||||
| `LOG_LEVEL` | info | Logging level |
|
||||
| `OLLAMA_URL` | http://localhost:11434 | Ollama server URL |
|
||||
| `OLLAMA_DEFAULT_MODEL` | gemma3:4b | Default Ollama model |
|
||||
| `OLLAMA_TIMEOUT` | 120 | Ollama request timeout (seconds) |
|
||||
| `OPENROUTER_API_KEY` | - | OpenRouter API key |
|
||||
| `GROQ_API_KEY` | - | Groq API key |
|
||||
| `TOGETHER_API_KEY` | - | Together API key |
|
||||
| `REDIS_URL` | - | Redis URL for caching |
|
||||
| `CACHE_TTL` | 3600 | Cache TTL in seconds |
|
||||
| `CORS_ORIGINS` | localhost | Allowed CORS origins |
|
||||
|
||||
## Project Structure
|
||||
|
||||
```
|
||||
services/mana-llm/
|
||||
├── src/
|
||||
│ ├── main.py # FastAPI app entry point
|
||||
│ ├── config.py # Settings via pydantic-settings
|
||||
│ ├── providers/
|
||||
│ │ ├── base.py # Abstract provider interface
|
||||
│ │ ├── ollama.py # Ollama provider
|
||||
│ │ ├── openai_compat.py # OpenAI-compatible provider
|
||||
│ │ └── router.py # Provider routing logic
|
||||
│ ├── models/
|
||||
│ │ ├── requests.py # Request Pydantic models
|
||||
│ │ └── responses.py # Response Pydantic models
|
||||
│ ├── streaming/
|
||||
│ │ └── sse.py # SSE response handling
|
||||
│ └── utils/
|
||||
│ ├── cache.py # Redis caching
|
||||
│ └── metrics.py # Prometheus metrics
|
||||
├── tests/
|
||||
│ ├── test_api.py # API endpoint tests
|
||||
│ ├── test_providers.py # Provider tests
|
||||
│ └── test_streaming.py # Streaming tests
|
||||
├── Dockerfile
|
||||
├── docker-compose.yml
|
||||
├── docker-compose.dev.yml
|
||||
├── requirements.txt
|
||||
├── pyproject.toml
|
||||
└── .env.example
|
||||
```
|
||||
|
||||
## Testing
|
||||
|
||||
```bash
|
||||
# Run tests
|
||||
pytest
|
||||
|
||||
# Run with coverage
|
||||
pytest --cov=src
|
||||
|
||||
# Run specific test file
|
||||
pytest tests/test_providers.py -v
|
||||
```
|
||||
|
||||
## Integration Example
|
||||
|
||||
### TypeScript/Node.js Client
|
||||
|
||||
```typescript
|
||||
// Using fetch
|
||||
const response = await fetch('http://localhost:3025/v1/chat/completions', {
|
||||
method: 'POST',
|
||||
headers: { 'Content-Type': 'application/json' },
|
||||
body: JSON.stringify({
|
||||
model: 'ollama/gemma3:4b',
|
||||
messages: [{ role: 'user', content: 'Hello!' }],
|
||||
stream: false,
|
||||
}),
|
||||
});
|
||||
|
||||
const data = await response.json();
|
||||
console.log(data.choices[0].message.content);
|
||||
```
|
||||
|
||||
### Streaming with EventSource
|
||||
|
||||
```typescript
|
||||
const response = await fetch('http://localhost:3025/v1/chat/completions', {
|
||||
method: 'POST',
|
||||
headers: { 'Content-Type': 'application/json' },
|
||||
body: JSON.stringify({
|
||||
model: 'ollama/gemma3:4b',
|
||||
messages: [{ role: 'user', content: 'Hello!' }],
|
||||
stream: true,
|
||||
}),
|
||||
});
|
||||
|
||||
const reader = response.body?.getReader();
|
||||
const decoder = new TextDecoder();
|
||||
|
||||
while (true) {
|
||||
const { done, value } = await reader!.read();
|
||||
if (done) break;
|
||||
|
||||
const chunk = decoder.decode(value);
|
||||
const lines = chunk.split('\n').filter(line => line.startsWith('data: '));
|
||||
|
||||
for (const line of lines) {
|
||||
const data = line.slice(6);
|
||||
if (data === '[DONE]') break;
|
||||
|
||||
const parsed = JSON.parse(data);
|
||||
const content = parsed.choices[0]?.delta?.content;
|
||||
if (content) process.stdout.write(content);
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## Related Services
|
||||
|
||||
| Service | Port | Description |
|
||||
|---------|------|-------------|
|
||||
| mana-tts | 3022 | Text-to-speech service |
|
||||
| mana-stt | 3023 | Speech-to-text service |
|
||||
| mana-search | 3021 | Web search & extraction |
|
||||
|
|
@ -1,49 +0,0 @@
|
|||
# Build stage
|
||||
FROM python:3.12-slim AS builder
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
# Install build dependencies
|
||||
RUN apt-get update && apt-get install -y --no-install-recommends \
|
||||
build-essential \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
# Copy requirements first for caching
|
||||
COPY requirements.txt .
|
||||
RUN pip install --no-cache-dir --user -r requirements.txt
|
||||
|
||||
# Production stage
|
||||
FROM python:3.12-slim
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
# Create non-root user
|
||||
RUN useradd -m -u 1000 appuser
|
||||
|
||||
# Copy installed packages from builder
|
||||
COPY --from=builder /root/.local /home/appuser/.local
|
||||
|
||||
# Copy application code + alias config. `aliases.yaml` is loaded by
|
||||
# main.py's lifespan handler from `Path(__file__).parent.parent /
|
||||
# 'aliases.yaml'` = /app/aliases.yaml. Without it the FastAPI app
|
||||
# raises AliasConfigError on startup and the container crashloops.
|
||||
COPY --chown=appuser:appuser src/ ./src/
|
||||
COPY --chown=appuser:appuser aliases.yaml ./aliases.yaml
|
||||
|
||||
# Set environment
|
||||
ENV PATH=/home/appuser/.local/bin:$PATH
|
||||
ENV PYTHONUNBUFFERED=1
|
||||
ENV PYTHONDONTWRITEBYTECODE=1
|
||||
|
||||
# Switch to non-root user
|
||||
USER appuser
|
||||
|
||||
# Expose port
|
||||
EXPOSE 3025
|
||||
|
||||
# Health check
|
||||
HEALTHCHECK --interval=30s --timeout=10s --start-period=5s --retries=3 \
|
||||
CMD python -c "import httpx; httpx.get('http://localhost:3025/health').raise_for_status()"
|
||||
|
||||
# Run application
|
||||
CMD ["python", "-m", "uvicorn", "src.main:app", "--host", "0.0.0.0", "--port", "3025"]
|
||||
|
|
@ -1,54 +0,0 @@
|
|||
# mana-llm Model Aliases — single source of truth for which class of
|
||||
# model each backend feature uses.
|
||||
#
|
||||
# Consumers (mana-api, mana-ai, …) send `"model": "mana/<class>"` in
|
||||
# their /v1/chat/completions requests; mana-llm resolves the alias to
|
||||
# the chain below and tries entries in order, skipping providers that
|
||||
# the health-cache has marked unhealthy.
|
||||
#
|
||||
# Order in `chain` = preference. First healthy entry wins. Each chain
|
||||
# should end with a cloud provider so the system stays functional even
|
||||
# when the local GPU server (mana-gpu, RTX 3090) is offline.
|
||||
#
|
||||
# Reload at runtime: `kill -HUP <pid>` after editing — no restart needed.
|
||||
# Reference: docs/plans/llm-fallback-aliases.md.
|
||||
|
||||
aliases:
|
||||
mana/fast-text:
|
||||
description: "Short answers, classification, single-shot Q&A"
|
||||
chain:
|
||||
- ollama/qwen2.5:7b
|
||||
- groq/llama-3.1-8b-instant
|
||||
- openrouter/anthropic/claude-3-haiku
|
||||
|
||||
mana/long-form:
|
||||
description: "Writing, essays, stories, longer prose"
|
||||
chain:
|
||||
- ollama/gemma3:12b
|
||||
- groq/llama-3.3-70b-versatile
|
||||
- openrouter/anthropic/claude-3.5-haiku
|
||||
|
||||
mana/structured:
|
||||
description: "JSON output (comic storyboards, research subqueries, tag suggestions)"
|
||||
chain:
|
||||
- ollama/qwen2.5:7b
|
||||
- groq/llama-3.1-8b-instant
|
||||
- openrouter/openai/gpt-4o-mini
|
||||
|
||||
mana/reasoning:
|
||||
description: "Agent missions, tool calls, multi-step plans"
|
||||
# Cloud first by design — local 4-7B models are unreliable for tool calls
|
||||
chain:
|
||||
- openrouter/anthropic/claude-3.5-sonnet
|
||||
- groq/llama-3.3-70b-versatile
|
||||
|
||||
mana/vision:
|
||||
description: "Multimodal (image + text)"
|
||||
chain:
|
||||
- ollama/llava:7b
|
||||
- google/gemini-2.0-flash-exp
|
||||
- openrouter/openai/gpt-4o
|
||||
|
||||
# Default alias used when a request omits `model` or sends an unknown
|
||||
# value with no provider prefix. Keep this conservative (cheap class).
|
||||
default: mana/fast-text
|
||||
|
|
@ -1,15 +0,0 @@
|
|||
# Development compose - only Redis (run Python locally)
|
||||
version: "3.8"
|
||||
|
||||
services:
|
||||
redis:
|
||||
image: redis:7-alpine
|
||||
container_name: mana-llm-redis-dev
|
||||
ports:
|
||||
- "6380:6379"
|
||||
volumes:
|
||||
- redis-data-dev:/data
|
||||
restart: unless-stopped
|
||||
|
||||
volumes:
|
||||
redis-data-dev:
|
||||
|
|
@ -1,52 +0,0 @@
|
|||
version: "3.8"
|
||||
|
||||
services:
|
||||
mana-llm:
|
||||
build:
|
||||
context: .
|
||||
dockerfile: Dockerfile
|
||||
container_name: mana-llm
|
||||
ports:
|
||||
- "3025:3025"
|
||||
environment:
|
||||
- PORT=3025
|
||||
- LOG_LEVEL=info
|
||||
- OLLAMA_URL=http://host.docker.internal:11434
|
||||
- OLLAMA_DEFAULT_MODEL=gemma3:4b
|
||||
- OLLAMA_TIMEOUT=120
|
||||
- REDIS_URL=redis://redis:6379
|
||||
# Add API keys via .env file
|
||||
- GOOGLE_API_KEY=${GOOGLE_API_KEY:-}
|
||||
- GOOGLE_DEFAULT_MODEL=${GOOGLE_DEFAULT_MODEL:-gemini-2.5-flash}
|
||||
- OPENROUTER_API_KEY=${OPENROUTER_API_KEY:-}
|
||||
- GROQ_API_KEY=${GROQ_API_KEY:-}
|
||||
- TOGETHER_API_KEY=${TOGETHER_API_KEY:-}
|
||||
- CORS_ORIGINS=http://localhost:5173,http://localhost:3000,https://mana.how
|
||||
depends_on:
|
||||
- redis
|
||||
restart: unless-stopped
|
||||
healthcheck:
|
||||
test: ["CMD", "python", "-c", "import httpx; httpx.get('http://localhost:3025/health').raise_for_status()"]
|
||||
interval: 30s
|
||||
timeout: 10s
|
||||
retries: 3
|
||||
start_period: 10s
|
||||
extra_hosts:
|
||||
- "host.docker.internal:host-gateway"
|
||||
|
||||
redis:
|
||||
image: redis:7-alpine
|
||||
container_name: mana-llm-redis
|
||||
ports:
|
||||
- "6380:6379"
|
||||
volumes:
|
||||
- redis-data:/data
|
||||
restart: unless-stopped
|
||||
healthcheck:
|
||||
test: ["CMD", "redis-cli", "ping"]
|
||||
interval: 10s
|
||||
timeout: 5s
|
||||
retries: 5
|
||||
|
||||
volumes:
|
||||
redis-data:
|
||||
|
|
@ -1,40 +0,0 @@
|
|||
[project]
|
||||
name = "mana-llm"
|
||||
version = "0.1.0"
|
||||
description = "Central LLM abstraction service for Ollama and OpenAI-compatible APIs"
|
||||
requires-python = ">=3.11"
|
||||
dependencies = [
|
||||
"fastapi>=0.115.0",
|
||||
"uvicorn[standard]>=0.32.0",
|
||||
"pydantic>=2.10.0",
|
||||
"pydantic-settings>=2.6.0",
|
||||
"httpx>=0.28.0",
|
||||
"sse-starlette>=2.2.0",
|
||||
"redis>=5.2.0",
|
||||
"prometheus-client>=0.21.0",
|
||||
"google-genai>=1.0.0",
|
||||
"pyyaml>=6.0.2",
|
||||
]
|
||||
|
||||
[project.optional-dependencies]
|
||||
dev = [
|
||||
"pytest>=8.3.0",
|
||||
"pytest-asyncio>=0.24.0",
|
||||
"pytest-httpx>=0.35.0",
|
||||
"ruff>=0.8.0",
|
||||
]
|
||||
|
||||
[build-system]
|
||||
requires = ["hatchling"]
|
||||
build-backend = "hatchling.build"
|
||||
|
||||
[tool.ruff]
|
||||
line-length = 100
|
||||
target-version = "py311"
|
||||
|
||||
[tool.ruff.lint]
|
||||
select = ["E", "F", "I", "W"]
|
||||
|
||||
[tool.pytest.ini_options]
|
||||
asyncio_mode = "auto"
|
||||
testpaths = ["tests"]
|
||||
|
|
@ -1,29 +0,0 @@
|
|||
# Core
|
||||
fastapi>=0.115.0
|
||||
uvicorn[standard]>=0.32.0
|
||||
pydantic>=2.10.0
|
||||
pydantic-settings>=2.6.0
|
||||
|
||||
# HTTP Client
|
||||
httpx>=0.28.0
|
||||
|
||||
# Streaming
|
||||
sse-starlette>=2.2.0
|
||||
|
||||
# Caching (optional)
|
||||
redis>=5.2.0
|
||||
|
||||
# Google Gemini
|
||||
google-genai>=1.0.0
|
||||
|
||||
# Metrics
|
||||
prometheus-client>=0.21.0
|
||||
|
||||
# Config (alias registry)
|
||||
pyyaml>=6.0.2
|
||||
|
||||
# Dev
|
||||
pytest>=8.3.0
|
||||
pytest-asyncio>=0.24.0
|
||||
pytest-httpx>=0.35.0
|
||||
ruff>=0.8.0
|
||||
|
|
@ -1,17 +0,0 @@
|
|||
"""mana-llm service runner (run with pythonw.exe to run headless)."""
|
||||
import os
|
||||
import sys
|
||||
os.chdir(r"C:\mana\services\mana-llm")
|
||||
sys.path.insert(0, r"C:\mana\services\mana-llm")
|
||||
|
||||
# Load .env file
|
||||
from dotenv import load_dotenv
|
||||
load_dotenv(r"C:\mana\services\mana-llm\.env")
|
||||
|
||||
# Redirect stdout/stderr to log file
|
||||
log = open(r"C:\mana\services\mana-llm\service.log", "w", buffering=1)
|
||||
sys.stdout = log
|
||||
sys.stderr = log
|
||||
|
||||
import uvicorn
|
||||
uvicorn.run("src.main:app", host="0.0.0.0", port=3025, log_level="info")
|
||||
|
|
@ -1,3 +0,0 @@
|
|||
"""mana-llm - Central LLM abstraction service."""
|
||||
|
||||
__version__ = "0.1.0"
|
||||
|
|
@ -1,223 +0,0 @@
|
|||
"""Model-alias registry.
|
||||
|
||||
Loads `aliases.yaml` and exposes a small API the router uses to resolve
|
||||
semantic model names like ``mana/long-form`` to an ordered list of
|
||||
concrete provider-prefixed model strings (``ollama/gemma3:12b`` →
|
||||
``groq/llama-3.3-70b-versatile`` → …).
|
||||
|
||||
The registry is hot-reloadable: ``reload()`` rebuilds the in-memory
|
||||
mapping atomically. Reload errors leave the previous good state intact
|
||||
so a typo in the yaml file doesn't take the service down — caller logs
|
||||
the error and keeps serving.
|
||||
|
||||
See docs/plans/llm-fallback-aliases.md for the full design.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import threading
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
|
||||
import yaml
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Aliases live in this namespace. Anything else passed as `model` is
|
||||
# treated as a direct provider/model string (preserves the legal
|
||||
# bypass-the-alias-layer escape hatch for tests/debugging).
|
||||
ALIAS_PREFIX = "mana/"
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class Alias:
|
||||
"""A resolved alias entry."""
|
||||
|
||||
name: str
|
||||
description: str
|
||||
chain: tuple[str, ...]
|
||||
|
||||
|
||||
class AliasConfigError(ValueError):
|
||||
"""Raised when the YAML file is malformed or violates schema constraints."""
|
||||
|
||||
|
||||
class UnknownAliasError(KeyError):
|
||||
"""Raised when a caller asks for an alias that isn't defined."""
|
||||
|
||||
|
||||
def _validate_chain(name: str, chain: object) -> tuple[str, ...]:
|
||||
"""Schema-check a single alias chain. Returns the validated tuple."""
|
||||
if not isinstance(chain, list):
|
||||
raise AliasConfigError(f"alias '{name}': chain must be a list, got {type(chain).__name__}")
|
||||
if not chain:
|
||||
raise AliasConfigError(f"alias '{name}': chain must not be empty")
|
||||
out: list[str] = []
|
||||
for i, entry in enumerate(chain):
|
||||
if not isinstance(entry, str) or not entry.strip():
|
||||
raise AliasConfigError(
|
||||
f"alias '{name}': chain[{i}] must be a non-empty string, got {entry!r}"
|
||||
)
|
||||
if "/" not in entry:
|
||||
raise AliasConfigError(
|
||||
f"alias '{name}': chain[{i}] = {entry!r} must include a provider prefix "
|
||||
f"(e.g. 'ollama/...', 'groq/...')"
|
||||
)
|
||||
out.append(entry.strip())
|
||||
return tuple(out)
|
||||
|
||||
|
||||
def _validate_name(name: object) -> str:
|
||||
"""Aliases must live in the reserved `mana/` namespace."""
|
||||
if not isinstance(name, str) or not name.startswith(ALIAS_PREFIX):
|
||||
raise AliasConfigError(
|
||||
f"alias name {name!r} must start with {ALIAS_PREFIX!r} (the reserved namespace)"
|
||||
)
|
||||
suffix = name[len(ALIAS_PREFIX) :]
|
||||
if not suffix or "/" in suffix:
|
||||
raise AliasConfigError(
|
||||
f"alias name {name!r} must have exactly one segment after {ALIAS_PREFIX!r}"
|
||||
)
|
||||
return name
|
||||
|
||||
|
||||
def _parse_document(doc: object) -> tuple[dict[str, Alias], str | None]:
|
||||
"""Parse a loaded YAML document into a normalized (aliases, default) pair."""
|
||||
if not isinstance(doc, dict):
|
||||
raise AliasConfigError(f"yaml root must be a mapping, got {type(doc).__name__}")
|
||||
|
||||
raw_aliases = doc.get("aliases", {})
|
||||
if not isinstance(raw_aliases, dict):
|
||||
raise AliasConfigError(
|
||||
f"`aliases` must be a mapping, got {type(raw_aliases).__name__}"
|
||||
)
|
||||
if not raw_aliases:
|
||||
raise AliasConfigError("`aliases` is empty — at least one alias is required")
|
||||
|
||||
parsed: dict[str, Alias] = {}
|
||||
for name, body in raw_aliases.items():
|
||||
validated_name = _validate_name(name)
|
||||
if not isinstance(body, dict):
|
||||
raise AliasConfigError(
|
||||
f"alias '{validated_name}': body must be a mapping, got {type(body).__name__}"
|
||||
)
|
||||
description = body.get("description", "")
|
||||
if not isinstance(description, str):
|
||||
raise AliasConfigError(
|
||||
f"alias '{validated_name}': description must be a string"
|
||||
)
|
||||
chain = _validate_chain(validated_name, body.get("chain"))
|
||||
parsed[validated_name] = Alias(
|
||||
name=validated_name,
|
||||
description=description.strip(),
|
||||
chain=chain,
|
||||
)
|
||||
|
||||
default = doc.get("default")
|
||||
if default is not None:
|
||||
if not isinstance(default, str):
|
||||
raise AliasConfigError(f"`default` must be a string, got {type(default).__name__}")
|
||||
if default not in parsed:
|
||||
raise AliasConfigError(
|
||||
f"`default` references unknown alias {default!r} "
|
||||
f"(known: {sorted(parsed)})"
|
||||
)
|
||||
|
||||
return parsed, default
|
||||
|
||||
|
||||
class AliasRegistry:
|
||||
"""Thread-safe in-memory registry of model aliases.
|
||||
|
||||
Construct once at startup with the path to the yaml file. Call
|
||||
:meth:`reload` to re-read after a SIGHUP. Reads (``resolve``,
|
||||
``is_alias``, …) are cheap and lock-free during steady state — they
|
||||
snapshot the current mapping reference; only the swap on reload is
|
||||
serialized.
|
||||
"""
|
||||
|
||||
def __init__(self, path: Path | str):
|
||||
self._path = Path(path)
|
||||
self._lock = threading.Lock()
|
||||
self._aliases: dict[str, Alias] = {}
|
||||
self._default: str | None = None
|
||||
self._load()
|
||||
|
||||
@property
|
||||
def path(self) -> Path:
|
||||
return self._path
|
||||
|
||||
def _load(self) -> None:
|
||||
"""Initial load — propagates errors so a bad config fails fast at startup."""
|
||||
if not self._path.exists():
|
||||
raise AliasConfigError(f"alias config not found at {self._path}")
|
||||
with self._path.open("r", encoding="utf-8") as f:
|
||||
try:
|
||||
doc = yaml.safe_load(f)
|
||||
except yaml.YAMLError as e:
|
||||
raise AliasConfigError(f"failed to parse {self._path}: {e}") from e
|
||||
aliases, default = _parse_document(doc)
|
||||
# No lock needed during __init__ — nothing else can read yet.
|
||||
self._aliases = aliases
|
||||
self._default = default
|
||||
logger.info(
|
||||
"AliasRegistry loaded %d alias(es) from %s (default=%s)",
|
||||
len(aliases),
|
||||
self._path,
|
||||
default,
|
||||
)
|
||||
|
||||
def reload(self) -> None:
|
||||
"""Re-read the yaml file. On parse error, keep the previous state and raise.
|
||||
|
||||
Designed for SIGHUP: callers should ``try/except AliasConfigError``
|
||||
and log; do not crash the service on a typo.
|
||||
"""
|
||||
with self._path.open("r", encoding="utf-8") as f:
|
||||
try:
|
||||
doc = yaml.safe_load(f)
|
||||
except yaml.YAMLError as e:
|
||||
raise AliasConfigError(f"failed to parse {self._path}: {e}") from e
|
||||
aliases, default = _parse_document(doc)
|
||||
with self._lock:
|
||||
self._aliases = aliases
|
||||
self._default = default
|
||||
logger.info(
|
||||
"AliasRegistry reloaded %d alias(es) from %s (default=%s)",
|
||||
len(aliases),
|
||||
self._path,
|
||||
default,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def is_alias(name: str) -> bool:
|
||||
"""Cheap syntactic check — does this name live in the alias namespace?
|
||||
|
||||
Static; doesn't require a registry instance. Used by the router to
|
||||
decide whether to dispatch to the alias layer or pass through to
|
||||
provider-direct routing.
|
||||
"""
|
||||
return isinstance(name, str) and name.startswith(ALIAS_PREFIX)
|
||||
|
||||
def resolve(self, name: str) -> Alias:
|
||||
"""Look up the named alias. Raises :class:`UnknownAliasError` if absent."""
|
||||
try:
|
||||
return self._aliases[name]
|
||||
except KeyError as e:
|
||||
raise UnknownAliasError(
|
||||
f"unknown alias {name!r} (known: {sorted(self._aliases)})"
|
||||
) from e
|
||||
|
||||
def resolve_chain(self, name: str) -> tuple[str, ...]:
|
||||
"""Sugar for ``resolve(name).chain`` — the form the router actually wants."""
|
||||
return self.resolve(name).chain
|
||||
|
||||
@property
|
||||
def default_alias(self) -> str | None:
|
||||
"""The alias used when a request arrives with no recognizable model."""
|
||||
return self._default
|
||||
|
||||
def list_aliases(self) -> list[Alias]:
|
||||
"""All aliases as a snapshot list — for the GET /v1/aliases debug endpoint."""
|
||||
return [self._aliases[k] for k in sorted(self._aliases)]
|
||||
|
|
@ -1,53 +0,0 @@
|
|||
"""
|
||||
Simple API Key Authentication Middleware for GPU Services.
|
||||
|
||||
Checks X-API-Key header or ?api_key query parameter.
|
||||
Skips auth for /health, /docs, /openapi.json, /redoc endpoints.
|
||||
|
||||
Environment variables:
|
||||
GPU_API_KEY: Required API key (if empty, auth is disabled)
|
||||
GPU_REQUIRE_AUTH: Enable/disable auth (default: true if GPU_API_KEY is set)
|
||||
"""
|
||||
|
||||
import os
|
||||
import logging
|
||||
from fastapi import Request
|
||||
from fastapi.responses import JSONResponse
|
||||
from starlette.middleware.base import BaseHTTPMiddleware
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
GPU_API_KEY = os.getenv("GPU_API_KEY", "")
|
||||
GPU_REQUIRE_AUTH = os.getenv("GPU_REQUIRE_AUTH", "true" if GPU_API_KEY else "false").lower() == "true"
|
||||
|
||||
# Endpoints that don't require auth
|
||||
PUBLIC_PATHS = {"/health", "/docs", "/openapi.json", "/redoc", "/metrics"}
|
||||
|
||||
|
||||
class ApiKeyMiddleware(BaseHTTPMiddleware):
|
||||
async def dispatch(self, request: Request, call_next):
|
||||
# Skip auth if disabled
|
||||
if not GPU_REQUIRE_AUTH or not GPU_API_KEY:
|
||||
return await call_next(request)
|
||||
|
||||
# Skip auth for public endpoints
|
||||
if request.url.path in PUBLIC_PATHS:
|
||||
return await call_next(request)
|
||||
|
||||
# Check API key from header or query param
|
||||
api_key = request.headers.get("X-API-Key") or request.query_params.get("api_key")
|
||||
|
||||
if not api_key:
|
||||
return JSONResponse(
|
||||
status_code=401,
|
||||
content={"detail": "Missing API key. Provide X-API-Key header."},
|
||||
)
|
||||
|
||||
if api_key != GPU_API_KEY:
|
||||
logger.warning(f"Invalid API key attempt from {request.client.host if request.client else 'unknown'}")
|
||||
return JSONResponse(
|
||||
status_code=401,
|
||||
content={"detail": "Invalid API key."},
|
||||
)
|
||||
|
||||
return await call_next(request)
|
||||
|
|
@ -1,57 +0,0 @@
|
|||
"""Configuration settings for mana-llm service."""
|
||||
|
||||
from pydantic_settings import BaseSettings
|
||||
|
||||
|
||||
class Settings(BaseSettings):
|
||||
"""Application settings loaded from environment variables."""
|
||||
|
||||
# Service
|
||||
port: int = 3025
|
||||
log_level: str = "info"
|
||||
|
||||
# Ollama (Primary provider)
|
||||
ollama_url: str = "http://localhost:11434"
|
||||
ollama_default_model: str = "gemma3:4b"
|
||||
ollama_timeout: int = 120
|
||||
|
||||
# OpenRouter (Cloud fallback)
|
||||
openrouter_api_key: str | None = None
|
||||
openrouter_base_url: str = "https://openrouter.ai/api/v1"
|
||||
openrouter_default_model: str = "meta-llama/llama-3.1-8b-instruct"
|
||||
|
||||
# Groq (Optional)
|
||||
groq_api_key: str | None = None
|
||||
groq_base_url: str = "https://api.groq.com/openai/v1"
|
||||
|
||||
# Together (Optional)
|
||||
together_api_key: str | None = None
|
||||
together_base_url: str = "https://api.together.xyz/v1"
|
||||
|
||||
# Google Gemini (Fallback provider)
|
||||
google_api_key: str | None = None
|
||||
google_default_model: str = "gemini-2.5-flash"
|
||||
|
||||
# Auto-fallback: Ollama → Google when Ollama is overloaded/down
|
||||
auto_fallback_enabled: bool = True
|
||||
ollama_max_concurrent: int = 3
|
||||
|
||||
# Caching (Optional)
|
||||
redis_url: str | None = None
|
||||
cache_ttl: int = 3600
|
||||
|
||||
# CORS
|
||||
cors_origins: str = "http://localhost:5173,http://localhost:5190,https://mana.how,https://playground.mana.how"
|
||||
|
||||
@property
|
||||
def cors_origins_list(self) -> list[str]:
|
||||
"""Parse CORS origins from comma-separated string."""
|
||||
return [origin.strip() for origin in self.cors_origins.split(",")]
|
||||
|
||||
class Config:
|
||||
env_file = ".env"
|
||||
env_file_encoding = "utf-8"
|
||||
extra = "ignore"
|
||||
|
||||
|
||||
settings = Settings()
|
||||
|
|
@ -1,203 +0,0 @@
|
|||
"""Provider health cache.
|
||||
|
||||
Tracks per-provider liveness for the LLM router. The router reads
|
||||
:meth:`is_healthy` to decide whether to even try a provider in a chain;
|
||||
the probe loop and the call-site fallback handler write state via
|
||||
:meth:`mark_healthy` / :meth:`mark_unhealthy`.
|
||||
|
||||
Implements a simple circuit-breaker:
|
||||
|
||||
* The first failure flips no switch — providers occasionally have
|
||||
transient blips, we don't want to bounce off after a single 502.
|
||||
* After ``failure_threshold`` consecutive failures the provider is
|
||||
marked unhealthy for ``unhealthy_backoff`` seconds. During that
|
||||
window :meth:`is_healthy` returns ``False`` so the router fails
|
||||
fast straight to the next chain entry.
|
||||
* When the backoff expires :meth:`is_healthy` returns ``True`` again
|
||||
(half-open). The next call exercises the provider; success calls
|
||||
:meth:`mark_healthy` and fully resets state, failure re-arms the
|
||||
backoff window.
|
||||
|
||||
State is kept in a plain dict guarded by a ``threading.Lock``. All
|
||||
operations are short, lock-free reads of dict references aren't safe
|
||||
because we mutate state in-place — the lock keeps it boring. Probe
|
||||
loop runs in the asyncio loop alongside the router, but the lock
|
||||
costs are negligible at ~1 update/30s/provider.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import threading
|
||||
import time
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Callable, Iterable
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
#: Notification fired whenever a provider transitions between healthy and
|
||||
#: unhealthy. ``main.py`` wires this to the Prometheus gauge — but the
|
||||
#: cache itself stays metrics-agnostic so tests don't need to mock it.
|
||||
HealthChangeListener = Callable[[str, bool], None]
|
||||
|
||||
DEFAULT_FAILURE_THRESHOLD = 2
|
||||
DEFAULT_UNHEALTHY_BACKOFF_SEC = 60.0
|
||||
|
||||
|
||||
@dataclass
|
||||
class ProviderState:
|
||||
"""Per-provider liveness snapshot. All times are unix seconds."""
|
||||
|
||||
healthy: bool = True
|
||||
consecutive_failures: int = 0
|
||||
last_check: float = 0.0
|
||||
last_error: str | None = None
|
||||
unhealthy_until: float = 0.0
|
||||
"""When > now, the provider is currently in backoff (`is_healthy → False`)."""
|
||||
|
||||
|
||||
class ProviderHealthCache:
|
||||
"""Thread-safe per-provider liveness with circuit-breaker semantics.
|
||||
|
||||
Provider IDs are arbitrary strings — by convention we use the same
|
||||
short name as the provider router (``ollama``, ``groq``, ``openrouter``,
|
||||
``together``, ``google``). The cache is provider-list agnostic; states
|
||||
are created lazily on first ``mark_*`` or queried-but-absent ``is_healthy``
|
||||
call (returning ``True`` by default — no state means no reason to skip).
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
failure_threshold: int = DEFAULT_FAILURE_THRESHOLD,
|
||||
unhealthy_backoff_sec: float = DEFAULT_UNHEALTHY_BACKOFF_SEC,
|
||||
clock: callable = time.time,
|
||||
) -> None:
|
||||
if failure_threshold < 1:
|
||||
raise ValueError("failure_threshold must be >= 1")
|
||||
if unhealthy_backoff_sec < 0:
|
||||
raise ValueError("unhealthy_backoff_sec must be >= 0")
|
||||
self._failure_threshold = failure_threshold
|
||||
self._unhealthy_backoff = unhealthy_backoff_sec
|
||||
self._clock = clock
|
||||
self._lock = threading.Lock()
|
||||
self._states: dict[str, ProviderState] = {}
|
||||
self._listeners: list[HealthChangeListener] = []
|
||||
|
||||
def add_listener(self, listener: HealthChangeListener) -> None:
|
||||
"""Register a callback fired with ``(provider_id, healthy: bool)``
|
||||
whenever a provider's healthy-flag transitions. Listeners run
|
||||
outside the cache's lock; exceptions are swallowed and logged so a
|
||||
bad listener can't break the underlying state machine.
|
||||
"""
|
||||
self._listeners.append(listener)
|
||||
|
||||
def _notify(self, provider_id: str, healthy: bool) -> None:
|
||||
for listener in self._listeners:
|
||||
try:
|
||||
listener(provider_id, healthy)
|
||||
except Exception as e: # noqa: BLE001
|
||||
logger.error("health-change listener raised: %s", e)
|
||||
|
||||
@property
|
||||
def failure_threshold(self) -> int:
|
||||
return self._failure_threshold
|
||||
|
||||
@property
|
||||
def unhealthy_backoff_sec(self) -> float:
|
||||
return self._unhealthy_backoff
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Reads
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def is_healthy(self, provider_id: str) -> bool:
|
||||
"""Should the router try this provider right now?
|
||||
|
||||
Returns ``True`` by default for unknown providers — the cache is
|
||||
observation-only, not a registry.
|
||||
"""
|
||||
with self._lock:
|
||||
state = self._states.get(provider_id)
|
||||
if state is None:
|
||||
return True
|
||||
if state.unhealthy_until > self._clock():
|
||||
return False
|
||||
# Backoff expired: caller is allowed to try again (half-open).
|
||||
return True
|
||||
|
||||
def get_state(self, provider_id: str) -> ProviderState | None:
|
||||
"""Snapshot of one provider's state (for debugging / tests)."""
|
||||
with self._lock:
|
||||
state = self._states.get(provider_id)
|
||||
return None if state is None else _copy(state)
|
||||
|
||||
def snapshot(self, expected: Iterable[str] | None = None) -> dict[str, ProviderState]:
|
||||
"""All known states, plus zero-state placeholders for any names in
|
||||
``expected`` that haven't been touched yet. Used by ``GET /v1/health``
|
||||
so the response shape is stable regardless of probe order.
|
||||
"""
|
||||
with self._lock:
|
||||
out = {pid: _copy(s) for pid, s in self._states.items()}
|
||||
if expected:
|
||||
for pid in expected:
|
||||
out.setdefault(pid, ProviderState())
|
||||
return out
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Writes
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def mark_healthy(self, provider_id: str) -> None:
|
||||
"""Provider answered correctly — clear any failure state."""
|
||||
transitioned = False
|
||||
with self._lock:
|
||||
state = self._states.setdefault(provider_id, ProviderState())
|
||||
transitioned = not state.healthy
|
||||
state.healthy = True
|
||||
state.consecutive_failures = 0
|
||||
state.last_check = self._clock()
|
||||
state.last_error = None
|
||||
state.unhealthy_until = 0.0
|
||||
if transitioned:
|
||||
logger.info("provider %s recovered", provider_id)
|
||||
self._notify(provider_id, True)
|
||||
|
||||
def mark_unhealthy(self, provider_id: str, reason: str) -> None:
|
||||
"""Record a failure. Trips the breaker after the threshold."""
|
||||
transitioned = False
|
||||
with self._lock:
|
||||
state = self._states.setdefault(provider_id, ProviderState())
|
||||
state.consecutive_failures += 1
|
||||
state.last_check = self._clock()
|
||||
state.last_error = reason
|
||||
tripped = state.consecutive_failures >= self._failure_threshold
|
||||
if tripped and state.healthy:
|
||||
state.healthy = False
|
||||
state.unhealthy_until = self._clock() + self._unhealthy_backoff
|
||||
transitioned = True
|
||||
logger.warning(
|
||||
"provider %s marked unhealthy after %d consecutive failures (%s); "
|
||||
"backoff %.0fs",
|
||||
provider_id,
|
||||
state.consecutive_failures,
|
||||
reason,
|
||||
self._unhealthy_backoff,
|
||||
)
|
||||
elif not state.healthy:
|
||||
# Still in unhealthy window; refresh the backoff so a flapping
|
||||
# provider doesn't get re-tried every probe tick.
|
||||
state.unhealthy_until = self._clock() + self._unhealthy_backoff
|
||||
if transitioned:
|
||||
self._notify(provider_id, False)
|
||||
|
||||
|
||||
def _copy(state: ProviderState) -> ProviderState:
|
||||
"""Return a shallow copy so callers can read without holding the lock."""
|
||||
return ProviderState(
|
||||
healthy=state.healthy,
|
||||
consecutive_failures=state.consecutive_failures,
|
||||
last_check=state.last_check,
|
||||
last_error=state.last_error,
|
||||
unhealthy_until=state.unhealthy_until,
|
||||
)
|
||||
|
|
@ -1,210 +0,0 @@
|
|||
"""Background probe loop that keeps :class:`ProviderHealthCache` fresh.
|
||||
|
||||
The router's circuit-breaker is reactive — it only learns a provider is
|
||||
sick after the next live request fails. A reactive-only design means:
|
||||
|
||||
* every cold start re-discovers Ollama is down by paying one 75-second
|
||||
``ConnectError``, and
|
||||
* a provider that quietly recovers stays marked unhealthy until its
|
||||
backoff expires and someone tries it.
|
||||
|
||||
The probe loop closes both gaps. Every ``interval`` seconds it pings
|
||||
each registered provider with a small known-cheap request (Ollama:
|
||||
``GET /api/tags``, OpenAI-compat: ``GET /v1/models``) and updates the
|
||||
cache. Probes run concurrently per tick and respect a hard
|
||||
``probe_timeout`` so a hanging provider can't stall the loop.
|
||||
|
||||
Probe functions are injected from outside (``probes={name: async-fn}``)
|
||||
so this module stays decoupled from the provider classes — wiring lives
|
||||
in ``main.py`` where we already know which providers are configured.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from typing import Awaitable, Callable
|
||||
|
||||
from .health import ProviderHealthCache
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
#: Probe function: returns ``True`` for healthy. Raising or returning
|
||||
#: ``False`` both count as a failure (the loop just calls
|
||||
#: ``mark_unhealthy``); an exception's string form becomes the
|
||||
#: ``last_error`` for the snapshot endpoint.
|
||||
ProbeFn = Callable[[], Awaitable[bool]]
|
||||
|
||||
DEFAULT_PROBE_INTERVAL_SEC = 30.0
|
||||
DEFAULT_PROBE_TIMEOUT_SEC = 3.0
|
||||
|
||||
|
||||
class HealthProbe:
|
||||
"""Periodically probes every registered provider and updates the cache.
|
||||
|
||||
Lifecycle::
|
||||
|
||||
probe = HealthProbe(cache, {"ollama": probe_ollama, ...})
|
||||
await probe.start() # spawns the background task
|
||||
...
|
||||
await probe.stop() # cancels and awaits cleanup
|
||||
|
||||
Tests typically call :meth:`tick_once` directly to exercise one cycle
|
||||
without driving the asyncio scheduler through real ``asyncio.sleep``.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
cache: ProviderHealthCache,
|
||||
probes: dict[str, ProbeFn],
|
||||
*,
|
||||
interval: float = DEFAULT_PROBE_INTERVAL_SEC,
|
||||
timeout: float = DEFAULT_PROBE_TIMEOUT_SEC,
|
||||
) -> None:
|
||||
if interval <= 0:
|
||||
raise ValueError("interval must be > 0")
|
||||
if timeout <= 0:
|
||||
raise ValueError("timeout must be > 0")
|
||||
self._cache = cache
|
||||
self._probes = dict(probes)
|
||||
self._interval = interval
|
||||
self._timeout = timeout
|
||||
self._task: asyncio.Task | None = None
|
||||
self._stop = asyncio.Event()
|
||||
|
||||
@property
|
||||
def interval(self) -> float:
|
||||
return self._interval
|
||||
|
||||
@property
|
||||
def timeout(self) -> float:
|
||||
return self._timeout
|
||||
|
||||
@property
|
||||
def provider_ids(self) -> list[str]:
|
||||
return list(self._probes)
|
||||
|
||||
@property
|
||||
def running(self) -> bool:
|
||||
return self._task is not None and not self._task.done()
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Per-tick logic — exercised directly by tests
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
async def tick_once(self) -> None:
|
||||
"""Probe every provider once, in parallel, updating the cache.
|
||||
|
||||
Errors in any one probe (including ``asyncio.TimeoutError``) are
|
||||
captured per-provider — one bad probe never sinks the loop.
|
||||
"""
|
||||
if not self._probes:
|
||||
return
|
||||
results = await asyncio.gather(
|
||||
*(self._probe_one(name, fn) for name, fn in self._probes.items()),
|
||||
return_exceptions=True,
|
||||
)
|
||||
# gather(return_exceptions=True) caught everything per-probe; this
|
||||
# branch should never fire, but guard in case _probe_one ever grows
|
||||
# a code path that bypasses its try/except.
|
||||
# asyncio.gather() returns results in input order and same length,
|
||||
# so the zip is a 1:1 mapping back to provider names.
|
||||
for name, result in zip(self._probes, results):
|
||||
if isinstance(result, BaseException):
|
||||
logger.error("probe %s leaked exception: %s", name, result)
|
||||
|
||||
async def _probe_one(self, name: str, fn: ProbeFn) -> None:
|
||||
try:
|
||||
healthy = await asyncio.wait_for(fn(), timeout=self._timeout)
|
||||
except asyncio.TimeoutError:
|
||||
self._cache.mark_unhealthy(name, f"probe-timeout (>{self._timeout:.0f}s)")
|
||||
return
|
||||
except asyncio.CancelledError:
|
||||
raise
|
||||
except Exception as e: # noqa: BLE001 — probe SHOULD be permissive
|
||||
self._cache.mark_unhealthy(name, f"probe-exception: {type(e).__name__}: {e}")
|
||||
return
|
||||
if healthy:
|
||||
self._cache.mark_healthy(name)
|
||||
else:
|
||||
self._cache.mark_unhealthy(name, "probe-returned-false")
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Long-running task management
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
async def start(self) -> None:
|
||||
"""Spawn the periodic probe task. Idempotent."""
|
||||
if self.running:
|
||||
return
|
||||
self._stop.clear()
|
||||
self._task = asyncio.create_task(self._run_forever(), name="mana-llm-health-probe")
|
||||
logger.info(
|
||||
"HealthProbe started (interval=%.0fs, timeout=%.0fs, providers=%s)",
|
||||
self._interval,
|
||||
self._timeout,
|
||||
", ".join(self._probes) or "<none>",
|
||||
)
|
||||
|
||||
async def stop(self) -> None:
|
||||
"""Cancel the background task and wait for it to finish."""
|
||||
if not self.running:
|
||||
return
|
||||
self._stop.set()
|
||||
assert self._task is not None
|
||||
self._task.cancel()
|
||||
try:
|
||||
await self._task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
finally:
|
||||
self._task = None
|
||||
logger.info("HealthProbe stopped")
|
||||
|
||||
async def _run_forever(self) -> None:
|
||||
# Probe immediately at boot so we don't serve traffic for `interval`
|
||||
# seconds based on optimistic-default assumptions.
|
||||
try:
|
||||
await self.tick_once()
|
||||
except Exception as e: # noqa: BLE001
|
||||
logger.error("HealthProbe initial tick failed: %s", e)
|
||||
while not self._stop.is_set():
|
||||
try:
|
||||
await asyncio.wait_for(self._stop.wait(), timeout=self._interval)
|
||||
except asyncio.TimeoutError:
|
||||
pass
|
||||
else:
|
||||
# _stop.wait() succeeded → stop signalled, exit.
|
||||
return
|
||||
try:
|
||||
await self.tick_once()
|
||||
except Exception as e: # noqa: BLE001
|
||||
logger.error("HealthProbe tick failed: %s", e)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Probe-function helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def make_http_probe(
|
||||
url: str,
|
||||
*,
|
||||
headers: dict[str, str] | None = None,
|
||||
expected_status_lt: int = 500,
|
||||
) -> ProbeFn:
|
||||
"""Return a probe function that does ``GET <url>`` and considers the
|
||||
provider healthy iff the response status is below
|
||||
``expected_status_lt`` (default: any non-5xx counts).
|
||||
|
||||
A 401/403/404 still counts as healthy because the *server* answered —
|
||||
auth or path mistakes are misconfiguration, not provider liveness.
|
||||
"""
|
||||
import httpx
|
||||
|
||||
async def probe() -> bool:
|
||||
async with httpx.AsyncClient(timeout=httpx.Timeout(5.0)) as client:
|
||||
resp = await client.get(url, headers=headers or None)
|
||||
return resp.status_code < expected_status_lt
|
||||
|
||||
return probe
|
||||
|
|
@ -1,405 +0,0 @@
|
|||
"""Main FastAPI application for mana-llm service."""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import signal
|
||||
import time
|
||||
from contextlib import asynccontextmanager
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from fastapi import FastAPI, HTTPException, Request
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.responses import JSONResponse, Response
|
||||
from sse_starlette.sse import EventSourceResponse
|
||||
|
||||
from src.aliases import AliasConfigError, AliasRegistry
|
||||
from src.api_auth import ApiKeyMiddleware
|
||||
from src.config import settings
|
||||
from src.health import ProviderHealthCache
|
||||
from src.health_probe import HealthProbe, make_http_probe
|
||||
from src.models import (
|
||||
ChatCompletionRequest,
|
||||
ChatCompletionResponse,
|
||||
EmbeddingRequest,
|
||||
EmbeddingResponse,
|
||||
ModelInfo,
|
||||
ModelsResponse,
|
||||
)
|
||||
from src.providers import ProviderRouter
|
||||
from src.providers.errors import ProviderError
|
||||
from src.streaming import stream_chat_completion
|
||||
from src.utils.cache import close_redis
|
||||
from src.utils.metrics import (
|
||||
get_metrics,
|
||||
record_llm_error,
|
||||
record_llm_request,
|
||||
set_provider_healthy,
|
||||
)
|
||||
|
||||
#: Header carrying the concrete provider/model that actually served a
|
||||
#: non-streaming response — useful for token-cost accounting on the
|
||||
#: caller side, since `mana/long-form` could resolve to ollama, groq,
|
||||
#: or claude depending on which providers were healthy at request time.
|
||||
RESOLVED_MODEL_HEADER = "X-Mana-LLM-Resolved"
|
||||
|
||||
# Configure logging
|
||||
logging.basicConfig(
|
||||
level=getattr(logging, settings.log_level.upper()),
|
||||
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
|
||||
)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Global service singletons
|
||||
router: ProviderRouter | None = None
|
||||
health_cache: ProviderHealthCache | None = None
|
||||
health_probe: HealthProbe | None = None
|
||||
alias_registry: AliasRegistry | None = None
|
||||
|
||||
|
||||
def _build_provider_probes(
|
||||
providers: dict[str, Any],
|
||||
) -> dict[str, Any]:
|
||||
"""Wire each configured provider to a cheap HTTP probe."""
|
||||
probes: dict[str, Any] = {}
|
||||
|
||||
if "ollama" in providers:
|
||||
probes["ollama"] = make_http_probe(f"{settings.ollama_url}/api/tags")
|
||||
if "openrouter" in providers:
|
||||
probes["openrouter"] = make_http_probe(
|
||||
f"{settings.openrouter_base_url}/models",
|
||||
headers={"Authorization": f"Bearer {settings.openrouter_api_key}"},
|
||||
)
|
||||
if "groq" in providers:
|
||||
probes["groq"] = make_http_probe(
|
||||
f"{settings.groq_base_url}/models",
|
||||
headers={"Authorization": f"Bearer {settings.groq_api_key}"},
|
||||
)
|
||||
if "together" in providers:
|
||||
probes["together"] = make_http_probe(
|
||||
f"{settings.together_base_url}/models",
|
||||
headers={"Authorization": f"Bearer {settings.together_api_key}"},
|
||||
)
|
||||
# Google: skipped — google-genai SDK is opaque enough that a probe
|
||||
# would amount to a real API call. Treat as healthy by default; the
|
||||
# router's call-site fallback will mark it unhealthy on real errors.
|
||||
return probes
|
||||
|
||||
|
||||
def _on_health_change(provider: str, healthy: bool) -> None:
|
||||
"""Mirror health-cache transitions into the Prometheus gauge."""
|
||||
set_provider_healthy(provider, healthy)
|
||||
|
||||
|
||||
def _install_sighup_reload(loop: asyncio.AbstractEventLoop) -> None:
|
||||
"""Reload ``aliases.yaml`` when the process receives SIGHUP.
|
||||
|
||||
Reload errors keep the previous good state in memory (see
|
||||
AliasRegistry.reload). SIGHUP isn't available on Windows; we just
|
||||
log and skip in that case.
|
||||
"""
|
||||
|
||||
def handler() -> None:
|
||||
if alias_registry is None:
|
||||
return
|
||||
try:
|
||||
alias_registry.reload()
|
||||
except AliasConfigError as e:
|
||||
logger.error("alias reload rejected, keeping previous state: %s", e)
|
||||
|
||||
try:
|
||||
loop.add_signal_handler(signal.SIGHUP, handler)
|
||||
except (NotImplementedError, AttributeError, RuntimeError):
|
||||
# NotImplementedError on Windows; RuntimeError when the loop is
|
||||
# not running in the main thread (TestClient does this).
|
||||
logger.info("SIGHUP reload not available in this context — skipping")
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
"""Application lifespan: load aliases, spin up router + health probe."""
|
||||
global router, health_cache, health_probe, alias_registry
|
||||
|
||||
logger.info("Starting mana-llm service...")
|
||||
|
||||
aliases_path = Path(__file__).resolve().parent.parent / "aliases.yaml"
|
||||
alias_registry = AliasRegistry(aliases_path)
|
||||
logger.info("Loaded %d aliases from %s", len(alias_registry.list_aliases()), aliases_path)
|
||||
|
||||
health_cache = ProviderHealthCache()
|
||||
health_cache.add_listener(_on_health_change)
|
||||
router = ProviderRouter(aliases=alias_registry, health_cache=health_cache)
|
||||
logger.info("Initialized providers: %s", list(router.providers))
|
||||
|
||||
# Initial gauge values so dashboards render before the first probe
|
||||
# transition fires the listener.
|
||||
for provider_name in router.providers:
|
||||
set_provider_healthy(provider_name, True)
|
||||
|
||||
health_probe = HealthProbe(health_cache, _build_provider_probes(router.providers))
|
||||
await health_probe.start()
|
||||
|
||||
_install_sighup_reload(asyncio.get_running_loop())
|
||||
|
||||
yield
|
||||
|
||||
logger.info("Shutting down mana-llm service...")
|
||||
if health_probe is not None:
|
||||
await health_probe.stop()
|
||||
if router is not None:
|
||||
await router.close()
|
||||
await close_redis()
|
||||
|
||||
|
||||
# Create FastAPI app
|
||||
app = FastAPI(
|
||||
title="mana-llm",
|
||||
description="Central LLM abstraction service for Ollama and OpenAI-compatible APIs",
|
||||
version="0.1.0",
|
||||
lifespan=lifespan,
|
||||
)
|
||||
|
||||
# Add CORS middleware
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=settings.cors_origins_list,
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
app.add_middleware(ApiKeyMiddleware)
|
||||
|
||||
|
||||
# Health endpoint
|
||||
@app.get("/health")
|
||||
async def health_check() -> dict[str, Any]:
|
||||
"""Check service health and provider status."""
|
||||
if router is None:
|
||||
return {"status": "unhealthy", "error": "Router not initialized"}
|
||||
|
||||
provider_health = await router.health_check()
|
||||
return {
|
||||
"status": provider_health["status"],
|
||||
"service": "mana-llm",
|
||||
"version": "0.1.0",
|
||||
"providers": provider_health["providers"],
|
||||
}
|
||||
|
||||
|
||||
# Metrics endpoint
|
||||
@app.get("/metrics")
|
||||
async def metrics() -> Response:
|
||||
"""Prometheus metrics endpoint."""
|
||||
return Response(content=get_metrics(), media_type="text/plain")
|
||||
|
||||
|
||||
# ----- Alias / health debug endpoints (M4) ---------------------------------
|
||||
|
||||
|
||||
@app.get("/v1/aliases")
|
||||
async def list_aliases() -> dict[str, Any]:
|
||||
"""Inspect the alias registry — what each ``mana/<class>`` resolves to.
|
||||
|
||||
Useful for debugging "which model actually answered my request" and
|
||||
for confirming SIGHUP reloads picked up edits to ``aliases.yaml``.
|
||||
"""
|
||||
if alias_registry is None:
|
||||
raise HTTPException(status_code=503, detail="Service not ready")
|
||||
return {
|
||||
"default": alias_registry.default_alias,
|
||||
"aliases": [
|
||||
{
|
||||
"name": a.name,
|
||||
"description": a.description,
|
||||
"chain": list(a.chain),
|
||||
}
|
||||
for a in alias_registry.list_aliases()
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
@app.get("/v1/health")
|
||||
async def detailed_health() -> dict[str, Any]:
|
||||
"""Full per-provider liveness snapshot.
|
||||
|
||||
Includes the failure counter, last error, and the unhealthy-until
|
||||
backoff timestamp — info the original ``/health`` endpoint hides.
|
||||
"""
|
||||
if router is None:
|
||||
raise HTTPException(status_code=503, detail="Service not ready")
|
||||
return await router.health_check()
|
||||
|
||||
|
||||
# Models endpoints
|
||||
@app.get("/v1/models", response_model=ModelsResponse)
|
||||
async def list_models() -> ModelsResponse:
|
||||
"""List all available models from all providers."""
|
||||
if router is None:
|
||||
raise HTTPException(status_code=503, detail="Service not ready")
|
||||
|
||||
models = await router.list_models()
|
||||
return ModelsResponse(data=models)
|
||||
|
||||
|
||||
@app.get("/v1/models/{model_id:path}")
|
||||
async def get_model(model_id: str) -> ModelInfo:
|
||||
"""Get specific model information."""
|
||||
if router is None:
|
||||
raise HTTPException(status_code=503, detail="Service not ready")
|
||||
|
||||
model = await router.get_model(model_id)
|
||||
if model is None:
|
||||
raise HTTPException(status_code=404, detail=f"Model '{model_id}' not found")
|
||||
|
||||
return model
|
||||
|
||||
|
||||
# Chat completions endpoint
|
||||
@app.post("/v1/chat/completions", response_model=None)
|
||||
async def chat_completions(
|
||||
request: ChatCompletionRequest,
|
||||
http_request: Request,
|
||||
) -> ChatCompletionResponse | EventSourceResponse:
|
||||
"""
|
||||
Create a chat completion.
|
||||
|
||||
Supports both streaming (SSE) and non-streaming responses based on the
|
||||
`stream` parameter in the request body.
|
||||
"""
|
||||
if router is None:
|
||||
raise HTTPException(status_code=503, detail="Service not ready")
|
||||
|
||||
# The request's `model` field is what the caller asked for — could be
|
||||
# `mana/long-form`, `ollama/gemma3:4b`, or even bare `gemma3:4b`. For
|
||||
# error-path metrics we use that value (it's what the caller will
|
||||
# search for); for success-path metrics we use the resolved provider
|
||||
# so token-cost / latency attribute to the model that actually ran.
|
||||
requested_provider, requested_model = _split_model(request.model)
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
try:
|
||||
if request.stream:
|
||||
# Streaming response via SSE
|
||||
logger.info(f"Streaming chat completion: {request.model}")
|
||||
|
||||
async def generate():
|
||||
async for chunk in stream_chat_completion(router, request):
|
||||
yield chunk
|
||||
|
||||
# Streaming metrics: we don't yet know which provider answered
|
||||
# at request-record time. Each chunk's `model` field carries
|
||||
# the resolved name; per-token latency is harder to attribute
|
||||
# cleanly so we skip it for streams.
|
||||
record_llm_request(requested_provider, requested_model, streaming=True)
|
||||
|
||||
return EventSourceResponse(
|
||||
generate(),
|
||||
media_type="text/event-stream",
|
||||
)
|
||||
else:
|
||||
# Non-streaming response
|
||||
logger.info(f"Chat completion: {request.model}")
|
||||
response = await router.chat_completion(request)
|
||||
|
||||
resolved_provider, resolved_model = _split_model(response.model)
|
||||
latency = time.time() - start_time
|
||||
record_llm_request(
|
||||
provider=resolved_provider,
|
||||
model=resolved_model,
|
||||
streaming=False,
|
||||
prompt_tokens=response.usage.prompt_tokens,
|
||||
completion_tokens=response.usage.completion_tokens,
|
||||
latency=latency,
|
||||
)
|
||||
|
||||
# `response.model` is the concrete provider/model the chain
|
||||
# actually resolved to. Surface it via header so the caller
|
||||
# can attribute token cost to the right model even when the
|
||||
# request used an alias.
|
||||
return JSONResponse(
|
||||
content=response.model_dump(),
|
||||
headers={RESOLVED_MODEL_HEADER: response.model},
|
||||
)
|
||||
|
||||
except ValueError as e:
|
||||
logger.error(f"Invalid request: {e}")
|
||||
record_llm_error(requested_provider, requested_model, "invalid_request")
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
except ProviderError as e:
|
||||
logger.warning(
|
||||
f"Provider error on {requested_provider}/{requested_model}: "
|
||||
f"kind={e.kind} detail={e}"
|
||||
)
|
||||
record_llm_error(requested_provider, requested_model, e.kind)
|
||||
raise HTTPException(
|
||||
status_code=e.http_status,
|
||||
detail={"kind": e.kind, "message": str(e)},
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Chat completion failed: {e}")
|
||||
record_llm_error(requested_provider, requested_model, "server_error")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
# Embeddings endpoint
|
||||
@app.post("/v1/embeddings", response_model=EmbeddingResponse)
|
||||
async def create_embeddings(request: EmbeddingRequest) -> EmbeddingResponse:
|
||||
"""Create embeddings for the input text."""
|
||||
if router is None:
|
||||
raise HTTPException(status_code=503, detail="Service not ready")
|
||||
|
||||
provider, model = _split_model(request.model)
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
try:
|
||||
logger.info(f"Creating embeddings: {request.model}")
|
||||
response = await router.embeddings(request)
|
||||
|
||||
latency = time.time() - start_time
|
||||
record_llm_request(
|
||||
provider=provider,
|
||||
model=model,
|
||||
streaming=False,
|
||||
prompt_tokens=response.usage.prompt_tokens,
|
||||
latency=latency,
|
||||
)
|
||||
|
||||
return response
|
||||
|
||||
except ValueError as e:
|
||||
logger.error(f"Invalid embedding request: {e}")
|
||||
record_llm_error(provider, model, "invalid_request")
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
except Exception as e:
|
||||
logger.error(f"Embeddings failed: {e}")
|
||||
record_llm_error(provider, model, "server_error")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
def _split_model(model: str) -> tuple[str, str]:
|
||||
"""Split a ``provider/model`` string for metric labelling.
|
||||
|
||||
Bare names with no slash default to ``ollama`` to match the legacy
|
||||
OpenAI-style behaviour. Aliases (``mana/...``) keep their namespace
|
||||
in the metrics — that's intentional, so request-side counters tell
|
||||
you what callers ASKED for, while the resolved-side counters
|
||||
(``mana_llm_alias_resolved_total``) tell you what they GOT.
|
||||
"""
|
||||
if "/" in model:
|
||||
provider, _, name = model.partition("/")
|
||||
return provider.lower(), name
|
||||
return "ollama", model
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import uvicorn
|
||||
|
||||
uvicorn.run(
|
||||
"src.main:app",
|
||||
host="0.0.0.0",
|
||||
port=settings.port,
|
||||
reload=True,
|
||||
log_level=settings.log_level,
|
||||
)
|
||||
|
|
@ -1,47 +0,0 @@
|
|||
"""Pydantic models for OpenAI-compatible API."""
|
||||
|
||||
from .requests import (
|
||||
ChatCompletionRequest,
|
||||
EmbeddingRequest,
|
||||
FunctionSpec,
|
||||
Message,
|
||||
ToolChoice,
|
||||
ToolSpec,
|
||||
)
|
||||
from .responses import (
|
||||
ChatCompletionResponse,
|
||||
ChatCompletionStreamResponse,
|
||||
Choice,
|
||||
DeltaContent,
|
||||
EmbeddingData,
|
||||
EmbeddingResponse,
|
||||
MessageResponse,
|
||||
ModelInfo,
|
||||
ModelsResponse,
|
||||
StreamChoice,
|
||||
ToolCall,
|
||||
ToolCallFunction,
|
||||
Usage,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"ChatCompletionRequest",
|
||||
"ChatCompletionResponse",
|
||||
"ChatCompletionStreamResponse",
|
||||
"Choice",
|
||||
"DeltaContent",
|
||||
"EmbeddingData",
|
||||
"EmbeddingRequest",
|
||||
"EmbeddingResponse",
|
||||
"FunctionSpec",
|
||||
"Message",
|
||||
"MessageResponse",
|
||||
"ModelInfo",
|
||||
"ModelsResponse",
|
||||
"StreamChoice",
|
||||
"ToolCall",
|
||||
"ToolCallFunction",
|
||||
"ToolChoice",
|
||||
"ToolSpec",
|
||||
"Usage",
|
||||
]
|
||||
|
|
@ -1,128 +0,0 @@
|
|||
"""Request models for OpenAI-compatible API."""
|
||||
|
||||
from typing import Any, Literal
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class TextContent(BaseModel):
|
||||
"""Text content in a message."""
|
||||
|
||||
type: Literal["text"] = "text"
|
||||
text: str
|
||||
|
||||
|
||||
class ImageUrl(BaseModel):
|
||||
"""Image URL reference."""
|
||||
|
||||
url: str # Can be http(s):// or data:image/...;base64,...
|
||||
|
||||
|
||||
class ImageContent(BaseModel):
|
||||
"""Image content in a message."""
|
||||
|
||||
type: Literal["image_url"] = "image_url"
|
||||
image_url: ImageUrl
|
||||
|
||||
|
||||
MessageContent = str | list[TextContent | ImageContent]
|
||||
|
||||
|
||||
class ToolCallFunction(BaseModel):
|
||||
"""The function portion of a tool_call on an assistant message."""
|
||||
|
||||
name: str
|
||||
# Arguments are passed as a JSON string (OpenAI spec). Providers may
|
||||
# emit structured args natively; the adapter serialises them here.
|
||||
arguments: str
|
||||
|
||||
|
||||
class ToolCall(BaseModel):
|
||||
"""A tool invocation the assistant decided to make."""
|
||||
|
||||
id: str
|
||||
type: Literal["function"] = "function"
|
||||
function: ToolCallFunction
|
||||
|
||||
|
||||
class Message(BaseModel):
|
||||
"""A single message in the conversation.
|
||||
|
||||
`tool` messages carry the result of a previously-requested tool call
|
||||
back into the context; they must reference the originating call via
|
||||
``tool_call_id``. Assistant messages may contain either plain
|
||||
``content`` or ``tool_calls`` (or both, though providers typically
|
||||
only emit one at a time).
|
||||
"""
|
||||
|
||||
role: Literal["system", "user", "assistant", "tool"]
|
||||
content: MessageContent | None = None
|
||||
tool_call_id: str | None = None
|
||||
tool_calls: list[ToolCall] | None = None
|
||||
|
||||
|
||||
class FunctionSpec(BaseModel):
|
||||
"""OpenAI-style function declaration."""
|
||||
|
||||
name: str
|
||||
description: str
|
||||
# JSON Schema for the function's parameters. Kept as a dict here to
|
||||
# stay loose; providers translate it to their native shape.
|
||||
parameters: dict[str, Any] = Field(default_factory=dict)
|
||||
|
||||
|
||||
class ToolSpec(BaseModel):
|
||||
"""A tool the model may call. Only ``function`` tools are supported."""
|
||||
|
||||
type: Literal["function"] = "function"
|
||||
function: FunctionSpec
|
||||
|
||||
|
||||
class ToolChoiceFunction(BaseModel):
|
||||
"""Force-pick a specific function for ``tool_choice``."""
|
||||
|
||||
type: Literal["function"] = "function"
|
||||
function: dict[str, str] # {"name": "tool_name"}
|
||||
|
||||
|
||||
ToolChoice = Literal["auto", "none", "required"] | ToolChoiceFunction
|
||||
|
||||
|
||||
class ResponseFormat(BaseModel):
|
||||
"""OpenAI structured-output response_format hint.
|
||||
|
||||
Two shapes are accepted:
|
||||
- {"type": "json_object"} — free-form JSON
|
||||
- {"type": "json_schema",
|
||||
"json_schema": {"name": "...", "schema": {...}, "strict": bool}}
|
||||
— schema-constrained JSON; passed through to providers that
|
||||
support it (e.g. Ollama 0.5+ via its native `format` field).
|
||||
"""
|
||||
|
||||
type: Literal["json_object", "json_schema"]
|
||||
json_schema: dict[str, Any] | None = None
|
||||
|
||||
|
||||
class ChatCompletionRequest(BaseModel):
|
||||
"""Request body for chat completions endpoint."""
|
||||
|
||||
model: str = Field(..., description="Model identifier in format 'provider/model' or just 'model'")
|
||||
messages: list[Message] = Field(..., min_length=1)
|
||||
stream: bool = False
|
||||
temperature: float | None = Field(default=None, ge=0.0, le=2.0)
|
||||
max_tokens: int | None = Field(default=None, gt=0)
|
||||
top_p: float | None = Field(default=None, ge=0.0, le=1.0)
|
||||
frequency_penalty: float | None = Field(default=None, ge=-2.0, le=2.0)
|
||||
presence_penalty: float | None = Field(default=None, ge=-2.0, le=2.0)
|
||||
stop: str | list[str] | None = None
|
||||
response_format: ResponseFormat | None = None
|
||||
tools: list[ToolSpec] | None = None
|
||||
tool_choice: ToolChoice | None = None
|
||||
|
||||
|
||||
class EmbeddingRequest(BaseModel):
|
||||
"""Request body for embeddings endpoint."""
|
||||
|
||||
model: str = Field(..., description="Model identifier")
|
||||
input: str | list[str] = Field(..., description="Text(s) to embed")
|
||||
encoding_format: Literal["float", "base64"] = "float"
|
||||
|
|
@ -1,120 +0,0 @@
|
|||
"""Response models for OpenAI-compatible API."""
|
||||
|
||||
import time
|
||||
import uuid
|
||||
from typing import Literal
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class Usage(BaseModel):
|
||||
"""Token usage information."""
|
||||
|
||||
prompt_tokens: int = 0
|
||||
completion_tokens: int = 0
|
||||
total_tokens: int = 0
|
||||
|
||||
|
||||
class ToolCallFunction(BaseModel):
|
||||
"""Function portion of a tool_call the assistant produced."""
|
||||
|
||||
name: str
|
||||
arguments: str # JSON-encoded (OpenAI spec)
|
||||
|
||||
|
||||
class ToolCall(BaseModel):
|
||||
"""A tool invocation the assistant decided to make."""
|
||||
|
||||
id: str
|
||||
type: Literal["function"] = "function"
|
||||
function: ToolCallFunction
|
||||
|
||||
|
||||
class MessageResponse(BaseModel):
|
||||
"""Response message from the model."""
|
||||
|
||||
role: Literal["assistant"] = "assistant"
|
||||
content: str | None = None
|
||||
tool_calls: list[ToolCall] | None = None
|
||||
|
||||
|
||||
class Choice(BaseModel):
|
||||
"""A single completion choice."""
|
||||
|
||||
index: int = 0
|
||||
message: MessageResponse
|
||||
finish_reason: (
|
||||
Literal["stop", "length", "content_filter", "tool_calls"] | None
|
||||
) = "stop"
|
||||
|
||||
|
||||
class ChatCompletionResponse(BaseModel):
|
||||
"""Response from chat completions endpoint (non-streaming)."""
|
||||
|
||||
id: str = Field(default_factory=lambda: f"chatcmpl-{uuid.uuid4().hex[:12]}")
|
||||
object: Literal["chat.completion"] = "chat.completion"
|
||||
created: int = Field(default_factory=lambda: int(time.time()))
|
||||
model: str
|
||||
choices: list[Choice]
|
||||
usage: Usage = Field(default_factory=Usage)
|
||||
|
||||
|
||||
class DeltaContent(BaseModel):
|
||||
"""Delta content for streaming responses."""
|
||||
|
||||
role: Literal["assistant"] | None = None
|
||||
content: str | None = None
|
||||
tool_calls: list[ToolCall] | None = None
|
||||
|
||||
|
||||
class StreamChoice(BaseModel):
|
||||
"""A single streaming choice."""
|
||||
|
||||
index: int = 0
|
||||
delta: DeltaContent
|
||||
finish_reason: (
|
||||
Literal["stop", "length", "content_filter", "tool_calls"] | None
|
||||
) = None
|
||||
|
||||
|
||||
class ChatCompletionStreamResponse(BaseModel):
|
||||
"""Response chunk from chat completions endpoint (streaming)."""
|
||||
|
||||
id: str = Field(default_factory=lambda: f"chatcmpl-{uuid.uuid4().hex[:12]}")
|
||||
object: Literal["chat.completion.chunk"] = "chat.completion.chunk"
|
||||
created: int = Field(default_factory=lambda: int(time.time()))
|
||||
model: str
|
||||
choices: list[StreamChoice]
|
||||
|
||||
|
||||
class ModelInfo(BaseModel):
|
||||
"""Information about a model."""
|
||||
|
||||
id: str
|
||||
object: Literal["model"] = "model"
|
||||
created: int = Field(default_factory=lambda: int(time.time()))
|
||||
owned_by: str = "mana-llm"
|
||||
|
||||
|
||||
class ModelsResponse(BaseModel):
|
||||
"""Response from models endpoint."""
|
||||
|
||||
object: Literal["list"] = "list"
|
||||
data: list[ModelInfo]
|
||||
|
||||
|
||||
class EmbeddingData(BaseModel):
|
||||
"""A single embedding result."""
|
||||
|
||||
object: Literal["embedding"] = "embedding"
|
||||
index: int = 0
|
||||
embedding: list[float]
|
||||
|
||||
|
||||
class EmbeddingResponse(BaseModel):
|
||||
"""Response from embeddings endpoint."""
|
||||
|
||||
object: Literal["list"] = "list"
|
||||
data: list[EmbeddingData]
|
||||
model: str
|
||||
usage: Usage = Field(default_factory=Usage)
|
||||
|
|
@ -1,13 +0,0 @@
|
|||
"""LLM Provider implementations."""
|
||||
|
||||
from .base import LLMProvider
|
||||
from .ollama import OllamaProvider
|
||||
from .openai_compat import OpenAICompatProvider
|
||||
from .router import ProviderRouter
|
||||
|
||||
__all__ = [
|
||||
"LLMProvider",
|
||||
"OllamaProvider",
|
||||
"OpenAICompatProvider",
|
||||
"ProviderRouter",
|
||||
]
|
||||
|
|
@ -1,76 +0,0 @@
|
|||
"""Abstract base class for LLM providers."""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import AsyncIterator
|
||||
from typing import Any
|
||||
|
||||
from src.models import (
|
||||
ChatCompletionRequest,
|
||||
ChatCompletionResponse,
|
||||
ChatCompletionStreamResponse,
|
||||
EmbeddingRequest,
|
||||
EmbeddingResponse,
|
||||
ModelInfo,
|
||||
)
|
||||
|
||||
|
||||
class LLMProvider(ABC):
|
||||
"""Abstract base class for LLM providers."""
|
||||
|
||||
name: str = "base"
|
||||
|
||||
# Set to True if the provider supports OpenAI-style `tools` + `tool_calls`
|
||||
# for chat completions. The router rejects tool-bearing requests routed
|
||||
# to providers without support rather than silently dropping the tools.
|
||||
# Provider adapters may further narrow this per-model if needed.
|
||||
supports_tools: bool = False
|
||||
|
||||
def model_supports_tools(self, model: str) -> bool:
|
||||
"""Check if a specific model within this provider supports tools.
|
||||
|
||||
Default: falls back to the provider-wide flag. Providers with a
|
||||
mixed capability surface (e.g. Ollama — depends on the local
|
||||
model) override this.
|
||||
"""
|
||||
return self.supports_tools
|
||||
|
||||
@abstractmethod
|
||||
async def chat_completion(
|
||||
self,
|
||||
request: ChatCompletionRequest,
|
||||
model: str,
|
||||
) -> ChatCompletionResponse:
|
||||
"""Generate a chat completion (non-streaming)."""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def chat_completion_stream(
|
||||
self,
|
||||
request: ChatCompletionRequest,
|
||||
model: str,
|
||||
) -> AsyncIterator[ChatCompletionStreamResponse]:
|
||||
"""Generate a chat completion (streaming)."""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def list_models(self) -> list[ModelInfo]:
|
||||
"""List available models."""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def embeddings(
|
||||
self,
|
||||
request: EmbeddingRequest,
|
||||
model: str,
|
||||
) -> EmbeddingResponse:
|
||||
"""Generate embeddings for input text."""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def health_check(self) -> dict[str, Any]:
|
||||
"""Check provider health status."""
|
||||
...
|
||||
|
||||
async def close(self) -> None:
|
||||
"""Clean up resources."""
|
||||
pass
|
||||
|
|
@ -1,111 +0,0 @@
|
|||
"""Structured provider errors.
|
||||
|
||||
These map to distinct HTTP status codes and metric labels so callers can
|
||||
distinguish between "model refused to answer" (blocked), "hit the token
|
||||
budget" (truncated), "auth is broken", "we were rate-limited", and
|
||||
"model doesn't support what we asked" (capability). The old behaviour of
|
||||
returning an empty-string response on content-filter hits silently
|
||||
corrupts downstream pipelines (e.g. the planner sees "" and reports
|
||||
"no JSON block found" — misleading).
|
||||
"""
|
||||
|
||||
|
||||
class ProviderError(Exception):
|
||||
"""Base class for structured provider errors."""
|
||||
|
||||
# Short stable identifier used in metrics and API responses.
|
||||
kind: str = "unknown"
|
||||
# Suggested HTTP status when this error bubbles out of an endpoint.
|
||||
http_status: int = 502
|
||||
|
||||
|
||||
class ProviderBlockedError(ProviderError):
|
||||
"""The provider refused to return content (safety, recitation, …).
|
||||
|
||||
The model produced nothing usable because a content filter or policy
|
||||
guardrail fired. Retry with the same inputs will fail again — the
|
||||
caller needs to adjust the prompt or give the user a clear message.
|
||||
"""
|
||||
|
||||
kind = "blocked"
|
||||
http_status = 422
|
||||
|
||||
def __init__(self, reason: str, detail: str | None = None):
|
||||
self.reason = reason # e.g. "SAFETY", "RECITATION"
|
||||
self.detail = detail
|
||||
msg = f"Provider blocked the response ({reason})"
|
||||
if detail:
|
||||
msg += f": {detail}"
|
||||
super().__init__(msg)
|
||||
|
||||
|
||||
class ProviderTruncatedError(ProviderError):
|
||||
"""The response was cut off before completion (hit max_tokens)."""
|
||||
|
||||
kind = "truncated"
|
||||
http_status = 502
|
||||
|
||||
def __init__(self, partial_text: str | None = None):
|
||||
self.partial_text = partial_text
|
||||
super().__init__(
|
||||
"Provider truncated the response (max_tokens reached) — "
|
||||
"re-run with a higher max_tokens or smaller input"
|
||||
)
|
||||
|
||||
|
||||
class ProviderAuthError(ProviderError):
|
||||
"""Authentication against the upstream provider failed."""
|
||||
|
||||
kind = "auth"
|
||||
http_status = 502 # not 401 — the client's auth is fine, *ours* isn't
|
||||
|
||||
|
||||
class ProviderRateLimitError(ProviderError):
|
||||
"""The upstream provider rate-limited us."""
|
||||
|
||||
kind = "rate_limit"
|
||||
http_status = 429
|
||||
|
||||
|
||||
class ProviderCapabilityError(ProviderError):
|
||||
"""The requested feature is not supported by the chosen model.
|
||||
|
||||
Typically raised when a request asks for native tool-calling against
|
||||
a model that does not support it. We refuse to silently degrade.
|
||||
"""
|
||||
|
||||
kind = "capability"
|
||||
http_status = 400
|
||||
|
||||
|
||||
class NoHealthyProviderError(ProviderError):
|
||||
"""Every entry in the resolved chain has been tried and either was
|
||||
unconfigured, marked unhealthy by the cache, or failed in flight.
|
||||
|
||||
Carries the full attempt log so the caller can see which providers
|
||||
were tried and why each one was skipped or failed — invaluable when
|
||||
a real outage hits and the API returns 503 instead of the usual 200.
|
||||
"""
|
||||
|
||||
kind = "no_healthy_provider"
|
||||
http_status = 503
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_or_alias: str,
|
||||
attempts: list[tuple[str, str]],
|
||||
last_exception: Exception | None = None,
|
||||
) -> None:
|
||||
self.model_or_alias = model_or_alias
|
||||
self.attempts = list(attempts)
|
||||
self.last_exception = last_exception
|
||||
if attempts:
|
||||
attempt_log = ", ".join(f"{model}={reason}" for model, reason in attempts)
|
||||
else:
|
||||
attempt_log = "(no providers in resolved chain were configured)"
|
||||
msg = (
|
||||
f"no healthy provider could serve {model_or_alias!r}. Attempts: {attempt_log}"
|
||||
)
|
||||
if last_exception is not None:
|
||||
msg += f". Last error: {type(last_exception).__name__}: {last_exception}"
|
||||
super().__init__(msg)
|
||||
|
|
@ -1,521 +0,0 @@
|
|||
"""Google Gemini provider for mana-llm (fallback when Ollama is unavailable)."""
|
||||
|
||||
import json
|
||||
import logging
|
||||
import uuid
|
||||
from collections.abc import AsyncIterator
|
||||
from typing import Any
|
||||
|
||||
from google import genai
|
||||
from google.genai import types
|
||||
|
||||
from src.config import settings
|
||||
from src.models import (
|
||||
ChatCompletionRequest,
|
||||
ChatCompletionResponse,
|
||||
ChatCompletionStreamResponse,
|
||||
Choice,
|
||||
DeltaContent,
|
||||
EmbeddingData,
|
||||
EmbeddingRequest,
|
||||
EmbeddingResponse,
|
||||
MessageResponse,
|
||||
ModelInfo,
|
||||
StreamChoice,
|
||||
ToolCall,
|
||||
ToolCallFunction,
|
||||
ToolSpec,
|
||||
Usage,
|
||||
)
|
||||
|
||||
from .base import LLMProvider
|
||||
from .errors import (
|
||||
ProviderAuthError,
|
||||
ProviderBlockedError,
|
||||
ProviderError,
|
||||
ProviderRateLimitError,
|
||||
ProviderTruncatedError,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _build_gemini_tools(tools: list[ToolSpec] | None) -> list[types.Tool] | None:
|
||||
"""Translate our ToolSpec list into Gemini ``types.Tool`` declarations."""
|
||||
if not tools:
|
||||
return None
|
||||
declarations: list[types.FunctionDeclaration] = []
|
||||
for t in tools:
|
||||
declarations.append(
|
||||
types.FunctionDeclaration(
|
||||
name=t.function.name,
|
||||
description=t.function.description,
|
||||
parameters=t.function.parameters or None,
|
||||
)
|
||||
)
|
||||
return [types.Tool(function_declarations=declarations)]
|
||||
|
||||
|
||||
def _extract_tool_calls(candidate: Any) -> list[ToolCall] | None:
|
||||
"""Pull any ``function_call`` parts off a candidate into ToolCalls.
|
||||
|
||||
Gemini emits tool calls as ``content.parts[i].function_call`` with a
|
||||
``FunctionCall(name, args)`` where ``args`` is a dict (not a JSON
|
||||
string). We normalise to OpenAI shape: arguments are JSON-encoded
|
||||
strings so downstream handlers can treat all providers the same.
|
||||
"""
|
||||
if candidate is None:
|
||||
return None
|
||||
content = getattr(candidate, "content", None)
|
||||
parts = getattr(content, "parts", None) or []
|
||||
calls: list[ToolCall] = []
|
||||
for part in parts:
|
||||
fc = getattr(part, "function_call", None)
|
||||
if fc is None:
|
||||
continue
|
||||
name = getattr(fc, "name", None)
|
||||
if not name:
|
||||
continue
|
||||
args = getattr(fc, "args", None) or {}
|
||||
# Gemini's args are already dict-shaped; serialise to JSON string.
|
||||
try:
|
||||
arguments_json = json.dumps(dict(args), ensure_ascii=False)
|
||||
except (TypeError, ValueError):
|
||||
arguments_json = json.dumps({}, ensure_ascii=False)
|
||||
calls.append(
|
||||
ToolCall(
|
||||
id=f"call_{uuid.uuid4().hex[:12]}",
|
||||
function=ToolCallFunction(name=name, arguments=arguments_json),
|
||||
)
|
||||
)
|
||||
return calls or None
|
||||
|
||||
|
||||
def _unwrap_gemini_response(
|
||||
response: Any, gemini_model: str
|
||||
) -> tuple[str, list[ToolCall] | None, str]:
|
||||
"""Validate a non-streaming Gemini response.
|
||||
|
||||
Returns ``(text, tool_calls, finish_reason)``. Raises a structured
|
||||
``ProviderError`` if the response was blocked, truncated, or
|
||||
otherwise produced no usable payload. ``response.text`` silently
|
||||
returns ``""`` on blocked responses — we refuse to propagate that.
|
||||
"""
|
||||
candidates = getattr(response, "candidates", None) or []
|
||||
candidate = candidates[0] if candidates else None
|
||||
finish_reason = getattr(candidate, "finish_reason", None)
|
||||
# SDK sometimes exposes the enum name on `.name`, sometimes it's a string.
|
||||
finish_name = getattr(finish_reason, "name", None) or (
|
||||
str(finish_reason) if finish_reason is not None else None
|
||||
)
|
||||
# Strip the leading enum prefix if present (e.g. "FinishReason.SAFETY").
|
||||
if finish_name and "." in finish_name:
|
||||
finish_name = finish_name.rsplit(".", 1)[-1]
|
||||
|
||||
text = response.text or ""
|
||||
tool_calls = _extract_tool_calls(candidate)
|
||||
|
||||
if finish_name in {"SAFETY", "RECITATION", "PROHIBITED_CONTENT", "SPII", "BLOCKLIST"}:
|
||||
# Pull the first safety rating that actually blocked if present.
|
||||
ratings = getattr(candidate, "safety_ratings", None) or []
|
||||
blocked = [
|
||||
getattr(r, "category", None)
|
||||
for r in ratings
|
||||
if getattr(r, "blocked", False)
|
||||
]
|
||||
detail = ", ".join(str(c) for c in blocked if c) or None
|
||||
logger.warning(
|
||||
"Gemini response blocked (model=%s, reason=%s, detail=%s)",
|
||||
gemini_model,
|
||||
finish_name,
|
||||
detail,
|
||||
)
|
||||
raise ProviderBlockedError(reason=finish_name, detail=detail)
|
||||
|
||||
if finish_name == "MAX_TOKENS":
|
||||
logger.warning(
|
||||
"Gemini response truncated at max_tokens (model=%s)", gemini_model
|
||||
)
|
||||
raise ProviderTruncatedError(partial_text=text or None)
|
||||
|
||||
if not text and not tool_calls and finish_name not in (None, "STOP"):
|
||||
# Unknown finish reason, nothing to return — surface instead of "".
|
||||
raise ProviderError(
|
||||
f"Gemini returned no content (finish_reason={finish_name})"
|
||||
)
|
||||
|
||||
# Normalise the finish_reason to our OpenAI-compatible vocabulary.
|
||||
openai_finish = "stop"
|
||||
if tool_calls:
|
||||
openai_finish = "tool_calls"
|
||||
elif finish_name == "MAX_TOKENS":
|
||||
openai_finish = "length"
|
||||
return text, tool_calls, openai_finish
|
||||
|
||||
|
||||
def _lookup_tool_name(messages: list[Any], tool_call_id: str | None) -> str | None:
|
||||
"""Find the tool name for a given ``tool_call_id``.
|
||||
|
||||
OpenAI's tool-message carries only the id; the matching name lives
|
||||
on the preceding assistant message's ``tool_calls[]``. We scan from
|
||||
the end (most recent assistant turn) backwards.
|
||||
"""
|
||||
if not tool_call_id:
|
||||
return None
|
||||
for m in reversed(messages):
|
||||
if m.role != "assistant" or not m.tool_calls:
|
||||
continue
|
||||
for call in m.tool_calls:
|
||||
if call.id == tool_call_id:
|
||||
return call.function.name
|
||||
return None
|
||||
|
||||
|
||||
def _wrap_gemini_call_error(err: Exception, gemini_model: str) -> ProviderError:
|
||||
"""Translate a raw Google SDK exception into a structured ProviderError.
|
||||
|
||||
The SDK uses google.genai.errors.* but we avoid importing them at
|
||||
top level to keep the provider optional. String-match the class
|
||||
name instead.
|
||||
"""
|
||||
cls_name = type(err).__name__
|
||||
msg = str(err) or cls_name
|
||||
if "Auth" in cls_name or "PermissionDenied" in cls_name or "Unauthenticated" in cls_name:
|
||||
return ProviderAuthError(f"Gemini auth failed for {gemini_model}: {msg}")
|
||||
if "ResourceExhausted" in cls_name or "RateLimit" in cls_name or "429" in msg:
|
||||
return ProviderRateLimitError(f"Gemini rate-limited for {gemini_model}: {msg}")
|
||||
return ProviderError(f"Gemini call failed for {gemini_model}: {msg}")
|
||||
|
||||
# Model mapping: Ollama model → Google Gemini equivalent
|
||||
OLLAMA_TO_GEMINI: dict[str, str] = {
|
||||
"gemma3:4b": "gemini-2.5-flash",
|
||||
"gemma3:12b": "gemini-2.5-flash",
|
||||
"gemma3:27b": "gemini-2.5-pro",
|
||||
"llava:7b": "gemini-2.5-flash", # Gemini has native vision
|
||||
"qwen3-vl:4b": "gemini-2.5-flash", # vision fallback
|
||||
"qwen2.5-coder:7b": "gemini-2.5-flash",
|
||||
"qwen2.5-coder:14b": "gemini-2.5-pro",
|
||||
"phi3.5:latest": "gemini-2.5-flash",
|
||||
"ministral-3:3b": "gemini-2.5-flash",
|
||||
"deepseek-ocr:latest": "gemini-2.5-flash",
|
||||
}
|
||||
|
||||
|
||||
class GoogleProvider(LLMProvider):
|
||||
"""Google Gemini API provider."""
|
||||
|
||||
name = "google"
|
||||
# Gemini 2.x supports OpenAI-style function calling across all listed models.
|
||||
supports_tools = True
|
||||
|
||||
def __init__(self, api_key: str, default_model: str = "gemini-2.5-flash"):
|
||||
self.api_key = api_key
|
||||
self.default_model = default_model
|
||||
self.client = genai.Client(api_key=api_key)
|
||||
|
||||
def map_model(self, ollama_model: str) -> str:
|
||||
"""Map an Ollama model name to a Google Gemini equivalent."""
|
||||
return OLLAMA_TO_GEMINI.get(ollama_model, self.default_model)
|
||||
|
||||
def _convert_messages(
|
||||
self, request: ChatCompletionRequest
|
||||
) -> tuple[str | None, list[types.Content]]:
|
||||
"""Convert OpenAI-format messages to Google Gemini format.
|
||||
|
||||
Returns (system_instruction, contents). Handles text, multimodal
|
||||
image content, assistant messages carrying ``tool_calls``, and
|
||||
``tool`` result messages (mapped to Gemini ``function_response``
|
||||
parts — Gemini has no ``tool`` role, function responses ride on
|
||||
a ``user`` turn).
|
||||
"""
|
||||
system_instruction: str | None = None
|
||||
contents: list[types.Content] = []
|
||||
|
||||
for msg in request.messages:
|
||||
if msg.role == "system":
|
||||
if isinstance(msg.content, str):
|
||||
system_instruction = msg.content
|
||||
continue
|
||||
|
||||
# Tool result message → function_response Part on a user turn.
|
||||
if msg.role == "tool":
|
||||
# The content is the stringified tool result. We also need
|
||||
# a tool name — the OpenAI spec carries it on the matching
|
||||
# assistant tool_call, keyed by tool_call_id. We don't
|
||||
# track that back-reference here, so we pull the name
|
||||
# from the preceding assistant message's tool_calls.
|
||||
name = _lookup_tool_name(request.messages, msg.tool_call_id)
|
||||
if not name:
|
||||
continue # orphan tool message — skip silently
|
||||
payload: Any
|
||||
if isinstance(msg.content, str):
|
||||
try:
|
||||
payload = json.loads(msg.content)
|
||||
if not isinstance(payload, dict):
|
||||
payload = {"result": payload}
|
||||
except (TypeError, ValueError):
|
||||
payload = {"result": msg.content}
|
||||
else:
|
||||
payload = {"result": ""}
|
||||
contents.append(
|
||||
types.Content(
|
||||
role="user",
|
||||
parts=[
|
||||
types.Part.from_function_response(
|
||||
name=name, response=payload
|
||||
)
|
||||
],
|
||||
)
|
||||
)
|
||||
continue
|
||||
|
||||
role = "user" if msg.role == "user" else "model"
|
||||
parts: list[types.Part] = []
|
||||
|
||||
if msg.role == "assistant" and msg.tool_calls:
|
||||
for call in msg.tool_calls:
|
||||
try:
|
||||
args_obj = json.loads(call.function.arguments or "{}")
|
||||
except (TypeError, ValueError):
|
||||
args_obj = {}
|
||||
parts.append(
|
||||
types.Part(
|
||||
function_call=types.FunctionCall(
|
||||
name=call.function.name, args=args_obj
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
if msg.content is not None:
|
||||
if isinstance(msg.content, str):
|
||||
parts.append(types.Part.from_text(text=msg.content))
|
||||
else:
|
||||
for part in msg.content:
|
||||
if part.type == "text":
|
||||
parts.append(types.Part.from_text(text=part.text))
|
||||
elif part.type == "image_url" and part.image_url:
|
||||
url = part.image_url.url
|
||||
if url.startswith("data:"):
|
||||
header, b64_data = url.split(",", 1)
|
||||
mime_type = header.split(":")[1].split(";")[0]
|
||||
import base64
|
||||
|
||||
image_bytes = base64.b64decode(b64_data)
|
||||
parts.append(
|
||||
types.Part.from_bytes(
|
||||
data=image_bytes, mime_type=mime_type
|
||||
)
|
||||
)
|
||||
else:
|
||||
parts.append(
|
||||
types.Part.from_uri(
|
||||
file_uri=url, mime_type="image/jpeg"
|
||||
)
|
||||
)
|
||||
|
||||
if parts:
|
||||
contents.append(types.Content(role=role, parts=parts))
|
||||
|
||||
return system_instruction, contents
|
||||
|
||||
async def chat_completion(
|
||||
self,
|
||||
request: ChatCompletionRequest,
|
||||
model: str,
|
||||
) -> ChatCompletionResponse:
|
||||
"""Generate a chat completion via Google Gemini."""
|
||||
gemini_model = self.map_model(model) if model in OLLAMA_TO_GEMINI else model
|
||||
system_instruction, contents = self._convert_messages(request)
|
||||
|
||||
config: dict[str, Any] = {}
|
||||
if request.temperature is not None:
|
||||
config["temperature"] = request.temperature
|
||||
if request.max_tokens is not None:
|
||||
config["max_output_tokens"] = request.max_tokens
|
||||
if request.top_p is not None:
|
||||
config["top_p"] = request.top_p
|
||||
if request.stop:
|
||||
stop_seqs = request.stop if isinstance(request.stop, list) else [request.stop]
|
||||
config["stop_sequences"] = stop_seqs
|
||||
|
||||
gemini_tools = _build_gemini_tools(request.tools)
|
||||
if gemini_tools:
|
||||
config["tools"] = gemini_tools
|
||||
|
||||
gen_config = types.GenerateContentConfig(
|
||||
system_instruction=system_instruction,
|
||||
**config,
|
||||
)
|
||||
|
||||
logger.debug(f"Google Gemini request: {gemini_model}, messages: {len(contents)}")
|
||||
|
||||
try:
|
||||
response = await self.client.aio.models.generate_content(
|
||||
model=gemini_model,
|
||||
contents=contents,
|
||||
config=gen_config,
|
||||
)
|
||||
except ProviderError:
|
||||
raise
|
||||
except Exception as err:
|
||||
raise _wrap_gemini_call_error(err, gemini_model) from err
|
||||
|
||||
content, tool_calls, finish_reason = _unwrap_gemini_response(
|
||||
response, gemini_model
|
||||
)
|
||||
usage_meta = response.usage_metadata
|
||||
|
||||
return ChatCompletionResponse(
|
||||
model=f"google/{gemini_model}",
|
||||
choices=[
|
||||
Choice(
|
||||
index=0,
|
||||
message=MessageResponse(
|
||||
content=content or None,
|
||||
tool_calls=tool_calls,
|
||||
),
|
||||
finish_reason=finish_reason,
|
||||
)
|
||||
],
|
||||
usage=Usage(
|
||||
prompt_tokens=usage_meta.prompt_token_count if usage_meta else 0,
|
||||
completion_tokens=usage_meta.candidates_token_count if usage_meta else 0,
|
||||
total_tokens=usage_meta.total_token_count if usage_meta else 0,
|
||||
),
|
||||
)
|
||||
|
||||
async def chat_completion_stream(
|
||||
self,
|
||||
request: ChatCompletionRequest,
|
||||
model: str,
|
||||
) -> AsyncIterator[ChatCompletionStreamResponse]:
|
||||
"""Generate a streaming chat completion via Google Gemini."""
|
||||
gemini_model = self.map_model(model) if model in OLLAMA_TO_GEMINI else model
|
||||
system_instruction, contents = self._convert_messages(request)
|
||||
|
||||
config: dict[str, Any] = {}
|
||||
if request.temperature is not None:
|
||||
config["temperature"] = request.temperature
|
||||
if request.max_tokens is not None:
|
||||
config["max_output_tokens"] = request.max_tokens
|
||||
if request.top_p is not None:
|
||||
config["top_p"] = request.top_p
|
||||
|
||||
gemini_tools = _build_gemini_tools(request.tools)
|
||||
if gemini_tools:
|
||||
config["tools"] = gemini_tools
|
||||
|
||||
gen_config = types.GenerateContentConfig(
|
||||
system_instruction=system_instruction,
|
||||
**config,
|
||||
)
|
||||
|
||||
# First chunk with role
|
||||
yield ChatCompletionStreamResponse(
|
||||
model=f"google/{gemini_model}",
|
||||
choices=[
|
||||
StreamChoice(
|
||||
delta=DeltaContent(role="assistant"),
|
||||
finish_reason=None,
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
last_chunk: Any = None
|
||||
emitted_any_text = False
|
||||
try:
|
||||
stream = await self.client.aio.models.generate_content_stream(
|
||||
model=gemini_model,
|
||||
contents=contents,
|
||||
config=gen_config,
|
||||
)
|
||||
except Exception as err:
|
||||
raise _wrap_gemini_call_error(err, gemini_model) from err
|
||||
|
||||
async for chunk in stream:
|
||||
last_chunk = chunk
|
||||
text = chunk.text
|
||||
if text:
|
||||
emitted_any_text = True
|
||||
yield ChatCompletionStreamResponse(
|
||||
model=f"google/{gemini_model}",
|
||||
choices=[
|
||||
StreamChoice(
|
||||
delta=DeltaContent(content=text),
|
||||
finish_reason=None,
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
# Post-stream check: if the stream ended without emitting any text,
|
||||
# surface the structured reason instead of quietly closing with an
|
||||
# empty "stop". Matches _unwrap_gemini_response semantics. We
|
||||
# discard the return value here — the streaming path produces its
|
||||
# content chunk-by-chunk above, so we only need this call for its
|
||||
# side-effect of raising on SAFETY / RECITATION / MAX_TOKENS.
|
||||
if not emitted_any_text and last_chunk is not None:
|
||||
_unwrap_gemini_response(last_chunk, gemini_model)
|
||||
|
||||
# Final chunk
|
||||
yield ChatCompletionStreamResponse(
|
||||
model=f"google/{gemini_model}",
|
||||
choices=[
|
||||
StreamChoice(
|
||||
delta=DeltaContent(),
|
||||
finish_reason="stop",
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
async def list_models(self) -> list[ModelInfo]:
|
||||
"""List available Google Gemini models."""
|
||||
# Return a static list of commonly used models
|
||||
return [
|
||||
ModelInfo(id="google/gemini-2.5-flash", owned_by="google"),
|
||||
ModelInfo(id="google/gemini-2.5-pro", owned_by="google"),
|
||||
]
|
||||
|
||||
async def embeddings(
|
||||
self,
|
||||
request: EmbeddingRequest,
|
||||
model: str,
|
||||
) -> EmbeddingResponse:
|
||||
"""Generate embeddings via Google Gemini."""
|
||||
inputs = request.input if isinstance(request.input, list) else [request.input]
|
||||
|
||||
result = await self.client.aio.models.embed_content(
|
||||
model="text-embedding-004",
|
||||
contents=inputs,
|
||||
)
|
||||
|
||||
return EmbeddingResponse(
|
||||
data=[
|
||||
EmbeddingData(index=i, embedding=emb.values)
|
||||
for i, emb in enumerate(result.embeddings)
|
||||
],
|
||||
model="google/text-embedding-004",
|
||||
usage=Usage(
|
||||
prompt_tokens=sum(len(t.split()) for t in inputs),
|
||||
total_tokens=sum(len(t.split()) for t in inputs),
|
||||
),
|
||||
)
|
||||
|
||||
async def health_check(self) -> dict[str, Any]:
|
||||
"""Check Google API health."""
|
||||
try:
|
||||
# Quick test: list models
|
||||
response = await self.client.aio.models.list(config={"page_size": 1})
|
||||
return {
|
||||
"status": "healthy",
|
||||
"provider": "google",
|
||||
}
|
||||
except Exception as e:
|
||||
return {
|
||||
"status": "unhealthy",
|
||||
"provider": "google",
|
||||
"error": str(e),
|
||||
}
|
||||
|
||||
async def close(self) -> None:
|
||||
"""No cleanup needed for Google client."""
|
||||
pass
|
||||
|
|
@ -1,462 +0,0 @@
|
|||
"""Ollama provider implementation."""
|
||||
|
||||
import json
|
||||
import logging
|
||||
from collections.abc import AsyncIterator
|
||||
from typing import Any
|
||||
|
||||
import httpx
|
||||
|
||||
from src.config import settings
|
||||
from src.models import (
|
||||
ChatCompletionRequest,
|
||||
ChatCompletionResponse,
|
||||
ChatCompletionStreamResponse,
|
||||
Choice,
|
||||
DeltaContent,
|
||||
EmbeddingData,
|
||||
EmbeddingRequest,
|
||||
EmbeddingResponse,
|
||||
MessageResponse,
|
||||
ModelInfo,
|
||||
StreamChoice,
|
||||
ToolCall,
|
||||
ToolCallFunction,
|
||||
Usage,
|
||||
)
|
||||
|
||||
from .base import LLMProvider
|
||||
|
||||
# Ollama emits tool_calls under /api/chat as:
|
||||
# {"message":{"content":"", "tool_calls":[{"function":{"name":"x","arguments":{...}}}]}}
|
||||
# Arguments arrive as a dict — we normalise to a JSON string to match
|
||||
# the OpenAI-spec shape our downstream code expects.
|
||||
|
||||
# Ollama models known to support tool-calling reliably (as of 0.3+).
|
||||
# Everything else: we still pass `tools` to the API (it ignores them on
|
||||
# incompatible models), but the assistant response will be plain text.
|
||||
# A shared-ai-level capability check rejects these cases before the call.
|
||||
TOOL_CAPABLE_OLLAMA_PATTERNS: tuple[str, ...] = (
|
||||
"llama3.1",
|
||||
"llama3.2",
|
||||
"llama3.3",
|
||||
"qwen2.5",
|
||||
"qwen3",
|
||||
"mistral",
|
||||
"mixtral",
|
||||
"command-r",
|
||||
"firefunction",
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _safe_parse_args(raw: str | None) -> dict[str, Any]:
|
||||
"""Best-effort parse of a JSON-encoded arguments string."""
|
||||
if not raw:
|
||||
return {}
|
||||
try:
|
||||
parsed = json.loads(raw)
|
||||
except (TypeError, ValueError):
|
||||
return {}
|
||||
return parsed if isinstance(parsed, dict) else {}
|
||||
|
||||
|
||||
def _tool_calls_from_ollama(raw: list[dict[str, Any]] | None) -> list[ToolCall] | None:
|
||||
"""Normalise Ollama's tool_calls into our ToolCall model."""
|
||||
if not raw:
|
||||
return None
|
||||
import uuid
|
||||
|
||||
calls: list[ToolCall] = []
|
||||
for idx, c in enumerate(raw):
|
||||
fn = c.get("function") or {}
|
||||
name = fn.get("name")
|
||||
if not name:
|
||||
continue
|
||||
args = fn.get("arguments")
|
||||
if isinstance(args, dict):
|
||||
arguments_json = json.dumps(args, ensure_ascii=False)
|
||||
elif isinstance(args, str):
|
||||
arguments_json = args
|
||||
else:
|
||||
arguments_json = "{}"
|
||||
calls.append(
|
||||
ToolCall(
|
||||
id=c.get("id") or f"call_{uuid.uuid4().hex[:12]}",
|
||||
function=ToolCallFunction(name=name, arguments=arguments_json),
|
||||
)
|
||||
)
|
||||
return calls or None
|
||||
|
||||
|
||||
def _strip_json_fences(content: str) -> str:
|
||||
"""Strip ```json ... ``` markdown fences from a string if present.
|
||||
|
||||
Some Ollama vision models still wrap structured-output responses in
|
||||
a markdown code block even when `format` is set. Downstream parsers
|
||||
(Vercel AI SDK generateObject, manual JSON.parse) expect clean JSON,
|
||||
so we normalize the response here at the proxy boundary.
|
||||
"""
|
||||
s = content.strip()
|
||||
if s.startswith("```"):
|
||||
# Drop the opening fence (```json or ``` plus any language tag)
|
||||
first_newline = s.find("\n")
|
||||
if first_newline != -1:
|
||||
s = s[first_newline + 1 :]
|
||||
# Drop the closing fence
|
||||
if s.endswith("```"):
|
||||
s = s[:-3]
|
||||
s = s.strip()
|
||||
return s
|
||||
|
||||
|
||||
class OllamaProvider(LLMProvider):
|
||||
"""Ollama LLM provider."""
|
||||
|
||||
name = "ollama"
|
||||
supports_tools = True
|
||||
|
||||
def model_supports_tools(self, model: str) -> bool:
|
||||
"""Narrow tool capability to models trained for function calling.
|
||||
|
||||
Ollama accepts the ``tools`` payload even for incompatible models
|
||||
but silently drops it — the assistant just replies in prose. We
|
||||
reject those requests upfront so callers get a clear error.
|
||||
"""
|
||||
lower = model.lower()
|
||||
return any(pattern in lower for pattern in TOOL_CAPABLE_OLLAMA_PATTERNS)
|
||||
|
||||
def __init__(self, base_url: str | None = None, timeout: int | None = None):
|
||||
self.base_url = (base_url or settings.ollama_url).rstrip("/")
|
||||
self.timeout = timeout or settings.ollama_timeout
|
||||
self._client: httpx.AsyncClient | None = None
|
||||
|
||||
@property
|
||||
def client(self) -> httpx.AsyncClient:
|
||||
"""Get or create HTTP client."""
|
||||
if self._client is None or self._client.is_closed:
|
||||
self._client = httpx.AsyncClient(
|
||||
base_url=self.base_url,
|
||||
timeout=httpx.Timeout(self.timeout),
|
||||
)
|
||||
return self._client
|
||||
|
||||
def _convert_messages(self, request: ChatCompletionRequest) -> list[dict[str, Any]]:
|
||||
"""Convert OpenAI message format to Ollama format.
|
||||
|
||||
Ollama's ``/api/chat`` uses the OpenAI shape for assistant
|
||||
``tool_calls`` and ``tool`` result messages (with args as objects
|
||||
rather than JSON strings on output — input still accepts JSON
|
||||
strings), so we largely pass them through.
|
||||
"""
|
||||
messages: list[dict[str, Any]] = []
|
||||
for msg in request.messages:
|
||||
out: dict[str, Any] = {"role": msg.role}
|
||||
|
||||
if msg.tool_calls:
|
||||
out["tool_calls"] = [
|
||||
{
|
||||
"function": {
|
||||
"name": c.function.name,
|
||||
# Ollama accepts either stringified JSON or
|
||||
# already-parsed objects; send parsed for
|
||||
# better compatibility with older models.
|
||||
"arguments": _safe_parse_args(c.function.arguments),
|
||||
}
|
||||
}
|
||||
for c in msg.tool_calls
|
||||
]
|
||||
|
||||
if msg.content is None:
|
||||
out["content"] = ""
|
||||
elif isinstance(msg.content, str):
|
||||
out["content"] = msg.content
|
||||
else:
|
||||
text_parts: list[str] = []
|
||||
images: list[str] = []
|
||||
for part in msg.content:
|
||||
if part.type == "text":
|
||||
text_parts.append(part.text)
|
||||
elif part.type == "image_url":
|
||||
url = part.image_url.url
|
||||
if url.startswith("data:"):
|
||||
base64_data = url.split(",", 1)[1] if "," in url else url
|
||||
images.append(base64_data)
|
||||
else:
|
||||
logger.warning(
|
||||
"HTTP image URLs not supported, skipping: %s...",
|
||||
url[:50],
|
||||
)
|
||||
out["content"] = " ".join(text_parts)
|
||||
if images:
|
||||
out["images"] = images
|
||||
|
||||
messages.append(out)
|
||||
return messages
|
||||
|
||||
async def chat_completion(
|
||||
self,
|
||||
request: ChatCompletionRequest,
|
||||
model: str,
|
||||
) -> ChatCompletionResponse:
|
||||
"""Generate a chat completion (non-streaming)."""
|
||||
payload: dict[str, Any] = {
|
||||
"model": model,
|
||||
"messages": self._convert_messages(request),
|
||||
"stream": False,
|
||||
}
|
||||
|
||||
# Pass through structured-output requests to Ollama's native
|
||||
# `format` field. Ollama supports either `"json"` (free-form
|
||||
# JSON object) or a full JSON schema dict. The OpenAI-style
|
||||
# response_format the consumer sends maps as follows:
|
||||
# - {"type": "json_object"} → "json"
|
||||
# - {"type": "json_schema", "json_schema": {"schema": {...}}}
|
||||
# → the schema dict (Ollama 0.5+ supports full schemas)
|
||||
# Without this, Ollama wraps JSON in ```json ... ``` markdown
|
||||
# fences, which breaks downstream strict parsers like the AI SDK
|
||||
# generateObject() helper.
|
||||
if request.response_format is not None:
|
||||
rf = request.response_format
|
||||
if rf.type == "json_object":
|
||||
payload["format"] = "json"
|
||||
elif rf.type == "json_schema" and rf.json_schema is not None:
|
||||
# rf.json_schema is the OpenAI envelope:
|
||||
# {"name": "...", "schema": {...}, "strict": true}
|
||||
# Ollama wants just the inner schema dict.
|
||||
inner = rf.json_schema.get("schema")
|
||||
payload["format"] = inner if inner is not None else "json"
|
||||
else:
|
||||
payload["format"] = "json"
|
||||
|
||||
# Add optional parameters
|
||||
options: dict[str, Any] = {}
|
||||
if request.temperature is not None:
|
||||
options["temperature"] = request.temperature
|
||||
if request.top_p is not None:
|
||||
options["top_p"] = request.top_p
|
||||
if request.max_tokens is not None:
|
||||
options["num_predict"] = request.max_tokens
|
||||
if request.stop:
|
||||
options["stop"] = request.stop if isinstance(request.stop, list) else [request.stop]
|
||||
|
||||
if options:
|
||||
payload["options"] = options
|
||||
|
||||
if request.tools:
|
||||
payload["tools"] = [
|
||||
{"type": "function", "function": t.function.model_dump()}
|
||||
for t in request.tools
|
||||
]
|
||||
|
||||
logger.debug(f"Ollama request: {model}, messages: {len(request.messages)}")
|
||||
|
||||
response = await self.client.post("/api/chat", json=payload)
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
|
||||
message = data.get("message", {})
|
||||
tool_calls = _tool_calls_from_ollama(message.get("tool_calls"))
|
||||
# Defensive fence-stripping: even with `format` set, some older
|
||||
# Ollama versions still emit ```json ... ``` wrappers for vision
|
||||
# models. Strip them so strict downstream parsers see clean JSON.
|
||||
raw_content = message.get("content", "") or ""
|
||||
content = _strip_json_fences(raw_content) if raw_content else ""
|
||||
|
||||
finish_reason = "tool_calls" if tool_calls else (
|
||||
"stop" if data.get("done") else None
|
||||
)
|
||||
|
||||
return ChatCompletionResponse(
|
||||
model=f"ollama/{model}",
|
||||
choices=[
|
||||
Choice(
|
||||
message=MessageResponse(
|
||||
content=content or None,
|
||||
tool_calls=tool_calls,
|
||||
),
|
||||
finish_reason=finish_reason,
|
||||
)
|
||||
],
|
||||
usage=Usage(
|
||||
prompt_tokens=data.get("prompt_eval_count", 0),
|
||||
completion_tokens=data.get("eval_count", 0),
|
||||
total_tokens=data.get("prompt_eval_count", 0) + data.get("eval_count", 0),
|
||||
),
|
||||
)
|
||||
|
||||
async def chat_completion_stream(
|
||||
self,
|
||||
request: ChatCompletionRequest,
|
||||
model: str,
|
||||
) -> AsyncIterator[ChatCompletionStreamResponse]:
|
||||
"""Generate a chat completion (streaming)."""
|
||||
payload: dict[str, Any] = {
|
||||
"model": model,
|
||||
"messages": self._convert_messages(request),
|
||||
"stream": True,
|
||||
}
|
||||
|
||||
# Add optional parameters
|
||||
options: dict[str, Any] = {}
|
||||
if request.temperature is not None:
|
||||
options["temperature"] = request.temperature
|
||||
if request.top_p is not None:
|
||||
options["top_p"] = request.top_p
|
||||
if request.max_tokens is not None:
|
||||
options["num_predict"] = request.max_tokens
|
||||
if request.stop:
|
||||
options["stop"] = request.stop if isinstance(request.stop, list) else [request.stop]
|
||||
|
||||
if options:
|
||||
payload["options"] = options
|
||||
|
||||
if request.tools:
|
||||
payload["tools"] = [
|
||||
{"type": "function", "function": t.function.model_dump()}
|
||||
for t in request.tools
|
||||
]
|
||||
|
||||
logger.debug(f"Ollama streaming request: {model}")
|
||||
|
||||
response_id = f"chatcmpl-{model[:8]}"
|
||||
first_chunk = True
|
||||
|
||||
async with self.client.stream("POST", "/api/chat", json=payload) as response:
|
||||
response.raise_for_status()
|
||||
async for line in response.aiter_lines():
|
||||
if not line:
|
||||
continue
|
||||
|
||||
try:
|
||||
data = json.loads(line)
|
||||
except json.JSONDecodeError:
|
||||
logger.warning(f"Failed to parse Ollama response line: {line}")
|
||||
continue
|
||||
|
||||
# First chunk includes role
|
||||
if first_chunk:
|
||||
yield ChatCompletionStreamResponse(
|
||||
id=response_id,
|
||||
model=f"ollama/{model}",
|
||||
choices=[
|
||||
StreamChoice(
|
||||
delta=DeltaContent(role="assistant"),
|
||||
)
|
||||
],
|
||||
)
|
||||
first_chunk = False
|
||||
|
||||
# Content chunks
|
||||
content = data.get("message", {}).get("content", "")
|
||||
if content:
|
||||
yield ChatCompletionStreamResponse(
|
||||
id=response_id,
|
||||
model=f"ollama/{model}",
|
||||
choices=[
|
||||
StreamChoice(
|
||||
delta=DeltaContent(content=content),
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
# Final chunk with finish_reason
|
||||
if data.get("done"):
|
||||
yield ChatCompletionStreamResponse(
|
||||
id=response_id,
|
||||
model=f"ollama/{model}",
|
||||
choices=[
|
||||
StreamChoice(
|
||||
delta=DeltaContent(),
|
||||
finish_reason="stop",
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
async def list_models(self) -> list[ModelInfo]:
|
||||
"""List available Ollama models."""
|
||||
response = await self.client.get("/api/tags")
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
|
||||
models = []
|
||||
for model_data in data.get("models", []):
|
||||
name = model_data.get("name", "")
|
||||
# Parse modified_at datetime string to Unix timestamp
|
||||
created = None
|
||||
if modified_at := model_data.get("modified_at"):
|
||||
try:
|
||||
from datetime import datetime
|
||||
# Handle ISO format with timezone
|
||||
dt = datetime.fromisoformat(modified_at.replace("Z", "+00:00"))
|
||||
created = int(dt.timestamp())
|
||||
except (ValueError, TypeError):
|
||||
pass
|
||||
models.append(
|
||||
ModelInfo(
|
||||
id=f"ollama/{name}",
|
||||
owned_by="ollama",
|
||||
created=created,
|
||||
)
|
||||
)
|
||||
return models
|
||||
|
||||
async def embeddings(
|
||||
self,
|
||||
request: EmbeddingRequest,
|
||||
model: str,
|
||||
) -> EmbeddingResponse:
|
||||
"""Generate embeddings for input text."""
|
||||
inputs = request.input if isinstance(request.input, list) else [request.input]
|
||||
embeddings_data = []
|
||||
|
||||
for i, text in enumerate(inputs):
|
||||
response = await self.client.post(
|
||||
"/api/embeddings",
|
||||
json={"model": model, "prompt": text},
|
||||
)
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
|
||||
embeddings_data.append(
|
||||
EmbeddingData(
|
||||
index=i,
|
||||
embedding=data.get("embedding", []),
|
||||
)
|
||||
)
|
||||
|
||||
return EmbeddingResponse(
|
||||
data=embeddings_data,
|
||||
model=f"ollama/{model}",
|
||||
usage=Usage(
|
||||
prompt_tokens=sum(len(text.split()) for text in inputs), # Approximate
|
||||
total_tokens=sum(len(text.split()) for text in inputs),
|
||||
),
|
||||
)
|
||||
|
||||
async def health_check(self) -> dict[str, Any]:
|
||||
"""Check Ollama health status."""
|
||||
try:
|
||||
response = await self.client.get("/api/tags")
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
model_count = len(data.get("models", []))
|
||||
return {
|
||||
"status": "healthy",
|
||||
"provider": self.name,
|
||||
"url": self.base_url,
|
||||
"models_available": model_count,
|
||||
}
|
||||
except Exception as e:
|
||||
return {
|
||||
"status": "unhealthy",
|
||||
"provider": self.name,
|
||||
"url": self.base_url,
|
||||
"error": str(e),
|
||||
}
|
||||
|
||||
async def close(self) -> None:
|
||||
"""Close HTTP client."""
|
||||
if self._client and not self._client.is_closed:
|
||||
await self._client.aclose()
|
||||
|
|
@ -1,364 +0,0 @@
|
|||
"""OpenAI-compatible provider for OpenRouter, Groq, Together, etc."""
|
||||
|
||||
import json
|
||||
import logging
|
||||
from collections.abc import AsyncIterator
|
||||
from typing import Any
|
||||
|
||||
import httpx
|
||||
|
||||
from src.models import (
|
||||
ChatCompletionRequest,
|
||||
ChatCompletionResponse,
|
||||
ChatCompletionStreamResponse,
|
||||
Choice,
|
||||
DeltaContent,
|
||||
EmbeddingData,
|
||||
EmbeddingRequest,
|
||||
EmbeddingResponse,
|
||||
MessageResponse,
|
||||
ModelInfo,
|
||||
StreamChoice,
|
||||
ToolCall,
|
||||
ToolCallFunction,
|
||||
Usage,
|
||||
)
|
||||
|
||||
from .base import LLMProvider
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _tool_calls_from_openai(raw: list[dict[str, Any]] | None) -> list[ToolCall] | None:
|
||||
"""Normalise an OpenAI-spec ``tool_calls`` array into our model shape."""
|
||||
if not raw:
|
||||
return None
|
||||
calls: list[ToolCall] = []
|
||||
for c in raw:
|
||||
fn = c.get("function") or {}
|
||||
name = fn.get("name")
|
||||
if not name:
|
||||
continue
|
||||
calls.append(
|
||||
ToolCall(
|
||||
id=c.get("id") or "",
|
||||
function=ToolCallFunction(
|
||||
name=name,
|
||||
arguments=fn.get("arguments") or "{}",
|
||||
),
|
||||
)
|
||||
)
|
||||
return calls or None
|
||||
|
||||
|
||||
class OpenAICompatProvider(LLMProvider):
|
||||
"""OpenAI-compatible API provider (OpenRouter, Groq, Together, etc.)."""
|
||||
|
||||
# OpenRouter/Groq/Together all expose tool_calls per the OpenAI spec;
|
||||
# individual models within those services may or may not support it,
|
||||
# but the request shape is uniform. The upstream returns a proper
|
||||
# error for unsupported models — no silent downgrade here.
|
||||
supports_tools = True
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
name: str,
|
||||
base_url: str,
|
||||
api_key: str,
|
||||
default_model: str | None = None,
|
||||
timeout: int = 120,
|
||||
):
|
||||
self.name = name
|
||||
self.base_url = base_url.rstrip("/")
|
||||
self.api_key = api_key
|
||||
self.default_model = default_model
|
||||
self.timeout = timeout
|
||||
self._client: httpx.AsyncClient | None = None
|
||||
|
||||
@property
|
||||
def client(self) -> httpx.AsyncClient:
|
||||
"""Get or create HTTP client."""
|
||||
if self._client is None or self._client.is_closed:
|
||||
self._client = httpx.AsyncClient(
|
||||
base_url=self.base_url,
|
||||
timeout=httpx.Timeout(self.timeout),
|
||||
headers={
|
||||
"Authorization": f"Bearer {self.api_key}",
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
)
|
||||
return self._client
|
||||
|
||||
def _convert_messages(self, request: ChatCompletionRequest) -> list[dict[str, Any]]:
|
||||
"""Convert internal message format to OpenAI format.
|
||||
|
||||
The OpenAI chat-completions endpoint is the source of truth for
|
||||
this shape, so most fields pass through verbatim — including
|
||||
``role='tool'`` messages with ``tool_call_id`` + ``content`` and
|
||||
assistant messages carrying ``tool_calls[]``.
|
||||
"""
|
||||
messages: list[dict[str, Any]] = []
|
||||
for msg in request.messages:
|
||||
out: dict[str, Any] = {"role": msg.role}
|
||||
|
||||
if msg.tool_call_id:
|
||||
out["tool_call_id"] = msg.tool_call_id
|
||||
|
||||
if msg.tool_calls:
|
||||
out["tool_calls"] = [
|
||||
{
|
||||
"id": c.id,
|
||||
"type": c.type,
|
||||
"function": {
|
||||
"name": c.function.name,
|
||||
"arguments": c.function.arguments,
|
||||
},
|
||||
}
|
||||
for c in msg.tool_calls
|
||||
]
|
||||
|
||||
if msg.content is None:
|
||||
# Assistant tool-call messages have null content per spec.
|
||||
out["content"] = None
|
||||
elif isinstance(msg.content, str):
|
||||
out["content"] = msg.content
|
||||
else:
|
||||
content_parts = []
|
||||
for part in msg.content:
|
||||
if part.type == "text":
|
||||
content_parts.append({"type": "text", "text": part.text})
|
||||
elif part.type == "image_url":
|
||||
content_parts.append(
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {"url": part.image_url.url},
|
||||
}
|
||||
)
|
||||
out["content"] = content_parts
|
||||
|
||||
messages.append(out)
|
||||
return messages
|
||||
|
||||
async def chat_completion(
|
||||
self,
|
||||
request: ChatCompletionRequest,
|
||||
model: str,
|
||||
) -> ChatCompletionResponse:
|
||||
"""Generate a chat completion (non-streaming)."""
|
||||
payload: dict[str, Any] = {
|
||||
"model": model,
|
||||
"messages": self._convert_messages(request),
|
||||
"stream": False,
|
||||
}
|
||||
|
||||
# Add optional parameters
|
||||
if request.temperature is not None:
|
||||
payload["temperature"] = request.temperature
|
||||
if request.max_tokens is not None:
|
||||
payload["max_tokens"] = request.max_tokens
|
||||
if request.top_p is not None:
|
||||
payload["top_p"] = request.top_p
|
||||
if request.frequency_penalty is not None:
|
||||
payload["frequency_penalty"] = request.frequency_penalty
|
||||
if request.presence_penalty is not None:
|
||||
payload["presence_penalty"] = request.presence_penalty
|
||||
if request.stop:
|
||||
payload["stop"] = request.stop
|
||||
if request.tools:
|
||||
payload["tools"] = [
|
||||
{"type": "function", "function": t.function.model_dump()}
|
||||
for t in request.tools
|
||||
]
|
||||
if request.tool_choice is not None:
|
||||
payload["tool_choice"] = (
|
||||
request.tool_choice
|
||||
if isinstance(request.tool_choice, str)
|
||||
else request.tool_choice.model_dump()
|
||||
)
|
||||
|
||||
logger.debug(f"{self.name} request: {model}, messages: {len(request.messages)}")
|
||||
|
||||
response = await self.client.post("/chat/completions", json=payload)
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
|
||||
return ChatCompletionResponse(
|
||||
id=data.get("id", ""),
|
||||
model=f"{self.name}/{model}",
|
||||
choices=[
|
||||
Choice(
|
||||
index=choice.get("index", 0),
|
||||
message=MessageResponse(
|
||||
content=choice["message"].get("content"),
|
||||
tool_calls=_tool_calls_from_openai(
|
||||
choice["message"].get("tool_calls")
|
||||
),
|
||||
),
|
||||
finish_reason=choice.get("finish_reason", "stop"),
|
||||
)
|
||||
for choice in data.get("choices", [])
|
||||
],
|
||||
usage=Usage(
|
||||
prompt_tokens=data.get("usage", {}).get("prompt_tokens", 0),
|
||||
completion_tokens=data.get("usage", {}).get("completion_tokens", 0),
|
||||
total_tokens=data.get("usage", {}).get("total_tokens", 0),
|
||||
),
|
||||
)
|
||||
|
||||
async def chat_completion_stream(
|
||||
self,
|
||||
request: ChatCompletionRequest,
|
||||
model: str,
|
||||
) -> AsyncIterator[ChatCompletionStreamResponse]:
|
||||
"""Generate a chat completion (streaming)."""
|
||||
payload: dict[str, Any] = {
|
||||
"model": model,
|
||||
"messages": self._convert_messages(request),
|
||||
"stream": True,
|
||||
}
|
||||
|
||||
# Add optional parameters
|
||||
if request.temperature is not None:
|
||||
payload["temperature"] = request.temperature
|
||||
if request.max_tokens is not None:
|
||||
payload["max_tokens"] = request.max_tokens
|
||||
if request.top_p is not None:
|
||||
payload["top_p"] = request.top_p
|
||||
if request.frequency_penalty is not None:
|
||||
payload["frequency_penalty"] = request.frequency_penalty
|
||||
if request.presence_penalty is not None:
|
||||
payload["presence_penalty"] = request.presence_penalty
|
||||
if request.stop:
|
||||
payload["stop"] = request.stop
|
||||
if request.tools:
|
||||
payload["tools"] = [
|
||||
{"type": "function", "function": t.function.model_dump()}
|
||||
for t in request.tools
|
||||
]
|
||||
if request.tool_choice is not None:
|
||||
payload["tool_choice"] = (
|
||||
request.tool_choice
|
||||
if isinstance(request.tool_choice, str)
|
||||
else request.tool_choice.model_dump()
|
||||
)
|
||||
|
||||
logger.debug(f"{self.name} streaming request: {model}")
|
||||
|
||||
async with self.client.stream("POST", "/chat/completions", json=payload) as response:
|
||||
response.raise_for_status()
|
||||
async for line in response.aiter_lines():
|
||||
if not line or not line.startswith("data: "):
|
||||
continue
|
||||
|
||||
data_str = line[6:] # Remove "data: " prefix
|
||||
|
||||
if data_str == "[DONE]":
|
||||
break
|
||||
|
||||
try:
|
||||
data = json.loads(data_str)
|
||||
except json.JSONDecodeError:
|
||||
logger.warning(f"Failed to parse stream line: {data_str}")
|
||||
continue
|
||||
|
||||
choices = data.get("choices", [])
|
||||
if not choices:
|
||||
continue
|
||||
|
||||
choice = choices[0]
|
||||
delta = choice.get("delta", {})
|
||||
|
||||
yield ChatCompletionStreamResponse(
|
||||
id=data.get("id", ""),
|
||||
model=f"{self.name}/{model}",
|
||||
choices=[
|
||||
StreamChoice(
|
||||
index=choice.get("index", 0),
|
||||
delta=DeltaContent(
|
||||
role=delta.get("role"),
|
||||
content=delta.get("content"),
|
||||
tool_calls=_tool_calls_from_openai(
|
||||
delta.get("tool_calls")
|
||||
),
|
||||
),
|
||||
finish_reason=choice.get("finish_reason"),
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
async def list_models(self) -> list[ModelInfo]:
|
||||
"""List available models."""
|
||||
try:
|
||||
response = await self.client.get("/models")
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
|
||||
models = []
|
||||
for model_data in data.get("data", []):
|
||||
model_id = model_data.get("id", "")
|
||||
models.append(
|
||||
ModelInfo(
|
||||
id=f"{self.name}/{model_id}",
|
||||
owned_by=model_data.get("owned_by", self.name),
|
||||
)
|
||||
)
|
||||
return models
|
||||
except httpx.HTTPError as e:
|
||||
logger.warning(f"Failed to list models from {self.name}: {e}")
|
||||
return []
|
||||
|
||||
async def embeddings(
|
||||
self,
|
||||
request: EmbeddingRequest,
|
||||
model: str,
|
||||
) -> EmbeddingResponse:
|
||||
"""Generate embeddings for input text."""
|
||||
payload = {
|
||||
"model": model,
|
||||
"input": request.input,
|
||||
}
|
||||
|
||||
response = await self.client.post("/embeddings", json=payload)
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
|
||||
return EmbeddingResponse(
|
||||
data=[
|
||||
EmbeddingData(
|
||||
index=item.get("index", i),
|
||||
embedding=item.get("embedding", []),
|
||||
)
|
||||
for i, item in enumerate(data.get("data", []))
|
||||
],
|
||||
model=f"{self.name}/{model}",
|
||||
usage=Usage(
|
||||
prompt_tokens=data.get("usage", {}).get("prompt_tokens", 0),
|
||||
total_tokens=data.get("usage", {}).get("total_tokens", 0),
|
||||
),
|
||||
)
|
||||
|
||||
async def health_check(self) -> dict[str, Any]:
|
||||
"""Check provider health status."""
|
||||
try:
|
||||
response = await self.client.get("/models")
|
||||
response.raise_for_status()
|
||||
data = response.json()
|
||||
model_count = len(data.get("data", []))
|
||||
return {
|
||||
"status": "healthy",
|
||||
"provider": self.name,
|
||||
"url": self.base_url,
|
||||
"models_available": model_count,
|
||||
}
|
||||
except Exception as e:
|
||||
return {
|
||||
"status": "unhealthy",
|
||||
"provider": self.name,
|
||||
"url": self.base_url,
|
||||
"error": str(e),
|
||||
}
|
||||
|
||||
async def close(self) -> None:
|
||||
"""Close HTTP client."""
|
||||
if self._client and not self._client.is_closed:
|
||||
await self._client.aclose()
|
||||
|
|
@ -1,463 +0,0 @@
|
|||
"""Provider routing with alias resolution and health-aware fallback.
|
||||
|
||||
The router is the single entry point that the FastAPI handlers use. Its
|
||||
job is:
|
||||
|
||||
1. Resolve the request's ``model`` field. If it lives in the ``mana/``
|
||||
namespace the :class:`AliasRegistry` returns an ordered chain of
|
||||
concrete provider/model strings; everything else is treated as a
|
||||
single-entry chain (caller passed a direct provider/model).
|
||||
2. Walk the chain, skipping entries whose provider is either
|
||||
unconfigured at this deployment (no API key) or currently marked
|
||||
unhealthy in the :class:`ProviderHealthCache`.
|
||||
3. Try each remaining entry. Connection errors, timeouts, 5xx, and rate
|
||||
limits are retryable — record them in the cache and move to the next
|
||||
entry. Capability/auth/blocked errors are caller-fixable and
|
||||
propagate immediately without touching the health cache.
|
||||
4. Return the first successful response. If every entry was skipped or
|
||||
failed, raise :class:`NoHealthyProviderError` (HTTP 503) carrying
|
||||
the full attempt log so debugging is straightforward.
|
||||
|
||||
The full design lives in ``docs/plans/llm-fallback-aliases.md``. This is
|
||||
the M3 milestone.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from collections.abc import AsyncIterator, Awaitable, Callable
|
||||
from typing import Any, TypeVar
|
||||
|
||||
import httpx
|
||||
|
||||
from src.aliases import AliasRegistry
|
||||
from src.config import settings
|
||||
from src.health import ProviderHealthCache
|
||||
from src.models import (
|
||||
ChatCompletionRequest,
|
||||
ChatCompletionResponse,
|
||||
ChatCompletionStreamResponse,
|
||||
EmbeddingRequest,
|
||||
EmbeddingResponse,
|
||||
ModelInfo,
|
||||
)
|
||||
from src.utils.metrics import (
|
||||
record_alias_resolved,
|
||||
record_fallback,
|
||||
set_provider_healthy,
|
||||
)
|
||||
|
||||
from .base import LLMProvider
|
||||
from .errors import (
|
||||
NoHealthyProviderError,
|
||||
ProviderAuthError,
|
||||
ProviderBlockedError,
|
||||
ProviderCapabilityError,
|
||||
ProviderError,
|
||||
ProviderRateLimitError,
|
||||
)
|
||||
from .ollama import OllamaProvider
|
||||
from .openai_compat import OpenAICompatProvider
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
class ProviderRouter:
|
||||
"""Health-aware provider router with alias resolution.
|
||||
|
||||
Construct with the AliasRegistry and ProviderHealthCache from
|
||||
application startup; both are external dependencies so tests can
|
||||
inject mocks without going through global state.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
aliases: AliasRegistry,
|
||||
health_cache: ProviderHealthCache,
|
||||
) -> None:
|
||||
self.aliases = aliases
|
||||
self.health_cache = health_cache
|
||||
self.providers: dict[str, LLMProvider] = {}
|
||||
self._initialize_providers()
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Provider initialisation
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _initialize_providers(self) -> None:
|
||||
"""Spin up provider adapters based on what's configured."""
|
||||
# Ollama: always present (talks to a local/proxied server). Whether
|
||||
# it's actually reachable is the cache's job to figure out.
|
||||
self.providers["ollama"] = OllamaProvider()
|
||||
logger.info("Initialized Ollama provider at %s", settings.ollama_url)
|
||||
|
||||
if settings.google_api_key:
|
||||
from .google import GoogleProvider
|
||||
|
||||
self.providers["google"] = GoogleProvider(
|
||||
api_key=settings.google_api_key,
|
||||
default_model=settings.google_default_model,
|
||||
)
|
||||
logger.info("Initialized Google Gemini provider")
|
||||
|
||||
if settings.openrouter_api_key:
|
||||
self.providers["openrouter"] = OpenAICompatProvider(
|
||||
name="openrouter",
|
||||
base_url=settings.openrouter_base_url,
|
||||
api_key=settings.openrouter_api_key,
|
||||
default_model=settings.openrouter_default_model,
|
||||
)
|
||||
logger.info("Initialized OpenRouter provider")
|
||||
|
||||
if settings.groq_api_key:
|
||||
self.providers["groq"] = OpenAICompatProvider(
|
||||
name="groq",
|
||||
base_url=settings.groq_base_url,
|
||||
api_key=settings.groq_api_key,
|
||||
)
|
||||
logger.info("Initialized Groq provider")
|
||||
|
||||
if settings.together_api_key:
|
||||
self.providers["together"] = OpenAICompatProvider(
|
||||
name="together",
|
||||
base_url=settings.together_base_url,
|
||||
api_key=settings.together_api_key,
|
||||
)
|
||||
logger.info("Initialized Together provider")
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _parse_model(self, model: str) -> tuple[str, str]:
|
||||
"""Split ``provider/model`` into its parts.
|
||||
|
||||
Bare names (no prefix) default to Ollama for compatibility with
|
||||
plain OpenAI-style requests. Aliases (``mana/...``) are resolved
|
||||
before this is ever called.
|
||||
"""
|
||||
if "/" in model:
|
||||
provider, _, model_name = model.partition("/")
|
||||
return provider.lower(), model_name
|
||||
return "ollama", model
|
||||
|
||||
def _resolve_chain(self, model_or_alias: str) -> list[str]:
|
||||
"""Expand aliases to chains; pass everything else through unchanged."""
|
||||
if AliasRegistry.is_alias(model_or_alias):
|
||||
return list(self.aliases.resolve_chain(model_or_alias))
|
||||
return [model_or_alias]
|
||||
|
||||
@staticmethod
|
||||
def _is_retryable(exc: BaseException) -> bool:
|
||||
"""Should we treat this exception as "try the next chain entry"?
|
||||
|
||||
ConnectError / timeouts / 5xx / rate-limits = yes (provider blip).
|
||||
Auth / capability / blocked / 4xx = no (caller has to fix
|
||||
something; retrying with a different provider only hides the bug).
|
||||
"""
|
||||
if isinstance(exc, ProviderCapabilityError):
|
||||
return False
|
||||
if isinstance(exc, ProviderBlockedError):
|
||||
return False
|
||||
if isinstance(exc, ProviderAuthError):
|
||||
return False
|
||||
if isinstance(exc, ProviderRateLimitError):
|
||||
return True
|
||||
if isinstance(
|
||||
exc,
|
||||
(
|
||||
httpx.ConnectError,
|
||||
httpx.ConnectTimeout,
|
||||
httpx.ReadError,
|
||||
httpx.ReadTimeout,
|
||||
httpx.RemoteProtocolError,
|
||||
httpx.WriteError,
|
||||
httpx.WriteTimeout,
|
||||
httpx.PoolTimeout,
|
||||
),
|
||||
):
|
||||
return True
|
||||
if isinstance(exc, httpx.HTTPStatusError):
|
||||
return exc.response.status_code >= 500
|
||||
if isinstance(exc, ProviderError):
|
||||
# Any other provider-side error — treat as retryable.
|
||||
# Subclasses with explicit non-retry semantics are caught above.
|
||||
return True
|
||||
# Unknown exception types: do NOT silently retry. Better to
|
||||
# surface a strange error than hide a real bug behind a fallback.
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
def _exception_summary(exc: BaseException) -> str:
|
||||
"""Compact one-liner for cache.last_error and log entries."""
|
||||
return f"{type(exc).__name__}: {exc}"
|
||||
|
||||
def _check_tool_capability(
|
||||
self,
|
||||
provider: LLMProvider,
|
||||
model_name: str,
|
||||
request: ChatCompletionRequest,
|
||||
) -> None:
|
||||
"""Refuse tool-bearing requests for models that don't support tools.
|
||||
|
||||
Silent downgrade (dropping the ``tools`` payload) is more dangerous
|
||||
than an explicit error — the caller would get plain text back and
|
||||
have no way to tell the tools never reached the model.
|
||||
"""
|
||||
if not request.tools:
|
||||
return
|
||||
if not provider.model_supports_tools(model_name):
|
||||
raise ProviderCapabilityError(
|
||||
f"{provider.name}/{model_name} does not support tool calling. "
|
||||
"Choose a tool-capable model (e.g. groq/llama-3.3-70b-versatile)"
|
||||
)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Core fallback executor (non-streaming)
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
async def _execute_with_fallback(
|
||||
self,
|
||||
model_or_alias: str,
|
||||
request: ChatCompletionRequest,
|
||||
call: Callable[[LLMProvider, str, ChatCompletionRequest], Awaitable[T]],
|
||||
) -> T:
|
||||
"""Walk the resolved chain, returning the first successful result.
|
||||
|
||||
``call`` is the operation to run against each chain entry, e.g.
|
||||
``lambda p, m, req: p.chat_completion(req, m)``. The function
|
||||
receives the provider instance, the model name (without the
|
||||
provider prefix), and the original request.
|
||||
"""
|
||||
chain = self._resolve_chain(model_or_alias)
|
||||
attempts: list[tuple[str, str]] = []
|
||||
last_exc: Exception | None = None
|
||||
is_alias = AliasRegistry.is_alias(model_or_alias)
|
||||
|
||||
for i, entry in enumerate(chain):
|
||||
next_entry = chain[i + 1] if i + 1 < len(chain) else ""
|
||||
provider_name, model_name = self._parse_model(entry)
|
||||
if provider_name not in self.providers:
|
||||
logger.debug(
|
||||
"skip chain entry %s — provider %s not configured here",
|
||||
entry,
|
||||
provider_name,
|
||||
)
|
||||
attempts.append((entry, "unconfigured"))
|
||||
record_fallback(entry, next_entry, "unconfigured")
|
||||
continue
|
||||
|
||||
if not self.health_cache.is_healthy(provider_name):
|
||||
logger.debug("skip chain entry %s — cache says unhealthy", entry)
|
||||
attempts.append((entry, "cache-unhealthy"))
|
||||
record_fallback(entry, next_entry, "cache-unhealthy")
|
||||
continue
|
||||
|
||||
provider = self.providers[provider_name]
|
||||
self._check_tool_capability(provider, model_name, request)
|
||||
|
||||
try:
|
||||
logger.info(
|
||||
"execute → %s (alias=%s)",
|
||||
entry,
|
||||
model_or_alias if is_alias else "<direct>",
|
||||
)
|
||||
result = await call(provider, model_name, request)
|
||||
self.health_cache.mark_healthy(provider_name)
|
||||
set_provider_healthy(provider_name, True)
|
||||
if is_alias:
|
||||
record_alias_resolved(model_or_alias, entry)
|
||||
return result
|
||||
except Exception as e:
|
||||
if not self._is_retryable(e):
|
||||
# Caller error / non-retryable provider error — propagate
|
||||
# without touching the health cache. The cache is for
|
||||
# liveness, not for recording what the user asked for
|
||||
# being wrong.
|
||||
raise
|
||||
self.health_cache.mark_unhealthy(provider_name, self._exception_summary(e))
|
||||
set_provider_healthy(
|
||||
provider_name, self.health_cache.is_healthy(provider_name)
|
||||
)
|
||||
attempts.append((entry, type(e).__name__))
|
||||
record_fallback(entry, next_entry, type(e).__name__)
|
||||
last_exc = e
|
||||
logger.warning(
|
||||
"execute %s failed (retryable, will try next): %s",
|
||||
entry,
|
||||
e,
|
||||
)
|
||||
|
||||
raise NoHealthyProviderError(model_or_alias, attempts, last_exc)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Public API — non-streaming
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
async def chat_completion(
|
||||
self,
|
||||
request: ChatCompletionRequest,
|
||||
) -> ChatCompletionResponse:
|
||||
"""Chat completion with alias resolution + health-aware fallback."""
|
||||
|
||||
async def call(provider: LLMProvider, model: str, req: ChatCompletionRequest):
|
||||
return await provider.chat_completion(req, model)
|
||||
|
||||
return await self._execute_with_fallback(request.model, request, call)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Public API — streaming (pre-first-byte fallback)
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
async def chat_completion_stream(
|
||||
self,
|
||||
request: ChatCompletionRequest,
|
||||
) -> AsyncIterator[ChatCompletionStreamResponse]:
|
||||
"""Streaming variant. Falls back BEFORE the first chunk arrives;
|
||||
once the first chunk has been yielded the provider is committed
|
||||
and any further error propagates.
|
||||
|
||||
Why pre-first-byte only: stitching half-streams from two different
|
||||
providers would mix two voices in the output and is impossible to
|
||||
sanity-check after the fact.
|
||||
"""
|
||||
chain = self._resolve_chain(request.model)
|
||||
attempts: list[tuple[str, str]] = []
|
||||
last_exc: Exception | None = None
|
||||
is_alias = AliasRegistry.is_alias(request.model)
|
||||
|
||||
for i, entry in enumerate(chain):
|
||||
next_entry = chain[i + 1] if i + 1 < len(chain) else ""
|
||||
provider_name, model_name = self._parse_model(entry)
|
||||
if provider_name not in self.providers:
|
||||
attempts.append((entry, "unconfigured"))
|
||||
record_fallback(entry, next_entry, "unconfigured")
|
||||
continue
|
||||
if not self.health_cache.is_healthy(provider_name):
|
||||
attempts.append((entry, "cache-unhealthy"))
|
||||
record_fallback(entry, next_entry, "cache-unhealthy")
|
||||
continue
|
||||
|
||||
provider = self.providers[provider_name]
|
||||
self._check_tool_capability(provider, model_name, request)
|
||||
|
||||
stream = provider.chat_completion_stream(request, model_name)
|
||||
try:
|
||||
first_chunk = await stream.__anext__()
|
||||
except StopAsyncIteration:
|
||||
# Empty stream is a successful but content-free response.
|
||||
# Commit and exit cleanly.
|
||||
self.health_cache.mark_healthy(provider_name)
|
||||
set_provider_healthy(provider_name, True)
|
||||
if is_alias:
|
||||
record_alias_resolved(request.model, entry)
|
||||
logger.info("stream %s yielded empty response", entry)
|
||||
return
|
||||
except Exception as e:
|
||||
if not self._is_retryable(e):
|
||||
raise
|
||||
self.health_cache.mark_unhealthy(provider_name, self._exception_summary(e))
|
||||
set_provider_healthy(
|
||||
provider_name, self.health_cache.is_healthy(provider_name)
|
||||
)
|
||||
attempts.append((entry, type(e).__name__))
|
||||
record_fallback(entry, next_entry, type(e).__name__)
|
||||
last_exc = e
|
||||
logger.warning(
|
||||
"stream %s failed before first byte (retryable, trying next): %s",
|
||||
entry,
|
||||
e,
|
||||
)
|
||||
continue
|
||||
|
||||
# First byte landed — commit the provider, mark healthy, drain
|
||||
# the rest of the stream. Any error from here on propagates;
|
||||
# it is NOT safe to splice another provider's output in.
|
||||
self.health_cache.mark_healthy(provider_name)
|
||||
set_provider_healthy(provider_name, True)
|
||||
if is_alias:
|
||||
record_alias_resolved(request.model, entry)
|
||||
logger.info("stream → %s (committed after first chunk)", entry)
|
||||
yield first_chunk
|
||||
async for chunk in stream:
|
||||
yield chunk
|
||||
return
|
||||
|
||||
raise NoHealthyProviderError(request.model, attempts, last_exc)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Embeddings — no fallback (out of scope for M3, separate concerns)
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
async def embeddings(self, request: EmbeddingRequest) -> EmbeddingResponse:
|
||||
"""Route an embeddings request directly. No alias / fallback."""
|
||||
provider_name, model_name = self._parse_model(request.model)
|
||||
if provider_name not in self.providers:
|
||||
available = list(self.providers)
|
||||
raise ValueError(
|
||||
f"Provider '{provider_name}' not available. Available: {available}"
|
||||
)
|
||||
provider = self.providers[provider_name]
|
||||
logger.info("embeddings → %s/%s", provider_name, model_name)
|
||||
return await provider.embeddings(request, model_name)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Discovery / introspection
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
async def list_models(self) -> list[ModelInfo]:
|
||||
"""List all available models from all configured providers.
|
||||
|
||||
Best-effort: providers that error are skipped with a warning so a
|
||||
single broken provider can't take down ``GET /v1/models``.
|
||||
"""
|
||||
all_models: list[ModelInfo] = []
|
||||
for provider in self.providers.values():
|
||||
try:
|
||||
all_models.extend(await provider.list_models())
|
||||
except Exception as e: # noqa: BLE001
|
||||
logger.warning("Failed to list models from %s: %s", provider.name, e)
|
||||
return all_models
|
||||
|
||||
async def get_model(self, model_id: str) -> ModelInfo | None:
|
||||
"""Look up a single model by id, dispatching on the prefix."""
|
||||
provider_name, model_name = self._parse_model(model_id)
|
||||
if provider_name not in self.providers:
|
||||
return None
|
||||
models = await self.providers[provider_name].list_models()
|
||||
for m in models:
|
||||
if m.id == model_id or m.id.endswith(f"/{model_name}"):
|
||||
return m
|
||||
return None
|
||||
|
||||
async def health_check(self) -> dict[str, Any]:
|
||||
"""Snapshot of the per-provider liveness cache.
|
||||
|
||||
Returns the same shape as before for backwards-compat with
|
||||
``GET /health`` (deprecated structure — M4 will swap to a
|
||||
cleaner ``/v1/health`` endpoint).
|
||||
"""
|
||||
snapshot = self.health_cache.snapshot(expected=list(self.providers))
|
||||
providers_out: dict[str, Any] = {}
|
||||
for name, state in snapshot.items():
|
||||
providers_out[name] = {
|
||||
"status": "healthy" if state.healthy else "unhealthy",
|
||||
"consecutive_failures": state.consecutive_failures,
|
||||
"last_error": state.last_error,
|
||||
"last_check_unix": state.last_check or None,
|
||||
"unhealthy_until_unix": state.unhealthy_until or None,
|
||||
}
|
||||
|
||||
all_healthy = all(state.healthy for state in snapshot.values())
|
||||
any_healthy = any(state.healthy for state in snapshot.values())
|
||||
return {
|
||||
"status": "healthy" if all_healthy else ("degraded" if any_healthy else "unhealthy"),
|
||||
"providers": providers_out,
|
||||
}
|
||||
|
||||
async def close(self) -> None:
|
||||
"""Close all provider clients."""
|
||||
for provider in self.providers.values():
|
||||
await provider.close()
|
||||
|
|
@ -1,5 +0,0 @@
|
|||
"""Streaming utilities for SSE responses."""
|
||||
|
||||
from .sse import stream_chat_completion
|
||||
|
||||
__all__ = ["stream_chat_completion"]
|
||||
|
|
@ -1,52 +0,0 @@
|
|||
"""Server-Sent Events (SSE) response handling."""
|
||||
|
||||
import json
|
||||
import logging
|
||||
from collections.abc import AsyncIterator
|
||||
|
||||
from src.models import ChatCompletionRequest, ChatCompletionStreamResponse
|
||||
from src.providers import ProviderRouter
|
||||
from src.providers.errors import ProviderError
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
async def stream_chat_completion(
|
||||
router: ProviderRouter,
|
||||
request: ChatCompletionRequest,
|
||||
) -> AsyncIterator[dict]:
|
||||
"""
|
||||
Stream chat completion responses for SSE.
|
||||
|
||||
Yields dicts that EventSourceResponse will serialize as:
|
||||
data: {"choices":[{"delta":{"content":"Hello"}}]}
|
||||
data: [DONE]
|
||||
"""
|
||||
try:
|
||||
async for chunk in router.chat_completion_stream(request):
|
||||
# Yield dict for EventSourceResponse to serialize
|
||||
yield {"data": json.dumps(chunk.model_dump(exclude_none=True))}
|
||||
|
||||
# Send final [DONE] marker
|
||||
yield {"data": "[DONE]"}
|
||||
|
||||
except ProviderError as e:
|
||||
logger.warning(f"Streaming provider error: kind={e.kind} detail={e}")
|
||||
error_data = {
|
||||
"error": {
|
||||
"message": str(e),
|
||||
"type": e.kind,
|
||||
}
|
||||
}
|
||||
yield {"data": json.dumps(error_data)}
|
||||
yield {"data": "[DONE]"}
|
||||
except Exception as e:
|
||||
logger.error(f"Streaming error: {e}")
|
||||
error_data = {
|
||||
"error": {
|
||||
"message": str(e),
|
||||
"type": "server_error",
|
||||
}
|
||||
}
|
||||
yield {"data": json.dumps(error_data)}
|
||||
yield {"data": "[DONE]"}
|
||||
|
|
@ -1,5 +0,0 @@
|
|||
"""Utility modules."""
|
||||
|
||||
from .metrics import get_metrics, metrics_middleware
|
||||
|
||||
__all__ = ["get_metrics", "metrics_middleware"]
|
||||
|
|
@ -1,85 +0,0 @@
|
|||
"""Redis caching utilities (optional)."""
|
||||
|
||||
import hashlib
|
||||
import json
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from src.config import settings
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Redis client (lazy initialized)
|
||||
_redis_client = None
|
||||
|
||||
|
||||
async def get_redis_client():
|
||||
"""Get or create Redis client."""
|
||||
global _redis_client
|
||||
|
||||
if _redis_client is not None:
|
||||
return _redis_client
|
||||
|
||||
if not settings.redis_url:
|
||||
return None
|
||||
|
||||
try:
|
||||
import redis.asyncio as redis
|
||||
|
||||
_redis_client = redis.from_url(settings.redis_url)
|
||||
# Test connection
|
||||
await _redis_client.ping()
|
||||
logger.info(f"Connected to Redis at {settings.redis_url}")
|
||||
return _redis_client
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to connect to Redis: {e}")
|
||||
return None
|
||||
|
||||
|
||||
def generate_cache_key(prefix: str, data: dict[str, Any]) -> str:
|
||||
"""Generate a cache key from request data."""
|
||||
# Serialize and hash the data for consistent key
|
||||
serialized = json.dumps(data, sort_keys=True)
|
||||
hash_value = hashlib.sha256(serialized.encode()).hexdigest()[:16]
|
||||
return f"mana-llm:{prefix}:{hash_value}"
|
||||
|
||||
|
||||
async def get_cached(key: str) -> dict[str, Any] | None:
|
||||
"""Get cached value by key."""
|
||||
client = await get_redis_client()
|
||||
if client is None:
|
||||
return None
|
||||
|
||||
try:
|
||||
value = await client.get(key)
|
||||
if value:
|
||||
return json.loads(value)
|
||||
except Exception as e:
|
||||
logger.warning(f"Cache get failed: {e}")
|
||||
|
||||
return None
|
||||
|
||||
|
||||
async def set_cached(key: str, value: dict[str, Any], ttl: int | None = None) -> bool:
|
||||
"""Set cached value with optional TTL."""
|
||||
client = await get_redis_client()
|
||||
if client is None:
|
||||
return False
|
||||
|
||||
try:
|
||||
ttl = ttl or settings.cache_ttl
|
||||
serialized = json.dumps(value)
|
||||
await client.setex(key, ttl, serialized)
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.warning(f"Cache set failed: {e}")
|
||||
return False
|
||||
|
||||
|
||||
async def close_redis() -> None:
|
||||
"""Close Redis connection."""
|
||||
global _redis_client
|
||||
|
||||
if _redis_client is not None:
|
||||
await _redis_client.aclose()
|
||||
_redis_client = None
|
||||
|
|
@ -1,158 +0,0 @@
|
|||
"""Prometheus metrics for mana-llm."""
|
||||
|
||||
import time
|
||||
from collections.abc import Callable
|
||||
|
||||
from fastapi import Request, Response
|
||||
from prometheus_client import Counter, Gauge, Histogram, generate_latest
|
||||
from starlette.middleware.base import BaseHTTPMiddleware
|
||||
|
||||
# Request metrics
|
||||
REQUEST_COUNT = Counter(
|
||||
"mana_llm_requests_total",
|
||||
"Total number of requests",
|
||||
["method", "endpoint", "status"],
|
||||
)
|
||||
|
||||
REQUEST_LATENCY = Histogram(
|
||||
"mana_llm_request_latency_seconds",
|
||||
"Request latency in seconds",
|
||||
["method", "endpoint"],
|
||||
buckets=[0.1, 0.5, 1.0, 2.0, 5.0, 10.0, 30.0, 60.0, 120.0],
|
||||
)
|
||||
|
||||
# LLM-specific metrics
|
||||
LLM_REQUEST_COUNT = Counter(
|
||||
"mana_llm_llm_requests_total",
|
||||
"Total number of LLM requests",
|
||||
["provider", "model", "streaming"],
|
||||
)
|
||||
|
||||
LLM_TOKEN_COUNT = Counter(
|
||||
"mana_llm_tokens_total",
|
||||
"Total tokens processed",
|
||||
["provider", "model", "type"], # type: prompt, completion
|
||||
)
|
||||
|
||||
LLM_LATENCY = Histogram(
|
||||
"mana_llm_llm_latency_seconds",
|
||||
"LLM request latency in seconds",
|
||||
["provider", "model"],
|
||||
buckets=[0.5, 1.0, 2.0, 5.0, 10.0, 30.0, 60.0, 120.0],
|
||||
)
|
||||
|
||||
LLM_ERRORS = Counter(
|
||||
"mana_llm_llm_errors_total",
|
||||
"Total LLM errors",
|
||||
["provider", "model", "error_type"],
|
||||
)
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Alias / fallback / health metrics — added in M4 of llm-fallback-aliases.md.
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
ALIAS_RESOLVED = Counter(
|
||||
"mana_llm_alias_resolved_total",
|
||||
"How often an alias resolved to a concrete provider/model. The `target` "
|
||||
"label is the chain entry that actually served the request — useful for "
|
||||
"spotting cases where the primary always falls through to a cloud entry.",
|
||||
["alias", "target"],
|
||||
)
|
||||
|
||||
FALLBACK_TRIGGERED = Counter(
|
||||
"mana_llm_fallback_total",
|
||||
"Fallback transitions: a chain entry failed (or was skipped via cache) "
|
||||
"and the router moved to the next entry. `reason` is the exception class "
|
||||
"name or `cache-unhealthy` / `unconfigured`. `from_model` is the entry "
|
||||
"that didn't serve, `to_model` is empty when no further entries existed.",
|
||||
["from_model", "to_model", "reason"],
|
||||
)
|
||||
|
||||
PROVIDER_HEALTHY = Gauge(
|
||||
"mana_llm_provider_healthy",
|
||||
"1 when the provider is currently considered healthy by the cache, "
|
||||
"0 when in backoff. Refreshed on every probe tick and on every router "
|
||||
"call-site state transition.",
|
||||
["provider"],
|
||||
)
|
||||
|
||||
|
||||
def get_metrics() -> bytes:
|
||||
"""Generate Prometheus metrics output."""
|
||||
return generate_latest()
|
||||
|
||||
|
||||
class MetricsMiddleware(BaseHTTPMiddleware):
|
||||
"""Middleware for collecting HTTP metrics."""
|
||||
|
||||
async def dispatch(self, request: Request, call_next: Callable) -> Response:
|
||||
start_time = time.time()
|
||||
|
||||
response = await call_next(request)
|
||||
|
||||
# Record metrics
|
||||
duration = time.time() - start_time
|
||||
endpoint = request.url.path
|
||||
method = request.method
|
||||
status = str(response.status_code)
|
||||
|
||||
REQUEST_COUNT.labels(method=method, endpoint=endpoint, status=status).inc()
|
||||
REQUEST_LATENCY.labels(method=method, endpoint=endpoint).observe(duration)
|
||||
|
||||
return response
|
||||
|
||||
|
||||
# Export middleware instance
|
||||
metrics_middleware = MetricsMiddleware
|
||||
|
||||
|
||||
def record_llm_request(
|
||||
provider: str,
|
||||
model: str,
|
||||
streaming: bool,
|
||||
prompt_tokens: int = 0,
|
||||
completion_tokens: int = 0,
|
||||
latency: float | None = None,
|
||||
) -> None:
|
||||
"""Record LLM request metrics."""
|
||||
LLM_REQUEST_COUNT.labels(
|
||||
provider=provider,
|
||||
model=model,
|
||||
streaming=str(streaming).lower(),
|
||||
).inc()
|
||||
|
||||
if prompt_tokens > 0:
|
||||
LLM_TOKEN_COUNT.labels(provider=provider, model=model, type="prompt").inc(prompt_tokens)
|
||||
|
||||
if completion_tokens > 0:
|
||||
LLM_TOKEN_COUNT.labels(provider=provider, model=model, type="completion").inc(
|
||||
completion_tokens
|
||||
)
|
||||
|
||||
if latency is not None:
|
||||
LLM_LATENCY.labels(provider=provider, model=model).observe(latency)
|
||||
|
||||
|
||||
def record_llm_error(provider: str, model: str, error_type: str) -> None:
|
||||
"""Record LLM error metrics."""
|
||||
LLM_ERRORS.labels(provider=provider, model=model, error_type=error_type).inc()
|
||||
|
||||
|
||||
def record_alias_resolved(alias: str, target: str) -> None:
|
||||
"""Record which concrete model an alias resolved to for this request."""
|
||||
ALIAS_RESOLVED.labels(alias=alias, target=target).inc()
|
||||
|
||||
|
||||
def record_fallback(from_model: str, to_model: str, reason: str) -> None:
|
||||
"""Record a fallback transition. ``to_model`` is empty when the chain
|
||||
ran out (i.e. NoHealthyProviderError)."""
|
||||
FALLBACK_TRIGGERED.labels(
|
||||
from_model=from_model,
|
||||
to_model=to_model,
|
||||
reason=reason,
|
||||
).inc()
|
||||
|
||||
|
||||
def set_provider_healthy(provider: str, healthy: bool) -> None:
|
||||
"""Mirror ``ProviderHealthCache`` state into a Prometheus gauge."""
|
||||
PROVIDER_HEALTHY.labels(provider=provider).set(1.0 if healthy else 0.0)
|
||||
|
|
@ -1,28 +0,0 @@
|
|||
#!/bin/bash
|
||||
|
||||
# Start mana-llm service
|
||||
# Automatically creates venv and installs dependencies if needed
|
||||
|
||||
cd "$(dirname "$0")"
|
||||
|
||||
# Check if venv exists, create if not
|
||||
if [ ! -d "venv" ]; then
|
||||
echo "Creating virtual environment..."
|
||||
python3 -m venv venv
|
||||
fi
|
||||
|
||||
# Activate venv
|
||||
source venv/bin/activate
|
||||
|
||||
# Install/update dependencies
|
||||
pip install -q -r requirements.txt
|
||||
|
||||
# Copy .env if not exists
|
||||
if [ ! -f ".env" ] && [ -f ".env.example" ]; then
|
||||
cp .env.example .env
|
||||
echo "Created .env from .env.example"
|
||||
fi
|
||||
|
||||
# Start the service
|
||||
echo "Starting mana-llm on port 3025..."
|
||||
exec python -m uvicorn src.main:app --port 3025 --reload
|
||||
|
|
@ -1 +0,0 @@
|
|||
"""Tests for mana-llm service."""
|
||||
|
|
@ -1,300 +0,0 @@
|
|||
"""Tests for the model-alias registry."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from src.aliases import (
|
||||
ALIAS_PREFIX,
|
||||
Alias,
|
||||
AliasConfigError,
|
||||
AliasRegistry,
|
||||
UnknownAliasError,
|
||||
)
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def write_yaml(tmp_path: Path, body: str, name: str = "aliases.yaml") -> Path:
|
||||
p = tmp_path / name
|
||||
p.write_text(body, encoding="utf-8")
|
||||
return p
|
||||
|
||||
|
||||
VALID_CONFIG = """\
|
||||
aliases:
|
||||
mana/fast-text:
|
||||
description: "fast"
|
||||
chain:
|
||||
- ollama/qwen2.5:7b
|
||||
- groq/llama-3.1-8b-instant
|
||||
mana/long-form:
|
||||
description: "long"
|
||||
chain:
|
||||
- ollama/gemma3:12b
|
||||
- groq/llama-3.3-70b-versatile
|
||||
default: mana/fast-text
|
||||
"""
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Construction & happy-path resolution
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestRegistryHappyPath:
|
||||
def test_loads_valid_yaml(self, tmp_path: Path) -> None:
|
||||
path = write_yaml(tmp_path, VALID_CONFIG)
|
||||
reg = AliasRegistry(path)
|
||||
assert reg.path == path
|
||||
assert reg.default_alias == "mana/fast-text"
|
||||
|
||||
def test_resolve_returns_alias_dataclass(self, tmp_path: Path) -> None:
|
||||
reg = AliasRegistry(write_yaml(tmp_path, VALID_CONFIG))
|
||||
alias = reg.resolve("mana/long-form")
|
||||
assert isinstance(alias, Alias)
|
||||
assert alias.name == "mana/long-form"
|
||||
assert alias.description == "long"
|
||||
assert alias.chain == ("ollama/gemma3:12b", "groq/llama-3.3-70b-versatile")
|
||||
|
||||
def test_resolve_chain_returns_tuple(self, tmp_path: Path) -> None:
|
||||
reg = AliasRegistry(write_yaml(tmp_path, VALID_CONFIG))
|
||||
chain = reg.resolve_chain("mana/fast-text")
|
||||
assert chain == ("ollama/qwen2.5:7b", "groq/llama-3.1-8b-instant")
|
||||
# Tuples ensure callers can't mutate the registry's internal state.
|
||||
assert isinstance(chain, tuple)
|
||||
|
||||
def test_list_aliases_sorted(self, tmp_path: Path) -> None:
|
||||
reg = AliasRegistry(write_yaml(tmp_path, VALID_CONFIG))
|
||||
names = [a.name for a in reg.list_aliases()]
|
||||
assert names == sorted(names)
|
||||
assert names == ["mana/fast-text", "mana/long-form"]
|
||||
|
||||
def test_unknown_alias_raises(self, tmp_path: Path) -> None:
|
||||
reg = AliasRegistry(write_yaml(tmp_path, VALID_CONFIG))
|
||||
with pytest.raises(UnknownAliasError, match="mana/nope"):
|
||||
reg.resolve("mana/nope")
|
||||
|
||||
def test_default_optional(self, tmp_path: Path) -> None:
|
||||
body = (
|
||||
"aliases:\n"
|
||||
" mana/x:\n"
|
||||
' description: "x"\n'
|
||||
" chain:\n"
|
||||
" - ollama/foo:1b\n"
|
||||
)
|
||||
reg = AliasRegistry(write_yaml(tmp_path, body))
|
||||
assert reg.default_alias is None
|
||||
|
||||
|
||||
class TestIsAlias:
|
||||
"""``is_alias`` is a cheap static syntactic check used by the router."""
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"name",
|
||||
["mana/fast-text", "mana/anything", f"{ALIAS_PREFIX}foo"],
|
||||
)
|
||||
def test_recognises_alias_namespace(self, name: str) -> None:
|
||||
assert AliasRegistry.is_alias(name) is True
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"name",
|
||||
["ollama/gemma3:4b", "groq/llama", "gemma3:4b", "", "mana", "manaX/foo"],
|
||||
)
|
||||
def test_rejects_non_alias(self, name: str) -> None:
|
||||
assert AliasRegistry.is_alias(name) is False
|
||||
|
||||
def test_static_no_instance_needed(self) -> None:
|
||||
# Important: callers can hit this without instantiating, so it must
|
||||
# be a free function or @staticmethod.
|
||||
assert AliasRegistry.is_alias("mana/x") is True
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Schema validation — the YAML is user-edited, must fail loudly on typos
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestSchemaValidation:
|
||||
def test_missing_file_raises(self, tmp_path: Path) -> None:
|
||||
with pytest.raises(AliasConfigError, match="not found"):
|
||||
AliasRegistry(tmp_path / "absent.yaml")
|
||||
|
||||
def test_invalid_yaml_raises(self, tmp_path: Path) -> None:
|
||||
path = write_yaml(tmp_path, "aliases: [\n unclosed")
|
||||
with pytest.raises(AliasConfigError, match="failed to parse"):
|
||||
AliasRegistry(path)
|
||||
|
||||
def test_root_not_a_mapping(self, tmp_path: Path) -> None:
|
||||
path = write_yaml(tmp_path, "- just-a-list\n")
|
||||
with pytest.raises(AliasConfigError, match="root must be a mapping"):
|
||||
AliasRegistry(path)
|
||||
|
||||
def test_aliases_must_be_mapping(self, tmp_path: Path) -> None:
|
||||
path = write_yaml(tmp_path, "aliases: just-a-string\n")
|
||||
with pytest.raises(AliasConfigError, match="`aliases` must be a mapping"):
|
||||
AliasRegistry(path)
|
||||
|
||||
def test_empty_aliases_rejected(self, tmp_path: Path) -> None:
|
||||
path = write_yaml(tmp_path, "aliases: {}\n")
|
||||
with pytest.raises(AliasConfigError, match="empty"):
|
||||
AliasRegistry(path)
|
||||
|
||||
def test_alias_name_must_use_mana_namespace(self, tmp_path: Path) -> None:
|
||||
body = (
|
||||
"aliases:\n"
|
||||
" fast-text:\n"
|
||||
' description: "x"\n'
|
||||
" chain:\n"
|
||||
" - ollama/foo:1b\n"
|
||||
)
|
||||
path = write_yaml(tmp_path, body)
|
||||
with pytest.raises(AliasConfigError, match="mana/"):
|
||||
AliasRegistry(path)
|
||||
|
||||
def test_alias_name_must_have_one_segment(self, tmp_path: Path) -> None:
|
||||
body = (
|
||||
"aliases:\n"
|
||||
" mana/foo/bar:\n"
|
||||
' description: "x"\n'
|
||||
" chain:\n"
|
||||
" - ollama/foo:1b\n"
|
||||
)
|
||||
path = write_yaml(tmp_path, body)
|
||||
with pytest.raises(AliasConfigError, match="exactly one segment"):
|
||||
AliasRegistry(path)
|
||||
|
||||
def test_chain_must_be_list(self, tmp_path: Path) -> None:
|
||||
body = (
|
||||
"aliases:\n"
|
||||
" mana/x:\n"
|
||||
' description: "x"\n'
|
||||
' chain: "ollama/gemma3:4b"\n'
|
||||
)
|
||||
path = write_yaml(tmp_path, body)
|
||||
with pytest.raises(AliasConfigError, match="chain must be a list"):
|
||||
AliasRegistry(path)
|
||||
|
||||
def test_empty_chain_rejected(self, tmp_path: Path) -> None:
|
||||
body = "aliases:\n mana/x:\n chain: []\n"
|
||||
path = write_yaml(tmp_path, body)
|
||||
with pytest.raises(AliasConfigError, match="must not be empty"):
|
||||
AliasRegistry(path)
|
||||
|
||||
def test_chain_entry_without_provider_prefix_rejected(self, tmp_path: Path) -> None:
|
||||
# "gemma3:4b" without a provider/ prefix would silently default to
|
||||
# ollama and confuse the health-cache; reject loudly at config-load.
|
||||
body = "aliases:\n mana/x:\n chain:\n - gemma3:4b\n"
|
||||
path = write_yaml(tmp_path, body)
|
||||
with pytest.raises(AliasConfigError, match="provider prefix"):
|
||||
AliasRegistry(path)
|
||||
|
||||
def test_chain_entry_must_be_string(self, tmp_path: Path) -> None:
|
||||
body = "aliases:\n mana/x:\n chain:\n - 42\n"
|
||||
path = write_yaml(tmp_path, body)
|
||||
with pytest.raises(AliasConfigError):
|
||||
AliasRegistry(path)
|
||||
|
||||
def test_default_must_reference_known_alias(self, tmp_path: Path) -> None:
|
||||
body = (
|
||||
"aliases:\n"
|
||||
" mana/x:\n"
|
||||
' description: "x"\n'
|
||||
" chain:\n"
|
||||
" - ollama/foo:1b\n"
|
||||
"default: mana/missing\n"
|
||||
)
|
||||
path = write_yaml(tmp_path, body)
|
||||
with pytest.raises(AliasConfigError, match="references unknown alias"):
|
||||
AliasRegistry(path)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Reload semantics — SIGHUP should be safe even with typos
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestReload:
|
||||
def test_reload_picks_up_edits(self, tmp_path: Path) -> None:
|
||||
path = write_yaml(tmp_path, VALID_CONFIG)
|
||||
reg = AliasRegistry(path)
|
||||
assert reg.resolve_chain("mana/long-form") == (
|
||||
"ollama/gemma3:12b",
|
||||
"groq/llama-3.3-70b-versatile",
|
||||
)
|
||||
|
||||
# Edit on disk: shrink the long-form chain.
|
||||
new_body = (
|
||||
"aliases:\n"
|
||||
" mana/long-form:\n"
|
||||
' description: "shorter"\n'
|
||||
" chain:\n"
|
||||
" - groq/llama-3.3-70b-versatile\n"
|
||||
"default: mana/long-form\n"
|
||||
)
|
||||
path.write_text(new_body, encoding="utf-8")
|
||||
reg.reload()
|
||||
|
||||
assert reg.resolve_chain("mana/long-form") == ("groq/llama-3.3-70b-versatile",)
|
||||
assert reg.default_alias == "mana/long-form"
|
||||
# Aliases that disappeared from the new file are gone.
|
||||
with pytest.raises(UnknownAliasError):
|
||||
reg.resolve("mana/fast-text")
|
||||
|
||||
def test_reload_keeps_old_state_on_parse_error(self, tmp_path: Path) -> None:
|
||||
path = write_yaml(tmp_path, VALID_CONFIG)
|
||||
reg = AliasRegistry(path)
|
||||
# First reload fine — establish a baseline.
|
||||
reg.reload()
|
||||
|
||||
# Now break the file with an obviously invalid yaml.
|
||||
path.write_text("aliases: [unclosed\n", encoding="utf-8")
|
||||
with pytest.raises(AliasConfigError):
|
||||
reg.reload()
|
||||
|
||||
# The previous good state must still be queryable — service stays up.
|
||||
assert reg.resolve_chain("mana/fast-text") == (
|
||||
"ollama/qwen2.5:7b",
|
||||
"groq/llama-3.1-8b-instant",
|
||||
)
|
||||
assert reg.default_alias == "mana/fast-text"
|
||||
|
||||
def test_reload_keeps_old_state_on_schema_error(self, tmp_path: Path) -> None:
|
||||
path = write_yaml(tmp_path, VALID_CONFIG)
|
||||
reg = AliasRegistry(path)
|
||||
|
||||
# Empty aliases — would be rejected on first load, must also be
|
||||
# rejected here without nuking the in-memory state.
|
||||
path.write_text("aliases: {}\n", encoding="utf-8")
|
||||
with pytest.raises(AliasConfigError):
|
||||
reg.reload()
|
||||
|
||||
assert "mana/fast-text" in [a.name for a in reg.list_aliases()]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Repo-shipped aliases.yaml is itself valid
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestShippedConfig:
|
||||
def test_repo_aliases_yaml_loads(self) -> None:
|
||||
# The yaml file checked into services/mana-llm/aliases.yaml is the
|
||||
# one that runs in production. It must always parse cleanly — this
|
||||
# test catches editor accidents before they ship.
|
||||
repo_yaml = Path(__file__).resolve().parents[1] / "aliases.yaml"
|
||||
assert repo_yaml.exists(), f"shipped config missing at {repo_yaml}"
|
||||
reg = AliasRegistry(repo_yaml)
|
||||
# Sanity: the five classes the plan calls out must exist.
|
||||
for expected in (
|
||||
"mana/fast-text",
|
||||
"mana/long-form",
|
||||
"mana/structured",
|
||||
"mana/reasoning",
|
||||
"mana/vision",
|
||||
):
|
||||
reg.resolve(expected)
|
||||
|
|
@ -1,41 +0,0 @@
|
|||
"""API endpoint tests."""
|
||||
|
||||
import pytest
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def client():
|
||||
"""Create test client."""
|
||||
from src.main import app
|
||||
|
||||
with TestClient(app) as c:
|
||||
yield c
|
||||
|
||||
|
||||
def test_health_endpoint(client):
|
||||
"""Test health check endpoint."""
|
||||
response = client.get("/health")
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert "status" in data
|
||||
assert "service" in data
|
||||
assert data["service"] == "mana-llm"
|
||||
|
||||
|
||||
def test_metrics_endpoint(client):
|
||||
"""Test metrics endpoint."""
|
||||
response = client.get("/metrics")
|
||||
assert response.status_code == 200
|
||||
assert "mana_llm" in response.text
|
||||
|
||||
|
||||
def test_list_models_endpoint(client):
|
||||
"""Test list models endpoint."""
|
||||
response = client.get("/v1/models")
|
||||
# May fail if Ollama is not running, but should return valid response structure
|
||||
if response.status_code == 200:
|
||||
data = response.json()
|
||||
assert "data" in data
|
||||
assert "object" in data
|
||||
assert data["object"] == "list"
|
||||
|
|
@ -1,203 +0,0 @@
|
|||
"""Tests for the provider health cache."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
from src.health import (
|
||||
DEFAULT_FAILURE_THRESHOLD,
|
||||
DEFAULT_UNHEALTHY_BACKOFF_SEC,
|
||||
ProviderHealthCache,
|
||||
ProviderState,
|
||||
)
|
||||
|
||||
|
||||
class FakeClock:
|
||||
"""Deterministic clock for circuit-breaker timing tests."""
|
||||
|
||||
def __init__(self, start: float = 1_000_000.0) -> None:
|
||||
self.now = start
|
||||
|
||||
def __call__(self) -> float:
|
||||
return self.now
|
||||
|
||||
def advance(self, seconds: float) -> None:
|
||||
self.now += seconds
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Construction
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestConstruction:
|
||||
def test_defaults(self) -> None:
|
||||
c = ProviderHealthCache()
|
||||
assert c.failure_threshold == DEFAULT_FAILURE_THRESHOLD
|
||||
assert c.unhealthy_backoff_sec == DEFAULT_UNHEALTHY_BACKOFF_SEC
|
||||
|
||||
def test_invalid_threshold_rejected(self) -> None:
|
||||
with pytest.raises(ValueError, match="failure_threshold"):
|
||||
ProviderHealthCache(failure_threshold=0)
|
||||
|
||||
def test_invalid_backoff_rejected(self) -> None:
|
||||
with pytest.raises(ValueError, match="backoff"):
|
||||
ProviderHealthCache(unhealthy_backoff_sec=-1.0)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Default-healthy behaviour
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestDefaults:
|
||||
def test_unknown_provider_is_healthy(self) -> None:
|
||||
# The cache is observation-only — no entry means "no reason to skip".
|
||||
c = ProviderHealthCache()
|
||||
assert c.is_healthy("ollama") is True
|
||||
|
||||
def test_unknown_provider_has_no_state(self) -> None:
|
||||
c = ProviderHealthCache()
|
||||
assert c.get_state("ollama") is None
|
||||
|
||||
def test_snapshot_includes_expected_zero_state(self) -> None:
|
||||
c = ProviderHealthCache()
|
||||
snap = c.snapshot(expected=["ollama", "groq"])
|
||||
assert set(snap.keys()) == {"ollama", "groq"}
|
||||
assert all(s.healthy for s in snap.values())
|
||||
assert all(s.consecutive_failures == 0 for s in snap.values())
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Failure → unhealthy state machine
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestFailureSemantics:
|
||||
def test_first_failure_does_not_trip(self) -> None:
|
||||
# Single transient blips shouldn't bounce a provider — wait for the
|
||||
# next consecutive failure to confirm.
|
||||
c = ProviderHealthCache(failure_threshold=2)
|
||||
c.mark_unhealthy("ollama", "boom")
|
||||
assert c.is_healthy("ollama") is True
|
||||
state = c.get_state("ollama")
|
||||
assert state is not None
|
||||
assert state.consecutive_failures == 1
|
||||
assert state.healthy is True
|
||||
|
||||
def test_threshold_reached_trips_breaker(self) -> None:
|
||||
clock = FakeClock()
|
||||
c = ProviderHealthCache(
|
||||
failure_threshold=2,
|
||||
unhealthy_backoff_sec=60.0,
|
||||
clock=clock,
|
||||
)
|
||||
c.mark_unhealthy("ollama", "fail-1")
|
||||
c.mark_unhealthy("ollama", "fail-2")
|
||||
assert c.is_healthy("ollama") is False
|
||||
state = c.get_state("ollama")
|
||||
assert state is not None
|
||||
assert state.healthy is False
|
||||
assert state.last_error == "fail-2"
|
||||
assert state.unhealthy_until == clock.now + 60.0
|
||||
|
||||
def test_threshold_one_trips_immediately(self) -> None:
|
||||
c = ProviderHealthCache(failure_threshold=1)
|
||||
c.mark_unhealthy("ollama", "boom")
|
||||
assert c.is_healthy("ollama") is False
|
||||
|
||||
def test_mark_healthy_clears_state(self) -> None:
|
||||
c = ProviderHealthCache(failure_threshold=2)
|
||||
c.mark_unhealthy("ollama", "x")
|
||||
c.mark_unhealthy("ollama", "x")
|
||||
assert c.is_healthy("ollama") is False
|
||||
c.mark_healthy("ollama")
|
||||
assert c.is_healthy("ollama") is True
|
||||
state = c.get_state("ollama")
|
||||
assert state is not None
|
||||
assert state.healthy is True
|
||||
assert state.consecutive_failures == 0
|
||||
assert state.unhealthy_until == 0.0
|
||||
assert state.last_error is None
|
||||
|
||||
|
||||
class TestBackoffWindow:
|
||||
def test_is_healthy_false_during_backoff(self) -> None:
|
||||
clock = FakeClock()
|
||||
c = ProviderHealthCache(failure_threshold=1, unhealthy_backoff_sec=60.0, clock=clock)
|
||||
c.mark_unhealthy("ollama", "boom")
|
||||
assert c.is_healthy("ollama") is False
|
||||
clock.advance(30.0)
|
||||
assert c.is_healthy("ollama") is False
|
||||
|
||||
def test_is_healthy_true_after_backoff_expires(self) -> None:
|
||||
# After backoff: half-open. Router gets one more attempt; success
|
||||
# mark_healthy resets, failure mark_unhealthy re-arms backoff.
|
||||
clock = FakeClock()
|
||||
c = ProviderHealthCache(failure_threshold=1, unhealthy_backoff_sec=60.0, clock=clock)
|
||||
c.mark_unhealthy("ollama", "boom")
|
||||
clock.advance(61.0)
|
||||
assert c.is_healthy("ollama") is True
|
||||
|
||||
def test_failure_during_backoff_extends_window(self) -> None:
|
||||
clock = FakeClock()
|
||||
c = ProviderHealthCache(failure_threshold=1, unhealthy_backoff_sec=60.0, clock=clock)
|
||||
c.mark_unhealthy("ollama", "first")
|
||||
original_until = c.get_state("ollama").unhealthy_until
|
||||
clock.advance(20.0)
|
||||
c.mark_unhealthy("ollama", "second")
|
||||
new_until = c.get_state("ollama").unhealthy_until
|
||||
assert new_until > original_until
|
||||
assert new_until == clock.now + 60.0
|
||||
|
||||
def test_recovery_logged_only_once(self, caplog: pytest.LogCaptureFixture) -> None:
|
||||
clock = FakeClock()
|
||||
c = ProviderHealthCache(failure_threshold=1, unhealthy_backoff_sec=60.0, clock=clock)
|
||||
c.mark_unhealthy("ollama", "boom")
|
||||
with caplog.at_level("INFO"):
|
||||
c.mark_healthy("ollama")
|
||||
# Calling mark_healthy on an already-healthy provider must not
|
||||
# spam the recovery log line.
|
||||
c.mark_healthy("ollama")
|
||||
c.mark_healthy("ollama")
|
||||
recovery_lines = [r for r in caplog.records if "recovered" in r.message]
|
||||
assert len(recovery_lines) == 1
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Snapshot shape
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestSnapshot:
|
||||
def test_snapshot_returns_copies(self) -> None:
|
||||
# Caller shouldn't be able to poke through to the cache's internal
|
||||
# state via a snapshot reference.
|
||||
c = ProviderHealthCache(failure_threshold=1)
|
||||
c.mark_unhealthy("ollama", "x")
|
||||
snap = c.snapshot()
|
||||
snap["ollama"].healthy = True
|
||||
# Original state untouched:
|
||||
assert c.is_healthy("ollama") is False
|
||||
|
||||
def test_snapshot_matches_recorded_state(self) -> None:
|
||||
clock = FakeClock()
|
||||
c = ProviderHealthCache(
|
||||
failure_threshold=2,
|
||||
unhealthy_backoff_sec=60.0,
|
||||
clock=clock,
|
||||
)
|
||||
c.mark_unhealthy("groq", "rate-limit")
|
||||
snap = c.snapshot()
|
||||
assert isinstance(snap["groq"], ProviderState)
|
||||
assert snap["groq"].consecutive_failures == 1
|
||||
assert snap["groq"].last_error == "rate-limit"
|
||||
assert snap["groq"].last_check == clock.now
|
||||
|
||||
def test_snapshot_expected_does_not_overwrite_real_state(self) -> None:
|
||||
c = ProviderHealthCache(failure_threshold=1)
|
||||
c.mark_unhealthy("ollama", "real-boom")
|
||||
snap = c.snapshot(expected=["ollama", "groq"])
|
||||
# ollama keeps its real (unhealthy) state, groq gets the zero-default.
|
||||
assert snap["ollama"].consecutive_failures == 1
|
||||
assert snap["groq"].consecutive_failures == 0
|
||||
|
|
@ -1,222 +0,0 @@
|
|||
"""Tests for the background health-probe loop."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from typing import Awaitable, Callable
|
||||
|
||||
import pytest
|
||||
|
||||
from src.health import ProviderHealthCache
|
||||
from src.health_probe import HealthProbe, ProbeFn
|
||||
|
||||
|
||||
def make_probe(*, returns: bool = True, raises: type[BaseException] | None = None) -> ProbeFn:
|
||||
"""Synthesise a probe function that always returns / raises the same."""
|
||||
|
||||
async def probe() -> bool:
|
||||
if raises is not None:
|
||||
raise raises("boom")
|
||||
return returns
|
||||
|
||||
return probe
|
||||
|
||||
|
||||
def make_slow_probe(delay: float, returns: bool = True) -> ProbeFn:
|
||||
async def probe() -> bool:
|
||||
await asyncio.sleep(delay)
|
||||
return returns
|
||||
|
||||
return probe
|
||||
|
||||
|
||||
def make_call_counter() -> tuple[ProbeFn, Callable[[], int]]:
|
||||
"""Probe that counts how many times it was awaited."""
|
||||
count = 0
|
||||
|
||||
async def probe() -> bool:
|
||||
nonlocal count
|
||||
count += 1
|
||||
return True
|
||||
|
||||
return probe, lambda: count
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Construction
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestConstruction:
|
||||
def test_invalid_interval(self) -> None:
|
||||
with pytest.raises(ValueError, match="interval"):
|
||||
HealthProbe(ProviderHealthCache(), {}, interval=0.0)
|
||||
|
||||
def test_invalid_timeout(self) -> None:
|
||||
with pytest.raises(ValueError, match="timeout"):
|
||||
HealthProbe(ProviderHealthCache(), {}, timeout=0.0)
|
||||
|
||||
def test_provider_ids_exposes_keys(self) -> None:
|
||||
cache = ProviderHealthCache()
|
||||
probe = HealthProbe(
|
||||
cache,
|
||||
{"ollama": make_probe(), "groq": make_probe()},
|
||||
)
|
||||
assert sorted(probe.provider_ids) == ["groq", "ollama"]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# tick_once — the per-cycle behaviour
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestTickOnce:
|
||||
@pytest.mark.asyncio
|
||||
async def test_healthy_probe_marks_healthy(self) -> None:
|
||||
cache = ProviderHealthCache(failure_threshold=1)
|
||||
# Pre-mark unhealthy so we can verify the probe recovers it.
|
||||
cache.mark_unhealthy("ollama", "stale")
|
||||
probe = HealthProbe(cache, {"ollama": make_probe(returns=True)})
|
||||
await probe.tick_once()
|
||||
assert cache.is_healthy("ollama") is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_returning_false_marks_unhealthy(self) -> None:
|
||||
cache = ProviderHealthCache(failure_threshold=1)
|
||||
probe = HealthProbe(cache, {"ollama": make_probe(returns=False)})
|
||||
await probe.tick_once()
|
||||
assert cache.is_healthy("ollama") is False
|
||||
state = cache.get_state("ollama")
|
||||
assert state is not None
|
||||
assert "false" in (state.last_error or "")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_raising_marks_unhealthy_with_exc_info(self) -> None:
|
||||
cache = ProviderHealthCache(failure_threshold=1)
|
||||
probe = HealthProbe(
|
||||
cache, {"ollama": make_probe(raises=ConnectionError)}
|
||||
)
|
||||
await probe.tick_once()
|
||||
assert cache.is_healthy("ollama") is False
|
||||
state = cache.get_state("ollama")
|
||||
assert state is not None
|
||||
assert "ConnectionError" in (state.last_error or "")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_timeout_marks_unhealthy(self) -> None:
|
||||
cache = ProviderHealthCache(failure_threshold=1)
|
||||
probe = HealthProbe(
|
||||
cache,
|
||||
{"ollama": make_slow_probe(delay=1.0)},
|
||||
timeout=0.05,
|
||||
)
|
||||
await probe.tick_once()
|
||||
assert cache.is_healthy("ollama") is False
|
||||
state = cache.get_state("ollama")
|
||||
assert state is not None
|
||||
assert "timeout" in (state.last_error or "").lower()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_one_bad_probe_does_not_sink_others(self) -> None:
|
||||
# Probe 'ollama' raises — 'groq' must still be evaluated and marked
|
||||
# healthy. Bug shape: an unhandled exception in gather() sinks the
|
||||
# whole loop.
|
||||
cache = ProviderHealthCache(failure_threshold=1)
|
||||
probe = HealthProbe(
|
||||
cache,
|
||||
{
|
||||
"ollama": make_probe(raises=RuntimeError),
|
||||
"groq": make_probe(returns=True),
|
||||
},
|
||||
)
|
||||
await probe.tick_once()
|
||||
assert cache.is_healthy("ollama") is False
|
||||
assert cache.is_healthy("groq") is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_concurrent_probes(self) -> None:
|
||||
# All probes should run in parallel — total elapsed wall-clock for
|
||||
# N x 100ms probes should be well under N*100ms.
|
||||
import time
|
||||
|
||||
cache = ProviderHealthCache()
|
||||
probes = {f"p{i}": make_slow_probe(delay=0.1) for i in range(5)}
|
||||
probe = HealthProbe(cache, probes, timeout=1.0)
|
||||
t0 = time.perf_counter()
|
||||
await probe.tick_once()
|
||||
elapsed = time.perf_counter() - t0
|
||||
assert elapsed < 0.3, f"probes ran serially? elapsed={elapsed:.3f}s"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_empty_probes_is_noop(self) -> None:
|
||||
cache = ProviderHealthCache()
|
||||
probe = HealthProbe(cache, {})
|
||||
# No exception, no state mutation.
|
||||
await probe.tick_once()
|
||||
assert cache.snapshot() == {}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# start / stop lifecycle
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestLifecycle:
|
||||
@pytest.mark.asyncio
|
||||
async def test_start_runs_initial_tick_immediately(self) -> None:
|
||||
cache = ProviderHealthCache(failure_threshold=1)
|
||||
cache.mark_unhealthy("ollama", "stale")
|
||||
probe = HealthProbe(cache, {"ollama": make_probe(returns=True)}, interval=10.0)
|
||||
await probe.start()
|
||||
# Give the loop one event-loop turn to run the initial tick before
|
||||
# blocking on the long sleep.
|
||||
await asyncio.sleep(0.01)
|
||||
assert cache.is_healthy("ollama") is True
|
||||
await probe.stop()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_stop_cancels_cleanly(self) -> None:
|
||||
cache = ProviderHealthCache()
|
||||
probe = HealthProbe(
|
||||
cache, {"ollama": make_probe()}, interval=10.0, timeout=1.0
|
||||
)
|
||||
await probe.start()
|
||||
assert probe.running is True
|
||||
await probe.stop()
|
||||
assert probe.running is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_start_is_idempotent(self) -> None:
|
||||
cache = ProviderHealthCache()
|
||||
probe = HealthProbe(cache, {"ollama": make_probe()}, interval=10.0)
|
||||
await probe.start()
|
||||
await probe.start() # must not spawn a second task
|
||||
assert probe.running is True
|
||||
await probe.stop()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_stop_without_start_is_safe(self) -> None:
|
||||
cache = ProviderHealthCache()
|
||||
probe = HealthProbe(cache, {})
|
||||
await probe.stop() # idempotent / safe pre-start
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_loop_keeps_running_after_tick_error(self) -> None:
|
||||
# Even if every probe explodes, the loop must keep ticking.
|
||||
cache = ProviderHealthCache(failure_threshold=1)
|
||||
fn, count = make_call_counter()
|
||||
# Wrap with one that raises — but tick_once internally catches
|
||||
# per-probe via gather(return_exceptions=True). Force an outer error
|
||||
# via an evil probe key that the dict can't handle? Easier: use the
|
||||
# call-counter to verify multiple ticks happened.
|
||||
probe = HealthProbe(
|
||||
cache,
|
||||
{"counter": fn},
|
||||
interval=0.05, # short interval for test
|
||||
timeout=1.0,
|
||||
)
|
||||
await probe.start()
|
||||
await asyncio.sleep(0.18) # ~3-4 ticks
|
||||
await probe.stop()
|
||||
# Initial tick + at least 2 interval ticks.
|
||||
assert count() >= 3
|
||||
|
|
@ -1,415 +0,0 @@
|
|||
"""Tests for M4: observability + debug endpoints + reload."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
import signal
|
||||
from pathlib import Path
|
||||
|
||||
import httpx
|
||||
import pytest
|
||||
from fastapi.testclient import TestClient
|
||||
from prometheus_client import REGISTRY
|
||||
|
||||
from src.aliases import AliasRegistry
|
||||
from src.health import ProviderHealthCache
|
||||
from src.models import (
|
||||
ChatCompletionRequest,
|
||||
ChatCompletionResponse,
|
||||
Choice,
|
||||
Message,
|
||||
MessageResponse,
|
||||
)
|
||||
from src.providers import ProviderRouter
|
||||
from src.utils.metrics import (
|
||||
record_alias_resolved,
|
||||
record_fallback,
|
||||
set_provider_healthy,
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Cache → listener → metric gauge
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestHealthChangeListener:
|
||||
def test_listener_fires_on_unhealthy_transition(self) -> None:
|
||||
cache = ProviderHealthCache(failure_threshold=2)
|
||||
events: list[tuple[str, bool]] = []
|
||||
cache.add_listener(lambda p, h: events.append((p, h)))
|
||||
|
||||
# First failure: still healthy → no transition.
|
||||
cache.mark_unhealthy("ollama", "blip")
|
||||
assert events == []
|
||||
|
||||
# Second failure: transition healthy→unhealthy → fires.
|
||||
cache.mark_unhealthy("ollama", "boom")
|
||||
assert events == [("ollama", False)]
|
||||
|
||||
def test_listener_fires_on_recovery(self) -> None:
|
||||
cache = ProviderHealthCache(failure_threshold=1)
|
||||
events: list[tuple[str, bool]] = []
|
||||
cache.add_listener(lambda p, h: events.append((p, h)))
|
||||
|
||||
cache.mark_unhealthy("ollama", "boom")
|
||||
assert events == [("ollama", False)]
|
||||
|
||||
cache.mark_healthy("ollama")
|
||||
assert events == [("ollama", False), ("ollama", True)]
|
||||
|
||||
def test_steady_state_does_not_fire(self) -> None:
|
||||
cache = ProviderHealthCache(failure_threshold=1)
|
||||
events: list[tuple[str, bool]] = []
|
||||
cache.add_listener(lambda p, h: events.append((p, h)))
|
||||
|
||||
# Three healthy ops in a row — no transitions, no events.
|
||||
for _ in range(3):
|
||||
cache.mark_healthy("ollama")
|
||||
assert events == []
|
||||
|
||||
def test_listener_exception_does_not_break_cache(self) -> None:
|
||||
cache = ProviderHealthCache(failure_threshold=1)
|
||||
|
||||
def bad(_provider: str, _healthy: bool) -> None:
|
||||
raise RuntimeError("listener boom")
|
||||
|
||||
cache.add_listener(bad)
|
||||
# Should NOT raise — the cache must keep working with a broken
|
||||
# listener, otherwise one bad metric callback would brick the
|
||||
# whole router.
|
||||
cache.mark_unhealthy("ollama", "x")
|
||||
assert cache.is_healthy("ollama") is False
|
||||
|
||||
def test_multiple_listeners(self) -> None:
|
||||
cache = ProviderHealthCache(failure_threshold=1)
|
||||
a: list = []
|
||||
b: list = []
|
||||
cache.add_listener(lambda p, h: a.append((p, h)))
|
||||
cache.add_listener(lambda p, h: b.append((p, h)))
|
||||
|
||||
cache.mark_unhealthy("ollama", "x")
|
||||
assert a == [("ollama", False)]
|
||||
assert b == [("ollama", False)]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Prometheus metrics — counters/gauges actually move
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _counter_value(name: str, labels: dict[str, str]) -> float:
|
||||
"""Helper: read the current value of a labeled Prometheus metric."""
|
||||
samples = REGISTRY.get_sample_value(name, labels=labels)
|
||||
return samples or 0.0
|
||||
|
||||
|
||||
class TestMetricsRecording:
|
||||
def test_record_alias_resolved_increments(self) -> None:
|
||||
before = _counter_value(
|
||||
"mana_llm_alias_resolved_total",
|
||||
{"alias": "mana/test-class", "target": "ollama/x:1b"},
|
||||
)
|
||||
record_alias_resolved("mana/test-class", "ollama/x:1b")
|
||||
after = _counter_value(
|
||||
"mana_llm_alias_resolved_total",
|
||||
{"alias": "mana/test-class", "target": "ollama/x:1b"},
|
||||
)
|
||||
assert after - before == pytest.approx(1.0)
|
||||
|
||||
def test_record_fallback_increments(self) -> None:
|
||||
before = _counter_value(
|
||||
"mana_llm_fallback_total",
|
||||
{"from_model": "ollama/x", "to_model": "groq/y", "reason": "ConnectError"},
|
||||
)
|
||||
record_fallback("ollama/x", "groq/y", "ConnectError")
|
||||
after = _counter_value(
|
||||
"mana_llm_fallback_total",
|
||||
{"from_model": "ollama/x", "to_model": "groq/y", "reason": "ConnectError"},
|
||||
)
|
||||
assert after - before == pytest.approx(1.0)
|
||||
|
||||
def test_set_provider_healthy_writes_gauge(self) -> None:
|
||||
set_provider_healthy("test_provider_xyz", True)
|
||||
v = REGISTRY.get_sample_value(
|
||||
"mana_llm_provider_healthy", labels={"provider": "test_provider_xyz"}
|
||||
)
|
||||
assert v == 1.0
|
||||
|
||||
set_provider_healthy("test_provider_xyz", False)
|
||||
v = REGISTRY.get_sample_value(
|
||||
"mana_llm_provider_healthy", labels={"provider": "test_provider_xyz"}
|
||||
)
|
||||
assert v == 0.0
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Router → metrics: end-to-end through a fallback
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class _OkProvider:
|
||||
"""Minimal provider double — only what the router uses for chat."""
|
||||
|
||||
name = "ok-provider"
|
||||
supports_tools = True
|
||||
|
||||
def __init__(self, name: str, fail_with: BaseException | None = None) -> None:
|
||||
self.name = name
|
||||
self.fail_with = fail_with
|
||||
self.calls = 0
|
||||
|
||||
def model_supports_tools(self, model: str) -> bool:
|
||||
return True
|
||||
|
||||
async def chat_completion(self, request, model):
|
||||
self.calls += 1
|
||||
if self.fail_with is not None:
|
||||
raise self.fail_with
|
||||
return ChatCompletionResponse(
|
||||
model=f"{self.name}/{model}",
|
||||
choices=[Choice(message=MessageResponse(content="ok"))],
|
||||
)
|
||||
|
||||
async def chat_completion_stream(self, request, model): # pragma: no cover
|
||||
if False: # pragma: no cover
|
||||
yield None
|
||||
|
||||
async def list_models(self):
|
||||
return []
|
||||
|
||||
async def embeddings(self, request, model):
|
||||
raise NotImplementedError
|
||||
|
||||
async def health_check(self):
|
||||
return {"status": "healthy"}
|
||||
|
||||
async def close(self):
|
||||
pass
|
||||
|
||||
|
||||
def _aliases(tmp_path: Path) -> AliasRegistry:
|
||||
cfg = (
|
||||
"aliases:\n"
|
||||
" mana/two-step:\n"
|
||||
' description: "x"\n'
|
||||
" chain:\n"
|
||||
" - alpha/m1\n"
|
||||
" - beta/m2\n"
|
||||
)
|
||||
p = tmp_path / "aliases.yaml"
|
||||
p.write_text(cfg)
|
||||
return AliasRegistry(p)
|
||||
|
||||
|
||||
class TestRouterMetricsIntegration:
|
||||
@pytest.mark.asyncio
|
||||
async def test_alias_resolved_metric_records_target(self, tmp_path: Path) -> None:
|
||||
aliases = _aliases(tmp_path)
|
||||
cache = ProviderHealthCache()
|
||||
router = ProviderRouter(aliases=aliases, health_cache=cache)
|
||||
router.providers = {"alpha": _OkProvider("alpha")} # beta not configured
|
||||
|
||||
before = _counter_value(
|
||||
"mana_llm_alias_resolved_total",
|
||||
{"alias": "mana/two-step", "target": "alpha/m1"},
|
||||
)
|
||||
await router.chat_completion(
|
||||
ChatCompletionRequest(
|
||||
model="mana/two-step",
|
||||
messages=[Message(role="user", content="hi")],
|
||||
)
|
||||
)
|
||||
after = _counter_value(
|
||||
"mana_llm_alias_resolved_total",
|
||||
{"alias": "mana/two-step", "target": "alpha/m1"},
|
||||
)
|
||||
assert after - before == pytest.approx(1.0)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_fallback_metric_records_transition(self, tmp_path: Path) -> None:
|
||||
aliases = _aliases(tmp_path)
|
||||
cache = ProviderHealthCache()
|
||||
router = ProviderRouter(aliases=aliases, health_cache=cache)
|
||||
router.providers = {
|
||||
"alpha": _OkProvider("alpha", fail_with=httpx.ConnectError("dead")),
|
||||
"beta": _OkProvider("beta"),
|
||||
}
|
||||
|
||||
before = _counter_value(
|
||||
"mana_llm_fallback_total",
|
||||
{"from_model": "alpha/m1", "to_model": "beta/m2", "reason": "ConnectError"},
|
||||
)
|
||||
await router.chat_completion(
|
||||
ChatCompletionRequest(
|
||||
model="mana/two-step",
|
||||
messages=[Message(role="user", content="hi")],
|
||||
)
|
||||
)
|
||||
after = _counter_value(
|
||||
"mana_llm_fallback_total",
|
||||
{"from_model": "alpha/m1", "to_model": "beta/m2", "reason": "ConnectError"},
|
||||
)
|
||||
assert after - before == pytest.approx(1.0)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_direct_model_does_not_record_alias_metric(
|
||||
self, tmp_path: Path
|
||||
) -> None:
|
||||
# Direct provider/model is not an alias — ALIAS_RESOLVED counter
|
||||
# must stay flat for those calls.
|
||||
aliases = _aliases(tmp_path)
|
||||
cache = ProviderHealthCache()
|
||||
router = ProviderRouter(aliases=aliases, health_cache=cache)
|
||||
router.providers = {"alpha": _OkProvider("alpha")}
|
||||
|
||||
before = _counter_value(
|
||||
"mana_llm_alias_resolved_total",
|
||||
{"alias": "alpha/anything", "target": "alpha/anything"},
|
||||
)
|
||||
await router.chat_completion(
|
||||
ChatCompletionRequest(
|
||||
model="alpha/anything",
|
||||
messages=[Message(role="user", content="hi")],
|
||||
)
|
||||
)
|
||||
after = _counter_value(
|
||||
"mana_llm_alias_resolved_total",
|
||||
{"alias": "alpha/anything", "target": "alpha/anything"},
|
||||
)
|
||||
# Counter must have NOT increased — direct calls aren't aliases.
|
||||
assert after == before
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Debug endpoints: GET /v1/aliases, GET /v1/health
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def client():
|
||||
from src.main import app
|
||||
|
||||
with TestClient(app) as c:
|
||||
yield c
|
||||
|
||||
|
||||
class TestDebugEndpoints:
|
||||
def test_v1_aliases_returns_shipped_config(self, client: TestClient) -> None:
|
||||
resp = client.get("/v1/aliases")
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
names = [a["name"] for a in data["aliases"]]
|
||||
# The five canonical classes must always be present.
|
||||
for expected in (
|
||||
"mana/fast-text",
|
||||
"mana/long-form",
|
||||
"mana/structured",
|
||||
"mana/reasoning",
|
||||
"mana/vision",
|
||||
):
|
||||
assert expected in names
|
||||
# Default is set in the shipped config.
|
||||
assert data["default"] == "mana/fast-text"
|
||||
|
||||
def test_v1_aliases_chain_format(self, client: TestClient) -> None:
|
||||
resp = client.get("/v1/aliases")
|
||||
data = resp.json()
|
||||
long_form = next(a for a in data["aliases"] if a["name"] == "mana/long-form")
|
||||
# Each chain entry is a `provider/model` string.
|
||||
assert all("/" in entry for entry in long_form["chain"])
|
||||
assert len(long_form["chain"]) >= 2 # plan requires at least one cloud fallback
|
||||
|
||||
def test_v1_health_includes_all_providers(self, client: TestClient) -> None:
|
||||
resp = client.get("/v1/health")
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert "status" in data
|
||||
assert "providers" in data
|
||||
# ollama is always configured (provider list is non-empty).
|
||||
assert "ollama" in data["providers"]
|
||||
for name, info in data["providers"].items():
|
||||
assert "status" in info
|
||||
assert "consecutive_failures" in info
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# X-Mana-LLM-Resolved header on non-streaming responses
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestResolvedHeader:
|
||||
"""The header is the consumer's hook for token-cost attribution.
|
||||
|
||||
Tested at the router level — wiring through main.py would need a
|
||||
real provider connection, which isn't available in unit tests.
|
||||
"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_response_model_field_carries_resolved_target(
|
||||
self, tmp_path: Path
|
||||
) -> None:
|
||||
# The header value is `response.model`; verify that field reflects
|
||||
# the actual chain entry that served, not the requested alias.
|
||||
aliases = _aliases(tmp_path)
|
||||
cache = ProviderHealthCache()
|
||||
router = ProviderRouter(aliases=aliases, health_cache=cache)
|
||||
# Force fallback to beta.
|
||||
router.providers = {
|
||||
"alpha": _OkProvider("alpha", fail_with=httpx.ConnectError("d")),
|
||||
"beta": _OkProvider("beta"),
|
||||
}
|
||||
|
||||
resp = await router.chat_completion(
|
||||
ChatCompletionRequest(
|
||||
model="mana/two-step",
|
||||
messages=[Message(role="user", content="hi")],
|
||||
)
|
||||
)
|
||||
# Even though the caller asked for `mana/two-step`, the resolved
|
||||
# field shows the entry that actually answered.
|
||||
assert resp.model == "beta/m2"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# SIGHUP reload — only meaningful on Unix; tested by signalling the proc
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestSighupReload:
|
||||
"""SIGHUP triggers ``alias_registry.reload()``; reload-error keeps state.
|
||||
|
||||
The signal-handler wiring lives in main.py and only installs when
|
||||
the loop is running in the main thread. We exercise the reload
|
||||
semantics here directly on the registry instead — the signal-handler
|
||||
code path itself is a 4-line wrapper around ``reload()``.
|
||||
"""
|
||||
|
||||
def test_reload_picks_up_yaml_edits(self, tmp_path: Path) -> None:
|
||||
path = tmp_path / "aliases.yaml"
|
||||
path.write_text(
|
||||
"aliases:\n"
|
||||
" mana/x:\n"
|
||||
' description: "x"\n'
|
||||
" chain:\n"
|
||||
" - ollama/foo:1b\n"
|
||||
)
|
||||
reg = AliasRegistry(path)
|
||||
assert reg.resolve_chain("mana/x") == ("ollama/foo:1b",)
|
||||
|
||||
# Edit on disk, reload (this is exactly what the SIGHUP handler
|
||||
# does — minus the signal plumbing).
|
||||
path.write_text(
|
||||
"aliases:\n"
|
||||
" mana/x:\n"
|
||||
' description: "x"\n'
|
||||
" chain:\n"
|
||||
" - ollama/bar:1b\n"
|
||||
" - groq/llama-3.1-8b-instant\n"
|
||||
)
|
||||
reg.reload()
|
||||
assert reg.resolve_chain("mana/x") == (
|
||||
"ollama/bar:1b",
|
||||
"groq/llama-3.1-8b-instant",
|
||||
)
|
||||
|
|
@ -1,123 +0,0 @@
|
|||
"""Provider tests."""
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from src.aliases import AliasRegistry
|
||||
from src.health import ProviderHealthCache
|
||||
from src.models import ChatCompletionRequest, EmbeddingRequest, Message
|
||||
from src.providers import OllamaProvider, OpenAICompatProvider, ProviderRouter
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def shipped_aliases() -> AliasRegistry:
|
||||
"""The repo's real aliases.yaml — same one production uses."""
|
||||
return AliasRegistry(Path(__file__).resolve().parents[1] / "aliases.yaml")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def router(shipped_aliases: AliasRegistry) -> ProviderRouter:
|
||||
return ProviderRouter(aliases=shipped_aliases, health_cache=ProviderHealthCache())
|
||||
|
||||
|
||||
class TestProviderRouter:
|
||||
"""Tests for the helpers exposed by the router."""
|
||||
|
||||
def test_parse_model_with_provider(self, router: ProviderRouter) -> None:
|
||||
provider, model = router._parse_model("ollama/gemma3:4b")
|
||||
assert provider == "ollama"
|
||||
assert model == "gemma3:4b"
|
||||
|
||||
def test_parse_model_without_provider(self, router: ProviderRouter) -> None:
|
||||
# Bare names default to Ollama for OpenAI-style compat.
|
||||
provider, model = router._parse_model("gemma3:4b")
|
||||
assert provider == "ollama"
|
||||
assert model == "gemma3:4b"
|
||||
|
||||
def test_parse_model_openrouter(self, router: ProviderRouter) -> None:
|
||||
provider, model = router._parse_model("openrouter/meta-llama/llama-3.1-8b-instruct")
|
||||
assert provider == "openrouter"
|
||||
assert model == "meta-llama/llama-3.1-8b-instruct"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_embeddings_unknown_provider_raises(self, router: ProviderRouter) -> None:
|
||||
# Embeddings don't go through the alias/fallback pipeline — they
|
||||
# hit the requested provider directly. Asking for an unconfigured
|
||||
# one is a config error and must raise loudly.
|
||||
with pytest.raises(ValueError, match="not available"):
|
||||
await router.embeddings(
|
||||
EmbeddingRequest(model="bogus_provider/x", input="hi")
|
||||
)
|
||||
|
||||
|
||||
class TestOllamaProvider:
|
||||
"""Test Ollama provider."""
|
||||
|
||||
def test_convert_simple_messages(self):
|
||||
"""Test converting simple text messages."""
|
||||
provider = OllamaProvider()
|
||||
request = ChatCompletionRequest(
|
||||
model="gemma3:4b",
|
||||
messages=[
|
||||
Message(role="user", content="Hello"),
|
||||
],
|
||||
)
|
||||
|
||||
messages = provider._convert_messages(request)
|
||||
assert len(messages) == 1
|
||||
assert messages[0]["role"] == "user"
|
||||
assert messages[0]["content"] == "Hello"
|
||||
|
||||
def test_convert_multimodal_messages(self):
|
||||
"""Test converting multimodal messages."""
|
||||
provider = OllamaProvider()
|
||||
request = ChatCompletionRequest(
|
||||
model="llava:7b",
|
||||
messages=[
|
||||
Message(
|
||||
role="user",
|
||||
content=[
|
||||
{"type": "text", "text": "What's in this image?"},
|
||||
{
|
||||
"type": "image_url",
|
||||
"image_url": {"url": "data:image/png;base64,iVBORw0KGgo="},
|
||||
},
|
||||
],
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
messages = provider._convert_messages(request)
|
||||
assert len(messages) == 1
|
||||
assert messages[0]["role"] == "user"
|
||||
assert messages[0]["content"] == "What's in this image?"
|
||||
assert "images" in messages[0]
|
||||
assert len(messages[0]["images"]) == 1
|
||||
|
||||
|
||||
class TestOpenAICompatProvider:
|
||||
"""Test OpenAI-compatible provider."""
|
||||
|
||||
def test_convert_simple_messages(self):
|
||||
"""Test converting simple text messages."""
|
||||
provider = OpenAICompatProvider(
|
||||
name="test",
|
||||
base_url="http://localhost",
|
||||
api_key="test-key",
|
||||
)
|
||||
|
||||
request = ChatCompletionRequest(
|
||||
model="test-model",
|
||||
messages=[
|
||||
Message(role="system", content="You are helpful."),
|
||||
Message(role="user", content="Hello"),
|
||||
],
|
||||
)
|
||||
|
||||
messages = provider._convert_messages(request)
|
||||
assert len(messages) == 2
|
||||
assert messages[0]["role"] == "system"
|
||||
assert messages[0]["content"] == "You are helpful."
|
||||
assert messages[1]["role"] == "user"
|
||||
assert messages[1]["content"] == "Hello"
|
||||
|
|
@ -1,526 +0,0 @@
|
|||
"""Tests for ProviderRouter fallback / alias execution (M3)."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import AsyncIterator
|
||||
from typing import Any
|
||||
|
||||
import httpx
|
||||
import pytest
|
||||
|
||||
from src.aliases import AliasRegistry
|
||||
from src.health import ProviderHealthCache
|
||||
from src.models import (
|
||||
ChatCompletionRequest,
|
||||
ChatCompletionResponse,
|
||||
ChatCompletionStreamResponse,
|
||||
Choice,
|
||||
DeltaContent,
|
||||
EmbeddingRequest,
|
||||
EmbeddingResponse,
|
||||
Message,
|
||||
MessageResponse,
|
||||
ModelInfo,
|
||||
StreamChoice,
|
||||
)
|
||||
from src.providers import ProviderRouter
|
||||
from src.providers.base import LLMProvider
|
||||
from src.providers.errors import (
|
||||
NoHealthyProviderError,
|
||||
ProviderAuthError,
|
||||
ProviderCapabilityError,
|
||||
ProviderRateLimitError,
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Test doubles
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class MockProvider(LLMProvider):
|
||||
"""Provider that lets tests inject a sequence of behaviours.
|
||||
|
||||
Each call pops one entry from ``behaviors``. Strings ``"ok"`` and
|
||||
``"empty"`` are sentinels for normal returns; everything else (a
|
||||
BaseException instance / class) is raised.
|
||||
"""
|
||||
|
||||
supports_tools = True
|
||||
|
||||
def __init__(self, name: str, behaviors: list[Any] | None = None) -> None:
|
||||
self.name = name
|
||||
self._behaviors: list[Any] = list(behaviors or [])
|
||||
self.calls: list[str] = []
|
||||
|
||||
def push(self, *behaviors: Any) -> None:
|
||||
self._behaviors.extend(behaviors)
|
||||
|
||||
def _next(self) -> Any:
|
||||
return self._behaviors.pop(0) if self._behaviors else "ok"
|
||||
|
||||
async def chat_completion(
|
||||
self, request: ChatCompletionRequest, model: str
|
||||
) -> ChatCompletionResponse:
|
||||
self.calls.append(model)
|
||||
b = self._next()
|
||||
if isinstance(b, type) and issubclass(b, BaseException):
|
||||
raise b("simulated")
|
||||
if isinstance(b, BaseException):
|
||||
raise b
|
||||
return _ok_response(self.name, model)
|
||||
|
||||
async def chat_completion_stream(
|
||||
self, request: ChatCompletionRequest, model: str
|
||||
) -> AsyncIterator[ChatCompletionStreamResponse]:
|
||||
self.calls.append(model)
|
||||
b = self._next()
|
||||
if isinstance(b, type) and issubclass(b, BaseException):
|
||||
raise b("simulated")
|
||||
if isinstance(b, BaseException):
|
||||
raise b
|
||||
if b == "empty":
|
||||
return
|
||||
for content in ("Hello", " ", "world"):
|
||||
yield ChatCompletionStreamResponse(
|
||||
model=f"{self.name}/{model}",
|
||||
choices=[StreamChoice(delta=DeltaContent(content=content))],
|
||||
)
|
||||
|
||||
async def list_models(self) -> list[ModelInfo]:
|
||||
return [ModelInfo(id=f"{self.name}/{m}") for m in ("modelA", "modelB")]
|
||||
|
||||
async def embeddings(
|
||||
self, request: EmbeddingRequest, model: str
|
||||
) -> EmbeddingResponse:
|
||||
raise NotImplementedError
|
||||
|
||||
async def health_check(self) -> dict[str, Any]:
|
||||
return {"status": "healthy"}
|
||||
|
||||
|
||||
class FailFirstChunkProvider(MockProvider):
|
||||
"""Streaming provider that raises BEFORE the first chunk every time.
|
||||
|
||||
Kept separate from MockProvider's behaviour list so the per-call
|
||||
semantics stay simple — this one models a permanently-broken streamer.
|
||||
"""
|
||||
|
||||
def __init__(self, name: str, exc: BaseException) -> None:
|
||||
super().__init__(name)
|
||||
self._exc = exc
|
||||
|
||||
async def chat_completion_stream(self, request, model): # type: ignore[override]
|
||||
self.calls.append(model)
|
||||
raise self._exc
|
||||
# the yield is unreachable but keeps the function an async generator
|
||||
yield # pragma: no cover
|
||||
|
||||
|
||||
def _ok_response(provider: str, model: str) -> ChatCompletionResponse:
|
||||
return ChatCompletionResponse(
|
||||
model=f"{provider}/{model}",
|
||||
choices=[
|
||||
Choice(
|
||||
message=MessageResponse(content="ok"),
|
||||
finish_reason="stop",
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
def _request(model: str) -> ChatCompletionRequest:
|
||||
return ChatCompletionRequest(
|
||||
model=model,
|
||||
messages=[Message(role="user", content="hi")],
|
||||
)
|
||||
|
||||
|
||||
def _aliases_yaml(tmp_path) -> AliasRegistry:
|
||||
"""A two-alias config used across most tests."""
|
||||
cfg = (
|
||||
"aliases:\n"
|
||||
" mana/long-form:\n"
|
||||
' description: "long"\n'
|
||||
" chain:\n"
|
||||
" - alpha/m1\n"
|
||||
" - beta/m2\n"
|
||||
" - gamma/m3\n"
|
||||
" mana/single:\n"
|
||||
' description: "single-entry"\n'
|
||||
" chain:\n"
|
||||
" - alpha/solo\n"
|
||||
)
|
||||
p = tmp_path / "aliases.yaml"
|
||||
p.write_text(cfg)
|
||||
return AliasRegistry(p)
|
||||
|
||||
|
||||
def _make_router(
|
||||
tmp_path,
|
||||
*,
|
||||
providers: dict[str, MockProvider],
|
||||
cache: ProviderHealthCache | None = None,
|
||||
) -> ProviderRouter:
|
||||
aliases = _aliases_yaml(tmp_path)
|
||||
router = ProviderRouter(aliases=aliases, health_cache=cache or ProviderHealthCache())
|
||||
# Replace the auto-initialised live providers with the test doubles.
|
||||
router.providers = dict(providers)
|
||||
return router
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Non-streaming chain walking
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestChatCompletionChain:
|
||||
@pytest.mark.asyncio
|
||||
async def test_first_provider_ok_returns_immediately(self, tmp_path) -> None:
|
||||
alpha = MockProvider("alpha", ["ok"])
|
||||
beta = MockProvider("beta")
|
||||
router = _make_router(tmp_path, providers={"alpha": alpha, "beta": beta})
|
||||
|
||||
resp = await router.chat_completion(_request("mana/long-form"))
|
||||
|
||||
assert resp.model == "alpha/m1"
|
||||
assert alpha.calls == ["m1"]
|
||||
assert beta.calls == [] # never reached
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_falls_through_on_connect_error(self, tmp_path) -> None:
|
||||
alpha = MockProvider("alpha", [httpx.ConnectError("dead")])
|
||||
beta = MockProvider("beta", ["ok"])
|
||||
router = _make_router(tmp_path, providers={"alpha": alpha, "beta": beta})
|
||||
|
||||
resp = await router.chat_completion(_request("mana/long-form"))
|
||||
|
||||
assert resp.model == "beta/m2"
|
||||
assert alpha.calls == ["m1"]
|
||||
assert beta.calls == ["m2"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_skips_unconfigured_chain_entries(self, tmp_path) -> None:
|
||||
# gamma isn't configured at all → chain should silently skip it
|
||||
# rather than raise.
|
||||
alpha = MockProvider("alpha", [httpx.ConnectError("dead")])
|
||||
beta = MockProvider("beta", [httpx.ConnectError("dead too")])
|
||||
router = _make_router(tmp_path, providers={"alpha": alpha, "beta": beta})
|
||||
|
||||
with pytest.raises(NoHealthyProviderError) as exc_info:
|
||||
await router.chat_completion(_request("mana/long-form"))
|
||||
# All three entries appear in attempts: two as ConnectError, one
|
||||
# as unconfigured (not a fatal error, just skipped).
|
||||
attempts = exc_info.value.attempts
|
||||
assert ("alpha/m1", "ConnectError") in attempts
|
||||
assert ("beta/m2", "ConnectError") in attempts
|
||||
assert ("gamma/m3", "unconfigured") in attempts
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_skips_cache_unhealthy(self, tmp_path) -> None:
|
||||
cache = ProviderHealthCache(failure_threshold=1)
|
||||
cache.mark_unhealthy("alpha", "stale")
|
||||
alpha = MockProvider("alpha", ["ok"])
|
||||
beta = MockProvider("beta", ["ok"])
|
||||
router = _make_router(
|
||||
tmp_path, providers={"alpha": alpha, "beta": beta}, cache=cache
|
||||
)
|
||||
|
||||
resp = await router.chat_completion(_request("mana/long-form"))
|
||||
|
||||
assert alpha.calls == [] # router skipped per cache
|
||||
assert beta.calls == ["m2"]
|
||||
assert resp.model == "beta/m2"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_5xx_treated_as_retryable(self, tmp_path) -> None:
|
||||
five_hundred = httpx.HTTPStatusError(
|
||||
"boom",
|
||||
request=httpx.Request("POST", "http://x"),
|
||||
response=httpx.Response(503),
|
||||
)
|
||||
alpha = MockProvider("alpha", [five_hundred])
|
||||
beta = MockProvider("beta", ["ok"])
|
||||
router = _make_router(tmp_path, providers={"alpha": alpha, "beta": beta})
|
||||
|
||||
resp = await router.chat_completion(_request("mana/long-form"))
|
||||
|
||||
assert resp.model == "beta/m2"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_4xx_propagates(self, tmp_path) -> None:
|
||||
four_hundred = httpx.HTTPStatusError(
|
||||
"bad request",
|
||||
request=httpx.Request("POST", "http://x"),
|
||||
response=httpx.Response(422),
|
||||
)
|
||||
alpha = MockProvider("alpha", [four_hundred])
|
||||
beta = MockProvider("beta", ["ok"])
|
||||
router = _make_router(tmp_path, providers={"alpha": alpha, "beta": beta})
|
||||
|
||||
with pytest.raises(httpx.HTTPStatusError):
|
||||
await router.chat_completion(_request("mana/long-form"))
|
||||
# Beta never tried — caller's request needs fixing, retrying
|
||||
# against another model would just hide the bug.
|
||||
assert beta.calls == []
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_capability_error_propagates(self, tmp_path) -> None:
|
||||
alpha = MockProvider("alpha", [ProviderCapabilityError("no tools")])
|
||||
beta = MockProvider("beta", ["ok"])
|
||||
router = _make_router(tmp_path, providers={"alpha": alpha, "beta": beta})
|
||||
|
||||
with pytest.raises(ProviderCapabilityError):
|
||||
await router.chat_completion(_request("mana/long-form"))
|
||||
assert beta.calls == []
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_auth_error_propagates(self, tmp_path) -> None:
|
||||
# Auth errors mean OUR setup is broken (wrong key); falling back
|
||||
# to the next provider hides the misconfiguration.
|
||||
alpha = MockProvider("alpha", [ProviderAuthError("bad key")])
|
||||
beta = MockProvider("beta", ["ok"])
|
||||
router = _make_router(tmp_path, providers={"alpha": alpha, "beta": beta})
|
||||
|
||||
with pytest.raises(ProviderAuthError):
|
||||
await router.chat_completion(_request("mana/long-form"))
|
||||
assert beta.calls == []
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_rate_limit_is_retryable(self, tmp_path) -> None:
|
||||
alpha = MockProvider("alpha", [ProviderRateLimitError("slow down")])
|
||||
beta = MockProvider("beta", ["ok"])
|
||||
router = _make_router(tmp_path, providers={"alpha": alpha, "beta": beta})
|
||||
|
||||
resp = await router.chat_completion(_request("mana/long-form"))
|
||||
|
||||
assert resp.model == "beta/m2"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_all_fail_raises_no_healthy_provider(self, tmp_path) -> None:
|
||||
alpha = MockProvider("alpha", [httpx.ConnectError("a")])
|
||||
beta = MockProvider("beta", [httpx.ConnectError("b")])
|
||||
gamma = MockProvider("gamma", [httpx.ConnectError("c")])
|
||||
router = _make_router(
|
||||
tmp_path, providers={"alpha": alpha, "beta": beta, "gamma": gamma}
|
||||
)
|
||||
|
||||
with pytest.raises(NoHealthyProviderError) as exc_info:
|
||||
await router.chat_completion(_request("mana/long-form"))
|
||||
assert exc_info.value.model_or_alias == "mana/long-form"
|
||||
assert isinstance(exc_info.value.last_exception, httpx.ConnectError)
|
||||
# 503 status so calling code (mana-api etc.) can decide to retry
|
||||
# later vs surface a clean error to the user.
|
||||
assert exc_info.value.http_status == 503
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_direct_provider_string_no_alias_resolution(self, tmp_path) -> None:
|
||||
# Caller bypasses aliases by passing a direct provider/model.
|
||||
# No fallback chain — fail = fail.
|
||||
alpha = MockProvider("alpha", [httpx.ConnectError("dead")])
|
||||
beta = MockProvider("beta", ["ok"])
|
||||
router = _make_router(tmp_path, providers={"alpha": alpha, "beta": beta})
|
||||
|
||||
with pytest.raises(NoHealthyProviderError):
|
||||
await router.chat_completion(_request("alpha/anything"))
|
||||
# Beta would have served if this had been an alias — but it
|
||||
# wasn't, so beta never gets touched.
|
||||
assert beta.calls == []
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Health-cache feedback: success clears, failure marks
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestHealthCacheFeedback:
|
||||
@pytest.mark.asyncio
|
||||
async def test_success_marks_provider_healthy(self, tmp_path) -> None:
|
||||
cache = ProviderHealthCache(failure_threshold=1)
|
||||
cache.mark_unhealthy("alpha", "stale-from-probe")
|
||||
# After the cache TTL the cache thinks alpha might be OK again,
|
||||
# so the router will try it; success must fully clear the state.
|
||||
# (Force half-open by zeroing backoff.)
|
||||
alpha = MockProvider("alpha", ["ok"])
|
||||
router = _make_router(
|
||||
tmp_path,
|
||||
providers={"alpha": alpha},
|
||||
cache=ProviderHealthCache(), # fresh cache, alpha optimistic
|
||||
)
|
||||
|
||||
await router.chat_completion(_request("mana/single"))
|
||||
|
||||
assert router.health_cache.get_state("alpha").healthy is True
|
||||
assert router.health_cache.get_state("alpha").consecutive_failures == 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_failure_marks_provider_unhealthy(self, tmp_path) -> None:
|
||||
# threshold=1 so a single fail is enough to flip the breaker.
|
||||
cache = ProviderHealthCache(failure_threshold=1)
|
||||
alpha = MockProvider("alpha", [httpx.ConnectError("boom")])
|
||||
beta = MockProvider("beta", ["ok"])
|
||||
router = _make_router(
|
||||
tmp_path, providers={"alpha": alpha, "beta": beta}, cache=cache
|
||||
)
|
||||
|
||||
await router.chat_completion(_request("mana/long-form"))
|
||||
|
||||
assert cache.get_state("alpha").healthy is False
|
||||
assert cache.get_state("alpha").last_error is not None
|
||||
assert "ConnectError" in cache.get_state("alpha").last_error
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_propagating_error_does_not_touch_cache(self, tmp_path) -> None:
|
||||
# Auth/Capability errors are about CALLER state, not provider
|
||||
# health — the cache must stay clean so a real outage later
|
||||
# isn't masked by stale "marked unhealthy because of bad key".
|
||||
cache = ProviderHealthCache(failure_threshold=1)
|
||||
alpha = MockProvider("alpha", [ProviderAuthError("bad key")])
|
||||
router = _make_router(tmp_path, providers={"alpha": alpha}, cache=cache)
|
||||
|
||||
with pytest.raises(ProviderAuthError):
|
||||
await router.chat_completion(_request("mana/single"))
|
||||
|
||||
# No state recorded.
|
||||
assert cache.get_state("alpha") is None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Streaming pre-first-byte fallback
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestChatCompletionStream:
|
||||
@pytest.mark.asyncio
|
||||
async def test_first_provider_streams_normally(self, tmp_path) -> None:
|
||||
alpha = MockProvider("alpha", ["ok"])
|
||||
beta = MockProvider("beta")
|
||||
router = _make_router(tmp_path, providers={"alpha": alpha, "beta": beta})
|
||||
|
||||
chunks = [
|
||||
c async for c in router.chat_completion_stream(_request("mana/long-form"))
|
||||
]
|
||||
|
||||
assert beta.calls == []
|
||||
assert len(chunks) == 3
|
||||
assert "".join(c.choices[0].delta.content or "" for c in chunks) == "Hello world"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_pre_first_byte_failure_falls_back(self, tmp_path) -> None:
|
||||
alpha = FailFirstChunkProvider("alpha", httpx.ConnectError("dead"))
|
||||
beta = MockProvider("beta", ["ok"])
|
||||
router = _make_router(tmp_path, providers={"alpha": alpha, "beta": beta})
|
||||
|
||||
chunks = [
|
||||
c async for c in router.chat_completion_stream(_request("mana/long-form"))
|
||||
]
|
||||
|
||||
assert alpha.calls == ["m1"]
|
||||
assert beta.calls == ["m2"]
|
||||
assert len(chunks) == 3
|
||||
assert all(c.model == "beta/m2" for c in chunks)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_pre_first_byte_4xx_propagates_no_fallback(self, tmp_path) -> None:
|
||||
alpha = FailFirstChunkProvider("alpha", ProviderCapabilityError("no tools"))
|
||||
beta = MockProvider("beta", ["ok"])
|
||||
router = _make_router(tmp_path, providers={"alpha": alpha, "beta": beta})
|
||||
|
||||
with pytest.raises(ProviderCapabilityError):
|
||||
async for _ in router.chat_completion_stream(_request("mana/long-form")):
|
||||
pass
|
||||
assert beta.calls == []
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_empty_stream_commits_without_fallback(self, tmp_path) -> None:
|
||||
# Empty-but-successful stream is a valid response, not a failure
|
||||
# we should retry — committing avoids accidentally calling two
|
||||
# providers and double-billing.
|
||||
alpha = MockProvider("alpha", ["empty"])
|
||||
beta = MockProvider("beta", ["ok"])
|
||||
router = _make_router(tmp_path, providers={"alpha": alpha, "beta": beta})
|
||||
|
||||
chunks = [
|
||||
c async for c in router.chat_completion_stream(_request("mana/long-form"))
|
||||
]
|
||||
|
||||
assert chunks == []
|
||||
assert beta.calls == [] # didn't fall through
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_mid_stream_failure_does_not_fall_back(self, tmp_path) -> None:
|
||||
# Custom provider that yields once then raises mid-stream — the
|
||||
# router has already committed and must let the error propagate
|
||||
# rather than splice in another provider's voice.
|
||||
class MidStreamFailProvider(MockProvider):
|
||||
async def chat_completion_stream(self, request, model): # type: ignore[override]
|
||||
self.calls.append(model)
|
||||
yield ChatCompletionStreamResponse(
|
||||
model=f"{self.name}/{model}",
|
||||
choices=[StreamChoice(delta=DeltaContent(content="halb"))],
|
||||
)
|
||||
raise httpx.RemoteProtocolError("connection dropped")
|
||||
|
||||
alpha = MidStreamFailProvider("alpha")
|
||||
beta = MockProvider("beta", ["ok"])
|
||||
router = _make_router(tmp_path, providers={"alpha": alpha, "beta": beta})
|
||||
|
||||
collected: list[str] = []
|
||||
with pytest.raises(httpx.RemoteProtocolError):
|
||||
async for chunk in router.chat_completion_stream(_request("mana/long-form")):
|
||||
collected.append(chunk.choices[0].delta.content or "")
|
||||
|
||||
# We got the half-chunk that landed before the break; beta was
|
||||
# NOT called as fallback.
|
||||
assert collected == ["halb"]
|
||||
assert beta.calls == []
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_all_fail_streaming_raises_no_healthy_provider(self, tmp_path) -> None:
|
||||
alpha = FailFirstChunkProvider("alpha", httpx.ConnectError("a"))
|
||||
beta = FailFirstChunkProvider("beta", httpx.ConnectError("b"))
|
||||
gamma = FailFirstChunkProvider("gamma", httpx.ConnectError("c"))
|
||||
router = _make_router(
|
||||
tmp_path, providers={"alpha": alpha, "beta": beta, "gamma": gamma}
|
||||
)
|
||||
|
||||
with pytest.raises(NoHealthyProviderError):
|
||||
async for _ in router.chat_completion_stream(_request("mana/long-form")):
|
||||
pass
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Health-check shape (still using the cache snapshot)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestHealthCheck:
|
||||
@pytest.mark.asyncio
|
||||
async def test_health_check_lists_known_providers(self, tmp_path) -> None:
|
||||
# Even if no probe has run yet, every configured provider should
|
||||
# appear in the snapshot (zero-defaults) so /health has a stable
|
||||
# shape for monitors.
|
||||
alpha = MockProvider("alpha")
|
||||
beta = MockProvider("beta")
|
||||
router = _make_router(tmp_path, providers={"alpha": alpha, "beta": beta})
|
||||
|
||||
out = await router.health_check()
|
||||
|
||||
assert set(out["providers"].keys()) == {"alpha", "beta"}
|
||||
assert out["status"] == "healthy"
|
||||
assert all(p["status"] == "healthy" for p in out["providers"].values())
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_health_check_degraded_when_one_unhealthy(self, tmp_path) -> None:
|
||||
cache = ProviderHealthCache(failure_threshold=1)
|
||||
cache.mark_unhealthy("alpha", "boom")
|
||||
alpha = MockProvider("alpha")
|
||||
beta = MockProvider("beta")
|
||||
router = _make_router(
|
||||
tmp_path, providers={"alpha": alpha, "beta": beta}, cache=cache
|
||||
)
|
||||
|
||||
out = await router.health_check()
|
||||
assert out["status"] == "degraded"
|
||||
assert out["providers"]["alpha"]["status"] == "unhealthy"
|
||||
assert out["providers"]["beta"]["status"] == "healthy"
|
||||
|
|
@ -1,57 +0,0 @@
|
|||
"""Streaming tests."""
|
||||
|
||||
import pytest
|
||||
|
||||
from src.models import ChatCompletionStreamResponse, DeltaContent, StreamChoice
|
||||
|
||||
|
||||
class TestStreamingModels:
|
||||
"""Test streaming response models."""
|
||||
|
||||
def test_stream_response_serialization(self):
|
||||
"""Test streaming response serializes correctly."""
|
||||
response = ChatCompletionStreamResponse(
|
||||
id="test-id",
|
||||
model="ollama/gemma3:4b",
|
||||
choices=[
|
||||
StreamChoice(
|
||||
delta=DeltaContent(content="Hello"),
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
data = response.model_dump(exclude_none=True)
|
||||
assert data["id"] == "test-id"
|
||||
assert data["model"] == "ollama/gemma3:4b"
|
||||
assert data["choices"][0]["delta"]["content"] == "Hello"
|
||||
|
||||
def test_stream_response_with_role(self):
|
||||
"""Test first chunk with role."""
|
||||
response = ChatCompletionStreamResponse(
|
||||
id="test-id",
|
||||
model="ollama/gemma3:4b",
|
||||
choices=[
|
||||
StreamChoice(
|
||||
delta=DeltaContent(role="assistant"),
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
data = response.model_dump(exclude_none=True)
|
||||
assert data["choices"][0]["delta"]["role"] == "assistant"
|
||||
|
||||
def test_stream_response_with_finish_reason(self):
|
||||
"""Test final chunk with finish_reason."""
|
||||
response = ChatCompletionStreamResponse(
|
||||
id="test-id",
|
||||
model="ollama/gemma3:4b",
|
||||
choices=[
|
||||
StreamChoice(
|
||||
delta=DeltaContent(),
|
||||
finish_reason="stop",
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
data = response.model_dump(exclude_none=True)
|
||||
assert data["choices"][0]["finish_reason"] == "stop"
|
||||
Loading…
Add table
Add a link
Reference in a new issue