diff --git a/services/mana-llm/src/main.py b/services/mana-llm/src/main.py index f04c4aaea..757105865 100644 --- a/services/mana-llm/src/main.py +++ b/services/mana-llm/src/main.py @@ -3,6 +3,7 @@ import logging import time from contextlib import asynccontextmanager +from pathlib import Path from typing import Any from fastapi import FastAPI, HTTPException, Request @@ -10,8 +11,11 @@ from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import Response from sse_starlette.sse import EventSourceResponse +from src.aliases import AliasRegistry from src.api_auth import ApiKeyMiddleware from src.config import settings +from src.health import ProviderHealthCache +from src.health_probe import HealthProbe, make_http_probe from src.models import ( ChatCompletionRequest, ChatCompletionResponse, @@ -33,25 +37,66 @@ logging.basicConfig( ) logger = logging.getLogger(__name__) -# Global router instance +# Global service singletons router: ProviderRouter | None = None +health_cache: ProviderHealthCache | None = None +health_probe: HealthProbe | None = None +alias_registry: AliasRegistry | None = None + + +def _build_provider_probes( + providers: dict[str, Any], +) -> dict[str, Any]: + """Wire each configured provider to a cheap HTTP probe.""" + probes: dict[str, Any] = {} + + if "ollama" in providers: + probes["ollama"] = make_http_probe(f"{settings.ollama_url}/api/tags") + if "openrouter" in providers: + probes["openrouter"] = make_http_probe( + f"{settings.openrouter_base_url}/models", + headers={"Authorization": f"Bearer {settings.openrouter_api_key}"}, + ) + if "groq" in providers: + probes["groq"] = make_http_probe( + f"{settings.groq_base_url}/models", + headers={"Authorization": f"Bearer {settings.groq_api_key}"}, + ) + if "together" in providers: + probes["together"] = make_http_probe( + f"{settings.together_base_url}/models", + headers={"Authorization": f"Bearer {settings.together_api_key}"}, + ) + # Google: skipped — google-genai SDK is opaque enough that a probe + # would amount to a real API call. Treat as healthy by default; the + # router's call-site fallback will mark it unhealthy on real errors. + return probes @asynccontextmanager async def lifespan(app: FastAPI): - """Application lifespan management.""" - global router + """Application lifespan: load aliases, spin up router + health probe.""" + global router, health_cache, health_probe, alias_registry - # Startup logger.info("Starting mana-llm service...") - router = ProviderRouter() - logger.info(f"Initialized providers: {list(router.providers.keys())}") + + aliases_path = Path(__file__).resolve().parent.parent / "aliases.yaml" + alias_registry = AliasRegistry(aliases_path) + logger.info("Loaded %d aliases from %s", len(alias_registry.list_aliases()), aliases_path) + + health_cache = ProviderHealthCache() + router = ProviderRouter(aliases=alias_registry, health_cache=health_cache) + logger.info("Initialized providers: %s", list(router.providers)) + + health_probe = HealthProbe(health_cache, _build_provider_probes(router.providers)) + await health_probe.start() yield - # Shutdown logger.info("Shutting down mana-llm service...") - if router: + if health_probe is not None: + await health_probe.stop() + if router is not None: await router.close() await close_redis() diff --git a/services/mana-llm/src/providers/errors.py b/services/mana-llm/src/providers/errors.py index f63efa17a..440651270 100644 --- a/services/mana-llm/src/providers/errors.py +++ b/services/mana-llm/src/providers/errors.py @@ -76,3 +76,36 @@ class ProviderCapabilityError(ProviderError): kind = "capability" http_status = 400 + + +class NoHealthyProviderError(ProviderError): + """Every entry in the resolved chain has been tried and either was + unconfigured, marked unhealthy by the cache, or failed in flight. + + Carries the full attempt log so the caller can see which providers + were tried and why each one was skipped or failed — invaluable when + a real outage hits and the API returns 503 instead of the usual 200. + """ + + kind = "no_healthy_provider" + http_status = 503 + + def __init__( + self, + model_or_alias: str, + attempts: list[tuple[str, str]], + last_exception: Exception | None = None, + ) -> None: + self.model_or_alias = model_or_alias + self.attempts = list(attempts) + self.last_exception = last_exception + if attempts: + attempt_log = ", ".join(f"{model}={reason}" for model, reason in attempts) + else: + attempt_log = "(no providers in resolved chain were configured)" + msg = ( + f"no healthy provider could serve {model_or_alias!r}. Attempts: {attempt_log}" + ) + if last_exception is not None: + msg += f". Last error: {type(last_exception).__name__}: {last_exception}" + super().__init__(msg) diff --git a/services/mana-llm/src/providers/router.py b/services/mana-llm/src/providers/router.py index dfec1ac07..0f0dfe00c 100644 --- a/services/mana-llm/src/providers/router.py +++ b/services/mana-llm/src/providers/router.py @@ -1,12 +1,38 @@ -"""Provider routing logic for mana-llm with auto-fallback support.""" +"""Provider routing with alias resolution and health-aware fallback. + +The router is the single entry point that the FastAPI handlers use. Its +job is: + +1. Resolve the request's ``model`` field. If it lives in the ``mana/`` + namespace the :class:`AliasRegistry` returns an ordered chain of + concrete provider/model strings; everything else is treated as a + single-entry chain (caller passed a direct provider/model). +2. Walk the chain, skipping entries whose provider is either + unconfigured at this deployment (no API key) or currently marked + unhealthy in the :class:`ProviderHealthCache`. +3. Try each remaining entry. Connection errors, timeouts, 5xx, and rate + limits are retryable — record them in the cache and move to the next + entry. Capability/auth/blocked errors are caller-fixable and + propagate immediately without touching the health cache. +4. Return the first successful response. If every entry was skipped or + failed, raise :class:`NoHealthyProviderError` (HTTP 503) carrying + the full attempt log so debugging is straightforward. + +The full design lives in ``docs/plans/llm-fallback-aliases.md``. This is +the M3 milestone. +""" + +from __future__ import annotations -import asyncio import logging -import time -from collections.abc import AsyncIterator -from typing import Any +from collections.abc import AsyncIterator, Awaitable, Callable +from typing import Any, TypeVar +import httpx + +from src.aliases import AliasRegistry from src.config import settings +from src.health import ProviderHealthCache from src.models import ( ChatCompletionRequest, ChatCompletionResponse, @@ -17,36 +43,51 @@ from src.models import ( ) from .base import LLMProvider -from .errors import ProviderCapabilityError +from .errors import ( + NoHealthyProviderError, + ProviderAuthError, + ProviderBlockedError, + ProviderCapabilityError, + ProviderError, + ProviderRateLimitError, +) from .ollama import OllamaProvider from .openai_compat import OpenAICompatProvider logger = logging.getLogger(__name__) +T = TypeVar("T") + class ProviderRouter: - """Routes requests to appropriate LLM providers with auto-fallback. + """Health-aware provider router with alias resolution. - When auto_fallback_enabled is True and a Google API key is configured: - - Ollama requests that fail or exceed max_concurrent are automatically - retried on Google Gemini with model mapping. - - Explicit provider requests (e.g., openrouter/...) are never fallback-routed. + Construct with the AliasRegistry and ProviderHealthCache from + application startup; both are external dependencies so tests can + inject mocks without going through global state. """ - def __init__(self): + def __init__( + self, + aliases: AliasRegistry, + health_cache: ProviderHealthCache, + ) -> None: + self.aliases = aliases + self.health_cache = health_cache self.providers: dict[str, LLMProvider] = {} - self._ollama_concurrent: int = 0 - self._ollama_health_cache: tuple[dict[str, Any] | None, float] = (None, 0) - self._health_cache_ttl: float = 5.0 # seconds self._initialize_providers() - def _initialize_providers(self) -> None: - """Initialize available providers based on configuration.""" - # Ollama is always available (local) - self.providers["ollama"] = OllamaProvider() - logger.info(f"Initialized Ollama provider at {settings.ollama_url}") + # ------------------------------------------------------------------ + # Provider initialisation + # ------------------------------------------------------------------ + + def _initialize_providers(self) -> None: + """Spin up provider adapters based on what's configured.""" + # Ollama: always present (talks to a local/proxied server). Whether + # it's actually reachable is the cache's job to figure out. + self.providers["ollama"] = OllamaProvider() + logger.info("Initialized Ollama provider at %s", settings.ollama_url) - # Google Gemini (fallback provider) if settings.google_api_key: from .google import GoogleProvider @@ -54,9 +95,8 @@ class ProviderRouter: api_key=settings.google_api_key, default_model=settings.google_default_model, ) - logger.info("Initialized Google Gemini provider (fallback)") + logger.info("Initialized Google Gemini provider") - # OpenRouter (if API key configured) if settings.openrouter_api_key: self.providers["openrouter"] = OpenAICompatProvider( name="openrouter", @@ -66,7 +106,6 @@ class ProviderRouter: ) logger.info("Initialized OpenRouter provider") - # Groq (if API key configured) if settings.groq_api_key: self.providers["groq"] = OpenAICompatProvider( name="groq", @@ -75,7 +114,6 @@ class ProviderRouter: ) logger.info("Initialized Groq provider") - # Together (if API key configured) if settings.together_api_key: self.providers["together"] = OpenAICompatProvider( name="together", @@ -84,83 +122,82 @@ class ProviderRouter: ) logger.info("Initialized Together provider") + # ------------------------------------------------------------------ + # Helpers + # ------------------------------------------------------------------ + def _parse_model(self, model: str) -> tuple[str, str]: - """Parse model string into (provider, model_name).""" + """Split ``provider/model`` into its parts. + + Bare names (no prefix) default to Ollama for compatibility with + plain OpenAI-style requests. Aliases (``mana/...``) are resolved + before this is ever called. + """ if "/" in model: - parts = model.split("/", 1) - provider = parts[0].lower() - model_name = parts[1] - else: - provider = "ollama" - model_name = model + provider, _, model_name = model.partition("/") + return provider.lower(), model_name + return "ollama", model - return provider, model_name + def _resolve_chain(self, model_or_alias: str) -> list[str]: + """Expand aliases to chains; pass everything else through unchanged.""" + if AliasRegistry.is_alias(model_or_alias): + return list(self.aliases.resolve_chain(model_or_alias)) + return [model_or_alias] - def _get_provider(self, provider_name: str) -> LLMProvider: - """Get provider by name, raise if not available.""" - if provider_name not in self.providers: - available = list(self.providers.keys()) - raise ValueError( - f"Provider '{provider_name}' not available. " - f"Available providers: {available}" - ) - return self.providers[provider_name] + @staticmethod + def _is_retryable(exc: BaseException) -> bool: + """Should we treat this exception as "try the next chain entry"? - def _can_fallback_to_google(self, provider_name: str) -> bool: - """Check if a request can be fallback-routed to Google.""" - return ( - settings.auto_fallback_enabled - and provider_name == "ollama" - and "google" in self.providers - ) + ConnectError / timeouts / 5xx / rate-limits = yes (provider blip). + Auth / capability / blocked / 4xx = no (caller has to fix + something; retrying with a different provider only hides the bug). + """ + if isinstance(exc, ProviderCapabilityError): + return False + if isinstance(exc, ProviderBlockedError): + return False + if isinstance(exc, ProviderAuthError): + return False + if isinstance(exc, ProviderRateLimitError): + return True + if isinstance( + exc, + ( + httpx.ConnectError, + httpx.ConnectTimeout, + httpx.ReadError, + httpx.ReadTimeout, + httpx.RemoteProtocolError, + httpx.WriteError, + httpx.WriteTimeout, + httpx.PoolTimeout, + ), + ): + return True + if isinstance(exc, httpx.HTTPStatusError): + return exc.response.status_code >= 500 + if isinstance(exc, ProviderError): + # Any other provider-side error — treat as retryable. + # Subclasses with explicit non-retry semantics are caught above. + return True + # Unknown exception types: do NOT silently retry. Better to + # surface a strange error than hide a real bug behind a fallback. + return False - def _should_use_ollama(self) -> bool: - """Determine if Ollama should handle the request based on load.""" - return self._ollama_concurrent < settings.ollama_max_concurrent - - async def _get_ollama_health_cached(self) -> dict[str, Any]: - """Get Ollama health with caching (5s TTL).""" - cached, cached_at = self._ollama_health_cache - if cached is not None and (time.time() - cached_at) < self._health_cache_ttl: - return cached - - try: - provider = self.providers.get("ollama") - if provider: - result = await provider.health_check() - else: - result = {"status": "unhealthy", "error": "no ollama provider"} - except Exception as e: - result = {"status": "unhealthy", "error": str(e)} - - self._ollama_health_cache = (result, time.time()) - return result - - async def _fallback_to_google( - self, - request: ChatCompletionRequest, - model_name: str, - original_error: Exception | None = None, - ) -> ChatCompletionResponse: - """Route a request to Google Gemini as fallback.""" - from .google import GoogleProvider - - google = self.providers["google"] - assert isinstance(google, GoogleProvider) - - gemini_model = google.map_model(model_name) - reason = f"error: {original_error}" if original_error else "overloaded" - logger.warning( - f"Falling back to Google Gemini ({gemini_model}) for ollama/{model_name} ({reason})" - ) - return await google.chat_completion(request, gemini_model) + @staticmethod + def _exception_summary(exc: BaseException) -> str: + """Compact one-liner for cache.last_error and log entries.""" + return f"{type(exc).__name__}: {exc}" def _check_tool_capability( - self, provider: LLMProvider, model_name: str, request: ChatCompletionRequest + self, + provider: LLMProvider, + model_name: str, + request: ChatCompletionRequest, ) -> None: - """Refuse tool-bearing requests for providers/models without tool support. + """Refuse tool-bearing requests for models that don't support tools. - Silent downgrade (dropping the `tools` payload) is more dangerous + Silent downgrade (dropping the ``tools`` payload) is more dangerous than an explicit error — the caller would get plain text back and have no way to tell the tools never reached the model. """ @@ -169,168 +206,228 @@ class ProviderRouter: if not provider.model_supports_tools(model_name): raise ProviderCapabilityError( f"{provider.name}/{model_name} does not support tool calling. " - "Choose a tool-capable model (e.g. gemini-2.5-flash, llama3.1:*)" + "Choose a tool-capable model (e.g. groq/llama-3.3-70b-versatile)" ) + # ------------------------------------------------------------------ + # Core fallback executor (non-streaming) + # ------------------------------------------------------------------ + + async def _execute_with_fallback( + self, + model_or_alias: str, + request: ChatCompletionRequest, + call: Callable[[LLMProvider, str, ChatCompletionRequest], Awaitable[T]], + ) -> T: + """Walk the resolved chain, returning the first successful result. + + ``call`` is the operation to run against each chain entry, e.g. + ``lambda p, m, req: p.chat_completion(req, m)``. The function + receives the provider instance, the model name (without the + provider prefix), and the original request. + """ + chain = self._resolve_chain(model_or_alias) + attempts: list[tuple[str, str]] = [] + last_exc: Exception | None = None + + for entry in chain: + provider_name, model_name = self._parse_model(entry) + if provider_name not in self.providers: + logger.debug( + "skip chain entry %s — provider %s not configured here", + entry, + provider_name, + ) + attempts.append((entry, "unconfigured")) + continue + + if not self.health_cache.is_healthy(provider_name): + logger.debug("skip chain entry %s — cache says unhealthy", entry) + attempts.append((entry, "cache-unhealthy")) + continue + + provider = self.providers[provider_name] + self._check_tool_capability(provider, model_name, request) + + try: + logger.info( + "execute → %s (alias=%s)", + entry, + model_or_alias if model_or_alias != entry else "", + ) + result = await call(provider, model_name, request) + self.health_cache.mark_healthy(provider_name) + return result + except Exception as e: + if not self._is_retryable(e): + # Caller error / non-retryable provider error — propagate + # without touching the health cache. The cache is for + # liveness, not for recording what the user asked for + # being wrong. + raise + self.health_cache.mark_unhealthy(provider_name, self._exception_summary(e)) + attempts.append((entry, type(e).__name__)) + last_exc = e + logger.warning( + "execute %s failed (retryable, will try next): %s", + entry, + e, + ) + + raise NoHealthyProviderError(model_or_alias, attempts, last_exc) + + # ------------------------------------------------------------------ + # Public API — non-streaming + # ------------------------------------------------------------------ + async def chat_completion( self, request: ChatCompletionRequest, ) -> ChatCompletionResponse: - """Route chat completion request with auto-fallback.""" - provider_name, model_name = self._parse_model(request.model) + """Chat completion with alias resolution + health-aware fallback.""" - # Non-Ollama providers: direct routing, no fallback - if provider_name != "ollama": - provider = self._get_provider(provider_name) - self._check_tool_capability(provider, model_name, request) - logger.info(f"Routing chat completion to {provider_name}/{model_name}") - return await provider.chat_completion(request, model_name) + async def call(provider: LLMProvider, model: str, req: ChatCompletionRequest): + return await provider.chat_completion(req, model) - # Ollama with fallback logic - can_fallback = self._can_fallback_to_google(provider_name) + return await self._execute_with_fallback(request.model, request, call) - # Check if Ollama is overloaded - if can_fallback and not self._should_use_ollama(): - return await self._fallback_to_google(request, model_name) - - # Try Ollama first - provider = self._get_provider("ollama") - self._check_tool_capability(provider, model_name, request) - logger.info(f"Routing chat completion to ollama/{model_name}") - self._ollama_concurrent += 1 - - try: - return await provider.chat_completion(request, model_name) - except Exception as e: - logger.error(f"Chat completion failed on ollama: {e}") - if can_fallback: - return await self._fallback_to_google(request, model_name, e) - raise - finally: - self._ollama_concurrent -= 1 + # ------------------------------------------------------------------ + # Public API — streaming (pre-first-byte fallback) + # ------------------------------------------------------------------ async def chat_completion_stream( self, request: ChatCompletionRequest, ) -> AsyncIterator[ChatCompletionStreamResponse]: - """Route streaming chat completion with auto-fallback.""" - provider_name, model_name = self._parse_model(request.model) + """Streaming variant. Falls back BEFORE the first chunk arrives; + once the first chunk has been yielded the provider is committed + and any further error propagates. - # Non-Ollama: direct - if provider_name != "ollama": - provider = self._get_provider(provider_name) + Why pre-first-byte only: stitching half-streams from two different + providers would mix two voices in the output and is impossible to + sanity-check after the fact. + """ + chain = self._resolve_chain(request.model) + attempts: list[tuple[str, str]] = [] + last_exc: Exception | None = None + + for entry in chain: + provider_name, model_name = self._parse_model(entry) + if provider_name not in self.providers: + attempts.append((entry, "unconfigured")) + continue + if not self.health_cache.is_healthy(provider_name): + attempts.append((entry, "cache-unhealthy")) + continue + + provider = self.providers[provider_name] self._check_tool_capability(provider, model_name, request) - logger.info(f"Routing streaming to {provider_name}/{model_name}") - async for chunk in provider.chat_completion_stream(request, model_name): + + stream = provider.chat_completion_stream(request, model_name) + try: + first_chunk = await stream.__anext__() + except StopAsyncIteration: + # Empty stream is a successful but content-free response. + # Commit and exit cleanly. + self.health_cache.mark_healthy(provider_name) + logger.info("stream %s yielded empty response", entry) + return + except Exception as e: + if not self._is_retryable(e): + raise + self.health_cache.mark_unhealthy(provider_name, self._exception_summary(e)) + attempts.append((entry, type(e).__name__)) + last_exc = e + logger.warning( + "stream %s failed before first byte (retryable, trying next): %s", + entry, + e, + ) + continue + + # First byte landed — commit the provider, mark healthy, drain + # the rest of the stream. Any error from here on propagates; + # it is NOT safe to splice another provider's output in. + self.health_cache.mark_healthy(provider_name) + logger.info("stream → %s (committed after first chunk)", entry) + yield first_chunk + async for chunk in stream: yield chunk return - # Ollama with fallback - can_fallback = self._can_fallback_to_google(provider_name) + raise NoHealthyProviderError(request.model, attempts, last_exc) - if can_fallback and not self._should_use_ollama(): - from .google import GoogleProvider + # ------------------------------------------------------------------ + # Embeddings — no fallback (out of scope for M3, separate concerns) + # ------------------------------------------------------------------ - google = self.providers["google"] - assert isinstance(google, GoogleProvider) - gemini_model = google.map_model(model_name) - logger.warning(f"Streaming fallback to Google Gemini ({gemini_model})") - async for chunk in google.chat_completion_stream(request, gemini_model): - yield chunk - return - - provider = self._get_provider("ollama") - self._check_tool_capability(provider, model_name, request) - logger.info(f"Routing streaming to ollama/{model_name}") - self._ollama_concurrent += 1 - - try: - async for chunk in provider.chat_completion_stream(request, model_name): - yield chunk - except Exception as e: - logger.error(f"Streaming failed on ollama: {e}") - if can_fallback: - from .google import GoogleProvider - - google = self.providers["google"] - assert isinstance(google, GoogleProvider) - gemini_model = google.map_model(model_name) - logger.warning(f"Streaming fallback to Google Gemini ({gemini_model})") - async for chunk in google.chat_completion_stream(request, gemini_model): - yield chunk - else: - raise - finally: - self._ollama_concurrent -= 1 - - async def embeddings( - self, - request: EmbeddingRequest, - ) -> EmbeddingResponse: - """Route embeddings request to appropriate provider.""" + async def embeddings(self, request: EmbeddingRequest) -> EmbeddingResponse: + """Route an embeddings request directly. No alias / fallback.""" provider_name, model_name = self._parse_model(request.model) - provider = self._get_provider(provider_name) - - logger.info(f"Routing embeddings to {provider_name}/{model_name}") - + if provider_name not in self.providers: + available = list(self.providers) + raise ValueError( + f"Provider '{provider_name}' not available. Available: {available}" + ) + provider = self.providers[provider_name] + logger.info("embeddings → %s/%s", provider_name, model_name) return await provider.embeddings(request, model_name) - async def list_models(self) -> list[ModelInfo]: - """List all available models from all providers.""" - all_models: list[ModelInfo] = [] + # ------------------------------------------------------------------ + # Discovery / introspection + # ------------------------------------------------------------------ + async def list_models(self) -> list[ModelInfo]: + """List all available models from all configured providers. + + Best-effort: providers that error are skipped with a warning so a + single broken provider can't take down ``GET /v1/models``. + """ + all_models: list[ModelInfo] = [] for provider in self.providers.values(): try: - models = await provider.list_models() - all_models.extend(models) - except Exception as e: - logger.warning(f"Failed to list models from {provider.name}: {e}") - + all_models.extend(await provider.list_models()) + except Exception as e: # noqa: BLE001 + logger.warning("Failed to list models from %s: %s", provider.name, e) return all_models async def get_model(self, model_id: str) -> ModelInfo | None: - """Get specific model info.""" + """Look up a single model by id, dispatching on the prefix.""" provider_name, model_name = self._parse_model(model_id) - if provider_name not in self.providers: return None - - provider = self.providers[provider_name] - models = await provider.list_models() - - for model in models: - if model.id == model_id or model.id.endswith(f"/{model_name}"): - return model - + models = await self.providers[provider_name].list_models() + for m in models: + if m.id == model_id or m.id.endswith(f"/{model_name}"): + return m return None async def health_check(self) -> dict[str, Any]: - """Check health of all providers.""" - results: dict[str, Any] = {} + """Snapshot of the per-provider liveness cache. - for name, provider in self.providers.items(): - results[name] = await provider.health_check() - - # Overall status - all_healthy = all(r.get("status") == "healthy" for r in results.values()) - any_healthy = any(r.get("status") == "healthy" for r in results.values()) - - status_info: dict[str, Any] = { - "status": "healthy" if all_healthy else ("degraded" if any_healthy else "unhealthy"), - "providers": results, - } - - # Include fallback info - if settings.auto_fallback_enabled and "google" in self.providers: - status_info["fallback"] = { - "enabled": True, - "ollama_concurrent": self._ollama_concurrent, - "ollama_max_concurrent": settings.ollama_max_concurrent, + Returns the same shape as before for backwards-compat with + ``GET /health`` (deprecated structure — M4 will swap to a + cleaner ``/v1/health`` endpoint). + """ + snapshot = self.health_cache.snapshot(expected=list(self.providers)) + providers_out: dict[str, Any] = {} + for name, state in snapshot.items(): + providers_out[name] = { + "status": "healthy" if state.healthy else "unhealthy", + "consecutive_failures": state.consecutive_failures, + "last_error": state.last_error, + "last_check_unix": state.last_check or None, + "unhealthy_until_unix": state.unhealthy_until or None, } - return status_info + all_healthy = all(state.healthy for state in snapshot.values()) + any_healthy = any(state.healthy for state in snapshot.values()) + return { + "status": "healthy" if all_healthy else ("degraded" if any_healthy else "unhealthy"), + "providers": providers_out, + } async def close(self) -> None: - """Close all providers.""" + """Close all provider clients.""" for provider in self.providers.values(): await provider.close() diff --git a/services/mana-llm/tests/test_providers.py b/services/mana-llm/tests/test_providers.py index b0f85cc70..e5c3ffcb0 100644 --- a/services/mana-llm/tests/test_providers.py +++ b/services/mana-llm/tests/test_providers.py @@ -1,44 +1,54 @@ """Provider tests.""" +from pathlib import Path + import pytest -from src.models import ChatCompletionRequest, Message +from src.aliases import AliasRegistry +from src.health import ProviderHealthCache +from src.models import ChatCompletionRequest, EmbeddingRequest, Message from src.providers import OllamaProvider, OpenAICompatProvider, ProviderRouter +@pytest.fixture +def shipped_aliases() -> AliasRegistry: + """The repo's real aliases.yaml — same one production uses.""" + return AliasRegistry(Path(__file__).resolve().parents[1] / "aliases.yaml") + + +@pytest.fixture +def router(shipped_aliases: AliasRegistry) -> ProviderRouter: + return ProviderRouter(aliases=shipped_aliases, health_cache=ProviderHealthCache()) + + class TestProviderRouter: - """Test provider routing logic.""" - - def test_parse_model_with_provider(self): - """Test model parsing with provider prefix.""" - router = ProviderRouter() + """Tests for the helpers exposed by the router.""" + def test_parse_model_with_provider(self, router: ProviderRouter) -> None: provider, model = router._parse_model("ollama/gemma3:4b") assert provider == "ollama" assert model == "gemma3:4b" - def test_parse_model_without_provider(self): - """Test model parsing without provider prefix (defaults to ollama).""" - router = ProviderRouter() - + def test_parse_model_without_provider(self, router: ProviderRouter) -> None: + # Bare names default to Ollama for OpenAI-style compat. provider, model = router._parse_model("gemma3:4b") assert provider == "ollama" assert model == "gemma3:4b" - def test_parse_model_openrouter(self): - """Test model parsing for OpenRouter.""" - router = ProviderRouter() - + def test_parse_model_openrouter(self, router: ProviderRouter) -> None: provider, model = router._parse_model("openrouter/meta-llama/llama-3.1-8b-instruct") assert provider == "openrouter" assert model == "meta-llama/llama-3.1-8b-instruct" - def test_get_invalid_provider(self): - """Test getting invalid provider raises error.""" - router = ProviderRouter() - + @pytest.mark.asyncio + async def test_embeddings_unknown_provider_raises(self, router: ProviderRouter) -> None: + # Embeddings don't go through the alias/fallback pipeline — they + # hit the requested provider directly. Asking for an unconfigured + # one is a config error and must raise loudly. with pytest.raises(ValueError, match="not available"): - router._get_provider("invalid_provider") + await router.embeddings( + EmbeddingRequest(model="bogus_provider/x", input="hi") + ) class TestOllamaProvider: diff --git a/services/mana-llm/tests/test_router_fallback.py b/services/mana-llm/tests/test_router_fallback.py new file mode 100644 index 000000000..cb1f34743 --- /dev/null +++ b/services/mana-llm/tests/test_router_fallback.py @@ -0,0 +1,526 @@ +"""Tests for ProviderRouter fallback / alias execution (M3).""" + +from __future__ import annotations + +from collections.abc import AsyncIterator +from typing import Any + +import httpx +import pytest + +from src.aliases import AliasRegistry +from src.health import ProviderHealthCache +from src.models import ( + ChatCompletionRequest, + ChatCompletionResponse, + ChatCompletionStreamResponse, + Choice, + DeltaContent, + EmbeddingRequest, + EmbeddingResponse, + Message, + MessageResponse, + ModelInfo, + StreamChoice, +) +from src.providers import ProviderRouter +from src.providers.base import LLMProvider +from src.providers.errors import ( + NoHealthyProviderError, + ProviderAuthError, + ProviderCapabilityError, + ProviderRateLimitError, +) + + +# --------------------------------------------------------------------------- +# Test doubles +# --------------------------------------------------------------------------- + + +class MockProvider(LLMProvider): + """Provider that lets tests inject a sequence of behaviours. + + Each call pops one entry from ``behaviors``. Strings ``"ok"`` and + ``"empty"`` are sentinels for normal returns; everything else (a + BaseException instance / class) is raised. + """ + + supports_tools = True + + def __init__(self, name: str, behaviors: list[Any] | None = None) -> None: + self.name = name + self._behaviors: list[Any] = list(behaviors or []) + self.calls: list[str] = [] + + def push(self, *behaviors: Any) -> None: + self._behaviors.extend(behaviors) + + def _next(self) -> Any: + return self._behaviors.pop(0) if self._behaviors else "ok" + + async def chat_completion( + self, request: ChatCompletionRequest, model: str + ) -> ChatCompletionResponse: + self.calls.append(model) + b = self._next() + if isinstance(b, type) and issubclass(b, BaseException): + raise b("simulated") + if isinstance(b, BaseException): + raise b + return _ok_response(self.name, model) + + async def chat_completion_stream( + self, request: ChatCompletionRequest, model: str + ) -> AsyncIterator[ChatCompletionStreamResponse]: + self.calls.append(model) + b = self._next() + if isinstance(b, type) and issubclass(b, BaseException): + raise b("simulated") + if isinstance(b, BaseException): + raise b + if b == "empty": + return + for content in ("Hello", " ", "world"): + yield ChatCompletionStreamResponse( + model=f"{self.name}/{model}", + choices=[StreamChoice(delta=DeltaContent(content=content))], + ) + + async def list_models(self) -> list[ModelInfo]: + return [ModelInfo(id=f"{self.name}/{m}") for m in ("modelA", "modelB")] + + async def embeddings( + self, request: EmbeddingRequest, model: str + ) -> EmbeddingResponse: + raise NotImplementedError + + async def health_check(self) -> dict[str, Any]: + return {"status": "healthy"} + + +class FailFirstChunkProvider(MockProvider): + """Streaming provider that raises BEFORE the first chunk every time. + + Kept separate from MockProvider's behaviour list so the per-call + semantics stay simple — this one models a permanently-broken streamer. + """ + + def __init__(self, name: str, exc: BaseException) -> None: + super().__init__(name) + self._exc = exc + + async def chat_completion_stream(self, request, model): # type: ignore[override] + self.calls.append(model) + raise self._exc + # the yield is unreachable but keeps the function an async generator + yield # pragma: no cover + + +def _ok_response(provider: str, model: str) -> ChatCompletionResponse: + return ChatCompletionResponse( + model=f"{provider}/{model}", + choices=[ + Choice( + message=MessageResponse(content="ok"), + finish_reason="stop", + ) + ], + ) + + +def _request(model: str) -> ChatCompletionRequest: + return ChatCompletionRequest( + model=model, + messages=[Message(role="user", content="hi")], + ) + + +def _aliases_yaml(tmp_path) -> AliasRegistry: + """A two-alias config used across most tests.""" + cfg = ( + "aliases:\n" + " mana/long-form:\n" + ' description: "long"\n' + " chain:\n" + " - alpha/m1\n" + " - beta/m2\n" + " - gamma/m3\n" + " mana/single:\n" + ' description: "single-entry"\n' + " chain:\n" + " - alpha/solo\n" + ) + p = tmp_path / "aliases.yaml" + p.write_text(cfg) + return AliasRegistry(p) + + +def _make_router( + tmp_path, + *, + providers: dict[str, MockProvider], + cache: ProviderHealthCache | None = None, +) -> ProviderRouter: + aliases = _aliases_yaml(tmp_path) + router = ProviderRouter(aliases=aliases, health_cache=cache or ProviderHealthCache()) + # Replace the auto-initialised live providers with the test doubles. + router.providers = dict(providers) + return router + + +# --------------------------------------------------------------------------- +# Non-streaming chain walking +# --------------------------------------------------------------------------- + + +class TestChatCompletionChain: + @pytest.mark.asyncio + async def test_first_provider_ok_returns_immediately(self, tmp_path) -> None: + alpha = MockProvider("alpha", ["ok"]) + beta = MockProvider("beta") + router = _make_router(tmp_path, providers={"alpha": alpha, "beta": beta}) + + resp = await router.chat_completion(_request("mana/long-form")) + + assert resp.model == "alpha/m1" + assert alpha.calls == ["m1"] + assert beta.calls == [] # never reached + + @pytest.mark.asyncio + async def test_falls_through_on_connect_error(self, tmp_path) -> None: + alpha = MockProvider("alpha", [httpx.ConnectError("dead")]) + beta = MockProvider("beta", ["ok"]) + router = _make_router(tmp_path, providers={"alpha": alpha, "beta": beta}) + + resp = await router.chat_completion(_request("mana/long-form")) + + assert resp.model == "beta/m2" + assert alpha.calls == ["m1"] + assert beta.calls == ["m2"] + + @pytest.mark.asyncio + async def test_skips_unconfigured_chain_entries(self, tmp_path) -> None: + # gamma isn't configured at all → chain should silently skip it + # rather than raise. + alpha = MockProvider("alpha", [httpx.ConnectError("dead")]) + beta = MockProvider("beta", [httpx.ConnectError("dead too")]) + router = _make_router(tmp_path, providers={"alpha": alpha, "beta": beta}) + + with pytest.raises(NoHealthyProviderError) as exc_info: + await router.chat_completion(_request("mana/long-form")) + # All three entries appear in attempts: two as ConnectError, one + # as unconfigured (not a fatal error, just skipped). + attempts = exc_info.value.attempts + assert ("alpha/m1", "ConnectError") in attempts + assert ("beta/m2", "ConnectError") in attempts + assert ("gamma/m3", "unconfigured") in attempts + + @pytest.mark.asyncio + async def test_skips_cache_unhealthy(self, tmp_path) -> None: + cache = ProviderHealthCache(failure_threshold=1) + cache.mark_unhealthy("alpha", "stale") + alpha = MockProvider("alpha", ["ok"]) + beta = MockProvider("beta", ["ok"]) + router = _make_router( + tmp_path, providers={"alpha": alpha, "beta": beta}, cache=cache + ) + + resp = await router.chat_completion(_request("mana/long-form")) + + assert alpha.calls == [] # router skipped per cache + assert beta.calls == ["m2"] + assert resp.model == "beta/m2" + + @pytest.mark.asyncio + async def test_5xx_treated_as_retryable(self, tmp_path) -> None: + five_hundred = httpx.HTTPStatusError( + "boom", + request=httpx.Request("POST", "http://x"), + response=httpx.Response(503), + ) + alpha = MockProvider("alpha", [five_hundred]) + beta = MockProvider("beta", ["ok"]) + router = _make_router(tmp_path, providers={"alpha": alpha, "beta": beta}) + + resp = await router.chat_completion(_request("mana/long-form")) + + assert resp.model == "beta/m2" + + @pytest.mark.asyncio + async def test_4xx_propagates(self, tmp_path) -> None: + four_hundred = httpx.HTTPStatusError( + "bad request", + request=httpx.Request("POST", "http://x"), + response=httpx.Response(422), + ) + alpha = MockProvider("alpha", [four_hundred]) + beta = MockProvider("beta", ["ok"]) + router = _make_router(tmp_path, providers={"alpha": alpha, "beta": beta}) + + with pytest.raises(httpx.HTTPStatusError): + await router.chat_completion(_request("mana/long-form")) + # Beta never tried — caller's request needs fixing, retrying + # against another model would just hide the bug. + assert beta.calls == [] + + @pytest.mark.asyncio + async def test_capability_error_propagates(self, tmp_path) -> None: + alpha = MockProvider("alpha", [ProviderCapabilityError("no tools")]) + beta = MockProvider("beta", ["ok"]) + router = _make_router(tmp_path, providers={"alpha": alpha, "beta": beta}) + + with pytest.raises(ProviderCapabilityError): + await router.chat_completion(_request("mana/long-form")) + assert beta.calls == [] + + @pytest.mark.asyncio + async def test_auth_error_propagates(self, tmp_path) -> None: + # Auth errors mean OUR setup is broken (wrong key); falling back + # to the next provider hides the misconfiguration. + alpha = MockProvider("alpha", [ProviderAuthError("bad key")]) + beta = MockProvider("beta", ["ok"]) + router = _make_router(tmp_path, providers={"alpha": alpha, "beta": beta}) + + with pytest.raises(ProviderAuthError): + await router.chat_completion(_request("mana/long-form")) + assert beta.calls == [] + + @pytest.mark.asyncio + async def test_rate_limit_is_retryable(self, tmp_path) -> None: + alpha = MockProvider("alpha", [ProviderRateLimitError("slow down")]) + beta = MockProvider("beta", ["ok"]) + router = _make_router(tmp_path, providers={"alpha": alpha, "beta": beta}) + + resp = await router.chat_completion(_request("mana/long-form")) + + assert resp.model == "beta/m2" + + @pytest.mark.asyncio + async def test_all_fail_raises_no_healthy_provider(self, tmp_path) -> None: + alpha = MockProvider("alpha", [httpx.ConnectError("a")]) + beta = MockProvider("beta", [httpx.ConnectError("b")]) + gamma = MockProvider("gamma", [httpx.ConnectError("c")]) + router = _make_router( + tmp_path, providers={"alpha": alpha, "beta": beta, "gamma": gamma} + ) + + with pytest.raises(NoHealthyProviderError) as exc_info: + await router.chat_completion(_request("mana/long-form")) + assert exc_info.value.model_or_alias == "mana/long-form" + assert isinstance(exc_info.value.last_exception, httpx.ConnectError) + # 503 status so calling code (mana-api etc.) can decide to retry + # later vs surface a clean error to the user. + assert exc_info.value.http_status == 503 + + @pytest.mark.asyncio + async def test_direct_provider_string_no_alias_resolution(self, tmp_path) -> None: + # Caller bypasses aliases by passing a direct provider/model. + # No fallback chain — fail = fail. + alpha = MockProvider("alpha", [httpx.ConnectError("dead")]) + beta = MockProvider("beta", ["ok"]) + router = _make_router(tmp_path, providers={"alpha": alpha, "beta": beta}) + + with pytest.raises(NoHealthyProviderError): + await router.chat_completion(_request("alpha/anything")) + # Beta would have served if this had been an alias — but it + # wasn't, so beta never gets touched. + assert beta.calls == [] + + +# --------------------------------------------------------------------------- +# Health-cache feedback: success clears, failure marks +# --------------------------------------------------------------------------- + + +class TestHealthCacheFeedback: + @pytest.mark.asyncio + async def test_success_marks_provider_healthy(self, tmp_path) -> None: + cache = ProviderHealthCache(failure_threshold=1) + cache.mark_unhealthy("alpha", "stale-from-probe") + # After the cache TTL the cache thinks alpha might be OK again, + # so the router will try it; success must fully clear the state. + # (Force half-open by zeroing backoff.) + alpha = MockProvider("alpha", ["ok"]) + router = _make_router( + tmp_path, + providers={"alpha": alpha}, + cache=ProviderHealthCache(), # fresh cache, alpha optimistic + ) + + await router.chat_completion(_request("mana/single")) + + assert router.health_cache.get_state("alpha").healthy is True + assert router.health_cache.get_state("alpha").consecutive_failures == 0 + + @pytest.mark.asyncio + async def test_failure_marks_provider_unhealthy(self, tmp_path) -> None: + # threshold=1 so a single fail is enough to flip the breaker. + cache = ProviderHealthCache(failure_threshold=1) + alpha = MockProvider("alpha", [httpx.ConnectError("boom")]) + beta = MockProvider("beta", ["ok"]) + router = _make_router( + tmp_path, providers={"alpha": alpha, "beta": beta}, cache=cache + ) + + await router.chat_completion(_request("mana/long-form")) + + assert cache.get_state("alpha").healthy is False + assert cache.get_state("alpha").last_error is not None + assert "ConnectError" in cache.get_state("alpha").last_error + + @pytest.mark.asyncio + async def test_propagating_error_does_not_touch_cache(self, tmp_path) -> None: + # Auth/Capability errors are about CALLER state, not provider + # health — the cache must stay clean so a real outage later + # isn't masked by stale "marked unhealthy because of bad key". + cache = ProviderHealthCache(failure_threshold=1) + alpha = MockProvider("alpha", [ProviderAuthError("bad key")]) + router = _make_router(tmp_path, providers={"alpha": alpha}, cache=cache) + + with pytest.raises(ProviderAuthError): + await router.chat_completion(_request("mana/single")) + + # No state recorded. + assert cache.get_state("alpha") is None + + +# --------------------------------------------------------------------------- +# Streaming pre-first-byte fallback +# --------------------------------------------------------------------------- + + +class TestChatCompletionStream: + @pytest.mark.asyncio + async def test_first_provider_streams_normally(self, tmp_path) -> None: + alpha = MockProvider("alpha", ["ok"]) + beta = MockProvider("beta") + router = _make_router(tmp_path, providers={"alpha": alpha, "beta": beta}) + + chunks = [ + c async for c in router.chat_completion_stream(_request("mana/long-form")) + ] + + assert beta.calls == [] + assert len(chunks) == 3 + assert "".join(c.choices[0].delta.content or "" for c in chunks) == "Hello world" + + @pytest.mark.asyncio + async def test_pre_first_byte_failure_falls_back(self, tmp_path) -> None: + alpha = FailFirstChunkProvider("alpha", httpx.ConnectError("dead")) + beta = MockProvider("beta", ["ok"]) + router = _make_router(tmp_path, providers={"alpha": alpha, "beta": beta}) + + chunks = [ + c async for c in router.chat_completion_stream(_request("mana/long-form")) + ] + + assert alpha.calls == ["m1"] + assert beta.calls == ["m2"] + assert len(chunks) == 3 + assert all(c.model == "beta/m2" for c in chunks) + + @pytest.mark.asyncio + async def test_pre_first_byte_4xx_propagates_no_fallback(self, tmp_path) -> None: + alpha = FailFirstChunkProvider("alpha", ProviderCapabilityError("no tools")) + beta = MockProvider("beta", ["ok"]) + router = _make_router(tmp_path, providers={"alpha": alpha, "beta": beta}) + + with pytest.raises(ProviderCapabilityError): + async for _ in router.chat_completion_stream(_request("mana/long-form")): + pass + assert beta.calls == [] + + @pytest.mark.asyncio + async def test_empty_stream_commits_without_fallback(self, tmp_path) -> None: + # Empty-but-successful stream is a valid response, not a failure + # we should retry — committing avoids accidentally calling two + # providers and double-billing. + alpha = MockProvider("alpha", ["empty"]) + beta = MockProvider("beta", ["ok"]) + router = _make_router(tmp_path, providers={"alpha": alpha, "beta": beta}) + + chunks = [ + c async for c in router.chat_completion_stream(_request("mana/long-form")) + ] + + assert chunks == [] + assert beta.calls == [] # didn't fall through + + @pytest.mark.asyncio + async def test_mid_stream_failure_does_not_fall_back(self, tmp_path) -> None: + # Custom provider that yields once then raises mid-stream — the + # router has already committed and must let the error propagate + # rather than splice in another provider's voice. + class MidStreamFailProvider(MockProvider): + async def chat_completion_stream(self, request, model): # type: ignore[override] + self.calls.append(model) + yield ChatCompletionStreamResponse( + model=f"{self.name}/{model}", + choices=[StreamChoice(delta=DeltaContent(content="halb"))], + ) + raise httpx.RemoteProtocolError("connection dropped") + + alpha = MidStreamFailProvider("alpha") + beta = MockProvider("beta", ["ok"]) + router = _make_router(tmp_path, providers={"alpha": alpha, "beta": beta}) + + collected: list[str] = [] + with pytest.raises(httpx.RemoteProtocolError): + async for chunk in router.chat_completion_stream(_request("mana/long-form")): + collected.append(chunk.choices[0].delta.content or "") + + # We got the half-chunk that landed before the break; beta was + # NOT called as fallback. + assert collected == ["halb"] + assert beta.calls == [] + + @pytest.mark.asyncio + async def test_all_fail_streaming_raises_no_healthy_provider(self, tmp_path) -> None: + alpha = FailFirstChunkProvider("alpha", httpx.ConnectError("a")) + beta = FailFirstChunkProvider("beta", httpx.ConnectError("b")) + gamma = FailFirstChunkProvider("gamma", httpx.ConnectError("c")) + router = _make_router( + tmp_path, providers={"alpha": alpha, "beta": beta, "gamma": gamma} + ) + + with pytest.raises(NoHealthyProviderError): + async for _ in router.chat_completion_stream(_request("mana/long-form")): + pass + + +# --------------------------------------------------------------------------- +# Health-check shape (still using the cache snapshot) +# --------------------------------------------------------------------------- + + +class TestHealthCheck: + @pytest.mark.asyncio + async def test_health_check_lists_known_providers(self, tmp_path) -> None: + # Even if no probe has run yet, every configured provider should + # appear in the snapshot (zero-defaults) so /health has a stable + # shape for monitors. + alpha = MockProvider("alpha") + beta = MockProvider("beta") + router = _make_router(tmp_path, providers={"alpha": alpha, "beta": beta}) + + out = await router.health_check() + + assert set(out["providers"].keys()) == {"alpha", "beta"} + assert out["status"] == "healthy" + assert all(p["status"] == "healthy" for p in out["providers"].values()) + + @pytest.mark.asyncio + async def test_health_check_degraded_when_one_unhealthy(self, tmp_path) -> None: + cache = ProviderHealthCache(failure_threshold=1) + cache.mark_unhealthy("alpha", "boom") + alpha = MockProvider("alpha") + beta = MockProvider("beta") + router = _make_router( + tmp_path, providers={"alpha": alpha, "beta": beta}, cache=cache + ) + + out = await router.health_check() + assert out["status"] == "degraded" + assert out["providers"]["alpha"]["status"] == "unhealthy" + assert out["providers"]["beta"]["status"] == "healthy"