feat(mana-llm): M2 — ProviderHealthCache + background probe loop

Per-provider liveness with circuit-breaker semantics. The router (M3)
will read `is_healthy()` to skip dead providers in a chain; the probe
loop and the call-site fallback handler write state via
`mark_healthy` / `mark_unhealthy`.

State machine: 1st failure stays healthy (transient blips happen);
2nd consecutive failure trips the breaker and sets a 60s backoff
window during which `is_healthy → False`. After the window the
provider is half-open again — next call exercises it, success
resets, failure re-arms.

HealthProbe is the background asyncio.Task that pings every
registered provider every 30s with a 3s timeout. Probes run
concurrently per tick and one bad probe can't sink the loop. Probe
functions are injected (`{name: async-fn}`) so this module stays
decoupled from the provider classes — the wiring lives in main.py
where we already know which providers are configured.

32 new tests (FakeClock for deterministic backoff timing, slow-probe
helpers for parallelism + timeout, lifecycle tests for start/stop
idempotency and tick-after-error survival). 64/64 alias+health tests
green.

Not yet wired into the request path — that's M3.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
Till JS 2026-04-26 20:29:57 +02:00
parent dff8629e1d
commit 59557e62d7
4 changed files with 811 additions and 0 deletions

View file

@ -0,0 +1,176 @@
"""Provider health cache.
Tracks per-provider liveness for the LLM router. The router reads
:meth:`is_healthy` to decide whether to even try a provider in a chain;
the probe loop and the call-site fallback handler write state via
:meth:`mark_healthy` / :meth:`mark_unhealthy`.
Implements a simple circuit-breaker:
* The first failure flips no switch providers occasionally have
transient blips, we don't want to bounce off after a single 502.
* After ``failure_threshold`` consecutive failures the provider is
marked unhealthy for ``unhealthy_backoff`` seconds. During that
window :meth:`is_healthy` returns ``False`` so the router fails
fast straight to the next chain entry.
* When the backoff expires :meth:`is_healthy` returns ``True`` again
(half-open). The next call exercises the provider; success calls
:meth:`mark_healthy` and fully resets state, failure re-arms the
backoff window.
State is kept in a plain dict guarded by a ``threading.Lock``. All
operations are short, lock-free reads of dict references aren't safe
because we mutate state in-place the lock keeps it boring. Probe
loop runs in the asyncio loop alongside the router, but the lock
costs are negligible at ~1 update/30s/provider.
"""
from __future__ import annotations
import logging
import threading
import time
from dataclasses import dataclass, field
from typing import Iterable
logger = logging.getLogger(__name__)
DEFAULT_FAILURE_THRESHOLD = 2
DEFAULT_UNHEALTHY_BACKOFF_SEC = 60.0
@dataclass
class ProviderState:
"""Per-provider liveness snapshot. All times are unix seconds."""
healthy: bool = True
consecutive_failures: int = 0
last_check: float = 0.0
last_error: str | None = None
unhealthy_until: float = 0.0
"""When > now, the provider is currently in backoff (`is_healthy → False`)."""
class ProviderHealthCache:
"""Thread-safe per-provider liveness with circuit-breaker semantics.
Provider IDs are arbitrary strings by convention we use the same
short name as the provider router (``ollama``, ``groq``, ``openrouter``,
``together``, ``google``). The cache is provider-list agnostic; states
are created lazily on first ``mark_*`` or queried-but-absent ``is_healthy``
call (returning ``True`` by default no state means no reason to skip).
"""
def __init__(
self,
*,
failure_threshold: int = DEFAULT_FAILURE_THRESHOLD,
unhealthy_backoff_sec: float = DEFAULT_UNHEALTHY_BACKOFF_SEC,
clock: callable = time.time,
) -> None:
if failure_threshold < 1:
raise ValueError("failure_threshold must be >= 1")
if unhealthy_backoff_sec < 0:
raise ValueError("unhealthy_backoff_sec must be >= 0")
self._failure_threshold = failure_threshold
self._unhealthy_backoff = unhealthy_backoff_sec
self._clock = clock
self._lock = threading.Lock()
self._states: dict[str, ProviderState] = {}
@property
def failure_threshold(self) -> int:
return self._failure_threshold
@property
def unhealthy_backoff_sec(self) -> float:
return self._unhealthy_backoff
# ------------------------------------------------------------------
# Reads
# ------------------------------------------------------------------
def is_healthy(self, provider_id: str) -> bool:
"""Should the router try this provider right now?
Returns ``True`` by default for unknown providers the cache is
observation-only, not a registry.
"""
with self._lock:
state = self._states.get(provider_id)
if state is None:
return True
if state.unhealthy_until > self._clock():
return False
# Backoff expired: caller is allowed to try again (half-open).
return True
def get_state(self, provider_id: str) -> ProviderState | None:
"""Snapshot of one provider's state (for debugging / tests)."""
with self._lock:
state = self._states.get(provider_id)
return None if state is None else _copy(state)
def snapshot(self, expected: Iterable[str] | None = None) -> dict[str, ProviderState]:
"""All known states, plus zero-state placeholders for any names in
``expected`` that haven't been touched yet. Used by ``GET /v1/health``
so the response shape is stable regardless of probe order.
"""
with self._lock:
out = {pid: _copy(s) for pid, s in self._states.items()}
if expected:
for pid in expected:
out.setdefault(pid, ProviderState())
return out
# ------------------------------------------------------------------
# Writes
# ------------------------------------------------------------------
def mark_healthy(self, provider_id: str) -> None:
"""Provider answered correctly — clear any failure state."""
with self._lock:
state = self._states.setdefault(provider_id, ProviderState())
previously_unhealthy = not state.healthy
state.healthy = True
state.consecutive_failures = 0
state.last_check = self._clock()
state.last_error = None
state.unhealthy_until = 0.0
if previously_unhealthy:
logger.info("provider %s recovered", provider_id)
def mark_unhealthy(self, provider_id: str, reason: str) -> None:
"""Record a failure. Trips the breaker after the threshold."""
with self._lock:
state = self._states.setdefault(provider_id, ProviderState())
state.consecutive_failures += 1
state.last_check = self._clock()
state.last_error = reason
tripped = state.consecutive_failures >= self._failure_threshold
if tripped and state.healthy:
state.healthy = False
state.unhealthy_until = self._clock() + self._unhealthy_backoff
logger.warning(
"provider %s marked unhealthy after %d consecutive failures (%s); "
"backoff %.0fs",
provider_id,
state.consecutive_failures,
reason,
self._unhealthy_backoff,
)
elif not state.healthy:
# Still in unhealthy window; refresh the backoff so a flapping
# provider doesn't get re-tried every probe tick.
state.unhealthy_until = self._clock() + self._unhealthy_backoff
def _copy(state: ProviderState) -> ProviderState:
"""Return a shallow copy so callers can read without holding the lock."""
return ProviderState(
healthy=state.healthy,
consecutive_failures=state.consecutive_failures,
last_check=state.last_check,
last_error=state.last_error,
unhealthy_until=state.unhealthy_until,
)

View file

@ -0,0 +1,210 @@
"""Background probe loop that keeps :class:`ProviderHealthCache` fresh.
The router's circuit-breaker is reactive — it only learns a provider is
sick after the next live request fails. A reactive-only design means:
* every cold start re-discovers Ollama is down by paying one 75-second
``ConnectError``, and
* a provider that quietly recovers stays marked unhealthy until its
backoff expires and someone tries it.
The probe loop closes both gaps. Every ``interval`` seconds it pings
each registered provider with a small known-cheap request (Ollama:
``GET /api/tags``, OpenAI-compat: ``GET /v1/models``) and updates the
cache. Probes run concurrently per tick and respect a hard
``probe_timeout`` so a hanging provider can't stall the loop.
Probe functions are injected from outside (``probes={name: async-fn}``)
so this module stays decoupled from the provider classes wiring lives
in ``main.py`` where we already know which providers are configured.
"""
from __future__ import annotations
import asyncio
import logging
from typing import Awaitable, Callable
from .health import ProviderHealthCache
logger = logging.getLogger(__name__)
#: Probe function: returns ``True`` for healthy. Raising or returning
#: ``False`` both count as a failure (the loop just calls
#: ``mark_unhealthy``); an exception's string form becomes the
#: ``last_error`` for the snapshot endpoint.
ProbeFn = Callable[[], Awaitable[bool]]
DEFAULT_PROBE_INTERVAL_SEC = 30.0
DEFAULT_PROBE_TIMEOUT_SEC = 3.0
class HealthProbe:
"""Periodically probes every registered provider and updates the cache.
Lifecycle::
probe = HealthProbe(cache, {"ollama": probe_ollama, ...})
await probe.start() # spawns the background task
...
await probe.stop() # cancels and awaits cleanup
Tests typically call :meth:`tick_once` directly to exercise one cycle
without driving the asyncio scheduler through real ``asyncio.sleep``.
"""
def __init__(
self,
cache: ProviderHealthCache,
probes: dict[str, ProbeFn],
*,
interval: float = DEFAULT_PROBE_INTERVAL_SEC,
timeout: float = DEFAULT_PROBE_TIMEOUT_SEC,
) -> None:
if interval <= 0:
raise ValueError("interval must be > 0")
if timeout <= 0:
raise ValueError("timeout must be > 0")
self._cache = cache
self._probes = dict(probes)
self._interval = interval
self._timeout = timeout
self._task: asyncio.Task | None = None
self._stop = asyncio.Event()
@property
def interval(self) -> float:
return self._interval
@property
def timeout(self) -> float:
return self._timeout
@property
def provider_ids(self) -> list[str]:
return list(self._probes)
@property
def running(self) -> bool:
return self._task is not None and not self._task.done()
# ------------------------------------------------------------------
# Per-tick logic — exercised directly by tests
# ------------------------------------------------------------------
async def tick_once(self) -> None:
"""Probe every provider once, in parallel, updating the cache.
Errors in any one probe (including ``asyncio.TimeoutError``) are
captured per-provider one bad probe never sinks the loop.
"""
if not self._probes:
return
results = await asyncio.gather(
*(self._probe_one(name, fn) for name, fn in self._probes.items()),
return_exceptions=True,
)
# gather(return_exceptions=True) caught everything per-probe; this
# branch should never fire, but guard in case _probe_one ever grows
# a code path that bypasses its try/except.
# asyncio.gather() returns results in input order and same length,
# so the zip is a 1:1 mapping back to provider names.
for name, result in zip(self._probes, results):
if isinstance(result, BaseException):
logger.error("probe %s leaked exception: %s", name, result)
async def _probe_one(self, name: str, fn: ProbeFn) -> None:
try:
healthy = await asyncio.wait_for(fn(), timeout=self._timeout)
except asyncio.TimeoutError:
self._cache.mark_unhealthy(name, f"probe-timeout (>{self._timeout:.0f}s)")
return
except asyncio.CancelledError:
raise
except Exception as e: # noqa: BLE001 — probe SHOULD be permissive
self._cache.mark_unhealthy(name, f"probe-exception: {type(e).__name__}: {e}")
return
if healthy:
self._cache.mark_healthy(name)
else:
self._cache.mark_unhealthy(name, "probe-returned-false")
# ------------------------------------------------------------------
# Long-running task management
# ------------------------------------------------------------------
async def start(self) -> None:
"""Spawn the periodic probe task. Idempotent."""
if self.running:
return
self._stop.clear()
self._task = asyncio.create_task(self._run_forever(), name="mana-llm-health-probe")
logger.info(
"HealthProbe started (interval=%.0fs, timeout=%.0fs, providers=%s)",
self._interval,
self._timeout,
", ".join(self._probes) or "<none>",
)
async def stop(self) -> None:
"""Cancel the background task and wait for it to finish."""
if not self.running:
return
self._stop.set()
assert self._task is not None
self._task.cancel()
try:
await self._task
except asyncio.CancelledError:
pass
finally:
self._task = None
logger.info("HealthProbe stopped")
async def _run_forever(self) -> None:
# Probe immediately at boot so we don't serve traffic for `interval`
# seconds based on optimistic-default assumptions.
try:
await self.tick_once()
except Exception as e: # noqa: BLE001
logger.error("HealthProbe initial tick failed: %s", e)
while not self._stop.is_set():
try:
await asyncio.wait_for(self._stop.wait(), timeout=self._interval)
except asyncio.TimeoutError:
pass
else:
# _stop.wait() succeeded → stop signalled, exit.
return
try:
await self.tick_once()
except Exception as e: # noqa: BLE001
logger.error("HealthProbe tick failed: %s", e)
# ---------------------------------------------------------------------------
# Probe-function helpers
# ---------------------------------------------------------------------------
def make_http_probe(
url: str,
*,
headers: dict[str, str] | None = None,
expected_status_lt: int = 500,
) -> ProbeFn:
"""Return a probe function that does ``GET <url>`` and considers the
provider healthy iff the response status is below
``expected_status_lt`` (default: any non-5xx counts).
A 401/403/404 still counts as healthy because the *server* answered
auth or path mistakes are misconfiguration, not provider liveness.
"""
import httpx
async def probe() -> bool:
async with httpx.AsyncClient(timeout=httpx.Timeout(5.0)) as client:
resp = await client.get(url, headers=headers or None)
return resp.status_code < expected_status_lt
return probe

View file

@ -0,0 +1,203 @@
"""Tests for the provider health cache."""
from __future__ import annotations
import pytest
from src.health import (
DEFAULT_FAILURE_THRESHOLD,
DEFAULT_UNHEALTHY_BACKOFF_SEC,
ProviderHealthCache,
ProviderState,
)
class FakeClock:
"""Deterministic clock for circuit-breaker timing tests."""
def __init__(self, start: float = 1_000_000.0) -> None:
self.now = start
def __call__(self) -> float:
return self.now
def advance(self, seconds: float) -> None:
self.now += seconds
# ---------------------------------------------------------------------------
# Construction
# ---------------------------------------------------------------------------
class TestConstruction:
def test_defaults(self) -> None:
c = ProviderHealthCache()
assert c.failure_threshold == DEFAULT_FAILURE_THRESHOLD
assert c.unhealthy_backoff_sec == DEFAULT_UNHEALTHY_BACKOFF_SEC
def test_invalid_threshold_rejected(self) -> None:
with pytest.raises(ValueError, match="failure_threshold"):
ProviderHealthCache(failure_threshold=0)
def test_invalid_backoff_rejected(self) -> None:
with pytest.raises(ValueError, match="backoff"):
ProviderHealthCache(unhealthy_backoff_sec=-1.0)
# ---------------------------------------------------------------------------
# Default-healthy behaviour
# ---------------------------------------------------------------------------
class TestDefaults:
def test_unknown_provider_is_healthy(self) -> None:
# The cache is observation-only — no entry means "no reason to skip".
c = ProviderHealthCache()
assert c.is_healthy("ollama") is True
def test_unknown_provider_has_no_state(self) -> None:
c = ProviderHealthCache()
assert c.get_state("ollama") is None
def test_snapshot_includes_expected_zero_state(self) -> None:
c = ProviderHealthCache()
snap = c.snapshot(expected=["ollama", "groq"])
assert set(snap.keys()) == {"ollama", "groq"}
assert all(s.healthy for s in snap.values())
assert all(s.consecutive_failures == 0 for s in snap.values())
# ---------------------------------------------------------------------------
# Failure → unhealthy state machine
# ---------------------------------------------------------------------------
class TestFailureSemantics:
def test_first_failure_does_not_trip(self) -> None:
# Single transient blips shouldn't bounce a provider — wait for the
# next consecutive failure to confirm.
c = ProviderHealthCache(failure_threshold=2)
c.mark_unhealthy("ollama", "boom")
assert c.is_healthy("ollama") is True
state = c.get_state("ollama")
assert state is not None
assert state.consecutive_failures == 1
assert state.healthy is True
def test_threshold_reached_trips_breaker(self) -> None:
clock = FakeClock()
c = ProviderHealthCache(
failure_threshold=2,
unhealthy_backoff_sec=60.0,
clock=clock,
)
c.mark_unhealthy("ollama", "fail-1")
c.mark_unhealthy("ollama", "fail-2")
assert c.is_healthy("ollama") is False
state = c.get_state("ollama")
assert state is not None
assert state.healthy is False
assert state.last_error == "fail-2"
assert state.unhealthy_until == clock.now + 60.0
def test_threshold_one_trips_immediately(self) -> None:
c = ProviderHealthCache(failure_threshold=1)
c.mark_unhealthy("ollama", "boom")
assert c.is_healthy("ollama") is False
def test_mark_healthy_clears_state(self) -> None:
c = ProviderHealthCache(failure_threshold=2)
c.mark_unhealthy("ollama", "x")
c.mark_unhealthy("ollama", "x")
assert c.is_healthy("ollama") is False
c.mark_healthy("ollama")
assert c.is_healthy("ollama") is True
state = c.get_state("ollama")
assert state is not None
assert state.healthy is True
assert state.consecutive_failures == 0
assert state.unhealthy_until == 0.0
assert state.last_error is None
class TestBackoffWindow:
def test_is_healthy_false_during_backoff(self) -> None:
clock = FakeClock()
c = ProviderHealthCache(failure_threshold=1, unhealthy_backoff_sec=60.0, clock=clock)
c.mark_unhealthy("ollama", "boom")
assert c.is_healthy("ollama") is False
clock.advance(30.0)
assert c.is_healthy("ollama") is False
def test_is_healthy_true_after_backoff_expires(self) -> None:
# After backoff: half-open. Router gets one more attempt; success
# mark_healthy resets, failure mark_unhealthy re-arms backoff.
clock = FakeClock()
c = ProviderHealthCache(failure_threshold=1, unhealthy_backoff_sec=60.0, clock=clock)
c.mark_unhealthy("ollama", "boom")
clock.advance(61.0)
assert c.is_healthy("ollama") is True
def test_failure_during_backoff_extends_window(self) -> None:
clock = FakeClock()
c = ProviderHealthCache(failure_threshold=1, unhealthy_backoff_sec=60.0, clock=clock)
c.mark_unhealthy("ollama", "first")
original_until = c.get_state("ollama").unhealthy_until
clock.advance(20.0)
c.mark_unhealthy("ollama", "second")
new_until = c.get_state("ollama").unhealthy_until
assert new_until > original_until
assert new_until == clock.now + 60.0
def test_recovery_logged_only_once(self, caplog: pytest.LogCaptureFixture) -> None:
clock = FakeClock()
c = ProviderHealthCache(failure_threshold=1, unhealthy_backoff_sec=60.0, clock=clock)
c.mark_unhealthy("ollama", "boom")
with caplog.at_level("INFO"):
c.mark_healthy("ollama")
# Calling mark_healthy on an already-healthy provider must not
# spam the recovery log line.
c.mark_healthy("ollama")
c.mark_healthy("ollama")
recovery_lines = [r for r in caplog.records if "recovered" in r.message]
assert len(recovery_lines) == 1
# ---------------------------------------------------------------------------
# Snapshot shape
# ---------------------------------------------------------------------------
class TestSnapshot:
def test_snapshot_returns_copies(self) -> None:
# Caller shouldn't be able to poke through to the cache's internal
# state via a snapshot reference.
c = ProviderHealthCache(failure_threshold=1)
c.mark_unhealthy("ollama", "x")
snap = c.snapshot()
snap["ollama"].healthy = True
# Original state untouched:
assert c.is_healthy("ollama") is False
def test_snapshot_matches_recorded_state(self) -> None:
clock = FakeClock()
c = ProviderHealthCache(
failure_threshold=2,
unhealthy_backoff_sec=60.0,
clock=clock,
)
c.mark_unhealthy("groq", "rate-limit")
snap = c.snapshot()
assert isinstance(snap["groq"], ProviderState)
assert snap["groq"].consecutive_failures == 1
assert snap["groq"].last_error == "rate-limit"
assert snap["groq"].last_check == clock.now
def test_snapshot_expected_does_not_overwrite_real_state(self) -> None:
c = ProviderHealthCache(failure_threshold=1)
c.mark_unhealthy("ollama", "real-boom")
snap = c.snapshot(expected=["ollama", "groq"])
# ollama keeps its real (unhealthy) state, groq gets the zero-default.
assert snap["ollama"].consecutive_failures == 1
assert snap["groq"].consecutive_failures == 0

View file

@ -0,0 +1,222 @@
"""Tests for the background health-probe loop."""
from __future__ import annotations
import asyncio
from typing import Awaitable, Callable
import pytest
from src.health import ProviderHealthCache
from src.health_probe import HealthProbe, ProbeFn
def make_probe(*, returns: bool = True, raises: type[BaseException] | None = None) -> ProbeFn:
"""Synthesise a probe function that always returns / raises the same."""
async def probe() -> bool:
if raises is not None:
raise raises("boom")
return returns
return probe
def make_slow_probe(delay: float, returns: bool = True) -> ProbeFn:
async def probe() -> bool:
await asyncio.sleep(delay)
return returns
return probe
def make_call_counter() -> tuple[ProbeFn, Callable[[], int]]:
"""Probe that counts how many times it was awaited."""
count = 0
async def probe() -> bool:
nonlocal count
count += 1
return True
return probe, lambda: count
# ---------------------------------------------------------------------------
# Construction
# ---------------------------------------------------------------------------
class TestConstruction:
def test_invalid_interval(self) -> None:
with pytest.raises(ValueError, match="interval"):
HealthProbe(ProviderHealthCache(), {}, interval=0.0)
def test_invalid_timeout(self) -> None:
with pytest.raises(ValueError, match="timeout"):
HealthProbe(ProviderHealthCache(), {}, timeout=0.0)
def test_provider_ids_exposes_keys(self) -> None:
cache = ProviderHealthCache()
probe = HealthProbe(
cache,
{"ollama": make_probe(), "groq": make_probe()},
)
assert sorted(probe.provider_ids) == ["groq", "ollama"]
# ---------------------------------------------------------------------------
# tick_once — the per-cycle behaviour
# ---------------------------------------------------------------------------
class TestTickOnce:
@pytest.mark.asyncio
async def test_healthy_probe_marks_healthy(self) -> None:
cache = ProviderHealthCache(failure_threshold=1)
# Pre-mark unhealthy so we can verify the probe recovers it.
cache.mark_unhealthy("ollama", "stale")
probe = HealthProbe(cache, {"ollama": make_probe(returns=True)})
await probe.tick_once()
assert cache.is_healthy("ollama") is True
@pytest.mark.asyncio
async def test_returning_false_marks_unhealthy(self) -> None:
cache = ProviderHealthCache(failure_threshold=1)
probe = HealthProbe(cache, {"ollama": make_probe(returns=False)})
await probe.tick_once()
assert cache.is_healthy("ollama") is False
state = cache.get_state("ollama")
assert state is not None
assert "false" in (state.last_error or "")
@pytest.mark.asyncio
async def test_raising_marks_unhealthy_with_exc_info(self) -> None:
cache = ProviderHealthCache(failure_threshold=1)
probe = HealthProbe(
cache, {"ollama": make_probe(raises=ConnectionError)}
)
await probe.tick_once()
assert cache.is_healthy("ollama") is False
state = cache.get_state("ollama")
assert state is not None
assert "ConnectionError" in (state.last_error or "")
@pytest.mark.asyncio
async def test_timeout_marks_unhealthy(self) -> None:
cache = ProviderHealthCache(failure_threshold=1)
probe = HealthProbe(
cache,
{"ollama": make_slow_probe(delay=1.0)},
timeout=0.05,
)
await probe.tick_once()
assert cache.is_healthy("ollama") is False
state = cache.get_state("ollama")
assert state is not None
assert "timeout" in (state.last_error or "").lower()
@pytest.mark.asyncio
async def test_one_bad_probe_does_not_sink_others(self) -> None:
# Probe 'ollama' raises — 'groq' must still be evaluated and marked
# healthy. Bug shape: an unhandled exception in gather() sinks the
# whole loop.
cache = ProviderHealthCache(failure_threshold=1)
probe = HealthProbe(
cache,
{
"ollama": make_probe(raises=RuntimeError),
"groq": make_probe(returns=True),
},
)
await probe.tick_once()
assert cache.is_healthy("ollama") is False
assert cache.is_healthy("groq") is True
@pytest.mark.asyncio
async def test_concurrent_probes(self) -> None:
# All probes should run in parallel — total elapsed wall-clock for
# N x 100ms probes should be well under N*100ms.
import time
cache = ProviderHealthCache()
probes = {f"p{i}": make_slow_probe(delay=0.1) for i in range(5)}
probe = HealthProbe(cache, probes, timeout=1.0)
t0 = time.perf_counter()
await probe.tick_once()
elapsed = time.perf_counter() - t0
assert elapsed < 0.3, f"probes ran serially? elapsed={elapsed:.3f}s"
@pytest.mark.asyncio
async def test_empty_probes_is_noop(self) -> None:
cache = ProviderHealthCache()
probe = HealthProbe(cache, {})
# No exception, no state mutation.
await probe.tick_once()
assert cache.snapshot() == {}
# ---------------------------------------------------------------------------
# start / stop lifecycle
# ---------------------------------------------------------------------------
class TestLifecycle:
@pytest.mark.asyncio
async def test_start_runs_initial_tick_immediately(self) -> None:
cache = ProviderHealthCache(failure_threshold=1)
cache.mark_unhealthy("ollama", "stale")
probe = HealthProbe(cache, {"ollama": make_probe(returns=True)}, interval=10.0)
await probe.start()
# Give the loop one event-loop turn to run the initial tick before
# blocking on the long sleep.
await asyncio.sleep(0.01)
assert cache.is_healthy("ollama") is True
await probe.stop()
@pytest.mark.asyncio
async def test_stop_cancels_cleanly(self) -> None:
cache = ProviderHealthCache()
probe = HealthProbe(
cache, {"ollama": make_probe()}, interval=10.0, timeout=1.0
)
await probe.start()
assert probe.running is True
await probe.stop()
assert probe.running is False
@pytest.mark.asyncio
async def test_start_is_idempotent(self) -> None:
cache = ProviderHealthCache()
probe = HealthProbe(cache, {"ollama": make_probe()}, interval=10.0)
await probe.start()
await probe.start() # must not spawn a second task
assert probe.running is True
await probe.stop()
@pytest.mark.asyncio
async def test_stop_without_start_is_safe(self) -> None:
cache = ProviderHealthCache()
probe = HealthProbe(cache, {})
await probe.stop() # idempotent / safe pre-start
@pytest.mark.asyncio
async def test_loop_keeps_running_after_tick_error(self) -> None:
# Even if every probe explodes, the loop must keep ticking.
cache = ProviderHealthCache(failure_threshold=1)
fn, count = make_call_counter()
# Wrap with one that raises — but tick_once internally catches
# per-probe via gather(return_exceptions=True). Force an outer error
# via an evil probe key that the dict can't handle? Easier: use the
# call-counter to verify multiple ticks happened.
probe = HealthProbe(
cache,
{"counter": fn},
interval=0.05, # short interval for test
timeout=1.0,
)
await probe.start()
await asyncio.sleep(0.18) # ~3-4 ticks
await probe.stop()
# Initial tick + at least 2 interval ticks.
assert count() >= 3