mirror of
https://github.com/Memo-2023/mana-monorepo.git
synced 2026-05-15 04:01:09 +02:00
feat(mana-llm): M2 — ProviderHealthCache + background probe loop
Per-provider liveness with circuit-breaker semantics. The router (M3)
will read `is_healthy()` to skip dead providers in a chain; the probe
loop and the call-site fallback handler write state via
`mark_healthy` / `mark_unhealthy`.
State machine: 1st failure stays healthy (transient blips happen);
2nd consecutive failure trips the breaker and sets a 60s backoff
window during which `is_healthy → False`. After the window the
provider is half-open again — next call exercises it, success
resets, failure re-arms.
HealthProbe is the background asyncio.Task that pings every
registered provider every 30s with a 3s timeout. Probes run
concurrently per tick and one bad probe can't sink the loop. Probe
functions are injected (`{name: async-fn}`) so this module stays
decoupled from the provider classes — the wiring lives in main.py
where we already know which providers are configured.
32 new tests (FakeClock for deterministic backoff timing, slow-probe
helpers for parallelism + timeout, lifecycle tests for start/stop
idempotency and tick-after-error survival). 64/64 alias+health tests
green.
Not yet wired into the request path — that's M3.
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
parent
dff8629e1d
commit
59557e62d7
4 changed files with 811 additions and 0 deletions
176
services/mana-llm/src/health.py
Normal file
176
services/mana-llm/src/health.py
Normal file
|
|
@ -0,0 +1,176 @@
|
|||
"""Provider health cache.
|
||||
|
||||
Tracks per-provider liveness for the LLM router. The router reads
|
||||
:meth:`is_healthy` to decide whether to even try a provider in a chain;
|
||||
the probe loop and the call-site fallback handler write state via
|
||||
:meth:`mark_healthy` / :meth:`mark_unhealthy`.
|
||||
|
||||
Implements a simple circuit-breaker:
|
||||
|
||||
* The first failure flips no switch — providers occasionally have
|
||||
transient blips, we don't want to bounce off after a single 502.
|
||||
* After ``failure_threshold`` consecutive failures the provider is
|
||||
marked unhealthy for ``unhealthy_backoff`` seconds. During that
|
||||
window :meth:`is_healthy` returns ``False`` so the router fails
|
||||
fast straight to the next chain entry.
|
||||
* When the backoff expires :meth:`is_healthy` returns ``True`` again
|
||||
(half-open). The next call exercises the provider; success calls
|
||||
:meth:`mark_healthy` and fully resets state, failure re-arms the
|
||||
backoff window.
|
||||
|
||||
State is kept in a plain dict guarded by a ``threading.Lock``. All
|
||||
operations are short, lock-free reads of dict references aren't safe
|
||||
because we mutate state in-place — the lock keeps it boring. Probe
|
||||
loop runs in the asyncio loop alongside the router, but the lock
|
||||
costs are negligible at ~1 update/30s/provider.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import threading
|
||||
import time
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Iterable
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
DEFAULT_FAILURE_THRESHOLD = 2
|
||||
DEFAULT_UNHEALTHY_BACKOFF_SEC = 60.0
|
||||
|
||||
|
||||
@dataclass
|
||||
class ProviderState:
|
||||
"""Per-provider liveness snapshot. All times are unix seconds."""
|
||||
|
||||
healthy: bool = True
|
||||
consecutive_failures: int = 0
|
||||
last_check: float = 0.0
|
||||
last_error: str | None = None
|
||||
unhealthy_until: float = 0.0
|
||||
"""When > now, the provider is currently in backoff (`is_healthy → False`)."""
|
||||
|
||||
|
||||
class ProviderHealthCache:
|
||||
"""Thread-safe per-provider liveness with circuit-breaker semantics.
|
||||
|
||||
Provider IDs are arbitrary strings — by convention we use the same
|
||||
short name as the provider router (``ollama``, ``groq``, ``openrouter``,
|
||||
``together``, ``google``). The cache is provider-list agnostic; states
|
||||
are created lazily on first ``mark_*`` or queried-but-absent ``is_healthy``
|
||||
call (returning ``True`` by default — no state means no reason to skip).
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
failure_threshold: int = DEFAULT_FAILURE_THRESHOLD,
|
||||
unhealthy_backoff_sec: float = DEFAULT_UNHEALTHY_BACKOFF_SEC,
|
||||
clock: callable = time.time,
|
||||
) -> None:
|
||||
if failure_threshold < 1:
|
||||
raise ValueError("failure_threshold must be >= 1")
|
||||
if unhealthy_backoff_sec < 0:
|
||||
raise ValueError("unhealthy_backoff_sec must be >= 0")
|
||||
self._failure_threshold = failure_threshold
|
||||
self._unhealthy_backoff = unhealthy_backoff_sec
|
||||
self._clock = clock
|
||||
self._lock = threading.Lock()
|
||||
self._states: dict[str, ProviderState] = {}
|
||||
|
||||
@property
|
||||
def failure_threshold(self) -> int:
|
||||
return self._failure_threshold
|
||||
|
||||
@property
|
||||
def unhealthy_backoff_sec(self) -> float:
|
||||
return self._unhealthy_backoff
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Reads
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def is_healthy(self, provider_id: str) -> bool:
|
||||
"""Should the router try this provider right now?
|
||||
|
||||
Returns ``True`` by default for unknown providers — the cache is
|
||||
observation-only, not a registry.
|
||||
"""
|
||||
with self._lock:
|
||||
state = self._states.get(provider_id)
|
||||
if state is None:
|
||||
return True
|
||||
if state.unhealthy_until > self._clock():
|
||||
return False
|
||||
# Backoff expired: caller is allowed to try again (half-open).
|
||||
return True
|
||||
|
||||
def get_state(self, provider_id: str) -> ProviderState | None:
|
||||
"""Snapshot of one provider's state (for debugging / tests)."""
|
||||
with self._lock:
|
||||
state = self._states.get(provider_id)
|
||||
return None if state is None else _copy(state)
|
||||
|
||||
def snapshot(self, expected: Iterable[str] | None = None) -> dict[str, ProviderState]:
|
||||
"""All known states, plus zero-state placeholders for any names in
|
||||
``expected`` that haven't been touched yet. Used by ``GET /v1/health``
|
||||
so the response shape is stable regardless of probe order.
|
||||
"""
|
||||
with self._lock:
|
||||
out = {pid: _copy(s) for pid, s in self._states.items()}
|
||||
if expected:
|
||||
for pid in expected:
|
||||
out.setdefault(pid, ProviderState())
|
||||
return out
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Writes
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def mark_healthy(self, provider_id: str) -> None:
|
||||
"""Provider answered correctly — clear any failure state."""
|
||||
with self._lock:
|
||||
state = self._states.setdefault(provider_id, ProviderState())
|
||||
previously_unhealthy = not state.healthy
|
||||
state.healthy = True
|
||||
state.consecutive_failures = 0
|
||||
state.last_check = self._clock()
|
||||
state.last_error = None
|
||||
state.unhealthy_until = 0.0
|
||||
if previously_unhealthy:
|
||||
logger.info("provider %s recovered", provider_id)
|
||||
|
||||
def mark_unhealthy(self, provider_id: str, reason: str) -> None:
|
||||
"""Record a failure. Trips the breaker after the threshold."""
|
||||
with self._lock:
|
||||
state = self._states.setdefault(provider_id, ProviderState())
|
||||
state.consecutive_failures += 1
|
||||
state.last_check = self._clock()
|
||||
state.last_error = reason
|
||||
tripped = state.consecutive_failures >= self._failure_threshold
|
||||
if tripped and state.healthy:
|
||||
state.healthy = False
|
||||
state.unhealthy_until = self._clock() + self._unhealthy_backoff
|
||||
logger.warning(
|
||||
"provider %s marked unhealthy after %d consecutive failures (%s); "
|
||||
"backoff %.0fs",
|
||||
provider_id,
|
||||
state.consecutive_failures,
|
||||
reason,
|
||||
self._unhealthy_backoff,
|
||||
)
|
||||
elif not state.healthy:
|
||||
# Still in unhealthy window; refresh the backoff so a flapping
|
||||
# provider doesn't get re-tried every probe tick.
|
||||
state.unhealthy_until = self._clock() + self._unhealthy_backoff
|
||||
|
||||
|
||||
def _copy(state: ProviderState) -> ProviderState:
|
||||
"""Return a shallow copy so callers can read without holding the lock."""
|
||||
return ProviderState(
|
||||
healthy=state.healthy,
|
||||
consecutive_failures=state.consecutive_failures,
|
||||
last_check=state.last_check,
|
||||
last_error=state.last_error,
|
||||
unhealthy_until=state.unhealthy_until,
|
||||
)
|
||||
210
services/mana-llm/src/health_probe.py
Normal file
210
services/mana-llm/src/health_probe.py
Normal file
|
|
@ -0,0 +1,210 @@
|
|||
"""Background probe loop that keeps :class:`ProviderHealthCache` fresh.
|
||||
|
||||
The router's circuit-breaker is reactive — it only learns a provider is
|
||||
sick after the next live request fails. A reactive-only design means:
|
||||
|
||||
* every cold start re-discovers Ollama is down by paying one 75-second
|
||||
``ConnectError``, and
|
||||
* a provider that quietly recovers stays marked unhealthy until its
|
||||
backoff expires and someone tries it.
|
||||
|
||||
The probe loop closes both gaps. Every ``interval`` seconds it pings
|
||||
each registered provider with a small known-cheap request (Ollama:
|
||||
``GET /api/tags``, OpenAI-compat: ``GET /v1/models``) and updates the
|
||||
cache. Probes run concurrently per tick and respect a hard
|
||||
``probe_timeout`` so a hanging provider can't stall the loop.
|
||||
|
||||
Probe functions are injected from outside (``probes={name: async-fn}``)
|
||||
so this module stays decoupled from the provider classes — wiring lives
|
||||
in ``main.py`` where we already know which providers are configured.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from typing import Awaitable, Callable
|
||||
|
||||
from .health import ProviderHealthCache
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
#: Probe function: returns ``True`` for healthy. Raising or returning
|
||||
#: ``False`` both count as a failure (the loop just calls
|
||||
#: ``mark_unhealthy``); an exception's string form becomes the
|
||||
#: ``last_error`` for the snapshot endpoint.
|
||||
ProbeFn = Callable[[], Awaitable[bool]]
|
||||
|
||||
DEFAULT_PROBE_INTERVAL_SEC = 30.0
|
||||
DEFAULT_PROBE_TIMEOUT_SEC = 3.0
|
||||
|
||||
|
||||
class HealthProbe:
|
||||
"""Periodically probes every registered provider and updates the cache.
|
||||
|
||||
Lifecycle::
|
||||
|
||||
probe = HealthProbe(cache, {"ollama": probe_ollama, ...})
|
||||
await probe.start() # spawns the background task
|
||||
...
|
||||
await probe.stop() # cancels and awaits cleanup
|
||||
|
||||
Tests typically call :meth:`tick_once` directly to exercise one cycle
|
||||
without driving the asyncio scheduler through real ``asyncio.sleep``.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
cache: ProviderHealthCache,
|
||||
probes: dict[str, ProbeFn],
|
||||
*,
|
||||
interval: float = DEFAULT_PROBE_INTERVAL_SEC,
|
||||
timeout: float = DEFAULT_PROBE_TIMEOUT_SEC,
|
||||
) -> None:
|
||||
if interval <= 0:
|
||||
raise ValueError("interval must be > 0")
|
||||
if timeout <= 0:
|
||||
raise ValueError("timeout must be > 0")
|
||||
self._cache = cache
|
||||
self._probes = dict(probes)
|
||||
self._interval = interval
|
||||
self._timeout = timeout
|
||||
self._task: asyncio.Task | None = None
|
||||
self._stop = asyncio.Event()
|
||||
|
||||
@property
|
||||
def interval(self) -> float:
|
||||
return self._interval
|
||||
|
||||
@property
|
||||
def timeout(self) -> float:
|
||||
return self._timeout
|
||||
|
||||
@property
|
||||
def provider_ids(self) -> list[str]:
|
||||
return list(self._probes)
|
||||
|
||||
@property
|
||||
def running(self) -> bool:
|
||||
return self._task is not None and not self._task.done()
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Per-tick logic — exercised directly by tests
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
async def tick_once(self) -> None:
|
||||
"""Probe every provider once, in parallel, updating the cache.
|
||||
|
||||
Errors in any one probe (including ``asyncio.TimeoutError``) are
|
||||
captured per-provider — one bad probe never sinks the loop.
|
||||
"""
|
||||
if not self._probes:
|
||||
return
|
||||
results = await asyncio.gather(
|
||||
*(self._probe_one(name, fn) for name, fn in self._probes.items()),
|
||||
return_exceptions=True,
|
||||
)
|
||||
# gather(return_exceptions=True) caught everything per-probe; this
|
||||
# branch should never fire, but guard in case _probe_one ever grows
|
||||
# a code path that bypasses its try/except.
|
||||
# asyncio.gather() returns results in input order and same length,
|
||||
# so the zip is a 1:1 mapping back to provider names.
|
||||
for name, result in zip(self._probes, results):
|
||||
if isinstance(result, BaseException):
|
||||
logger.error("probe %s leaked exception: %s", name, result)
|
||||
|
||||
async def _probe_one(self, name: str, fn: ProbeFn) -> None:
|
||||
try:
|
||||
healthy = await asyncio.wait_for(fn(), timeout=self._timeout)
|
||||
except asyncio.TimeoutError:
|
||||
self._cache.mark_unhealthy(name, f"probe-timeout (>{self._timeout:.0f}s)")
|
||||
return
|
||||
except asyncio.CancelledError:
|
||||
raise
|
||||
except Exception as e: # noqa: BLE001 — probe SHOULD be permissive
|
||||
self._cache.mark_unhealthy(name, f"probe-exception: {type(e).__name__}: {e}")
|
||||
return
|
||||
if healthy:
|
||||
self._cache.mark_healthy(name)
|
||||
else:
|
||||
self._cache.mark_unhealthy(name, "probe-returned-false")
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Long-running task management
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
async def start(self) -> None:
|
||||
"""Spawn the periodic probe task. Idempotent."""
|
||||
if self.running:
|
||||
return
|
||||
self._stop.clear()
|
||||
self._task = asyncio.create_task(self._run_forever(), name="mana-llm-health-probe")
|
||||
logger.info(
|
||||
"HealthProbe started (interval=%.0fs, timeout=%.0fs, providers=%s)",
|
||||
self._interval,
|
||||
self._timeout,
|
||||
", ".join(self._probes) or "<none>",
|
||||
)
|
||||
|
||||
async def stop(self) -> None:
|
||||
"""Cancel the background task and wait for it to finish."""
|
||||
if not self.running:
|
||||
return
|
||||
self._stop.set()
|
||||
assert self._task is not None
|
||||
self._task.cancel()
|
||||
try:
|
||||
await self._task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
finally:
|
||||
self._task = None
|
||||
logger.info("HealthProbe stopped")
|
||||
|
||||
async def _run_forever(self) -> None:
|
||||
# Probe immediately at boot so we don't serve traffic for `interval`
|
||||
# seconds based on optimistic-default assumptions.
|
||||
try:
|
||||
await self.tick_once()
|
||||
except Exception as e: # noqa: BLE001
|
||||
logger.error("HealthProbe initial tick failed: %s", e)
|
||||
while not self._stop.is_set():
|
||||
try:
|
||||
await asyncio.wait_for(self._stop.wait(), timeout=self._interval)
|
||||
except asyncio.TimeoutError:
|
||||
pass
|
||||
else:
|
||||
# _stop.wait() succeeded → stop signalled, exit.
|
||||
return
|
||||
try:
|
||||
await self.tick_once()
|
||||
except Exception as e: # noqa: BLE001
|
||||
logger.error("HealthProbe tick failed: %s", e)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Probe-function helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def make_http_probe(
|
||||
url: str,
|
||||
*,
|
||||
headers: dict[str, str] | None = None,
|
||||
expected_status_lt: int = 500,
|
||||
) -> ProbeFn:
|
||||
"""Return a probe function that does ``GET <url>`` and considers the
|
||||
provider healthy iff the response status is below
|
||||
``expected_status_lt`` (default: any non-5xx counts).
|
||||
|
||||
A 401/403/404 still counts as healthy because the *server* answered —
|
||||
auth or path mistakes are misconfiguration, not provider liveness.
|
||||
"""
|
||||
import httpx
|
||||
|
||||
async def probe() -> bool:
|
||||
async with httpx.AsyncClient(timeout=httpx.Timeout(5.0)) as client:
|
||||
resp = await client.get(url, headers=headers or None)
|
||||
return resp.status_code < expected_status_lt
|
||||
|
||||
return probe
|
||||
203
services/mana-llm/tests/test_health.py
Normal file
203
services/mana-llm/tests/test_health.py
Normal file
|
|
@ -0,0 +1,203 @@
|
|||
"""Tests for the provider health cache."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
from src.health import (
|
||||
DEFAULT_FAILURE_THRESHOLD,
|
||||
DEFAULT_UNHEALTHY_BACKOFF_SEC,
|
||||
ProviderHealthCache,
|
||||
ProviderState,
|
||||
)
|
||||
|
||||
|
||||
class FakeClock:
|
||||
"""Deterministic clock for circuit-breaker timing tests."""
|
||||
|
||||
def __init__(self, start: float = 1_000_000.0) -> None:
|
||||
self.now = start
|
||||
|
||||
def __call__(self) -> float:
|
||||
return self.now
|
||||
|
||||
def advance(self, seconds: float) -> None:
|
||||
self.now += seconds
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Construction
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestConstruction:
|
||||
def test_defaults(self) -> None:
|
||||
c = ProviderHealthCache()
|
||||
assert c.failure_threshold == DEFAULT_FAILURE_THRESHOLD
|
||||
assert c.unhealthy_backoff_sec == DEFAULT_UNHEALTHY_BACKOFF_SEC
|
||||
|
||||
def test_invalid_threshold_rejected(self) -> None:
|
||||
with pytest.raises(ValueError, match="failure_threshold"):
|
||||
ProviderHealthCache(failure_threshold=0)
|
||||
|
||||
def test_invalid_backoff_rejected(self) -> None:
|
||||
with pytest.raises(ValueError, match="backoff"):
|
||||
ProviderHealthCache(unhealthy_backoff_sec=-1.0)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Default-healthy behaviour
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestDefaults:
|
||||
def test_unknown_provider_is_healthy(self) -> None:
|
||||
# The cache is observation-only — no entry means "no reason to skip".
|
||||
c = ProviderHealthCache()
|
||||
assert c.is_healthy("ollama") is True
|
||||
|
||||
def test_unknown_provider_has_no_state(self) -> None:
|
||||
c = ProviderHealthCache()
|
||||
assert c.get_state("ollama") is None
|
||||
|
||||
def test_snapshot_includes_expected_zero_state(self) -> None:
|
||||
c = ProviderHealthCache()
|
||||
snap = c.snapshot(expected=["ollama", "groq"])
|
||||
assert set(snap.keys()) == {"ollama", "groq"}
|
||||
assert all(s.healthy for s in snap.values())
|
||||
assert all(s.consecutive_failures == 0 for s in snap.values())
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Failure → unhealthy state machine
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestFailureSemantics:
|
||||
def test_first_failure_does_not_trip(self) -> None:
|
||||
# Single transient blips shouldn't bounce a provider — wait for the
|
||||
# next consecutive failure to confirm.
|
||||
c = ProviderHealthCache(failure_threshold=2)
|
||||
c.mark_unhealthy("ollama", "boom")
|
||||
assert c.is_healthy("ollama") is True
|
||||
state = c.get_state("ollama")
|
||||
assert state is not None
|
||||
assert state.consecutive_failures == 1
|
||||
assert state.healthy is True
|
||||
|
||||
def test_threshold_reached_trips_breaker(self) -> None:
|
||||
clock = FakeClock()
|
||||
c = ProviderHealthCache(
|
||||
failure_threshold=2,
|
||||
unhealthy_backoff_sec=60.0,
|
||||
clock=clock,
|
||||
)
|
||||
c.mark_unhealthy("ollama", "fail-1")
|
||||
c.mark_unhealthy("ollama", "fail-2")
|
||||
assert c.is_healthy("ollama") is False
|
||||
state = c.get_state("ollama")
|
||||
assert state is not None
|
||||
assert state.healthy is False
|
||||
assert state.last_error == "fail-2"
|
||||
assert state.unhealthy_until == clock.now + 60.0
|
||||
|
||||
def test_threshold_one_trips_immediately(self) -> None:
|
||||
c = ProviderHealthCache(failure_threshold=1)
|
||||
c.mark_unhealthy("ollama", "boom")
|
||||
assert c.is_healthy("ollama") is False
|
||||
|
||||
def test_mark_healthy_clears_state(self) -> None:
|
||||
c = ProviderHealthCache(failure_threshold=2)
|
||||
c.mark_unhealthy("ollama", "x")
|
||||
c.mark_unhealthy("ollama", "x")
|
||||
assert c.is_healthy("ollama") is False
|
||||
c.mark_healthy("ollama")
|
||||
assert c.is_healthy("ollama") is True
|
||||
state = c.get_state("ollama")
|
||||
assert state is not None
|
||||
assert state.healthy is True
|
||||
assert state.consecutive_failures == 0
|
||||
assert state.unhealthy_until == 0.0
|
||||
assert state.last_error is None
|
||||
|
||||
|
||||
class TestBackoffWindow:
|
||||
def test_is_healthy_false_during_backoff(self) -> None:
|
||||
clock = FakeClock()
|
||||
c = ProviderHealthCache(failure_threshold=1, unhealthy_backoff_sec=60.0, clock=clock)
|
||||
c.mark_unhealthy("ollama", "boom")
|
||||
assert c.is_healthy("ollama") is False
|
||||
clock.advance(30.0)
|
||||
assert c.is_healthy("ollama") is False
|
||||
|
||||
def test_is_healthy_true_after_backoff_expires(self) -> None:
|
||||
# After backoff: half-open. Router gets one more attempt; success
|
||||
# mark_healthy resets, failure mark_unhealthy re-arms backoff.
|
||||
clock = FakeClock()
|
||||
c = ProviderHealthCache(failure_threshold=1, unhealthy_backoff_sec=60.0, clock=clock)
|
||||
c.mark_unhealthy("ollama", "boom")
|
||||
clock.advance(61.0)
|
||||
assert c.is_healthy("ollama") is True
|
||||
|
||||
def test_failure_during_backoff_extends_window(self) -> None:
|
||||
clock = FakeClock()
|
||||
c = ProviderHealthCache(failure_threshold=1, unhealthy_backoff_sec=60.0, clock=clock)
|
||||
c.mark_unhealthy("ollama", "first")
|
||||
original_until = c.get_state("ollama").unhealthy_until
|
||||
clock.advance(20.0)
|
||||
c.mark_unhealthy("ollama", "second")
|
||||
new_until = c.get_state("ollama").unhealthy_until
|
||||
assert new_until > original_until
|
||||
assert new_until == clock.now + 60.0
|
||||
|
||||
def test_recovery_logged_only_once(self, caplog: pytest.LogCaptureFixture) -> None:
|
||||
clock = FakeClock()
|
||||
c = ProviderHealthCache(failure_threshold=1, unhealthy_backoff_sec=60.0, clock=clock)
|
||||
c.mark_unhealthy("ollama", "boom")
|
||||
with caplog.at_level("INFO"):
|
||||
c.mark_healthy("ollama")
|
||||
# Calling mark_healthy on an already-healthy provider must not
|
||||
# spam the recovery log line.
|
||||
c.mark_healthy("ollama")
|
||||
c.mark_healthy("ollama")
|
||||
recovery_lines = [r for r in caplog.records if "recovered" in r.message]
|
||||
assert len(recovery_lines) == 1
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Snapshot shape
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestSnapshot:
|
||||
def test_snapshot_returns_copies(self) -> None:
|
||||
# Caller shouldn't be able to poke through to the cache's internal
|
||||
# state via a snapshot reference.
|
||||
c = ProviderHealthCache(failure_threshold=1)
|
||||
c.mark_unhealthy("ollama", "x")
|
||||
snap = c.snapshot()
|
||||
snap["ollama"].healthy = True
|
||||
# Original state untouched:
|
||||
assert c.is_healthy("ollama") is False
|
||||
|
||||
def test_snapshot_matches_recorded_state(self) -> None:
|
||||
clock = FakeClock()
|
||||
c = ProviderHealthCache(
|
||||
failure_threshold=2,
|
||||
unhealthy_backoff_sec=60.0,
|
||||
clock=clock,
|
||||
)
|
||||
c.mark_unhealthy("groq", "rate-limit")
|
||||
snap = c.snapshot()
|
||||
assert isinstance(snap["groq"], ProviderState)
|
||||
assert snap["groq"].consecutive_failures == 1
|
||||
assert snap["groq"].last_error == "rate-limit"
|
||||
assert snap["groq"].last_check == clock.now
|
||||
|
||||
def test_snapshot_expected_does_not_overwrite_real_state(self) -> None:
|
||||
c = ProviderHealthCache(failure_threshold=1)
|
||||
c.mark_unhealthy("ollama", "real-boom")
|
||||
snap = c.snapshot(expected=["ollama", "groq"])
|
||||
# ollama keeps its real (unhealthy) state, groq gets the zero-default.
|
||||
assert snap["ollama"].consecutive_failures == 1
|
||||
assert snap["groq"].consecutive_failures == 0
|
||||
222
services/mana-llm/tests/test_health_probe.py
Normal file
222
services/mana-llm/tests/test_health_probe.py
Normal file
|
|
@ -0,0 +1,222 @@
|
|||
"""Tests for the background health-probe loop."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from typing import Awaitable, Callable
|
||||
|
||||
import pytest
|
||||
|
||||
from src.health import ProviderHealthCache
|
||||
from src.health_probe import HealthProbe, ProbeFn
|
||||
|
||||
|
||||
def make_probe(*, returns: bool = True, raises: type[BaseException] | None = None) -> ProbeFn:
|
||||
"""Synthesise a probe function that always returns / raises the same."""
|
||||
|
||||
async def probe() -> bool:
|
||||
if raises is not None:
|
||||
raise raises("boom")
|
||||
return returns
|
||||
|
||||
return probe
|
||||
|
||||
|
||||
def make_slow_probe(delay: float, returns: bool = True) -> ProbeFn:
|
||||
async def probe() -> bool:
|
||||
await asyncio.sleep(delay)
|
||||
return returns
|
||||
|
||||
return probe
|
||||
|
||||
|
||||
def make_call_counter() -> tuple[ProbeFn, Callable[[], int]]:
|
||||
"""Probe that counts how many times it was awaited."""
|
||||
count = 0
|
||||
|
||||
async def probe() -> bool:
|
||||
nonlocal count
|
||||
count += 1
|
||||
return True
|
||||
|
||||
return probe, lambda: count
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Construction
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestConstruction:
|
||||
def test_invalid_interval(self) -> None:
|
||||
with pytest.raises(ValueError, match="interval"):
|
||||
HealthProbe(ProviderHealthCache(), {}, interval=0.0)
|
||||
|
||||
def test_invalid_timeout(self) -> None:
|
||||
with pytest.raises(ValueError, match="timeout"):
|
||||
HealthProbe(ProviderHealthCache(), {}, timeout=0.0)
|
||||
|
||||
def test_provider_ids_exposes_keys(self) -> None:
|
||||
cache = ProviderHealthCache()
|
||||
probe = HealthProbe(
|
||||
cache,
|
||||
{"ollama": make_probe(), "groq": make_probe()},
|
||||
)
|
||||
assert sorted(probe.provider_ids) == ["groq", "ollama"]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# tick_once — the per-cycle behaviour
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestTickOnce:
|
||||
@pytest.mark.asyncio
|
||||
async def test_healthy_probe_marks_healthy(self) -> None:
|
||||
cache = ProviderHealthCache(failure_threshold=1)
|
||||
# Pre-mark unhealthy so we can verify the probe recovers it.
|
||||
cache.mark_unhealthy("ollama", "stale")
|
||||
probe = HealthProbe(cache, {"ollama": make_probe(returns=True)})
|
||||
await probe.tick_once()
|
||||
assert cache.is_healthy("ollama") is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_returning_false_marks_unhealthy(self) -> None:
|
||||
cache = ProviderHealthCache(failure_threshold=1)
|
||||
probe = HealthProbe(cache, {"ollama": make_probe(returns=False)})
|
||||
await probe.tick_once()
|
||||
assert cache.is_healthy("ollama") is False
|
||||
state = cache.get_state("ollama")
|
||||
assert state is not None
|
||||
assert "false" in (state.last_error or "")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_raising_marks_unhealthy_with_exc_info(self) -> None:
|
||||
cache = ProviderHealthCache(failure_threshold=1)
|
||||
probe = HealthProbe(
|
||||
cache, {"ollama": make_probe(raises=ConnectionError)}
|
||||
)
|
||||
await probe.tick_once()
|
||||
assert cache.is_healthy("ollama") is False
|
||||
state = cache.get_state("ollama")
|
||||
assert state is not None
|
||||
assert "ConnectionError" in (state.last_error or "")
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_timeout_marks_unhealthy(self) -> None:
|
||||
cache = ProviderHealthCache(failure_threshold=1)
|
||||
probe = HealthProbe(
|
||||
cache,
|
||||
{"ollama": make_slow_probe(delay=1.0)},
|
||||
timeout=0.05,
|
||||
)
|
||||
await probe.tick_once()
|
||||
assert cache.is_healthy("ollama") is False
|
||||
state = cache.get_state("ollama")
|
||||
assert state is not None
|
||||
assert "timeout" in (state.last_error or "").lower()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_one_bad_probe_does_not_sink_others(self) -> None:
|
||||
# Probe 'ollama' raises — 'groq' must still be evaluated and marked
|
||||
# healthy. Bug shape: an unhandled exception in gather() sinks the
|
||||
# whole loop.
|
||||
cache = ProviderHealthCache(failure_threshold=1)
|
||||
probe = HealthProbe(
|
||||
cache,
|
||||
{
|
||||
"ollama": make_probe(raises=RuntimeError),
|
||||
"groq": make_probe(returns=True),
|
||||
},
|
||||
)
|
||||
await probe.tick_once()
|
||||
assert cache.is_healthy("ollama") is False
|
||||
assert cache.is_healthy("groq") is True
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_concurrent_probes(self) -> None:
|
||||
# All probes should run in parallel — total elapsed wall-clock for
|
||||
# N x 100ms probes should be well under N*100ms.
|
||||
import time
|
||||
|
||||
cache = ProviderHealthCache()
|
||||
probes = {f"p{i}": make_slow_probe(delay=0.1) for i in range(5)}
|
||||
probe = HealthProbe(cache, probes, timeout=1.0)
|
||||
t0 = time.perf_counter()
|
||||
await probe.tick_once()
|
||||
elapsed = time.perf_counter() - t0
|
||||
assert elapsed < 0.3, f"probes ran serially? elapsed={elapsed:.3f}s"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_empty_probes_is_noop(self) -> None:
|
||||
cache = ProviderHealthCache()
|
||||
probe = HealthProbe(cache, {})
|
||||
# No exception, no state mutation.
|
||||
await probe.tick_once()
|
||||
assert cache.snapshot() == {}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# start / stop lifecycle
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestLifecycle:
|
||||
@pytest.mark.asyncio
|
||||
async def test_start_runs_initial_tick_immediately(self) -> None:
|
||||
cache = ProviderHealthCache(failure_threshold=1)
|
||||
cache.mark_unhealthy("ollama", "stale")
|
||||
probe = HealthProbe(cache, {"ollama": make_probe(returns=True)}, interval=10.0)
|
||||
await probe.start()
|
||||
# Give the loop one event-loop turn to run the initial tick before
|
||||
# blocking on the long sleep.
|
||||
await asyncio.sleep(0.01)
|
||||
assert cache.is_healthy("ollama") is True
|
||||
await probe.stop()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_stop_cancels_cleanly(self) -> None:
|
||||
cache = ProviderHealthCache()
|
||||
probe = HealthProbe(
|
||||
cache, {"ollama": make_probe()}, interval=10.0, timeout=1.0
|
||||
)
|
||||
await probe.start()
|
||||
assert probe.running is True
|
||||
await probe.stop()
|
||||
assert probe.running is False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_start_is_idempotent(self) -> None:
|
||||
cache = ProviderHealthCache()
|
||||
probe = HealthProbe(cache, {"ollama": make_probe()}, interval=10.0)
|
||||
await probe.start()
|
||||
await probe.start() # must not spawn a second task
|
||||
assert probe.running is True
|
||||
await probe.stop()
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_stop_without_start_is_safe(self) -> None:
|
||||
cache = ProviderHealthCache()
|
||||
probe = HealthProbe(cache, {})
|
||||
await probe.stop() # idempotent / safe pre-start
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_loop_keeps_running_after_tick_error(self) -> None:
|
||||
# Even if every probe explodes, the loop must keep ticking.
|
||||
cache = ProviderHealthCache(failure_threshold=1)
|
||||
fn, count = make_call_counter()
|
||||
# Wrap with one that raises — but tick_once internally catches
|
||||
# per-probe via gather(return_exceptions=True). Force an outer error
|
||||
# via an evil probe key that the dict can't handle? Easier: use the
|
||||
# call-counter to verify multiple ticks happened.
|
||||
probe = HealthProbe(
|
||||
cache,
|
||||
{"counter": fn},
|
||||
interval=0.05, # short interval for test
|
||||
timeout=1.0,
|
||||
)
|
||||
await probe.start()
|
||||
await asyncio.sleep(0.18) # ~3-4 ticks
|
||||
await probe.stop()
|
||||
# Initial tick + at least 2 interval ticks.
|
||||
assert count() >= 3
|
||||
Loading…
Add table
Add a link
Reference in a new issue