mirror of
https://github.com/Memo-2023/mana-monorepo.git
synced 2026-05-19 08:01:23 +02:00
- `X-Mana-LLM-Resolved: <provider>/<model>` header on non-streaming
responses. Streaming clients read the same info from each chunk's
`model` field (SSE headers go out before the chain is walked).
- Three new Prometheus metrics: `mana_llm_alias_resolved_total{alias,
target}` (which concrete model an alias resolved to per request),
`mana_llm_fallback_total{from_model, to_model, reason}` (each
fallback transition), `mana_llm_provider_healthy{provider}` (gauge,
mirrors the circuit-breaker).
- New debug endpoints: `GET /v1/aliases` (registry inspection — chain
+ description per alias, useful for confirming SIGHUP reloads),
`GET /v1/health` (full per-provider liveness snapshot — failure
counter, last error, unhealthy-until backoff).
- `kill -HUP <pid>` reloads `aliases.yaml`. Parse errors leave the
previous good state in memory and log the rejection.
- `ProviderHealthCache.add_listener()` for cache→metrics decoupling:
the gauge is updated via a transition-only listener wired in main.py
rather than the cache importing prometheus_client itself.
- Request-side metrics now use the requested model string, success-side
uses the resolved one. So `mana_llm_llm_requests_total{provider="ollama",
model="gemma3:12b"}` reflects actual upstream load even when callers
used `mana/long-form` aliases.
16 new observability tests (test_m4_observability.py): listener
fire-on-transition semantics, exception-isolation, multi-listener,
counter increments, gauge writes, end-to-end alias→metric flow,
v1/aliases + v1/health endpoint shape, response.model carries the
resolved target after fallback. Total suite: 115/115 in 1.6s.
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
203 lines
8 KiB
Python
203 lines
8 KiB
Python
"""Provider health cache.
|
|
|
|
Tracks per-provider liveness for the LLM router. The router reads
|
|
:meth:`is_healthy` to decide whether to even try a provider in a chain;
|
|
the probe loop and the call-site fallback handler write state via
|
|
:meth:`mark_healthy` / :meth:`mark_unhealthy`.
|
|
|
|
Implements a simple circuit-breaker:
|
|
|
|
* The first failure flips no switch — providers occasionally have
|
|
transient blips, we don't want to bounce off after a single 502.
|
|
* After ``failure_threshold`` consecutive failures the provider is
|
|
marked unhealthy for ``unhealthy_backoff`` seconds. During that
|
|
window :meth:`is_healthy` returns ``False`` so the router fails
|
|
fast straight to the next chain entry.
|
|
* When the backoff expires :meth:`is_healthy` returns ``True`` again
|
|
(half-open). The next call exercises the provider; success calls
|
|
:meth:`mark_healthy` and fully resets state, failure re-arms the
|
|
backoff window.
|
|
|
|
State is kept in a plain dict guarded by a ``threading.Lock``. All
|
|
operations are short, lock-free reads of dict references aren't safe
|
|
because we mutate state in-place — the lock keeps it boring. Probe
|
|
loop runs in the asyncio loop alongside the router, but the lock
|
|
costs are negligible at ~1 update/30s/provider.
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import logging
|
|
import threading
|
|
import time
|
|
from dataclasses import dataclass, field
|
|
from typing import Callable, Iterable
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
#: Notification fired whenever a provider transitions between healthy and
|
|
#: unhealthy. ``main.py`` wires this to the Prometheus gauge — but the
|
|
#: cache itself stays metrics-agnostic so tests don't need to mock it.
|
|
HealthChangeListener = Callable[[str, bool], None]
|
|
|
|
DEFAULT_FAILURE_THRESHOLD = 2
|
|
DEFAULT_UNHEALTHY_BACKOFF_SEC = 60.0
|
|
|
|
|
|
@dataclass
|
|
class ProviderState:
|
|
"""Per-provider liveness snapshot. All times are unix seconds."""
|
|
|
|
healthy: bool = True
|
|
consecutive_failures: int = 0
|
|
last_check: float = 0.0
|
|
last_error: str | None = None
|
|
unhealthy_until: float = 0.0
|
|
"""When > now, the provider is currently in backoff (`is_healthy → False`)."""
|
|
|
|
|
|
class ProviderHealthCache:
|
|
"""Thread-safe per-provider liveness with circuit-breaker semantics.
|
|
|
|
Provider IDs are arbitrary strings — by convention we use the same
|
|
short name as the provider router (``ollama``, ``groq``, ``openrouter``,
|
|
``together``, ``google``). The cache is provider-list agnostic; states
|
|
are created lazily on first ``mark_*`` or queried-but-absent ``is_healthy``
|
|
call (returning ``True`` by default — no state means no reason to skip).
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
*,
|
|
failure_threshold: int = DEFAULT_FAILURE_THRESHOLD,
|
|
unhealthy_backoff_sec: float = DEFAULT_UNHEALTHY_BACKOFF_SEC,
|
|
clock: callable = time.time,
|
|
) -> None:
|
|
if failure_threshold < 1:
|
|
raise ValueError("failure_threshold must be >= 1")
|
|
if unhealthy_backoff_sec < 0:
|
|
raise ValueError("unhealthy_backoff_sec must be >= 0")
|
|
self._failure_threshold = failure_threshold
|
|
self._unhealthy_backoff = unhealthy_backoff_sec
|
|
self._clock = clock
|
|
self._lock = threading.Lock()
|
|
self._states: dict[str, ProviderState] = {}
|
|
self._listeners: list[HealthChangeListener] = []
|
|
|
|
def add_listener(self, listener: HealthChangeListener) -> None:
|
|
"""Register a callback fired with ``(provider_id, healthy: bool)``
|
|
whenever a provider's healthy-flag transitions. Listeners run
|
|
outside the cache's lock; exceptions are swallowed and logged so a
|
|
bad listener can't break the underlying state machine.
|
|
"""
|
|
self._listeners.append(listener)
|
|
|
|
def _notify(self, provider_id: str, healthy: bool) -> None:
|
|
for listener in self._listeners:
|
|
try:
|
|
listener(provider_id, healthy)
|
|
except Exception as e: # noqa: BLE001
|
|
logger.error("health-change listener raised: %s", e)
|
|
|
|
@property
|
|
def failure_threshold(self) -> int:
|
|
return self._failure_threshold
|
|
|
|
@property
|
|
def unhealthy_backoff_sec(self) -> float:
|
|
return self._unhealthy_backoff
|
|
|
|
# ------------------------------------------------------------------
|
|
# Reads
|
|
# ------------------------------------------------------------------
|
|
|
|
def is_healthy(self, provider_id: str) -> bool:
|
|
"""Should the router try this provider right now?
|
|
|
|
Returns ``True`` by default for unknown providers — the cache is
|
|
observation-only, not a registry.
|
|
"""
|
|
with self._lock:
|
|
state = self._states.get(provider_id)
|
|
if state is None:
|
|
return True
|
|
if state.unhealthy_until > self._clock():
|
|
return False
|
|
# Backoff expired: caller is allowed to try again (half-open).
|
|
return True
|
|
|
|
def get_state(self, provider_id: str) -> ProviderState | None:
|
|
"""Snapshot of one provider's state (for debugging / tests)."""
|
|
with self._lock:
|
|
state = self._states.get(provider_id)
|
|
return None if state is None else _copy(state)
|
|
|
|
def snapshot(self, expected: Iterable[str] | None = None) -> dict[str, ProviderState]:
|
|
"""All known states, plus zero-state placeholders for any names in
|
|
``expected`` that haven't been touched yet. Used by ``GET /v1/health``
|
|
so the response shape is stable regardless of probe order.
|
|
"""
|
|
with self._lock:
|
|
out = {pid: _copy(s) for pid, s in self._states.items()}
|
|
if expected:
|
|
for pid in expected:
|
|
out.setdefault(pid, ProviderState())
|
|
return out
|
|
|
|
# ------------------------------------------------------------------
|
|
# Writes
|
|
# ------------------------------------------------------------------
|
|
|
|
def mark_healthy(self, provider_id: str) -> None:
|
|
"""Provider answered correctly — clear any failure state."""
|
|
transitioned = False
|
|
with self._lock:
|
|
state = self._states.setdefault(provider_id, ProviderState())
|
|
transitioned = not state.healthy
|
|
state.healthy = True
|
|
state.consecutive_failures = 0
|
|
state.last_check = self._clock()
|
|
state.last_error = None
|
|
state.unhealthy_until = 0.0
|
|
if transitioned:
|
|
logger.info("provider %s recovered", provider_id)
|
|
self._notify(provider_id, True)
|
|
|
|
def mark_unhealthy(self, provider_id: str, reason: str) -> None:
|
|
"""Record a failure. Trips the breaker after the threshold."""
|
|
transitioned = False
|
|
with self._lock:
|
|
state = self._states.setdefault(provider_id, ProviderState())
|
|
state.consecutive_failures += 1
|
|
state.last_check = self._clock()
|
|
state.last_error = reason
|
|
tripped = state.consecutive_failures >= self._failure_threshold
|
|
if tripped and state.healthy:
|
|
state.healthy = False
|
|
state.unhealthy_until = self._clock() + self._unhealthy_backoff
|
|
transitioned = True
|
|
logger.warning(
|
|
"provider %s marked unhealthy after %d consecutive failures (%s); "
|
|
"backoff %.0fs",
|
|
provider_id,
|
|
state.consecutive_failures,
|
|
reason,
|
|
self._unhealthy_backoff,
|
|
)
|
|
elif not state.healthy:
|
|
# Still in unhealthy window; refresh the backoff so a flapping
|
|
# provider doesn't get re-tried every probe tick.
|
|
state.unhealthy_until = self._clock() + self._unhealthy_backoff
|
|
if transitioned:
|
|
self._notify(provider_id, False)
|
|
|
|
|
|
def _copy(state: ProviderState) -> ProviderState:
|
|
"""Return a shallow copy so callers can read without holding the lock."""
|
|
return ProviderState(
|
|
healthy=state.healthy,
|
|
consecutive_failures=state.consecutive_failures,
|
|
last_check=state.last_check,
|
|
last_error=state.last_error,
|
|
unhealthy_until=state.unhealthy_until,
|
|
)
|