managarten/services/mana-llm/tests/test_router_fallback.py
Till JS 3046da3b19 feat(mana-llm): M3 — health-aware router with alias + chain fallback
Replaces the old Ollama→Google special-case auto-fallback with the
unified pipeline: caller passes either a direct provider/model or an
alias from the `mana/` namespace; the router resolves to a chain and
walks it skipping unhealthy providers (per ProviderHealthCache from M2),
trying each entry, marking provider unhealthy on retryable errors and
falling through to the next.

Retryable: ConnectError, ReadTimeout, RemoteProtocolError, 5xx,
ProviderRateLimitError. Propagated (don't fall back, don't poison the
cache): ProviderCapabilityError, ProviderAuthError, ProviderBlockedError,
4xx, unknown exception types. The cache stays "what the network told us
about this provider's liveness" — caller errors don't muddy that signal.

Streaming: pre-first-byte fallback only. Once a chunk has been yielded
the provider is committed; mid-stream errors propagate as-is so we
don't splice two voices into one output.

`NoHealthyProviderError` (HTTP 503) carries a structured attempt log —
each chain entry shows up as `(model, reason)` so the cause of a 503
is visible in the response and metrics, not only in service logs.

main.py wires the lifespan: aliases.yaml is loaded, ProviderHealthCache
created, ProviderRouter takes both as constructor deps, HealthProbe
spawned with cheap HTTP probes per configured provider (Ollama
/api/tags, OpenAI-compat /v1/models with Bearer header). Google is
skipped — google-genai SDK has no obvious cheap probe; the call-site
fallback handles real errors.

22 new router tests (test_router_fallback.py): chain walking, capability
& auth propagation, 5xx vs 4xx differentiation, rate-limit retry,
all-fail → NoHealthyProviderError, direct provider strings bypass
aliases, streaming pre-first-byte fallback, mid-stream-failure does
NOT fall back, empty stream commits without retry, cache feedback on
success/failure/non-retryable. Existing test_providers.py updated for
the new constructor signature; all 99 service tests green via the dev
container (Python 3.12).

Legacy purged: `_ollama_concurrent`, `_ollama_health_cache`,
`_can_fallback_to_google`, `_should_use_ollama`, `_fallback_to_google`,
`_get_ollama_health_cached` all gone. The `auto_fallback_enabled` /
`ollama_max_concurrent` settings remain in config.py for now (M5 will
remove them along with the per-feature env-var overrides).

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-04-26 20:44:16 +02:00

526 lines
20 KiB
Python

"""Tests for ProviderRouter fallback / alias execution (M3)."""
from __future__ import annotations
from collections.abc import AsyncIterator
from typing import Any
import httpx
import pytest
from src.aliases import AliasRegistry
from src.health import ProviderHealthCache
from src.models import (
ChatCompletionRequest,
ChatCompletionResponse,
ChatCompletionStreamResponse,
Choice,
DeltaContent,
EmbeddingRequest,
EmbeddingResponse,
Message,
MessageResponse,
ModelInfo,
StreamChoice,
)
from src.providers import ProviderRouter
from src.providers.base import LLMProvider
from src.providers.errors import (
NoHealthyProviderError,
ProviderAuthError,
ProviderCapabilityError,
ProviderRateLimitError,
)
# ---------------------------------------------------------------------------
# Test doubles
# ---------------------------------------------------------------------------
class MockProvider(LLMProvider):
"""Provider that lets tests inject a sequence of behaviours.
Each call pops one entry from ``behaviors``. Strings ``"ok"`` and
``"empty"`` are sentinels for normal returns; everything else (a
BaseException instance / class) is raised.
"""
supports_tools = True
def __init__(self, name: str, behaviors: list[Any] | None = None) -> None:
self.name = name
self._behaviors: list[Any] = list(behaviors or [])
self.calls: list[str] = []
def push(self, *behaviors: Any) -> None:
self._behaviors.extend(behaviors)
def _next(self) -> Any:
return self._behaviors.pop(0) if self._behaviors else "ok"
async def chat_completion(
self, request: ChatCompletionRequest, model: str
) -> ChatCompletionResponse:
self.calls.append(model)
b = self._next()
if isinstance(b, type) and issubclass(b, BaseException):
raise b("simulated")
if isinstance(b, BaseException):
raise b
return _ok_response(self.name, model)
async def chat_completion_stream(
self, request: ChatCompletionRequest, model: str
) -> AsyncIterator[ChatCompletionStreamResponse]:
self.calls.append(model)
b = self._next()
if isinstance(b, type) and issubclass(b, BaseException):
raise b("simulated")
if isinstance(b, BaseException):
raise b
if b == "empty":
return
for content in ("Hello", " ", "world"):
yield ChatCompletionStreamResponse(
model=f"{self.name}/{model}",
choices=[StreamChoice(delta=DeltaContent(content=content))],
)
async def list_models(self) -> list[ModelInfo]:
return [ModelInfo(id=f"{self.name}/{m}") for m in ("modelA", "modelB")]
async def embeddings(
self, request: EmbeddingRequest, model: str
) -> EmbeddingResponse:
raise NotImplementedError
async def health_check(self) -> dict[str, Any]:
return {"status": "healthy"}
class FailFirstChunkProvider(MockProvider):
"""Streaming provider that raises BEFORE the first chunk every time.
Kept separate from MockProvider's behaviour list so the per-call
semantics stay simple — this one models a permanently-broken streamer.
"""
def __init__(self, name: str, exc: BaseException) -> None:
super().__init__(name)
self._exc = exc
async def chat_completion_stream(self, request, model): # type: ignore[override]
self.calls.append(model)
raise self._exc
# the yield is unreachable but keeps the function an async generator
yield # pragma: no cover
def _ok_response(provider: str, model: str) -> ChatCompletionResponse:
return ChatCompletionResponse(
model=f"{provider}/{model}",
choices=[
Choice(
message=MessageResponse(content="ok"),
finish_reason="stop",
)
],
)
def _request(model: str) -> ChatCompletionRequest:
return ChatCompletionRequest(
model=model,
messages=[Message(role="user", content="hi")],
)
def _aliases_yaml(tmp_path) -> AliasRegistry:
"""A two-alias config used across most tests."""
cfg = (
"aliases:\n"
" mana/long-form:\n"
' description: "long"\n'
" chain:\n"
" - alpha/m1\n"
" - beta/m2\n"
" - gamma/m3\n"
" mana/single:\n"
' description: "single-entry"\n'
" chain:\n"
" - alpha/solo\n"
)
p = tmp_path / "aliases.yaml"
p.write_text(cfg)
return AliasRegistry(p)
def _make_router(
tmp_path,
*,
providers: dict[str, MockProvider],
cache: ProviderHealthCache | None = None,
) -> ProviderRouter:
aliases = _aliases_yaml(tmp_path)
router = ProviderRouter(aliases=aliases, health_cache=cache or ProviderHealthCache())
# Replace the auto-initialised live providers with the test doubles.
router.providers = dict(providers)
return router
# ---------------------------------------------------------------------------
# Non-streaming chain walking
# ---------------------------------------------------------------------------
class TestChatCompletionChain:
@pytest.mark.asyncio
async def test_first_provider_ok_returns_immediately(self, tmp_path) -> None:
alpha = MockProvider("alpha", ["ok"])
beta = MockProvider("beta")
router = _make_router(tmp_path, providers={"alpha": alpha, "beta": beta})
resp = await router.chat_completion(_request("mana/long-form"))
assert resp.model == "alpha/m1"
assert alpha.calls == ["m1"]
assert beta.calls == [] # never reached
@pytest.mark.asyncio
async def test_falls_through_on_connect_error(self, tmp_path) -> None:
alpha = MockProvider("alpha", [httpx.ConnectError("dead")])
beta = MockProvider("beta", ["ok"])
router = _make_router(tmp_path, providers={"alpha": alpha, "beta": beta})
resp = await router.chat_completion(_request("mana/long-form"))
assert resp.model == "beta/m2"
assert alpha.calls == ["m1"]
assert beta.calls == ["m2"]
@pytest.mark.asyncio
async def test_skips_unconfigured_chain_entries(self, tmp_path) -> None:
# gamma isn't configured at all → chain should silently skip it
# rather than raise.
alpha = MockProvider("alpha", [httpx.ConnectError("dead")])
beta = MockProvider("beta", [httpx.ConnectError("dead too")])
router = _make_router(tmp_path, providers={"alpha": alpha, "beta": beta})
with pytest.raises(NoHealthyProviderError) as exc_info:
await router.chat_completion(_request("mana/long-form"))
# All three entries appear in attempts: two as ConnectError, one
# as unconfigured (not a fatal error, just skipped).
attempts = exc_info.value.attempts
assert ("alpha/m1", "ConnectError") in attempts
assert ("beta/m2", "ConnectError") in attempts
assert ("gamma/m3", "unconfigured") in attempts
@pytest.mark.asyncio
async def test_skips_cache_unhealthy(self, tmp_path) -> None:
cache = ProviderHealthCache(failure_threshold=1)
cache.mark_unhealthy("alpha", "stale")
alpha = MockProvider("alpha", ["ok"])
beta = MockProvider("beta", ["ok"])
router = _make_router(
tmp_path, providers={"alpha": alpha, "beta": beta}, cache=cache
)
resp = await router.chat_completion(_request("mana/long-form"))
assert alpha.calls == [] # router skipped per cache
assert beta.calls == ["m2"]
assert resp.model == "beta/m2"
@pytest.mark.asyncio
async def test_5xx_treated_as_retryable(self, tmp_path) -> None:
five_hundred = httpx.HTTPStatusError(
"boom",
request=httpx.Request("POST", "http://x"),
response=httpx.Response(503),
)
alpha = MockProvider("alpha", [five_hundred])
beta = MockProvider("beta", ["ok"])
router = _make_router(tmp_path, providers={"alpha": alpha, "beta": beta})
resp = await router.chat_completion(_request("mana/long-form"))
assert resp.model == "beta/m2"
@pytest.mark.asyncio
async def test_4xx_propagates(self, tmp_path) -> None:
four_hundred = httpx.HTTPStatusError(
"bad request",
request=httpx.Request("POST", "http://x"),
response=httpx.Response(422),
)
alpha = MockProvider("alpha", [four_hundred])
beta = MockProvider("beta", ["ok"])
router = _make_router(tmp_path, providers={"alpha": alpha, "beta": beta})
with pytest.raises(httpx.HTTPStatusError):
await router.chat_completion(_request("mana/long-form"))
# Beta never tried — caller's request needs fixing, retrying
# against another model would just hide the bug.
assert beta.calls == []
@pytest.mark.asyncio
async def test_capability_error_propagates(self, tmp_path) -> None:
alpha = MockProvider("alpha", [ProviderCapabilityError("no tools")])
beta = MockProvider("beta", ["ok"])
router = _make_router(tmp_path, providers={"alpha": alpha, "beta": beta})
with pytest.raises(ProviderCapabilityError):
await router.chat_completion(_request("mana/long-form"))
assert beta.calls == []
@pytest.mark.asyncio
async def test_auth_error_propagates(self, tmp_path) -> None:
# Auth errors mean OUR setup is broken (wrong key); falling back
# to the next provider hides the misconfiguration.
alpha = MockProvider("alpha", [ProviderAuthError("bad key")])
beta = MockProvider("beta", ["ok"])
router = _make_router(tmp_path, providers={"alpha": alpha, "beta": beta})
with pytest.raises(ProviderAuthError):
await router.chat_completion(_request("mana/long-form"))
assert beta.calls == []
@pytest.mark.asyncio
async def test_rate_limit_is_retryable(self, tmp_path) -> None:
alpha = MockProvider("alpha", [ProviderRateLimitError("slow down")])
beta = MockProvider("beta", ["ok"])
router = _make_router(tmp_path, providers={"alpha": alpha, "beta": beta})
resp = await router.chat_completion(_request("mana/long-form"))
assert resp.model == "beta/m2"
@pytest.mark.asyncio
async def test_all_fail_raises_no_healthy_provider(self, tmp_path) -> None:
alpha = MockProvider("alpha", [httpx.ConnectError("a")])
beta = MockProvider("beta", [httpx.ConnectError("b")])
gamma = MockProvider("gamma", [httpx.ConnectError("c")])
router = _make_router(
tmp_path, providers={"alpha": alpha, "beta": beta, "gamma": gamma}
)
with pytest.raises(NoHealthyProviderError) as exc_info:
await router.chat_completion(_request("mana/long-form"))
assert exc_info.value.model_or_alias == "mana/long-form"
assert isinstance(exc_info.value.last_exception, httpx.ConnectError)
# 503 status so calling code (mana-api etc.) can decide to retry
# later vs surface a clean error to the user.
assert exc_info.value.http_status == 503
@pytest.mark.asyncio
async def test_direct_provider_string_no_alias_resolution(self, tmp_path) -> None:
# Caller bypasses aliases by passing a direct provider/model.
# No fallback chain — fail = fail.
alpha = MockProvider("alpha", [httpx.ConnectError("dead")])
beta = MockProvider("beta", ["ok"])
router = _make_router(tmp_path, providers={"alpha": alpha, "beta": beta})
with pytest.raises(NoHealthyProviderError):
await router.chat_completion(_request("alpha/anything"))
# Beta would have served if this had been an alias — but it
# wasn't, so beta never gets touched.
assert beta.calls == []
# ---------------------------------------------------------------------------
# Health-cache feedback: success clears, failure marks
# ---------------------------------------------------------------------------
class TestHealthCacheFeedback:
@pytest.mark.asyncio
async def test_success_marks_provider_healthy(self, tmp_path) -> None:
cache = ProviderHealthCache(failure_threshold=1)
cache.mark_unhealthy("alpha", "stale-from-probe")
# After the cache TTL the cache thinks alpha might be OK again,
# so the router will try it; success must fully clear the state.
# (Force half-open by zeroing backoff.)
alpha = MockProvider("alpha", ["ok"])
router = _make_router(
tmp_path,
providers={"alpha": alpha},
cache=ProviderHealthCache(), # fresh cache, alpha optimistic
)
await router.chat_completion(_request("mana/single"))
assert router.health_cache.get_state("alpha").healthy is True
assert router.health_cache.get_state("alpha").consecutive_failures == 0
@pytest.mark.asyncio
async def test_failure_marks_provider_unhealthy(self, tmp_path) -> None:
# threshold=1 so a single fail is enough to flip the breaker.
cache = ProviderHealthCache(failure_threshold=1)
alpha = MockProvider("alpha", [httpx.ConnectError("boom")])
beta = MockProvider("beta", ["ok"])
router = _make_router(
tmp_path, providers={"alpha": alpha, "beta": beta}, cache=cache
)
await router.chat_completion(_request("mana/long-form"))
assert cache.get_state("alpha").healthy is False
assert cache.get_state("alpha").last_error is not None
assert "ConnectError" in cache.get_state("alpha").last_error
@pytest.mark.asyncio
async def test_propagating_error_does_not_touch_cache(self, tmp_path) -> None:
# Auth/Capability errors are about CALLER state, not provider
# health — the cache must stay clean so a real outage later
# isn't masked by stale "marked unhealthy because of bad key".
cache = ProviderHealthCache(failure_threshold=1)
alpha = MockProvider("alpha", [ProviderAuthError("bad key")])
router = _make_router(tmp_path, providers={"alpha": alpha}, cache=cache)
with pytest.raises(ProviderAuthError):
await router.chat_completion(_request("mana/single"))
# No state recorded.
assert cache.get_state("alpha") is None
# ---------------------------------------------------------------------------
# Streaming pre-first-byte fallback
# ---------------------------------------------------------------------------
class TestChatCompletionStream:
@pytest.mark.asyncio
async def test_first_provider_streams_normally(self, tmp_path) -> None:
alpha = MockProvider("alpha", ["ok"])
beta = MockProvider("beta")
router = _make_router(tmp_path, providers={"alpha": alpha, "beta": beta})
chunks = [
c async for c in router.chat_completion_stream(_request("mana/long-form"))
]
assert beta.calls == []
assert len(chunks) == 3
assert "".join(c.choices[0].delta.content or "" for c in chunks) == "Hello world"
@pytest.mark.asyncio
async def test_pre_first_byte_failure_falls_back(self, tmp_path) -> None:
alpha = FailFirstChunkProvider("alpha", httpx.ConnectError("dead"))
beta = MockProvider("beta", ["ok"])
router = _make_router(tmp_path, providers={"alpha": alpha, "beta": beta})
chunks = [
c async for c in router.chat_completion_stream(_request("mana/long-form"))
]
assert alpha.calls == ["m1"]
assert beta.calls == ["m2"]
assert len(chunks) == 3
assert all(c.model == "beta/m2" for c in chunks)
@pytest.mark.asyncio
async def test_pre_first_byte_4xx_propagates_no_fallback(self, tmp_path) -> None:
alpha = FailFirstChunkProvider("alpha", ProviderCapabilityError("no tools"))
beta = MockProvider("beta", ["ok"])
router = _make_router(tmp_path, providers={"alpha": alpha, "beta": beta})
with pytest.raises(ProviderCapabilityError):
async for _ in router.chat_completion_stream(_request("mana/long-form")):
pass
assert beta.calls == []
@pytest.mark.asyncio
async def test_empty_stream_commits_without_fallback(self, tmp_path) -> None:
# Empty-but-successful stream is a valid response, not a failure
# we should retry — committing avoids accidentally calling two
# providers and double-billing.
alpha = MockProvider("alpha", ["empty"])
beta = MockProvider("beta", ["ok"])
router = _make_router(tmp_path, providers={"alpha": alpha, "beta": beta})
chunks = [
c async for c in router.chat_completion_stream(_request("mana/long-form"))
]
assert chunks == []
assert beta.calls == [] # didn't fall through
@pytest.mark.asyncio
async def test_mid_stream_failure_does_not_fall_back(self, tmp_path) -> None:
# Custom provider that yields once then raises mid-stream — the
# router has already committed and must let the error propagate
# rather than splice in another provider's voice.
class MidStreamFailProvider(MockProvider):
async def chat_completion_stream(self, request, model): # type: ignore[override]
self.calls.append(model)
yield ChatCompletionStreamResponse(
model=f"{self.name}/{model}",
choices=[StreamChoice(delta=DeltaContent(content="halb"))],
)
raise httpx.RemoteProtocolError("connection dropped")
alpha = MidStreamFailProvider("alpha")
beta = MockProvider("beta", ["ok"])
router = _make_router(tmp_path, providers={"alpha": alpha, "beta": beta})
collected: list[str] = []
with pytest.raises(httpx.RemoteProtocolError):
async for chunk in router.chat_completion_stream(_request("mana/long-form")):
collected.append(chunk.choices[0].delta.content or "")
# We got the half-chunk that landed before the break; beta was
# NOT called as fallback.
assert collected == ["halb"]
assert beta.calls == []
@pytest.mark.asyncio
async def test_all_fail_streaming_raises_no_healthy_provider(self, tmp_path) -> None:
alpha = FailFirstChunkProvider("alpha", httpx.ConnectError("a"))
beta = FailFirstChunkProvider("beta", httpx.ConnectError("b"))
gamma = FailFirstChunkProvider("gamma", httpx.ConnectError("c"))
router = _make_router(
tmp_path, providers={"alpha": alpha, "beta": beta, "gamma": gamma}
)
with pytest.raises(NoHealthyProviderError):
async for _ in router.chat_completion_stream(_request("mana/long-form")):
pass
# ---------------------------------------------------------------------------
# Health-check shape (still using the cache snapshot)
# ---------------------------------------------------------------------------
class TestHealthCheck:
@pytest.mark.asyncio
async def test_health_check_lists_known_providers(self, tmp_path) -> None:
# Even if no probe has run yet, every configured provider should
# appear in the snapshot (zero-defaults) so /health has a stable
# shape for monitors.
alpha = MockProvider("alpha")
beta = MockProvider("beta")
router = _make_router(tmp_path, providers={"alpha": alpha, "beta": beta})
out = await router.health_check()
assert set(out["providers"].keys()) == {"alpha", "beta"}
assert out["status"] == "healthy"
assert all(p["status"] == "healthy" for p in out["providers"].values())
@pytest.mark.asyncio
async def test_health_check_degraded_when_one_unhealthy(self, tmp_path) -> None:
cache = ProviderHealthCache(failure_threshold=1)
cache.mark_unhealthy("alpha", "boom")
alpha = MockProvider("alpha")
beta = MockProvider("beta")
router = _make_router(
tmp_path, providers={"alpha": alpha, "beta": beta}, cache=cache
)
out = await router.health_check()
assert out["status"] == "degraded"
assert out["providers"]["alpha"]["status"] == "unhealthy"
assert out["providers"]["beta"]["status"] == "healthy"