feat(mana-llm): M3 — health-aware router with alias + chain fallback

Replaces the old Ollama→Google special-case auto-fallback with the
unified pipeline: caller passes either a direct provider/model or an
alias from the `mana/` namespace; the router resolves to a chain and
walks it skipping unhealthy providers (per ProviderHealthCache from M2),
trying each entry, marking provider unhealthy on retryable errors and
falling through to the next.

Retryable: ConnectError, ReadTimeout, RemoteProtocolError, 5xx,
ProviderRateLimitError. Propagated (don't fall back, don't poison the
cache): ProviderCapabilityError, ProviderAuthError, ProviderBlockedError,
4xx, unknown exception types. The cache stays "what the network told us
about this provider's liveness" — caller errors don't muddy that signal.

Streaming: pre-first-byte fallback only. Once a chunk has been yielded
the provider is committed; mid-stream errors propagate as-is so we
don't splice two voices into one output.

`NoHealthyProviderError` (HTTP 503) carries a structured attempt log —
each chain entry shows up as `(model, reason)` so the cause of a 503
is visible in the response and metrics, not only in service logs.

main.py wires the lifespan: aliases.yaml is loaded, ProviderHealthCache
created, ProviderRouter takes both as constructor deps, HealthProbe
spawned with cheap HTTP probes per configured provider (Ollama
/api/tags, OpenAI-compat /v1/models with Bearer header). Google is
skipped — google-genai SDK has no obvious cheap probe; the call-site
fallback handles real errors.

22 new router tests (test_router_fallback.py): chain walking, capability
& auth propagation, 5xx vs 4xx differentiation, rate-limit retry,
all-fail → NoHealthyProviderError, direct provider strings bypass
aliases, streaming pre-first-byte fallback, mid-stream-failure does
NOT fall back, empty stream commits without retry, cache feedback on
success/failure/non-retryable. Existing test_providers.py updated for
the new constructor signature; all 99 service tests green via the dev
container (Python 3.12).

Legacy purged: `_ollama_concurrent`, `_ollama_health_cache`,
`_can_fallback_to_google`, `_should_use_ollama`, `_fallback_to_google`,
`_get_ollama_health_cached` all gone. The `auto_fallback_enabled` /
`ollama_max_concurrent` settings remain in config.py for now (M5 will
remove them along with the per-feature env-var overrides).

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
Till JS 2026-04-26 20:44:16 +02:00
parent 59557e62d7
commit 3046da3b19
5 changed files with 952 additions and 241 deletions

View file

@ -3,6 +3,7 @@
import logging
import time
from contextlib import asynccontextmanager
from pathlib import Path
from typing import Any
from fastapi import FastAPI, HTTPException, Request
@ -10,8 +11,11 @@ from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import Response
from sse_starlette.sse import EventSourceResponse
from src.aliases import AliasRegistry
from src.api_auth import ApiKeyMiddleware
from src.config import settings
from src.health import ProviderHealthCache
from src.health_probe import HealthProbe, make_http_probe
from src.models import (
ChatCompletionRequest,
ChatCompletionResponse,
@ -33,25 +37,66 @@ logging.basicConfig(
)
logger = logging.getLogger(__name__)
# Global router instance
# Global service singletons
router: ProviderRouter | None = None
health_cache: ProviderHealthCache | None = None
health_probe: HealthProbe | None = None
alias_registry: AliasRegistry | None = None
def _build_provider_probes(
providers: dict[str, Any],
) -> dict[str, Any]:
"""Wire each configured provider to a cheap HTTP probe."""
probes: dict[str, Any] = {}
if "ollama" in providers:
probes["ollama"] = make_http_probe(f"{settings.ollama_url}/api/tags")
if "openrouter" in providers:
probes["openrouter"] = make_http_probe(
f"{settings.openrouter_base_url}/models",
headers={"Authorization": f"Bearer {settings.openrouter_api_key}"},
)
if "groq" in providers:
probes["groq"] = make_http_probe(
f"{settings.groq_base_url}/models",
headers={"Authorization": f"Bearer {settings.groq_api_key}"},
)
if "together" in providers:
probes["together"] = make_http_probe(
f"{settings.together_base_url}/models",
headers={"Authorization": f"Bearer {settings.together_api_key}"},
)
# Google: skipped — google-genai SDK is opaque enough that a probe
# would amount to a real API call. Treat as healthy by default; the
# router's call-site fallback will mark it unhealthy on real errors.
return probes
@asynccontextmanager
async def lifespan(app: FastAPI):
"""Application lifespan management."""
global router
"""Application lifespan: load aliases, spin up router + health probe."""
global router, health_cache, health_probe, alias_registry
# Startup
logger.info("Starting mana-llm service...")
router = ProviderRouter()
logger.info(f"Initialized providers: {list(router.providers.keys())}")
aliases_path = Path(__file__).resolve().parent.parent / "aliases.yaml"
alias_registry = AliasRegistry(aliases_path)
logger.info("Loaded %d aliases from %s", len(alias_registry.list_aliases()), aliases_path)
health_cache = ProviderHealthCache()
router = ProviderRouter(aliases=alias_registry, health_cache=health_cache)
logger.info("Initialized providers: %s", list(router.providers))
health_probe = HealthProbe(health_cache, _build_provider_probes(router.providers))
await health_probe.start()
yield
# Shutdown
logger.info("Shutting down mana-llm service...")
if router:
if health_probe is not None:
await health_probe.stop()
if router is not None:
await router.close()
await close_redis()

View file

@ -76,3 +76,36 @@ class ProviderCapabilityError(ProviderError):
kind = "capability"
http_status = 400
class NoHealthyProviderError(ProviderError):
"""Every entry in the resolved chain has been tried and either was
unconfigured, marked unhealthy by the cache, or failed in flight.
Carries the full attempt log so the caller can see which providers
were tried and why each one was skipped or failed invaluable when
a real outage hits and the API returns 503 instead of the usual 200.
"""
kind = "no_healthy_provider"
http_status = 503
def __init__(
self,
model_or_alias: str,
attempts: list[tuple[str, str]],
last_exception: Exception | None = None,
) -> None:
self.model_or_alias = model_or_alias
self.attempts = list(attempts)
self.last_exception = last_exception
if attempts:
attempt_log = ", ".join(f"{model}={reason}" for model, reason in attempts)
else:
attempt_log = "(no providers in resolved chain were configured)"
msg = (
f"no healthy provider could serve {model_or_alias!r}. Attempts: {attempt_log}"
)
if last_exception is not None:
msg += f". Last error: {type(last_exception).__name__}: {last_exception}"
super().__init__(msg)

View file

@ -1,12 +1,38 @@
"""Provider routing logic for mana-llm with auto-fallback support."""
"""Provider routing with alias resolution and health-aware fallback.
The router is the single entry point that the FastAPI handlers use. Its
job is:
1. Resolve the request's ``model`` field. If it lives in the ``mana/``
namespace the :class:`AliasRegistry` returns an ordered chain of
concrete provider/model strings; everything else is treated as a
single-entry chain (caller passed a direct provider/model).
2. Walk the chain, skipping entries whose provider is either
unconfigured at this deployment (no API key) or currently marked
unhealthy in the :class:`ProviderHealthCache`.
3. Try each remaining entry. Connection errors, timeouts, 5xx, and rate
limits are retryable record them in the cache and move to the next
entry. Capability/auth/blocked errors are caller-fixable and
propagate immediately without touching the health cache.
4. Return the first successful response. If every entry was skipped or
failed, raise :class:`NoHealthyProviderError` (HTTP 503) carrying
the full attempt log so debugging is straightforward.
The full design lives in ``docs/plans/llm-fallback-aliases.md``. This is
the M3 milestone.
"""
from __future__ import annotations
import asyncio
import logging
import time
from collections.abc import AsyncIterator
from typing import Any
from collections.abc import AsyncIterator, Awaitable, Callable
from typing import Any, TypeVar
import httpx
from src.aliases import AliasRegistry
from src.config import settings
from src.health import ProviderHealthCache
from src.models import (
ChatCompletionRequest,
ChatCompletionResponse,
@ -17,36 +43,51 @@ from src.models import (
)
from .base import LLMProvider
from .errors import ProviderCapabilityError
from .errors import (
NoHealthyProviderError,
ProviderAuthError,
ProviderBlockedError,
ProviderCapabilityError,
ProviderError,
ProviderRateLimitError,
)
from .ollama import OllamaProvider
from .openai_compat import OpenAICompatProvider
logger = logging.getLogger(__name__)
T = TypeVar("T")
class ProviderRouter:
"""Routes requests to appropriate LLM providers with auto-fallback.
"""Health-aware provider router with alias resolution.
When auto_fallback_enabled is True and a Google API key is configured:
- Ollama requests that fail or exceed max_concurrent are automatically
retried on Google Gemini with model mapping.
- Explicit provider requests (e.g., openrouter/...) are never fallback-routed.
Construct with the AliasRegistry and ProviderHealthCache from
application startup; both are external dependencies so tests can
inject mocks without going through global state.
"""
def __init__(self):
def __init__(
self,
aliases: AliasRegistry,
health_cache: ProviderHealthCache,
) -> None:
self.aliases = aliases
self.health_cache = health_cache
self.providers: dict[str, LLMProvider] = {}
self._ollama_concurrent: int = 0
self._ollama_health_cache: tuple[dict[str, Any] | None, float] = (None, 0)
self._health_cache_ttl: float = 5.0 # seconds
self._initialize_providers()
def _initialize_providers(self) -> None:
"""Initialize available providers based on configuration."""
# Ollama is always available (local)
self.providers["ollama"] = OllamaProvider()
logger.info(f"Initialized Ollama provider at {settings.ollama_url}")
# ------------------------------------------------------------------
# Provider initialisation
# ------------------------------------------------------------------
def _initialize_providers(self) -> None:
"""Spin up provider adapters based on what's configured."""
# Ollama: always present (talks to a local/proxied server). Whether
# it's actually reachable is the cache's job to figure out.
self.providers["ollama"] = OllamaProvider()
logger.info("Initialized Ollama provider at %s", settings.ollama_url)
# Google Gemini (fallback provider)
if settings.google_api_key:
from .google import GoogleProvider
@ -54,9 +95,8 @@ class ProviderRouter:
api_key=settings.google_api_key,
default_model=settings.google_default_model,
)
logger.info("Initialized Google Gemini provider (fallback)")
logger.info("Initialized Google Gemini provider")
# OpenRouter (if API key configured)
if settings.openrouter_api_key:
self.providers["openrouter"] = OpenAICompatProvider(
name="openrouter",
@ -66,7 +106,6 @@ class ProviderRouter:
)
logger.info("Initialized OpenRouter provider")
# Groq (if API key configured)
if settings.groq_api_key:
self.providers["groq"] = OpenAICompatProvider(
name="groq",
@ -75,7 +114,6 @@ class ProviderRouter:
)
logger.info("Initialized Groq provider")
# Together (if API key configured)
if settings.together_api_key:
self.providers["together"] = OpenAICompatProvider(
name="together",
@ -84,83 +122,82 @@ class ProviderRouter:
)
logger.info("Initialized Together provider")
# ------------------------------------------------------------------
# Helpers
# ------------------------------------------------------------------
def _parse_model(self, model: str) -> tuple[str, str]:
"""Parse model string into (provider, model_name)."""
"""Split ``provider/model`` into its parts.
Bare names (no prefix) default to Ollama for compatibility with
plain OpenAI-style requests. Aliases (``mana/...``) are resolved
before this is ever called.
"""
if "/" in model:
parts = model.split("/", 1)
provider = parts[0].lower()
model_name = parts[1]
else:
provider = "ollama"
model_name = model
provider, _, model_name = model.partition("/")
return provider.lower(), model_name
return "ollama", model
return provider, model_name
def _resolve_chain(self, model_or_alias: str) -> list[str]:
"""Expand aliases to chains; pass everything else through unchanged."""
if AliasRegistry.is_alias(model_or_alias):
return list(self.aliases.resolve_chain(model_or_alias))
return [model_or_alias]
def _get_provider(self, provider_name: str) -> LLMProvider:
"""Get provider by name, raise if not available."""
if provider_name not in self.providers:
available = list(self.providers.keys())
raise ValueError(
f"Provider '{provider_name}' not available. "
f"Available providers: {available}"
)
return self.providers[provider_name]
@staticmethod
def _is_retryable(exc: BaseException) -> bool:
"""Should we treat this exception as "try the next chain entry"?
def _can_fallback_to_google(self, provider_name: str) -> bool:
"""Check if a request can be fallback-routed to Google."""
return (
settings.auto_fallback_enabled
and provider_name == "ollama"
and "google" in self.providers
)
ConnectError / timeouts / 5xx / rate-limits = yes (provider blip).
Auth / capability / blocked / 4xx = no (caller has to fix
something; retrying with a different provider only hides the bug).
"""
if isinstance(exc, ProviderCapabilityError):
return False
if isinstance(exc, ProviderBlockedError):
return False
if isinstance(exc, ProviderAuthError):
return False
if isinstance(exc, ProviderRateLimitError):
return True
if isinstance(
exc,
(
httpx.ConnectError,
httpx.ConnectTimeout,
httpx.ReadError,
httpx.ReadTimeout,
httpx.RemoteProtocolError,
httpx.WriteError,
httpx.WriteTimeout,
httpx.PoolTimeout,
),
):
return True
if isinstance(exc, httpx.HTTPStatusError):
return exc.response.status_code >= 500
if isinstance(exc, ProviderError):
# Any other provider-side error — treat as retryable.
# Subclasses with explicit non-retry semantics are caught above.
return True
# Unknown exception types: do NOT silently retry. Better to
# surface a strange error than hide a real bug behind a fallback.
return False
def _should_use_ollama(self) -> bool:
"""Determine if Ollama should handle the request based on load."""
return self._ollama_concurrent < settings.ollama_max_concurrent
async def _get_ollama_health_cached(self) -> dict[str, Any]:
"""Get Ollama health with caching (5s TTL)."""
cached, cached_at = self._ollama_health_cache
if cached is not None and (time.time() - cached_at) < self._health_cache_ttl:
return cached
try:
provider = self.providers.get("ollama")
if provider:
result = await provider.health_check()
else:
result = {"status": "unhealthy", "error": "no ollama provider"}
except Exception as e:
result = {"status": "unhealthy", "error": str(e)}
self._ollama_health_cache = (result, time.time())
return result
async def _fallback_to_google(
self,
request: ChatCompletionRequest,
model_name: str,
original_error: Exception | None = None,
) -> ChatCompletionResponse:
"""Route a request to Google Gemini as fallback."""
from .google import GoogleProvider
google = self.providers["google"]
assert isinstance(google, GoogleProvider)
gemini_model = google.map_model(model_name)
reason = f"error: {original_error}" if original_error else "overloaded"
logger.warning(
f"Falling back to Google Gemini ({gemini_model}) for ollama/{model_name} ({reason})"
)
return await google.chat_completion(request, gemini_model)
@staticmethod
def _exception_summary(exc: BaseException) -> str:
"""Compact one-liner for cache.last_error and log entries."""
return f"{type(exc).__name__}: {exc}"
def _check_tool_capability(
self, provider: LLMProvider, model_name: str, request: ChatCompletionRequest
self,
provider: LLMProvider,
model_name: str,
request: ChatCompletionRequest,
) -> None:
"""Refuse tool-bearing requests for providers/models without tool support.
"""Refuse tool-bearing requests for models that don't support tools.
Silent downgrade (dropping the `tools` payload) is more dangerous
Silent downgrade (dropping the ``tools`` payload) is more dangerous
than an explicit error the caller would get plain text back and
have no way to tell the tools never reached the model.
"""
@ -169,168 +206,228 @@ class ProviderRouter:
if not provider.model_supports_tools(model_name):
raise ProviderCapabilityError(
f"{provider.name}/{model_name} does not support tool calling. "
"Choose a tool-capable model (e.g. gemini-2.5-flash, llama3.1:*)"
"Choose a tool-capable model (e.g. groq/llama-3.3-70b-versatile)"
)
# ------------------------------------------------------------------
# Core fallback executor (non-streaming)
# ------------------------------------------------------------------
async def _execute_with_fallback(
self,
model_or_alias: str,
request: ChatCompletionRequest,
call: Callable[[LLMProvider, str, ChatCompletionRequest], Awaitable[T]],
) -> T:
"""Walk the resolved chain, returning the first successful result.
``call`` is the operation to run against each chain entry, e.g.
``lambda p, m, req: p.chat_completion(req, m)``. The function
receives the provider instance, the model name (without the
provider prefix), and the original request.
"""
chain = self._resolve_chain(model_or_alias)
attempts: list[tuple[str, str]] = []
last_exc: Exception | None = None
for entry in chain:
provider_name, model_name = self._parse_model(entry)
if provider_name not in self.providers:
logger.debug(
"skip chain entry %s — provider %s not configured here",
entry,
provider_name,
)
attempts.append((entry, "unconfigured"))
continue
if not self.health_cache.is_healthy(provider_name):
logger.debug("skip chain entry %s — cache says unhealthy", entry)
attempts.append((entry, "cache-unhealthy"))
continue
provider = self.providers[provider_name]
self._check_tool_capability(provider, model_name, request)
try:
logger.info(
"execute → %s (alias=%s)",
entry,
model_or_alias if model_or_alias != entry else "<direct>",
)
result = await call(provider, model_name, request)
self.health_cache.mark_healthy(provider_name)
return result
except Exception as e:
if not self._is_retryable(e):
# Caller error / non-retryable provider error — propagate
# without touching the health cache. The cache is for
# liveness, not for recording what the user asked for
# being wrong.
raise
self.health_cache.mark_unhealthy(provider_name, self._exception_summary(e))
attempts.append((entry, type(e).__name__))
last_exc = e
logger.warning(
"execute %s failed (retryable, will try next): %s",
entry,
e,
)
raise NoHealthyProviderError(model_or_alias, attempts, last_exc)
# ------------------------------------------------------------------
# Public API — non-streaming
# ------------------------------------------------------------------
async def chat_completion(
self,
request: ChatCompletionRequest,
) -> ChatCompletionResponse:
"""Route chat completion request with auto-fallback."""
provider_name, model_name = self._parse_model(request.model)
"""Chat completion with alias resolution + health-aware fallback."""
# Non-Ollama providers: direct routing, no fallback
if provider_name != "ollama":
provider = self._get_provider(provider_name)
self._check_tool_capability(provider, model_name, request)
logger.info(f"Routing chat completion to {provider_name}/{model_name}")
return await provider.chat_completion(request, model_name)
async def call(provider: LLMProvider, model: str, req: ChatCompletionRequest):
return await provider.chat_completion(req, model)
# Ollama with fallback logic
can_fallback = self._can_fallback_to_google(provider_name)
return await self._execute_with_fallback(request.model, request, call)
# Check if Ollama is overloaded
if can_fallback and not self._should_use_ollama():
return await self._fallback_to_google(request, model_name)
# Try Ollama first
provider = self._get_provider("ollama")
self._check_tool_capability(provider, model_name, request)
logger.info(f"Routing chat completion to ollama/{model_name}")
self._ollama_concurrent += 1
try:
return await provider.chat_completion(request, model_name)
except Exception as e:
logger.error(f"Chat completion failed on ollama: {e}")
if can_fallback:
return await self._fallback_to_google(request, model_name, e)
raise
finally:
self._ollama_concurrent -= 1
# ------------------------------------------------------------------
# Public API — streaming (pre-first-byte fallback)
# ------------------------------------------------------------------
async def chat_completion_stream(
self,
request: ChatCompletionRequest,
) -> AsyncIterator[ChatCompletionStreamResponse]:
"""Route streaming chat completion with auto-fallback."""
provider_name, model_name = self._parse_model(request.model)
"""Streaming variant. Falls back BEFORE the first chunk arrives;
once the first chunk has been yielded the provider is committed
and any further error propagates.
# Non-Ollama: direct
if provider_name != "ollama":
provider = self._get_provider(provider_name)
Why pre-first-byte only: stitching half-streams from two different
providers would mix two voices in the output and is impossible to
sanity-check after the fact.
"""
chain = self._resolve_chain(request.model)
attempts: list[tuple[str, str]] = []
last_exc: Exception | None = None
for entry in chain:
provider_name, model_name = self._parse_model(entry)
if provider_name not in self.providers:
attempts.append((entry, "unconfigured"))
continue
if not self.health_cache.is_healthy(provider_name):
attempts.append((entry, "cache-unhealthy"))
continue
provider = self.providers[provider_name]
self._check_tool_capability(provider, model_name, request)
logger.info(f"Routing streaming to {provider_name}/{model_name}")
async for chunk in provider.chat_completion_stream(request, model_name):
stream = provider.chat_completion_stream(request, model_name)
try:
first_chunk = await stream.__anext__()
except StopAsyncIteration:
# Empty stream is a successful but content-free response.
# Commit and exit cleanly.
self.health_cache.mark_healthy(provider_name)
logger.info("stream %s yielded empty response", entry)
return
except Exception as e:
if not self._is_retryable(e):
raise
self.health_cache.mark_unhealthy(provider_name, self._exception_summary(e))
attempts.append((entry, type(e).__name__))
last_exc = e
logger.warning(
"stream %s failed before first byte (retryable, trying next): %s",
entry,
e,
)
continue
# First byte landed — commit the provider, mark healthy, drain
# the rest of the stream. Any error from here on propagates;
# it is NOT safe to splice another provider's output in.
self.health_cache.mark_healthy(provider_name)
logger.info("stream → %s (committed after first chunk)", entry)
yield first_chunk
async for chunk in stream:
yield chunk
return
# Ollama with fallback
can_fallback = self._can_fallback_to_google(provider_name)
raise NoHealthyProviderError(request.model, attempts, last_exc)
if can_fallback and not self._should_use_ollama():
from .google import GoogleProvider
# ------------------------------------------------------------------
# Embeddings — no fallback (out of scope for M3, separate concerns)
# ------------------------------------------------------------------
google = self.providers["google"]
assert isinstance(google, GoogleProvider)
gemini_model = google.map_model(model_name)
logger.warning(f"Streaming fallback to Google Gemini ({gemini_model})")
async for chunk in google.chat_completion_stream(request, gemini_model):
yield chunk
return
provider = self._get_provider("ollama")
self._check_tool_capability(provider, model_name, request)
logger.info(f"Routing streaming to ollama/{model_name}")
self._ollama_concurrent += 1
try:
async for chunk in provider.chat_completion_stream(request, model_name):
yield chunk
except Exception as e:
logger.error(f"Streaming failed on ollama: {e}")
if can_fallback:
from .google import GoogleProvider
google = self.providers["google"]
assert isinstance(google, GoogleProvider)
gemini_model = google.map_model(model_name)
logger.warning(f"Streaming fallback to Google Gemini ({gemini_model})")
async for chunk in google.chat_completion_stream(request, gemini_model):
yield chunk
else:
raise
finally:
self._ollama_concurrent -= 1
async def embeddings(
self,
request: EmbeddingRequest,
) -> EmbeddingResponse:
"""Route embeddings request to appropriate provider."""
async def embeddings(self, request: EmbeddingRequest) -> EmbeddingResponse:
"""Route an embeddings request directly. No alias / fallback."""
provider_name, model_name = self._parse_model(request.model)
provider = self._get_provider(provider_name)
logger.info(f"Routing embeddings to {provider_name}/{model_name}")
if provider_name not in self.providers:
available = list(self.providers)
raise ValueError(
f"Provider '{provider_name}' not available. Available: {available}"
)
provider = self.providers[provider_name]
logger.info("embeddings → %s/%s", provider_name, model_name)
return await provider.embeddings(request, model_name)
async def list_models(self) -> list[ModelInfo]:
"""List all available models from all providers."""
all_models: list[ModelInfo] = []
# ------------------------------------------------------------------
# Discovery / introspection
# ------------------------------------------------------------------
async def list_models(self) -> list[ModelInfo]:
"""List all available models from all configured providers.
Best-effort: providers that error are skipped with a warning so a
single broken provider can't take down ``GET /v1/models``.
"""
all_models: list[ModelInfo] = []
for provider in self.providers.values():
try:
models = await provider.list_models()
all_models.extend(models)
except Exception as e:
logger.warning(f"Failed to list models from {provider.name}: {e}")
all_models.extend(await provider.list_models())
except Exception as e: # noqa: BLE001
logger.warning("Failed to list models from %s: %s", provider.name, e)
return all_models
async def get_model(self, model_id: str) -> ModelInfo | None:
"""Get specific model info."""
"""Look up a single model by id, dispatching on the prefix."""
provider_name, model_name = self._parse_model(model_id)
if provider_name not in self.providers:
return None
provider = self.providers[provider_name]
models = await provider.list_models()
for model in models:
if model.id == model_id or model.id.endswith(f"/{model_name}"):
return model
models = await self.providers[provider_name].list_models()
for m in models:
if m.id == model_id or m.id.endswith(f"/{model_name}"):
return m
return None
async def health_check(self) -> dict[str, Any]:
"""Check health of all providers."""
results: dict[str, Any] = {}
"""Snapshot of the per-provider liveness cache.
for name, provider in self.providers.items():
results[name] = await provider.health_check()
# Overall status
all_healthy = all(r.get("status") == "healthy" for r in results.values())
any_healthy = any(r.get("status") == "healthy" for r in results.values())
status_info: dict[str, Any] = {
"status": "healthy" if all_healthy else ("degraded" if any_healthy else "unhealthy"),
"providers": results,
}
# Include fallback info
if settings.auto_fallback_enabled and "google" in self.providers:
status_info["fallback"] = {
"enabled": True,
"ollama_concurrent": self._ollama_concurrent,
"ollama_max_concurrent": settings.ollama_max_concurrent,
Returns the same shape as before for backwards-compat with
``GET /health`` (deprecated structure M4 will swap to a
cleaner ``/v1/health`` endpoint).
"""
snapshot = self.health_cache.snapshot(expected=list(self.providers))
providers_out: dict[str, Any] = {}
for name, state in snapshot.items():
providers_out[name] = {
"status": "healthy" if state.healthy else "unhealthy",
"consecutive_failures": state.consecutive_failures,
"last_error": state.last_error,
"last_check_unix": state.last_check or None,
"unhealthy_until_unix": state.unhealthy_until or None,
}
return status_info
all_healthy = all(state.healthy for state in snapshot.values())
any_healthy = any(state.healthy for state in snapshot.values())
return {
"status": "healthy" if all_healthy else ("degraded" if any_healthy else "unhealthy"),
"providers": providers_out,
}
async def close(self) -> None:
"""Close all providers."""
"""Close all provider clients."""
for provider in self.providers.values():
await provider.close()

View file

@ -1,44 +1,54 @@
"""Provider tests."""
from pathlib import Path
import pytest
from src.models import ChatCompletionRequest, Message
from src.aliases import AliasRegistry
from src.health import ProviderHealthCache
from src.models import ChatCompletionRequest, EmbeddingRequest, Message
from src.providers import OllamaProvider, OpenAICompatProvider, ProviderRouter
@pytest.fixture
def shipped_aliases() -> AliasRegistry:
"""The repo's real aliases.yaml — same one production uses."""
return AliasRegistry(Path(__file__).resolve().parents[1] / "aliases.yaml")
@pytest.fixture
def router(shipped_aliases: AliasRegistry) -> ProviderRouter:
return ProviderRouter(aliases=shipped_aliases, health_cache=ProviderHealthCache())
class TestProviderRouter:
"""Test provider routing logic."""
def test_parse_model_with_provider(self):
"""Test model parsing with provider prefix."""
router = ProviderRouter()
"""Tests for the helpers exposed by the router."""
def test_parse_model_with_provider(self, router: ProviderRouter) -> None:
provider, model = router._parse_model("ollama/gemma3:4b")
assert provider == "ollama"
assert model == "gemma3:4b"
def test_parse_model_without_provider(self):
"""Test model parsing without provider prefix (defaults to ollama)."""
router = ProviderRouter()
def test_parse_model_without_provider(self, router: ProviderRouter) -> None:
# Bare names default to Ollama for OpenAI-style compat.
provider, model = router._parse_model("gemma3:4b")
assert provider == "ollama"
assert model == "gemma3:4b"
def test_parse_model_openrouter(self):
"""Test model parsing for OpenRouter."""
router = ProviderRouter()
def test_parse_model_openrouter(self, router: ProviderRouter) -> None:
provider, model = router._parse_model("openrouter/meta-llama/llama-3.1-8b-instruct")
assert provider == "openrouter"
assert model == "meta-llama/llama-3.1-8b-instruct"
def test_get_invalid_provider(self):
"""Test getting invalid provider raises error."""
router = ProviderRouter()
@pytest.mark.asyncio
async def test_embeddings_unknown_provider_raises(self, router: ProviderRouter) -> None:
# Embeddings don't go through the alias/fallback pipeline — they
# hit the requested provider directly. Asking for an unconfigured
# one is a config error and must raise loudly.
with pytest.raises(ValueError, match="not available"):
router._get_provider("invalid_provider")
await router.embeddings(
EmbeddingRequest(model="bogus_provider/x", input="hi")
)
class TestOllamaProvider:

View file

@ -0,0 +1,526 @@
"""Tests for ProviderRouter fallback / alias execution (M3)."""
from __future__ import annotations
from collections.abc import AsyncIterator
from typing import Any
import httpx
import pytest
from src.aliases import AliasRegistry
from src.health import ProviderHealthCache
from src.models import (
ChatCompletionRequest,
ChatCompletionResponse,
ChatCompletionStreamResponse,
Choice,
DeltaContent,
EmbeddingRequest,
EmbeddingResponse,
Message,
MessageResponse,
ModelInfo,
StreamChoice,
)
from src.providers import ProviderRouter
from src.providers.base import LLMProvider
from src.providers.errors import (
NoHealthyProviderError,
ProviderAuthError,
ProviderCapabilityError,
ProviderRateLimitError,
)
# ---------------------------------------------------------------------------
# Test doubles
# ---------------------------------------------------------------------------
class MockProvider(LLMProvider):
"""Provider that lets tests inject a sequence of behaviours.
Each call pops one entry from ``behaviors``. Strings ``"ok"`` and
``"empty"`` are sentinels for normal returns; everything else (a
BaseException instance / class) is raised.
"""
supports_tools = True
def __init__(self, name: str, behaviors: list[Any] | None = None) -> None:
self.name = name
self._behaviors: list[Any] = list(behaviors or [])
self.calls: list[str] = []
def push(self, *behaviors: Any) -> None:
self._behaviors.extend(behaviors)
def _next(self) -> Any:
return self._behaviors.pop(0) if self._behaviors else "ok"
async def chat_completion(
self, request: ChatCompletionRequest, model: str
) -> ChatCompletionResponse:
self.calls.append(model)
b = self._next()
if isinstance(b, type) and issubclass(b, BaseException):
raise b("simulated")
if isinstance(b, BaseException):
raise b
return _ok_response(self.name, model)
async def chat_completion_stream(
self, request: ChatCompletionRequest, model: str
) -> AsyncIterator[ChatCompletionStreamResponse]:
self.calls.append(model)
b = self._next()
if isinstance(b, type) and issubclass(b, BaseException):
raise b("simulated")
if isinstance(b, BaseException):
raise b
if b == "empty":
return
for content in ("Hello", " ", "world"):
yield ChatCompletionStreamResponse(
model=f"{self.name}/{model}",
choices=[StreamChoice(delta=DeltaContent(content=content))],
)
async def list_models(self) -> list[ModelInfo]:
return [ModelInfo(id=f"{self.name}/{m}") for m in ("modelA", "modelB")]
async def embeddings(
self, request: EmbeddingRequest, model: str
) -> EmbeddingResponse:
raise NotImplementedError
async def health_check(self) -> dict[str, Any]:
return {"status": "healthy"}
class FailFirstChunkProvider(MockProvider):
"""Streaming provider that raises BEFORE the first chunk every time.
Kept separate from MockProvider's behaviour list so the per-call
semantics stay simple this one models a permanently-broken streamer.
"""
def __init__(self, name: str, exc: BaseException) -> None:
super().__init__(name)
self._exc = exc
async def chat_completion_stream(self, request, model): # type: ignore[override]
self.calls.append(model)
raise self._exc
# the yield is unreachable but keeps the function an async generator
yield # pragma: no cover
def _ok_response(provider: str, model: str) -> ChatCompletionResponse:
return ChatCompletionResponse(
model=f"{provider}/{model}",
choices=[
Choice(
message=MessageResponse(content="ok"),
finish_reason="stop",
)
],
)
def _request(model: str) -> ChatCompletionRequest:
return ChatCompletionRequest(
model=model,
messages=[Message(role="user", content="hi")],
)
def _aliases_yaml(tmp_path) -> AliasRegistry:
"""A two-alias config used across most tests."""
cfg = (
"aliases:\n"
" mana/long-form:\n"
' description: "long"\n'
" chain:\n"
" - alpha/m1\n"
" - beta/m2\n"
" - gamma/m3\n"
" mana/single:\n"
' description: "single-entry"\n'
" chain:\n"
" - alpha/solo\n"
)
p = tmp_path / "aliases.yaml"
p.write_text(cfg)
return AliasRegistry(p)
def _make_router(
tmp_path,
*,
providers: dict[str, MockProvider],
cache: ProviderHealthCache | None = None,
) -> ProviderRouter:
aliases = _aliases_yaml(tmp_path)
router = ProviderRouter(aliases=aliases, health_cache=cache or ProviderHealthCache())
# Replace the auto-initialised live providers with the test doubles.
router.providers = dict(providers)
return router
# ---------------------------------------------------------------------------
# Non-streaming chain walking
# ---------------------------------------------------------------------------
class TestChatCompletionChain:
@pytest.mark.asyncio
async def test_first_provider_ok_returns_immediately(self, tmp_path) -> None:
alpha = MockProvider("alpha", ["ok"])
beta = MockProvider("beta")
router = _make_router(tmp_path, providers={"alpha": alpha, "beta": beta})
resp = await router.chat_completion(_request("mana/long-form"))
assert resp.model == "alpha/m1"
assert alpha.calls == ["m1"]
assert beta.calls == [] # never reached
@pytest.mark.asyncio
async def test_falls_through_on_connect_error(self, tmp_path) -> None:
alpha = MockProvider("alpha", [httpx.ConnectError("dead")])
beta = MockProvider("beta", ["ok"])
router = _make_router(tmp_path, providers={"alpha": alpha, "beta": beta})
resp = await router.chat_completion(_request("mana/long-form"))
assert resp.model == "beta/m2"
assert alpha.calls == ["m1"]
assert beta.calls == ["m2"]
@pytest.mark.asyncio
async def test_skips_unconfigured_chain_entries(self, tmp_path) -> None:
# gamma isn't configured at all → chain should silently skip it
# rather than raise.
alpha = MockProvider("alpha", [httpx.ConnectError("dead")])
beta = MockProvider("beta", [httpx.ConnectError("dead too")])
router = _make_router(tmp_path, providers={"alpha": alpha, "beta": beta})
with pytest.raises(NoHealthyProviderError) as exc_info:
await router.chat_completion(_request("mana/long-form"))
# All three entries appear in attempts: two as ConnectError, one
# as unconfigured (not a fatal error, just skipped).
attempts = exc_info.value.attempts
assert ("alpha/m1", "ConnectError") in attempts
assert ("beta/m2", "ConnectError") in attempts
assert ("gamma/m3", "unconfigured") in attempts
@pytest.mark.asyncio
async def test_skips_cache_unhealthy(self, tmp_path) -> None:
cache = ProviderHealthCache(failure_threshold=1)
cache.mark_unhealthy("alpha", "stale")
alpha = MockProvider("alpha", ["ok"])
beta = MockProvider("beta", ["ok"])
router = _make_router(
tmp_path, providers={"alpha": alpha, "beta": beta}, cache=cache
)
resp = await router.chat_completion(_request("mana/long-form"))
assert alpha.calls == [] # router skipped per cache
assert beta.calls == ["m2"]
assert resp.model == "beta/m2"
@pytest.mark.asyncio
async def test_5xx_treated_as_retryable(self, tmp_path) -> None:
five_hundred = httpx.HTTPStatusError(
"boom",
request=httpx.Request("POST", "http://x"),
response=httpx.Response(503),
)
alpha = MockProvider("alpha", [five_hundred])
beta = MockProvider("beta", ["ok"])
router = _make_router(tmp_path, providers={"alpha": alpha, "beta": beta})
resp = await router.chat_completion(_request("mana/long-form"))
assert resp.model == "beta/m2"
@pytest.mark.asyncio
async def test_4xx_propagates(self, tmp_path) -> None:
four_hundred = httpx.HTTPStatusError(
"bad request",
request=httpx.Request("POST", "http://x"),
response=httpx.Response(422),
)
alpha = MockProvider("alpha", [four_hundred])
beta = MockProvider("beta", ["ok"])
router = _make_router(tmp_path, providers={"alpha": alpha, "beta": beta})
with pytest.raises(httpx.HTTPStatusError):
await router.chat_completion(_request("mana/long-form"))
# Beta never tried — caller's request needs fixing, retrying
# against another model would just hide the bug.
assert beta.calls == []
@pytest.mark.asyncio
async def test_capability_error_propagates(self, tmp_path) -> None:
alpha = MockProvider("alpha", [ProviderCapabilityError("no tools")])
beta = MockProvider("beta", ["ok"])
router = _make_router(tmp_path, providers={"alpha": alpha, "beta": beta})
with pytest.raises(ProviderCapabilityError):
await router.chat_completion(_request("mana/long-form"))
assert beta.calls == []
@pytest.mark.asyncio
async def test_auth_error_propagates(self, tmp_path) -> None:
# Auth errors mean OUR setup is broken (wrong key); falling back
# to the next provider hides the misconfiguration.
alpha = MockProvider("alpha", [ProviderAuthError("bad key")])
beta = MockProvider("beta", ["ok"])
router = _make_router(tmp_path, providers={"alpha": alpha, "beta": beta})
with pytest.raises(ProviderAuthError):
await router.chat_completion(_request("mana/long-form"))
assert beta.calls == []
@pytest.mark.asyncio
async def test_rate_limit_is_retryable(self, tmp_path) -> None:
alpha = MockProvider("alpha", [ProviderRateLimitError("slow down")])
beta = MockProvider("beta", ["ok"])
router = _make_router(tmp_path, providers={"alpha": alpha, "beta": beta})
resp = await router.chat_completion(_request("mana/long-form"))
assert resp.model == "beta/m2"
@pytest.mark.asyncio
async def test_all_fail_raises_no_healthy_provider(self, tmp_path) -> None:
alpha = MockProvider("alpha", [httpx.ConnectError("a")])
beta = MockProvider("beta", [httpx.ConnectError("b")])
gamma = MockProvider("gamma", [httpx.ConnectError("c")])
router = _make_router(
tmp_path, providers={"alpha": alpha, "beta": beta, "gamma": gamma}
)
with pytest.raises(NoHealthyProviderError) as exc_info:
await router.chat_completion(_request("mana/long-form"))
assert exc_info.value.model_or_alias == "mana/long-form"
assert isinstance(exc_info.value.last_exception, httpx.ConnectError)
# 503 status so calling code (mana-api etc.) can decide to retry
# later vs surface a clean error to the user.
assert exc_info.value.http_status == 503
@pytest.mark.asyncio
async def test_direct_provider_string_no_alias_resolution(self, tmp_path) -> None:
# Caller bypasses aliases by passing a direct provider/model.
# No fallback chain — fail = fail.
alpha = MockProvider("alpha", [httpx.ConnectError("dead")])
beta = MockProvider("beta", ["ok"])
router = _make_router(tmp_path, providers={"alpha": alpha, "beta": beta})
with pytest.raises(NoHealthyProviderError):
await router.chat_completion(_request("alpha/anything"))
# Beta would have served if this had been an alias — but it
# wasn't, so beta never gets touched.
assert beta.calls == []
# ---------------------------------------------------------------------------
# Health-cache feedback: success clears, failure marks
# ---------------------------------------------------------------------------
class TestHealthCacheFeedback:
@pytest.mark.asyncio
async def test_success_marks_provider_healthy(self, tmp_path) -> None:
cache = ProviderHealthCache(failure_threshold=1)
cache.mark_unhealthy("alpha", "stale-from-probe")
# After the cache TTL the cache thinks alpha might be OK again,
# so the router will try it; success must fully clear the state.
# (Force half-open by zeroing backoff.)
alpha = MockProvider("alpha", ["ok"])
router = _make_router(
tmp_path,
providers={"alpha": alpha},
cache=ProviderHealthCache(), # fresh cache, alpha optimistic
)
await router.chat_completion(_request("mana/single"))
assert router.health_cache.get_state("alpha").healthy is True
assert router.health_cache.get_state("alpha").consecutive_failures == 0
@pytest.mark.asyncio
async def test_failure_marks_provider_unhealthy(self, tmp_path) -> None:
# threshold=1 so a single fail is enough to flip the breaker.
cache = ProviderHealthCache(failure_threshold=1)
alpha = MockProvider("alpha", [httpx.ConnectError("boom")])
beta = MockProvider("beta", ["ok"])
router = _make_router(
tmp_path, providers={"alpha": alpha, "beta": beta}, cache=cache
)
await router.chat_completion(_request("mana/long-form"))
assert cache.get_state("alpha").healthy is False
assert cache.get_state("alpha").last_error is not None
assert "ConnectError" in cache.get_state("alpha").last_error
@pytest.mark.asyncio
async def test_propagating_error_does_not_touch_cache(self, tmp_path) -> None:
# Auth/Capability errors are about CALLER state, not provider
# health — the cache must stay clean so a real outage later
# isn't masked by stale "marked unhealthy because of bad key".
cache = ProviderHealthCache(failure_threshold=1)
alpha = MockProvider("alpha", [ProviderAuthError("bad key")])
router = _make_router(tmp_path, providers={"alpha": alpha}, cache=cache)
with pytest.raises(ProviderAuthError):
await router.chat_completion(_request("mana/single"))
# No state recorded.
assert cache.get_state("alpha") is None
# ---------------------------------------------------------------------------
# Streaming pre-first-byte fallback
# ---------------------------------------------------------------------------
class TestChatCompletionStream:
@pytest.mark.asyncio
async def test_first_provider_streams_normally(self, tmp_path) -> None:
alpha = MockProvider("alpha", ["ok"])
beta = MockProvider("beta")
router = _make_router(tmp_path, providers={"alpha": alpha, "beta": beta})
chunks = [
c async for c in router.chat_completion_stream(_request("mana/long-form"))
]
assert beta.calls == []
assert len(chunks) == 3
assert "".join(c.choices[0].delta.content or "" for c in chunks) == "Hello world"
@pytest.mark.asyncio
async def test_pre_first_byte_failure_falls_back(self, tmp_path) -> None:
alpha = FailFirstChunkProvider("alpha", httpx.ConnectError("dead"))
beta = MockProvider("beta", ["ok"])
router = _make_router(tmp_path, providers={"alpha": alpha, "beta": beta})
chunks = [
c async for c in router.chat_completion_stream(_request("mana/long-form"))
]
assert alpha.calls == ["m1"]
assert beta.calls == ["m2"]
assert len(chunks) == 3
assert all(c.model == "beta/m2" for c in chunks)
@pytest.mark.asyncio
async def test_pre_first_byte_4xx_propagates_no_fallback(self, tmp_path) -> None:
alpha = FailFirstChunkProvider("alpha", ProviderCapabilityError("no tools"))
beta = MockProvider("beta", ["ok"])
router = _make_router(tmp_path, providers={"alpha": alpha, "beta": beta})
with pytest.raises(ProviderCapabilityError):
async for _ in router.chat_completion_stream(_request("mana/long-form")):
pass
assert beta.calls == []
@pytest.mark.asyncio
async def test_empty_stream_commits_without_fallback(self, tmp_path) -> None:
# Empty-but-successful stream is a valid response, not a failure
# we should retry — committing avoids accidentally calling two
# providers and double-billing.
alpha = MockProvider("alpha", ["empty"])
beta = MockProvider("beta", ["ok"])
router = _make_router(tmp_path, providers={"alpha": alpha, "beta": beta})
chunks = [
c async for c in router.chat_completion_stream(_request("mana/long-form"))
]
assert chunks == []
assert beta.calls == [] # didn't fall through
@pytest.mark.asyncio
async def test_mid_stream_failure_does_not_fall_back(self, tmp_path) -> None:
# Custom provider that yields once then raises mid-stream — the
# router has already committed and must let the error propagate
# rather than splice in another provider's voice.
class MidStreamFailProvider(MockProvider):
async def chat_completion_stream(self, request, model): # type: ignore[override]
self.calls.append(model)
yield ChatCompletionStreamResponse(
model=f"{self.name}/{model}",
choices=[StreamChoice(delta=DeltaContent(content="halb"))],
)
raise httpx.RemoteProtocolError("connection dropped")
alpha = MidStreamFailProvider("alpha")
beta = MockProvider("beta", ["ok"])
router = _make_router(tmp_path, providers={"alpha": alpha, "beta": beta})
collected: list[str] = []
with pytest.raises(httpx.RemoteProtocolError):
async for chunk in router.chat_completion_stream(_request("mana/long-form")):
collected.append(chunk.choices[0].delta.content or "")
# We got the half-chunk that landed before the break; beta was
# NOT called as fallback.
assert collected == ["halb"]
assert beta.calls == []
@pytest.mark.asyncio
async def test_all_fail_streaming_raises_no_healthy_provider(self, tmp_path) -> None:
alpha = FailFirstChunkProvider("alpha", httpx.ConnectError("a"))
beta = FailFirstChunkProvider("beta", httpx.ConnectError("b"))
gamma = FailFirstChunkProvider("gamma", httpx.ConnectError("c"))
router = _make_router(
tmp_path, providers={"alpha": alpha, "beta": beta, "gamma": gamma}
)
with pytest.raises(NoHealthyProviderError):
async for _ in router.chat_completion_stream(_request("mana/long-form")):
pass
# ---------------------------------------------------------------------------
# Health-check shape (still using the cache snapshot)
# ---------------------------------------------------------------------------
class TestHealthCheck:
@pytest.mark.asyncio
async def test_health_check_lists_known_providers(self, tmp_path) -> None:
# Even if no probe has run yet, every configured provider should
# appear in the snapshot (zero-defaults) so /health has a stable
# shape for monitors.
alpha = MockProvider("alpha")
beta = MockProvider("beta")
router = _make_router(tmp_path, providers={"alpha": alpha, "beta": beta})
out = await router.health_check()
assert set(out["providers"].keys()) == {"alpha", "beta"}
assert out["status"] == "healthy"
assert all(p["status"] == "healthy" for p in out["providers"].values())
@pytest.mark.asyncio
async def test_health_check_degraded_when_one_unhealthy(self, tmp_path) -> None:
cache = ProviderHealthCache(failure_threshold=1)
cache.mark_unhealthy("alpha", "boom")
alpha = MockProvider("alpha")
beta = MockProvider("beta")
router = _make_router(
tmp_path, providers={"alpha": alpha, "beta": beta}, cache=cache
)
out = await router.health_check()
assert out["status"] == "degraded"
assert out["providers"]["alpha"]["status"] == "unhealthy"
assert out["providers"]["beta"]["status"] == "healthy"