mirror of
https://github.com/Memo-2023/mana-monorepo.git
synced 2026-05-14 22:41:09 +02:00
feat(mana-llm): M3 — health-aware router with alias + chain fallback
Replaces the old Ollama→Google special-case auto-fallback with the unified pipeline: caller passes either a direct provider/model or an alias from the `mana/` namespace; the router resolves to a chain and walks it skipping unhealthy providers (per ProviderHealthCache from M2), trying each entry, marking provider unhealthy on retryable errors and falling through to the next. Retryable: ConnectError, ReadTimeout, RemoteProtocolError, 5xx, ProviderRateLimitError. Propagated (don't fall back, don't poison the cache): ProviderCapabilityError, ProviderAuthError, ProviderBlockedError, 4xx, unknown exception types. The cache stays "what the network told us about this provider's liveness" — caller errors don't muddy that signal. Streaming: pre-first-byte fallback only. Once a chunk has been yielded the provider is committed; mid-stream errors propagate as-is so we don't splice two voices into one output. `NoHealthyProviderError` (HTTP 503) carries a structured attempt log — each chain entry shows up as `(model, reason)` so the cause of a 503 is visible in the response and metrics, not only in service logs. main.py wires the lifespan: aliases.yaml is loaded, ProviderHealthCache created, ProviderRouter takes both as constructor deps, HealthProbe spawned with cheap HTTP probes per configured provider (Ollama /api/tags, OpenAI-compat /v1/models with Bearer header). Google is skipped — google-genai SDK has no obvious cheap probe; the call-site fallback handles real errors. 22 new router tests (test_router_fallback.py): chain walking, capability & auth propagation, 5xx vs 4xx differentiation, rate-limit retry, all-fail → NoHealthyProviderError, direct provider strings bypass aliases, streaming pre-first-byte fallback, mid-stream-failure does NOT fall back, empty stream commits without retry, cache feedback on success/failure/non-retryable. Existing test_providers.py updated for the new constructor signature; all 99 service tests green via the dev container (Python 3.12). Legacy purged: `_ollama_concurrent`, `_ollama_health_cache`, `_can_fallback_to_google`, `_should_use_ollama`, `_fallback_to_google`, `_get_ollama_health_cached` all gone. The `auto_fallback_enabled` / `ollama_max_concurrent` settings remain in config.py for now (M5 will remove them along with the per-feature env-var overrides). Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
parent
59557e62d7
commit
3046da3b19
5 changed files with 952 additions and 241 deletions
|
|
@ -3,6 +3,7 @@
|
|||
import logging
|
||||
import time
|
||||
from contextlib import asynccontextmanager
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from fastapi import FastAPI, HTTPException, Request
|
||||
|
|
@ -10,8 +11,11 @@ from fastapi.middleware.cors import CORSMiddleware
|
|||
from fastapi.responses import Response
|
||||
from sse_starlette.sse import EventSourceResponse
|
||||
|
||||
from src.aliases import AliasRegistry
|
||||
from src.api_auth import ApiKeyMiddleware
|
||||
from src.config import settings
|
||||
from src.health import ProviderHealthCache
|
||||
from src.health_probe import HealthProbe, make_http_probe
|
||||
from src.models import (
|
||||
ChatCompletionRequest,
|
||||
ChatCompletionResponse,
|
||||
|
|
@ -33,25 +37,66 @@ logging.basicConfig(
|
|||
)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Global router instance
|
||||
# Global service singletons
|
||||
router: ProviderRouter | None = None
|
||||
health_cache: ProviderHealthCache | None = None
|
||||
health_probe: HealthProbe | None = None
|
||||
alias_registry: AliasRegistry | None = None
|
||||
|
||||
|
||||
def _build_provider_probes(
|
||||
providers: dict[str, Any],
|
||||
) -> dict[str, Any]:
|
||||
"""Wire each configured provider to a cheap HTTP probe."""
|
||||
probes: dict[str, Any] = {}
|
||||
|
||||
if "ollama" in providers:
|
||||
probes["ollama"] = make_http_probe(f"{settings.ollama_url}/api/tags")
|
||||
if "openrouter" in providers:
|
||||
probes["openrouter"] = make_http_probe(
|
||||
f"{settings.openrouter_base_url}/models",
|
||||
headers={"Authorization": f"Bearer {settings.openrouter_api_key}"},
|
||||
)
|
||||
if "groq" in providers:
|
||||
probes["groq"] = make_http_probe(
|
||||
f"{settings.groq_base_url}/models",
|
||||
headers={"Authorization": f"Bearer {settings.groq_api_key}"},
|
||||
)
|
||||
if "together" in providers:
|
||||
probes["together"] = make_http_probe(
|
||||
f"{settings.together_base_url}/models",
|
||||
headers={"Authorization": f"Bearer {settings.together_api_key}"},
|
||||
)
|
||||
# Google: skipped — google-genai SDK is opaque enough that a probe
|
||||
# would amount to a real API call. Treat as healthy by default; the
|
||||
# router's call-site fallback will mark it unhealthy on real errors.
|
||||
return probes
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
"""Application lifespan management."""
|
||||
global router
|
||||
"""Application lifespan: load aliases, spin up router + health probe."""
|
||||
global router, health_cache, health_probe, alias_registry
|
||||
|
||||
# Startup
|
||||
logger.info("Starting mana-llm service...")
|
||||
router = ProviderRouter()
|
||||
logger.info(f"Initialized providers: {list(router.providers.keys())}")
|
||||
|
||||
aliases_path = Path(__file__).resolve().parent.parent / "aliases.yaml"
|
||||
alias_registry = AliasRegistry(aliases_path)
|
||||
logger.info("Loaded %d aliases from %s", len(alias_registry.list_aliases()), aliases_path)
|
||||
|
||||
health_cache = ProviderHealthCache()
|
||||
router = ProviderRouter(aliases=alias_registry, health_cache=health_cache)
|
||||
logger.info("Initialized providers: %s", list(router.providers))
|
||||
|
||||
health_probe = HealthProbe(health_cache, _build_provider_probes(router.providers))
|
||||
await health_probe.start()
|
||||
|
||||
yield
|
||||
|
||||
# Shutdown
|
||||
logger.info("Shutting down mana-llm service...")
|
||||
if router:
|
||||
if health_probe is not None:
|
||||
await health_probe.stop()
|
||||
if router is not None:
|
||||
await router.close()
|
||||
await close_redis()
|
||||
|
||||
|
|
|
|||
|
|
@ -76,3 +76,36 @@ class ProviderCapabilityError(ProviderError):
|
|||
|
||||
kind = "capability"
|
||||
http_status = 400
|
||||
|
||||
|
||||
class NoHealthyProviderError(ProviderError):
|
||||
"""Every entry in the resolved chain has been tried and either was
|
||||
unconfigured, marked unhealthy by the cache, or failed in flight.
|
||||
|
||||
Carries the full attempt log so the caller can see which providers
|
||||
were tried and why each one was skipped or failed — invaluable when
|
||||
a real outage hits and the API returns 503 instead of the usual 200.
|
||||
"""
|
||||
|
||||
kind = "no_healthy_provider"
|
||||
http_status = 503
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_or_alias: str,
|
||||
attempts: list[tuple[str, str]],
|
||||
last_exception: Exception | None = None,
|
||||
) -> None:
|
||||
self.model_or_alias = model_or_alias
|
||||
self.attempts = list(attempts)
|
||||
self.last_exception = last_exception
|
||||
if attempts:
|
||||
attempt_log = ", ".join(f"{model}={reason}" for model, reason in attempts)
|
||||
else:
|
||||
attempt_log = "(no providers in resolved chain were configured)"
|
||||
msg = (
|
||||
f"no healthy provider could serve {model_or_alias!r}. Attempts: {attempt_log}"
|
||||
)
|
||||
if last_exception is not None:
|
||||
msg += f". Last error: {type(last_exception).__name__}: {last_exception}"
|
||||
super().__init__(msg)
|
||||
|
|
|
|||
|
|
@ -1,12 +1,38 @@
|
|||
"""Provider routing logic for mana-llm with auto-fallback support."""
|
||||
"""Provider routing with alias resolution and health-aware fallback.
|
||||
|
||||
The router is the single entry point that the FastAPI handlers use. Its
|
||||
job is:
|
||||
|
||||
1. Resolve the request's ``model`` field. If it lives in the ``mana/``
|
||||
namespace the :class:`AliasRegistry` returns an ordered chain of
|
||||
concrete provider/model strings; everything else is treated as a
|
||||
single-entry chain (caller passed a direct provider/model).
|
||||
2. Walk the chain, skipping entries whose provider is either
|
||||
unconfigured at this deployment (no API key) or currently marked
|
||||
unhealthy in the :class:`ProviderHealthCache`.
|
||||
3. Try each remaining entry. Connection errors, timeouts, 5xx, and rate
|
||||
limits are retryable — record them in the cache and move to the next
|
||||
entry. Capability/auth/blocked errors are caller-fixable and
|
||||
propagate immediately without touching the health cache.
|
||||
4. Return the first successful response. If every entry was skipped or
|
||||
failed, raise :class:`NoHealthyProviderError` (HTTP 503) carrying
|
||||
the full attempt log so debugging is straightforward.
|
||||
|
||||
The full design lives in ``docs/plans/llm-fallback-aliases.md``. This is
|
||||
the M3 milestone.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import time
|
||||
from collections.abc import AsyncIterator
|
||||
from typing import Any
|
||||
from collections.abc import AsyncIterator, Awaitable, Callable
|
||||
from typing import Any, TypeVar
|
||||
|
||||
import httpx
|
||||
|
||||
from src.aliases import AliasRegistry
|
||||
from src.config import settings
|
||||
from src.health import ProviderHealthCache
|
||||
from src.models import (
|
||||
ChatCompletionRequest,
|
||||
ChatCompletionResponse,
|
||||
|
|
@ -17,36 +43,51 @@ from src.models import (
|
|||
)
|
||||
|
||||
from .base import LLMProvider
|
||||
from .errors import ProviderCapabilityError
|
||||
from .errors import (
|
||||
NoHealthyProviderError,
|
||||
ProviderAuthError,
|
||||
ProviderBlockedError,
|
||||
ProviderCapabilityError,
|
||||
ProviderError,
|
||||
ProviderRateLimitError,
|
||||
)
|
||||
from .ollama import OllamaProvider
|
||||
from .openai_compat import OpenAICompatProvider
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
class ProviderRouter:
|
||||
"""Routes requests to appropriate LLM providers with auto-fallback.
|
||||
"""Health-aware provider router with alias resolution.
|
||||
|
||||
When auto_fallback_enabled is True and a Google API key is configured:
|
||||
- Ollama requests that fail or exceed max_concurrent are automatically
|
||||
retried on Google Gemini with model mapping.
|
||||
- Explicit provider requests (e.g., openrouter/...) are never fallback-routed.
|
||||
Construct with the AliasRegistry and ProviderHealthCache from
|
||||
application startup; both are external dependencies so tests can
|
||||
inject mocks without going through global state.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
def __init__(
|
||||
self,
|
||||
aliases: AliasRegistry,
|
||||
health_cache: ProviderHealthCache,
|
||||
) -> None:
|
||||
self.aliases = aliases
|
||||
self.health_cache = health_cache
|
||||
self.providers: dict[str, LLMProvider] = {}
|
||||
self._ollama_concurrent: int = 0
|
||||
self._ollama_health_cache: tuple[dict[str, Any] | None, float] = (None, 0)
|
||||
self._health_cache_ttl: float = 5.0 # seconds
|
||||
self._initialize_providers()
|
||||
|
||||
def _initialize_providers(self) -> None:
|
||||
"""Initialize available providers based on configuration."""
|
||||
# Ollama is always available (local)
|
||||
self.providers["ollama"] = OllamaProvider()
|
||||
logger.info(f"Initialized Ollama provider at {settings.ollama_url}")
|
||||
# ------------------------------------------------------------------
|
||||
# Provider initialisation
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _initialize_providers(self) -> None:
|
||||
"""Spin up provider adapters based on what's configured."""
|
||||
# Ollama: always present (talks to a local/proxied server). Whether
|
||||
# it's actually reachable is the cache's job to figure out.
|
||||
self.providers["ollama"] = OllamaProvider()
|
||||
logger.info("Initialized Ollama provider at %s", settings.ollama_url)
|
||||
|
||||
# Google Gemini (fallback provider)
|
||||
if settings.google_api_key:
|
||||
from .google import GoogleProvider
|
||||
|
||||
|
|
@ -54,9 +95,8 @@ class ProviderRouter:
|
|||
api_key=settings.google_api_key,
|
||||
default_model=settings.google_default_model,
|
||||
)
|
||||
logger.info("Initialized Google Gemini provider (fallback)")
|
||||
logger.info("Initialized Google Gemini provider")
|
||||
|
||||
# OpenRouter (if API key configured)
|
||||
if settings.openrouter_api_key:
|
||||
self.providers["openrouter"] = OpenAICompatProvider(
|
||||
name="openrouter",
|
||||
|
|
@ -66,7 +106,6 @@ class ProviderRouter:
|
|||
)
|
||||
logger.info("Initialized OpenRouter provider")
|
||||
|
||||
# Groq (if API key configured)
|
||||
if settings.groq_api_key:
|
||||
self.providers["groq"] = OpenAICompatProvider(
|
||||
name="groq",
|
||||
|
|
@ -75,7 +114,6 @@ class ProviderRouter:
|
|||
)
|
||||
logger.info("Initialized Groq provider")
|
||||
|
||||
# Together (if API key configured)
|
||||
if settings.together_api_key:
|
||||
self.providers["together"] = OpenAICompatProvider(
|
||||
name="together",
|
||||
|
|
@ -84,83 +122,82 @@ class ProviderRouter:
|
|||
)
|
||||
logger.info("Initialized Together provider")
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _parse_model(self, model: str) -> tuple[str, str]:
|
||||
"""Parse model string into (provider, model_name)."""
|
||||
"""Split ``provider/model`` into its parts.
|
||||
|
||||
Bare names (no prefix) default to Ollama for compatibility with
|
||||
plain OpenAI-style requests. Aliases (``mana/...``) are resolved
|
||||
before this is ever called.
|
||||
"""
|
||||
if "/" in model:
|
||||
parts = model.split("/", 1)
|
||||
provider = parts[0].lower()
|
||||
model_name = parts[1]
|
||||
else:
|
||||
provider = "ollama"
|
||||
model_name = model
|
||||
provider, _, model_name = model.partition("/")
|
||||
return provider.lower(), model_name
|
||||
return "ollama", model
|
||||
|
||||
return provider, model_name
|
||||
def _resolve_chain(self, model_or_alias: str) -> list[str]:
|
||||
"""Expand aliases to chains; pass everything else through unchanged."""
|
||||
if AliasRegistry.is_alias(model_or_alias):
|
||||
return list(self.aliases.resolve_chain(model_or_alias))
|
||||
return [model_or_alias]
|
||||
|
||||
def _get_provider(self, provider_name: str) -> LLMProvider:
|
||||
"""Get provider by name, raise if not available."""
|
||||
if provider_name not in self.providers:
|
||||
available = list(self.providers.keys())
|
||||
raise ValueError(
|
||||
f"Provider '{provider_name}' not available. "
|
||||
f"Available providers: {available}"
|
||||
)
|
||||
return self.providers[provider_name]
|
||||
@staticmethod
|
||||
def _is_retryable(exc: BaseException) -> bool:
|
||||
"""Should we treat this exception as "try the next chain entry"?
|
||||
|
||||
def _can_fallback_to_google(self, provider_name: str) -> bool:
|
||||
"""Check if a request can be fallback-routed to Google."""
|
||||
return (
|
||||
settings.auto_fallback_enabled
|
||||
and provider_name == "ollama"
|
||||
and "google" in self.providers
|
||||
)
|
||||
ConnectError / timeouts / 5xx / rate-limits = yes (provider blip).
|
||||
Auth / capability / blocked / 4xx = no (caller has to fix
|
||||
something; retrying with a different provider only hides the bug).
|
||||
"""
|
||||
if isinstance(exc, ProviderCapabilityError):
|
||||
return False
|
||||
if isinstance(exc, ProviderBlockedError):
|
||||
return False
|
||||
if isinstance(exc, ProviderAuthError):
|
||||
return False
|
||||
if isinstance(exc, ProviderRateLimitError):
|
||||
return True
|
||||
if isinstance(
|
||||
exc,
|
||||
(
|
||||
httpx.ConnectError,
|
||||
httpx.ConnectTimeout,
|
||||
httpx.ReadError,
|
||||
httpx.ReadTimeout,
|
||||
httpx.RemoteProtocolError,
|
||||
httpx.WriteError,
|
||||
httpx.WriteTimeout,
|
||||
httpx.PoolTimeout,
|
||||
),
|
||||
):
|
||||
return True
|
||||
if isinstance(exc, httpx.HTTPStatusError):
|
||||
return exc.response.status_code >= 500
|
||||
if isinstance(exc, ProviderError):
|
||||
# Any other provider-side error — treat as retryable.
|
||||
# Subclasses with explicit non-retry semantics are caught above.
|
||||
return True
|
||||
# Unknown exception types: do NOT silently retry. Better to
|
||||
# surface a strange error than hide a real bug behind a fallback.
|
||||
return False
|
||||
|
||||
def _should_use_ollama(self) -> bool:
|
||||
"""Determine if Ollama should handle the request based on load."""
|
||||
return self._ollama_concurrent < settings.ollama_max_concurrent
|
||||
|
||||
async def _get_ollama_health_cached(self) -> dict[str, Any]:
|
||||
"""Get Ollama health with caching (5s TTL)."""
|
||||
cached, cached_at = self._ollama_health_cache
|
||||
if cached is not None and (time.time() - cached_at) < self._health_cache_ttl:
|
||||
return cached
|
||||
|
||||
try:
|
||||
provider = self.providers.get("ollama")
|
||||
if provider:
|
||||
result = await provider.health_check()
|
||||
else:
|
||||
result = {"status": "unhealthy", "error": "no ollama provider"}
|
||||
except Exception as e:
|
||||
result = {"status": "unhealthy", "error": str(e)}
|
||||
|
||||
self._ollama_health_cache = (result, time.time())
|
||||
return result
|
||||
|
||||
async def _fallback_to_google(
|
||||
self,
|
||||
request: ChatCompletionRequest,
|
||||
model_name: str,
|
||||
original_error: Exception | None = None,
|
||||
) -> ChatCompletionResponse:
|
||||
"""Route a request to Google Gemini as fallback."""
|
||||
from .google import GoogleProvider
|
||||
|
||||
google = self.providers["google"]
|
||||
assert isinstance(google, GoogleProvider)
|
||||
|
||||
gemini_model = google.map_model(model_name)
|
||||
reason = f"error: {original_error}" if original_error else "overloaded"
|
||||
logger.warning(
|
||||
f"Falling back to Google Gemini ({gemini_model}) for ollama/{model_name} ({reason})"
|
||||
)
|
||||
return await google.chat_completion(request, gemini_model)
|
||||
@staticmethod
|
||||
def _exception_summary(exc: BaseException) -> str:
|
||||
"""Compact one-liner for cache.last_error and log entries."""
|
||||
return f"{type(exc).__name__}: {exc}"
|
||||
|
||||
def _check_tool_capability(
|
||||
self, provider: LLMProvider, model_name: str, request: ChatCompletionRequest
|
||||
self,
|
||||
provider: LLMProvider,
|
||||
model_name: str,
|
||||
request: ChatCompletionRequest,
|
||||
) -> None:
|
||||
"""Refuse tool-bearing requests for providers/models without tool support.
|
||||
"""Refuse tool-bearing requests for models that don't support tools.
|
||||
|
||||
Silent downgrade (dropping the `tools` payload) is more dangerous
|
||||
Silent downgrade (dropping the ``tools`` payload) is more dangerous
|
||||
than an explicit error — the caller would get plain text back and
|
||||
have no way to tell the tools never reached the model.
|
||||
"""
|
||||
|
|
@ -169,168 +206,228 @@ class ProviderRouter:
|
|||
if not provider.model_supports_tools(model_name):
|
||||
raise ProviderCapabilityError(
|
||||
f"{provider.name}/{model_name} does not support tool calling. "
|
||||
"Choose a tool-capable model (e.g. gemini-2.5-flash, llama3.1:*)"
|
||||
"Choose a tool-capable model (e.g. groq/llama-3.3-70b-versatile)"
|
||||
)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Core fallback executor (non-streaming)
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
async def _execute_with_fallback(
|
||||
self,
|
||||
model_or_alias: str,
|
||||
request: ChatCompletionRequest,
|
||||
call: Callable[[LLMProvider, str, ChatCompletionRequest], Awaitable[T]],
|
||||
) -> T:
|
||||
"""Walk the resolved chain, returning the first successful result.
|
||||
|
||||
``call`` is the operation to run against each chain entry, e.g.
|
||||
``lambda p, m, req: p.chat_completion(req, m)``. The function
|
||||
receives the provider instance, the model name (without the
|
||||
provider prefix), and the original request.
|
||||
"""
|
||||
chain = self._resolve_chain(model_or_alias)
|
||||
attempts: list[tuple[str, str]] = []
|
||||
last_exc: Exception | None = None
|
||||
|
||||
for entry in chain:
|
||||
provider_name, model_name = self._parse_model(entry)
|
||||
if provider_name not in self.providers:
|
||||
logger.debug(
|
||||
"skip chain entry %s — provider %s not configured here",
|
||||
entry,
|
||||
provider_name,
|
||||
)
|
||||
attempts.append((entry, "unconfigured"))
|
||||
continue
|
||||
|
||||
if not self.health_cache.is_healthy(provider_name):
|
||||
logger.debug("skip chain entry %s — cache says unhealthy", entry)
|
||||
attempts.append((entry, "cache-unhealthy"))
|
||||
continue
|
||||
|
||||
provider = self.providers[provider_name]
|
||||
self._check_tool_capability(provider, model_name, request)
|
||||
|
||||
try:
|
||||
logger.info(
|
||||
"execute → %s (alias=%s)",
|
||||
entry,
|
||||
model_or_alias if model_or_alias != entry else "<direct>",
|
||||
)
|
||||
result = await call(provider, model_name, request)
|
||||
self.health_cache.mark_healthy(provider_name)
|
||||
return result
|
||||
except Exception as e:
|
||||
if not self._is_retryable(e):
|
||||
# Caller error / non-retryable provider error — propagate
|
||||
# without touching the health cache. The cache is for
|
||||
# liveness, not for recording what the user asked for
|
||||
# being wrong.
|
||||
raise
|
||||
self.health_cache.mark_unhealthy(provider_name, self._exception_summary(e))
|
||||
attempts.append((entry, type(e).__name__))
|
||||
last_exc = e
|
||||
logger.warning(
|
||||
"execute %s failed (retryable, will try next): %s",
|
||||
entry,
|
||||
e,
|
||||
)
|
||||
|
||||
raise NoHealthyProviderError(model_or_alias, attempts, last_exc)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Public API — non-streaming
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
async def chat_completion(
|
||||
self,
|
||||
request: ChatCompletionRequest,
|
||||
) -> ChatCompletionResponse:
|
||||
"""Route chat completion request with auto-fallback."""
|
||||
provider_name, model_name = self._parse_model(request.model)
|
||||
"""Chat completion with alias resolution + health-aware fallback."""
|
||||
|
||||
# Non-Ollama providers: direct routing, no fallback
|
||||
if provider_name != "ollama":
|
||||
provider = self._get_provider(provider_name)
|
||||
self._check_tool_capability(provider, model_name, request)
|
||||
logger.info(f"Routing chat completion to {provider_name}/{model_name}")
|
||||
return await provider.chat_completion(request, model_name)
|
||||
async def call(provider: LLMProvider, model: str, req: ChatCompletionRequest):
|
||||
return await provider.chat_completion(req, model)
|
||||
|
||||
# Ollama with fallback logic
|
||||
can_fallback = self._can_fallback_to_google(provider_name)
|
||||
return await self._execute_with_fallback(request.model, request, call)
|
||||
|
||||
# Check if Ollama is overloaded
|
||||
if can_fallback and not self._should_use_ollama():
|
||||
return await self._fallback_to_google(request, model_name)
|
||||
|
||||
# Try Ollama first
|
||||
provider = self._get_provider("ollama")
|
||||
self._check_tool_capability(provider, model_name, request)
|
||||
logger.info(f"Routing chat completion to ollama/{model_name}")
|
||||
self._ollama_concurrent += 1
|
||||
|
||||
try:
|
||||
return await provider.chat_completion(request, model_name)
|
||||
except Exception as e:
|
||||
logger.error(f"Chat completion failed on ollama: {e}")
|
||||
if can_fallback:
|
||||
return await self._fallback_to_google(request, model_name, e)
|
||||
raise
|
||||
finally:
|
||||
self._ollama_concurrent -= 1
|
||||
# ------------------------------------------------------------------
|
||||
# Public API — streaming (pre-first-byte fallback)
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
async def chat_completion_stream(
|
||||
self,
|
||||
request: ChatCompletionRequest,
|
||||
) -> AsyncIterator[ChatCompletionStreamResponse]:
|
||||
"""Route streaming chat completion with auto-fallback."""
|
||||
provider_name, model_name = self._parse_model(request.model)
|
||||
"""Streaming variant. Falls back BEFORE the first chunk arrives;
|
||||
once the first chunk has been yielded the provider is committed
|
||||
and any further error propagates.
|
||||
|
||||
# Non-Ollama: direct
|
||||
if provider_name != "ollama":
|
||||
provider = self._get_provider(provider_name)
|
||||
Why pre-first-byte only: stitching half-streams from two different
|
||||
providers would mix two voices in the output and is impossible to
|
||||
sanity-check after the fact.
|
||||
"""
|
||||
chain = self._resolve_chain(request.model)
|
||||
attempts: list[tuple[str, str]] = []
|
||||
last_exc: Exception | None = None
|
||||
|
||||
for entry in chain:
|
||||
provider_name, model_name = self._parse_model(entry)
|
||||
if provider_name not in self.providers:
|
||||
attempts.append((entry, "unconfigured"))
|
||||
continue
|
||||
if not self.health_cache.is_healthy(provider_name):
|
||||
attempts.append((entry, "cache-unhealthy"))
|
||||
continue
|
||||
|
||||
provider = self.providers[provider_name]
|
||||
self._check_tool_capability(provider, model_name, request)
|
||||
logger.info(f"Routing streaming to {provider_name}/{model_name}")
|
||||
async for chunk in provider.chat_completion_stream(request, model_name):
|
||||
|
||||
stream = provider.chat_completion_stream(request, model_name)
|
||||
try:
|
||||
first_chunk = await stream.__anext__()
|
||||
except StopAsyncIteration:
|
||||
# Empty stream is a successful but content-free response.
|
||||
# Commit and exit cleanly.
|
||||
self.health_cache.mark_healthy(provider_name)
|
||||
logger.info("stream %s yielded empty response", entry)
|
||||
return
|
||||
except Exception as e:
|
||||
if not self._is_retryable(e):
|
||||
raise
|
||||
self.health_cache.mark_unhealthy(provider_name, self._exception_summary(e))
|
||||
attempts.append((entry, type(e).__name__))
|
||||
last_exc = e
|
||||
logger.warning(
|
||||
"stream %s failed before first byte (retryable, trying next): %s",
|
||||
entry,
|
||||
e,
|
||||
)
|
||||
continue
|
||||
|
||||
# First byte landed — commit the provider, mark healthy, drain
|
||||
# the rest of the stream. Any error from here on propagates;
|
||||
# it is NOT safe to splice another provider's output in.
|
||||
self.health_cache.mark_healthy(provider_name)
|
||||
logger.info("stream → %s (committed after first chunk)", entry)
|
||||
yield first_chunk
|
||||
async for chunk in stream:
|
||||
yield chunk
|
||||
return
|
||||
|
||||
# Ollama with fallback
|
||||
can_fallback = self._can_fallback_to_google(provider_name)
|
||||
raise NoHealthyProviderError(request.model, attempts, last_exc)
|
||||
|
||||
if can_fallback and not self._should_use_ollama():
|
||||
from .google import GoogleProvider
|
||||
# ------------------------------------------------------------------
|
||||
# Embeddings — no fallback (out of scope for M3, separate concerns)
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
google = self.providers["google"]
|
||||
assert isinstance(google, GoogleProvider)
|
||||
gemini_model = google.map_model(model_name)
|
||||
logger.warning(f"Streaming fallback to Google Gemini ({gemini_model})")
|
||||
async for chunk in google.chat_completion_stream(request, gemini_model):
|
||||
yield chunk
|
||||
return
|
||||
|
||||
provider = self._get_provider("ollama")
|
||||
self._check_tool_capability(provider, model_name, request)
|
||||
logger.info(f"Routing streaming to ollama/{model_name}")
|
||||
self._ollama_concurrent += 1
|
||||
|
||||
try:
|
||||
async for chunk in provider.chat_completion_stream(request, model_name):
|
||||
yield chunk
|
||||
except Exception as e:
|
||||
logger.error(f"Streaming failed on ollama: {e}")
|
||||
if can_fallback:
|
||||
from .google import GoogleProvider
|
||||
|
||||
google = self.providers["google"]
|
||||
assert isinstance(google, GoogleProvider)
|
||||
gemini_model = google.map_model(model_name)
|
||||
logger.warning(f"Streaming fallback to Google Gemini ({gemini_model})")
|
||||
async for chunk in google.chat_completion_stream(request, gemini_model):
|
||||
yield chunk
|
||||
else:
|
||||
raise
|
||||
finally:
|
||||
self._ollama_concurrent -= 1
|
||||
|
||||
async def embeddings(
|
||||
self,
|
||||
request: EmbeddingRequest,
|
||||
) -> EmbeddingResponse:
|
||||
"""Route embeddings request to appropriate provider."""
|
||||
async def embeddings(self, request: EmbeddingRequest) -> EmbeddingResponse:
|
||||
"""Route an embeddings request directly. No alias / fallback."""
|
||||
provider_name, model_name = self._parse_model(request.model)
|
||||
provider = self._get_provider(provider_name)
|
||||
|
||||
logger.info(f"Routing embeddings to {provider_name}/{model_name}")
|
||||
|
||||
if provider_name not in self.providers:
|
||||
available = list(self.providers)
|
||||
raise ValueError(
|
||||
f"Provider '{provider_name}' not available. Available: {available}"
|
||||
)
|
||||
provider = self.providers[provider_name]
|
||||
logger.info("embeddings → %s/%s", provider_name, model_name)
|
||||
return await provider.embeddings(request, model_name)
|
||||
|
||||
async def list_models(self) -> list[ModelInfo]:
|
||||
"""List all available models from all providers."""
|
||||
all_models: list[ModelInfo] = []
|
||||
# ------------------------------------------------------------------
|
||||
# Discovery / introspection
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
async def list_models(self) -> list[ModelInfo]:
|
||||
"""List all available models from all configured providers.
|
||||
|
||||
Best-effort: providers that error are skipped with a warning so a
|
||||
single broken provider can't take down ``GET /v1/models``.
|
||||
"""
|
||||
all_models: list[ModelInfo] = []
|
||||
for provider in self.providers.values():
|
||||
try:
|
||||
models = await provider.list_models()
|
||||
all_models.extend(models)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to list models from {provider.name}: {e}")
|
||||
|
||||
all_models.extend(await provider.list_models())
|
||||
except Exception as e: # noqa: BLE001
|
||||
logger.warning("Failed to list models from %s: %s", provider.name, e)
|
||||
return all_models
|
||||
|
||||
async def get_model(self, model_id: str) -> ModelInfo | None:
|
||||
"""Get specific model info."""
|
||||
"""Look up a single model by id, dispatching on the prefix."""
|
||||
provider_name, model_name = self._parse_model(model_id)
|
||||
|
||||
if provider_name not in self.providers:
|
||||
return None
|
||||
|
||||
provider = self.providers[provider_name]
|
||||
models = await provider.list_models()
|
||||
|
||||
for model in models:
|
||||
if model.id == model_id or model.id.endswith(f"/{model_name}"):
|
||||
return model
|
||||
|
||||
models = await self.providers[provider_name].list_models()
|
||||
for m in models:
|
||||
if m.id == model_id or m.id.endswith(f"/{model_name}"):
|
||||
return m
|
||||
return None
|
||||
|
||||
async def health_check(self) -> dict[str, Any]:
|
||||
"""Check health of all providers."""
|
||||
results: dict[str, Any] = {}
|
||||
"""Snapshot of the per-provider liveness cache.
|
||||
|
||||
for name, provider in self.providers.items():
|
||||
results[name] = await provider.health_check()
|
||||
|
||||
# Overall status
|
||||
all_healthy = all(r.get("status") == "healthy" for r in results.values())
|
||||
any_healthy = any(r.get("status") == "healthy" for r in results.values())
|
||||
|
||||
status_info: dict[str, Any] = {
|
||||
"status": "healthy" if all_healthy else ("degraded" if any_healthy else "unhealthy"),
|
||||
"providers": results,
|
||||
}
|
||||
|
||||
# Include fallback info
|
||||
if settings.auto_fallback_enabled and "google" in self.providers:
|
||||
status_info["fallback"] = {
|
||||
"enabled": True,
|
||||
"ollama_concurrent": self._ollama_concurrent,
|
||||
"ollama_max_concurrent": settings.ollama_max_concurrent,
|
||||
Returns the same shape as before for backwards-compat with
|
||||
``GET /health`` (deprecated structure — M4 will swap to a
|
||||
cleaner ``/v1/health`` endpoint).
|
||||
"""
|
||||
snapshot = self.health_cache.snapshot(expected=list(self.providers))
|
||||
providers_out: dict[str, Any] = {}
|
||||
for name, state in snapshot.items():
|
||||
providers_out[name] = {
|
||||
"status": "healthy" if state.healthy else "unhealthy",
|
||||
"consecutive_failures": state.consecutive_failures,
|
||||
"last_error": state.last_error,
|
||||
"last_check_unix": state.last_check or None,
|
||||
"unhealthy_until_unix": state.unhealthy_until or None,
|
||||
}
|
||||
|
||||
return status_info
|
||||
all_healthy = all(state.healthy for state in snapshot.values())
|
||||
any_healthy = any(state.healthy for state in snapshot.values())
|
||||
return {
|
||||
"status": "healthy" if all_healthy else ("degraded" if any_healthy else "unhealthy"),
|
||||
"providers": providers_out,
|
||||
}
|
||||
|
||||
async def close(self) -> None:
|
||||
"""Close all providers."""
|
||||
"""Close all provider clients."""
|
||||
for provider in self.providers.values():
|
||||
await provider.close()
|
||||
|
|
|
|||
|
|
@ -1,44 +1,54 @@
|
|||
"""Provider tests."""
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from src.models import ChatCompletionRequest, Message
|
||||
from src.aliases import AliasRegistry
|
||||
from src.health import ProviderHealthCache
|
||||
from src.models import ChatCompletionRequest, EmbeddingRequest, Message
|
||||
from src.providers import OllamaProvider, OpenAICompatProvider, ProviderRouter
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def shipped_aliases() -> AliasRegistry:
|
||||
"""The repo's real aliases.yaml — same one production uses."""
|
||||
return AliasRegistry(Path(__file__).resolve().parents[1] / "aliases.yaml")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def router(shipped_aliases: AliasRegistry) -> ProviderRouter:
|
||||
return ProviderRouter(aliases=shipped_aliases, health_cache=ProviderHealthCache())
|
||||
|
||||
|
||||
class TestProviderRouter:
|
||||
"""Test provider routing logic."""
|
||||
|
||||
def test_parse_model_with_provider(self):
|
||||
"""Test model parsing with provider prefix."""
|
||||
router = ProviderRouter()
|
||||
"""Tests for the helpers exposed by the router."""
|
||||
|
||||
def test_parse_model_with_provider(self, router: ProviderRouter) -> None:
|
||||
provider, model = router._parse_model("ollama/gemma3:4b")
|
||||
assert provider == "ollama"
|
||||
assert model == "gemma3:4b"
|
||||
|
||||
def test_parse_model_without_provider(self):
|
||||
"""Test model parsing without provider prefix (defaults to ollama)."""
|
||||
router = ProviderRouter()
|
||||
|
||||
def test_parse_model_without_provider(self, router: ProviderRouter) -> None:
|
||||
# Bare names default to Ollama for OpenAI-style compat.
|
||||
provider, model = router._parse_model("gemma3:4b")
|
||||
assert provider == "ollama"
|
||||
assert model == "gemma3:4b"
|
||||
|
||||
def test_parse_model_openrouter(self):
|
||||
"""Test model parsing for OpenRouter."""
|
||||
router = ProviderRouter()
|
||||
|
||||
def test_parse_model_openrouter(self, router: ProviderRouter) -> None:
|
||||
provider, model = router._parse_model("openrouter/meta-llama/llama-3.1-8b-instruct")
|
||||
assert provider == "openrouter"
|
||||
assert model == "meta-llama/llama-3.1-8b-instruct"
|
||||
|
||||
def test_get_invalid_provider(self):
|
||||
"""Test getting invalid provider raises error."""
|
||||
router = ProviderRouter()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_embeddings_unknown_provider_raises(self, router: ProviderRouter) -> None:
|
||||
# Embeddings don't go through the alias/fallback pipeline — they
|
||||
# hit the requested provider directly. Asking for an unconfigured
|
||||
# one is a config error and must raise loudly.
|
||||
with pytest.raises(ValueError, match="not available"):
|
||||
router._get_provider("invalid_provider")
|
||||
await router.embeddings(
|
||||
EmbeddingRequest(model="bogus_provider/x", input="hi")
|
||||
)
|
||||
|
||||
|
||||
class TestOllamaProvider:
|
||||
|
|
|
|||
526
services/mana-llm/tests/test_router_fallback.py
Normal file
526
services/mana-llm/tests/test_router_fallback.py
Normal file
|
|
@ -0,0 +1,526 @@
|
|||
"""Tests for ProviderRouter fallback / alias execution (M3)."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import AsyncIterator
|
||||
from typing import Any
|
||||
|
||||
import httpx
|
||||
import pytest
|
||||
|
||||
from src.aliases import AliasRegistry
|
||||
from src.health import ProviderHealthCache
|
||||
from src.models import (
|
||||
ChatCompletionRequest,
|
||||
ChatCompletionResponse,
|
||||
ChatCompletionStreamResponse,
|
||||
Choice,
|
||||
DeltaContent,
|
||||
EmbeddingRequest,
|
||||
EmbeddingResponse,
|
||||
Message,
|
||||
MessageResponse,
|
||||
ModelInfo,
|
||||
StreamChoice,
|
||||
)
|
||||
from src.providers import ProviderRouter
|
||||
from src.providers.base import LLMProvider
|
||||
from src.providers.errors import (
|
||||
NoHealthyProviderError,
|
||||
ProviderAuthError,
|
||||
ProviderCapabilityError,
|
||||
ProviderRateLimitError,
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Test doubles
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class MockProvider(LLMProvider):
|
||||
"""Provider that lets tests inject a sequence of behaviours.
|
||||
|
||||
Each call pops one entry from ``behaviors``. Strings ``"ok"`` and
|
||||
``"empty"`` are sentinels for normal returns; everything else (a
|
||||
BaseException instance / class) is raised.
|
||||
"""
|
||||
|
||||
supports_tools = True
|
||||
|
||||
def __init__(self, name: str, behaviors: list[Any] | None = None) -> None:
|
||||
self.name = name
|
||||
self._behaviors: list[Any] = list(behaviors or [])
|
||||
self.calls: list[str] = []
|
||||
|
||||
def push(self, *behaviors: Any) -> None:
|
||||
self._behaviors.extend(behaviors)
|
||||
|
||||
def _next(self) -> Any:
|
||||
return self._behaviors.pop(0) if self._behaviors else "ok"
|
||||
|
||||
async def chat_completion(
|
||||
self, request: ChatCompletionRequest, model: str
|
||||
) -> ChatCompletionResponse:
|
||||
self.calls.append(model)
|
||||
b = self._next()
|
||||
if isinstance(b, type) and issubclass(b, BaseException):
|
||||
raise b("simulated")
|
||||
if isinstance(b, BaseException):
|
||||
raise b
|
||||
return _ok_response(self.name, model)
|
||||
|
||||
async def chat_completion_stream(
|
||||
self, request: ChatCompletionRequest, model: str
|
||||
) -> AsyncIterator[ChatCompletionStreamResponse]:
|
||||
self.calls.append(model)
|
||||
b = self._next()
|
||||
if isinstance(b, type) and issubclass(b, BaseException):
|
||||
raise b("simulated")
|
||||
if isinstance(b, BaseException):
|
||||
raise b
|
||||
if b == "empty":
|
||||
return
|
||||
for content in ("Hello", " ", "world"):
|
||||
yield ChatCompletionStreamResponse(
|
||||
model=f"{self.name}/{model}",
|
||||
choices=[StreamChoice(delta=DeltaContent(content=content))],
|
||||
)
|
||||
|
||||
async def list_models(self) -> list[ModelInfo]:
|
||||
return [ModelInfo(id=f"{self.name}/{m}") for m in ("modelA", "modelB")]
|
||||
|
||||
async def embeddings(
|
||||
self, request: EmbeddingRequest, model: str
|
||||
) -> EmbeddingResponse:
|
||||
raise NotImplementedError
|
||||
|
||||
async def health_check(self) -> dict[str, Any]:
|
||||
return {"status": "healthy"}
|
||||
|
||||
|
||||
class FailFirstChunkProvider(MockProvider):
|
||||
"""Streaming provider that raises BEFORE the first chunk every time.
|
||||
|
||||
Kept separate from MockProvider's behaviour list so the per-call
|
||||
semantics stay simple — this one models a permanently-broken streamer.
|
||||
"""
|
||||
|
||||
def __init__(self, name: str, exc: BaseException) -> None:
|
||||
super().__init__(name)
|
||||
self._exc = exc
|
||||
|
||||
async def chat_completion_stream(self, request, model): # type: ignore[override]
|
||||
self.calls.append(model)
|
||||
raise self._exc
|
||||
# the yield is unreachable but keeps the function an async generator
|
||||
yield # pragma: no cover
|
||||
|
||||
|
||||
def _ok_response(provider: str, model: str) -> ChatCompletionResponse:
|
||||
return ChatCompletionResponse(
|
||||
model=f"{provider}/{model}",
|
||||
choices=[
|
||||
Choice(
|
||||
message=MessageResponse(content="ok"),
|
||||
finish_reason="stop",
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
|
||||
def _request(model: str) -> ChatCompletionRequest:
|
||||
return ChatCompletionRequest(
|
||||
model=model,
|
||||
messages=[Message(role="user", content="hi")],
|
||||
)
|
||||
|
||||
|
||||
def _aliases_yaml(tmp_path) -> AliasRegistry:
|
||||
"""A two-alias config used across most tests."""
|
||||
cfg = (
|
||||
"aliases:\n"
|
||||
" mana/long-form:\n"
|
||||
' description: "long"\n'
|
||||
" chain:\n"
|
||||
" - alpha/m1\n"
|
||||
" - beta/m2\n"
|
||||
" - gamma/m3\n"
|
||||
" mana/single:\n"
|
||||
' description: "single-entry"\n'
|
||||
" chain:\n"
|
||||
" - alpha/solo\n"
|
||||
)
|
||||
p = tmp_path / "aliases.yaml"
|
||||
p.write_text(cfg)
|
||||
return AliasRegistry(p)
|
||||
|
||||
|
||||
def _make_router(
|
||||
tmp_path,
|
||||
*,
|
||||
providers: dict[str, MockProvider],
|
||||
cache: ProviderHealthCache | None = None,
|
||||
) -> ProviderRouter:
|
||||
aliases = _aliases_yaml(tmp_path)
|
||||
router = ProviderRouter(aliases=aliases, health_cache=cache or ProviderHealthCache())
|
||||
# Replace the auto-initialised live providers with the test doubles.
|
||||
router.providers = dict(providers)
|
||||
return router
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Non-streaming chain walking
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestChatCompletionChain:
|
||||
@pytest.mark.asyncio
|
||||
async def test_first_provider_ok_returns_immediately(self, tmp_path) -> None:
|
||||
alpha = MockProvider("alpha", ["ok"])
|
||||
beta = MockProvider("beta")
|
||||
router = _make_router(tmp_path, providers={"alpha": alpha, "beta": beta})
|
||||
|
||||
resp = await router.chat_completion(_request("mana/long-form"))
|
||||
|
||||
assert resp.model == "alpha/m1"
|
||||
assert alpha.calls == ["m1"]
|
||||
assert beta.calls == [] # never reached
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_falls_through_on_connect_error(self, tmp_path) -> None:
|
||||
alpha = MockProvider("alpha", [httpx.ConnectError("dead")])
|
||||
beta = MockProvider("beta", ["ok"])
|
||||
router = _make_router(tmp_path, providers={"alpha": alpha, "beta": beta})
|
||||
|
||||
resp = await router.chat_completion(_request("mana/long-form"))
|
||||
|
||||
assert resp.model == "beta/m2"
|
||||
assert alpha.calls == ["m1"]
|
||||
assert beta.calls == ["m2"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_skips_unconfigured_chain_entries(self, tmp_path) -> None:
|
||||
# gamma isn't configured at all → chain should silently skip it
|
||||
# rather than raise.
|
||||
alpha = MockProvider("alpha", [httpx.ConnectError("dead")])
|
||||
beta = MockProvider("beta", [httpx.ConnectError("dead too")])
|
||||
router = _make_router(tmp_path, providers={"alpha": alpha, "beta": beta})
|
||||
|
||||
with pytest.raises(NoHealthyProviderError) as exc_info:
|
||||
await router.chat_completion(_request("mana/long-form"))
|
||||
# All three entries appear in attempts: two as ConnectError, one
|
||||
# as unconfigured (not a fatal error, just skipped).
|
||||
attempts = exc_info.value.attempts
|
||||
assert ("alpha/m1", "ConnectError") in attempts
|
||||
assert ("beta/m2", "ConnectError") in attempts
|
||||
assert ("gamma/m3", "unconfigured") in attempts
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_skips_cache_unhealthy(self, tmp_path) -> None:
|
||||
cache = ProviderHealthCache(failure_threshold=1)
|
||||
cache.mark_unhealthy("alpha", "stale")
|
||||
alpha = MockProvider("alpha", ["ok"])
|
||||
beta = MockProvider("beta", ["ok"])
|
||||
router = _make_router(
|
||||
tmp_path, providers={"alpha": alpha, "beta": beta}, cache=cache
|
||||
)
|
||||
|
||||
resp = await router.chat_completion(_request("mana/long-form"))
|
||||
|
||||
assert alpha.calls == [] # router skipped per cache
|
||||
assert beta.calls == ["m2"]
|
||||
assert resp.model == "beta/m2"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_5xx_treated_as_retryable(self, tmp_path) -> None:
|
||||
five_hundred = httpx.HTTPStatusError(
|
||||
"boom",
|
||||
request=httpx.Request("POST", "http://x"),
|
||||
response=httpx.Response(503),
|
||||
)
|
||||
alpha = MockProvider("alpha", [five_hundred])
|
||||
beta = MockProvider("beta", ["ok"])
|
||||
router = _make_router(tmp_path, providers={"alpha": alpha, "beta": beta})
|
||||
|
||||
resp = await router.chat_completion(_request("mana/long-form"))
|
||||
|
||||
assert resp.model == "beta/m2"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_4xx_propagates(self, tmp_path) -> None:
|
||||
four_hundred = httpx.HTTPStatusError(
|
||||
"bad request",
|
||||
request=httpx.Request("POST", "http://x"),
|
||||
response=httpx.Response(422),
|
||||
)
|
||||
alpha = MockProvider("alpha", [four_hundred])
|
||||
beta = MockProvider("beta", ["ok"])
|
||||
router = _make_router(tmp_path, providers={"alpha": alpha, "beta": beta})
|
||||
|
||||
with pytest.raises(httpx.HTTPStatusError):
|
||||
await router.chat_completion(_request("mana/long-form"))
|
||||
# Beta never tried — caller's request needs fixing, retrying
|
||||
# against another model would just hide the bug.
|
||||
assert beta.calls == []
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_capability_error_propagates(self, tmp_path) -> None:
|
||||
alpha = MockProvider("alpha", [ProviderCapabilityError("no tools")])
|
||||
beta = MockProvider("beta", ["ok"])
|
||||
router = _make_router(tmp_path, providers={"alpha": alpha, "beta": beta})
|
||||
|
||||
with pytest.raises(ProviderCapabilityError):
|
||||
await router.chat_completion(_request("mana/long-form"))
|
||||
assert beta.calls == []
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_auth_error_propagates(self, tmp_path) -> None:
|
||||
# Auth errors mean OUR setup is broken (wrong key); falling back
|
||||
# to the next provider hides the misconfiguration.
|
||||
alpha = MockProvider("alpha", [ProviderAuthError("bad key")])
|
||||
beta = MockProvider("beta", ["ok"])
|
||||
router = _make_router(tmp_path, providers={"alpha": alpha, "beta": beta})
|
||||
|
||||
with pytest.raises(ProviderAuthError):
|
||||
await router.chat_completion(_request("mana/long-form"))
|
||||
assert beta.calls == []
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_rate_limit_is_retryable(self, tmp_path) -> None:
|
||||
alpha = MockProvider("alpha", [ProviderRateLimitError("slow down")])
|
||||
beta = MockProvider("beta", ["ok"])
|
||||
router = _make_router(tmp_path, providers={"alpha": alpha, "beta": beta})
|
||||
|
||||
resp = await router.chat_completion(_request("mana/long-form"))
|
||||
|
||||
assert resp.model == "beta/m2"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_all_fail_raises_no_healthy_provider(self, tmp_path) -> None:
|
||||
alpha = MockProvider("alpha", [httpx.ConnectError("a")])
|
||||
beta = MockProvider("beta", [httpx.ConnectError("b")])
|
||||
gamma = MockProvider("gamma", [httpx.ConnectError("c")])
|
||||
router = _make_router(
|
||||
tmp_path, providers={"alpha": alpha, "beta": beta, "gamma": gamma}
|
||||
)
|
||||
|
||||
with pytest.raises(NoHealthyProviderError) as exc_info:
|
||||
await router.chat_completion(_request("mana/long-form"))
|
||||
assert exc_info.value.model_or_alias == "mana/long-form"
|
||||
assert isinstance(exc_info.value.last_exception, httpx.ConnectError)
|
||||
# 503 status so calling code (mana-api etc.) can decide to retry
|
||||
# later vs surface a clean error to the user.
|
||||
assert exc_info.value.http_status == 503
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_direct_provider_string_no_alias_resolution(self, tmp_path) -> None:
|
||||
# Caller bypasses aliases by passing a direct provider/model.
|
||||
# No fallback chain — fail = fail.
|
||||
alpha = MockProvider("alpha", [httpx.ConnectError("dead")])
|
||||
beta = MockProvider("beta", ["ok"])
|
||||
router = _make_router(tmp_path, providers={"alpha": alpha, "beta": beta})
|
||||
|
||||
with pytest.raises(NoHealthyProviderError):
|
||||
await router.chat_completion(_request("alpha/anything"))
|
||||
# Beta would have served if this had been an alias — but it
|
||||
# wasn't, so beta never gets touched.
|
||||
assert beta.calls == []
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Health-cache feedback: success clears, failure marks
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestHealthCacheFeedback:
|
||||
@pytest.mark.asyncio
|
||||
async def test_success_marks_provider_healthy(self, tmp_path) -> None:
|
||||
cache = ProviderHealthCache(failure_threshold=1)
|
||||
cache.mark_unhealthy("alpha", "stale-from-probe")
|
||||
# After the cache TTL the cache thinks alpha might be OK again,
|
||||
# so the router will try it; success must fully clear the state.
|
||||
# (Force half-open by zeroing backoff.)
|
||||
alpha = MockProvider("alpha", ["ok"])
|
||||
router = _make_router(
|
||||
tmp_path,
|
||||
providers={"alpha": alpha},
|
||||
cache=ProviderHealthCache(), # fresh cache, alpha optimistic
|
||||
)
|
||||
|
||||
await router.chat_completion(_request("mana/single"))
|
||||
|
||||
assert router.health_cache.get_state("alpha").healthy is True
|
||||
assert router.health_cache.get_state("alpha").consecutive_failures == 0
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_failure_marks_provider_unhealthy(self, tmp_path) -> None:
|
||||
# threshold=1 so a single fail is enough to flip the breaker.
|
||||
cache = ProviderHealthCache(failure_threshold=1)
|
||||
alpha = MockProvider("alpha", [httpx.ConnectError("boom")])
|
||||
beta = MockProvider("beta", ["ok"])
|
||||
router = _make_router(
|
||||
tmp_path, providers={"alpha": alpha, "beta": beta}, cache=cache
|
||||
)
|
||||
|
||||
await router.chat_completion(_request("mana/long-form"))
|
||||
|
||||
assert cache.get_state("alpha").healthy is False
|
||||
assert cache.get_state("alpha").last_error is not None
|
||||
assert "ConnectError" in cache.get_state("alpha").last_error
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_propagating_error_does_not_touch_cache(self, tmp_path) -> None:
|
||||
# Auth/Capability errors are about CALLER state, not provider
|
||||
# health — the cache must stay clean so a real outage later
|
||||
# isn't masked by stale "marked unhealthy because of bad key".
|
||||
cache = ProviderHealthCache(failure_threshold=1)
|
||||
alpha = MockProvider("alpha", [ProviderAuthError("bad key")])
|
||||
router = _make_router(tmp_path, providers={"alpha": alpha}, cache=cache)
|
||||
|
||||
with pytest.raises(ProviderAuthError):
|
||||
await router.chat_completion(_request("mana/single"))
|
||||
|
||||
# No state recorded.
|
||||
assert cache.get_state("alpha") is None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Streaming pre-first-byte fallback
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestChatCompletionStream:
|
||||
@pytest.mark.asyncio
|
||||
async def test_first_provider_streams_normally(self, tmp_path) -> None:
|
||||
alpha = MockProvider("alpha", ["ok"])
|
||||
beta = MockProvider("beta")
|
||||
router = _make_router(tmp_path, providers={"alpha": alpha, "beta": beta})
|
||||
|
||||
chunks = [
|
||||
c async for c in router.chat_completion_stream(_request("mana/long-form"))
|
||||
]
|
||||
|
||||
assert beta.calls == []
|
||||
assert len(chunks) == 3
|
||||
assert "".join(c.choices[0].delta.content or "" for c in chunks) == "Hello world"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_pre_first_byte_failure_falls_back(self, tmp_path) -> None:
|
||||
alpha = FailFirstChunkProvider("alpha", httpx.ConnectError("dead"))
|
||||
beta = MockProvider("beta", ["ok"])
|
||||
router = _make_router(tmp_path, providers={"alpha": alpha, "beta": beta})
|
||||
|
||||
chunks = [
|
||||
c async for c in router.chat_completion_stream(_request("mana/long-form"))
|
||||
]
|
||||
|
||||
assert alpha.calls == ["m1"]
|
||||
assert beta.calls == ["m2"]
|
||||
assert len(chunks) == 3
|
||||
assert all(c.model == "beta/m2" for c in chunks)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_pre_first_byte_4xx_propagates_no_fallback(self, tmp_path) -> None:
|
||||
alpha = FailFirstChunkProvider("alpha", ProviderCapabilityError("no tools"))
|
||||
beta = MockProvider("beta", ["ok"])
|
||||
router = _make_router(tmp_path, providers={"alpha": alpha, "beta": beta})
|
||||
|
||||
with pytest.raises(ProviderCapabilityError):
|
||||
async for _ in router.chat_completion_stream(_request("mana/long-form")):
|
||||
pass
|
||||
assert beta.calls == []
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_empty_stream_commits_without_fallback(self, tmp_path) -> None:
|
||||
# Empty-but-successful stream is a valid response, not a failure
|
||||
# we should retry — committing avoids accidentally calling two
|
||||
# providers and double-billing.
|
||||
alpha = MockProvider("alpha", ["empty"])
|
||||
beta = MockProvider("beta", ["ok"])
|
||||
router = _make_router(tmp_path, providers={"alpha": alpha, "beta": beta})
|
||||
|
||||
chunks = [
|
||||
c async for c in router.chat_completion_stream(_request("mana/long-form"))
|
||||
]
|
||||
|
||||
assert chunks == []
|
||||
assert beta.calls == [] # didn't fall through
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_mid_stream_failure_does_not_fall_back(self, tmp_path) -> None:
|
||||
# Custom provider that yields once then raises mid-stream — the
|
||||
# router has already committed and must let the error propagate
|
||||
# rather than splice in another provider's voice.
|
||||
class MidStreamFailProvider(MockProvider):
|
||||
async def chat_completion_stream(self, request, model): # type: ignore[override]
|
||||
self.calls.append(model)
|
||||
yield ChatCompletionStreamResponse(
|
||||
model=f"{self.name}/{model}",
|
||||
choices=[StreamChoice(delta=DeltaContent(content="halb"))],
|
||||
)
|
||||
raise httpx.RemoteProtocolError("connection dropped")
|
||||
|
||||
alpha = MidStreamFailProvider("alpha")
|
||||
beta = MockProvider("beta", ["ok"])
|
||||
router = _make_router(tmp_path, providers={"alpha": alpha, "beta": beta})
|
||||
|
||||
collected: list[str] = []
|
||||
with pytest.raises(httpx.RemoteProtocolError):
|
||||
async for chunk in router.chat_completion_stream(_request("mana/long-form")):
|
||||
collected.append(chunk.choices[0].delta.content or "")
|
||||
|
||||
# We got the half-chunk that landed before the break; beta was
|
||||
# NOT called as fallback.
|
||||
assert collected == ["halb"]
|
||||
assert beta.calls == []
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_all_fail_streaming_raises_no_healthy_provider(self, tmp_path) -> None:
|
||||
alpha = FailFirstChunkProvider("alpha", httpx.ConnectError("a"))
|
||||
beta = FailFirstChunkProvider("beta", httpx.ConnectError("b"))
|
||||
gamma = FailFirstChunkProvider("gamma", httpx.ConnectError("c"))
|
||||
router = _make_router(
|
||||
tmp_path, providers={"alpha": alpha, "beta": beta, "gamma": gamma}
|
||||
)
|
||||
|
||||
with pytest.raises(NoHealthyProviderError):
|
||||
async for _ in router.chat_completion_stream(_request("mana/long-form")):
|
||||
pass
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Health-check shape (still using the cache snapshot)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestHealthCheck:
|
||||
@pytest.mark.asyncio
|
||||
async def test_health_check_lists_known_providers(self, tmp_path) -> None:
|
||||
# Even if no probe has run yet, every configured provider should
|
||||
# appear in the snapshot (zero-defaults) so /health has a stable
|
||||
# shape for monitors.
|
||||
alpha = MockProvider("alpha")
|
||||
beta = MockProvider("beta")
|
||||
router = _make_router(tmp_path, providers={"alpha": alpha, "beta": beta})
|
||||
|
||||
out = await router.health_check()
|
||||
|
||||
assert set(out["providers"].keys()) == {"alpha", "beta"}
|
||||
assert out["status"] == "healthy"
|
||||
assert all(p["status"] == "healthy" for p in out["providers"].values())
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_health_check_degraded_when_one_unhealthy(self, tmp_path) -> None:
|
||||
cache = ProviderHealthCache(failure_threshold=1)
|
||||
cache.mark_unhealthy("alpha", "boom")
|
||||
alpha = MockProvider("alpha")
|
||||
beta = MockProvider("beta")
|
||||
router = _make_router(
|
||||
tmp_path, providers={"alpha": alpha, "beta": beta}, cache=cache
|
||||
)
|
||||
|
||||
out = await router.health_check()
|
||||
assert out["status"] == "degraded"
|
||||
assert out["providers"]["alpha"]["status"] == "unhealthy"
|
||||
assert out["providers"]["beta"]["status"] == "healthy"
|
||||
Loading…
Add table
Add a link
Reference in a new issue