feat(mana-llm): M4 — observability, debug endpoints, SIGHUP reload

- `X-Mana-LLM-Resolved: <provider>/<model>` header on non-streaming
  responses. Streaming clients read the same info from each chunk's
  `model` field (SSE headers go out before the chain is walked).
- Three new Prometheus metrics: `mana_llm_alias_resolved_total{alias,
  target}` (which concrete model an alias resolved to per request),
  `mana_llm_fallback_total{from_model, to_model, reason}` (each
  fallback transition), `mana_llm_provider_healthy{provider}` (gauge,
  mirrors the circuit-breaker).
- New debug endpoints: `GET /v1/aliases` (registry inspection — chain
  + description per alias, useful for confirming SIGHUP reloads),
  `GET /v1/health` (full per-provider liveness snapshot — failure
  counter, last error, unhealthy-until backoff).
- `kill -HUP <pid>` reloads `aliases.yaml`. Parse errors leave the
  previous good state in memory and log the rejection.
- `ProviderHealthCache.add_listener()` for cache→metrics decoupling:
  the gauge is updated via a transition-only listener wired in main.py
  rather than the cache importing prometheus_client itself.
- Request-side metrics now use the requested model string, success-side
  uses the resolved one. So `mana_llm_llm_requests_total{provider="ollama",
  model="gemma3:12b"}` reflects actual upstream load even when callers
  used `mana/long-form` aliases.

16 new observability tests (test_m4_observability.py): listener
fire-on-transition semantics, exception-isolation, multi-listener,
counter increments, gauge writes, end-to-end alias→metric flow,
v1/aliases + v1/health endpoint shape, response.model carries the
resolved target after fallback. Total suite: 115/115 in 1.6s.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
Till JS 2026-04-26 20:52:28 +02:00
parent 3046da3b19
commit 8a49e3ffd5
6 changed files with 749 additions and 29 deletions

View file

@ -143,13 +143,99 @@ curl -X POST http://localhost:3025/v1/embeddings \
### Health & Metrics
```bash
# Health check
# Liveness summary (legacy, terse shape — only status + per-provider status string)
curl http://localhost:3025/health
# Detailed per-provider liveness snapshot (M4)
curl http://localhost:3025/v1/health
# Prometheus metrics
curl http://localhost:3025/metrics
```
### Aliases (M4 — see `aliases.yaml` and the fallback section below)
```bash
# What does each `mana/<class>` resolve to?
curl http://localhost:3025/v1/aliases
```
## Aliases & Fallback
> Background: [`docs/plans/llm-fallback-aliases.md`](../../docs/plans/llm-fallback-aliases.md)
### What callers send
Two acceptable shapes for the `model` field of `/v1/chat/completions`:
1. **Aliases** in the reserved `mana/` namespace — recommended for product code.
The router resolves them via `aliases.yaml` to a chain of concrete
`provider/model` strings and tries them in order.
2. **Direct `provider/model`** — bypasses the alias layer, no fallback.
Useful for tests, debugging, and one-off integrations.
| Alias | Class |
|---|---|
| `mana/fast-text` | Short answers, classification, single-shot Q&A |
| `mana/long-form` | Writing, essays, stories, longer prose |
| `mana/structured` | JSON output (comic storyboards, research subqueries, tag suggestions) |
| `mana/reasoning` | Agent missions, tool calls, multi-step plans |
| `mana/vision` | Multimodal (image + text) |
The chain for each alias lives in `services/mana-llm/aliases.yaml`. Edit
the file and `kill -HUP <pid>` to reload — no restart needed. Reload
errors keep the previous good state; check the service logs.
### Fallback semantics
Every chain is tried in order. The router skips an entry if the provider
isn't configured at this deployment (no API key) or is currently marked
unhealthy by the health-cache. For each remaining entry the request is
attempted; on a **retryable** error (connection failure, timeout, 5xx,
rate-limit, RemoteProtocolError) the provider is marked unhealthy and
the next entry is tried. **Non-retryable** errors (auth, capability,
content-blocked, 4xx, unknown exception types) propagate immediately —
no fallback, the cache is not poisoned.
Streaming follows the same logic up to the **first byte**. Once a chunk
has been yielded the provider is committed; mid-stream errors surface
as-is so we never splice two providers' voices into one output.
If every entry was skipped or failed, the response is `503` carrying a
structured `attempts: list[(model, reason)]` log so the cause is
visible to the caller, not only in service logs.
### Resolved-model header
Non-streaming responses carry `X-Mana-LLM-Resolved: <provider>/<model>`
(e.g. `groq/llama-3.3-70b-versatile`) — the concrete model that
actually answered. Use this for token-cost attribution when the request
used an alias. For streaming, each chunk's `model` field carries the
same info (headers go out before the chain is walked).
### Health-cache + probe
`ProviderHealthCache` keeps a per-provider circuit-breaker:
* 1 failure: still healthy (transient blip, don't bounce).
* 2 consecutive failures: `is_healthy → False` for 60 s; the router
fail-fasts straight to the next chain entry.
* After 60 s: half-open. Next call exercises the provider; success
fully resets, failure re-arms the backoff.
A background `HealthProbe` task runs every 30 s with a 3 s timeout per
provider, calling cheap endpoints (`/api/tags` for Ollama, `/v1/models`
for OpenAI-compat). One bad probe can't sink the loop; results feed
into the same cache as the call-site fallback.
### Prometheus metrics added in M4
| Metric | Labels | Purpose |
|---|---|---|
| `mana_llm_alias_resolved_total` | `alias`, `target` | How often an alias resolved to which concrete model — useful for spotting cases where the primary always falls through. |
| `mana_llm_fallback_total` | `from_model`, `to_model`, `reason` | Each fallback transition. `reason` is the exception class name or `cache-unhealthy` / `unconfigured`. |
| `mana_llm_provider_healthy` | `provider` | Gauge: 1 healthy, 0 in backoff. Mirrors the circuit-breaker. |
## Provider Routing
Models use the format `provider/model`:

View file

@ -31,10 +31,15 @@ import logging
import threading
import time
from dataclasses import dataclass, field
from typing import Iterable
from typing import Callable, Iterable
logger = logging.getLogger(__name__)
#: Notification fired whenever a provider transitions between healthy and
#: unhealthy. ``main.py`` wires this to the Prometheus gauge — but the
#: cache itself stays metrics-agnostic so tests don't need to mock it.
HealthChangeListener = Callable[[str, bool], None]
DEFAULT_FAILURE_THRESHOLD = 2
DEFAULT_UNHEALTHY_BACKOFF_SEC = 60.0
@ -77,6 +82,22 @@ class ProviderHealthCache:
self._clock = clock
self._lock = threading.Lock()
self._states: dict[str, ProviderState] = {}
self._listeners: list[HealthChangeListener] = []
def add_listener(self, listener: HealthChangeListener) -> None:
"""Register a callback fired with ``(provider_id, healthy: bool)``
whenever a provider's healthy-flag transitions. Listeners run
outside the cache's lock; exceptions are swallowed and logged so a
bad listener can't break the underlying state machine.
"""
self._listeners.append(listener)
def _notify(self, provider_id: str, healthy: bool) -> None:
for listener in self._listeners:
try:
listener(provider_id, healthy)
except Exception as e: # noqa: BLE001
logger.error("health-change listener raised: %s", e)
@property
def failure_threshold(self) -> int:
@ -129,19 +150,22 @@ class ProviderHealthCache:
def mark_healthy(self, provider_id: str) -> None:
"""Provider answered correctly — clear any failure state."""
transitioned = False
with self._lock:
state = self._states.setdefault(provider_id, ProviderState())
previously_unhealthy = not state.healthy
transitioned = not state.healthy
state.healthy = True
state.consecutive_failures = 0
state.last_check = self._clock()
state.last_error = None
state.unhealthy_until = 0.0
if previously_unhealthy:
logger.info("provider %s recovered", provider_id)
if transitioned:
logger.info("provider %s recovered", provider_id)
self._notify(provider_id, True)
def mark_unhealthy(self, provider_id: str, reason: str) -> None:
"""Record a failure. Trips the breaker after the threshold."""
transitioned = False
with self._lock:
state = self._states.setdefault(provider_id, ProviderState())
state.consecutive_failures += 1
@ -151,6 +175,7 @@ class ProviderHealthCache:
if tripped and state.healthy:
state.healthy = False
state.unhealthy_until = self._clock() + self._unhealthy_backoff
transitioned = True
logger.warning(
"provider %s marked unhealthy after %d consecutive failures (%s); "
"backoff %.0fs",
@ -163,6 +188,8 @@ class ProviderHealthCache:
# Still in unhealthy window; refresh the backoff so a flapping
# provider doesn't get re-tried every probe tick.
state.unhealthy_until = self._clock() + self._unhealthy_backoff
if transitioned:
self._notify(provider_id, False)
def _copy(state: ProviderState) -> ProviderState:

View file

@ -1,6 +1,8 @@
"""Main FastAPI application for mana-llm service."""
import asyncio
import logging
import signal
import time
from contextlib import asynccontextmanager
from pathlib import Path
@ -8,10 +10,10 @@ from typing import Any
from fastapi import FastAPI, HTTPException, Request
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import Response
from fastapi.responses import JSONResponse, Response
from sse_starlette.sse import EventSourceResponse
from src.aliases import AliasRegistry
from src.aliases import AliasConfigError, AliasRegistry
from src.api_auth import ApiKeyMiddleware
from src.config import settings
from src.health import ProviderHealthCache
@ -28,7 +30,18 @@ from src.providers import ProviderRouter
from src.providers.errors import ProviderError
from src.streaming import stream_chat_completion
from src.utils.cache import close_redis
from src.utils.metrics import get_metrics, record_llm_error, record_llm_request
from src.utils.metrics import (
get_metrics,
record_llm_error,
record_llm_request,
set_provider_healthy,
)
#: Header carrying the concrete provider/model that actually served a
#: non-streaming response — useful for token-cost accounting on the
#: caller side, since `mana/long-form` could resolve to ollama, groq,
#: or claude depending on which providers were healthy at request time.
RESOLVED_MODEL_HEADER = "X-Mana-LLM-Resolved"
# Configure logging
logging.basicConfig(
@ -73,6 +86,35 @@ def _build_provider_probes(
return probes
def _on_health_change(provider: str, healthy: bool) -> None:
"""Mirror health-cache transitions into the Prometheus gauge."""
set_provider_healthy(provider, healthy)
def _install_sighup_reload(loop: asyncio.AbstractEventLoop) -> None:
"""Reload ``aliases.yaml`` when the process receives SIGHUP.
Reload errors keep the previous good state in memory (see
AliasRegistry.reload). SIGHUP isn't available on Windows; we just
log and skip in that case.
"""
def handler() -> None:
if alias_registry is None:
return
try:
alias_registry.reload()
except AliasConfigError as e:
logger.error("alias reload rejected, keeping previous state: %s", e)
try:
loop.add_signal_handler(signal.SIGHUP, handler)
except (NotImplementedError, AttributeError, RuntimeError):
# NotImplementedError on Windows; RuntimeError when the loop is
# not running in the main thread (TestClient does this).
logger.info("SIGHUP reload not available in this context — skipping")
@asynccontextmanager
async def lifespan(app: FastAPI):
"""Application lifespan: load aliases, spin up router + health probe."""
@ -85,12 +127,20 @@ async def lifespan(app: FastAPI):
logger.info("Loaded %d aliases from %s", len(alias_registry.list_aliases()), aliases_path)
health_cache = ProviderHealthCache()
health_cache.add_listener(_on_health_change)
router = ProviderRouter(aliases=alias_registry, health_cache=health_cache)
logger.info("Initialized providers: %s", list(router.providers))
# Initial gauge values so dashboards render before the first probe
# transition fires the listener.
for provider_name in router.providers:
set_provider_healthy(provider_name, True)
health_probe = HealthProbe(health_cache, _build_provider_probes(router.providers))
await health_probe.start()
_install_sighup_reload(asyncio.get_running_loop())
yield
logger.info("Shutting down mana-llm service...")
@ -143,6 +193,43 @@ async def metrics() -> Response:
return Response(content=get_metrics(), media_type="text/plain")
# ----- Alias / health debug endpoints (M4) ---------------------------------
@app.get("/v1/aliases")
async def list_aliases() -> dict[str, Any]:
"""Inspect the alias registry — what each ``mana/<class>`` resolves to.
Useful for debugging "which model actually answered my request" and
for confirming SIGHUP reloads picked up edits to ``aliases.yaml``.
"""
if alias_registry is None:
raise HTTPException(status_code=503, detail="Service not ready")
return {
"default": alias_registry.default_alias,
"aliases": [
{
"name": a.name,
"description": a.description,
"chain": list(a.chain),
}
for a in alias_registry.list_aliases()
],
}
@app.get("/v1/health")
async def detailed_health() -> dict[str, Any]:
"""Full per-provider liveness snapshot.
Includes the failure counter, last error, and the unhealthy-until
backoff timestamp info the original ``/health`` endpoint hides.
"""
if router is None:
raise HTTPException(status_code=503, detail="Service not ready")
return await router.health_check()
# Models endpoints
@app.get("/v1/models", response_model=ModelsResponse)
async def list_models() -> ModelsResponse:
@ -182,10 +269,12 @@ async def chat_completions(
if router is None:
raise HTTPException(status_code=503, detail="Service not ready")
# Parse provider and model for metrics
model_parts = request.model.split("/", 1)
provider = model_parts[0] if len(model_parts) > 1 else "ollama"
model = model_parts[1] if len(model_parts) > 1 else request.model
# The request's `model` field is what the caller asked for — could be
# `mana/long-form`, `ollama/gemma3:4b`, or even bare `gemma3:4b`. For
# error-path metrics we use that value (it's what the caller will
# search for); for success-path metrics we use the resolved provider
# so token-cost / latency attribute to the model that actually ran.
requested_provider, requested_model = _split_model(request.model)
start_time = time.time()
@ -198,7 +287,11 @@ async def chat_completions(
async for chunk in stream_chat_completion(router, request):
yield chunk
record_llm_request(provider, model, streaming=True)
# Streaming metrics: we don't yet know which provider answered
# at request-record time. Each chunk's `model` field carries
# the resolved name; per-token latency is harder to attribute
# cleanly so we skip it for streams.
record_llm_request(requested_provider, requested_model, streaming=True)
return EventSourceResponse(
generate(),
@ -209,35 +302,43 @@ async def chat_completions(
logger.info(f"Chat completion: {request.model}")
response = await router.chat_completion(request)
# Record metrics
resolved_provider, resolved_model = _split_model(response.model)
latency = time.time() - start_time
record_llm_request(
provider=provider,
model=model,
provider=resolved_provider,
model=resolved_model,
streaming=False,
prompt_tokens=response.usage.prompt_tokens,
completion_tokens=response.usage.completion_tokens,
latency=latency,
)
return response
# `response.model` is the concrete provider/model the chain
# actually resolved to. Surface it via header so the caller
# can attribute token cost to the right model even when the
# request used an alias.
return JSONResponse(
content=response.model_dump(),
headers={RESOLVED_MODEL_HEADER: response.model},
)
except ValueError as e:
logger.error(f"Invalid request: {e}")
record_llm_error(provider, model, "invalid_request")
record_llm_error(requested_provider, requested_model, "invalid_request")
raise HTTPException(status_code=400, detail=str(e))
except ProviderError as e:
logger.warning(
f"Provider error on {provider}/{model}: kind={e.kind} detail={e}"
f"Provider error on {requested_provider}/{requested_model}: "
f"kind={e.kind} detail={e}"
)
record_llm_error(provider, model, e.kind)
record_llm_error(requested_provider, requested_model, e.kind)
raise HTTPException(
status_code=e.http_status,
detail={"kind": e.kind, "message": str(e)},
)
except Exception as e:
logger.error(f"Chat completion failed: {e}")
record_llm_error(provider, model, "server_error")
record_llm_error(requested_provider, requested_model, "server_error")
raise HTTPException(status_code=500, detail=str(e))
@ -248,10 +349,7 @@ async def create_embeddings(request: EmbeddingRequest) -> EmbeddingResponse:
if router is None:
raise HTTPException(status_code=503, detail="Service not ready")
# Parse provider and model for metrics
model_parts = request.model.split("/", 1)
provider = model_parts[0] if len(model_parts) > 1 else "ollama"
model = model_parts[1] if len(model_parts) > 1 else request.model
provider, model = _split_model(request.model)
start_time = time.time()
@ -280,6 +378,21 @@ async def create_embeddings(request: EmbeddingRequest) -> EmbeddingResponse:
raise HTTPException(status_code=500, detail=str(e))
def _split_model(model: str) -> tuple[str, str]:
"""Split a ``provider/model`` string for metric labelling.
Bare names with no slash default to ``ollama`` to match the legacy
OpenAI-style behaviour. Aliases (``mana/...``) keep their namespace
in the metrics that's intentional, so request-side counters tell
you what callers ASKED for, while the resolved-side counters
(``mana_llm_alias_resolved_total``) tell you what they GOT.
"""
if "/" in model:
provider, _, name = model.partition("/")
return provider.lower(), name
return "ollama", model
if __name__ == "__main__":
import uvicorn

View file

@ -41,6 +41,11 @@ from src.models import (
EmbeddingResponse,
ModelInfo,
)
from src.utils.metrics import (
record_alias_resolved,
record_fallback,
set_provider_healthy,
)
from .base import LLMProvider
from .errors import (
@ -229,8 +234,10 @@ class ProviderRouter:
chain = self._resolve_chain(model_or_alias)
attempts: list[tuple[str, str]] = []
last_exc: Exception | None = None
is_alias = AliasRegistry.is_alias(model_or_alias)
for entry in chain:
for i, entry in enumerate(chain):
next_entry = chain[i + 1] if i + 1 < len(chain) else ""
provider_name, model_name = self._parse_model(entry)
if provider_name not in self.providers:
logger.debug(
@ -239,11 +246,13 @@ class ProviderRouter:
provider_name,
)
attempts.append((entry, "unconfigured"))
record_fallback(entry, next_entry, "unconfigured")
continue
if not self.health_cache.is_healthy(provider_name):
logger.debug("skip chain entry %s — cache says unhealthy", entry)
attempts.append((entry, "cache-unhealthy"))
record_fallback(entry, next_entry, "cache-unhealthy")
continue
provider = self.providers[provider_name]
@ -253,10 +262,13 @@ class ProviderRouter:
logger.info(
"execute → %s (alias=%s)",
entry,
model_or_alias if model_or_alias != entry else "<direct>",
model_or_alias if is_alias else "<direct>",
)
result = await call(provider, model_name, request)
self.health_cache.mark_healthy(provider_name)
set_provider_healthy(provider_name, True)
if is_alias:
record_alias_resolved(model_or_alias, entry)
return result
except Exception as e:
if not self._is_retryable(e):
@ -266,7 +278,11 @@ class ProviderRouter:
# being wrong.
raise
self.health_cache.mark_unhealthy(provider_name, self._exception_summary(e))
set_provider_healthy(
provider_name, self.health_cache.is_healthy(provider_name)
)
attempts.append((entry, type(e).__name__))
record_fallback(entry, next_entry, type(e).__name__)
last_exc = e
logger.warning(
"execute %s failed (retryable, will try next): %s",
@ -310,14 +326,18 @@ class ProviderRouter:
chain = self._resolve_chain(request.model)
attempts: list[tuple[str, str]] = []
last_exc: Exception | None = None
is_alias = AliasRegistry.is_alias(request.model)
for entry in chain:
for i, entry in enumerate(chain):
next_entry = chain[i + 1] if i + 1 < len(chain) else ""
provider_name, model_name = self._parse_model(entry)
if provider_name not in self.providers:
attempts.append((entry, "unconfigured"))
record_fallback(entry, next_entry, "unconfigured")
continue
if not self.health_cache.is_healthy(provider_name):
attempts.append((entry, "cache-unhealthy"))
record_fallback(entry, next_entry, "cache-unhealthy")
continue
provider = self.providers[provider_name]
@ -330,13 +350,20 @@ class ProviderRouter:
# Empty stream is a successful but content-free response.
# Commit and exit cleanly.
self.health_cache.mark_healthy(provider_name)
set_provider_healthy(provider_name, True)
if is_alias:
record_alias_resolved(request.model, entry)
logger.info("stream %s yielded empty response", entry)
return
except Exception as e:
if not self._is_retryable(e):
raise
self.health_cache.mark_unhealthy(provider_name, self._exception_summary(e))
set_provider_healthy(
provider_name, self.health_cache.is_healthy(provider_name)
)
attempts.append((entry, type(e).__name__))
record_fallback(entry, next_entry, type(e).__name__)
last_exc = e
logger.warning(
"stream %s failed before first byte (retryable, trying next): %s",
@ -349,6 +376,9 @@ class ProviderRouter:
# the rest of the stream. Any error from here on propagates;
# it is NOT safe to splice another provider's output in.
self.health_cache.mark_healthy(provider_name)
set_provider_healthy(provider_name, True)
if is_alias:
record_alias_resolved(request.model, entry)
logger.info("stream → %s (committed after first chunk)", entry)
yield first_chunk
async for chunk in stream:

View file

@ -4,7 +4,7 @@ import time
from collections.abc import Callable
from fastapi import Request, Response
from prometheus_client import Counter, Histogram, generate_latest
from prometheus_client import Counter, Gauge, Histogram, generate_latest
from starlette.middleware.base import BaseHTTPMiddleware
# Request metrics
@ -47,6 +47,35 @@ LLM_ERRORS = Counter(
["provider", "model", "error_type"],
)
# ---------------------------------------------------------------------------
# Alias / fallback / health metrics — added in M4 of llm-fallback-aliases.md.
# ---------------------------------------------------------------------------
ALIAS_RESOLVED = Counter(
"mana_llm_alias_resolved_total",
"How often an alias resolved to a concrete provider/model. The `target` "
"label is the chain entry that actually served the request — useful for "
"spotting cases where the primary always falls through to a cloud entry.",
["alias", "target"],
)
FALLBACK_TRIGGERED = Counter(
"mana_llm_fallback_total",
"Fallback transitions: a chain entry failed (or was skipped via cache) "
"and the router moved to the next entry. `reason` is the exception class "
"name or `cache-unhealthy` / `unconfigured`. `from_model` is the entry "
"that didn't serve, `to_model` is empty when no further entries existed.",
["from_model", "to_model", "reason"],
)
PROVIDER_HEALTHY = Gauge(
"mana_llm_provider_healthy",
"1 when the provider is currently considered healthy by the cache, "
"0 when in backoff. Refreshed on every probe tick and on every router "
"call-site state transition.",
["provider"],
)
def get_metrics() -> bytes:
"""Generate Prometheus metrics output."""
@ -107,3 +136,23 @@ def record_llm_request(
def record_llm_error(provider: str, model: str, error_type: str) -> None:
"""Record LLM error metrics."""
LLM_ERRORS.labels(provider=provider, model=model, error_type=error_type).inc()
def record_alias_resolved(alias: str, target: str) -> None:
"""Record which concrete model an alias resolved to for this request."""
ALIAS_RESOLVED.labels(alias=alias, target=target).inc()
def record_fallback(from_model: str, to_model: str, reason: str) -> None:
"""Record a fallback transition. ``to_model`` is empty when the chain
ran out (i.e. NoHealthyProviderError)."""
FALLBACK_TRIGGERED.labels(
from_model=from_model,
to_model=to_model,
reason=reason,
).inc()
def set_provider_healthy(provider: str, healthy: bool) -> None:
"""Mirror ``ProviderHealthCache`` state into a Prometheus gauge."""
PROVIDER_HEALTHY.labels(provider=provider).set(1.0 if healthy else 0.0)

View file

@ -0,0 +1,415 @@
"""Tests for M4: observability + debug endpoints + reload."""
from __future__ import annotations
import asyncio
import os
import signal
from pathlib import Path
import httpx
import pytest
from fastapi.testclient import TestClient
from prometheus_client import REGISTRY
from src.aliases import AliasRegistry
from src.health import ProviderHealthCache
from src.models import (
ChatCompletionRequest,
ChatCompletionResponse,
Choice,
Message,
MessageResponse,
)
from src.providers import ProviderRouter
from src.utils.metrics import (
record_alias_resolved,
record_fallback,
set_provider_healthy,
)
# ---------------------------------------------------------------------------
# Cache → listener → metric gauge
# ---------------------------------------------------------------------------
class TestHealthChangeListener:
def test_listener_fires_on_unhealthy_transition(self) -> None:
cache = ProviderHealthCache(failure_threshold=2)
events: list[tuple[str, bool]] = []
cache.add_listener(lambda p, h: events.append((p, h)))
# First failure: still healthy → no transition.
cache.mark_unhealthy("ollama", "blip")
assert events == []
# Second failure: transition healthy→unhealthy → fires.
cache.mark_unhealthy("ollama", "boom")
assert events == [("ollama", False)]
def test_listener_fires_on_recovery(self) -> None:
cache = ProviderHealthCache(failure_threshold=1)
events: list[tuple[str, bool]] = []
cache.add_listener(lambda p, h: events.append((p, h)))
cache.mark_unhealthy("ollama", "boom")
assert events == [("ollama", False)]
cache.mark_healthy("ollama")
assert events == [("ollama", False), ("ollama", True)]
def test_steady_state_does_not_fire(self) -> None:
cache = ProviderHealthCache(failure_threshold=1)
events: list[tuple[str, bool]] = []
cache.add_listener(lambda p, h: events.append((p, h)))
# Three healthy ops in a row — no transitions, no events.
for _ in range(3):
cache.mark_healthy("ollama")
assert events == []
def test_listener_exception_does_not_break_cache(self) -> None:
cache = ProviderHealthCache(failure_threshold=1)
def bad(_provider: str, _healthy: bool) -> None:
raise RuntimeError("listener boom")
cache.add_listener(bad)
# Should NOT raise — the cache must keep working with a broken
# listener, otherwise one bad metric callback would brick the
# whole router.
cache.mark_unhealthy("ollama", "x")
assert cache.is_healthy("ollama") is False
def test_multiple_listeners(self) -> None:
cache = ProviderHealthCache(failure_threshold=1)
a: list = []
b: list = []
cache.add_listener(lambda p, h: a.append((p, h)))
cache.add_listener(lambda p, h: b.append((p, h)))
cache.mark_unhealthy("ollama", "x")
assert a == [("ollama", False)]
assert b == [("ollama", False)]
# ---------------------------------------------------------------------------
# Prometheus metrics — counters/gauges actually move
# ---------------------------------------------------------------------------
def _counter_value(name: str, labels: dict[str, str]) -> float:
"""Helper: read the current value of a labeled Prometheus metric."""
samples = REGISTRY.get_sample_value(name, labels=labels)
return samples or 0.0
class TestMetricsRecording:
def test_record_alias_resolved_increments(self) -> None:
before = _counter_value(
"mana_llm_alias_resolved_total",
{"alias": "mana/test-class", "target": "ollama/x:1b"},
)
record_alias_resolved("mana/test-class", "ollama/x:1b")
after = _counter_value(
"mana_llm_alias_resolved_total",
{"alias": "mana/test-class", "target": "ollama/x:1b"},
)
assert after - before == pytest.approx(1.0)
def test_record_fallback_increments(self) -> None:
before = _counter_value(
"mana_llm_fallback_total",
{"from_model": "ollama/x", "to_model": "groq/y", "reason": "ConnectError"},
)
record_fallback("ollama/x", "groq/y", "ConnectError")
after = _counter_value(
"mana_llm_fallback_total",
{"from_model": "ollama/x", "to_model": "groq/y", "reason": "ConnectError"},
)
assert after - before == pytest.approx(1.0)
def test_set_provider_healthy_writes_gauge(self) -> None:
set_provider_healthy("test_provider_xyz", True)
v = REGISTRY.get_sample_value(
"mana_llm_provider_healthy", labels={"provider": "test_provider_xyz"}
)
assert v == 1.0
set_provider_healthy("test_provider_xyz", False)
v = REGISTRY.get_sample_value(
"mana_llm_provider_healthy", labels={"provider": "test_provider_xyz"}
)
assert v == 0.0
# ---------------------------------------------------------------------------
# Router → metrics: end-to-end through a fallback
# ---------------------------------------------------------------------------
class _OkProvider:
"""Minimal provider double — only what the router uses for chat."""
name = "ok-provider"
supports_tools = True
def __init__(self, name: str, fail_with: BaseException | None = None) -> None:
self.name = name
self.fail_with = fail_with
self.calls = 0
def model_supports_tools(self, model: str) -> bool:
return True
async def chat_completion(self, request, model):
self.calls += 1
if self.fail_with is not None:
raise self.fail_with
return ChatCompletionResponse(
model=f"{self.name}/{model}",
choices=[Choice(message=MessageResponse(content="ok"))],
)
async def chat_completion_stream(self, request, model): # pragma: no cover
if False: # pragma: no cover
yield None
async def list_models(self):
return []
async def embeddings(self, request, model):
raise NotImplementedError
async def health_check(self):
return {"status": "healthy"}
async def close(self):
pass
def _aliases(tmp_path: Path) -> AliasRegistry:
cfg = (
"aliases:\n"
" mana/two-step:\n"
' description: "x"\n'
" chain:\n"
" - alpha/m1\n"
" - beta/m2\n"
)
p = tmp_path / "aliases.yaml"
p.write_text(cfg)
return AliasRegistry(p)
class TestRouterMetricsIntegration:
@pytest.mark.asyncio
async def test_alias_resolved_metric_records_target(self, tmp_path: Path) -> None:
aliases = _aliases(tmp_path)
cache = ProviderHealthCache()
router = ProviderRouter(aliases=aliases, health_cache=cache)
router.providers = {"alpha": _OkProvider("alpha")} # beta not configured
before = _counter_value(
"mana_llm_alias_resolved_total",
{"alias": "mana/two-step", "target": "alpha/m1"},
)
await router.chat_completion(
ChatCompletionRequest(
model="mana/two-step",
messages=[Message(role="user", content="hi")],
)
)
after = _counter_value(
"mana_llm_alias_resolved_total",
{"alias": "mana/two-step", "target": "alpha/m1"},
)
assert after - before == pytest.approx(1.0)
@pytest.mark.asyncio
async def test_fallback_metric_records_transition(self, tmp_path: Path) -> None:
aliases = _aliases(tmp_path)
cache = ProviderHealthCache()
router = ProviderRouter(aliases=aliases, health_cache=cache)
router.providers = {
"alpha": _OkProvider("alpha", fail_with=httpx.ConnectError("dead")),
"beta": _OkProvider("beta"),
}
before = _counter_value(
"mana_llm_fallback_total",
{"from_model": "alpha/m1", "to_model": "beta/m2", "reason": "ConnectError"},
)
await router.chat_completion(
ChatCompletionRequest(
model="mana/two-step",
messages=[Message(role="user", content="hi")],
)
)
after = _counter_value(
"mana_llm_fallback_total",
{"from_model": "alpha/m1", "to_model": "beta/m2", "reason": "ConnectError"},
)
assert after - before == pytest.approx(1.0)
@pytest.mark.asyncio
async def test_direct_model_does_not_record_alias_metric(
self, tmp_path: Path
) -> None:
# Direct provider/model is not an alias — ALIAS_RESOLVED counter
# must stay flat for those calls.
aliases = _aliases(tmp_path)
cache = ProviderHealthCache()
router = ProviderRouter(aliases=aliases, health_cache=cache)
router.providers = {"alpha": _OkProvider("alpha")}
before = _counter_value(
"mana_llm_alias_resolved_total",
{"alias": "alpha/anything", "target": "alpha/anything"},
)
await router.chat_completion(
ChatCompletionRequest(
model="alpha/anything",
messages=[Message(role="user", content="hi")],
)
)
after = _counter_value(
"mana_llm_alias_resolved_total",
{"alias": "alpha/anything", "target": "alpha/anything"},
)
# Counter must have NOT increased — direct calls aren't aliases.
assert after == before
# ---------------------------------------------------------------------------
# Debug endpoints: GET /v1/aliases, GET /v1/health
# ---------------------------------------------------------------------------
@pytest.fixture
def client():
from src.main import app
with TestClient(app) as c:
yield c
class TestDebugEndpoints:
def test_v1_aliases_returns_shipped_config(self, client: TestClient) -> None:
resp = client.get("/v1/aliases")
assert resp.status_code == 200
data = resp.json()
names = [a["name"] for a in data["aliases"]]
# The five canonical classes must always be present.
for expected in (
"mana/fast-text",
"mana/long-form",
"mana/structured",
"mana/reasoning",
"mana/vision",
):
assert expected in names
# Default is set in the shipped config.
assert data["default"] == "mana/fast-text"
def test_v1_aliases_chain_format(self, client: TestClient) -> None:
resp = client.get("/v1/aliases")
data = resp.json()
long_form = next(a for a in data["aliases"] if a["name"] == "mana/long-form")
# Each chain entry is a `provider/model` string.
assert all("/" in entry for entry in long_form["chain"])
assert len(long_form["chain"]) >= 2 # plan requires at least one cloud fallback
def test_v1_health_includes_all_providers(self, client: TestClient) -> None:
resp = client.get("/v1/health")
assert resp.status_code == 200
data = resp.json()
assert "status" in data
assert "providers" in data
# ollama is always configured (provider list is non-empty).
assert "ollama" in data["providers"]
for name, info in data["providers"].items():
assert "status" in info
assert "consecutive_failures" in info
# ---------------------------------------------------------------------------
# X-Mana-LLM-Resolved header on non-streaming responses
# ---------------------------------------------------------------------------
class TestResolvedHeader:
"""The header is the consumer's hook for token-cost attribution.
Tested at the router level wiring through main.py would need a
real provider connection, which isn't available in unit tests.
"""
@pytest.mark.asyncio
async def test_response_model_field_carries_resolved_target(
self, tmp_path: Path
) -> None:
# The header value is `response.model`; verify that field reflects
# the actual chain entry that served, not the requested alias.
aliases = _aliases(tmp_path)
cache = ProviderHealthCache()
router = ProviderRouter(aliases=aliases, health_cache=cache)
# Force fallback to beta.
router.providers = {
"alpha": _OkProvider("alpha", fail_with=httpx.ConnectError("d")),
"beta": _OkProvider("beta"),
}
resp = await router.chat_completion(
ChatCompletionRequest(
model="mana/two-step",
messages=[Message(role="user", content="hi")],
)
)
# Even though the caller asked for `mana/two-step`, the resolved
# field shows the entry that actually answered.
assert resp.model == "beta/m2"
# ---------------------------------------------------------------------------
# SIGHUP reload — only meaningful on Unix; tested by signalling the proc
# ---------------------------------------------------------------------------
class TestSighupReload:
"""SIGHUP triggers ``alias_registry.reload()``; reload-error keeps state.
The signal-handler wiring lives in main.py and only installs when
the loop is running in the main thread. We exercise the reload
semantics here directly on the registry instead the signal-handler
code path itself is a 4-line wrapper around ``reload()``.
"""
def test_reload_picks_up_yaml_edits(self, tmp_path: Path) -> None:
path = tmp_path / "aliases.yaml"
path.write_text(
"aliases:\n"
" mana/x:\n"
' description: "x"\n'
" chain:\n"
" - ollama/foo:1b\n"
)
reg = AliasRegistry(path)
assert reg.resolve_chain("mana/x") == ("ollama/foo:1b",)
# Edit on disk, reload (this is exactly what the SIGHUP handler
# does — minus the signal plumbing).
path.write_text(
"aliases:\n"
" mana/x:\n"
' description: "x"\n'
" chain:\n"
" - ollama/bar:1b\n"
" - groq/llama-3.1-8b-instant\n"
)
reg.reload()
assert reg.resolve_chain("mana/x") == (
"ollama/bar:1b",
"groq/llama-3.1-8b-instant",
)