From 2b07f6ef89260f8e4c2d00fcfa15a840eb411c96 Mon Sep 17 00:00:00 2001 From: Till JS Date: Fri, 8 May 2026 18:53:54 +0200 Subject: [PATCH] =?UTF-8?q?chore(cutover):=20remove=20services/mana-llm/?= =?UTF-8?q?=20=E2=80=94=20moved=20to=20mana-platform?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Live containers on the Mac Mini build out of `../mana/services/mana-llm/` since the 8-Doppel-Cutover commit (774852ba2). Smoke test green 2026-05-08 — health endpoints, JWKS, login flow, Stripe-webhook all reachable from the new build path. Removing the now-stale duplicate. Was 90M in this repo, gone now. Active code lives in `Code/mana/services/mana-llm/` (siehe ../mana/CLAUDE.md). Co-Authored-By: Claude Opus 4.7 (1M context) --- services/mana-llm/.env.example | 28 - services/mana-llm/.gitignore | 31 -- services/mana-llm/CLAUDE.md | 376 ------------- services/mana-llm/Dockerfile | 49 -- services/mana-llm/aliases.yaml | 54 -- services/mana-llm/docker-compose.dev.yml | 15 - services/mana-llm/docker-compose.yml | 52 -- services/mana-llm/pyproject.toml | 40 -- services/mana-llm/requirements.txt | 29 - services/mana-llm/service.pyw | 17 - services/mana-llm/src/__init__.py | 3 - services/mana-llm/src/aliases.py | 223 -------- services/mana-llm/src/api_auth.py | 53 -- services/mana-llm/src/config.py | 57 -- services/mana-llm/src/health.py | 203 ------- services/mana-llm/src/health_probe.py | 210 ------- services/mana-llm/src/main.py | 405 -------------- services/mana-llm/src/models/__init__.py | 47 -- services/mana-llm/src/models/requests.py | 128 ----- services/mana-llm/src/models/responses.py | 120 ---- services/mana-llm/src/providers/__init__.py | 13 - services/mana-llm/src/providers/base.py | 76 --- services/mana-llm/src/providers/errors.py | 111 ---- services/mana-llm/src/providers/google.py | 521 ----------------- services/mana-llm/src/providers/ollama.py | 462 --------------- .../mana-llm/src/providers/openai_compat.py | 364 ------------ services/mana-llm/src/providers/router.py | 463 --------------- services/mana-llm/src/streaming/__init__.py | 5 - services/mana-llm/src/streaming/sse.py | 52 -- services/mana-llm/src/utils/__init__.py | 5 - services/mana-llm/src/utils/cache.py | 85 --- services/mana-llm/src/utils/metrics.py | 158 ------ services/mana-llm/start.sh | 28 - services/mana-llm/tests/__init__.py | 1 - services/mana-llm/tests/test_aliases.py | 300 ---------- services/mana-llm/tests/test_api.py | 41 -- services/mana-llm/tests/test_health.py | 203 ------- services/mana-llm/tests/test_health_probe.py | 222 -------- .../mana-llm/tests/test_m4_observability.py | 415 -------------- services/mana-llm/tests/test_providers.py | 123 ---- .../mana-llm/tests/test_router_fallback.py | 526 ------------------ services/mana-llm/tests/test_streaming.py | 57 -- 42 files changed, 6371 deletions(-) delete mode 100644 services/mana-llm/.env.example delete mode 100644 services/mana-llm/.gitignore delete mode 100644 services/mana-llm/CLAUDE.md delete mode 100644 services/mana-llm/Dockerfile delete mode 100644 services/mana-llm/aliases.yaml delete mode 100644 services/mana-llm/docker-compose.dev.yml delete mode 100644 services/mana-llm/docker-compose.yml delete mode 100644 services/mana-llm/pyproject.toml delete mode 100644 services/mana-llm/requirements.txt delete mode 100644 services/mana-llm/service.pyw delete mode 100644 services/mana-llm/src/__init__.py delete mode 100644 services/mana-llm/src/aliases.py delete mode 100644 services/mana-llm/src/api_auth.py delete mode 100644 services/mana-llm/src/config.py delete mode 100644 services/mana-llm/src/health.py delete mode 100644 services/mana-llm/src/health_probe.py delete mode 100644 services/mana-llm/src/main.py delete mode 100644 services/mana-llm/src/models/__init__.py delete mode 100644 services/mana-llm/src/models/requests.py delete mode 100644 services/mana-llm/src/models/responses.py delete mode 100644 services/mana-llm/src/providers/__init__.py delete mode 100644 services/mana-llm/src/providers/base.py delete mode 100644 services/mana-llm/src/providers/errors.py delete mode 100644 services/mana-llm/src/providers/google.py delete mode 100644 services/mana-llm/src/providers/ollama.py delete mode 100644 services/mana-llm/src/providers/openai_compat.py delete mode 100644 services/mana-llm/src/providers/router.py delete mode 100644 services/mana-llm/src/streaming/__init__.py delete mode 100644 services/mana-llm/src/streaming/sse.py delete mode 100644 services/mana-llm/src/utils/__init__.py delete mode 100644 services/mana-llm/src/utils/cache.py delete mode 100644 services/mana-llm/src/utils/metrics.py delete mode 100755 services/mana-llm/start.sh delete mode 100644 services/mana-llm/tests/__init__.py delete mode 100644 services/mana-llm/tests/test_aliases.py delete mode 100644 services/mana-llm/tests/test_api.py delete mode 100644 services/mana-llm/tests/test_health.py delete mode 100644 services/mana-llm/tests/test_health_probe.py delete mode 100644 services/mana-llm/tests/test_m4_observability.py delete mode 100644 services/mana-llm/tests/test_providers.py delete mode 100644 services/mana-llm/tests/test_router_fallback.py delete mode 100644 services/mana-llm/tests/test_streaming.py diff --git a/services/mana-llm/.env.example b/services/mana-llm/.env.example deleted file mode 100644 index 98d47d8ba..000000000 --- a/services/mana-llm/.env.example +++ /dev/null @@ -1,28 +0,0 @@ -# Service -PORT=3025 -LOG_LEVEL=info - -# Ollama (Primary) -OLLAMA_URL=http://localhost:11434 -OLLAMA_DEFAULT_MODEL=gemma3:4b -OLLAMA_TIMEOUT=120 - -# OpenRouter (Cloud Fallback) -OPENROUTER_API_KEY=sk-or-v1-xxx -OPENROUTER_DEFAULT_MODEL=meta-llama/llama-3.1-8b-instruct - -# Groq (Optional) -GROQ_API_KEY=gsk_xxx - -# Together (Optional) -TOGETHER_API_KEY=xxx - -# Caching (Optional) -REDIS_URL=redis://localhost:6379 -CACHE_TTL=3600 - -# CORS -CORS_ORIGINS=http://localhost:5173,https://mana.how - -# API key for cross-service auth (validated in src/api_auth.py) -GPU_API_KEY= diff --git a/services/mana-llm/.gitignore b/services/mana-llm/.gitignore deleted file mode 100644 index 602e8abc4..000000000 --- a/services/mana-llm/.gitignore +++ /dev/null @@ -1,31 +0,0 @@ -# Python -__pycache__/ -*.py[cod] -*$py.class -*.so -.Python -venv/ -.venv/ -ENV/ -.env - -# IDE -.idea/ -.vscode/ -*.swp -*.swo - -# Testing -.pytest_cache/ -.coverage -htmlcov/ -.tox/ - -# Build -dist/ -build/ -*.egg-info/ - -# Local -.env.local -*.log diff --git a/services/mana-llm/CLAUDE.md b/services/mana-llm/CLAUDE.md deleted file mode 100644 index bdf59de4f..000000000 --- a/services/mana-llm/CLAUDE.md +++ /dev/null @@ -1,376 +0,0 @@ -# mana-llm - -Central LLM abstraction service providing a unified OpenAI-compatible API for Ollama and cloud LLM providers. - -## Overview - -mana-llm acts as a central gateway for all LLM requests in the monorepo, providing: -- Unified OpenAI-compatible API -- Provider routing (Ollama, OpenRouter, Groq, Together) -- Streaming via Server-Sent Events (SSE) -- Vision/multimodal support -- Embeddings generation -- Prometheus metrics - -## Architecture - -``` -┌─────────────────────────────────────────────────────────────────────┐ -│ Consumer Apps │ -│ chat-backend │ mana web │ todo (LLM enrich) │ etc. │ -└────────────────────────────────┬────────────────────────────────────┘ - │ HTTP/SSE - ▼ -┌─────────────────────────────────────────────────────────────────────┐ -│ mana-llm (Port 3025) │ -│ ┌─────────────┐ ┌─────────────┐ ┌─────────────┐ │ -│ │ Router │ │ Cache │ │ Metrics │ │ -│ │ (Provider) │ │ (Redis) │ │ (Prometheus)│ │ -│ └──────┬──────┘ └─────────────┘ └─────────────┘ │ -│ │ │ -│ ┌──────┴──────────────────────────────────────────┐ │ -│ │ Provider Adapters │ │ -│ │ ┌──────────┐ ┌──────────┐ ┌──────────────┐ │ │ -│ │ │ Ollama │ │ OpenAI │ │ OpenRouter │ │ │ -│ │ │ Adapter │ │ Adapter │ │ Adapter │ │ │ -│ │ └──────────┘ └──────────┘ └──────────────┘ │ │ -│ └─────────────────────────────────────────────────┘ │ -└─────────────────────────────────────────────────────────────────────┘ -``` - -## Quick Start - -### Prerequisites - -- Python 3.11+ -- Ollama running locally (http://localhost:11434) -- Redis (optional, for caching) - -### Development - -```bash -cd services/mana-llm - -# Create virtual environment -python -m venv venv -source venv/bin/activate # or venv\Scripts\activate on Windows - -# Install dependencies -pip install -r requirements.txt - -# Copy environment file -cp .env.example .env - -# Start Redis (optional) -docker-compose -f docker-compose.dev.yml up -d - -# Run service -python -m uvicorn src.main:app --port 3025 --reload -``` - -### Docker - -```bash -# Full stack (mana-llm + Redis) -docker-compose up -d - -# View logs -docker-compose logs -f mana-llm -``` - -## API Endpoints - -### Chat Completions - -```bash -# Non-streaming -curl -X POST http://localhost:3025/v1/chat/completions \ - -H "Content-Type: application/json" \ - -d '{ - "model": "ollama/gemma3:4b", - "messages": [{"role": "user", "content": "Hello!"}], - "stream": false - }' - -# Streaming (SSE) -curl -X POST http://localhost:3025/v1/chat/completions \ - -H "Content-Type: application/json" \ - -d '{ - "model": "ollama/gemma3:4b", - "messages": [{"role": "user", "content": "Hello!"}], - "stream": true - }' -``` - -### Vision/Multimodal - -```bash -curl -X POST http://localhost:3025/v1/chat/completions \ - -H "Content-Type: application/json" \ - -d '{ - "model": "ollama/llava:7b", - "messages": [{ - "role": "user", - "content": [ - {"type": "text", "text": "What is in this image?"}, - {"type": "image_url", "image_url": {"url": "data:image/png;base64,..."}} - ] - }] - }' -``` - -### Models - -```bash -# List all models -curl http://localhost:3025/v1/models - -# Get specific model -curl http://localhost:3025/v1/models/ollama/gemma3:4b -``` - -### Embeddings - -```bash -curl -X POST http://localhost:3025/v1/embeddings \ - -H "Content-Type: application/json" \ - -d '{ - "model": "ollama/nomic-embed-text", - "input": "Text to embed" - }' -``` - -### Health & Metrics - -```bash -# Liveness summary (legacy, terse shape — only status + per-provider status string) -curl http://localhost:3025/health - -# Detailed per-provider liveness snapshot (M4) -curl http://localhost:3025/v1/health - -# Prometheus metrics -curl http://localhost:3025/metrics -``` - -### Aliases (M4 — see `aliases.yaml` and the fallback section below) - -```bash -# What does each `mana/` resolve to? -curl http://localhost:3025/v1/aliases -``` - -## Aliases & Fallback - -> Background: [`docs/plans/llm-fallback-aliases.md`](../../docs/plans/llm-fallback-aliases.md) - -### What callers send - -Two acceptable shapes for the `model` field of `/v1/chat/completions`: - -1. **Aliases** in the reserved `mana/` namespace — recommended for product code. - The router resolves them via `aliases.yaml` to a chain of concrete - `provider/model` strings and tries them in order. -2. **Direct `provider/model`** — bypasses the alias layer, no fallback. - Useful for tests, debugging, and one-off integrations. - -| Alias | Class | -|---|---| -| `mana/fast-text` | Short answers, classification, single-shot Q&A | -| `mana/long-form` | Writing, essays, stories, longer prose | -| `mana/structured` | JSON output (comic storyboards, research subqueries, tag suggestions) | -| `mana/reasoning` | Agent missions, tool calls, multi-step plans | -| `mana/vision` | Multimodal (image + text) | - -The chain for each alias lives in `services/mana-llm/aliases.yaml`. Edit -the file and `kill -HUP ` to reload — no restart needed. Reload -errors keep the previous good state; check the service logs. - -### Fallback semantics - -Every chain is tried in order. The router skips an entry if the provider -isn't configured at this deployment (no API key) or is currently marked -unhealthy by the health-cache. For each remaining entry the request is -attempted; on a **retryable** error (connection failure, timeout, 5xx, -rate-limit, RemoteProtocolError) the provider is marked unhealthy and -the next entry is tried. **Non-retryable** errors (auth, capability, -content-blocked, 4xx, unknown exception types) propagate immediately — -no fallback, the cache is not poisoned. - -Streaming follows the same logic up to the **first byte**. Once a chunk -has been yielded the provider is committed; mid-stream errors surface -as-is so we never splice two providers' voices into one output. - -If every entry was skipped or failed, the response is `503` carrying a -structured `attempts: list[(model, reason)]` log so the cause is -visible to the caller, not only in service logs. - -### Resolved-model header - -Non-streaming responses carry `X-Mana-LLM-Resolved: /` -(e.g. `groq/llama-3.3-70b-versatile`) — the concrete model that -actually answered. Use this for token-cost attribution when the request -used an alias. For streaming, each chunk's `model` field carries the -same info (headers go out before the chain is walked). - -### Health-cache + probe - -`ProviderHealthCache` keeps a per-provider circuit-breaker: - -* 1 failure: still healthy (transient blip, don't bounce). -* 2 consecutive failures: `is_healthy → False` for 60 s; the router - fail-fasts straight to the next chain entry. -* After 60 s: half-open. Next call exercises the provider; success - fully resets, failure re-arms the backoff. - -A background `HealthProbe` task runs every 30 s with a 3 s timeout per -provider, calling cheap endpoints (`/api/tags` for Ollama, `/v1/models` -for OpenAI-compat). One bad probe can't sink the loop; results feed -into the same cache as the call-site fallback. - -### Prometheus metrics added in M4 - -| Metric | Labels | Purpose | -|---|---|---| -| `mana_llm_alias_resolved_total` | `alias`, `target` | How often an alias resolved to which concrete model — useful for spotting cases where the primary always falls through. | -| `mana_llm_fallback_total` | `from_model`, `to_model`, `reason` | Each fallback transition. `reason` is the exception class name or `cache-unhealthy` / `unconfigured`. | -| `mana_llm_provider_healthy` | `provider` | Gauge: 1 healthy, 0 in backoff. Mirrors the circuit-breaker. | - -## Provider Routing - -Models use the format `provider/model`: - -| Model | Provider | Target | -|-------|----------|--------| -| `ollama/gemma3:4b` | Ollama | localhost:11434 | -| `ollama/llava:7b` | Ollama | localhost:11434 | -| `openrouter/meta-llama/llama-3.1-8b-instruct` | OpenRouter | api.openrouter.ai | -| `groq/llama-3.1-8b-instant` | Groq | api.groq.com | -| `together/meta-llama/Meta-Llama-3.1-8B-Instruct-Turbo` | Together | api.together.xyz | - -**Default:** If no provider prefix is given (e.g., `gemma3:4b`), Ollama is used. - -## Configuration - -Environment variables (see `.env.example`): - -| Variable | Default | Description | -|----------|---------|-------------| -| `PORT` | 3025 | Service port | -| `LOG_LEVEL` | info | Logging level | -| `OLLAMA_URL` | http://localhost:11434 | Ollama server URL | -| `OLLAMA_DEFAULT_MODEL` | gemma3:4b | Default Ollama model | -| `OLLAMA_TIMEOUT` | 120 | Ollama request timeout (seconds) | -| `OPENROUTER_API_KEY` | - | OpenRouter API key | -| `GROQ_API_KEY` | - | Groq API key | -| `TOGETHER_API_KEY` | - | Together API key | -| `REDIS_URL` | - | Redis URL for caching | -| `CACHE_TTL` | 3600 | Cache TTL in seconds | -| `CORS_ORIGINS` | localhost | Allowed CORS origins | - -## Project Structure - -``` -services/mana-llm/ -├── src/ -│ ├── main.py # FastAPI app entry point -│ ├── config.py # Settings via pydantic-settings -│ ├── providers/ -│ │ ├── base.py # Abstract provider interface -│ │ ├── ollama.py # Ollama provider -│ │ ├── openai_compat.py # OpenAI-compatible provider -│ │ └── router.py # Provider routing logic -│ ├── models/ -│ │ ├── requests.py # Request Pydantic models -│ │ └── responses.py # Response Pydantic models -│ ├── streaming/ -│ │ └── sse.py # SSE response handling -│ └── utils/ -│ ├── cache.py # Redis caching -│ └── metrics.py # Prometheus metrics -├── tests/ -│ ├── test_api.py # API endpoint tests -│ ├── test_providers.py # Provider tests -│ └── test_streaming.py # Streaming tests -├── Dockerfile -├── docker-compose.yml -├── docker-compose.dev.yml -├── requirements.txt -├── pyproject.toml -└── .env.example -``` - -## Testing - -```bash -# Run tests -pytest - -# Run with coverage -pytest --cov=src - -# Run specific test file -pytest tests/test_providers.py -v -``` - -## Integration Example - -### TypeScript/Node.js Client - -```typescript -// Using fetch -const response = await fetch('http://localhost:3025/v1/chat/completions', { - method: 'POST', - headers: { 'Content-Type': 'application/json' }, - body: JSON.stringify({ - model: 'ollama/gemma3:4b', - messages: [{ role: 'user', content: 'Hello!' }], - stream: false, - }), -}); - -const data = await response.json(); -console.log(data.choices[0].message.content); -``` - -### Streaming with EventSource - -```typescript -const response = await fetch('http://localhost:3025/v1/chat/completions', { - method: 'POST', - headers: { 'Content-Type': 'application/json' }, - body: JSON.stringify({ - model: 'ollama/gemma3:4b', - messages: [{ role: 'user', content: 'Hello!' }], - stream: true, - }), -}); - -const reader = response.body?.getReader(); -const decoder = new TextDecoder(); - -while (true) { - const { done, value } = await reader!.read(); - if (done) break; - - const chunk = decoder.decode(value); - const lines = chunk.split('\n').filter(line => line.startsWith('data: ')); - - for (const line of lines) { - const data = line.slice(6); - if (data === '[DONE]') break; - - const parsed = JSON.parse(data); - const content = parsed.choices[0]?.delta?.content; - if (content) process.stdout.write(content); - } -} -``` - -## Related Services - -| Service | Port | Description | -|---------|------|-------------| -| mana-tts | 3022 | Text-to-speech service | -| mana-stt | 3023 | Speech-to-text service | -| mana-search | 3021 | Web search & extraction | diff --git a/services/mana-llm/Dockerfile b/services/mana-llm/Dockerfile deleted file mode 100644 index 8d782184e..000000000 --- a/services/mana-llm/Dockerfile +++ /dev/null @@ -1,49 +0,0 @@ -# Build stage -FROM python:3.12-slim AS builder - -WORKDIR /app - -# Install build dependencies -RUN apt-get update && apt-get install -y --no-install-recommends \ - build-essential \ - && rm -rf /var/lib/apt/lists/* - -# Copy requirements first for caching -COPY requirements.txt . -RUN pip install --no-cache-dir --user -r requirements.txt - -# Production stage -FROM python:3.12-slim - -WORKDIR /app - -# Create non-root user -RUN useradd -m -u 1000 appuser - -# Copy installed packages from builder -COPY --from=builder /root/.local /home/appuser/.local - -# Copy application code + alias config. `aliases.yaml` is loaded by -# main.py's lifespan handler from `Path(__file__).parent.parent / -# 'aliases.yaml'` = /app/aliases.yaml. Without it the FastAPI app -# raises AliasConfigError on startup and the container crashloops. -COPY --chown=appuser:appuser src/ ./src/ -COPY --chown=appuser:appuser aliases.yaml ./aliases.yaml - -# Set environment -ENV PATH=/home/appuser/.local/bin:$PATH -ENV PYTHONUNBUFFERED=1 -ENV PYTHONDONTWRITEBYTECODE=1 - -# Switch to non-root user -USER appuser - -# Expose port -EXPOSE 3025 - -# Health check -HEALTHCHECK --interval=30s --timeout=10s --start-period=5s --retries=3 \ - CMD python -c "import httpx; httpx.get('http://localhost:3025/health').raise_for_status()" - -# Run application -CMD ["python", "-m", "uvicorn", "src.main:app", "--host", "0.0.0.0", "--port", "3025"] diff --git a/services/mana-llm/aliases.yaml b/services/mana-llm/aliases.yaml deleted file mode 100644 index f6850a748..000000000 --- a/services/mana-llm/aliases.yaml +++ /dev/null @@ -1,54 +0,0 @@ -# mana-llm Model Aliases — single source of truth for which class of -# model each backend feature uses. -# -# Consumers (mana-api, mana-ai, …) send `"model": "mana/"` in -# their /v1/chat/completions requests; mana-llm resolves the alias to -# the chain below and tries entries in order, skipping providers that -# the health-cache has marked unhealthy. -# -# Order in `chain` = preference. First healthy entry wins. Each chain -# should end with a cloud provider so the system stays functional even -# when the local GPU server (mana-gpu, RTX 3090) is offline. -# -# Reload at runtime: `kill -HUP ` after editing — no restart needed. -# Reference: docs/plans/llm-fallback-aliases.md. - -aliases: - mana/fast-text: - description: "Short answers, classification, single-shot Q&A" - chain: - - ollama/qwen2.5:7b - - groq/llama-3.1-8b-instant - - openrouter/anthropic/claude-3-haiku - - mana/long-form: - description: "Writing, essays, stories, longer prose" - chain: - - ollama/gemma3:12b - - groq/llama-3.3-70b-versatile - - openrouter/anthropic/claude-3.5-haiku - - mana/structured: - description: "JSON output (comic storyboards, research subqueries, tag suggestions)" - chain: - - ollama/qwen2.5:7b - - groq/llama-3.1-8b-instant - - openrouter/openai/gpt-4o-mini - - mana/reasoning: - description: "Agent missions, tool calls, multi-step plans" - # Cloud first by design — local 4-7B models are unreliable for tool calls - chain: - - openrouter/anthropic/claude-3.5-sonnet - - groq/llama-3.3-70b-versatile - - mana/vision: - description: "Multimodal (image + text)" - chain: - - ollama/llava:7b - - google/gemini-2.0-flash-exp - - openrouter/openai/gpt-4o - -# Default alias used when a request omits `model` or sends an unknown -# value with no provider prefix. Keep this conservative (cheap class). -default: mana/fast-text diff --git a/services/mana-llm/docker-compose.dev.yml b/services/mana-llm/docker-compose.dev.yml deleted file mode 100644 index fe41d1aee..000000000 --- a/services/mana-llm/docker-compose.dev.yml +++ /dev/null @@ -1,15 +0,0 @@ -# Development compose - only Redis (run Python locally) -version: "3.8" - -services: - redis: - image: redis:7-alpine - container_name: mana-llm-redis-dev - ports: - - "6380:6379" - volumes: - - redis-data-dev:/data - restart: unless-stopped - -volumes: - redis-data-dev: diff --git a/services/mana-llm/docker-compose.yml b/services/mana-llm/docker-compose.yml deleted file mode 100644 index afe5dc223..000000000 --- a/services/mana-llm/docker-compose.yml +++ /dev/null @@ -1,52 +0,0 @@ -version: "3.8" - -services: - mana-llm: - build: - context: . - dockerfile: Dockerfile - container_name: mana-llm - ports: - - "3025:3025" - environment: - - PORT=3025 - - LOG_LEVEL=info - - OLLAMA_URL=http://host.docker.internal:11434 - - OLLAMA_DEFAULT_MODEL=gemma3:4b - - OLLAMA_TIMEOUT=120 - - REDIS_URL=redis://redis:6379 - # Add API keys via .env file - - GOOGLE_API_KEY=${GOOGLE_API_KEY:-} - - GOOGLE_DEFAULT_MODEL=${GOOGLE_DEFAULT_MODEL:-gemini-2.5-flash} - - OPENROUTER_API_KEY=${OPENROUTER_API_KEY:-} - - GROQ_API_KEY=${GROQ_API_KEY:-} - - TOGETHER_API_KEY=${TOGETHER_API_KEY:-} - - CORS_ORIGINS=http://localhost:5173,http://localhost:3000,https://mana.how - depends_on: - - redis - restart: unless-stopped - healthcheck: - test: ["CMD", "python", "-c", "import httpx; httpx.get('http://localhost:3025/health').raise_for_status()"] - interval: 30s - timeout: 10s - retries: 3 - start_period: 10s - extra_hosts: - - "host.docker.internal:host-gateway" - - redis: - image: redis:7-alpine - container_name: mana-llm-redis - ports: - - "6380:6379" - volumes: - - redis-data:/data - restart: unless-stopped - healthcheck: - test: ["CMD", "redis-cli", "ping"] - interval: 10s - timeout: 5s - retries: 5 - -volumes: - redis-data: diff --git a/services/mana-llm/pyproject.toml b/services/mana-llm/pyproject.toml deleted file mode 100644 index a2125b1c6..000000000 --- a/services/mana-llm/pyproject.toml +++ /dev/null @@ -1,40 +0,0 @@ -[project] -name = "mana-llm" -version = "0.1.0" -description = "Central LLM abstraction service for Ollama and OpenAI-compatible APIs" -requires-python = ">=3.11" -dependencies = [ - "fastapi>=0.115.0", - "uvicorn[standard]>=0.32.0", - "pydantic>=2.10.0", - "pydantic-settings>=2.6.0", - "httpx>=0.28.0", - "sse-starlette>=2.2.0", - "redis>=5.2.0", - "prometheus-client>=0.21.0", - "google-genai>=1.0.0", - "pyyaml>=6.0.2", -] - -[project.optional-dependencies] -dev = [ - "pytest>=8.3.0", - "pytest-asyncio>=0.24.0", - "pytest-httpx>=0.35.0", - "ruff>=0.8.0", -] - -[build-system] -requires = ["hatchling"] -build-backend = "hatchling.build" - -[tool.ruff] -line-length = 100 -target-version = "py311" - -[tool.ruff.lint] -select = ["E", "F", "I", "W"] - -[tool.pytest.ini_options] -asyncio_mode = "auto" -testpaths = ["tests"] diff --git a/services/mana-llm/requirements.txt b/services/mana-llm/requirements.txt deleted file mode 100644 index eb4686e25..000000000 --- a/services/mana-llm/requirements.txt +++ /dev/null @@ -1,29 +0,0 @@ -# Core -fastapi>=0.115.0 -uvicorn[standard]>=0.32.0 -pydantic>=2.10.0 -pydantic-settings>=2.6.0 - -# HTTP Client -httpx>=0.28.0 - -# Streaming -sse-starlette>=2.2.0 - -# Caching (optional) -redis>=5.2.0 - -# Google Gemini -google-genai>=1.0.0 - -# Metrics -prometheus-client>=0.21.0 - -# Config (alias registry) -pyyaml>=6.0.2 - -# Dev -pytest>=8.3.0 -pytest-asyncio>=0.24.0 -pytest-httpx>=0.35.0 -ruff>=0.8.0 diff --git a/services/mana-llm/service.pyw b/services/mana-llm/service.pyw deleted file mode 100644 index 5a21e29de..000000000 --- a/services/mana-llm/service.pyw +++ /dev/null @@ -1,17 +0,0 @@ -"""mana-llm service runner (run with pythonw.exe to run headless).""" -import os -import sys -os.chdir(r"C:\mana\services\mana-llm") -sys.path.insert(0, r"C:\mana\services\mana-llm") - -# Load .env file -from dotenv import load_dotenv -load_dotenv(r"C:\mana\services\mana-llm\.env") - -# Redirect stdout/stderr to log file -log = open(r"C:\mana\services\mana-llm\service.log", "w", buffering=1) -sys.stdout = log -sys.stderr = log - -import uvicorn -uvicorn.run("src.main:app", host="0.0.0.0", port=3025, log_level="info") diff --git a/services/mana-llm/src/__init__.py b/services/mana-llm/src/__init__.py deleted file mode 100644 index 202451520..000000000 --- a/services/mana-llm/src/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -"""mana-llm - Central LLM abstraction service.""" - -__version__ = "0.1.0" diff --git a/services/mana-llm/src/aliases.py b/services/mana-llm/src/aliases.py deleted file mode 100644 index 9e4edd26e..000000000 --- a/services/mana-llm/src/aliases.py +++ /dev/null @@ -1,223 +0,0 @@ -"""Model-alias registry. - -Loads `aliases.yaml` and exposes a small API the router uses to resolve -semantic model names like ``mana/long-form`` to an ordered list of -concrete provider-prefixed model strings (``ollama/gemma3:12b`` → -``groq/llama-3.3-70b-versatile`` → …). - -The registry is hot-reloadable: ``reload()`` rebuilds the in-memory -mapping atomically. Reload errors leave the previous good state intact -so a typo in the yaml file doesn't take the service down — caller logs -the error and keeps serving. - -See docs/plans/llm-fallback-aliases.md for the full design. -""" - -from __future__ import annotations - -import logging -import threading -from dataclasses import dataclass -from pathlib import Path - -import yaml - -logger = logging.getLogger(__name__) - -# Aliases live in this namespace. Anything else passed as `model` is -# treated as a direct provider/model string (preserves the legal -# bypass-the-alias-layer escape hatch for tests/debugging). -ALIAS_PREFIX = "mana/" - - -@dataclass(frozen=True) -class Alias: - """A resolved alias entry.""" - - name: str - description: str - chain: tuple[str, ...] - - -class AliasConfigError(ValueError): - """Raised when the YAML file is malformed or violates schema constraints.""" - - -class UnknownAliasError(KeyError): - """Raised when a caller asks for an alias that isn't defined.""" - - -def _validate_chain(name: str, chain: object) -> tuple[str, ...]: - """Schema-check a single alias chain. Returns the validated tuple.""" - if not isinstance(chain, list): - raise AliasConfigError(f"alias '{name}': chain must be a list, got {type(chain).__name__}") - if not chain: - raise AliasConfigError(f"alias '{name}': chain must not be empty") - out: list[str] = [] - for i, entry in enumerate(chain): - if not isinstance(entry, str) or not entry.strip(): - raise AliasConfigError( - f"alias '{name}': chain[{i}] must be a non-empty string, got {entry!r}" - ) - if "/" not in entry: - raise AliasConfigError( - f"alias '{name}': chain[{i}] = {entry!r} must include a provider prefix " - f"(e.g. 'ollama/...', 'groq/...')" - ) - out.append(entry.strip()) - return tuple(out) - - -def _validate_name(name: object) -> str: - """Aliases must live in the reserved `mana/` namespace.""" - if not isinstance(name, str) or not name.startswith(ALIAS_PREFIX): - raise AliasConfigError( - f"alias name {name!r} must start with {ALIAS_PREFIX!r} (the reserved namespace)" - ) - suffix = name[len(ALIAS_PREFIX) :] - if not suffix or "/" in suffix: - raise AliasConfigError( - f"alias name {name!r} must have exactly one segment after {ALIAS_PREFIX!r}" - ) - return name - - -def _parse_document(doc: object) -> tuple[dict[str, Alias], str | None]: - """Parse a loaded YAML document into a normalized (aliases, default) pair.""" - if not isinstance(doc, dict): - raise AliasConfigError(f"yaml root must be a mapping, got {type(doc).__name__}") - - raw_aliases = doc.get("aliases", {}) - if not isinstance(raw_aliases, dict): - raise AliasConfigError( - f"`aliases` must be a mapping, got {type(raw_aliases).__name__}" - ) - if not raw_aliases: - raise AliasConfigError("`aliases` is empty — at least one alias is required") - - parsed: dict[str, Alias] = {} - for name, body in raw_aliases.items(): - validated_name = _validate_name(name) - if not isinstance(body, dict): - raise AliasConfigError( - f"alias '{validated_name}': body must be a mapping, got {type(body).__name__}" - ) - description = body.get("description", "") - if not isinstance(description, str): - raise AliasConfigError( - f"alias '{validated_name}': description must be a string" - ) - chain = _validate_chain(validated_name, body.get("chain")) - parsed[validated_name] = Alias( - name=validated_name, - description=description.strip(), - chain=chain, - ) - - default = doc.get("default") - if default is not None: - if not isinstance(default, str): - raise AliasConfigError(f"`default` must be a string, got {type(default).__name__}") - if default not in parsed: - raise AliasConfigError( - f"`default` references unknown alias {default!r} " - f"(known: {sorted(parsed)})" - ) - - return parsed, default - - -class AliasRegistry: - """Thread-safe in-memory registry of model aliases. - - Construct once at startup with the path to the yaml file. Call - :meth:`reload` to re-read after a SIGHUP. Reads (``resolve``, - ``is_alias``, …) are cheap and lock-free during steady state — they - snapshot the current mapping reference; only the swap on reload is - serialized. - """ - - def __init__(self, path: Path | str): - self._path = Path(path) - self._lock = threading.Lock() - self._aliases: dict[str, Alias] = {} - self._default: str | None = None - self._load() - - @property - def path(self) -> Path: - return self._path - - def _load(self) -> None: - """Initial load — propagates errors so a bad config fails fast at startup.""" - if not self._path.exists(): - raise AliasConfigError(f"alias config not found at {self._path}") - with self._path.open("r", encoding="utf-8") as f: - try: - doc = yaml.safe_load(f) - except yaml.YAMLError as e: - raise AliasConfigError(f"failed to parse {self._path}: {e}") from e - aliases, default = _parse_document(doc) - # No lock needed during __init__ — nothing else can read yet. - self._aliases = aliases - self._default = default - logger.info( - "AliasRegistry loaded %d alias(es) from %s (default=%s)", - len(aliases), - self._path, - default, - ) - - def reload(self) -> None: - """Re-read the yaml file. On parse error, keep the previous state and raise. - - Designed for SIGHUP: callers should ``try/except AliasConfigError`` - and log; do not crash the service on a typo. - """ - with self._path.open("r", encoding="utf-8") as f: - try: - doc = yaml.safe_load(f) - except yaml.YAMLError as e: - raise AliasConfigError(f"failed to parse {self._path}: {e}") from e - aliases, default = _parse_document(doc) - with self._lock: - self._aliases = aliases - self._default = default - logger.info( - "AliasRegistry reloaded %d alias(es) from %s (default=%s)", - len(aliases), - self._path, - default, - ) - - @staticmethod - def is_alias(name: str) -> bool: - """Cheap syntactic check — does this name live in the alias namespace? - - Static; doesn't require a registry instance. Used by the router to - decide whether to dispatch to the alias layer or pass through to - provider-direct routing. - """ - return isinstance(name, str) and name.startswith(ALIAS_PREFIX) - - def resolve(self, name: str) -> Alias: - """Look up the named alias. Raises :class:`UnknownAliasError` if absent.""" - try: - return self._aliases[name] - except KeyError as e: - raise UnknownAliasError( - f"unknown alias {name!r} (known: {sorted(self._aliases)})" - ) from e - - def resolve_chain(self, name: str) -> tuple[str, ...]: - """Sugar for ``resolve(name).chain`` — the form the router actually wants.""" - return self.resolve(name).chain - - @property - def default_alias(self) -> str | None: - """The alias used when a request arrives with no recognizable model.""" - return self._default - - def list_aliases(self) -> list[Alias]: - """All aliases as a snapshot list — for the GET /v1/aliases debug endpoint.""" - return [self._aliases[k] for k in sorted(self._aliases)] diff --git a/services/mana-llm/src/api_auth.py b/services/mana-llm/src/api_auth.py deleted file mode 100644 index 0f5813735..000000000 --- a/services/mana-llm/src/api_auth.py +++ /dev/null @@ -1,53 +0,0 @@ -""" -Simple API Key Authentication Middleware for GPU Services. - -Checks X-API-Key header or ?api_key query parameter. -Skips auth for /health, /docs, /openapi.json, /redoc endpoints. - -Environment variables: - GPU_API_KEY: Required API key (if empty, auth is disabled) - GPU_REQUIRE_AUTH: Enable/disable auth (default: true if GPU_API_KEY is set) -""" - -import os -import logging -from fastapi import Request -from fastapi.responses import JSONResponse -from starlette.middleware.base import BaseHTTPMiddleware - -logger = logging.getLogger(__name__) - -GPU_API_KEY = os.getenv("GPU_API_KEY", "") -GPU_REQUIRE_AUTH = os.getenv("GPU_REQUIRE_AUTH", "true" if GPU_API_KEY else "false").lower() == "true" - -# Endpoints that don't require auth -PUBLIC_PATHS = {"/health", "/docs", "/openapi.json", "/redoc", "/metrics"} - - -class ApiKeyMiddleware(BaseHTTPMiddleware): - async def dispatch(self, request: Request, call_next): - # Skip auth if disabled - if not GPU_REQUIRE_AUTH or not GPU_API_KEY: - return await call_next(request) - - # Skip auth for public endpoints - if request.url.path in PUBLIC_PATHS: - return await call_next(request) - - # Check API key from header or query param - api_key = request.headers.get("X-API-Key") or request.query_params.get("api_key") - - if not api_key: - return JSONResponse( - status_code=401, - content={"detail": "Missing API key. Provide X-API-Key header."}, - ) - - if api_key != GPU_API_KEY: - logger.warning(f"Invalid API key attempt from {request.client.host if request.client else 'unknown'}") - return JSONResponse( - status_code=401, - content={"detail": "Invalid API key."}, - ) - - return await call_next(request) diff --git a/services/mana-llm/src/config.py b/services/mana-llm/src/config.py deleted file mode 100644 index b17fdc0ba..000000000 --- a/services/mana-llm/src/config.py +++ /dev/null @@ -1,57 +0,0 @@ -"""Configuration settings for mana-llm service.""" - -from pydantic_settings import BaseSettings - - -class Settings(BaseSettings): - """Application settings loaded from environment variables.""" - - # Service - port: int = 3025 - log_level: str = "info" - - # Ollama (Primary provider) - ollama_url: str = "http://localhost:11434" - ollama_default_model: str = "gemma3:4b" - ollama_timeout: int = 120 - - # OpenRouter (Cloud fallback) - openrouter_api_key: str | None = None - openrouter_base_url: str = "https://openrouter.ai/api/v1" - openrouter_default_model: str = "meta-llama/llama-3.1-8b-instruct" - - # Groq (Optional) - groq_api_key: str | None = None - groq_base_url: str = "https://api.groq.com/openai/v1" - - # Together (Optional) - together_api_key: str | None = None - together_base_url: str = "https://api.together.xyz/v1" - - # Google Gemini (Fallback provider) - google_api_key: str | None = None - google_default_model: str = "gemini-2.5-flash" - - # Auto-fallback: Ollama → Google when Ollama is overloaded/down - auto_fallback_enabled: bool = True - ollama_max_concurrent: int = 3 - - # Caching (Optional) - redis_url: str | None = None - cache_ttl: int = 3600 - - # CORS - cors_origins: str = "http://localhost:5173,http://localhost:5190,https://mana.how,https://playground.mana.how" - - @property - def cors_origins_list(self) -> list[str]: - """Parse CORS origins from comma-separated string.""" - return [origin.strip() for origin in self.cors_origins.split(",")] - - class Config: - env_file = ".env" - env_file_encoding = "utf-8" - extra = "ignore" - - -settings = Settings() diff --git a/services/mana-llm/src/health.py b/services/mana-llm/src/health.py deleted file mode 100644 index eda34526c..000000000 --- a/services/mana-llm/src/health.py +++ /dev/null @@ -1,203 +0,0 @@ -"""Provider health cache. - -Tracks per-provider liveness for the LLM router. The router reads -:meth:`is_healthy` to decide whether to even try a provider in a chain; -the probe loop and the call-site fallback handler write state via -:meth:`mark_healthy` / :meth:`mark_unhealthy`. - -Implements a simple circuit-breaker: - -* The first failure flips no switch — providers occasionally have - transient blips, we don't want to bounce off after a single 502. -* After ``failure_threshold`` consecutive failures the provider is - marked unhealthy for ``unhealthy_backoff`` seconds. During that - window :meth:`is_healthy` returns ``False`` so the router fails - fast straight to the next chain entry. -* When the backoff expires :meth:`is_healthy` returns ``True`` again - (half-open). The next call exercises the provider; success calls - :meth:`mark_healthy` and fully resets state, failure re-arms the - backoff window. - -State is kept in a plain dict guarded by a ``threading.Lock``. All -operations are short, lock-free reads of dict references aren't safe -because we mutate state in-place — the lock keeps it boring. Probe -loop runs in the asyncio loop alongside the router, but the lock -costs are negligible at ~1 update/30s/provider. -""" - -from __future__ import annotations - -import logging -import threading -import time -from dataclasses import dataclass, field -from typing import Callable, Iterable - -logger = logging.getLogger(__name__) - -#: Notification fired whenever a provider transitions between healthy and -#: unhealthy. ``main.py`` wires this to the Prometheus gauge — but the -#: cache itself stays metrics-agnostic so tests don't need to mock it. -HealthChangeListener = Callable[[str, bool], None] - -DEFAULT_FAILURE_THRESHOLD = 2 -DEFAULT_UNHEALTHY_BACKOFF_SEC = 60.0 - - -@dataclass -class ProviderState: - """Per-provider liveness snapshot. All times are unix seconds.""" - - healthy: bool = True - consecutive_failures: int = 0 - last_check: float = 0.0 - last_error: str | None = None - unhealthy_until: float = 0.0 - """When > now, the provider is currently in backoff (`is_healthy → False`).""" - - -class ProviderHealthCache: - """Thread-safe per-provider liveness with circuit-breaker semantics. - - Provider IDs are arbitrary strings — by convention we use the same - short name as the provider router (``ollama``, ``groq``, ``openrouter``, - ``together``, ``google``). The cache is provider-list agnostic; states - are created lazily on first ``mark_*`` or queried-but-absent ``is_healthy`` - call (returning ``True`` by default — no state means no reason to skip). - """ - - def __init__( - self, - *, - failure_threshold: int = DEFAULT_FAILURE_THRESHOLD, - unhealthy_backoff_sec: float = DEFAULT_UNHEALTHY_BACKOFF_SEC, - clock: callable = time.time, - ) -> None: - if failure_threshold < 1: - raise ValueError("failure_threshold must be >= 1") - if unhealthy_backoff_sec < 0: - raise ValueError("unhealthy_backoff_sec must be >= 0") - self._failure_threshold = failure_threshold - self._unhealthy_backoff = unhealthy_backoff_sec - self._clock = clock - self._lock = threading.Lock() - self._states: dict[str, ProviderState] = {} - self._listeners: list[HealthChangeListener] = [] - - def add_listener(self, listener: HealthChangeListener) -> None: - """Register a callback fired with ``(provider_id, healthy: bool)`` - whenever a provider's healthy-flag transitions. Listeners run - outside the cache's lock; exceptions are swallowed and logged so a - bad listener can't break the underlying state machine. - """ - self._listeners.append(listener) - - def _notify(self, provider_id: str, healthy: bool) -> None: - for listener in self._listeners: - try: - listener(provider_id, healthy) - except Exception as e: # noqa: BLE001 - logger.error("health-change listener raised: %s", e) - - @property - def failure_threshold(self) -> int: - return self._failure_threshold - - @property - def unhealthy_backoff_sec(self) -> float: - return self._unhealthy_backoff - - # ------------------------------------------------------------------ - # Reads - # ------------------------------------------------------------------ - - def is_healthy(self, provider_id: str) -> bool: - """Should the router try this provider right now? - - Returns ``True`` by default for unknown providers — the cache is - observation-only, not a registry. - """ - with self._lock: - state = self._states.get(provider_id) - if state is None: - return True - if state.unhealthy_until > self._clock(): - return False - # Backoff expired: caller is allowed to try again (half-open). - return True - - def get_state(self, provider_id: str) -> ProviderState | None: - """Snapshot of one provider's state (for debugging / tests).""" - with self._lock: - state = self._states.get(provider_id) - return None if state is None else _copy(state) - - def snapshot(self, expected: Iterable[str] | None = None) -> dict[str, ProviderState]: - """All known states, plus zero-state placeholders for any names in - ``expected`` that haven't been touched yet. Used by ``GET /v1/health`` - so the response shape is stable regardless of probe order. - """ - with self._lock: - out = {pid: _copy(s) for pid, s in self._states.items()} - if expected: - for pid in expected: - out.setdefault(pid, ProviderState()) - return out - - # ------------------------------------------------------------------ - # Writes - # ------------------------------------------------------------------ - - def mark_healthy(self, provider_id: str) -> None: - """Provider answered correctly — clear any failure state.""" - transitioned = False - with self._lock: - state = self._states.setdefault(provider_id, ProviderState()) - transitioned = not state.healthy - state.healthy = True - state.consecutive_failures = 0 - state.last_check = self._clock() - state.last_error = None - state.unhealthy_until = 0.0 - if transitioned: - logger.info("provider %s recovered", provider_id) - self._notify(provider_id, True) - - def mark_unhealthy(self, provider_id: str, reason: str) -> None: - """Record a failure. Trips the breaker after the threshold.""" - transitioned = False - with self._lock: - state = self._states.setdefault(provider_id, ProviderState()) - state.consecutive_failures += 1 - state.last_check = self._clock() - state.last_error = reason - tripped = state.consecutive_failures >= self._failure_threshold - if tripped and state.healthy: - state.healthy = False - state.unhealthy_until = self._clock() + self._unhealthy_backoff - transitioned = True - logger.warning( - "provider %s marked unhealthy after %d consecutive failures (%s); " - "backoff %.0fs", - provider_id, - state.consecutive_failures, - reason, - self._unhealthy_backoff, - ) - elif not state.healthy: - # Still in unhealthy window; refresh the backoff so a flapping - # provider doesn't get re-tried every probe tick. - state.unhealthy_until = self._clock() + self._unhealthy_backoff - if transitioned: - self._notify(provider_id, False) - - -def _copy(state: ProviderState) -> ProviderState: - """Return a shallow copy so callers can read without holding the lock.""" - return ProviderState( - healthy=state.healthy, - consecutive_failures=state.consecutive_failures, - last_check=state.last_check, - last_error=state.last_error, - unhealthy_until=state.unhealthy_until, - ) diff --git a/services/mana-llm/src/health_probe.py b/services/mana-llm/src/health_probe.py deleted file mode 100644 index 9f1cdfa28..000000000 --- a/services/mana-llm/src/health_probe.py +++ /dev/null @@ -1,210 +0,0 @@ -"""Background probe loop that keeps :class:`ProviderHealthCache` fresh. - -The router's circuit-breaker is reactive — it only learns a provider is -sick after the next live request fails. A reactive-only design means: - -* every cold start re-discovers Ollama is down by paying one 75-second - ``ConnectError``, and -* a provider that quietly recovers stays marked unhealthy until its - backoff expires and someone tries it. - -The probe loop closes both gaps. Every ``interval`` seconds it pings -each registered provider with a small known-cheap request (Ollama: -``GET /api/tags``, OpenAI-compat: ``GET /v1/models``) and updates the -cache. Probes run concurrently per tick and respect a hard -``probe_timeout`` so a hanging provider can't stall the loop. - -Probe functions are injected from outside (``probes={name: async-fn}``) -so this module stays decoupled from the provider classes — wiring lives -in ``main.py`` where we already know which providers are configured. -""" - -from __future__ import annotations - -import asyncio -import logging -from typing import Awaitable, Callable - -from .health import ProviderHealthCache - -logger = logging.getLogger(__name__) - -#: Probe function: returns ``True`` for healthy. Raising or returning -#: ``False`` both count as a failure (the loop just calls -#: ``mark_unhealthy``); an exception's string form becomes the -#: ``last_error`` for the snapshot endpoint. -ProbeFn = Callable[[], Awaitable[bool]] - -DEFAULT_PROBE_INTERVAL_SEC = 30.0 -DEFAULT_PROBE_TIMEOUT_SEC = 3.0 - - -class HealthProbe: - """Periodically probes every registered provider and updates the cache. - - Lifecycle:: - - probe = HealthProbe(cache, {"ollama": probe_ollama, ...}) - await probe.start() # spawns the background task - ... - await probe.stop() # cancels and awaits cleanup - - Tests typically call :meth:`tick_once` directly to exercise one cycle - without driving the asyncio scheduler through real ``asyncio.sleep``. - """ - - def __init__( - self, - cache: ProviderHealthCache, - probes: dict[str, ProbeFn], - *, - interval: float = DEFAULT_PROBE_INTERVAL_SEC, - timeout: float = DEFAULT_PROBE_TIMEOUT_SEC, - ) -> None: - if interval <= 0: - raise ValueError("interval must be > 0") - if timeout <= 0: - raise ValueError("timeout must be > 0") - self._cache = cache - self._probes = dict(probes) - self._interval = interval - self._timeout = timeout - self._task: asyncio.Task | None = None - self._stop = asyncio.Event() - - @property - def interval(self) -> float: - return self._interval - - @property - def timeout(self) -> float: - return self._timeout - - @property - def provider_ids(self) -> list[str]: - return list(self._probes) - - @property - def running(self) -> bool: - return self._task is not None and not self._task.done() - - # ------------------------------------------------------------------ - # Per-tick logic — exercised directly by tests - # ------------------------------------------------------------------ - - async def tick_once(self) -> None: - """Probe every provider once, in parallel, updating the cache. - - Errors in any one probe (including ``asyncio.TimeoutError``) are - captured per-provider — one bad probe never sinks the loop. - """ - if not self._probes: - return - results = await asyncio.gather( - *(self._probe_one(name, fn) for name, fn in self._probes.items()), - return_exceptions=True, - ) - # gather(return_exceptions=True) caught everything per-probe; this - # branch should never fire, but guard in case _probe_one ever grows - # a code path that bypasses its try/except. - # asyncio.gather() returns results in input order and same length, - # so the zip is a 1:1 mapping back to provider names. - for name, result in zip(self._probes, results): - if isinstance(result, BaseException): - logger.error("probe %s leaked exception: %s", name, result) - - async def _probe_one(self, name: str, fn: ProbeFn) -> None: - try: - healthy = await asyncio.wait_for(fn(), timeout=self._timeout) - except asyncio.TimeoutError: - self._cache.mark_unhealthy(name, f"probe-timeout (>{self._timeout:.0f}s)") - return - except asyncio.CancelledError: - raise - except Exception as e: # noqa: BLE001 — probe SHOULD be permissive - self._cache.mark_unhealthy(name, f"probe-exception: {type(e).__name__}: {e}") - return - if healthy: - self._cache.mark_healthy(name) - else: - self._cache.mark_unhealthy(name, "probe-returned-false") - - # ------------------------------------------------------------------ - # Long-running task management - # ------------------------------------------------------------------ - - async def start(self) -> None: - """Spawn the periodic probe task. Idempotent.""" - if self.running: - return - self._stop.clear() - self._task = asyncio.create_task(self._run_forever(), name="mana-llm-health-probe") - logger.info( - "HealthProbe started (interval=%.0fs, timeout=%.0fs, providers=%s)", - self._interval, - self._timeout, - ", ".join(self._probes) or "", - ) - - async def stop(self) -> None: - """Cancel the background task and wait for it to finish.""" - if not self.running: - return - self._stop.set() - assert self._task is not None - self._task.cancel() - try: - await self._task - except asyncio.CancelledError: - pass - finally: - self._task = None - logger.info("HealthProbe stopped") - - async def _run_forever(self) -> None: - # Probe immediately at boot so we don't serve traffic for `interval` - # seconds based on optimistic-default assumptions. - try: - await self.tick_once() - except Exception as e: # noqa: BLE001 - logger.error("HealthProbe initial tick failed: %s", e) - while not self._stop.is_set(): - try: - await asyncio.wait_for(self._stop.wait(), timeout=self._interval) - except asyncio.TimeoutError: - pass - else: - # _stop.wait() succeeded → stop signalled, exit. - return - try: - await self.tick_once() - except Exception as e: # noqa: BLE001 - logger.error("HealthProbe tick failed: %s", e) - - -# --------------------------------------------------------------------------- -# Probe-function helpers -# --------------------------------------------------------------------------- - - -def make_http_probe( - url: str, - *, - headers: dict[str, str] | None = None, - expected_status_lt: int = 500, -) -> ProbeFn: - """Return a probe function that does ``GET `` and considers the - provider healthy iff the response status is below - ``expected_status_lt`` (default: any non-5xx counts). - - A 401/403/404 still counts as healthy because the *server* answered — - auth or path mistakes are misconfiguration, not provider liveness. - """ - import httpx - - async def probe() -> bool: - async with httpx.AsyncClient(timeout=httpx.Timeout(5.0)) as client: - resp = await client.get(url, headers=headers or None) - return resp.status_code < expected_status_lt - - return probe diff --git a/services/mana-llm/src/main.py b/services/mana-llm/src/main.py deleted file mode 100644 index 24f78db9d..000000000 --- a/services/mana-llm/src/main.py +++ /dev/null @@ -1,405 +0,0 @@ -"""Main FastAPI application for mana-llm service.""" - -import asyncio -import logging -import signal -import time -from contextlib import asynccontextmanager -from pathlib import Path -from typing import Any - -from fastapi import FastAPI, HTTPException, Request -from fastapi.middleware.cors import CORSMiddleware -from fastapi.responses import JSONResponse, Response -from sse_starlette.sse import EventSourceResponse - -from src.aliases import AliasConfigError, AliasRegistry -from src.api_auth import ApiKeyMiddleware -from src.config import settings -from src.health import ProviderHealthCache -from src.health_probe import HealthProbe, make_http_probe -from src.models import ( - ChatCompletionRequest, - ChatCompletionResponse, - EmbeddingRequest, - EmbeddingResponse, - ModelInfo, - ModelsResponse, -) -from src.providers import ProviderRouter -from src.providers.errors import ProviderError -from src.streaming import stream_chat_completion -from src.utils.cache import close_redis -from src.utils.metrics import ( - get_metrics, - record_llm_error, - record_llm_request, - set_provider_healthy, -) - -#: Header carrying the concrete provider/model that actually served a -#: non-streaming response — useful for token-cost accounting on the -#: caller side, since `mana/long-form` could resolve to ollama, groq, -#: or claude depending on which providers were healthy at request time. -RESOLVED_MODEL_HEADER = "X-Mana-LLM-Resolved" - -# Configure logging -logging.basicConfig( - level=getattr(logging, settings.log_level.upper()), - format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", -) -logger = logging.getLogger(__name__) - -# Global service singletons -router: ProviderRouter | None = None -health_cache: ProviderHealthCache | None = None -health_probe: HealthProbe | None = None -alias_registry: AliasRegistry | None = None - - -def _build_provider_probes( - providers: dict[str, Any], -) -> dict[str, Any]: - """Wire each configured provider to a cheap HTTP probe.""" - probes: dict[str, Any] = {} - - if "ollama" in providers: - probes["ollama"] = make_http_probe(f"{settings.ollama_url}/api/tags") - if "openrouter" in providers: - probes["openrouter"] = make_http_probe( - f"{settings.openrouter_base_url}/models", - headers={"Authorization": f"Bearer {settings.openrouter_api_key}"}, - ) - if "groq" in providers: - probes["groq"] = make_http_probe( - f"{settings.groq_base_url}/models", - headers={"Authorization": f"Bearer {settings.groq_api_key}"}, - ) - if "together" in providers: - probes["together"] = make_http_probe( - f"{settings.together_base_url}/models", - headers={"Authorization": f"Bearer {settings.together_api_key}"}, - ) - # Google: skipped — google-genai SDK is opaque enough that a probe - # would amount to a real API call. Treat as healthy by default; the - # router's call-site fallback will mark it unhealthy on real errors. - return probes - - -def _on_health_change(provider: str, healthy: bool) -> None: - """Mirror health-cache transitions into the Prometheus gauge.""" - set_provider_healthy(provider, healthy) - - -def _install_sighup_reload(loop: asyncio.AbstractEventLoop) -> None: - """Reload ``aliases.yaml`` when the process receives SIGHUP. - - Reload errors keep the previous good state in memory (see - AliasRegistry.reload). SIGHUP isn't available on Windows; we just - log and skip in that case. - """ - - def handler() -> None: - if alias_registry is None: - return - try: - alias_registry.reload() - except AliasConfigError as e: - logger.error("alias reload rejected, keeping previous state: %s", e) - - try: - loop.add_signal_handler(signal.SIGHUP, handler) - except (NotImplementedError, AttributeError, RuntimeError): - # NotImplementedError on Windows; RuntimeError when the loop is - # not running in the main thread (TestClient does this). - logger.info("SIGHUP reload not available in this context — skipping") - - -@asynccontextmanager -async def lifespan(app: FastAPI): - """Application lifespan: load aliases, spin up router + health probe.""" - global router, health_cache, health_probe, alias_registry - - logger.info("Starting mana-llm service...") - - aliases_path = Path(__file__).resolve().parent.parent / "aliases.yaml" - alias_registry = AliasRegistry(aliases_path) - logger.info("Loaded %d aliases from %s", len(alias_registry.list_aliases()), aliases_path) - - health_cache = ProviderHealthCache() - health_cache.add_listener(_on_health_change) - router = ProviderRouter(aliases=alias_registry, health_cache=health_cache) - logger.info("Initialized providers: %s", list(router.providers)) - - # Initial gauge values so dashboards render before the first probe - # transition fires the listener. - for provider_name in router.providers: - set_provider_healthy(provider_name, True) - - health_probe = HealthProbe(health_cache, _build_provider_probes(router.providers)) - await health_probe.start() - - _install_sighup_reload(asyncio.get_running_loop()) - - yield - - logger.info("Shutting down mana-llm service...") - if health_probe is not None: - await health_probe.stop() - if router is not None: - await router.close() - await close_redis() - - -# Create FastAPI app -app = FastAPI( - title="mana-llm", - description="Central LLM abstraction service for Ollama and OpenAI-compatible APIs", - version="0.1.0", - lifespan=lifespan, -) - -# Add CORS middleware -app.add_middleware( - CORSMiddleware, - allow_origins=settings.cors_origins_list, - allow_credentials=True, - allow_methods=["*"], - allow_headers=["*"], -) -app.add_middleware(ApiKeyMiddleware) - - -# Health endpoint -@app.get("/health") -async def health_check() -> dict[str, Any]: - """Check service health and provider status.""" - if router is None: - return {"status": "unhealthy", "error": "Router not initialized"} - - provider_health = await router.health_check() - return { - "status": provider_health["status"], - "service": "mana-llm", - "version": "0.1.0", - "providers": provider_health["providers"], - } - - -# Metrics endpoint -@app.get("/metrics") -async def metrics() -> Response: - """Prometheus metrics endpoint.""" - return Response(content=get_metrics(), media_type="text/plain") - - -# ----- Alias / health debug endpoints (M4) --------------------------------- - - -@app.get("/v1/aliases") -async def list_aliases() -> dict[str, Any]: - """Inspect the alias registry — what each ``mana/`` resolves to. - - Useful for debugging "which model actually answered my request" and - for confirming SIGHUP reloads picked up edits to ``aliases.yaml``. - """ - if alias_registry is None: - raise HTTPException(status_code=503, detail="Service not ready") - return { - "default": alias_registry.default_alias, - "aliases": [ - { - "name": a.name, - "description": a.description, - "chain": list(a.chain), - } - for a in alias_registry.list_aliases() - ], - } - - -@app.get("/v1/health") -async def detailed_health() -> dict[str, Any]: - """Full per-provider liveness snapshot. - - Includes the failure counter, last error, and the unhealthy-until - backoff timestamp — info the original ``/health`` endpoint hides. - """ - if router is None: - raise HTTPException(status_code=503, detail="Service not ready") - return await router.health_check() - - -# Models endpoints -@app.get("/v1/models", response_model=ModelsResponse) -async def list_models() -> ModelsResponse: - """List all available models from all providers.""" - if router is None: - raise HTTPException(status_code=503, detail="Service not ready") - - models = await router.list_models() - return ModelsResponse(data=models) - - -@app.get("/v1/models/{model_id:path}") -async def get_model(model_id: str) -> ModelInfo: - """Get specific model information.""" - if router is None: - raise HTTPException(status_code=503, detail="Service not ready") - - model = await router.get_model(model_id) - if model is None: - raise HTTPException(status_code=404, detail=f"Model '{model_id}' not found") - - return model - - -# Chat completions endpoint -@app.post("/v1/chat/completions", response_model=None) -async def chat_completions( - request: ChatCompletionRequest, - http_request: Request, -) -> ChatCompletionResponse | EventSourceResponse: - """ - Create a chat completion. - - Supports both streaming (SSE) and non-streaming responses based on the - `stream` parameter in the request body. - """ - if router is None: - raise HTTPException(status_code=503, detail="Service not ready") - - # The request's `model` field is what the caller asked for — could be - # `mana/long-form`, `ollama/gemma3:4b`, or even bare `gemma3:4b`. For - # error-path metrics we use that value (it's what the caller will - # search for); for success-path metrics we use the resolved provider - # so token-cost / latency attribute to the model that actually ran. - requested_provider, requested_model = _split_model(request.model) - - start_time = time.time() - - try: - if request.stream: - # Streaming response via SSE - logger.info(f"Streaming chat completion: {request.model}") - - async def generate(): - async for chunk in stream_chat_completion(router, request): - yield chunk - - # Streaming metrics: we don't yet know which provider answered - # at request-record time. Each chunk's `model` field carries - # the resolved name; per-token latency is harder to attribute - # cleanly so we skip it for streams. - record_llm_request(requested_provider, requested_model, streaming=True) - - return EventSourceResponse( - generate(), - media_type="text/event-stream", - ) - else: - # Non-streaming response - logger.info(f"Chat completion: {request.model}") - response = await router.chat_completion(request) - - resolved_provider, resolved_model = _split_model(response.model) - latency = time.time() - start_time - record_llm_request( - provider=resolved_provider, - model=resolved_model, - streaming=False, - prompt_tokens=response.usage.prompt_tokens, - completion_tokens=response.usage.completion_tokens, - latency=latency, - ) - - # `response.model` is the concrete provider/model the chain - # actually resolved to. Surface it via header so the caller - # can attribute token cost to the right model even when the - # request used an alias. - return JSONResponse( - content=response.model_dump(), - headers={RESOLVED_MODEL_HEADER: response.model}, - ) - - except ValueError as e: - logger.error(f"Invalid request: {e}") - record_llm_error(requested_provider, requested_model, "invalid_request") - raise HTTPException(status_code=400, detail=str(e)) - except ProviderError as e: - logger.warning( - f"Provider error on {requested_provider}/{requested_model}: " - f"kind={e.kind} detail={e}" - ) - record_llm_error(requested_provider, requested_model, e.kind) - raise HTTPException( - status_code=e.http_status, - detail={"kind": e.kind, "message": str(e)}, - ) - except Exception as e: - logger.error(f"Chat completion failed: {e}") - record_llm_error(requested_provider, requested_model, "server_error") - raise HTTPException(status_code=500, detail=str(e)) - - -# Embeddings endpoint -@app.post("/v1/embeddings", response_model=EmbeddingResponse) -async def create_embeddings(request: EmbeddingRequest) -> EmbeddingResponse: - """Create embeddings for the input text.""" - if router is None: - raise HTTPException(status_code=503, detail="Service not ready") - - provider, model = _split_model(request.model) - - start_time = time.time() - - try: - logger.info(f"Creating embeddings: {request.model}") - response = await router.embeddings(request) - - latency = time.time() - start_time - record_llm_request( - provider=provider, - model=model, - streaming=False, - prompt_tokens=response.usage.prompt_tokens, - latency=latency, - ) - - return response - - except ValueError as e: - logger.error(f"Invalid embedding request: {e}") - record_llm_error(provider, model, "invalid_request") - raise HTTPException(status_code=400, detail=str(e)) - except Exception as e: - logger.error(f"Embeddings failed: {e}") - record_llm_error(provider, model, "server_error") - raise HTTPException(status_code=500, detail=str(e)) - - -def _split_model(model: str) -> tuple[str, str]: - """Split a ``provider/model`` string for metric labelling. - - Bare names with no slash default to ``ollama`` to match the legacy - OpenAI-style behaviour. Aliases (``mana/...``) keep their namespace - in the metrics — that's intentional, so request-side counters tell - you what callers ASKED for, while the resolved-side counters - (``mana_llm_alias_resolved_total``) tell you what they GOT. - """ - if "/" in model: - provider, _, name = model.partition("/") - return provider.lower(), name - return "ollama", model - - -if __name__ == "__main__": - import uvicorn - - uvicorn.run( - "src.main:app", - host="0.0.0.0", - port=settings.port, - reload=True, - log_level=settings.log_level, - ) diff --git a/services/mana-llm/src/models/__init__.py b/services/mana-llm/src/models/__init__.py deleted file mode 100644 index ac99636cd..000000000 --- a/services/mana-llm/src/models/__init__.py +++ /dev/null @@ -1,47 +0,0 @@ -"""Pydantic models for OpenAI-compatible API.""" - -from .requests import ( - ChatCompletionRequest, - EmbeddingRequest, - FunctionSpec, - Message, - ToolChoice, - ToolSpec, -) -from .responses import ( - ChatCompletionResponse, - ChatCompletionStreamResponse, - Choice, - DeltaContent, - EmbeddingData, - EmbeddingResponse, - MessageResponse, - ModelInfo, - ModelsResponse, - StreamChoice, - ToolCall, - ToolCallFunction, - Usage, -) - -__all__ = [ - "ChatCompletionRequest", - "ChatCompletionResponse", - "ChatCompletionStreamResponse", - "Choice", - "DeltaContent", - "EmbeddingData", - "EmbeddingRequest", - "EmbeddingResponse", - "FunctionSpec", - "Message", - "MessageResponse", - "ModelInfo", - "ModelsResponse", - "StreamChoice", - "ToolCall", - "ToolCallFunction", - "ToolChoice", - "ToolSpec", - "Usage", -] diff --git a/services/mana-llm/src/models/requests.py b/services/mana-llm/src/models/requests.py deleted file mode 100644 index 4c444124a..000000000 --- a/services/mana-llm/src/models/requests.py +++ /dev/null @@ -1,128 +0,0 @@ -"""Request models for OpenAI-compatible API.""" - -from typing import Any, Literal - -from pydantic import BaseModel, Field - - -class TextContent(BaseModel): - """Text content in a message.""" - - type: Literal["text"] = "text" - text: str - - -class ImageUrl(BaseModel): - """Image URL reference.""" - - url: str # Can be http(s):// or data:image/...;base64,... - - -class ImageContent(BaseModel): - """Image content in a message.""" - - type: Literal["image_url"] = "image_url" - image_url: ImageUrl - - -MessageContent = str | list[TextContent | ImageContent] - - -class ToolCallFunction(BaseModel): - """The function portion of a tool_call on an assistant message.""" - - name: str - # Arguments are passed as a JSON string (OpenAI spec). Providers may - # emit structured args natively; the adapter serialises them here. - arguments: str - - -class ToolCall(BaseModel): - """A tool invocation the assistant decided to make.""" - - id: str - type: Literal["function"] = "function" - function: ToolCallFunction - - -class Message(BaseModel): - """A single message in the conversation. - - `tool` messages carry the result of a previously-requested tool call - back into the context; they must reference the originating call via - ``tool_call_id``. Assistant messages may contain either plain - ``content`` or ``tool_calls`` (or both, though providers typically - only emit one at a time). - """ - - role: Literal["system", "user", "assistant", "tool"] - content: MessageContent | None = None - tool_call_id: str | None = None - tool_calls: list[ToolCall] | None = None - - -class FunctionSpec(BaseModel): - """OpenAI-style function declaration.""" - - name: str - description: str - # JSON Schema for the function's parameters. Kept as a dict here to - # stay loose; providers translate it to their native shape. - parameters: dict[str, Any] = Field(default_factory=dict) - - -class ToolSpec(BaseModel): - """A tool the model may call. Only ``function`` tools are supported.""" - - type: Literal["function"] = "function" - function: FunctionSpec - - -class ToolChoiceFunction(BaseModel): - """Force-pick a specific function for ``tool_choice``.""" - - type: Literal["function"] = "function" - function: dict[str, str] # {"name": "tool_name"} - - -ToolChoice = Literal["auto", "none", "required"] | ToolChoiceFunction - - -class ResponseFormat(BaseModel): - """OpenAI structured-output response_format hint. - - Two shapes are accepted: - - {"type": "json_object"} — free-form JSON - - {"type": "json_schema", - "json_schema": {"name": "...", "schema": {...}, "strict": bool}} - — schema-constrained JSON; passed through to providers that - support it (e.g. Ollama 0.5+ via its native `format` field). - """ - - type: Literal["json_object", "json_schema"] - json_schema: dict[str, Any] | None = None - - -class ChatCompletionRequest(BaseModel): - """Request body for chat completions endpoint.""" - - model: str = Field(..., description="Model identifier in format 'provider/model' or just 'model'") - messages: list[Message] = Field(..., min_length=1) - stream: bool = False - temperature: float | None = Field(default=None, ge=0.0, le=2.0) - max_tokens: int | None = Field(default=None, gt=0) - top_p: float | None = Field(default=None, ge=0.0, le=1.0) - frequency_penalty: float | None = Field(default=None, ge=-2.0, le=2.0) - presence_penalty: float | None = Field(default=None, ge=-2.0, le=2.0) - stop: str | list[str] | None = None - response_format: ResponseFormat | None = None - tools: list[ToolSpec] | None = None - tool_choice: ToolChoice | None = None - - -class EmbeddingRequest(BaseModel): - """Request body for embeddings endpoint.""" - - model: str = Field(..., description="Model identifier") - input: str | list[str] = Field(..., description="Text(s) to embed") - encoding_format: Literal["float", "base64"] = "float" diff --git a/services/mana-llm/src/models/responses.py b/services/mana-llm/src/models/responses.py deleted file mode 100644 index 37ac4af2e..000000000 --- a/services/mana-llm/src/models/responses.py +++ /dev/null @@ -1,120 +0,0 @@ -"""Response models for OpenAI-compatible API.""" - -import time -import uuid -from typing import Literal - -from pydantic import BaseModel, Field - - -class Usage(BaseModel): - """Token usage information.""" - - prompt_tokens: int = 0 - completion_tokens: int = 0 - total_tokens: int = 0 - - -class ToolCallFunction(BaseModel): - """Function portion of a tool_call the assistant produced.""" - - name: str - arguments: str # JSON-encoded (OpenAI spec) - - -class ToolCall(BaseModel): - """A tool invocation the assistant decided to make.""" - - id: str - type: Literal["function"] = "function" - function: ToolCallFunction - - -class MessageResponse(BaseModel): - """Response message from the model.""" - - role: Literal["assistant"] = "assistant" - content: str | None = None - tool_calls: list[ToolCall] | None = None - - -class Choice(BaseModel): - """A single completion choice.""" - - index: int = 0 - message: MessageResponse - finish_reason: ( - Literal["stop", "length", "content_filter", "tool_calls"] | None - ) = "stop" - - -class ChatCompletionResponse(BaseModel): - """Response from chat completions endpoint (non-streaming).""" - - id: str = Field(default_factory=lambda: f"chatcmpl-{uuid.uuid4().hex[:12]}") - object: Literal["chat.completion"] = "chat.completion" - created: int = Field(default_factory=lambda: int(time.time())) - model: str - choices: list[Choice] - usage: Usage = Field(default_factory=Usage) - - -class DeltaContent(BaseModel): - """Delta content for streaming responses.""" - - role: Literal["assistant"] | None = None - content: str | None = None - tool_calls: list[ToolCall] | None = None - - -class StreamChoice(BaseModel): - """A single streaming choice.""" - - index: int = 0 - delta: DeltaContent - finish_reason: ( - Literal["stop", "length", "content_filter", "tool_calls"] | None - ) = None - - -class ChatCompletionStreamResponse(BaseModel): - """Response chunk from chat completions endpoint (streaming).""" - - id: str = Field(default_factory=lambda: f"chatcmpl-{uuid.uuid4().hex[:12]}") - object: Literal["chat.completion.chunk"] = "chat.completion.chunk" - created: int = Field(default_factory=lambda: int(time.time())) - model: str - choices: list[StreamChoice] - - -class ModelInfo(BaseModel): - """Information about a model.""" - - id: str - object: Literal["model"] = "model" - created: int = Field(default_factory=lambda: int(time.time())) - owned_by: str = "mana-llm" - - -class ModelsResponse(BaseModel): - """Response from models endpoint.""" - - object: Literal["list"] = "list" - data: list[ModelInfo] - - -class EmbeddingData(BaseModel): - """A single embedding result.""" - - object: Literal["embedding"] = "embedding" - index: int = 0 - embedding: list[float] - - -class EmbeddingResponse(BaseModel): - """Response from embeddings endpoint.""" - - object: Literal["list"] = "list" - data: list[EmbeddingData] - model: str - usage: Usage = Field(default_factory=Usage) diff --git a/services/mana-llm/src/providers/__init__.py b/services/mana-llm/src/providers/__init__.py deleted file mode 100644 index 04127d67e..000000000 --- a/services/mana-llm/src/providers/__init__.py +++ /dev/null @@ -1,13 +0,0 @@ -"""LLM Provider implementations.""" - -from .base import LLMProvider -from .ollama import OllamaProvider -from .openai_compat import OpenAICompatProvider -from .router import ProviderRouter - -__all__ = [ - "LLMProvider", - "OllamaProvider", - "OpenAICompatProvider", - "ProviderRouter", -] diff --git a/services/mana-llm/src/providers/base.py b/services/mana-llm/src/providers/base.py deleted file mode 100644 index a3909e5ae..000000000 --- a/services/mana-llm/src/providers/base.py +++ /dev/null @@ -1,76 +0,0 @@ -"""Abstract base class for LLM providers.""" - -from abc import ABC, abstractmethod -from collections.abc import AsyncIterator -from typing import Any - -from src.models import ( - ChatCompletionRequest, - ChatCompletionResponse, - ChatCompletionStreamResponse, - EmbeddingRequest, - EmbeddingResponse, - ModelInfo, -) - - -class LLMProvider(ABC): - """Abstract base class for LLM providers.""" - - name: str = "base" - - # Set to True if the provider supports OpenAI-style `tools` + `tool_calls` - # for chat completions. The router rejects tool-bearing requests routed - # to providers without support rather than silently dropping the tools. - # Provider adapters may further narrow this per-model if needed. - supports_tools: bool = False - - def model_supports_tools(self, model: str) -> bool: - """Check if a specific model within this provider supports tools. - - Default: falls back to the provider-wide flag. Providers with a - mixed capability surface (e.g. Ollama — depends on the local - model) override this. - """ - return self.supports_tools - - @abstractmethod - async def chat_completion( - self, - request: ChatCompletionRequest, - model: str, - ) -> ChatCompletionResponse: - """Generate a chat completion (non-streaming).""" - ... - - @abstractmethod - async def chat_completion_stream( - self, - request: ChatCompletionRequest, - model: str, - ) -> AsyncIterator[ChatCompletionStreamResponse]: - """Generate a chat completion (streaming).""" - ... - - @abstractmethod - async def list_models(self) -> list[ModelInfo]: - """List available models.""" - ... - - @abstractmethod - async def embeddings( - self, - request: EmbeddingRequest, - model: str, - ) -> EmbeddingResponse: - """Generate embeddings for input text.""" - ... - - @abstractmethod - async def health_check(self) -> dict[str, Any]: - """Check provider health status.""" - ... - - async def close(self) -> None: - """Clean up resources.""" - pass diff --git a/services/mana-llm/src/providers/errors.py b/services/mana-llm/src/providers/errors.py deleted file mode 100644 index 440651270..000000000 --- a/services/mana-llm/src/providers/errors.py +++ /dev/null @@ -1,111 +0,0 @@ -"""Structured provider errors. - -These map to distinct HTTP status codes and metric labels so callers can -distinguish between "model refused to answer" (blocked), "hit the token -budget" (truncated), "auth is broken", "we were rate-limited", and -"model doesn't support what we asked" (capability). The old behaviour of -returning an empty-string response on content-filter hits silently -corrupts downstream pipelines (e.g. the planner sees "" and reports -"no JSON block found" — misleading). -""" - - -class ProviderError(Exception): - """Base class for structured provider errors.""" - - # Short stable identifier used in metrics and API responses. - kind: str = "unknown" - # Suggested HTTP status when this error bubbles out of an endpoint. - http_status: int = 502 - - -class ProviderBlockedError(ProviderError): - """The provider refused to return content (safety, recitation, …). - - The model produced nothing usable because a content filter or policy - guardrail fired. Retry with the same inputs will fail again — the - caller needs to adjust the prompt or give the user a clear message. - """ - - kind = "blocked" - http_status = 422 - - def __init__(self, reason: str, detail: str | None = None): - self.reason = reason # e.g. "SAFETY", "RECITATION" - self.detail = detail - msg = f"Provider blocked the response ({reason})" - if detail: - msg += f": {detail}" - super().__init__(msg) - - -class ProviderTruncatedError(ProviderError): - """The response was cut off before completion (hit max_tokens).""" - - kind = "truncated" - http_status = 502 - - def __init__(self, partial_text: str | None = None): - self.partial_text = partial_text - super().__init__( - "Provider truncated the response (max_tokens reached) — " - "re-run with a higher max_tokens or smaller input" - ) - - -class ProviderAuthError(ProviderError): - """Authentication against the upstream provider failed.""" - - kind = "auth" - http_status = 502 # not 401 — the client's auth is fine, *ours* isn't - - -class ProviderRateLimitError(ProviderError): - """The upstream provider rate-limited us.""" - - kind = "rate_limit" - http_status = 429 - - -class ProviderCapabilityError(ProviderError): - """The requested feature is not supported by the chosen model. - - Typically raised when a request asks for native tool-calling against - a model that does not support it. We refuse to silently degrade. - """ - - kind = "capability" - http_status = 400 - - -class NoHealthyProviderError(ProviderError): - """Every entry in the resolved chain has been tried and either was - unconfigured, marked unhealthy by the cache, or failed in flight. - - Carries the full attempt log so the caller can see which providers - were tried and why each one was skipped or failed — invaluable when - a real outage hits and the API returns 503 instead of the usual 200. - """ - - kind = "no_healthy_provider" - http_status = 503 - - def __init__( - self, - model_or_alias: str, - attempts: list[tuple[str, str]], - last_exception: Exception | None = None, - ) -> None: - self.model_or_alias = model_or_alias - self.attempts = list(attempts) - self.last_exception = last_exception - if attempts: - attempt_log = ", ".join(f"{model}={reason}" for model, reason in attempts) - else: - attempt_log = "(no providers in resolved chain were configured)" - msg = ( - f"no healthy provider could serve {model_or_alias!r}. Attempts: {attempt_log}" - ) - if last_exception is not None: - msg += f". Last error: {type(last_exception).__name__}: {last_exception}" - super().__init__(msg) diff --git a/services/mana-llm/src/providers/google.py b/services/mana-llm/src/providers/google.py deleted file mode 100644 index bdf768a74..000000000 --- a/services/mana-llm/src/providers/google.py +++ /dev/null @@ -1,521 +0,0 @@ -"""Google Gemini provider for mana-llm (fallback when Ollama is unavailable).""" - -import json -import logging -import uuid -from collections.abc import AsyncIterator -from typing import Any - -from google import genai -from google.genai import types - -from src.config import settings -from src.models import ( - ChatCompletionRequest, - ChatCompletionResponse, - ChatCompletionStreamResponse, - Choice, - DeltaContent, - EmbeddingData, - EmbeddingRequest, - EmbeddingResponse, - MessageResponse, - ModelInfo, - StreamChoice, - ToolCall, - ToolCallFunction, - ToolSpec, - Usage, -) - -from .base import LLMProvider -from .errors import ( - ProviderAuthError, - ProviderBlockedError, - ProviderError, - ProviderRateLimitError, - ProviderTruncatedError, -) - -logger = logging.getLogger(__name__) - - -def _build_gemini_tools(tools: list[ToolSpec] | None) -> list[types.Tool] | None: - """Translate our ToolSpec list into Gemini ``types.Tool`` declarations.""" - if not tools: - return None - declarations: list[types.FunctionDeclaration] = [] - for t in tools: - declarations.append( - types.FunctionDeclaration( - name=t.function.name, - description=t.function.description, - parameters=t.function.parameters or None, - ) - ) - return [types.Tool(function_declarations=declarations)] - - -def _extract_tool_calls(candidate: Any) -> list[ToolCall] | None: - """Pull any ``function_call`` parts off a candidate into ToolCalls. - - Gemini emits tool calls as ``content.parts[i].function_call`` with a - ``FunctionCall(name, args)`` where ``args`` is a dict (not a JSON - string). We normalise to OpenAI shape: arguments are JSON-encoded - strings so downstream handlers can treat all providers the same. - """ - if candidate is None: - return None - content = getattr(candidate, "content", None) - parts = getattr(content, "parts", None) or [] - calls: list[ToolCall] = [] - for part in parts: - fc = getattr(part, "function_call", None) - if fc is None: - continue - name = getattr(fc, "name", None) - if not name: - continue - args = getattr(fc, "args", None) or {} - # Gemini's args are already dict-shaped; serialise to JSON string. - try: - arguments_json = json.dumps(dict(args), ensure_ascii=False) - except (TypeError, ValueError): - arguments_json = json.dumps({}, ensure_ascii=False) - calls.append( - ToolCall( - id=f"call_{uuid.uuid4().hex[:12]}", - function=ToolCallFunction(name=name, arguments=arguments_json), - ) - ) - return calls or None - - -def _unwrap_gemini_response( - response: Any, gemini_model: str -) -> tuple[str, list[ToolCall] | None, str]: - """Validate a non-streaming Gemini response. - - Returns ``(text, tool_calls, finish_reason)``. Raises a structured - ``ProviderError`` if the response was blocked, truncated, or - otherwise produced no usable payload. ``response.text`` silently - returns ``""`` on blocked responses — we refuse to propagate that. - """ - candidates = getattr(response, "candidates", None) or [] - candidate = candidates[0] if candidates else None - finish_reason = getattr(candidate, "finish_reason", None) - # SDK sometimes exposes the enum name on `.name`, sometimes it's a string. - finish_name = getattr(finish_reason, "name", None) or ( - str(finish_reason) if finish_reason is not None else None - ) - # Strip the leading enum prefix if present (e.g. "FinishReason.SAFETY"). - if finish_name and "." in finish_name: - finish_name = finish_name.rsplit(".", 1)[-1] - - text = response.text or "" - tool_calls = _extract_tool_calls(candidate) - - if finish_name in {"SAFETY", "RECITATION", "PROHIBITED_CONTENT", "SPII", "BLOCKLIST"}: - # Pull the first safety rating that actually blocked if present. - ratings = getattr(candidate, "safety_ratings", None) or [] - blocked = [ - getattr(r, "category", None) - for r in ratings - if getattr(r, "blocked", False) - ] - detail = ", ".join(str(c) for c in blocked if c) or None - logger.warning( - "Gemini response blocked (model=%s, reason=%s, detail=%s)", - gemini_model, - finish_name, - detail, - ) - raise ProviderBlockedError(reason=finish_name, detail=detail) - - if finish_name == "MAX_TOKENS": - logger.warning( - "Gemini response truncated at max_tokens (model=%s)", gemini_model - ) - raise ProviderTruncatedError(partial_text=text or None) - - if not text and not tool_calls and finish_name not in (None, "STOP"): - # Unknown finish reason, nothing to return — surface instead of "". - raise ProviderError( - f"Gemini returned no content (finish_reason={finish_name})" - ) - - # Normalise the finish_reason to our OpenAI-compatible vocabulary. - openai_finish = "stop" - if tool_calls: - openai_finish = "tool_calls" - elif finish_name == "MAX_TOKENS": - openai_finish = "length" - return text, tool_calls, openai_finish - - -def _lookup_tool_name(messages: list[Any], tool_call_id: str | None) -> str | None: - """Find the tool name for a given ``tool_call_id``. - - OpenAI's tool-message carries only the id; the matching name lives - on the preceding assistant message's ``tool_calls[]``. We scan from - the end (most recent assistant turn) backwards. - """ - if not tool_call_id: - return None - for m in reversed(messages): - if m.role != "assistant" or not m.tool_calls: - continue - for call in m.tool_calls: - if call.id == tool_call_id: - return call.function.name - return None - - -def _wrap_gemini_call_error(err: Exception, gemini_model: str) -> ProviderError: - """Translate a raw Google SDK exception into a structured ProviderError. - - The SDK uses google.genai.errors.* but we avoid importing them at - top level to keep the provider optional. String-match the class - name instead. - """ - cls_name = type(err).__name__ - msg = str(err) or cls_name - if "Auth" in cls_name or "PermissionDenied" in cls_name or "Unauthenticated" in cls_name: - return ProviderAuthError(f"Gemini auth failed for {gemini_model}: {msg}") - if "ResourceExhausted" in cls_name or "RateLimit" in cls_name or "429" in msg: - return ProviderRateLimitError(f"Gemini rate-limited for {gemini_model}: {msg}") - return ProviderError(f"Gemini call failed for {gemini_model}: {msg}") - -# Model mapping: Ollama model → Google Gemini equivalent -OLLAMA_TO_GEMINI: dict[str, str] = { - "gemma3:4b": "gemini-2.5-flash", - "gemma3:12b": "gemini-2.5-flash", - "gemma3:27b": "gemini-2.5-pro", - "llava:7b": "gemini-2.5-flash", # Gemini has native vision - "qwen3-vl:4b": "gemini-2.5-flash", # vision fallback - "qwen2.5-coder:7b": "gemini-2.5-flash", - "qwen2.5-coder:14b": "gemini-2.5-pro", - "phi3.5:latest": "gemini-2.5-flash", - "ministral-3:3b": "gemini-2.5-flash", - "deepseek-ocr:latest": "gemini-2.5-flash", -} - - -class GoogleProvider(LLMProvider): - """Google Gemini API provider.""" - - name = "google" - # Gemini 2.x supports OpenAI-style function calling across all listed models. - supports_tools = True - - def __init__(self, api_key: str, default_model: str = "gemini-2.5-flash"): - self.api_key = api_key - self.default_model = default_model - self.client = genai.Client(api_key=api_key) - - def map_model(self, ollama_model: str) -> str: - """Map an Ollama model name to a Google Gemini equivalent.""" - return OLLAMA_TO_GEMINI.get(ollama_model, self.default_model) - - def _convert_messages( - self, request: ChatCompletionRequest - ) -> tuple[str | None, list[types.Content]]: - """Convert OpenAI-format messages to Google Gemini format. - - Returns (system_instruction, contents). Handles text, multimodal - image content, assistant messages carrying ``tool_calls``, and - ``tool`` result messages (mapped to Gemini ``function_response`` - parts — Gemini has no ``tool`` role, function responses ride on - a ``user`` turn). - """ - system_instruction: str | None = None - contents: list[types.Content] = [] - - for msg in request.messages: - if msg.role == "system": - if isinstance(msg.content, str): - system_instruction = msg.content - continue - - # Tool result message → function_response Part on a user turn. - if msg.role == "tool": - # The content is the stringified tool result. We also need - # a tool name — the OpenAI spec carries it on the matching - # assistant tool_call, keyed by tool_call_id. We don't - # track that back-reference here, so we pull the name - # from the preceding assistant message's tool_calls. - name = _lookup_tool_name(request.messages, msg.tool_call_id) - if not name: - continue # orphan tool message — skip silently - payload: Any - if isinstance(msg.content, str): - try: - payload = json.loads(msg.content) - if not isinstance(payload, dict): - payload = {"result": payload} - except (TypeError, ValueError): - payload = {"result": msg.content} - else: - payload = {"result": ""} - contents.append( - types.Content( - role="user", - parts=[ - types.Part.from_function_response( - name=name, response=payload - ) - ], - ) - ) - continue - - role = "user" if msg.role == "user" else "model" - parts: list[types.Part] = [] - - if msg.role == "assistant" and msg.tool_calls: - for call in msg.tool_calls: - try: - args_obj = json.loads(call.function.arguments or "{}") - except (TypeError, ValueError): - args_obj = {} - parts.append( - types.Part( - function_call=types.FunctionCall( - name=call.function.name, args=args_obj - ) - ) - ) - - if msg.content is not None: - if isinstance(msg.content, str): - parts.append(types.Part.from_text(text=msg.content)) - else: - for part in msg.content: - if part.type == "text": - parts.append(types.Part.from_text(text=part.text)) - elif part.type == "image_url" and part.image_url: - url = part.image_url.url - if url.startswith("data:"): - header, b64_data = url.split(",", 1) - mime_type = header.split(":")[1].split(";")[0] - import base64 - - image_bytes = base64.b64decode(b64_data) - parts.append( - types.Part.from_bytes( - data=image_bytes, mime_type=mime_type - ) - ) - else: - parts.append( - types.Part.from_uri( - file_uri=url, mime_type="image/jpeg" - ) - ) - - if parts: - contents.append(types.Content(role=role, parts=parts)) - - return system_instruction, contents - - async def chat_completion( - self, - request: ChatCompletionRequest, - model: str, - ) -> ChatCompletionResponse: - """Generate a chat completion via Google Gemini.""" - gemini_model = self.map_model(model) if model in OLLAMA_TO_GEMINI else model - system_instruction, contents = self._convert_messages(request) - - config: dict[str, Any] = {} - if request.temperature is not None: - config["temperature"] = request.temperature - if request.max_tokens is not None: - config["max_output_tokens"] = request.max_tokens - if request.top_p is not None: - config["top_p"] = request.top_p - if request.stop: - stop_seqs = request.stop if isinstance(request.stop, list) else [request.stop] - config["stop_sequences"] = stop_seqs - - gemini_tools = _build_gemini_tools(request.tools) - if gemini_tools: - config["tools"] = gemini_tools - - gen_config = types.GenerateContentConfig( - system_instruction=system_instruction, - **config, - ) - - logger.debug(f"Google Gemini request: {gemini_model}, messages: {len(contents)}") - - try: - response = await self.client.aio.models.generate_content( - model=gemini_model, - contents=contents, - config=gen_config, - ) - except ProviderError: - raise - except Exception as err: - raise _wrap_gemini_call_error(err, gemini_model) from err - - content, tool_calls, finish_reason = _unwrap_gemini_response( - response, gemini_model - ) - usage_meta = response.usage_metadata - - return ChatCompletionResponse( - model=f"google/{gemini_model}", - choices=[ - Choice( - index=0, - message=MessageResponse( - content=content or None, - tool_calls=tool_calls, - ), - finish_reason=finish_reason, - ) - ], - usage=Usage( - prompt_tokens=usage_meta.prompt_token_count if usage_meta else 0, - completion_tokens=usage_meta.candidates_token_count if usage_meta else 0, - total_tokens=usage_meta.total_token_count if usage_meta else 0, - ), - ) - - async def chat_completion_stream( - self, - request: ChatCompletionRequest, - model: str, - ) -> AsyncIterator[ChatCompletionStreamResponse]: - """Generate a streaming chat completion via Google Gemini.""" - gemini_model = self.map_model(model) if model in OLLAMA_TO_GEMINI else model - system_instruction, contents = self._convert_messages(request) - - config: dict[str, Any] = {} - if request.temperature is not None: - config["temperature"] = request.temperature - if request.max_tokens is not None: - config["max_output_tokens"] = request.max_tokens - if request.top_p is not None: - config["top_p"] = request.top_p - - gemini_tools = _build_gemini_tools(request.tools) - if gemini_tools: - config["tools"] = gemini_tools - - gen_config = types.GenerateContentConfig( - system_instruction=system_instruction, - **config, - ) - - # First chunk with role - yield ChatCompletionStreamResponse( - model=f"google/{gemini_model}", - choices=[ - StreamChoice( - delta=DeltaContent(role="assistant"), - finish_reason=None, - ) - ], - ) - - last_chunk: Any = None - emitted_any_text = False - try: - stream = await self.client.aio.models.generate_content_stream( - model=gemini_model, - contents=contents, - config=gen_config, - ) - except Exception as err: - raise _wrap_gemini_call_error(err, gemini_model) from err - - async for chunk in stream: - last_chunk = chunk - text = chunk.text - if text: - emitted_any_text = True - yield ChatCompletionStreamResponse( - model=f"google/{gemini_model}", - choices=[ - StreamChoice( - delta=DeltaContent(content=text), - finish_reason=None, - ) - ], - ) - - # Post-stream check: if the stream ended without emitting any text, - # surface the structured reason instead of quietly closing with an - # empty "stop". Matches _unwrap_gemini_response semantics. We - # discard the return value here — the streaming path produces its - # content chunk-by-chunk above, so we only need this call for its - # side-effect of raising on SAFETY / RECITATION / MAX_TOKENS. - if not emitted_any_text and last_chunk is not None: - _unwrap_gemini_response(last_chunk, gemini_model) - - # Final chunk - yield ChatCompletionStreamResponse( - model=f"google/{gemini_model}", - choices=[ - StreamChoice( - delta=DeltaContent(), - finish_reason="stop", - ) - ], - ) - - async def list_models(self) -> list[ModelInfo]: - """List available Google Gemini models.""" - # Return a static list of commonly used models - return [ - ModelInfo(id="google/gemini-2.5-flash", owned_by="google"), - ModelInfo(id="google/gemini-2.5-pro", owned_by="google"), - ] - - async def embeddings( - self, - request: EmbeddingRequest, - model: str, - ) -> EmbeddingResponse: - """Generate embeddings via Google Gemini.""" - inputs = request.input if isinstance(request.input, list) else [request.input] - - result = await self.client.aio.models.embed_content( - model="text-embedding-004", - contents=inputs, - ) - - return EmbeddingResponse( - data=[ - EmbeddingData(index=i, embedding=emb.values) - for i, emb in enumerate(result.embeddings) - ], - model="google/text-embedding-004", - usage=Usage( - prompt_tokens=sum(len(t.split()) for t in inputs), - total_tokens=sum(len(t.split()) for t in inputs), - ), - ) - - async def health_check(self) -> dict[str, Any]: - """Check Google API health.""" - try: - # Quick test: list models - response = await self.client.aio.models.list(config={"page_size": 1}) - return { - "status": "healthy", - "provider": "google", - } - except Exception as e: - return { - "status": "unhealthy", - "provider": "google", - "error": str(e), - } - - async def close(self) -> None: - """No cleanup needed for Google client.""" - pass diff --git a/services/mana-llm/src/providers/ollama.py b/services/mana-llm/src/providers/ollama.py deleted file mode 100644 index 0cc22fe5f..000000000 --- a/services/mana-llm/src/providers/ollama.py +++ /dev/null @@ -1,462 +0,0 @@ -"""Ollama provider implementation.""" - -import json -import logging -from collections.abc import AsyncIterator -from typing import Any - -import httpx - -from src.config import settings -from src.models import ( - ChatCompletionRequest, - ChatCompletionResponse, - ChatCompletionStreamResponse, - Choice, - DeltaContent, - EmbeddingData, - EmbeddingRequest, - EmbeddingResponse, - MessageResponse, - ModelInfo, - StreamChoice, - ToolCall, - ToolCallFunction, - Usage, -) - -from .base import LLMProvider - -# Ollama emits tool_calls under /api/chat as: -# {"message":{"content":"", "tool_calls":[{"function":{"name":"x","arguments":{...}}}]}} -# Arguments arrive as a dict — we normalise to a JSON string to match -# the OpenAI-spec shape our downstream code expects. - -# Ollama models known to support tool-calling reliably (as of 0.3+). -# Everything else: we still pass `tools` to the API (it ignores them on -# incompatible models), but the assistant response will be plain text. -# A shared-ai-level capability check rejects these cases before the call. -TOOL_CAPABLE_OLLAMA_PATTERNS: tuple[str, ...] = ( - "llama3.1", - "llama3.2", - "llama3.3", - "qwen2.5", - "qwen3", - "mistral", - "mixtral", - "command-r", - "firefunction", -) - -logger = logging.getLogger(__name__) - - -def _safe_parse_args(raw: str | None) -> dict[str, Any]: - """Best-effort parse of a JSON-encoded arguments string.""" - if not raw: - return {} - try: - parsed = json.loads(raw) - except (TypeError, ValueError): - return {} - return parsed if isinstance(parsed, dict) else {} - - -def _tool_calls_from_ollama(raw: list[dict[str, Any]] | None) -> list[ToolCall] | None: - """Normalise Ollama's tool_calls into our ToolCall model.""" - if not raw: - return None - import uuid - - calls: list[ToolCall] = [] - for idx, c in enumerate(raw): - fn = c.get("function") or {} - name = fn.get("name") - if not name: - continue - args = fn.get("arguments") - if isinstance(args, dict): - arguments_json = json.dumps(args, ensure_ascii=False) - elif isinstance(args, str): - arguments_json = args - else: - arguments_json = "{}" - calls.append( - ToolCall( - id=c.get("id") or f"call_{uuid.uuid4().hex[:12]}", - function=ToolCallFunction(name=name, arguments=arguments_json), - ) - ) - return calls or None - - -def _strip_json_fences(content: str) -> str: - """Strip ```json ... ``` markdown fences from a string if present. - - Some Ollama vision models still wrap structured-output responses in - a markdown code block even when `format` is set. Downstream parsers - (Vercel AI SDK generateObject, manual JSON.parse) expect clean JSON, - so we normalize the response here at the proxy boundary. - """ - s = content.strip() - if s.startswith("```"): - # Drop the opening fence (```json or ``` plus any language tag) - first_newline = s.find("\n") - if first_newline != -1: - s = s[first_newline + 1 :] - # Drop the closing fence - if s.endswith("```"): - s = s[:-3] - s = s.strip() - return s - - -class OllamaProvider(LLMProvider): - """Ollama LLM provider.""" - - name = "ollama" - supports_tools = True - - def model_supports_tools(self, model: str) -> bool: - """Narrow tool capability to models trained for function calling. - - Ollama accepts the ``tools`` payload even for incompatible models - but silently drops it — the assistant just replies in prose. We - reject those requests upfront so callers get a clear error. - """ - lower = model.lower() - return any(pattern in lower for pattern in TOOL_CAPABLE_OLLAMA_PATTERNS) - - def __init__(self, base_url: str | None = None, timeout: int | None = None): - self.base_url = (base_url or settings.ollama_url).rstrip("/") - self.timeout = timeout or settings.ollama_timeout - self._client: httpx.AsyncClient | None = None - - @property - def client(self) -> httpx.AsyncClient: - """Get or create HTTP client.""" - if self._client is None or self._client.is_closed: - self._client = httpx.AsyncClient( - base_url=self.base_url, - timeout=httpx.Timeout(self.timeout), - ) - return self._client - - def _convert_messages(self, request: ChatCompletionRequest) -> list[dict[str, Any]]: - """Convert OpenAI message format to Ollama format. - - Ollama's ``/api/chat`` uses the OpenAI shape for assistant - ``tool_calls`` and ``tool`` result messages (with args as objects - rather than JSON strings on output — input still accepts JSON - strings), so we largely pass them through. - """ - messages: list[dict[str, Any]] = [] - for msg in request.messages: - out: dict[str, Any] = {"role": msg.role} - - if msg.tool_calls: - out["tool_calls"] = [ - { - "function": { - "name": c.function.name, - # Ollama accepts either stringified JSON or - # already-parsed objects; send parsed for - # better compatibility with older models. - "arguments": _safe_parse_args(c.function.arguments), - } - } - for c in msg.tool_calls - ] - - if msg.content is None: - out["content"] = "" - elif isinstance(msg.content, str): - out["content"] = msg.content - else: - text_parts: list[str] = [] - images: list[str] = [] - for part in msg.content: - if part.type == "text": - text_parts.append(part.text) - elif part.type == "image_url": - url = part.image_url.url - if url.startswith("data:"): - base64_data = url.split(",", 1)[1] if "," in url else url - images.append(base64_data) - else: - logger.warning( - "HTTP image URLs not supported, skipping: %s...", - url[:50], - ) - out["content"] = " ".join(text_parts) - if images: - out["images"] = images - - messages.append(out) - return messages - - async def chat_completion( - self, - request: ChatCompletionRequest, - model: str, - ) -> ChatCompletionResponse: - """Generate a chat completion (non-streaming).""" - payload: dict[str, Any] = { - "model": model, - "messages": self._convert_messages(request), - "stream": False, - } - - # Pass through structured-output requests to Ollama's native - # `format` field. Ollama supports either `"json"` (free-form - # JSON object) or a full JSON schema dict. The OpenAI-style - # response_format the consumer sends maps as follows: - # - {"type": "json_object"} → "json" - # - {"type": "json_schema", "json_schema": {"schema": {...}}} - # → the schema dict (Ollama 0.5+ supports full schemas) - # Without this, Ollama wraps JSON in ```json ... ``` markdown - # fences, which breaks downstream strict parsers like the AI SDK - # generateObject() helper. - if request.response_format is not None: - rf = request.response_format - if rf.type == "json_object": - payload["format"] = "json" - elif rf.type == "json_schema" and rf.json_schema is not None: - # rf.json_schema is the OpenAI envelope: - # {"name": "...", "schema": {...}, "strict": true} - # Ollama wants just the inner schema dict. - inner = rf.json_schema.get("schema") - payload["format"] = inner if inner is not None else "json" - else: - payload["format"] = "json" - - # Add optional parameters - options: dict[str, Any] = {} - if request.temperature is not None: - options["temperature"] = request.temperature - if request.top_p is not None: - options["top_p"] = request.top_p - if request.max_tokens is not None: - options["num_predict"] = request.max_tokens - if request.stop: - options["stop"] = request.stop if isinstance(request.stop, list) else [request.stop] - - if options: - payload["options"] = options - - if request.tools: - payload["tools"] = [ - {"type": "function", "function": t.function.model_dump()} - for t in request.tools - ] - - logger.debug(f"Ollama request: {model}, messages: {len(request.messages)}") - - response = await self.client.post("/api/chat", json=payload) - response.raise_for_status() - data = response.json() - - message = data.get("message", {}) - tool_calls = _tool_calls_from_ollama(message.get("tool_calls")) - # Defensive fence-stripping: even with `format` set, some older - # Ollama versions still emit ```json ... ``` wrappers for vision - # models. Strip them so strict downstream parsers see clean JSON. - raw_content = message.get("content", "") or "" - content = _strip_json_fences(raw_content) if raw_content else "" - - finish_reason = "tool_calls" if tool_calls else ( - "stop" if data.get("done") else None - ) - - return ChatCompletionResponse( - model=f"ollama/{model}", - choices=[ - Choice( - message=MessageResponse( - content=content or None, - tool_calls=tool_calls, - ), - finish_reason=finish_reason, - ) - ], - usage=Usage( - prompt_tokens=data.get("prompt_eval_count", 0), - completion_tokens=data.get("eval_count", 0), - total_tokens=data.get("prompt_eval_count", 0) + data.get("eval_count", 0), - ), - ) - - async def chat_completion_stream( - self, - request: ChatCompletionRequest, - model: str, - ) -> AsyncIterator[ChatCompletionStreamResponse]: - """Generate a chat completion (streaming).""" - payload: dict[str, Any] = { - "model": model, - "messages": self._convert_messages(request), - "stream": True, - } - - # Add optional parameters - options: dict[str, Any] = {} - if request.temperature is not None: - options["temperature"] = request.temperature - if request.top_p is not None: - options["top_p"] = request.top_p - if request.max_tokens is not None: - options["num_predict"] = request.max_tokens - if request.stop: - options["stop"] = request.stop if isinstance(request.stop, list) else [request.stop] - - if options: - payload["options"] = options - - if request.tools: - payload["tools"] = [ - {"type": "function", "function": t.function.model_dump()} - for t in request.tools - ] - - logger.debug(f"Ollama streaming request: {model}") - - response_id = f"chatcmpl-{model[:8]}" - first_chunk = True - - async with self.client.stream("POST", "/api/chat", json=payload) as response: - response.raise_for_status() - async for line in response.aiter_lines(): - if not line: - continue - - try: - data = json.loads(line) - except json.JSONDecodeError: - logger.warning(f"Failed to parse Ollama response line: {line}") - continue - - # First chunk includes role - if first_chunk: - yield ChatCompletionStreamResponse( - id=response_id, - model=f"ollama/{model}", - choices=[ - StreamChoice( - delta=DeltaContent(role="assistant"), - ) - ], - ) - first_chunk = False - - # Content chunks - content = data.get("message", {}).get("content", "") - if content: - yield ChatCompletionStreamResponse( - id=response_id, - model=f"ollama/{model}", - choices=[ - StreamChoice( - delta=DeltaContent(content=content), - ) - ], - ) - - # Final chunk with finish_reason - if data.get("done"): - yield ChatCompletionStreamResponse( - id=response_id, - model=f"ollama/{model}", - choices=[ - StreamChoice( - delta=DeltaContent(), - finish_reason="stop", - ) - ], - ) - - async def list_models(self) -> list[ModelInfo]: - """List available Ollama models.""" - response = await self.client.get("/api/tags") - response.raise_for_status() - data = response.json() - - models = [] - for model_data in data.get("models", []): - name = model_data.get("name", "") - # Parse modified_at datetime string to Unix timestamp - created = None - if modified_at := model_data.get("modified_at"): - try: - from datetime import datetime - # Handle ISO format with timezone - dt = datetime.fromisoformat(modified_at.replace("Z", "+00:00")) - created = int(dt.timestamp()) - except (ValueError, TypeError): - pass - models.append( - ModelInfo( - id=f"ollama/{name}", - owned_by="ollama", - created=created, - ) - ) - return models - - async def embeddings( - self, - request: EmbeddingRequest, - model: str, - ) -> EmbeddingResponse: - """Generate embeddings for input text.""" - inputs = request.input if isinstance(request.input, list) else [request.input] - embeddings_data = [] - - for i, text in enumerate(inputs): - response = await self.client.post( - "/api/embeddings", - json={"model": model, "prompt": text}, - ) - response.raise_for_status() - data = response.json() - - embeddings_data.append( - EmbeddingData( - index=i, - embedding=data.get("embedding", []), - ) - ) - - return EmbeddingResponse( - data=embeddings_data, - model=f"ollama/{model}", - usage=Usage( - prompt_tokens=sum(len(text.split()) for text in inputs), # Approximate - total_tokens=sum(len(text.split()) for text in inputs), - ), - ) - - async def health_check(self) -> dict[str, Any]: - """Check Ollama health status.""" - try: - response = await self.client.get("/api/tags") - response.raise_for_status() - data = response.json() - model_count = len(data.get("models", [])) - return { - "status": "healthy", - "provider": self.name, - "url": self.base_url, - "models_available": model_count, - } - except Exception as e: - return { - "status": "unhealthy", - "provider": self.name, - "url": self.base_url, - "error": str(e), - } - - async def close(self) -> None: - """Close HTTP client.""" - if self._client and not self._client.is_closed: - await self._client.aclose() diff --git a/services/mana-llm/src/providers/openai_compat.py b/services/mana-llm/src/providers/openai_compat.py deleted file mode 100644 index a720fe7f2..000000000 --- a/services/mana-llm/src/providers/openai_compat.py +++ /dev/null @@ -1,364 +0,0 @@ -"""OpenAI-compatible provider for OpenRouter, Groq, Together, etc.""" - -import json -import logging -from collections.abc import AsyncIterator -from typing import Any - -import httpx - -from src.models import ( - ChatCompletionRequest, - ChatCompletionResponse, - ChatCompletionStreamResponse, - Choice, - DeltaContent, - EmbeddingData, - EmbeddingRequest, - EmbeddingResponse, - MessageResponse, - ModelInfo, - StreamChoice, - ToolCall, - ToolCallFunction, - Usage, -) - -from .base import LLMProvider - -logger = logging.getLogger(__name__) - - -def _tool_calls_from_openai(raw: list[dict[str, Any]] | None) -> list[ToolCall] | None: - """Normalise an OpenAI-spec ``tool_calls`` array into our model shape.""" - if not raw: - return None - calls: list[ToolCall] = [] - for c in raw: - fn = c.get("function") or {} - name = fn.get("name") - if not name: - continue - calls.append( - ToolCall( - id=c.get("id") or "", - function=ToolCallFunction( - name=name, - arguments=fn.get("arguments") or "{}", - ), - ) - ) - return calls or None - - -class OpenAICompatProvider(LLMProvider): - """OpenAI-compatible API provider (OpenRouter, Groq, Together, etc.).""" - - # OpenRouter/Groq/Together all expose tool_calls per the OpenAI spec; - # individual models within those services may or may not support it, - # but the request shape is uniform. The upstream returns a proper - # error for unsupported models — no silent downgrade here. - supports_tools = True - - def __init__( - self, - name: str, - base_url: str, - api_key: str, - default_model: str | None = None, - timeout: int = 120, - ): - self.name = name - self.base_url = base_url.rstrip("/") - self.api_key = api_key - self.default_model = default_model - self.timeout = timeout - self._client: httpx.AsyncClient | None = None - - @property - def client(self) -> httpx.AsyncClient: - """Get or create HTTP client.""" - if self._client is None or self._client.is_closed: - self._client = httpx.AsyncClient( - base_url=self.base_url, - timeout=httpx.Timeout(self.timeout), - headers={ - "Authorization": f"Bearer {self.api_key}", - "Content-Type": "application/json", - }, - ) - return self._client - - def _convert_messages(self, request: ChatCompletionRequest) -> list[dict[str, Any]]: - """Convert internal message format to OpenAI format. - - The OpenAI chat-completions endpoint is the source of truth for - this shape, so most fields pass through verbatim — including - ``role='tool'`` messages with ``tool_call_id`` + ``content`` and - assistant messages carrying ``tool_calls[]``. - """ - messages: list[dict[str, Any]] = [] - for msg in request.messages: - out: dict[str, Any] = {"role": msg.role} - - if msg.tool_call_id: - out["tool_call_id"] = msg.tool_call_id - - if msg.tool_calls: - out["tool_calls"] = [ - { - "id": c.id, - "type": c.type, - "function": { - "name": c.function.name, - "arguments": c.function.arguments, - }, - } - for c in msg.tool_calls - ] - - if msg.content is None: - # Assistant tool-call messages have null content per spec. - out["content"] = None - elif isinstance(msg.content, str): - out["content"] = msg.content - else: - content_parts = [] - for part in msg.content: - if part.type == "text": - content_parts.append({"type": "text", "text": part.text}) - elif part.type == "image_url": - content_parts.append( - { - "type": "image_url", - "image_url": {"url": part.image_url.url}, - } - ) - out["content"] = content_parts - - messages.append(out) - return messages - - async def chat_completion( - self, - request: ChatCompletionRequest, - model: str, - ) -> ChatCompletionResponse: - """Generate a chat completion (non-streaming).""" - payload: dict[str, Any] = { - "model": model, - "messages": self._convert_messages(request), - "stream": False, - } - - # Add optional parameters - if request.temperature is not None: - payload["temperature"] = request.temperature - if request.max_tokens is not None: - payload["max_tokens"] = request.max_tokens - if request.top_p is not None: - payload["top_p"] = request.top_p - if request.frequency_penalty is not None: - payload["frequency_penalty"] = request.frequency_penalty - if request.presence_penalty is not None: - payload["presence_penalty"] = request.presence_penalty - if request.stop: - payload["stop"] = request.stop - if request.tools: - payload["tools"] = [ - {"type": "function", "function": t.function.model_dump()} - for t in request.tools - ] - if request.tool_choice is not None: - payload["tool_choice"] = ( - request.tool_choice - if isinstance(request.tool_choice, str) - else request.tool_choice.model_dump() - ) - - logger.debug(f"{self.name} request: {model}, messages: {len(request.messages)}") - - response = await self.client.post("/chat/completions", json=payload) - response.raise_for_status() - data = response.json() - - return ChatCompletionResponse( - id=data.get("id", ""), - model=f"{self.name}/{model}", - choices=[ - Choice( - index=choice.get("index", 0), - message=MessageResponse( - content=choice["message"].get("content"), - tool_calls=_tool_calls_from_openai( - choice["message"].get("tool_calls") - ), - ), - finish_reason=choice.get("finish_reason", "stop"), - ) - for choice in data.get("choices", []) - ], - usage=Usage( - prompt_tokens=data.get("usage", {}).get("prompt_tokens", 0), - completion_tokens=data.get("usage", {}).get("completion_tokens", 0), - total_tokens=data.get("usage", {}).get("total_tokens", 0), - ), - ) - - async def chat_completion_stream( - self, - request: ChatCompletionRequest, - model: str, - ) -> AsyncIterator[ChatCompletionStreamResponse]: - """Generate a chat completion (streaming).""" - payload: dict[str, Any] = { - "model": model, - "messages": self._convert_messages(request), - "stream": True, - } - - # Add optional parameters - if request.temperature is not None: - payload["temperature"] = request.temperature - if request.max_tokens is not None: - payload["max_tokens"] = request.max_tokens - if request.top_p is not None: - payload["top_p"] = request.top_p - if request.frequency_penalty is not None: - payload["frequency_penalty"] = request.frequency_penalty - if request.presence_penalty is not None: - payload["presence_penalty"] = request.presence_penalty - if request.stop: - payload["stop"] = request.stop - if request.tools: - payload["tools"] = [ - {"type": "function", "function": t.function.model_dump()} - for t in request.tools - ] - if request.tool_choice is not None: - payload["tool_choice"] = ( - request.tool_choice - if isinstance(request.tool_choice, str) - else request.tool_choice.model_dump() - ) - - logger.debug(f"{self.name} streaming request: {model}") - - async with self.client.stream("POST", "/chat/completions", json=payload) as response: - response.raise_for_status() - async for line in response.aiter_lines(): - if not line or not line.startswith("data: "): - continue - - data_str = line[6:] # Remove "data: " prefix - - if data_str == "[DONE]": - break - - try: - data = json.loads(data_str) - except json.JSONDecodeError: - logger.warning(f"Failed to parse stream line: {data_str}") - continue - - choices = data.get("choices", []) - if not choices: - continue - - choice = choices[0] - delta = choice.get("delta", {}) - - yield ChatCompletionStreamResponse( - id=data.get("id", ""), - model=f"{self.name}/{model}", - choices=[ - StreamChoice( - index=choice.get("index", 0), - delta=DeltaContent( - role=delta.get("role"), - content=delta.get("content"), - tool_calls=_tool_calls_from_openai( - delta.get("tool_calls") - ), - ), - finish_reason=choice.get("finish_reason"), - ) - ], - ) - - async def list_models(self) -> list[ModelInfo]: - """List available models.""" - try: - response = await self.client.get("/models") - response.raise_for_status() - data = response.json() - - models = [] - for model_data in data.get("data", []): - model_id = model_data.get("id", "") - models.append( - ModelInfo( - id=f"{self.name}/{model_id}", - owned_by=model_data.get("owned_by", self.name), - ) - ) - return models - except httpx.HTTPError as e: - logger.warning(f"Failed to list models from {self.name}: {e}") - return [] - - async def embeddings( - self, - request: EmbeddingRequest, - model: str, - ) -> EmbeddingResponse: - """Generate embeddings for input text.""" - payload = { - "model": model, - "input": request.input, - } - - response = await self.client.post("/embeddings", json=payload) - response.raise_for_status() - data = response.json() - - return EmbeddingResponse( - data=[ - EmbeddingData( - index=item.get("index", i), - embedding=item.get("embedding", []), - ) - for i, item in enumerate(data.get("data", [])) - ], - model=f"{self.name}/{model}", - usage=Usage( - prompt_tokens=data.get("usage", {}).get("prompt_tokens", 0), - total_tokens=data.get("usage", {}).get("total_tokens", 0), - ), - ) - - async def health_check(self) -> dict[str, Any]: - """Check provider health status.""" - try: - response = await self.client.get("/models") - response.raise_for_status() - data = response.json() - model_count = len(data.get("data", [])) - return { - "status": "healthy", - "provider": self.name, - "url": self.base_url, - "models_available": model_count, - } - except Exception as e: - return { - "status": "unhealthy", - "provider": self.name, - "url": self.base_url, - "error": str(e), - } - - async def close(self) -> None: - """Close HTTP client.""" - if self._client and not self._client.is_closed: - await self._client.aclose() diff --git a/services/mana-llm/src/providers/router.py b/services/mana-llm/src/providers/router.py deleted file mode 100644 index e6efaa6a0..000000000 --- a/services/mana-llm/src/providers/router.py +++ /dev/null @@ -1,463 +0,0 @@ -"""Provider routing with alias resolution and health-aware fallback. - -The router is the single entry point that the FastAPI handlers use. Its -job is: - -1. Resolve the request's ``model`` field. If it lives in the ``mana/`` - namespace the :class:`AliasRegistry` returns an ordered chain of - concrete provider/model strings; everything else is treated as a - single-entry chain (caller passed a direct provider/model). -2. Walk the chain, skipping entries whose provider is either - unconfigured at this deployment (no API key) or currently marked - unhealthy in the :class:`ProviderHealthCache`. -3. Try each remaining entry. Connection errors, timeouts, 5xx, and rate - limits are retryable — record them in the cache and move to the next - entry. Capability/auth/blocked errors are caller-fixable and - propagate immediately without touching the health cache. -4. Return the first successful response. If every entry was skipped or - failed, raise :class:`NoHealthyProviderError` (HTTP 503) carrying - the full attempt log so debugging is straightforward. - -The full design lives in ``docs/plans/llm-fallback-aliases.md``. This is -the M3 milestone. -""" - -from __future__ import annotations - -import logging -from collections.abc import AsyncIterator, Awaitable, Callable -from typing import Any, TypeVar - -import httpx - -from src.aliases import AliasRegistry -from src.config import settings -from src.health import ProviderHealthCache -from src.models import ( - ChatCompletionRequest, - ChatCompletionResponse, - ChatCompletionStreamResponse, - EmbeddingRequest, - EmbeddingResponse, - ModelInfo, -) -from src.utils.metrics import ( - record_alias_resolved, - record_fallback, - set_provider_healthy, -) - -from .base import LLMProvider -from .errors import ( - NoHealthyProviderError, - ProviderAuthError, - ProviderBlockedError, - ProviderCapabilityError, - ProviderError, - ProviderRateLimitError, -) -from .ollama import OllamaProvider -from .openai_compat import OpenAICompatProvider - -logger = logging.getLogger(__name__) - -T = TypeVar("T") - - -class ProviderRouter: - """Health-aware provider router with alias resolution. - - Construct with the AliasRegistry and ProviderHealthCache from - application startup; both are external dependencies so tests can - inject mocks without going through global state. - """ - - def __init__( - self, - aliases: AliasRegistry, - health_cache: ProviderHealthCache, - ) -> None: - self.aliases = aliases - self.health_cache = health_cache - self.providers: dict[str, LLMProvider] = {} - self._initialize_providers() - - # ------------------------------------------------------------------ - # Provider initialisation - # ------------------------------------------------------------------ - - def _initialize_providers(self) -> None: - """Spin up provider adapters based on what's configured.""" - # Ollama: always present (talks to a local/proxied server). Whether - # it's actually reachable is the cache's job to figure out. - self.providers["ollama"] = OllamaProvider() - logger.info("Initialized Ollama provider at %s", settings.ollama_url) - - if settings.google_api_key: - from .google import GoogleProvider - - self.providers["google"] = GoogleProvider( - api_key=settings.google_api_key, - default_model=settings.google_default_model, - ) - logger.info("Initialized Google Gemini provider") - - if settings.openrouter_api_key: - self.providers["openrouter"] = OpenAICompatProvider( - name="openrouter", - base_url=settings.openrouter_base_url, - api_key=settings.openrouter_api_key, - default_model=settings.openrouter_default_model, - ) - logger.info("Initialized OpenRouter provider") - - if settings.groq_api_key: - self.providers["groq"] = OpenAICompatProvider( - name="groq", - base_url=settings.groq_base_url, - api_key=settings.groq_api_key, - ) - logger.info("Initialized Groq provider") - - if settings.together_api_key: - self.providers["together"] = OpenAICompatProvider( - name="together", - base_url=settings.together_base_url, - api_key=settings.together_api_key, - ) - logger.info("Initialized Together provider") - - # ------------------------------------------------------------------ - # Helpers - # ------------------------------------------------------------------ - - def _parse_model(self, model: str) -> tuple[str, str]: - """Split ``provider/model`` into its parts. - - Bare names (no prefix) default to Ollama for compatibility with - plain OpenAI-style requests. Aliases (``mana/...``) are resolved - before this is ever called. - """ - if "/" in model: - provider, _, model_name = model.partition("/") - return provider.lower(), model_name - return "ollama", model - - def _resolve_chain(self, model_or_alias: str) -> list[str]: - """Expand aliases to chains; pass everything else through unchanged.""" - if AliasRegistry.is_alias(model_or_alias): - return list(self.aliases.resolve_chain(model_or_alias)) - return [model_or_alias] - - @staticmethod - def _is_retryable(exc: BaseException) -> bool: - """Should we treat this exception as "try the next chain entry"? - - ConnectError / timeouts / 5xx / rate-limits = yes (provider blip). - Auth / capability / blocked / 4xx = no (caller has to fix - something; retrying with a different provider only hides the bug). - """ - if isinstance(exc, ProviderCapabilityError): - return False - if isinstance(exc, ProviderBlockedError): - return False - if isinstance(exc, ProviderAuthError): - return False - if isinstance(exc, ProviderRateLimitError): - return True - if isinstance( - exc, - ( - httpx.ConnectError, - httpx.ConnectTimeout, - httpx.ReadError, - httpx.ReadTimeout, - httpx.RemoteProtocolError, - httpx.WriteError, - httpx.WriteTimeout, - httpx.PoolTimeout, - ), - ): - return True - if isinstance(exc, httpx.HTTPStatusError): - return exc.response.status_code >= 500 - if isinstance(exc, ProviderError): - # Any other provider-side error — treat as retryable. - # Subclasses with explicit non-retry semantics are caught above. - return True - # Unknown exception types: do NOT silently retry. Better to - # surface a strange error than hide a real bug behind a fallback. - return False - - @staticmethod - def _exception_summary(exc: BaseException) -> str: - """Compact one-liner for cache.last_error and log entries.""" - return f"{type(exc).__name__}: {exc}" - - def _check_tool_capability( - self, - provider: LLMProvider, - model_name: str, - request: ChatCompletionRequest, - ) -> None: - """Refuse tool-bearing requests for models that don't support tools. - - Silent downgrade (dropping the ``tools`` payload) is more dangerous - than an explicit error — the caller would get plain text back and - have no way to tell the tools never reached the model. - """ - if not request.tools: - return - if not provider.model_supports_tools(model_name): - raise ProviderCapabilityError( - f"{provider.name}/{model_name} does not support tool calling. " - "Choose a tool-capable model (e.g. groq/llama-3.3-70b-versatile)" - ) - - # ------------------------------------------------------------------ - # Core fallback executor (non-streaming) - # ------------------------------------------------------------------ - - async def _execute_with_fallback( - self, - model_or_alias: str, - request: ChatCompletionRequest, - call: Callable[[LLMProvider, str, ChatCompletionRequest], Awaitable[T]], - ) -> T: - """Walk the resolved chain, returning the first successful result. - - ``call`` is the operation to run against each chain entry, e.g. - ``lambda p, m, req: p.chat_completion(req, m)``. The function - receives the provider instance, the model name (without the - provider prefix), and the original request. - """ - chain = self._resolve_chain(model_or_alias) - attempts: list[tuple[str, str]] = [] - last_exc: Exception | None = None - is_alias = AliasRegistry.is_alias(model_or_alias) - - for i, entry in enumerate(chain): - next_entry = chain[i + 1] if i + 1 < len(chain) else "" - provider_name, model_name = self._parse_model(entry) - if provider_name not in self.providers: - logger.debug( - "skip chain entry %s — provider %s not configured here", - entry, - provider_name, - ) - attempts.append((entry, "unconfigured")) - record_fallback(entry, next_entry, "unconfigured") - continue - - if not self.health_cache.is_healthy(provider_name): - logger.debug("skip chain entry %s — cache says unhealthy", entry) - attempts.append((entry, "cache-unhealthy")) - record_fallback(entry, next_entry, "cache-unhealthy") - continue - - provider = self.providers[provider_name] - self._check_tool_capability(provider, model_name, request) - - try: - logger.info( - "execute → %s (alias=%s)", - entry, - model_or_alias if is_alias else "", - ) - result = await call(provider, model_name, request) - self.health_cache.mark_healthy(provider_name) - set_provider_healthy(provider_name, True) - if is_alias: - record_alias_resolved(model_or_alias, entry) - return result - except Exception as e: - if not self._is_retryable(e): - # Caller error / non-retryable provider error — propagate - # without touching the health cache. The cache is for - # liveness, not for recording what the user asked for - # being wrong. - raise - self.health_cache.mark_unhealthy(provider_name, self._exception_summary(e)) - set_provider_healthy( - provider_name, self.health_cache.is_healthy(provider_name) - ) - attempts.append((entry, type(e).__name__)) - record_fallback(entry, next_entry, type(e).__name__) - last_exc = e - logger.warning( - "execute %s failed (retryable, will try next): %s", - entry, - e, - ) - - raise NoHealthyProviderError(model_or_alias, attempts, last_exc) - - # ------------------------------------------------------------------ - # Public API — non-streaming - # ------------------------------------------------------------------ - - async def chat_completion( - self, - request: ChatCompletionRequest, - ) -> ChatCompletionResponse: - """Chat completion with alias resolution + health-aware fallback.""" - - async def call(provider: LLMProvider, model: str, req: ChatCompletionRequest): - return await provider.chat_completion(req, model) - - return await self._execute_with_fallback(request.model, request, call) - - # ------------------------------------------------------------------ - # Public API — streaming (pre-first-byte fallback) - # ------------------------------------------------------------------ - - async def chat_completion_stream( - self, - request: ChatCompletionRequest, - ) -> AsyncIterator[ChatCompletionStreamResponse]: - """Streaming variant. Falls back BEFORE the first chunk arrives; - once the first chunk has been yielded the provider is committed - and any further error propagates. - - Why pre-first-byte only: stitching half-streams from two different - providers would mix two voices in the output and is impossible to - sanity-check after the fact. - """ - chain = self._resolve_chain(request.model) - attempts: list[tuple[str, str]] = [] - last_exc: Exception | None = None - is_alias = AliasRegistry.is_alias(request.model) - - for i, entry in enumerate(chain): - next_entry = chain[i + 1] if i + 1 < len(chain) else "" - provider_name, model_name = self._parse_model(entry) - if provider_name not in self.providers: - attempts.append((entry, "unconfigured")) - record_fallback(entry, next_entry, "unconfigured") - continue - if not self.health_cache.is_healthy(provider_name): - attempts.append((entry, "cache-unhealthy")) - record_fallback(entry, next_entry, "cache-unhealthy") - continue - - provider = self.providers[provider_name] - self._check_tool_capability(provider, model_name, request) - - stream = provider.chat_completion_stream(request, model_name) - try: - first_chunk = await stream.__anext__() - except StopAsyncIteration: - # Empty stream is a successful but content-free response. - # Commit and exit cleanly. - self.health_cache.mark_healthy(provider_name) - set_provider_healthy(provider_name, True) - if is_alias: - record_alias_resolved(request.model, entry) - logger.info("stream %s yielded empty response", entry) - return - except Exception as e: - if not self._is_retryable(e): - raise - self.health_cache.mark_unhealthy(provider_name, self._exception_summary(e)) - set_provider_healthy( - provider_name, self.health_cache.is_healthy(provider_name) - ) - attempts.append((entry, type(e).__name__)) - record_fallback(entry, next_entry, type(e).__name__) - last_exc = e - logger.warning( - "stream %s failed before first byte (retryable, trying next): %s", - entry, - e, - ) - continue - - # First byte landed — commit the provider, mark healthy, drain - # the rest of the stream. Any error from here on propagates; - # it is NOT safe to splice another provider's output in. - self.health_cache.mark_healthy(provider_name) - set_provider_healthy(provider_name, True) - if is_alias: - record_alias_resolved(request.model, entry) - logger.info("stream → %s (committed after first chunk)", entry) - yield first_chunk - async for chunk in stream: - yield chunk - return - - raise NoHealthyProviderError(request.model, attempts, last_exc) - - # ------------------------------------------------------------------ - # Embeddings — no fallback (out of scope for M3, separate concerns) - # ------------------------------------------------------------------ - - async def embeddings(self, request: EmbeddingRequest) -> EmbeddingResponse: - """Route an embeddings request directly. No alias / fallback.""" - provider_name, model_name = self._parse_model(request.model) - if provider_name not in self.providers: - available = list(self.providers) - raise ValueError( - f"Provider '{provider_name}' not available. Available: {available}" - ) - provider = self.providers[provider_name] - logger.info("embeddings → %s/%s", provider_name, model_name) - return await provider.embeddings(request, model_name) - - # ------------------------------------------------------------------ - # Discovery / introspection - # ------------------------------------------------------------------ - - async def list_models(self) -> list[ModelInfo]: - """List all available models from all configured providers. - - Best-effort: providers that error are skipped with a warning so a - single broken provider can't take down ``GET /v1/models``. - """ - all_models: list[ModelInfo] = [] - for provider in self.providers.values(): - try: - all_models.extend(await provider.list_models()) - except Exception as e: # noqa: BLE001 - logger.warning("Failed to list models from %s: %s", provider.name, e) - return all_models - - async def get_model(self, model_id: str) -> ModelInfo | None: - """Look up a single model by id, dispatching on the prefix.""" - provider_name, model_name = self._parse_model(model_id) - if provider_name not in self.providers: - return None - models = await self.providers[provider_name].list_models() - for m in models: - if m.id == model_id or m.id.endswith(f"/{model_name}"): - return m - return None - - async def health_check(self) -> dict[str, Any]: - """Snapshot of the per-provider liveness cache. - - Returns the same shape as before for backwards-compat with - ``GET /health`` (deprecated structure — M4 will swap to a - cleaner ``/v1/health`` endpoint). - """ - snapshot = self.health_cache.snapshot(expected=list(self.providers)) - providers_out: dict[str, Any] = {} - for name, state in snapshot.items(): - providers_out[name] = { - "status": "healthy" if state.healthy else "unhealthy", - "consecutive_failures": state.consecutive_failures, - "last_error": state.last_error, - "last_check_unix": state.last_check or None, - "unhealthy_until_unix": state.unhealthy_until or None, - } - - all_healthy = all(state.healthy for state in snapshot.values()) - any_healthy = any(state.healthy for state in snapshot.values()) - return { - "status": "healthy" if all_healthy else ("degraded" if any_healthy else "unhealthy"), - "providers": providers_out, - } - - async def close(self) -> None: - """Close all provider clients.""" - for provider in self.providers.values(): - await provider.close() diff --git a/services/mana-llm/src/streaming/__init__.py b/services/mana-llm/src/streaming/__init__.py deleted file mode 100644 index abc609671..000000000 --- a/services/mana-llm/src/streaming/__init__.py +++ /dev/null @@ -1,5 +0,0 @@ -"""Streaming utilities for SSE responses.""" - -from .sse import stream_chat_completion - -__all__ = ["stream_chat_completion"] diff --git a/services/mana-llm/src/streaming/sse.py b/services/mana-llm/src/streaming/sse.py deleted file mode 100644 index f87e04c64..000000000 --- a/services/mana-llm/src/streaming/sse.py +++ /dev/null @@ -1,52 +0,0 @@ -"""Server-Sent Events (SSE) response handling.""" - -import json -import logging -from collections.abc import AsyncIterator - -from src.models import ChatCompletionRequest, ChatCompletionStreamResponse -from src.providers import ProviderRouter -from src.providers.errors import ProviderError - -logger = logging.getLogger(__name__) - - -async def stream_chat_completion( - router: ProviderRouter, - request: ChatCompletionRequest, -) -> AsyncIterator[dict]: - """ - Stream chat completion responses for SSE. - - Yields dicts that EventSourceResponse will serialize as: - data: {"choices":[{"delta":{"content":"Hello"}}]} - data: [DONE] - """ - try: - async for chunk in router.chat_completion_stream(request): - # Yield dict for EventSourceResponse to serialize - yield {"data": json.dumps(chunk.model_dump(exclude_none=True))} - - # Send final [DONE] marker - yield {"data": "[DONE]"} - - except ProviderError as e: - logger.warning(f"Streaming provider error: kind={e.kind} detail={e}") - error_data = { - "error": { - "message": str(e), - "type": e.kind, - } - } - yield {"data": json.dumps(error_data)} - yield {"data": "[DONE]"} - except Exception as e: - logger.error(f"Streaming error: {e}") - error_data = { - "error": { - "message": str(e), - "type": "server_error", - } - } - yield {"data": json.dumps(error_data)} - yield {"data": "[DONE]"} diff --git a/services/mana-llm/src/utils/__init__.py b/services/mana-llm/src/utils/__init__.py deleted file mode 100644 index a2f2d361b..000000000 --- a/services/mana-llm/src/utils/__init__.py +++ /dev/null @@ -1,5 +0,0 @@ -"""Utility modules.""" - -from .metrics import get_metrics, metrics_middleware - -__all__ = ["get_metrics", "metrics_middleware"] diff --git a/services/mana-llm/src/utils/cache.py b/services/mana-llm/src/utils/cache.py deleted file mode 100644 index 7427ec8c3..000000000 --- a/services/mana-llm/src/utils/cache.py +++ /dev/null @@ -1,85 +0,0 @@ -"""Redis caching utilities (optional).""" - -import hashlib -import json -import logging -from typing import Any - -from src.config import settings - -logger = logging.getLogger(__name__) - -# Redis client (lazy initialized) -_redis_client = None - - -async def get_redis_client(): - """Get or create Redis client.""" - global _redis_client - - if _redis_client is not None: - return _redis_client - - if not settings.redis_url: - return None - - try: - import redis.asyncio as redis - - _redis_client = redis.from_url(settings.redis_url) - # Test connection - await _redis_client.ping() - logger.info(f"Connected to Redis at {settings.redis_url}") - return _redis_client - except Exception as e: - logger.warning(f"Failed to connect to Redis: {e}") - return None - - -def generate_cache_key(prefix: str, data: dict[str, Any]) -> str: - """Generate a cache key from request data.""" - # Serialize and hash the data for consistent key - serialized = json.dumps(data, sort_keys=True) - hash_value = hashlib.sha256(serialized.encode()).hexdigest()[:16] - return f"mana-llm:{prefix}:{hash_value}" - - -async def get_cached(key: str) -> dict[str, Any] | None: - """Get cached value by key.""" - client = await get_redis_client() - if client is None: - return None - - try: - value = await client.get(key) - if value: - return json.loads(value) - except Exception as e: - logger.warning(f"Cache get failed: {e}") - - return None - - -async def set_cached(key: str, value: dict[str, Any], ttl: int | None = None) -> bool: - """Set cached value with optional TTL.""" - client = await get_redis_client() - if client is None: - return False - - try: - ttl = ttl or settings.cache_ttl - serialized = json.dumps(value) - await client.setex(key, ttl, serialized) - return True - except Exception as e: - logger.warning(f"Cache set failed: {e}") - return False - - -async def close_redis() -> None: - """Close Redis connection.""" - global _redis_client - - if _redis_client is not None: - await _redis_client.aclose() - _redis_client = None diff --git a/services/mana-llm/src/utils/metrics.py b/services/mana-llm/src/utils/metrics.py deleted file mode 100644 index 057cc1656..000000000 --- a/services/mana-llm/src/utils/metrics.py +++ /dev/null @@ -1,158 +0,0 @@ -"""Prometheus metrics for mana-llm.""" - -import time -from collections.abc import Callable - -from fastapi import Request, Response -from prometheus_client import Counter, Gauge, Histogram, generate_latest -from starlette.middleware.base import BaseHTTPMiddleware - -# Request metrics -REQUEST_COUNT = Counter( - "mana_llm_requests_total", - "Total number of requests", - ["method", "endpoint", "status"], -) - -REQUEST_LATENCY = Histogram( - "mana_llm_request_latency_seconds", - "Request latency in seconds", - ["method", "endpoint"], - buckets=[0.1, 0.5, 1.0, 2.0, 5.0, 10.0, 30.0, 60.0, 120.0], -) - -# LLM-specific metrics -LLM_REQUEST_COUNT = Counter( - "mana_llm_llm_requests_total", - "Total number of LLM requests", - ["provider", "model", "streaming"], -) - -LLM_TOKEN_COUNT = Counter( - "mana_llm_tokens_total", - "Total tokens processed", - ["provider", "model", "type"], # type: prompt, completion -) - -LLM_LATENCY = Histogram( - "mana_llm_llm_latency_seconds", - "LLM request latency in seconds", - ["provider", "model"], - buckets=[0.5, 1.0, 2.0, 5.0, 10.0, 30.0, 60.0, 120.0], -) - -LLM_ERRORS = Counter( - "mana_llm_llm_errors_total", - "Total LLM errors", - ["provider", "model", "error_type"], -) - -# --------------------------------------------------------------------------- -# Alias / fallback / health metrics — added in M4 of llm-fallback-aliases.md. -# --------------------------------------------------------------------------- - -ALIAS_RESOLVED = Counter( - "mana_llm_alias_resolved_total", - "How often an alias resolved to a concrete provider/model. The `target` " - "label is the chain entry that actually served the request — useful for " - "spotting cases where the primary always falls through to a cloud entry.", - ["alias", "target"], -) - -FALLBACK_TRIGGERED = Counter( - "mana_llm_fallback_total", - "Fallback transitions: a chain entry failed (or was skipped via cache) " - "and the router moved to the next entry. `reason` is the exception class " - "name or `cache-unhealthy` / `unconfigured`. `from_model` is the entry " - "that didn't serve, `to_model` is empty when no further entries existed.", - ["from_model", "to_model", "reason"], -) - -PROVIDER_HEALTHY = Gauge( - "mana_llm_provider_healthy", - "1 when the provider is currently considered healthy by the cache, " - "0 when in backoff. Refreshed on every probe tick and on every router " - "call-site state transition.", - ["provider"], -) - - -def get_metrics() -> bytes: - """Generate Prometheus metrics output.""" - return generate_latest() - - -class MetricsMiddleware(BaseHTTPMiddleware): - """Middleware for collecting HTTP metrics.""" - - async def dispatch(self, request: Request, call_next: Callable) -> Response: - start_time = time.time() - - response = await call_next(request) - - # Record metrics - duration = time.time() - start_time - endpoint = request.url.path - method = request.method - status = str(response.status_code) - - REQUEST_COUNT.labels(method=method, endpoint=endpoint, status=status).inc() - REQUEST_LATENCY.labels(method=method, endpoint=endpoint).observe(duration) - - return response - - -# Export middleware instance -metrics_middleware = MetricsMiddleware - - -def record_llm_request( - provider: str, - model: str, - streaming: bool, - prompt_tokens: int = 0, - completion_tokens: int = 0, - latency: float | None = None, -) -> None: - """Record LLM request metrics.""" - LLM_REQUEST_COUNT.labels( - provider=provider, - model=model, - streaming=str(streaming).lower(), - ).inc() - - if prompt_tokens > 0: - LLM_TOKEN_COUNT.labels(provider=provider, model=model, type="prompt").inc(prompt_tokens) - - if completion_tokens > 0: - LLM_TOKEN_COUNT.labels(provider=provider, model=model, type="completion").inc( - completion_tokens - ) - - if latency is not None: - LLM_LATENCY.labels(provider=provider, model=model).observe(latency) - - -def record_llm_error(provider: str, model: str, error_type: str) -> None: - """Record LLM error metrics.""" - LLM_ERRORS.labels(provider=provider, model=model, error_type=error_type).inc() - - -def record_alias_resolved(alias: str, target: str) -> None: - """Record which concrete model an alias resolved to for this request.""" - ALIAS_RESOLVED.labels(alias=alias, target=target).inc() - - -def record_fallback(from_model: str, to_model: str, reason: str) -> None: - """Record a fallback transition. ``to_model`` is empty when the chain - ran out (i.e. NoHealthyProviderError).""" - FALLBACK_TRIGGERED.labels( - from_model=from_model, - to_model=to_model, - reason=reason, - ).inc() - - -def set_provider_healthy(provider: str, healthy: bool) -> None: - """Mirror ``ProviderHealthCache`` state into a Prometheus gauge.""" - PROVIDER_HEALTHY.labels(provider=provider).set(1.0 if healthy else 0.0) diff --git a/services/mana-llm/start.sh b/services/mana-llm/start.sh deleted file mode 100755 index 95717583b..000000000 --- a/services/mana-llm/start.sh +++ /dev/null @@ -1,28 +0,0 @@ -#!/bin/bash - -# Start mana-llm service -# Automatically creates venv and installs dependencies if needed - -cd "$(dirname "$0")" - -# Check if venv exists, create if not -if [ ! -d "venv" ]; then - echo "Creating virtual environment..." - python3 -m venv venv -fi - -# Activate venv -source venv/bin/activate - -# Install/update dependencies -pip install -q -r requirements.txt - -# Copy .env if not exists -if [ ! -f ".env" ] && [ -f ".env.example" ]; then - cp .env.example .env - echo "Created .env from .env.example" -fi - -# Start the service -echo "Starting mana-llm on port 3025..." -exec python -m uvicorn src.main:app --port 3025 --reload diff --git a/services/mana-llm/tests/__init__.py b/services/mana-llm/tests/__init__.py deleted file mode 100644 index a0a8deda8..000000000 --- a/services/mana-llm/tests/__init__.py +++ /dev/null @@ -1 +0,0 @@ -"""Tests for mana-llm service.""" diff --git a/services/mana-llm/tests/test_aliases.py b/services/mana-llm/tests/test_aliases.py deleted file mode 100644 index a20389262..000000000 --- a/services/mana-llm/tests/test_aliases.py +++ /dev/null @@ -1,300 +0,0 @@ -"""Tests for the model-alias registry.""" - -from __future__ import annotations - -from pathlib import Path - -import pytest - -from src.aliases import ( - ALIAS_PREFIX, - Alias, - AliasConfigError, - AliasRegistry, - UnknownAliasError, -) - -# --------------------------------------------------------------------------- -# Helpers -# --------------------------------------------------------------------------- - - -def write_yaml(tmp_path: Path, body: str, name: str = "aliases.yaml") -> Path: - p = tmp_path / name - p.write_text(body, encoding="utf-8") - return p - - -VALID_CONFIG = """\ -aliases: - mana/fast-text: - description: "fast" - chain: - - ollama/qwen2.5:7b - - groq/llama-3.1-8b-instant - mana/long-form: - description: "long" - chain: - - ollama/gemma3:12b - - groq/llama-3.3-70b-versatile -default: mana/fast-text -""" - - -# --------------------------------------------------------------------------- -# Construction & happy-path resolution -# --------------------------------------------------------------------------- - - -class TestRegistryHappyPath: - def test_loads_valid_yaml(self, tmp_path: Path) -> None: - path = write_yaml(tmp_path, VALID_CONFIG) - reg = AliasRegistry(path) - assert reg.path == path - assert reg.default_alias == "mana/fast-text" - - def test_resolve_returns_alias_dataclass(self, tmp_path: Path) -> None: - reg = AliasRegistry(write_yaml(tmp_path, VALID_CONFIG)) - alias = reg.resolve("mana/long-form") - assert isinstance(alias, Alias) - assert alias.name == "mana/long-form" - assert alias.description == "long" - assert alias.chain == ("ollama/gemma3:12b", "groq/llama-3.3-70b-versatile") - - def test_resolve_chain_returns_tuple(self, tmp_path: Path) -> None: - reg = AliasRegistry(write_yaml(tmp_path, VALID_CONFIG)) - chain = reg.resolve_chain("mana/fast-text") - assert chain == ("ollama/qwen2.5:7b", "groq/llama-3.1-8b-instant") - # Tuples ensure callers can't mutate the registry's internal state. - assert isinstance(chain, tuple) - - def test_list_aliases_sorted(self, tmp_path: Path) -> None: - reg = AliasRegistry(write_yaml(tmp_path, VALID_CONFIG)) - names = [a.name for a in reg.list_aliases()] - assert names == sorted(names) - assert names == ["mana/fast-text", "mana/long-form"] - - def test_unknown_alias_raises(self, tmp_path: Path) -> None: - reg = AliasRegistry(write_yaml(tmp_path, VALID_CONFIG)) - with pytest.raises(UnknownAliasError, match="mana/nope"): - reg.resolve("mana/nope") - - def test_default_optional(self, tmp_path: Path) -> None: - body = ( - "aliases:\n" - " mana/x:\n" - ' description: "x"\n' - " chain:\n" - " - ollama/foo:1b\n" - ) - reg = AliasRegistry(write_yaml(tmp_path, body)) - assert reg.default_alias is None - - -class TestIsAlias: - """``is_alias`` is a cheap static syntactic check used by the router.""" - - @pytest.mark.parametrize( - "name", - ["mana/fast-text", "mana/anything", f"{ALIAS_PREFIX}foo"], - ) - def test_recognises_alias_namespace(self, name: str) -> None: - assert AliasRegistry.is_alias(name) is True - - @pytest.mark.parametrize( - "name", - ["ollama/gemma3:4b", "groq/llama", "gemma3:4b", "", "mana", "manaX/foo"], - ) - def test_rejects_non_alias(self, name: str) -> None: - assert AliasRegistry.is_alias(name) is False - - def test_static_no_instance_needed(self) -> None: - # Important: callers can hit this without instantiating, so it must - # be a free function or @staticmethod. - assert AliasRegistry.is_alias("mana/x") is True - - -# --------------------------------------------------------------------------- -# Schema validation — the YAML is user-edited, must fail loudly on typos -# --------------------------------------------------------------------------- - - -class TestSchemaValidation: - def test_missing_file_raises(self, tmp_path: Path) -> None: - with pytest.raises(AliasConfigError, match="not found"): - AliasRegistry(tmp_path / "absent.yaml") - - def test_invalid_yaml_raises(self, tmp_path: Path) -> None: - path = write_yaml(tmp_path, "aliases: [\n unclosed") - with pytest.raises(AliasConfigError, match="failed to parse"): - AliasRegistry(path) - - def test_root_not_a_mapping(self, tmp_path: Path) -> None: - path = write_yaml(tmp_path, "- just-a-list\n") - with pytest.raises(AliasConfigError, match="root must be a mapping"): - AliasRegistry(path) - - def test_aliases_must_be_mapping(self, tmp_path: Path) -> None: - path = write_yaml(tmp_path, "aliases: just-a-string\n") - with pytest.raises(AliasConfigError, match="`aliases` must be a mapping"): - AliasRegistry(path) - - def test_empty_aliases_rejected(self, tmp_path: Path) -> None: - path = write_yaml(tmp_path, "aliases: {}\n") - with pytest.raises(AliasConfigError, match="empty"): - AliasRegistry(path) - - def test_alias_name_must_use_mana_namespace(self, tmp_path: Path) -> None: - body = ( - "aliases:\n" - " fast-text:\n" - ' description: "x"\n' - " chain:\n" - " - ollama/foo:1b\n" - ) - path = write_yaml(tmp_path, body) - with pytest.raises(AliasConfigError, match="mana/"): - AliasRegistry(path) - - def test_alias_name_must_have_one_segment(self, tmp_path: Path) -> None: - body = ( - "aliases:\n" - " mana/foo/bar:\n" - ' description: "x"\n' - " chain:\n" - " - ollama/foo:1b\n" - ) - path = write_yaml(tmp_path, body) - with pytest.raises(AliasConfigError, match="exactly one segment"): - AliasRegistry(path) - - def test_chain_must_be_list(self, tmp_path: Path) -> None: - body = ( - "aliases:\n" - " mana/x:\n" - ' description: "x"\n' - ' chain: "ollama/gemma3:4b"\n' - ) - path = write_yaml(tmp_path, body) - with pytest.raises(AliasConfigError, match="chain must be a list"): - AliasRegistry(path) - - def test_empty_chain_rejected(self, tmp_path: Path) -> None: - body = "aliases:\n mana/x:\n chain: []\n" - path = write_yaml(tmp_path, body) - with pytest.raises(AliasConfigError, match="must not be empty"): - AliasRegistry(path) - - def test_chain_entry_without_provider_prefix_rejected(self, tmp_path: Path) -> None: - # "gemma3:4b" without a provider/ prefix would silently default to - # ollama and confuse the health-cache; reject loudly at config-load. - body = "aliases:\n mana/x:\n chain:\n - gemma3:4b\n" - path = write_yaml(tmp_path, body) - with pytest.raises(AliasConfigError, match="provider prefix"): - AliasRegistry(path) - - def test_chain_entry_must_be_string(self, tmp_path: Path) -> None: - body = "aliases:\n mana/x:\n chain:\n - 42\n" - path = write_yaml(tmp_path, body) - with pytest.raises(AliasConfigError): - AliasRegistry(path) - - def test_default_must_reference_known_alias(self, tmp_path: Path) -> None: - body = ( - "aliases:\n" - " mana/x:\n" - ' description: "x"\n' - " chain:\n" - " - ollama/foo:1b\n" - "default: mana/missing\n" - ) - path = write_yaml(tmp_path, body) - with pytest.raises(AliasConfigError, match="references unknown alias"): - AliasRegistry(path) - - -# --------------------------------------------------------------------------- -# Reload semantics — SIGHUP should be safe even with typos -# --------------------------------------------------------------------------- - - -class TestReload: - def test_reload_picks_up_edits(self, tmp_path: Path) -> None: - path = write_yaml(tmp_path, VALID_CONFIG) - reg = AliasRegistry(path) - assert reg.resolve_chain("mana/long-form") == ( - "ollama/gemma3:12b", - "groq/llama-3.3-70b-versatile", - ) - - # Edit on disk: shrink the long-form chain. - new_body = ( - "aliases:\n" - " mana/long-form:\n" - ' description: "shorter"\n' - " chain:\n" - " - groq/llama-3.3-70b-versatile\n" - "default: mana/long-form\n" - ) - path.write_text(new_body, encoding="utf-8") - reg.reload() - - assert reg.resolve_chain("mana/long-form") == ("groq/llama-3.3-70b-versatile",) - assert reg.default_alias == "mana/long-form" - # Aliases that disappeared from the new file are gone. - with pytest.raises(UnknownAliasError): - reg.resolve("mana/fast-text") - - def test_reload_keeps_old_state_on_parse_error(self, tmp_path: Path) -> None: - path = write_yaml(tmp_path, VALID_CONFIG) - reg = AliasRegistry(path) - # First reload fine — establish a baseline. - reg.reload() - - # Now break the file with an obviously invalid yaml. - path.write_text("aliases: [unclosed\n", encoding="utf-8") - with pytest.raises(AliasConfigError): - reg.reload() - - # The previous good state must still be queryable — service stays up. - assert reg.resolve_chain("mana/fast-text") == ( - "ollama/qwen2.5:7b", - "groq/llama-3.1-8b-instant", - ) - assert reg.default_alias == "mana/fast-text" - - def test_reload_keeps_old_state_on_schema_error(self, tmp_path: Path) -> None: - path = write_yaml(tmp_path, VALID_CONFIG) - reg = AliasRegistry(path) - - # Empty aliases — would be rejected on first load, must also be - # rejected here without nuking the in-memory state. - path.write_text("aliases: {}\n", encoding="utf-8") - with pytest.raises(AliasConfigError): - reg.reload() - - assert "mana/fast-text" in [a.name for a in reg.list_aliases()] - - -# --------------------------------------------------------------------------- -# Repo-shipped aliases.yaml is itself valid -# --------------------------------------------------------------------------- - - -class TestShippedConfig: - def test_repo_aliases_yaml_loads(self) -> None: - # The yaml file checked into services/mana-llm/aliases.yaml is the - # one that runs in production. It must always parse cleanly — this - # test catches editor accidents before they ship. - repo_yaml = Path(__file__).resolve().parents[1] / "aliases.yaml" - assert repo_yaml.exists(), f"shipped config missing at {repo_yaml}" - reg = AliasRegistry(repo_yaml) - # Sanity: the five classes the plan calls out must exist. - for expected in ( - "mana/fast-text", - "mana/long-form", - "mana/structured", - "mana/reasoning", - "mana/vision", - ): - reg.resolve(expected) diff --git a/services/mana-llm/tests/test_api.py b/services/mana-llm/tests/test_api.py deleted file mode 100644 index e6ea4f8b7..000000000 --- a/services/mana-llm/tests/test_api.py +++ /dev/null @@ -1,41 +0,0 @@ -"""API endpoint tests.""" - -import pytest -from fastapi.testclient import TestClient - - -@pytest.fixture -def client(): - """Create test client.""" - from src.main import app - - with TestClient(app) as c: - yield c - - -def test_health_endpoint(client): - """Test health check endpoint.""" - response = client.get("/health") - assert response.status_code == 200 - data = response.json() - assert "status" in data - assert "service" in data - assert data["service"] == "mana-llm" - - -def test_metrics_endpoint(client): - """Test metrics endpoint.""" - response = client.get("/metrics") - assert response.status_code == 200 - assert "mana_llm" in response.text - - -def test_list_models_endpoint(client): - """Test list models endpoint.""" - response = client.get("/v1/models") - # May fail if Ollama is not running, but should return valid response structure - if response.status_code == 200: - data = response.json() - assert "data" in data - assert "object" in data - assert data["object"] == "list" diff --git a/services/mana-llm/tests/test_health.py b/services/mana-llm/tests/test_health.py deleted file mode 100644 index f3b72ccc5..000000000 --- a/services/mana-llm/tests/test_health.py +++ /dev/null @@ -1,203 +0,0 @@ -"""Tests for the provider health cache.""" - -from __future__ import annotations - -import pytest - -from src.health import ( - DEFAULT_FAILURE_THRESHOLD, - DEFAULT_UNHEALTHY_BACKOFF_SEC, - ProviderHealthCache, - ProviderState, -) - - -class FakeClock: - """Deterministic clock for circuit-breaker timing tests.""" - - def __init__(self, start: float = 1_000_000.0) -> None: - self.now = start - - def __call__(self) -> float: - return self.now - - def advance(self, seconds: float) -> None: - self.now += seconds - - -# --------------------------------------------------------------------------- -# Construction -# --------------------------------------------------------------------------- - - -class TestConstruction: - def test_defaults(self) -> None: - c = ProviderHealthCache() - assert c.failure_threshold == DEFAULT_FAILURE_THRESHOLD - assert c.unhealthy_backoff_sec == DEFAULT_UNHEALTHY_BACKOFF_SEC - - def test_invalid_threshold_rejected(self) -> None: - with pytest.raises(ValueError, match="failure_threshold"): - ProviderHealthCache(failure_threshold=0) - - def test_invalid_backoff_rejected(self) -> None: - with pytest.raises(ValueError, match="backoff"): - ProviderHealthCache(unhealthy_backoff_sec=-1.0) - - -# --------------------------------------------------------------------------- -# Default-healthy behaviour -# --------------------------------------------------------------------------- - - -class TestDefaults: - def test_unknown_provider_is_healthy(self) -> None: - # The cache is observation-only — no entry means "no reason to skip". - c = ProviderHealthCache() - assert c.is_healthy("ollama") is True - - def test_unknown_provider_has_no_state(self) -> None: - c = ProviderHealthCache() - assert c.get_state("ollama") is None - - def test_snapshot_includes_expected_zero_state(self) -> None: - c = ProviderHealthCache() - snap = c.snapshot(expected=["ollama", "groq"]) - assert set(snap.keys()) == {"ollama", "groq"} - assert all(s.healthy for s in snap.values()) - assert all(s.consecutive_failures == 0 for s in snap.values()) - - -# --------------------------------------------------------------------------- -# Failure → unhealthy state machine -# --------------------------------------------------------------------------- - - -class TestFailureSemantics: - def test_first_failure_does_not_trip(self) -> None: - # Single transient blips shouldn't bounce a provider — wait for the - # next consecutive failure to confirm. - c = ProviderHealthCache(failure_threshold=2) - c.mark_unhealthy("ollama", "boom") - assert c.is_healthy("ollama") is True - state = c.get_state("ollama") - assert state is not None - assert state.consecutive_failures == 1 - assert state.healthy is True - - def test_threshold_reached_trips_breaker(self) -> None: - clock = FakeClock() - c = ProviderHealthCache( - failure_threshold=2, - unhealthy_backoff_sec=60.0, - clock=clock, - ) - c.mark_unhealthy("ollama", "fail-1") - c.mark_unhealthy("ollama", "fail-2") - assert c.is_healthy("ollama") is False - state = c.get_state("ollama") - assert state is not None - assert state.healthy is False - assert state.last_error == "fail-2" - assert state.unhealthy_until == clock.now + 60.0 - - def test_threshold_one_trips_immediately(self) -> None: - c = ProviderHealthCache(failure_threshold=1) - c.mark_unhealthy("ollama", "boom") - assert c.is_healthy("ollama") is False - - def test_mark_healthy_clears_state(self) -> None: - c = ProviderHealthCache(failure_threshold=2) - c.mark_unhealthy("ollama", "x") - c.mark_unhealthy("ollama", "x") - assert c.is_healthy("ollama") is False - c.mark_healthy("ollama") - assert c.is_healthy("ollama") is True - state = c.get_state("ollama") - assert state is not None - assert state.healthy is True - assert state.consecutive_failures == 0 - assert state.unhealthy_until == 0.0 - assert state.last_error is None - - -class TestBackoffWindow: - def test_is_healthy_false_during_backoff(self) -> None: - clock = FakeClock() - c = ProviderHealthCache(failure_threshold=1, unhealthy_backoff_sec=60.0, clock=clock) - c.mark_unhealthy("ollama", "boom") - assert c.is_healthy("ollama") is False - clock.advance(30.0) - assert c.is_healthy("ollama") is False - - def test_is_healthy_true_after_backoff_expires(self) -> None: - # After backoff: half-open. Router gets one more attempt; success - # mark_healthy resets, failure mark_unhealthy re-arms backoff. - clock = FakeClock() - c = ProviderHealthCache(failure_threshold=1, unhealthy_backoff_sec=60.0, clock=clock) - c.mark_unhealthy("ollama", "boom") - clock.advance(61.0) - assert c.is_healthy("ollama") is True - - def test_failure_during_backoff_extends_window(self) -> None: - clock = FakeClock() - c = ProviderHealthCache(failure_threshold=1, unhealthy_backoff_sec=60.0, clock=clock) - c.mark_unhealthy("ollama", "first") - original_until = c.get_state("ollama").unhealthy_until - clock.advance(20.0) - c.mark_unhealthy("ollama", "second") - new_until = c.get_state("ollama").unhealthy_until - assert new_until > original_until - assert new_until == clock.now + 60.0 - - def test_recovery_logged_only_once(self, caplog: pytest.LogCaptureFixture) -> None: - clock = FakeClock() - c = ProviderHealthCache(failure_threshold=1, unhealthy_backoff_sec=60.0, clock=clock) - c.mark_unhealthy("ollama", "boom") - with caplog.at_level("INFO"): - c.mark_healthy("ollama") - # Calling mark_healthy on an already-healthy provider must not - # spam the recovery log line. - c.mark_healthy("ollama") - c.mark_healthy("ollama") - recovery_lines = [r for r in caplog.records if "recovered" in r.message] - assert len(recovery_lines) == 1 - - -# --------------------------------------------------------------------------- -# Snapshot shape -# --------------------------------------------------------------------------- - - -class TestSnapshot: - def test_snapshot_returns_copies(self) -> None: - # Caller shouldn't be able to poke through to the cache's internal - # state via a snapshot reference. - c = ProviderHealthCache(failure_threshold=1) - c.mark_unhealthy("ollama", "x") - snap = c.snapshot() - snap["ollama"].healthy = True - # Original state untouched: - assert c.is_healthy("ollama") is False - - def test_snapshot_matches_recorded_state(self) -> None: - clock = FakeClock() - c = ProviderHealthCache( - failure_threshold=2, - unhealthy_backoff_sec=60.0, - clock=clock, - ) - c.mark_unhealthy("groq", "rate-limit") - snap = c.snapshot() - assert isinstance(snap["groq"], ProviderState) - assert snap["groq"].consecutive_failures == 1 - assert snap["groq"].last_error == "rate-limit" - assert snap["groq"].last_check == clock.now - - def test_snapshot_expected_does_not_overwrite_real_state(self) -> None: - c = ProviderHealthCache(failure_threshold=1) - c.mark_unhealthy("ollama", "real-boom") - snap = c.snapshot(expected=["ollama", "groq"]) - # ollama keeps its real (unhealthy) state, groq gets the zero-default. - assert snap["ollama"].consecutive_failures == 1 - assert snap["groq"].consecutive_failures == 0 diff --git a/services/mana-llm/tests/test_health_probe.py b/services/mana-llm/tests/test_health_probe.py deleted file mode 100644 index 9305f5013..000000000 --- a/services/mana-llm/tests/test_health_probe.py +++ /dev/null @@ -1,222 +0,0 @@ -"""Tests for the background health-probe loop.""" - -from __future__ import annotations - -import asyncio -from typing import Awaitable, Callable - -import pytest - -from src.health import ProviderHealthCache -from src.health_probe import HealthProbe, ProbeFn - - -def make_probe(*, returns: bool = True, raises: type[BaseException] | None = None) -> ProbeFn: - """Synthesise a probe function that always returns / raises the same.""" - - async def probe() -> bool: - if raises is not None: - raise raises("boom") - return returns - - return probe - - -def make_slow_probe(delay: float, returns: bool = True) -> ProbeFn: - async def probe() -> bool: - await asyncio.sleep(delay) - return returns - - return probe - - -def make_call_counter() -> tuple[ProbeFn, Callable[[], int]]: - """Probe that counts how many times it was awaited.""" - count = 0 - - async def probe() -> bool: - nonlocal count - count += 1 - return True - - return probe, lambda: count - - -# --------------------------------------------------------------------------- -# Construction -# --------------------------------------------------------------------------- - - -class TestConstruction: - def test_invalid_interval(self) -> None: - with pytest.raises(ValueError, match="interval"): - HealthProbe(ProviderHealthCache(), {}, interval=0.0) - - def test_invalid_timeout(self) -> None: - with pytest.raises(ValueError, match="timeout"): - HealthProbe(ProviderHealthCache(), {}, timeout=0.0) - - def test_provider_ids_exposes_keys(self) -> None: - cache = ProviderHealthCache() - probe = HealthProbe( - cache, - {"ollama": make_probe(), "groq": make_probe()}, - ) - assert sorted(probe.provider_ids) == ["groq", "ollama"] - - -# --------------------------------------------------------------------------- -# tick_once — the per-cycle behaviour -# --------------------------------------------------------------------------- - - -class TestTickOnce: - @pytest.mark.asyncio - async def test_healthy_probe_marks_healthy(self) -> None: - cache = ProviderHealthCache(failure_threshold=1) - # Pre-mark unhealthy so we can verify the probe recovers it. - cache.mark_unhealthy("ollama", "stale") - probe = HealthProbe(cache, {"ollama": make_probe(returns=True)}) - await probe.tick_once() - assert cache.is_healthy("ollama") is True - - @pytest.mark.asyncio - async def test_returning_false_marks_unhealthy(self) -> None: - cache = ProviderHealthCache(failure_threshold=1) - probe = HealthProbe(cache, {"ollama": make_probe(returns=False)}) - await probe.tick_once() - assert cache.is_healthy("ollama") is False - state = cache.get_state("ollama") - assert state is not None - assert "false" in (state.last_error or "") - - @pytest.mark.asyncio - async def test_raising_marks_unhealthy_with_exc_info(self) -> None: - cache = ProviderHealthCache(failure_threshold=1) - probe = HealthProbe( - cache, {"ollama": make_probe(raises=ConnectionError)} - ) - await probe.tick_once() - assert cache.is_healthy("ollama") is False - state = cache.get_state("ollama") - assert state is not None - assert "ConnectionError" in (state.last_error or "") - - @pytest.mark.asyncio - async def test_timeout_marks_unhealthy(self) -> None: - cache = ProviderHealthCache(failure_threshold=1) - probe = HealthProbe( - cache, - {"ollama": make_slow_probe(delay=1.0)}, - timeout=0.05, - ) - await probe.tick_once() - assert cache.is_healthy("ollama") is False - state = cache.get_state("ollama") - assert state is not None - assert "timeout" in (state.last_error or "").lower() - - @pytest.mark.asyncio - async def test_one_bad_probe_does_not_sink_others(self) -> None: - # Probe 'ollama' raises — 'groq' must still be evaluated and marked - # healthy. Bug shape: an unhandled exception in gather() sinks the - # whole loop. - cache = ProviderHealthCache(failure_threshold=1) - probe = HealthProbe( - cache, - { - "ollama": make_probe(raises=RuntimeError), - "groq": make_probe(returns=True), - }, - ) - await probe.tick_once() - assert cache.is_healthy("ollama") is False - assert cache.is_healthy("groq") is True - - @pytest.mark.asyncio - async def test_concurrent_probes(self) -> None: - # All probes should run in parallel — total elapsed wall-clock for - # N x 100ms probes should be well under N*100ms. - import time - - cache = ProviderHealthCache() - probes = {f"p{i}": make_slow_probe(delay=0.1) for i in range(5)} - probe = HealthProbe(cache, probes, timeout=1.0) - t0 = time.perf_counter() - await probe.tick_once() - elapsed = time.perf_counter() - t0 - assert elapsed < 0.3, f"probes ran serially? elapsed={elapsed:.3f}s" - - @pytest.mark.asyncio - async def test_empty_probes_is_noop(self) -> None: - cache = ProviderHealthCache() - probe = HealthProbe(cache, {}) - # No exception, no state mutation. - await probe.tick_once() - assert cache.snapshot() == {} - - -# --------------------------------------------------------------------------- -# start / stop lifecycle -# --------------------------------------------------------------------------- - - -class TestLifecycle: - @pytest.mark.asyncio - async def test_start_runs_initial_tick_immediately(self) -> None: - cache = ProviderHealthCache(failure_threshold=1) - cache.mark_unhealthy("ollama", "stale") - probe = HealthProbe(cache, {"ollama": make_probe(returns=True)}, interval=10.0) - await probe.start() - # Give the loop one event-loop turn to run the initial tick before - # blocking on the long sleep. - await asyncio.sleep(0.01) - assert cache.is_healthy("ollama") is True - await probe.stop() - - @pytest.mark.asyncio - async def test_stop_cancels_cleanly(self) -> None: - cache = ProviderHealthCache() - probe = HealthProbe( - cache, {"ollama": make_probe()}, interval=10.0, timeout=1.0 - ) - await probe.start() - assert probe.running is True - await probe.stop() - assert probe.running is False - - @pytest.mark.asyncio - async def test_start_is_idempotent(self) -> None: - cache = ProviderHealthCache() - probe = HealthProbe(cache, {"ollama": make_probe()}, interval=10.0) - await probe.start() - await probe.start() # must not spawn a second task - assert probe.running is True - await probe.stop() - - @pytest.mark.asyncio - async def test_stop_without_start_is_safe(self) -> None: - cache = ProviderHealthCache() - probe = HealthProbe(cache, {}) - await probe.stop() # idempotent / safe pre-start - - @pytest.mark.asyncio - async def test_loop_keeps_running_after_tick_error(self) -> None: - # Even if every probe explodes, the loop must keep ticking. - cache = ProviderHealthCache(failure_threshold=1) - fn, count = make_call_counter() - # Wrap with one that raises — but tick_once internally catches - # per-probe via gather(return_exceptions=True). Force an outer error - # via an evil probe key that the dict can't handle? Easier: use the - # call-counter to verify multiple ticks happened. - probe = HealthProbe( - cache, - {"counter": fn}, - interval=0.05, # short interval for test - timeout=1.0, - ) - await probe.start() - await asyncio.sleep(0.18) # ~3-4 ticks - await probe.stop() - # Initial tick + at least 2 interval ticks. - assert count() >= 3 diff --git a/services/mana-llm/tests/test_m4_observability.py b/services/mana-llm/tests/test_m4_observability.py deleted file mode 100644 index d0fe0c84b..000000000 --- a/services/mana-llm/tests/test_m4_observability.py +++ /dev/null @@ -1,415 +0,0 @@ -"""Tests for M4: observability + debug endpoints + reload.""" - -from __future__ import annotations - -import asyncio -import os -import signal -from pathlib import Path - -import httpx -import pytest -from fastapi.testclient import TestClient -from prometheus_client import REGISTRY - -from src.aliases import AliasRegistry -from src.health import ProviderHealthCache -from src.models import ( - ChatCompletionRequest, - ChatCompletionResponse, - Choice, - Message, - MessageResponse, -) -from src.providers import ProviderRouter -from src.utils.metrics import ( - record_alias_resolved, - record_fallback, - set_provider_healthy, -) - - -# --------------------------------------------------------------------------- -# Cache → listener → metric gauge -# --------------------------------------------------------------------------- - - -class TestHealthChangeListener: - def test_listener_fires_on_unhealthy_transition(self) -> None: - cache = ProviderHealthCache(failure_threshold=2) - events: list[tuple[str, bool]] = [] - cache.add_listener(lambda p, h: events.append((p, h))) - - # First failure: still healthy → no transition. - cache.mark_unhealthy("ollama", "blip") - assert events == [] - - # Second failure: transition healthy→unhealthy → fires. - cache.mark_unhealthy("ollama", "boom") - assert events == [("ollama", False)] - - def test_listener_fires_on_recovery(self) -> None: - cache = ProviderHealthCache(failure_threshold=1) - events: list[tuple[str, bool]] = [] - cache.add_listener(lambda p, h: events.append((p, h))) - - cache.mark_unhealthy("ollama", "boom") - assert events == [("ollama", False)] - - cache.mark_healthy("ollama") - assert events == [("ollama", False), ("ollama", True)] - - def test_steady_state_does_not_fire(self) -> None: - cache = ProviderHealthCache(failure_threshold=1) - events: list[tuple[str, bool]] = [] - cache.add_listener(lambda p, h: events.append((p, h))) - - # Three healthy ops in a row — no transitions, no events. - for _ in range(3): - cache.mark_healthy("ollama") - assert events == [] - - def test_listener_exception_does_not_break_cache(self) -> None: - cache = ProviderHealthCache(failure_threshold=1) - - def bad(_provider: str, _healthy: bool) -> None: - raise RuntimeError("listener boom") - - cache.add_listener(bad) - # Should NOT raise — the cache must keep working with a broken - # listener, otherwise one bad metric callback would brick the - # whole router. - cache.mark_unhealthy("ollama", "x") - assert cache.is_healthy("ollama") is False - - def test_multiple_listeners(self) -> None: - cache = ProviderHealthCache(failure_threshold=1) - a: list = [] - b: list = [] - cache.add_listener(lambda p, h: a.append((p, h))) - cache.add_listener(lambda p, h: b.append((p, h))) - - cache.mark_unhealthy("ollama", "x") - assert a == [("ollama", False)] - assert b == [("ollama", False)] - - -# --------------------------------------------------------------------------- -# Prometheus metrics — counters/gauges actually move -# --------------------------------------------------------------------------- - - -def _counter_value(name: str, labels: dict[str, str]) -> float: - """Helper: read the current value of a labeled Prometheus metric.""" - samples = REGISTRY.get_sample_value(name, labels=labels) - return samples or 0.0 - - -class TestMetricsRecording: - def test_record_alias_resolved_increments(self) -> None: - before = _counter_value( - "mana_llm_alias_resolved_total", - {"alias": "mana/test-class", "target": "ollama/x:1b"}, - ) - record_alias_resolved("mana/test-class", "ollama/x:1b") - after = _counter_value( - "mana_llm_alias_resolved_total", - {"alias": "mana/test-class", "target": "ollama/x:1b"}, - ) - assert after - before == pytest.approx(1.0) - - def test_record_fallback_increments(self) -> None: - before = _counter_value( - "mana_llm_fallback_total", - {"from_model": "ollama/x", "to_model": "groq/y", "reason": "ConnectError"}, - ) - record_fallback("ollama/x", "groq/y", "ConnectError") - after = _counter_value( - "mana_llm_fallback_total", - {"from_model": "ollama/x", "to_model": "groq/y", "reason": "ConnectError"}, - ) - assert after - before == pytest.approx(1.0) - - def test_set_provider_healthy_writes_gauge(self) -> None: - set_provider_healthy("test_provider_xyz", True) - v = REGISTRY.get_sample_value( - "mana_llm_provider_healthy", labels={"provider": "test_provider_xyz"} - ) - assert v == 1.0 - - set_provider_healthy("test_provider_xyz", False) - v = REGISTRY.get_sample_value( - "mana_llm_provider_healthy", labels={"provider": "test_provider_xyz"} - ) - assert v == 0.0 - - -# --------------------------------------------------------------------------- -# Router → metrics: end-to-end through a fallback -# --------------------------------------------------------------------------- - - -class _OkProvider: - """Minimal provider double — only what the router uses for chat.""" - - name = "ok-provider" - supports_tools = True - - def __init__(self, name: str, fail_with: BaseException | None = None) -> None: - self.name = name - self.fail_with = fail_with - self.calls = 0 - - def model_supports_tools(self, model: str) -> bool: - return True - - async def chat_completion(self, request, model): - self.calls += 1 - if self.fail_with is not None: - raise self.fail_with - return ChatCompletionResponse( - model=f"{self.name}/{model}", - choices=[Choice(message=MessageResponse(content="ok"))], - ) - - async def chat_completion_stream(self, request, model): # pragma: no cover - if False: # pragma: no cover - yield None - - async def list_models(self): - return [] - - async def embeddings(self, request, model): - raise NotImplementedError - - async def health_check(self): - return {"status": "healthy"} - - async def close(self): - pass - - -def _aliases(tmp_path: Path) -> AliasRegistry: - cfg = ( - "aliases:\n" - " mana/two-step:\n" - ' description: "x"\n' - " chain:\n" - " - alpha/m1\n" - " - beta/m2\n" - ) - p = tmp_path / "aliases.yaml" - p.write_text(cfg) - return AliasRegistry(p) - - -class TestRouterMetricsIntegration: - @pytest.mark.asyncio - async def test_alias_resolved_metric_records_target(self, tmp_path: Path) -> None: - aliases = _aliases(tmp_path) - cache = ProviderHealthCache() - router = ProviderRouter(aliases=aliases, health_cache=cache) - router.providers = {"alpha": _OkProvider("alpha")} # beta not configured - - before = _counter_value( - "mana_llm_alias_resolved_total", - {"alias": "mana/two-step", "target": "alpha/m1"}, - ) - await router.chat_completion( - ChatCompletionRequest( - model="mana/two-step", - messages=[Message(role="user", content="hi")], - ) - ) - after = _counter_value( - "mana_llm_alias_resolved_total", - {"alias": "mana/two-step", "target": "alpha/m1"}, - ) - assert after - before == pytest.approx(1.0) - - @pytest.mark.asyncio - async def test_fallback_metric_records_transition(self, tmp_path: Path) -> None: - aliases = _aliases(tmp_path) - cache = ProviderHealthCache() - router = ProviderRouter(aliases=aliases, health_cache=cache) - router.providers = { - "alpha": _OkProvider("alpha", fail_with=httpx.ConnectError("dead")), - "beta": _OkProvider("beta"), - } - - before = _counter_value( - "mana_llm_fallback_total", - {"from_model": "alpha/m1", "to_model": "beta/m2", "reason": "ConnectError"}, - ) - await router.chat_completion( - ChatCompletionRequest( - model="mana/two-step", - messages=[Message(role="user", content="hi")], - ) - ) - after = _counter_value( - "mana_llm_fallback_total", - {"from_model": "alpha/m1", "to_model": "beta/m2", "reason": "ConnectError"}, - ) - assert after - before == pytest.approx(1.0) - - @pytest.mark.asyncio - async def test_direct_model_does_not_record_alias_metric( - self, tmp_path: Path - ) -> None: - # Direct provider/model is not an alias — ALIAS_RESOLVED counter - # must stay flat for those calls. - aliases = _aliases(tmp_path) - cache = ProviderHealthCache() - router = ProviderRouter(aliases=aliases, health_cache=cache) - router.providers = {"alpha": _OkProvider("alpha")} - - before = _counter_value( - "mana_llm_alias_resolved_total", - {"alias": "alpha/anything", "target": "alpha/anything"}, - ) - await router.chat_completion( - ChatCompletionRequest( - model="alpha/anything", - messages=[Message(role="user", content="hi")], - ) - ) - after = _counter_value( - "mana_llm_alias_resolved_total", - {"alias": "alpha/anything", "target": "alpha/anything"}, - ) - # Counter must have NOT increased — direct calls aren't aliases. - assert after == before - - -# --------------------------------------------------------------------------- -# Debug endpoints: GET /v1/aliases, GET /v1/health -# --------------------------------------------------------------------------- - - -@pytest.fixture -def client(): - from src.main import app - - with TestClient(app) as c: - yield c - - -class TestDebugEndpoints: - def test_v1_aliases_returns_shipped_config(self, client: TestClient) -> None: - resp = client.get("/v1/aliases") - assert resp.status_code == 200 - data = resp.json() - names = [a["name"] for a in data["aliases"]] - # The five canonical classes must always be present. - for expected in ( - "mana/fast-text", - "mana/long-form", - "mana/structured", - "mana/reasoning", - "mana/vision", - ): - assert expected in names - # Default is set in the shipped config. - assert data["default"] == "mana/fast-text" - - def test_v1_aliases_chain_format(self, client: TestClient) -> None: - resp = client.get("/v1/aliases") - data = resp.json() - long_form = next(a for a in data["aliases"] if a["name"] == "mana/long-form") - # Each chain entry is a `provider/model` string. - assert all("/" in entry for entry in long_form["chain"]) - assert len(long_form["chain"]) >= 2 # plan requires at least one cloud fallback - - def test_v1_health_includes_all_providers(self, client: TestClient) -> None: - resp = client.get("/v1/health") - assert resp.status_code == 200 - data = resp.json() - assert "status" in data - assert "providers" in data - # ollama is always configured (provider list is non-empty). - assert "ollama" in data["providers"] - for name, info in data["providers"].items(): - assert "status" in info - assert "consecutive_failures" in info - - -# --------------------------------------------------------------------------- -# X-Mana-LLM-Resolved header on non-streaming responses -# --------------------------------------------------------------------------- - - -class TestResolvedHeader: - """The header is the consumer's hook for token-cost attribution. - - Tested at the router level — wiring through main.py would need a - real provider connection, which isn't available in unit tests. - """ - - @pytest.mark.asyncio - async def test_response_model_field_carries_resolved_target( - self, tmp_path: Path - ) -> None: - # The header value is `response.model`; verify that field reflects - # the actual chain entry that served, not the requested alias. - aliases = _aliases(tmp_path) - cache = ProviderHealthCache() - router = ProviderRouter(aliases=aliases, health_cache=cache) - # Force fallback to beta. - router.providers = { - "alpha": _OkProvider("alpha", fail_with=httpx.ConnectError("d")), - "beta": _OkProvider("beta"), - } - - resp = await router.chat_completion( - ChatCompletionRequest( - model="mana/two-step", - messages=[Message(role="user", content="hi")], - ) - ) - # Even though the caller asked for `mana/two-step`, the resolved - # field shows the entry that actually answered. - assert resp.model == "beta/m2" - - -# --------------------------------------------------------------------------- -# SIGHUP reload — only meaningful on Unix; tested by signalling the proc -# --------------------------------------------------------------------------- - - -class TestSighupReload: - """SIGHUP triggers ``alias_registry.reload()``; reload-error keeps state. - - The signal-handler wiring lives in main.py and only installs when - the loop is running in the main thread. We exercise the reload - semantics here directly on the registry instead — the signal-handler - code path itself is a 4-line wrapper around ``reload()``. - """ - - def test_reload_picks_up_yaml_edits(self, tmp_path: Path) -> None: - path = tmp_path / "aliases.yaml" - path.write_text( - "aliases:\n" - " mana/x:\n" - ' description: "x"\n' - " chain:\n" - " - ollama/foo:1b\n" - ) - reg = AliasRegistry(path) - assert reg.resolve_chain("mana/x") == ("ollama/foo:1b",) - - # Edit on disk, reload (this is exactly what the SIGHUP handler - # does — minus the signal plumbing). - path.write_text( - "aliases:\n" - " mana/x:\n" - ' description: "x"\n' - " chain:\n" - " - ollama/bar:1b\n" - " - groq/llama-3.1-8b-instant\n" - ) - reg.reload() - assert reg.resolve_chain("mana/x") == ( - "ollama/bar:1b", - "groq/llama-3.1-8b-instant", - ) diff --git a/services/mana-llm/tests/test_providers.py b/services/mana-llm/tests/test_providers.py deleted file mode 100644 index e5c3ffcb0..000000000 --- a/services/mana-llm/tests/test_providers.py +++ /dev/null @@ -1,123 +0,0 @@ -"""Provider tests.""" - -from pathlib import Path - -import pytest - -from src.aliases import AliasRegistry -from src.health import ProviderHealthCache -from src.models import ChatCompletionRequest, EmbeddingRequest, Message -from src.providers import OllamaProvider, OpenAICompatProvider, ProviderRouter - - -@pytest.fixture -def shipped_aliases() -> AliasRegistry: - """The repo's real aliases.yaml — same one production uses.""" - return AliasRegistry(Path(__file__).resolve().parents[1] / "aliases.yaml") - - -@pytest.fixture -def router(shipped_aliases: AliasRegistry) -> ProviderRouter: - return ProviderRouter(aliases=shipped_aliases, health_cache=ProviderHealthCache()) - - -class TestProviderRouter: - """Tests for the helpers exposed by the router.""" - - def test_parse_model_with_provider(self, router: ProviderRouter) -> None: - provider, model = router._parse_model("ollama/gemma3:4b") - assert provider == "ollama" - assert model == "gemma3:4b" - - def test_parse_model_without_provider(self, router: ProviderRouter) -> None: - # Bare names default to Ollama for OpenAI-style compat. - provider, model = router._parse_model("gemma3:4b") - assert provider == "ollama" - assert model == "gemma3:4b" - - def test_parse_model_openrouter(self, router: ProviderRouter) -> None: - provider, model = router._parse_model("openrouter/meta-llama/llama-3.1-8b-instruct") - assert provider == "openrouter" - assert model == "meta-llama/llama-3.1-8b-instruct" - - @pytest.mark.asyncio - async def test_embeddings_unknown_provider_raises(self, router: ProviderRouter) -> None: - # Embeddings don't go through the alias/fallback pipeline — they - # hit the requested provider directly. Asking for an unconfigured - # one is a config error and must raise loudly. - with pytest.raises(ValueError, match="not available"): - await router.embeddings( - EmbeddingRequest(model="bogus_provider/x", input="hi") - ) - - -class TestOllamaProvider: - """Test Ollama provider.""" - - def test_convert_simple_messages(self): - """Test converting simple text messages.""" - provider = OllamaProvider() - request = ChatCompletionRequest( - model="gemma3:4b", - messages=[ - Message(role="user", content="Hello"), - ], - ) - - messages = provider._convert_messages(request) - assert len(messages) == 1 - assert messages[0]["role"] == "user" - assert messages[0]["content"] == "Hello" - - def test_convert_multimodal_messages(self): - """Test converting multimodal messages.""" - provider = OllamaProvider() - request = ChatCompletionRequest( - model="llava:7b", - messages=[ - Message( - role="user", - content=[ - {"type": "text", "text": "What's in this image?"}, - { - "type": "image_url", - "image_url": {"url": "data:image/png;base64,iVBORw0KGgo="}, - }, - ], - ), - ], - ) - - messages = provider._convert_messages(request) - assert len(messages) == 1 - assert messages[0]["role"] == "user" - assert messages[0]["content"] == "What's in this image?" - assert "images" in messages[0] - assert len(messages[0]["images"]) == 1 - - -class TestOpenAICompatProvider: - """Test OpenAI-compatible provider.""" - - def test_convert_simple_messages(self): - """Test converting simple text messages.""" - provider = OpenAICompatProvider( - name="test", - base_url="http://localhost", - api_key="test-key", - ) - - request = ChatCompletionRequest( - model="test-model", - messages=[ - Message(role="system", content="You are helpful."), - Message(role="user", content="Hello"), - ], - ) - - messages = provider._convert_messages(request) - assert len(messages) == 2 - assert messages[0]["role"] == "system" - assert messages[0]["content"] == "You are helpful." - assert messages[1]["role"] == "user" - assert messages[1]["content"] == "Hello" diff --git a/services/mana-llm/tests/test_router_fallback.py b/services/mana-llm/tests/test_router_fallback.py deleted file mode 100644 index cb1f34743..000000000 --- a/services/mana-llm/tests/test_router_fallback.py +++ /dev/null @@ -1,526 +0,0 @@ -"""Tests for ProviderRouter fallback / alias execution (M3).""" - -from __future__ import annotations - -from collections.abc import AsyncIterator -from typing import Any - -import httpx -import pytest - -from src.aliases import AliasRegistry -from src.health import ProviderHealthCache -from src.models import ( - ChatCompletionRequest, - ChatCompletionResponse, - ChatCompletionStreamResponse, - Choice, - DeltaContent, - EmbeddingRequest, - EmbeddingResponse, - Message, - MessageResponse, - ModelInfo, - StreamChoice, -) -from src.providers import ProviderRouter -from src.providers.base import LLMProvider -from src.providers.errors import ( - NoHealthyProviderError, - ProviderAuthError, - ProviderCapabilityError, - ProviderRateLimitError, -) - - -# --------------------------------------------------------------------------- -# Test doubles -# --------------------------------------------------------------------------- - - -class MockProvider(LLMProvider): - """Provider that lets tests inject a sequence of behaviours. - - Each call pops one entry from ``behaviors``. Strings ``"ok"`` and - ``"empty"`` are sentinels for normal returns; everything else (a - BaseException instance / class) is raised. - """ - - supports_tools = True - - def __init__(self, name: str, behaviors: list[Any] | None = None) -> None: - self.name = name - self._behaviors: list[Any] = list(behaviors or []) - self.calls: list[str] = [] - - def push(self, *behaviors: Any) -> None: - self._behaviors.extend(behaviors) - - def _next(self) -> Any: - return self._behaviors.pop(0) if self._behaviors else "ok" - - async def chat_completion( - self, request: ChatCompletionRequest, model: str - ) -> ChatCompletionResponse: - self.calls.append(model) - b = self._next() - if isinstance(b, type) and issubclass(b, BaseException): - raise b("simulated") - if isinstance(b, BaseException): - raise b - return _ok_response(self.name, model) - - async def chat_completion_stream( - self, request: ChatCompletionRequest, model: str - ) -> AsyncIterator[ChatCompletionStreamResponse]: - self.calls.append(model) - b = self._next() - if isinstance(b, type) and issubclass(b, BaseException): - raise b("simulated") - if isinstance(b, BaseException): - raise b - if b == "empty": - return - for content in ("Hello", " ", "world"): - yield ChatCompletionStreamResponse( - model=f"{self.name}/{model}", - choices=[StreamChoice(delta=DeltaContent(content=content))], - ) - - async def list_models(self) -> list[ModelInfo]: - return [ModelInfo(id=f"{self.name}/{m}") for m in ("modelA", "modelB")] - - async def embeddings( - self, request: EmbeddingRequest, model: str - ) -> EmbeddingResponse: - raise NotImplementedError - - async def health_check(self) -> dict[str, Any]: - return {"status": "healthy"} - - -class FailFirstChunkProvider(MockProvider): - """Streaming provider that raises BEFORE the first chunk every time. - - Kept separate from MockProvider's behaviour list so the per-call - semantics stay simple — this one models a permanently-broken streamer. - """ - - def __init__(self, name: str, exc: BaseException) -> None: - super().__init__(name) - self._exc = exc - - async def chat_completion_stream(self, request, model): # type: ignore[override] - self.calls.append(model) - raise self._exc - # the yield is unreachable but keeps the function an async generator - yield # pragma: no cover - - -def _ok_response(provider: str, model: str) -> ChatCompletionResponse: - return ChatCompletionResponse( - model=f"{provider}/{model}", - choices=[ - Choice( - message=MessageResponse(content="ok"), - finish_reason="stop", - ) - ], - ) - - -def _request(model: str) -> ChatCompletionRequest: - return ChatCompletionRequest( - model=model, - messages=[Message(role="user", content="hi")], - ) - - -def _aliases_yaml(tmp_path) -> AliasRegistry: - """A two-alias config used across most tests.""" - cfg = ( - "aliases:\n" - " mana/long-form:\n" - ' description: "long"\n' - " chain:\n" - " - alpha/m1\n" - " - beta/m2\n" - " - gamma/m3\n" - " mana/single:\n" - ' description: "single-entry"\n' - " chain:\n" - " - alpha/solo\n" - ) - p = tmp_path / "aliases.yaml" - p.write_text(cfg) - return AliasRegistry(p) - - -def _make_router( - tmp_path, - *, - providers: dict[str, MockProvider], - cache: ProviderHealthCache | None = None, -) -> ProviderRouter: - aliases = _aliases_yaml(tmp_path) - router = ProviderRouter(aliases=aliases, health_cache=cache or ProviderHealthCache()) - # Replace the auto-initialised live providers with the test doubles. - router.providers = dict(providers) - return router - - -# --------------------------------------------------------------------------- -# Non-streaming chain walking -# --------------------------------------------------------------------------- - - -class TestChatCompletionChain: - @pytest.mark.asyncio - async def test_first_provider_ok_returns_immediately(self, tmp_path) -> None: - alpha = MockProvider("alpha", ["ok"]) - beta = MockProvider("beta") - router = _make_router(tmp_path, providers={"alpha": alpha, "beta": beta}) - - resp = await router.chat_completion(_request("mana/long-form")) - - assert resp.model == "alpha/m1" - assert alpha.calls == ["m1"] - assert beta.calls == [] # never reached - - @pytest.mark.asyncio - async def test_falls_through_on_connect_error(self, tmp_path) -> None: - alpha = MockProvider("alpha", [httpx.ConnectError("dead")]) - beta = MockProvider("beta", ["ok"]) - router = _make_router(tmp_path, providers={"alpha": alpha, "beta": beta}) - - resp = await router.chat_completion(_request("mana/long-form")) - - assert resp.model == "beta/m2" - assert alpha.calls == ["m1"] - assert beta.calls == ["m2"] - - @pytest.mark.asyncio - async def test_skips_unconfigured_chain_entries(self, tmp_path) -> None: - # gamma isn't configured at all → chain should silently skip it - # rather than raise. - alpha = MockProvider("alpha", [httpx.ConnectError("dead")]) - beta = MockProvider("beta", [httpx.ConnectError("dead too")]) - router = _make_router(tmp_path, providers={"alpha": alpha, "beta": beta}) - - with pytest.raises(NoHealthyProviderError) as exc_info: - await router.chat_completion(_request("mana/long-form")) - # All three entries appear in attempts: two as ConnectError, one - # as unconfigured (not a fatal error, just skipped). - attempts = exc_info.value.attempts - assert ("alpha/m1", "ConnectError") in attempts - assert ("beta/m2", "ConnectError") in attempts - assert ("gamma/m3", "unconfigured") in attempts - - @pytest.mark.asyncio - async def test_skips_cache_unhealthy(self, tmp_path) -> None: - cache = ProviderHealthCache(failure_threshold=1) - cache.mark_unhealthy("alpha", "stale") - alpha = MockProvider("alpha", ["ok"]) - beta = MockProvider("beta", ["ok"]) - router = _make_router( - tmp_path, providers={"alpha": alpha, "beta": beta}, cache=cache - ) - - resp = await router.chat_completion(_request("mana/long-form")) - - assert alpha.calls == [] # router skipped per cache - assert beta.calls == ["m2"] - assert resp.model == "beta/m2" - - @pytest.mark.asyncio - async def test_5xx_treated_as_retryable(self, tmp_path) -> None: - five_hundred = httpx.HTTPStatusError( - "boom", - request=httpx.Request("POST", "http://x"), - response=httpx.Response(503), - ) - alpha = MockProvider("alpha", [five_hundred]) - beta = MockProvider("beta", ["ok"]) - router = _make_router(tmp_path, providers={"alpha": alpha, "beta": beta}) - - resp = await router.chat_completion(_request("mana/long-form")) - - assert resp.model == "beta/m2" - - @pytest.mark.asyncio - async def test_4xx_propagates(self, tmp_path) -> None: - four_hundred = httpx.HTTPStatusError( - "bad request", - request=httpx.Request("POST", "http://x"), - response=httpx.Response(422), - ) - alpha = MockProvider("alpha", [four_hundred]) - beta = MockProvider("beta", ["ok"]) - router = _make_router(tmp_path, providers={"alpha": alpha, "beta": beta}) - - with pytest.raises(httpx.HTTPStatusError): - await router.chat_completion(_request("mana/long-form")) - # Beta never tried — caller's request needs fixing, retrying - # against another model would just hide the bug. - assert beta.calls == [] - - @pytest.mark.asyncio - async def test_capability_error_propagates(self, tmp_path) -> None: - alpha = MockProvider("alpha", [ProviderCapabilityError("no tools")]) - beta = MockProvider("beta", ["ok"]) - router = _make_router(tmp_path, providers={"alpha": alpha, "beta": beta}) - - with pytest.raises(ProviderCapabilityError): - await router.chat_completion(_request("mana/long-form")) - assert beta.calls == [] - - @pytest.mark.asyncio - async def test_auth_error_propagates(self, tmp_path) -> None: - # Auth errors mean OUR setup is broken (wrong key); falling back - # to the next provider hides the misconfiguration. - alpha = MockProvider("alpha", [ProviderAuthError("bad key")]) - beta = MockProvider("beta", ["ok"]) - router = _make_router(tmp_path, providers={"alpha": alpha, "beta": beta}) - - with pytest.raises(ProviderAuthError): - await router.chat_completion(_request("mana/long-form")) - assert beta.calls == [] - - @pytest.mark.asyncio - async def test_rate_limit_is_retryable(self, tmp_path) -> None: - alpha = MockProvider("alpha", [ProviderRateLimitError("slow down")]) - beta = MockProvider("beta", ["ok"]) - router = _make_router(tmp_path, providers={"alpha": alpha, "beta": beta}) - - resp = await router.chat_completion(_request("mana/long-form")) - - assert resp.model == "beta/m2" - - @pytest.mark.asyncio - async def test_all_fail_raises_no_healthy_provider(self, tmp_path) -> None: - alpha = MockProvider("alpha", [httpx.ConnectError("a")]) - beta = MockProvider("beta", [httpx.ConnectError("b")]) - gamma = MockProvider("gamma", [httpx.ConnectError("c")]) - router = _make_router( - tmp_path, providers={"alpha": alpha, "beta": beta, "gamma": gamma} - ) - - with pytest.raises(NoHealthyProviderError) as exc_info: - await router.chat_completion(_request("mana/long-form")) - assert exc_info.value.model_or_alias == "mana/long-form" - assert isinstance(exc_info.value.last_exception, httpx.ConnectError) - # 503 status so calling code (mana-api etc.) can decide to retry - # later vs surface a clean error to the user. - assert exc_info.value.http_status == 503 - - @pytest.mark.asyncio - async def test_direct_provider_string_no_alias_resolution(self, tmp_path) -> None: - # Caller bypasses aliases by passing a direct provider/model. - # No fallback chain — fail = fail. - alpha = MockProvider("alpha", [httpx.ConnectError("dead")]) - beta = MockProvider("beta", ["ok"]) - router = _make_router(tmp_path, providers={"alpha": alpha, "beta": beta}) - - with pytest.raises(NoHealthyProviderError): - await router.chat_completion(_request("alpha/anything")) - # Beta would have served if this had been an alias — but it - # wasn't, so beta never gets touched. - assert beta.calls == [] - - -# --------------------------------------------------------------------------- -# Health-cache feedback: success clears, failure marks -# --------------------------------------------------------------------------- - - -class TestHealthCacheFeedback: - @pytest.mark.asyncio - async def test_success_marks_provider_healthy(self, tmp_path) -> None: - cache = ProviderHealthCache(failure_threshold=1) - cache.mark_unhealthy("alpha", "stale-from-probe") - # After the cache TTL the cache thinks alpha might be OK again, - # so the router will try it; success must fully clear the state. - # (Force half-open by zeroing backoff.) - alpha = MockProvider("alpha", ["ok"]) - router = _make_router( - tmp_path, - providers={"alpha": alpha}, - cache=ProviderHealthCache(), # fresh cache, alpha optimistic - ) - - await router.chat_completion(_request("mana/single")) - - assert router.health_cache.get_state("alpha").healthy is True - assert router.health_cache.get_state("alpha").consecutive_failures == 0 - - @pytest.mark.asyncio - async def test_failure_marks_provider_unhealthy(self, tmp_path) -> None: - # threshold=1 so a single fail is enough to flip the breaker. - cache = ProviderHealthCache(failure_threshold=1) - alpha = MockProvider("alpha", [httpx.ConnectError("boom")]) - beta = MockProvider("beta", ["ok"]) - router = _make_router( - tmp_path, providers={"alpha": alpha, "beta": beta}, cache=cache - ) - - await router.chat_completion(_request("mana/long-form")) - - assert cache.get_state("alpha").healthy is False - assert cache.get_state("alpha").last_error is not None - assert "ConnectError" in cache.get_state("alpha").last_error - - @pytest.mark.asyncio - async def test_propagating_error_does_not_touch_cache(self, tmp_path) -> None: - # Auth/Capability errors are about CALLER state, not provider - # health — the cache must stay clean so a real outage later - # isn't masked by stale "marked unhealthy because of bad key". - cache = ProviderHealthCache(failure_threshold=1) - alpha = MockProvider("alpha", [ProviderAuthError("bad key")]) - router = _make_router(tmp_path, providers={"alpha": alpha}, cache=cache) - - with pytest.raises(ProviderAuthError): - await router.chat_completion(_request("mana/single")) - - # No state recorded. - assert cache.get_state("alpha") is None - - -# --------------------------------------------------------------------------- -# Streaming pre-first-byte fallback -# --------------------------------------------------------------------------- - - -class TestChatCompletionStream: - @pytest.mark.asyncio - async def test_first_provider_streams_normally(self, tmp_path) -> None: - alpha = MockProvider("alpha", ["ok"]) - beta = MockProvider("beta") - router = _make_router(tmp_path, providers={"alpha": alpha, "beta": beta}) - - chunks = [ - c async for c in router.chat_completion_stream(_request("mana/long-form")) - ] - - assert beta.calls == [] - assert len(chunks) == 3 - assert "".join(c.choices[0].delta.content or "" for c in chunks) == "Hello world" - - @pytest.mark.asyncio - async def test_pre_first_byte_failure_falls_back(self, tmp_path) -> None: - alpha = FailFirstChunkProvider("alpha", httpx.ConnectError("dead")) - beta = MockProvider("beta", ["ok"]) - router = _make_router(tmp_path, providers={"alpha": alpha, "beta": beta}) - - chunks = [ - c async for c in router.chat_completion_stream(_request("mana/long-form")) - ] - - assert alpha.calls == ["m1"] - assert beta.calls == ["m2"] - assert len(chunks) == 3 - assert all(c.model == "beta/m2" for c in chunks) - - @pytest.mark.asyncio - async def test_pre_first_byte_4xx_propagates_no_fallback(self, tmp_path) -> None: - alpha = FailFirstChunkProvider("alpha", ProviderCapabilityError("no tools")) - beta = MockProvider("beta", ["ok"]) - router = _make_router(tmp_path, providers={"alpha": alpha, "beta": beta}) - - with pytest.raises(ProviderCapabilityError): - async for _ in router.chat_completion_stream(_request("mana/long-form")): - pass - assert beta.calls == [] - - @pytest.mark.asyncio - async def test_empty_stream_commits_without_fallback(self, tmp_path) -> None: - # Empty-but-successful stream is a valid response, not a failure - # we should retry — committing avoids accidentally calling two - # providers and double-billing. - alpha = MockProvider("alpha", ["empty"]) - beta = MockProvider("beta", ["ok"]) - router = _make_router(tmp_path, providers={"alpha": alpha, "beta": beta}) - - chunks = [ - c async for c in router.chat_completion_stream(_request("mana/long-form")) - ] - - assert chunks == [] - assert beta.calls == [] # didn't fall through - - @pytest.mark.asyncio - async def test_mid_stream_failure_does_not_fall_back(self, tmp_path) -> None: - # Custom provider that yields once then raises mid-stream — the - # router has already committed and must let the error propagate - # rather than splice in another provider's voice. - class MidStreamFailProvider(MockProvider): - async def chat_completion_stream(self, request, model): # type: ignore[override] - self.calls.append(model) - yield ChatCompletionStreamResponse( - model=f"{self.name}/{model}", - choices=[StreamChoice(delta=DeltaContent(content="halb"))], - ) - raise httpx.RemoteProtocolError("connection dropped") - - alpha = MidStreamFailProvider("alpha") - beta = MockProvider("beta", ["ok"]) - router = _make_router(tmp_path, providers={"alpha": alpha, "beta": beta}) - - collected: list[str] = [] - with pytest.raises(httpx.RemoteProtocolError): - async for chunk in router.chat_completion_stream(_request("mana/long-form")): - collected.append(chunk.choices[0].delta.content or "") - - # We got the half-chunk that landed before the break; beta was - # NOT called as fallback. - assert collected == ["halb"] - assert beta.calls == [] - - @pytest.mark.asyncio - async def test_all_fail_streaming_raises_no_healthy_provider(self, tmp_path) -> None: - alpha = FailFirstChunkProvider("alpha", httpx.ConnectError("a")) - beta = FailFirstChunkProvider("beta", httpx.ConnectError("b")) - gamma = FailFirstChunkProvider("gamma", httpx.ConnectError("c")) - router = _make_router( - tmp_path, providers={"alpha": alpha, "beta": beta, "gamma": gamma} - ) - - with pytest.raises(NoHealthyProviderError): - async for _ in router.chat_completion_stream(_request("mana/long-form")): - pass - - -# --------------------------------------------------------------------------- -# Health-check shape (still using the cache snapshot) -# --------------------------------------------------------------------------- - - -class TestHealthCheck: - @pytest.mark.asyncio - async def test_health_check_lists_known_providers(self, tmp_path) -> None: - # Even if no probe has run yet, every configured provider should - # appear in the snapshot (zero-defaults) so /health has a stable - # shape for monitors. - alpha = MockProvider("alpha") - beta = MockProvider("beta") - router = _make_router(tmp_path, providers={"alpha": alpha, "beta": beta}) - - out = await router.health_check() - - assert set(out["providers"].keys()) == {"alpha", "beta"} - assert out["status"] == "healthy" - assert all(p["status"] == "healthy" for p in out["providers"].values()) - - @pytest.mark.asyncio - async def test_health_check_degraded_when_one_unhealthy(self, tmp_path) -> None: - cache = ProviderHealthCache(failure_threshold=1) - cache.mark_unhealthy("alpha", "boom") - alpha = MockProvider("alpha") - beta = MockProvider("beta") - router = _make_router( - tmp_path, providers={"alpha": alpha, "beta": beta}, cache=cache - ) - - out = await router.health_check() - assert out["status"] == "degraded" - assert out["providers"]["alpha"]["status"] == "unhealthy" - assert out["providers"]["beta"]["status"] == "healthy" diff --git a/services/mana-llm/tests/test_streaming.py b/services/mana-llm/tests/test_streaming.py deleted file mode 100644 index 4c97abe00..000000000 --- a/services/mana-llm/tests/test_streaming.py +++ /dev/null @@ -1,57 +0,0 @@ -"""Streaming tests.""" - -import pytest - -from src.models import ChatCompletionStreamResponse, DeltaContent, StreamChoice - - -class TestStreamingModels: - """Test streaming response models.""" - - def test_stream_response_serialization(self): - """Test streaming response serializes correctly.""" - response = ChatCompletionStreamResponse( - id="test-id", - model="ollama/gemma3:4b", - choices=[ - StreamChoice( - delta=DeltaContent(content="Hello"), - ) - ], - ) - - data = response.model_dump(exclude_none=True) - assert data["id"] == "test-id" - assert data["model"] == "ollama/gemma3:4b" - assert data["choices"][0]["delta"]["content"] == "Hello" - - def test_stream_response_with_role(self): - """Test first chunk with role.""" - response = ChatCompletionStreamResponse( - id="test-id", - model="ollama/gemma3:4b", - choices=[ - StreamChoice( - delta=DeltaContent(role="assistant"), - ) - ], - ) - - data = response.model_dump(exclude_none=True) - assert data["choices"][0]["delta"]["role"] == "assistant" - - def test_stream_response_with_finish_reason(self): - """Test final chunk with finish_reason.""" - response = ChatCompletionStreamResponse( - id="test-id", - model="ollama/gemma3:4b", - choices=[ - StreamChoice( - delta=DeltaContent(), - finish_reason="stop", - ) - ], - ) - - data = response.model_dump(exclude_none=True) - assert data["choices"][0]["finish_reason"] == "stop"