feat(mana-llm): M4 — observability, debug endpoints, SIGHUP reload

- `X-Mana-LLM-Resolved: <provider>/<model>` header on non-streaming
  responses. Streaming clients read the same info from each chunk's
  `model` field (SSE headers go out before the chain is walked).
- Three new Prometheus metrics: `mana_llm_alias_resolved_total{alias,
  target}` (which concrete model an alias resolved to per request),
  `mana_llm_fallback_total{from_model, to_model, reason}` (each
  fallback transition), `mana_llm_provider_healthy{provider}` (gauge,
  mirrors the circuit-breaker).
- New debug endpoints: `GET /v1/aliases` (registry inspection — chain
  + description per alias, useful for confirming SIGHUP reloads),
  `GET /v1/health` (full per-provider liveness snapshot — failure
  counter, last error, unhealthy-until backoff).
- `kill -HUP <pid>` reloads `aliases.yaml`. Parse errors leave the
  previous good state in memory and log the rejection.
- `ProviderHealthCache.add_listener()` for cache→metrics decoupling:
  the gauge is updated via a transition-only listener wired in main.py
  rather than the cache importing prometheus_client itself.
- Request-side metrics now use the requested model string, success-side
  uses the resolved one. So `mana_llm_llm_requests_total{provider="ollama",
  model="gemma3:12b"}` reflects actual upstream load even when callers
  used `mana/long-form` aliases.

16 new observability tests (test_m4_observability.py): listener
fire-on-transition semantics, exception-isolation, multi-listener,
counter increments, gauge writes, end-to-end alias→metric flow,
v1/aliases + v1/health endpoint shape, response.model carries the
resolved target after fallback. Total suite: 115/115 in 1.6s.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
Till JS 2026-04-26 20:52:28 +02:00
parent 3046da3b19
commit 8a49e3ffd5
6 changed files with 749 additions and 29 deletions

View file

@ -31,10 +31,15 @@ import logging
import threading
import time
from dataclasses import dataclass, field
from typing import Iterable
from typing import Callable, Iterable
logger = logging.getLogger(__name__)
#: Notification fired whenever a provider transitions between healthy and
#: unhealthy. ``main.py`` wires this to the Prometheus gauge — but the
#: cache itself stays metrics-agnostic so tests don't need to mock it.
HealthChangeListener = Callable[[str, bool], None]
DEFAULT_FAILURE_THRESHOLD = 2
DEFAULT_UNHEALTHY_BACKOFF_SEC = 60.0
@ -77,6 +82,22 @@ class ProviderHealthCache:
self._clock = clock
self._lock = threading.Lock()
self._states: dict[str, ProviderState] = {}
self._listeners: list[HealthChangeListener] = []
def add_listener(self, listener: HealthChangeListener) -> None:
"""Register a callback fired with ``(provider_id, healthy: bool)``
whenever a provider's healthy-flag transitions. Listeners run
outside the cache's lock; exceptions are swallowed and logged so a
bad listener can't break the underlying state machine.
"""
self._listeners.append(listener)
def _notify(self, provider_id: str, healthy: bool) -> None:
for listener in self._listeners:
try:
listener(provider_id, healthy)
except Exception as e: # noqa: BLE001
logger.error("health-change listener raised: %s", e)
@property
def failure_threshold(self) -> int:
@ -129,19 +150,22 @@ class ProviderHealthCache:
def mark_healthy(self, provider_id: str) -> None:
"""Provider answered correctly — clear any failure state."""
transitioned = False
with self._lock:
state = self._states.setdefault(provider_id, ProviderState())
previously_unhealthy = not state.healthy
transitioned = not state.healthy
state.healthy = True
state.consecutive_failures = 0
state.last_check = self._clock()
state.last_error = None
state.unhealthy_until = 0.0
if previously_unhealthy:
logger.info("provider %s recovered", provider_id)
if transitioned:
logger.info("provider %s recovered", provider_id)
self._notify(provider_id, True)
def mark_unhealthy(self, provider_id: str, reason: str) -> None:
"""Record a failure. Trips the breaker after the threshold."""
transitioned = False
with self._lock:
state = self._states.setdefault(provider_id, ProviderState())
state.consecutive_failures += 1
@ -151,6 +175,7 @@ class ProviderHealthCache:
if tripped and state.healthy:
state.healthy = False
state.unhealthy_until = self._clock() + self._unhealthy_backoff
transitioned = True
logger.warning(
"provider %s marked unhealthy after %d consecutive failures (%s); "
"backoff %.0fs",
@ -163,6 +188,8 @@ class ProviderHealthCache:
# Still in unhealthy window; refresh the backoff so a flapping
# provider doesn't get re-tried every probe tick.
state.unhealthy_until = self._clock() + self._unhealthy_backoff
if transitioned:
self._notify(provider_id, False)
def _copy(state: ProviderState) -> ProviderState:

View file

@ -1,6 +1,8 @@
"""Main FastAPI application for mana-llm service."""
import asyncio
import logging
import signal
import time
from contextlib import asynccontextmanager
from pathlib import Path
@ -8,10 +10,10 @@ from typing import Any
from fastapi import FastAPI, HTTPException, Request
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import Response
from fastapi.responses import JSONResponse, Response
from sse_starlette.sse import EventSourceResponse
from src.aliases import AliasRegistry
from src.aliases import AliasConfigError, AliasRegistry
from src.api_auth import ApiKeyMiddleware
from src.config import settings
from src.health import ProviderHealthCache
@ -28,7 +30,18 @@ from src.providers import ProviderRouter
from src.providers.errors import ProviderError
from src.streaming import stream_chat_completion
from src.utils.cache import close_redis
from src.utils.metrics import get_metrics, record_llm_error, record_llm_request
from src.utils.metrics import (
get_metrics,
record_llm_error,
record_llm_request,
set_provider_healthy,
)
#: Header carrying the concrete provider/model that actually served a
#: non-streaming response — useful for token-cost accounting on the
#: caller side, since `mana/long-form` could resolve to ollama, groq,
#: or claude depending on which providers were healthy at request time.
RESOLVED_MODEL_HEADER = "X-Mana-LLM-Resolved"
# Configure logging
logging.basicConfig(
@ -73,6 +86,35 @@ def _build_provider_probes(
return probes
def _on_health_change(provider: str, healthy: bool) -> None:
"""Mirror health-cache transitions into the Prometheus gauge."""
set_provider_healthy(provider, healthy)
def _install_sighup_reload(loop: asyncio.AbstractEventLoop) -> None:
"""Reload ``aliases.yaml`` when the process receives SIGHUP.
Reload errors keep the previous good state in memory (see
AliasRegistry.reload). SIGHUP isn't available on Windows; we just
log and skip in that case.
"""
def handler() -> None:
if alias_registry is None:
return
try:
alias_registry.reload()
except AliasConfigError as e:
logger.error("alias reload rejected, keeping previous state: %s", e)
try:
loop.add_signal_handler(signal.SIGHUP, handler)
except (NotImplementedError, AttributeError, RuntimeError):
# NotImplementedError on Windows; RuntimeError when the loop is
# not running in the main thread (TestClient does this).
logger.info("SIGHUP reload not available in this context — skipping")
@asynccontextmanager
async def lifespan(app: FastAPI):
"""Application lifespan: load aliases, spin up router + health probe."""
@ -85,12 +127,20 @@ async def lifespan(app: FastAPI):
logger.info("Loaded %d aliases from %s", len(alias_registry.list_aliases()), aliases_path)
health_cache = ProviderHealthCache()
health_cache.add_listener(_on_health_change)
router = ProviderRouter(aliases=alias_registry, health_cache=health_cache)
logger.info("Initialized providers: %s", list(router.providers))
# Initial gauge values so dashboards render before the first probe
# transition fires the listener.
for provider_name in router.providers:
set_provider_healthy(provider_name, True)
health_probe = HealthProbe(health_cache, _build_provider_probes(router.providers))
await health_probe.start()
_install_sighup_reload(asyncio.get_running_loop())
yield
logger.info("Shutting down mana-llm service...")
@ -143,6 +193,43 @@ async def metrics() -> Response:
return Response(content=get_metrics(), media_type="text/plain")
# ----- Alias / health debug endpoints (M4) ---------------------------------
@app.get("/v1/aliases")
async def list_aliases() -> dict[str, Any]:
"""Inspect the alias registry — what each ``mana/<class>`` resolves to.
Useful for debugging "which model actually answered my request" and
for confirming SIGHUP reloads picked up edits to ``aliases.yaml``.
"""
if alias_registry is None:
raise HTTPException(status_code=503, detail="Service not ready")
return {
"default": alias_registry.default_alias,
"aliases": [
{
"name": a.name,
"description": a.description,
"chain": list(a.chain),
}
for a in alias_registry.list_aliases()
],
}
@app.get("/v1/health")
async def detailed_health() -> dict[str, Any]:
"""Full per-provider liveness snapshot.
Includes the failure counter, last error, and the unhealthy-until
backoff timestamp info the original ``/health`` endpoint hides.
"""
if router is None:
raise HTTPException(status_code=503, detail="Service not ready")
return await router.health_check()
# Models endpoints
@app.get("/v1/models", response_model=ModelsResponse)
async def list_models() -> ModelsResponse:
@ -182,10 +269,12 @@ async def chat_completions(
if router is None:
raise HTTPException(status_code=503, detail="Service not ready")
# Parse provider and model for metrics
model_parts = request.model.split("/", 1)
provider = model_parts[0] if len(model_parts) > 1 else "ollama"
model = model_parts[1] if len(model_parts) > 1 else request.model
# The request's `model` field is what the caller asked for — could be
# `mana/long-form`, `ollama/gemma3:4b`, or even bare `gemma3:4b`. For
# error-path metrics we use that value (it's what the caller will
# search for); for success-path metrics we use the resolved provider
# so token-cost / latency attribute to the model that actually ran.
requested_provider, requested_model = _split_model(request.model)
start_time = time.time()
@ -198,7 +287,11 @@ async def chat_completions(
async for chunk in stream_chat_completion(router, request):
yield chunk
record_llm_request(provider, model, streaming=True)
# Streaming metrics: we don't yet know which provider answered
# at request-record time. Each chunk's `model` field carries
# the resolved name; per-token latency is harder to attribute
# cleanly so we skip it for streams.
record_llm_request(requested_provider, requested_model, streaming=True)
return EventSourceResponse(
generate(),
@ -209,35 +302,43 @@ async def chat_completions(
logger.info(f"Chat completion: {request.model}")
response = await router.chat_completion(request)
# Record metrics
resolved_provider, resolved_model = _split_model(response.model)
latency = time.time() - start_time
record_llm_request(
provider=provider,
model=model,
provider=resolved_provider,
model=resolved_model,
streaming=False,
prompt_tokens=response.usage.prompt_tokens,
completion_tokens=response.usage.completion_tokens,
latency=latency,
)
return response
# `response.model` is the concrete provider/model the chain
# actually resolved to. Surface it via header so the caller
# can attribute token cost to the right model even when the
# request used an alias.
return JSONResponse(
content=response.model_dump(),
headers={RESOLVED_MODEL_HEADER: response.model},
)
except ValueError as e:
logger.error(f"Invalid request: {e}")
record_llm_error(provider, model, "invalid_request")
record_llm_error(requested_provider, requested_model, "invalid_request")
raise HTTPException(status_code=400, detail=str(e))
except ProviderError as e:
logger.warning(
f"Provider error on {provider}/{model}: kind={e.kind} detail={e}"
f"Provider error on {requested_provider}/{requested_model}: "
f"kind={e.kind} detail={e}"
)
record_llm_error(provider, model, e.kind)
record_llm_error(requested_provider, requested_model, e.kind)
raise HTTPException(
status_code=e.http_status,
detail={"kind": e.kind, "message": str(e)},
)
except Exception as e:
logger.error(f"Chat completion failed: {e}")
record_llm_error(provider, model, "server_error")
record_llm_error(requested_provider, requested_model, "server_error")
raise HTTPException(status_code=500, detail=str(e))
@ -248,10 +349,7 @@ async def create_embeddings(request: EmbeddingRequest) -> EmbeddingResponse:
if router is None:
raise HTTPException(status_code=503, detail="Service not ready")
# Parse provider and model for metrics
model_parts = request.model.split("/", 1)
provider = model_parts[0] if len(model_parts) > 1 else "ollama"
model = model_parts[1] if len(model_parts) > 1 else request.model
provider, model = _split_model(request.model)
start_time = time.time()
@ -280,6 +378,21 @@ async def create_embeddings(request: EmbeddingRequest) -> EmbeddingResponse:
raise HTTPException(status_code=500, detail=str(e))
def _split_model(model: str) -> tuple[str, str]:
"""Split a ``provider/model`` string for metric labelling.
Bare names with no slash default to ``ollama`` to match the legacy
OpenAI-style behaviour. Aliases (``mana/...``) keep their namespace
in the metrics that's intentional, so request-side counters tell
you what callers ASKED for, while the resolved-side counters
(``mana_llm_alias_resolved_total``) tell you what they GOT.
"""
if "/" in model:
provider, _, name = model.partition("/")
return provider.lower(), name
return "ollama", model
if __name__ == "__main__":
import uvicorn

View file

@ -41,6 +41,11 @@ from src.models import (
EmbeddingResponse,
ModelInfo,
)
from src.utils.metrics import (
record_alias_resolved,
record_fallback,
set_provider_healthy,
)
from .base import LLMProvider
from .errors import (
@ -229,8 +234,10 @@ class ProviderRouter:
chain = self._resolve_chain(model_or_alias)
attempts: list[tuple[str, str]] = []
last_exc: Exception | None = None
is_alias = AliasRegistry.is_alias(model_or_alias)
for entry in chain:
for i, entry in enumerate(chain):
next_entry = chain[i + 1] if i + 1 < len(chain) else ""
provider_name, model_name = self._parse_model(entry)
if provider_name not in self.providers:
logger.debug(
@ -239,11 +246,13 @@ class ProviderRouter:
provider_name,
)
attempts.append((entry, "unconfigured"))
record_fallback(entry, next_entry, "unconfigured")
continue
if not self.health_cache.is_healthy(provider_name):
logger.debug("skip chain entry %s — cache says unhealthy", entry)
attempts.append((entry, "cache-unhealthy"))
record_fallback(entry, next_entry, "cache-unhealthy")
continue
provider = self.providers[provider_name]
@ -253,10 +262,13 @@ class ProviderRouter:
logger.info(
"execute → %s (alias=%s)",
entry,
model_or_alias if model_or_alias != entry else "<direct>",
model_or_alias if is_alias else "<direct>",
)
result = await call(provider, model_name, request)
self.health_cache.mark_healthy(provider_name)
set_provider_healthy(provider_name, True)
if is_alias:
record_alias_resolved(model_or_alias, entry)
return result
except Exception as e:
if not self._is_retryable(e):
@ -266,7 +278,11 @@ class ProviderRouter:
# being wrong.
raise
self.health_cache.mark_unhealthy(provider_name, self._exception_summary(e))
set_provider_healthy(
provider_name, self.health_cache.is_healthy(provider_name)
)
attempts.append((entry, type(e).__name__))
record_fallback(entry, next_entry, type(e).__name__)
last_exc = e
logger.warning(
"execute %s failed (retryable, will try next): %s",
@ -310,14 +326,18 @@ class ProviderRouter:
chain = self._resolve_chain(request.model)
attempts: list[tuple[str, str]] = []
last_exc: Exception | None = None
is_alias = AliasRegistry.is_alias(request.model)
for entry in chain:
for i, entry in enumerate(chain):
next_entry = chain[i + 1] if i + 1 < len(chain) else ""
provider_name, model_name = self._parse_model(entry)
if provider_name not in self.providers:
attempts.append((entry, "unconfigured"))
record_fallback(entry, next_entry, "unconfigured")
continue
if not self.health_cache.is_healthy(provider_name):
attempts.append((entry, "cache-unhealthy"))
record_fallback(entry, next_entry, "cache-unhealthy")
continue
provider = self.providers[provider_name]
@ -330,13 +350,20 @@ class ProviderRouter:
# Empty stream is a successful but content-free response.
# Commit and exit cleanly.
self.health_cache.mark_healthy(provider_name)
set_provider_healthy(provider_name, True)
if is_alias:
record_alias_resolved(request.model, entry)
logger.info("stream %s yielded empty response", entry)
return
except Exception as e:
if not self._is_retryable(e):
raise
self.health_cache.mark_unhealthy(provider_name, self._exception_summary(e))
set_provider_healthy(
provider_name, self.health_cache.is_healthy(provider_name)
)
attempts.append((entry, type(e).__name__))
record_fallback(entry, next_entry, type(e).__name__)
last_exc = e
logger.warning(
"stream %s failed before first byte (retryable, trying next): %s",
@ -349,6 +376,9 @@ class ProviderRouter:
# the rest of the stream. Any error from here on propagates;
# it is NOT safe to splice another provider's output in.
self.health_cache.mark_healthy(provider_name)
set_provider_healthy(provider_name, True)
if is_alias:
record_alias_resolved(request.model, entry)
logger.info("stream → %s (committed after first chunk)", entry)
yield first_chunk
async for chunk in stream:

View file

@ -4,7 +4,7 @@ import time
from collections.abc import Callable
from fastapi import Request, Response
from prometheus_client import Counter, Histogram, generate_latest
from prometheus_client import Counter, Gauge, Histogram, generate_latest
from starlette.middleware.base import BaseHTTPMiddleware
# Request metrics
@ -47,6 +47,35 @@ LLM_ERRORS = Counter(
["provider", "model", "error_type"],
)
# ---------------------------------------------------------------------------
# Alias / fallback / health metrics — added in M4 of llm-fallback-aliases.md.
# ---------------------------------------------------------------------------
ALIAS_RESOLVED = Counter(
"mana_llm_alias_resolved_total",
"How often an alias resolved to a concrete provider/model. The `target` "
"label is the chain entry that actually served the request — useful for "
"spotting cases where the primary always falls through to a cloud entry.",
["alias", "target"],
)
FALLBACK_TRIGGERED = Counter(
"mana_llm_fallback_total",
"Fallback transitions: a chain entry failed (or was skipped via cache) "
"and the router moved to the next entry. `reason` is the exception class "
"name or `cache-unhealthy` / `unconfigured`. `from_model` is the entry "
"that didn't serve, `to_model` is empty when no further entries existed.",
["from_model", "to_model", "reason"],
)
PROVIDER_HEALTHY = Gauge(
"mana_llm_provider_healthy",
"1 when the provider is currently considered healthy by the cache, "
"0 when in backoff. Refreshed on every probe tick and on every router "
"call-site state transition.",
["provider"],
)
def get_metrics() -> bytes:
"""Generate Prometheus metrics output."""
@ -107,3 +136,23 @@ def record_llm_request(
def record_llm_error(provider: str, model: str, error_type: str) -> None:
"""Record LLM error metrics."""
LLM_ERRORS.labels(provider=provider, model=model, error_type=error_type).inc()
def record_alias_resolved(alias: str, target: str) -> None:
"""Record which concrete model an alias resolved to for this request."""
ALIAS_RESOLVED.labels(alias=alias, target=target).inc()
def record_fallback(from_model: str, to_model: str, reason: str) -> None:
"""Record a fallback transition. ``to_model`` is empty when the chain
ran out (i.e. NoHealthyProviderError)."""
FALLBACK_TRIGGERED.labels(
from_model=from_model,
to_model=to_model,
reason=reason,
).inc()
def set_provider_healthy(provider: str, healthy: bool) -> None:
"""Mirror ``ProviderHealthCache`` state into a Prometheus gauge."""
PROVIDER_HEALTHY.labels(provider=provider).set(1.0 if healthy else 0.0)