managarten/services/mana-llm/src/health.py
Till JS 8a49e3ffd5 feat(mana-llm): M4 — observability, debug endpoints, SIGHUP reload
- `X-Mana-LLM-Resolved: <provider>/<model>` header on non-streaming
  responses. Streaming clients read the same info from each chunk's
  `model` field (SSE headers go out before the chain is walked).
- Three new Prometheus metrics: `mana_llm_alias_resolved_total{alias,
  target}` (which concrete model an alias resolved to per request),
  `mana_llm_fallback_total{from_model, to_model, reason}` (each
  fallback transition), `mana_llm_provider_healthy{provider}` (gauge,
  mirrors the circuit-breaker).
- New debug endpoints: `GET /v1/aliases` (registry inspection — chain
  + description per alias, useful for confirming SIGHUP reloads),
  `GET /v1/health` (full per-provider liveness snapshot — failure
  counter, last error, unhealthy-until backoff).
- `kill -HUP <pid>` reloads `aliases.yaml`. Parse errors leave the
  previous good state in memory and log the rejection.
- `ProviderHealthCache.add_listener()` for cache→metrics decoupling:
  the gauge is updated via a transition-only listener wired in main.py
  rather than the cache importing prometheus_client itself.
- Request-side metrics now use the requested model string, success-side
  uses the resolved one. So `mana_llm_llm_requests_total{provider="ollama",
  model="gemma3:12b"}` reflects actual upstream load even when callers
  used `mana/long-form` aliases.

16 new observability tests (test_m4_observability.py): listener
fire-on-transition semantics, exception-isolation, multi-listener,
counter increments, gauge writes, end-to-end alias→metric flow,
v1/aliases + v1/health endpoint shape, response.model carries the
resolved target after fallback. Total suite: 115/115 in 1.6s.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-04-26 20:52:28 +02:00

203 lines
8 KiB
Python

"""Provider health cache.
Tracks per-provider liveness for the LLM router. The router reads
:meth:`is_healthy` to decide whether to even try a provider in a chain;
the probe loop and the call-site fallback handler write state via
:meth:`mark_healthy` / :meth:`mark_unhealthy`.
Implements a simple circuit-breaker:
* The first failure flips no switch — providers occasionally have
transient blips, we don't want to bounce off after a single 502.
* After ``failure_threshold`` consecutive failures the provider is
marked unhealthy for ``unhealthy_backoff`` seconds. During that
window :meth:`is_healthy` returns ``False`` so the router fails
fast straight to the next chain entry.
* When the backoff expires :meth:`is_healthy` returns ``True`` again
(half-open). The next call exercises the provider; success calls
:meth:`mark_healthy` and fully resets state, failure re-arms the
backoff window.
State is kept in a plain dict guarded by a ``threading.Lock``. All
operations are short, lock-free reads of dict references aren't safe
because we mutate state in-place — the lock keeps it boring. Probe
loop runs in the asyncio loop alongside the router, but the lock
costs are negligible at ~1 update/30s/provider.
"""
from __future__ import annotations
import logging
import threading
import time
from dataclasses import dataclass, field
from typing import Callable, Iterable
logger = logging.getLogger(__name__)
#: Notification fired whenever a provider transitions between healthy and
#: unhealthy. ``main.py`` wires this to the Prometheus gauge — but the
#: cache itself stays metrics-agnostic so tests don't need to mock it.
HealthChangeListener = Callable[[str, bool], None]
DEFAULT_FAILURE_THRESHOLD = 2
DEFAULT_UNHEALTHY_BACKOFF_SEC = 60.0
@dataclass
class ProviderState:
"""Per-provider liveness snapshot. All times are unix seconds."""
healthy: bool = True
consecutive_failures: int = 0
last_check: float = 0.0
last_error: str | None = None
unhealthy_until: float = 0.0
"""When > now, the provider is currently in backoff (`is_healthy → False`)."""
class ProviderHealthCache:
"""Thread-safe per-provider liveness with circuit-breaker semantics.
Provider IDs are arbitrary strings — by convention we use the same
short name as the provider router (``ollama``, ``groq``, ``openrouter``,
``together``, ``google``). The cache is provider-list agnostic; states
are created lazily on first ``mark_*`` or queried-but-absent ``is_healthy``
call (returning ``True`` by default — no state means no reason to skip).
"""
def __init__(
self,
*,
failure_threshold: int = DEFAULT_FAILURE_THRESHOLD,
unhealthy_backoff_sec: float = DEFAULT_UNHEALTHY_BACKOFF_SEC,
clock: callable = time.time,
) -> None:
if failure_threshold < 1:
raise ValueError("failure_threshold must be >= 1")
if unhealthy_backoff_sec < 0:
raise ValueError("unhealthy_backoff_sec must be >= 0")
self._failure_threshold = failure_threshold
self._unhealthy_backoff = unhealthy_backoff_sec
self._clock = clock
self._lock = threading.Lock()
self._states: dict[str, ProviderState] = {}
self._listeners: list[HealthChangeListener] = []
def add_listener(self, listener: HealthChangeListener) -> None:
"""Register a callback fired with ``(provider_id, healthy: bool)``
whenever a provider's healthy-flag transitions. Listeners run
outside the cache's lock; exceptions are swallowed and logged so a
bad listener can't break the underlying state machine.
"""
self._listeners.append(listener)
def _notify(self, provider_id: str, healthy: bool) -> None:
for listener in self._listeners:
try:
listener(provider_id, healthy)
except Exception as e: # noqa: BLE001
logger.error("health-change listener raised: %s", e)
@property
def failure_threshold(self) -> int:
return self._failure_threshold
@property
def unhealthy_backoff_sec(self) -> float:
return self._unhealthy_backoff
# ------------------------------------------------------------------
# Reads
# ------------------------------------------------------------------
def is_healthy(self, provider_id: str) -> bool:
"""Should the router try this provider right now?
Returns ``True`` by default for unknown providers — the cache is
observation-only, not a registry.
"""
with self._lock:
state = self._states.get(provider_id)
if state is None:
return True
if state.unhealthy_until > self._clock():
return False
# Backoff expired: caller is allowed to try again (half-open).
return True
def get_state(self, provider_id: str) -> ProviderState | None:
"""Snapshot of one provider's state (for debugging / tests)."""
with self._lock:
state = self._states.get(provider_id)
return None if state is None else _copy(state)
def snapshot(self, expected: Iterable[str] | None = None) -> dict[str, ProviderState]:
"""All known states, plus zero-state placeholders for any names in
``expected`` that haven't been touched yet. Used by ``GET /v1/health``
so the response shape is stable regardless of probe order.
"""
with self._lock:
out = {pid: _copy(s) for pid, s in self._states.items()}
if expected:
for pid in expected:
out.setdefault(pid, ProviderState())
return out
# ------------------------------------------------------------------
# Writes
# ------------------------------------------------------------------
def mark_healthy(self, provider_id: str) -> None:
"""Provider answered correctly — clear any failure state."""
transitioned = False
with self._lock:
state = self._states.setdefault(provider_id, ProviderState())
transitioned = not state.healthy
state.healthy = True
state.consecutive_failures = 0
state.last_check = self._clock()
state.last_error = None
state.unhealthy_until = 0.0
if transitioned:
logger.info("provider %s recovered", provider_id)
self._notify(provider_id, True)
def mark_unhealthy(self, provider_id: str, reason: str) -> None:
"""Record a failure. Trips the breaker after the threshold."""
transitioned = False
with self._lock:
state = self._states.setdefault(provider_id, ProviderState())
state.consecutive_failures += 1
state.last_check = self._clock()
state.last_error = reason
tripped = state.consecutive_failures >= self._failure_threshold
if tripped and state.healthy:
state.healthy = False
state.unhealthy_until = self._clock() + self._unhealthy_backoff
transitioned = True
logger.warning(
"provider %s marked unhealthy after %d consecutive failures (%s); "
"backoff %.0fs",
provider_id,
state.consecutive_failures,
reason,
self._unhealthy_backoff,
)
elif not state.healthy:
# Still in unhealthy window; refresh the backoff so a flapping
# provider doesn't get re-tried every probe tick.
state.unhealthy_until = self._clock() + self._unhealthy_backoff
if transitioned:
self._notify(provider_id, False)
def _copy(state: ProviderState) -> ProviderState:
"""Return a shallow copy so callers can read without holding the lock."""
return ProviderState(
healthy=state.healthy,
consecutive_failures=state.consecutive_failures,
last_check=state.last_check,
last_error=state.last_error,
unhealthy_until=state.unhealthy_until,
)