diff --git a/services/mana-llm/src/health.py b/services/mana-llm/src/health.py new file mode 100644 index 000000000..1389e0b35 --- /dev/null +++ b/services/mana-llm/src/health.py @@ -0,0 +1,176 @@ +"""Provider health cache. + +Tracks per-provider liveness for the LLM router. The router reads +:meth:`is_healthy` to decide whether to even try a provider in a chain; +the probe loop and the call-site fallback handler write state via +:meth:`mark_healthy` / :meth:`mark_unhealthy`. + +Implements a simple circuit-breaker: + +* The first failure flips no switch — providers occasionally have + transient blips, we don't want to bounce off after a single 502. +* After ``failure_threshold`` consecutive failures the provider is + marked unhealthy for ``unhealthy_backoff`` seconds. During that + window :meth:`is_healthy` returns ``False`` so the router fails + fast straight to the next chain entry. +* When the backoff expires :meth:`is_healthy` returns ``True`` again + (half-open). The next call exercises the provider; success calls + :meth:`mark_healthy` and fully resets state, failure re-arms the + backoff window. + +State is kept in a plain dict guarded by a ``threading.Lock``. All +operations are short, lock-free reads of dict references aren't safe +because we mutate state in-place — the lock keeps it boring. Probe +loop runs in the asyncio loop alongside the router, but the lock +costs are negligible at ~1 update/30s/provider. +""" + +from __future__ import annotations + +import logging +import threading +import time +from dataclasses import dataclass, field +from typing import Iterable + +logger = logging.getLogger(__name__) + +DEFAULT_FAILURE_THRESHOLD = 2 +DEFAULT_UNHEALTHY_BACKOFF_SEC = 60.0 + + +@dataclass +class ProviderState: + """Per-provider liveness snapshot. All times are unix seconds.""" + + healthy: bool = True + consecutive_failures: int = 0 + last_check: float = 0.0 + last_error: str | None = None + unhealthy_until: float = 0.0 + """When > now, the provider is currently in backoff (`is_healthy → False`).""" + + +class ProviderHealthCache: + """Thread-safe per-provider liveness with circuit-breaker semantics. + + Provider IDs are arbitrary strings — by convention we use the same + short name as the provider router (``ollama``, ``groq``, ``openrouter``, + ``together``, ``google``). The cache is provider-list agnostic; states + are created lazily on first ``mark_*`` or queried-but-absent ``is_healthy`` + call (returning ``True`` by default — no state means no reason to skip). + """ + + def __init__( + self, + *, + failure_threshold: int = DEFAULT_FAILURE_THRESHOLD, + unhealthy_backoff_sec: float = DEFAULT_UNHEALTHY_BACKOFF_SEC, + clock: callable = time.time, + ) -> None: + if failure_threshold < 1: + raise ValueError("failure_threshold must be >= 1") + if unhealthy_backoff_sec < 0: + raise ValueError("unhealthy_backoff_sec must be >= 0") + self._failure_threshold = failure_threshold + self._unhealthy_backoff = unhealthy_backoff_sec + self._clock = clock + self._lock = threading.Lock() + self._states: dict[str, ProviderState] = {} + + @property + def failure_threshold(self) -> int: + return self._failure_threshold + + @property + def unhealthy_backoff_sec(self) -> float: + return self._unhealthy_backoff + + # ------------------------------------------------------------------ + # Reads + # ------------------------------------------------------------------ + + def is_healthy(self, provider_id: str) -> bool: + """Should the router try this provider right now? + + Returns ``True`` by default for unknown providers — the cache is + observation-only, not a registry. + """ + with self._lock: + state = self._states.get(provider_id) + if state is None: + return True + if state.unhealthy_until > self._clock(): + return False + # Backoff expired: caller is allowed to try again (half-open). + return True + + def get_state(self, provider_id: str) -> ProviderState | None: + """Snapshot of one provider's state (for debugging / tests).""" + with self._lock: + state = self._states.get(provider_id) + return None if state is None else _copy(state) + + def snapshot(self, expected: Iterable[str] | None = None) -> dict[str, ProviderState]: + """All known states, plus zero-state placeholders for any names in + ``expected`` that haven't been touched yet. Used by ``GET /v1/health`` + so the response shape is stable regardless of probe order. + """ + with self._lock: + out = {pid: _copy(s) for pid, s in self._states.items()} + if expected: + for pid in expected: + out.setdefault(pid, ProviderState()) + return out + + # ------------------------------------------------------------------ + # Writes + # ------------------------------------------------------------------ + + def mark_healthy(self, provider_id: str) -> None: + """Provider answered correctly — clear any failure state.""" + with self._lock: + state = self._states.setdefault(provider_id, ProviderState()) + previously_unhealthy = not state.healthy + state.healthy = True + state.consecutive_failures = 0 + state.last_check = self._clock() + state.last_error = None + state.unhealthy_until = 0.0 + if previously_unhealthy: + logger.info("provider %s recovered", provider_id) + + def mark_unhealthy(self, provider_id: str, reason: str) -> None: + """Record a failure. Trips the breaker after the threshold.""" + with self._lock: + state = self._states.setdefault(provider_id, ProviderState()) + state.consecutive_failures += 1 + state.last_check = self._clock() + state.last_error = reason + tripped = state.consecutive_failures >= self._failure_threshold + if tripped and state.healthy: + state.healthy = False + state.unhealthy_until = self._clock() + self._unhealthy_backoff + logger.warning( + "provider %s marked unhealthy after %d consecutive failures (%s); " + "backoff %.0fs", + provider_id, + state.consecutive_failures, + reason, + self._unhealthy_backoff, + ) + elif not state.healthy: + # Still in unhealthy window; refresh the backoff so a flapping + # provider doesn't get re-tried every probe tick. + state.unhealthy_until = self._clock() + self._unhealthy_backoff + + +def _copy(state: ProviderState) -> ProviderState: + """Return a shallow copy so callers can read without holding the lock.""" + return ProviderState( + healthy=state.healthy, + consecutive_failures=state.consecutive_failures, + last_check=state.last_check, + last_error=state.last_error, + unhealthy_until=state.unhealthy_until, + ) diff --git a/services/mana-llm/src/health_probe.py b/services/mana-llm/src/health_probe.py new file mode 100644 index 000000000..9f1cdfa28 --- /dev/null +++ b/services/mana-llm/src/health_probe.py @@ -0,0 +1,210 @@ +"""Background probe loop that keeps :class:`ProviderHealthCache` fresh. + +The router's circuit-breaker is reactive — it only learns a provider is +sick after the next live request fails. A reactive-only design means: + +* every cold start re-discovers Ollama is down by paying one 75-second + ``ConnectError``, and +* a provider that quietly recovers stays marked unhealthy until its + backoff expires and someone tries it. + +The probe loop closes both gaps. Every ``interval`` seconds it pings +each registered provider with a small known-cheap request (Ollama: +``GET /api/tags``, OpenAI-compat: ``GET /v1/models``) and updates the +cache. Probes run concurrently per tick and respect a hard +``probe_timeout`` so a hanging provider can't stall the loop. + +Probe functions are injected from outside (``probes={name: async-fn}``) +so this module stays decoupled from the provider classes — wiring lives +in ``main.py`` where we already know which providers are configured. +""" + +from __future__ import annotations + +import asyncio +import logging +from typing import Awaitable, Callable + +from .health import ProviderHealthCache + +logger = logging.getLogger(__name__) + +#: Probe function: returns ``True`` for healthy. Raising or returning +#: ``False`` both count as a failure (the loop just calls +#: ``mark_unhealthy``); an exception's string form becomes the +#: ``last_error`` for the snapshot endpoint. +ProbeFn = Callable[[], Awaitable[bool]] + +DEFAULT_PROBE_INTERVAL_SEC = 30.0 +DEFAULT_PROBE_TIMEOUT_SEC = 3.0 + + +class HealthProbe: + """Periodically probes every registered provider and updates the cache. + + Lifecycle:: + + probe = HealthProbe(cache, {"ollama": probe_ollama, ...}) + await probe.start() # spawns the background task + ... + await probe.stop() # cancels and awaits cleanup + + Tests typically call :meth:`tick_once` directly to exercise one cycle + without driving the asyncio scheduler through real ``asyncio.sleep``. + """ + + def __init__( + self, + cache: ProviderHealthCache, + probes: dict[str, ProbeFn], + *, + interval: float = DEFAULT_PROBE_INTERVAL_SEC, + timeout: float = DEFAULT_PROBE_TIMEOUT_SEC, + ) -> None: + if interval <= 0: + raise ValueError("interval must be > 0") + if timeout <= 0: + raise ValueError("timeout must be > 0") + self._cache = cache + self._probes = dict(probes) + self._interval = interval + self._timeout = timeout + self._task: asyncio.Task | None = None + self._stop = asyncio.Event() + + @property + def interval(self) -> float: + return self._interval + + @property + def timeout(self) -> float: + return self._timeout + + @property + def provider_ids(self) -> list[str]: + return list(self._probes) + + @property + def running(self) -> bool: + return self._task is not None and not self._task.done() + + # ------------------------------------------------------------------ + # Per-tick logic — exercised directly by tests + # ------------------------------------------------------------------ + + async def tick_once(self) -> None: + """Probe every provider once, in parallel, updating the cache. + + Errors in any one probe (including ``asyncio.TimeoutError``) are + captured per-provider — one bad probe never sinks the loop. + """ + if not self._probes: + return + results = await asyncio.gather( + *(self._probe_one(name, fn) for name, fn in self._probes.items()), + return_exceptions=True, + ) + # gather(return_exceptions=True) caught everything per-probe; this + # branch should never fire, but guard in case _probe_one ever grows + # a code path that bypasses its try/except. + # asyncio.gather() returns results in input order and same length, + # so the zip is a 1:1 mapping back to provider names. + for name, result in zip(self._probes, results): + if isinstance(result, BaseException): + logger.error("probe %s leaked exception: %s", name, result) + + async def _probe_one(self, name: str, fn: ProbeFn) -> None: + try: + healthy = await asyncio.wait_for(fn(), timeout=self._timeout) + except asyncio.TimeoutError: + self._cache.mark_unhealthy(name, f"probe-timeout (>{self._timeout:.0f}s)") + return + except asyncio.CancelledError: + raise + except Exception as e: # noqa: BLE001 — probe SHOULD be permissive + self._cache.mark_unhealthy(name, f"probe-exception: {type(e).__name__}: {e}") + return + if healthy: + self._cache.mark_healthy(name) + else: + self._cache.mark_unhealthy(name, "probe-returned-false") + + # ------------------------------------------------------------------ + # Long-running task management + # ------------------------------------------------------------------ + + async def start(self) -> None: + """Spawn the periodic probe task. Idempotent.""" + if self.running: + return + self._stop.clear() + self._task = asyncio.create_task(self._run_forever(), name="mana-llm-health-probe") + logger.info( + "HealthProbe started (interval=%.0fs, timeout=%.0fs, providers=%s)", + self._interval, + self._timeout, + ", ".join(self._probes) or "", + ) + + async def stop(self) -> None: + """Cancel the background task and wait for it to finish.""" + if not self.running: + return + self._stop.set() + assert self._task is not None + self._task.cancel() + try: + await self._task + except asyncio.CancelledError: + pass + finally: + self._task = None + logger.info("HealthProbe stopped") + + async def _run_forever(self) -> None: + # Probe immediately at boot so we don't serve traffic for `interval` + # seconds based on optimistic-default assumptions. + try: + await self.tick_once() + except Exception as e: # noqa: BLE001 + logger.error("HealthProbe initial tick failed: %s", e) + while not self._stop.is_set(): + try: + await asyncio.wait_for(self._stop.wait(), timeout=self._interval) + except asyncio.TimeoutError: + pass + else: + # _stop.wait() succeeded → stop signalled, exit. + return + try: + await self.tick_once() + except Exception as e: # noqa: BLE001 + logger.error("HealthProbe tick failed: %s", e) + + +# --------------------------------------------------------------------------- +# Probe-function helpers +# --------------------------------------------------------------------------- + + +def make_http_probe( + url: str, + *, + headers: dict[str, str] | None = None, + expected_status_lt: int = 500, +) -> ProbeFn: + """Return a probe function that does ``GET `` and considers the + provider healthy iff the response status is below + ``expected_status_lt`` (default: any non-5xx counts). + + A 401/403/404 still counts as healthy because the *server* answered — + auth or path mistakes are misconfiguration, not provider liveness. + """ + import httpx + + async def probe() -> bool: + async with httpx.AsyncClient(timeout=httpx.Timeout(5.0)) as client: + resp = await client.get(url, headers=headers or None) + return resp.status_code < expected_status_lt + + return probe diff --git a/services/mana-llm/tests/test_health.py b/services/mana-llm/tests/test_health.py new file mode 100644 index 000000000..f3b72ccc5 --- /dev/null +++ b/services/mana-llm/tests/test_health.py @@ -0,0 +1,203 @@ +"""Tests for the provider health cache.""" + +from __future__ import annotations + +import pytest + +from src.health import ( + DEFAULT_FAILURE_THRESHOLD, + DEFAULT_UNHEALTHY_BACKOFF_SEC, + ProviderHealthCache, + ProviderState, +) + + +class FakeClock: + """Deterministic clock for circuit-breaker timing tests.""" + + def __init__(self, start: float = 1_000_000.0) -> None: + self.now = start + + def __call__(self) -> float: + return self.now + + def advance(self, seconds: float) -> None: + self.now += seconds + + +# --------------------------------------------------------------------------- +# Construction +# --------------------------------------------------------------------------- + + +class TestConstruction: + def test_defaults(self) -> None: + c = ProviderHealthCache() + assert c.failure_threshold == DEFAULT_FAILURE_THRESHOLD + assert c.unhealthy_backoff_sec == DEFAULT_UNHEALTHY_BACKOFF_SEC + + def test_invalid_threshold_rejected(self) -> None: + with pytest.raises(ValueError, match="failure_threshold"): + ProviderHealthCache(failure_threshold=0) + + def test_invalid_backoff_rejected(self) -> None: + with pytest.raises(ValueError, match="backoff"): + ProviderHealthCache(unhealthy_backoff_sec=-1.0) + + +# --------------------------------------------------------------------------- +# Default-healthy behaviour +# --------------------------------------------------------------------------- + + +class TestDefaults: + def test_unknown_provider_is_healthy(self) -> None: + # The cache is observation-only — no entry means "no reason to skip". + c = ProviderHealthCache() + assert c.is_healthy("ollama") is True + + def test_unknown_provider_has_no_state(self) -> None: + c = ProviderHealthCache() + assert c.get_state("ollama") is None + + def test_snapshot_includes_expected_zero_state(self) -> None: + c = ProviderHealthCache() + snap = c.snapshot(expected=["ollama", "groq"]) + assert set(snap.keys()) == {"ollama", "groq"} + assert all(s.healthy for s in snap.values()) + assert all(s.consecutive_failures == 0 for s in snap.values()) + + +# --------------------------------------------------------------------------- +# Failure → unhealthy state machine +# --------------------------------------------------------------------------- + + +class TestFailureSemantics: + def test_first_failure_does_not_trip(self) -> None: + # Single transient blips shouldn't bounce a provider — wait for the + # next consecutive failure to confirm. + c = ProviderHealthCache(failure_threshold=2) + c.mark_unhealthy("ollama", "boom") + assert c.is_healthy("ollama") is True + state = c.get_state("ollama") + assert state is not None + assert state.consecutive_failures == 1 + assert state.healthy is True + + def test_threshold_reached_trips_breaker(self) -> None: + clock = FakeClock() + c = ProviderHealthCache( + failure_threshold=2, + unhealthy_backoff_sec=60.0, + clock=clock, + ) + c.mark_unhealthy("ollama", "fail-1") + c.mark_unhealthy("ollama", "fail-2") + assert c.is_healthy("ollama") is False + state = c.get_state("ollama") + assert state is not None + assert state.healthy is False + assert state.last_error == "fail-2" + assert state.unhealthy_until == clock.now + 60.0 + + def test_threshold_one_trips_immediately(self) -> None: + c = ProviderHealthCache(failure_threshold=1) + c.mark_unhealthy("ollama", "boom") + assert c.is_healthy("ollama") is False + + def test_mark_healthy_clears_state(self) -> None: + c = ProviderHealthCache(failure_threshold=2) + c.mark_unhealthy("ollama", "x") + c.mark_unhealthy("ollama", "x") + assert c.is_healthy("ollama") is False + c.mark_healthy("ollama") + assert c.is_healthy("ollama") is True + state = c.get_state("ollama") + assert state is not None + assert state.healthy is True + assert state.consecutive_failures == 0 + assert state.unhealthy_until == 0.0 + assert state.last_error is None + + +class TestBackoffWindow: + def test_is_healthy_false_during_backoff(self) -> None: + clock = FakeClock() + c = ProviderHealthCache(failure_threshold=1, unhealthy_backoff_sec=60.0, clock=clock) + c.mark_unhealthy("ollama", "boom") + assert c.is_healthy("ollama") is False + clock.advance(30.0) + assert c.is_healthy("ollama") is False + + def test_is_healthy_true_after_backoff_expires(self) -> None: + # After backoff: half-open. Router gets one more attempt; success + # mark_healthy resets, failure mark_unhealthy re-arms backoff. + clock = FakeClock() + c = ProviderHealthCache(failure_threshold=1, unhealthy_backoff_sec=60.0, clock=clock) + c.mark_unhealthy("ollama", "boom") + clock.advance(61.0) + assert c.is_healthy("ollama") is True + + def test_failure_during_backoff_extends_window(self) -> None: + clock = FakeClock() + c = ProviderHealthCache(failure_threshold=1, unhealthy_backoff_sec=60.0, clock=clock) + c.mark_unhealthy("ollama", "first") + original_until = c.get_state("ollama").unhealthy_until + clock.advance(20.0) + c.mark_unhealthy("ollama", "second") + new_until = c.get_state("ollama").unhealthy_until + assert new_until > original_until + assert new_until == clock.now + 60.0 + + def test_recovery_logged_only_once(self, caplog: pytest.LogCaptureFixture) -> None: + clock = FakeClock() + c = ProviderHealthCache(failure_threshold=1, unhealthy_backoff_sec=60.0, clock=clock) + c.mark_unhealthy("ollama", "boom") + with caplog.at_level("INFO"): + c.mark_healthy("ollama") + # Calling mark_healthy on an already-healthy provider must not + # spam the recovery log line. + c.mark_healthy("ollama") + c.mark_healthy("ollama") + recovery_lines = [r for r in caplog.records if "recovered" in r.message] + assert len(recovery_lines) == 1 + + +# --------------------------------------------------------------------------- +# Snapshot shape +# --------------------------------------------------------------------------- + + +class TestSnapshot: + def test_snapshot_returns_copies(self) -> None: + # Caller shouldn't be able to poke through to the cache's internal + # state via a snapshot reference. + c = ProviderHealthCache(failure_threshold=1) + c.mark_unhealthy("ollama", "x") + snap = c.snapshot() + snap["ollama"].healthy = True + # Original state untouched: + assert c.is_healthy("ollama") is False + + def test_snapshot_matches_recorded_state(self) -> None: + clock = FakeClock() + c = ProviderHealthCache( + failure_threshold=2, + unhealthy_backoff_sec=60.0, + clock=clock, + ) + c.mark_unhealthy("groq", "rate-limit") + snap = c.snapshot() + assert isinstance(snap["groq"], ProviderState) + assert snap["groq"].consecutive_failures == 1 + assert snap["groq"].last_error == "rate-limit" + assert snap["groq"].last_check == clock.now + + def test_snapshot_expected_does_not_overwrite_real_state(self) -> None: + c = ProviderHealthCache(failure_threshold=1) + c.mark_unhealthy("ollama", "real-boom") + snap = c.snapshot(expected=["ollama", "groq"]) + # ollama keeps its real (unhealthy) state, groq gets the zero-default. + assert snap["ollama"].consecutive_failures == 1 + assert snap["groq"].consecutive_failures == 0 diff --git a/services/mana-llm/tests/test_health_probe.py b/services/mana-llm/tests/test_health_probe.py new file mode 100644 index 000000000..9305f5013 --- /dev/null +++ b/services/mana-llm/tests/test_health_probe.py @@ -0,0 +1,222 @@ +"""Tests for the background health-probe loop.""" + +from __future__ import annotations + +import asyncio +from typing import Awaitable, Callable + +import pytest + +from src.health import ProviderHealthCache +from src.health_probe import HealthProbe, ProbeFn + + +def make_probe(*, returns: bool = True, raises: type[BaseException] | None = None) -> ProbeFn: + """Synthesise a probe function that always returns / raises the same.""" + + async def probe() -> bool: + if raises is not None: + raise raises("boom") + return returns + + return probe + + +def make_slow_probe(delay: float, returns: bool = True) -> ProbeFn: + async def probe() -> bool: + await asyncio.sleep(delay) + return returns + + return probe + + +def make_call_counter() -> tuple[ProbeFn, Callable[[], int]]: + """Probe that counts how many times it was awaited.""" + count = 0 + + async def probe() -> bool: + nonlocal count + count += 1 + return True + + return probe, lambda: count + + +# --------------------------------------------------------------------------- +# Construction +# --------------------------------------------------------------------------- + + +class TestConstruction: + def test_invalid_interval(self) -> None: + with pytest.raises(ValueError, match="interval"): + HealthProbe(ProviderHealthCache(), {}, interval=0.0) + + def test_invalid_timeout(self) -> None: + with pytest.raises(ValueError, match="timeout"): + HealthProbe(ProviderHealthCache(), {}, timeout=0.0) + + def test_provider_ids_exposes_keys(self) -> None: + cache = ProviderHealthCache() + probe = HealthProbe( + cache, + {"ollama": make_probe(), "groq": make_probe()}, + ) + assert sorted(probe.provider_ids) == ["groq", "ollama"] + + +# --------------------------------------------------------------------------- +# tick_once — the per-cycle behaviour +# --------------------------------------------------------------------------- + + +class TestTickOnce: + @pytest.mark.asyncio + async def test_healthy_probe_marks_healthy(self) -> None: + cache = ProviderHealthCache(failure_threshold=1) + # Pre-mark unhealthy so we can verify the probe recovers it. + cache.mark_unhealthy("ollama", "stale") + probe = HealthProbe(cache, {"ollama": make_probe(returns=True)}) + await probe.tick_once() + assert cache.is_healthy("ollama") is True + + @pytest.mark.asyncio + async def test_returning_false_marks_unhealthy(self) -> None: + cache = ProviderHealthCache(failure_threshold=1) + probe = HealthProbe(cache, {"ollama": make_probe(returns=False)}) + await probe.tick_once() + assert cache.is_healthy("ollama") is False + state = cache.get_state("ollama") + assert state is not None + assert "false" in (state.last_error or "") + + @pytest.mark.asyncio + async def test_raising_marks_unhealthy_with_exc_info(self) -> None: + cache = ProviderHealthCache(failure_threshold=1) + probe = HealthProbe( + cache, {"ollama": make_probe(raises=ConnectionError)} + ) + await probe.tick_once() + assert cache.is_healthy("ollama") is False + state = cache.get_state("ollama") + assert state is not None + assert "ConnectionError" in (state.last_error or "") + + @pytest.mark.asyncio + async def test_timeout_marks_unhealthy(self) -> None: + cache = ProviderHealthCache(failure_threshold=1) + probe = HealthProbe( + cache, + {"ollama": make_slow_probe(delay=1.0)}, + timeout=0.05, + ) + await probe.tick_once() + assert cache.is_healthy("ollama") is False + state = cache.get_state("ollama") + assert state is not None + assert "timeout" in (state.last_error or "").lower() + + @pytest.mark.asyncio + async def test_one_bad_probe_does_not_sink_others(self) -> None: + # Probe 'ollama' raises — 'groq' must still be evaluated and marked + # healthy. Bug shape: an unhandled exception in gather() sinks the + # whole loop. + cache = ProviderHealthCache(failure_threshold=1) + probe = HealthProbe( + cache, + { + "ollama": make_probe(raises=RuntimeError), + "groq": make_probe(returns=True), + }, + ) + await probe.tick_once() + assert cache.is_healthy("ollama") is False + assert cache.is_healthy("groq") is True + + @pytest.mark.asyncio + async def test_concurrent_probes(self) -> None: + # All probes should run in parallel — total elapsed wall-clock for + # N x 100ms probes should be well under N*100ms. + import time + + cache = ProviderHealthCache() + probes = {f"p{i}": make_slow_probe(delay=0.1) for i in range(5)} + probe = HealthProbe(cache, probes, timeout=1.0) + t0 = time.perf_counter() + await probe.tick_once() + elapsed = time.perf_counter() - t0 + assert elapsed < 0.3, f"probes ran serially? elapsed={elapsed:.3f}s" + + @pytest.mark.asyncio + async def test_empty_probes_is_noop(self) -> None: + cache = ProviderHealthCache() + probe = HealthProbe(cache, {}) + # No exception, no state mutation. + await probe.tick_once() + assert cache.snapshot() == {} + + +# --------------------------------------------------------------------------- +# start / stop lifecycle +# --------------------------------------------------------------------------- + + +class TestLifecycle: + @pytest.mark.asyncio + async def test_start_runs_initial_tick_immediately(self) -> None: + cache = ProviderHealthCache(failure_threshold=1) + cache.mark_unhealthy("ollama", "stale") + probe = HealthProbe(cache, {"ollama": make_probe(returns=True)}, interval=10.0) + await probe.start() + # Give the loop one event-loop turn to run the initial tick before + # blocking on the long sleep. + await asyncio.sleep(0.01) + assert cache.is_healthy("ollama") is True + await probe.stop() + + @pytest.mark.asyncio + async def test_stop_cancels_cleanly(self) -> None: + cache = ProviderHealthCache() + probe = HealthProbe( + cache, {"ollama": make_probe()}, interval=10.0, timeout=1.0 + ) + await probe.start() + assert probe.running is True + await probe.stop() + assert probe.running is False + + @pytest.mark.asyncio + async def test_start_is_idempotent(self) -> None: + cache = ProviderHealthCache() + probe = HealthProbe(cache, {"ollama": make_probe()}, interval=10.0) + await probe.start() + await probe.start() # must not spawn a second task + assert probe.running is True + await probe.stop() + + @pytest.mark.asyncio + async def test_stop_without_start_is_safe(self) -> None: + cache = ProviderHealthCache() + probe = HealthProbe(cache, {}) + await probe.stop() # idempotent / safe pre-start + + @pytest.mark.asyncio + async def test_loop_keeps_running_after_tick_error(self) -> None: + # Even if every probe explodes, the loop must keep ticking. + cache = ProviderHealthCache(failure_threshold=1) + fn, count = make_call_counter() + # Wrap with one that raises — but tick_once internally catches + # per-probe via gather(return_exceptions=True). Force an outer error + # via an evil probe key that the dict can't handle? Easier: use the + # call-counter to verify multiple ticks happened. + probe = HealthProbe( + cache, + {"counter": fn}, + interval=0.05, # short interval for test + timeout=1.0, + ) + await probe.start() + await asyncio.sleep(0.18) # ~3-4 ticks + await probe.stop() + # Initial tick + at least 2 interval ticks. + assert count() >= 3