From 8a49e3ffd563ec51cd4355746a7501937eafce34 Mon Sep 17 00:00:00 2001 From: Till JS Date: Sun, 26 Apr 2026 20:52:28 +0200 Subject: [PATCH] =?UTF-8?q?feat(mana-llm):=20M4=20=E2=80=94=20observabilit?= =?UTF-8?q?y,=20debug=20endpoints,=20SIGHUP=20reload?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit - `X-Mana-LLM-Resolved: /` header on non-streaming responses. Streaming clients read the same info from each chunk's `model` field (SSE headers go out before the chain is walked). - Three new Prometheus metrics: `mana_llm_alias_resolved_total{alias, target}` (which concrete model an alias resolved to per request), `mana_llm_fallback_total{from_model, to_model, reason}` (each fallback transition), `mana_llm_provider_healthy{provider}` (gauge, mirrors the circuit-breaker). - New debug endpoints: `GET /v1/aliases` (registry inspection — chain + description per alias, useful for confirming SIGHUP reloads), `GET /v1/health` (full per-provider liveness snapshot — failure counter, last error, unhealthy-until backoff). - `kill -HUP ` reloads `aliases.yaml`. Parse errors leave the previous good state in memory and log the rejection. - `ProviderHealthCache.add_listener()` for cache→metrics decoupling: the gauge is updated via a transition-only listener wired in main.py rather than the cache importing prometheus_client itself. - Request-side metrics now use the requested model string, success-side uses the resolved one. So `mana_llm_llm_requests_total{provider="ollama", model="gemma3:12b"}` reflects actual upstream load even when callers used `mana/long-form` aliases. 16 new observability tests (test_m4_observability.py): listener fire-on-transition semantics, exception-isolation, multi-listener, counter increments, gauge writes, end-to-end alias→metric flow, v1/aliases + v1/health endpoint shape, response.model carries the resolved target after fallback. Total suite: 115/115 in 1.6s. Co-Authored-By: Claude Opus 4.7 (1M context) --- services/mana-llm/CLAUDE.md | 88 +++- services/mana-llm/src/health.py | 35 +- services/mana-llm/src/main.py | 153 ++++++- services/mana-llm/src/providers/router.py | 36 +- services/mana-llm/src/utils/metrics.py | 51 ++- .../mana-llm/tests/test_m4_observability.py | 415 ++++++++++++++++++ 6 files changed, 749 insertions(+), 29 deletions(-) create mode 100644 services/mana-llm/tests/test_m4_observability.py diff --git a/services/mana-llm/CLAUDE.md b/services/mana-llm/CLAUDE.md index a4cbb5b05..bdf59de4f 100644 --- a/services/mana-llm/CLAUDE.md +++ b/services/mana-llm/CLAUDE.md @@ -143,13 +143,99 @@ curl -X POST http://localhost:3025/v1/embeddings \ ### Health & Metrics ```bash -# Health check +# Liveness summary (legacy, terse shape — only status + per-provider status string) curl http://localhost:3025/health +# Detailed per-provider liveness snapshot (M4) +curl http://localhost:3025/v1/health + # Prometheus metrics curl http://localhost:3025/metrics ``` +### Aliases (M4 — see `aliases.yaml` and the fallback section below) + +```bash +# What does each `mana/` resolve to? +curl http://localhost:3025/v1/aliases +``` + +## Aliases & Fallback + +> Background: [`docs/plans/llm-fallback-aliases.md`](../../docs/plans/llm-fallback-aliases.md) + +### What callers send + +Two acceptable shapes for the `model` field of `/v1/chat/completions`: + +1. **Aliases** in the reserved `mana/` namespace — recommended for product code. + The router resolves them via `aliases.yaml` to a chain of concrete + `provider/model` strings and tries them in order. +2. **Direct `provider/model`** — bypasses the alias layer, no fallback. + Useful for tests, debugging, and one-off integrations. + +| Alias | Class | +|---|---| +| `mana/fast-text` | Short answers, classification, single-shot Q&A | +| `mana/long-form` | Writing, essays, stories, longer prose | +| `mana/structured` | JSON output (comic storyboards, research subqueries, tag suggestions) | +| `mana/reasoning` | Agent missions, tool calls, multi-step plans | +| `mana/vision` | Multimodal (image + text) | + +The chain for each alias lives in `services/mana-llm/aliases.yaml`. Edit +the file and `kill -HUP ` to reload — no restart needed. Reload +errors keep the previous good state; check the service logs. + +### Fallback semantics + +Every chain is tried in order. The router skips an entry if the provider +isn't configured at this deployment (no API key) or is currently marked +unhealthy by the health-cache. For each remaining entry the request is +attempted; on a **retryable** error (connection failure, timeout, 5xx, +rate-limit, RemoteProtocolError) the provider is marked unhealthy and +the next entry is tried. **Non-retryable** errors (auth, capability, +content-blocked, 4xx, unknown exception types) propagate immediately — +no fallback, the cache is not poisoned. + +Streaming follows the same logic up to the **first byte**. Once a chunk +has been yielded the provider is committed; mid-stream errors surface +as-is so we never splice two providers' voices into one output. + +If every entry was skipped or failed, the response is `503` carrying a +structured `attempts: list[(model, reason)]` log so the cause is +visible to the caller, not only in service logs. + +### Resolved-model header + +Non-streaming responses carry `X-Mana-LLM-Resolved: /` +(e.g. `groq/llama-3.3-70b-versatile`) — the concrete model that +actually answered. Use this for token-cost attribution when the request +used an alias. For streaming, each chunk's `model` field carries the +same info (headers go out before the chain is walked). + +### Health-cache + probe + +`ProviderHealthCache` keeps a per-provider circuit-breaker: + +* 1 failure: still healthy (transient blip, don't bounce). +* 2 consecutive failures: `is_healthy → False` for 60 s; the router + fail-fasts straight to the next chain entry. +* After 60 s: half-open. Next call exercises the provider; success + fully resets, failure re-arms the backoff. + +A background `HealthProbe` task runs every 30 s with a 3 s timeout per +provider, calling cheap endpoints (`/api/tags` for Ollama, `/v1/models` +for OpenAI-compat). One bad probe can't sink the loop; results feed +into the same cache as the call-site fallback. + +### Prometheus metrics added in M4 + +| Metric | Labels | Purpose | +|---|---|---| +| `mana_llm_alias_resolved_total` | `alias`, `target` | How often an alias resolved to which concrete model — useful for spotting cases where the primary always falls through. | +| `mana_llm_fallback_total` | `from_model`, `to_model`, `reason` | Each fallback transition. `reason` is the exception class name or `cache-unhealthy` / `unconfigured`. | +| `mana_llm_provider_healthy` | `provider` | Gauge: 1 healthy, 0 in backoff. Mirrors the circuit-breaker. | + ## Provider Routing Models use the format `provider/model`: diff --git a/services/mana-llm/src/health.py b/services/mana-llm/src/health.py index 1389e0b35..eda34526c 100644 --- a/services/mana-llm/src/health.py +++ b/services/mana-llm/src/health.py @@ -31,10 +31,15 @@ import logging import threading import time from dataclasses import dataclass, field -from typing import Iterable +from typing import Callable, Iterable logger = logging.getLogger(__name__) +#: Notification fired whenever a provider transitions between healthy and +#: unhealthy. ``main.py`` wires this to the Prometheus gauge — but the +#: cache itself stays metrics-agnostic so tests don't need to mock it. +HealthChangeListener = Callable[[str, bool], None] + DEFAULT_FAILURE_THRESHOLD = 2 DEFAULT_UNHEALTHY_BACKOFF_SEC = 60.0 @@ -77,6 +82,22 @@ class ProviderHealthCache: self._clock = clock self._lock = threading.Lock() self._states: dict[str, ProviderState] = {} + self._listeners: list[HealthChangeListener] = [] + + def add_listener(self, listener: HealthChangeListener) -> None: + """Register a callback fired with ``(provider_id, healthy: bool)`` + whenever a provider's healthy-flag transitions. Listeners run + outside the cache's lock; exceptions are swallowed and logged so a + bad listener can't break the underlying state machine. + """ + self._listeners.append(listener) + + def _notify(self, provider_id: str, healthy: bool) -> None: + for listener in self._listeners: + try: + listener(provider_id, healthy) + except Exception as e: # noqa: BLE001 + logger.error("health-change listener raised: %s", e) @property def failure_threshold(self) -> int: @@ -129,19 +150,22 @@ class ProviderHealthCache: def mark_healthy(self, provider_id: str) -> None: """Provider answered correctly — clear any failure state.""" + transitioned = False with self._lock: state = self._states.setdefault(provider_id, ProviderState()) - previously_unhealthy = not state.healthy + transitioned = not state.healthy state.healthy = True state.consecutive_failures = 0 state.last_check = self._clock() state.last_error = None state.unhealthy_until = 0.0 - if previously_unhealthy: - logger.info("provider %s recovered", provider_id) + if transitioned: + logger.info("provider %s recovered", provider_id) + self._notify(provider_id, True) def mark_unhealthy(self, provider_id: str, reason: str) -> None: """Record a failure. Trips the breaker after the threshold.""" + transitioned = False with self._lock: state = self._states.setdefault(provider_id, ProviderState()) state.consecutive_failures += 1 @@ -151,6 +175,7 @@ class ProviderHealthCache: if tripped and state.healthy: state.healthy = False state.unhealthy_until = self._clock() + self._unhealthy_backoff + transitioned = True logger.warning( "provider %s marked unhealthy after %d consecutive failures (%s); " "backoff %.0fs", @@ -163,6 +188,8 @@ class ProviderHealthCache: # Still in unhealthy window; refresh the backoff so a flapping # provider doesn't get re-tried every probe tick. state.unhealthy_until = self._clock() + self._unhealthy_backoff + if transitioned: + self._notify(provider_id, False) def _copy(state: ProviderState) -> ProviderState: diff --git a/services/mana-llm/src/main.py b/services/mana-llm/src/main.py index 757105865..24f78db9d 100644 --- a/services/mana-llm/src/main.py +++ b/services/mana-llm/src/main.py @@ -1,6 +1,8 @@ """Main FastAPI application for mana-llm service.""" +import asyncio import logging +import signal import time from contextlib import asynccontextmanager from pathlib import Path @@ -8,10 +10,10 @@ from typing import Any from fastapi import FastAPI, HTTPException, Request from fastapi.middleware.cors import CORSMiddleware -from fastapi.responses import Response +from fastapi.responses import JSONResponse, Response from sse_starlette.sse import EventSourceResponse -from src.aliases import AliasRegistry +from src.aliases import AliasConfigError, AliasRegistry from src.api_auth import ApiKeyMiddleware from src.config import settings from src.health import ProviderHealthCache @@ -28,7 +30,18 @@ from src.providers import ProviderRouter from src.providers.errors import ProviderError from src.streaming import stream_chat_completion from src.utils.cache import close_redis -from src.utils.metrics import get_metrics, record_llm_error, record_llm_request +from src.utils.metrics import ( + get_metrics, + record_llm_error, + record_llm_request, + set_provider_healthy, +) + +#: Header carrying the concrete provider/model that actually served a +#: non-streaming response — useful for token-cost accounting on the +#: caller side, since `mana/long-form` could resolve to ollama, groq, +#: or claude depending on which providers were healthy at request time. +RESOLVED_MODEL_HEADER = "X-Mana-LLM-Resolved" # Configure logging logging.basicConfig( @@ -73,6 +86,35 @@ def _build_provider_probes( return probes +def _on_health_change(provider: str, healthy: bool) -> None: + """Mirror health-cache transitions into the Prometheus gauge.""" + set_provider_healthy(provider, healthy) + + +def _install_sighup_reload(loop: asyncio.AbstractEventLoop) -> None: + """Reload ``aliases.yaml`` when the process receives SIGHUP. + + Reload errors keep the previous good state in memory (see + AliasRegistry.reload). SIGHUP isn't available on Windows; we just + log and skip in that case. + """ + + def handler() -> None: + if alias_registry is None: + return + try: + alias_registry.reload() + except AliasConfigError as e: + logger.error("alias reload rejected, keeping previous state: %s", e) + + try: + loop.add_signal_handler(signal.SIGHUP, handler) + except (NotImplementedError, AttributeError, RuntimeError): + # NotImplementedError on Windows; RuntimeError when the loop is + # not running in the main thread (TestClient does this). + logger.info("SIGHUP reload not available in this context — skipping") + + @asynccontextmanager async def lifespan(app: FastAPI): """Application lifespan: load aliases, spin up router + health probe.""" @@ -85,12 +127,20 @@ async def lifespan(app: FastAPI): logger.info("Loaded %d aliases from %s", len(alias_registry.list_aliases()), aliases_path) health_cache = ProviderHealthCache() + health_cache.add_listener(_on_health_change) router = ProviderRouter(aliases=alias_registry, health_cache=health_cache) logger.info("Initialized providers: %s", list(router.providers)) + # Initial gauge values so dashboards render before the first probe + # transition fires the listener. + for provider_name in router.providers: + set_provider_healthy(provider_name, True) + health_probe = HealthProbe(health_cache, _build_provider_probes(router.providers)) await health_probe.start() + _install_sighup_reload(asyncio.get_running_loop()) + yield logger.info("Shutting down mana-llm service...") @@ -143,6 +193,43 @@ async def metrics() -> Response: return Response(content=get_metrics(), media_type="text/plain") +# ----- Alias / health debug endpoints (M4) --------------------------------- + + +@app.get("/v1/aliases") +async def list_aliases() -> dict[str, Any]: + """Inspect the alias registry — what each ``mana/`` resolves to. + + Useful for debugging "which model actually answered my request" and + for confirming SIGHUP reloads picked up edits to ``aliases.yaml``. + """ + if alias_registry is None: + raise HTTPException(status_code=503, detail="Service not ready") + return { + "default": alias_registry.default_alias, + "aliases": [ + { + "name": a.name, + "description": a.description, + "chain": list(a.chain), + } + for a in alias_registry.list_aliases() + ], + } + + +@app.get("/v1/health") +async def detailed_health() -> dict[str, Any]: + """Full per-provider liveness snapshot. + + Includes the failure counter, last error, and the unhealthy-until + backoff timestamp — info the original ``/health`` endpoint hides. + """ + if router is None: + raise HTTPException(status_code=503, detail="Service not ready") + return await router.health_check() + + # Models endpoints @app.get("/v1/models", response_model=ModelsResponse) async def list_models() -> ModelsResponse: @@ -182,10 +269,12 @@ async def chat_completions( if router is None: raise HTTPException(status_code=503, detail="Service not ready") - # Parse provider and model for metrics - model_parts = request.model.split("/", 1) - provider = model_parts[0] if len(model_parts) > 1 else "ollama" - model = model_parts[1] if len(model_parts) > 1 else request.model + # The request's `model` field is what the caller asked for — could be + # `mana/long-form`, `ollama/gemma3:4b`, or even bare `gemma3:4b`. For + # error-path metrics we use that value (it's what the caller will + # search for); for success-path metrics we use the resolved provider + # so token-cost / latency attribute to the model that actually ran. + requested_provider, requested_model = _split_model(request.model) start_time = time.time() @@ -198,7 +287,11 @@ async def chat_completions( async for chunk in stream_chat_completion(router, request): yield chunk - record_llm_request(provider, model, streaming=True) + # Streaming metrics: we don't yet know which provider answered + # at request-record time. Each chunk's `model` field carries + # the resolved name; per-token latency is harder to attribute + # cleanly so we skip it for streams. + record_llm_request(requested_provider, requested_model, streaming=True) return EventSourceResponse( generate(), @@ -209,35 +302,43 @@ async def chat_completions( logger.info(f"Chat completion: {request.model}") response = await router.chat_completion(request) - # Record metrics + resolved_provider, resolved_model = _split_model(response.model) latency = time.time() - start_time record_llm_request( - provider=provider, - model=model, + provider=resolved_provider, + model=resolved_model, streaming=False, prompt_tokens=response.usage.prompt_tokens, completion_tokens=response.usage.completion_tokens, latency=latency, ) - return response + # `response.model` is the concrete provider/model the chain + # actually resolved to. Surface it via header so the caller + # can attribute token cost to the right model even when the + # request used an alias. + return JSONResponse( + content=response.model_dump(), + headers={RESOLVED_MODEL_HEADER: response.model}, + ) except ValueError as e: logger.error(f"Invalid request: {e}") - record_llm_error(provider, model, "invalid_request") + record_llm_error(requested_provider, requested_model, "invalid_request") raise HTTPException(status_code=400, detail=str(e)) except ProviderError as e: logger.warning( - f"Provider error on {provider}/{model}: kind={e.kind} detail={e}" + f"Provider error on {requested_provider}/{requested_model}: " + f"kind={e.kind} detail={e}" ) - record_llm_error(provider, model, e.kind) + record_llm_error(requested_provider, requested_model, e.kind) raise HTTPException( status_code=e.http_status, detail={"kind": e.kind, "message": str(e)}, ) except Exception as e: logger.error(f"Chat completion failed: {e}") - record_llm_error(provider, model, "server_error") + record_llm_error(requested_provider, requested_model, "server_error") raise HTTPException(status_code=500, detail=str(e)) @@ -248,10 +349,7 @@ async def create_embeddings(request: EmbeddingRequest) -> EmbeddingResponse: if router is None: raise HTTPException(status_code=503, detail="Service not ready") - # Parse provider and model for metrics - model_parts = request.model.split("/", 1) - provider = model_parts[0] if len(model_parts) > 1 else "ollama" - model = model_parts[1] if len(model_parts) > 1 else request.model + provider, model = _split_model(request.model) start_time = time.time() @@ -280,6 +378,21 @@ async def create_embeddings(request: EmbeddingRequest) -> EmbeddingResponse: raise HTTPException(status_code=500, detail=str(e)) +def _split_model(model: str) -> tuple[str, str]: + """Split a ``provider/model`` string for metric labelling. + + Bare names with no slash default to ``ollama`` to match the legacy + OpenAI-style behaviour. Aliases (``mana/...``) keep their namespace + in the metrics — that's intentional, so request-side counters tell + you what callers ASKED for, while the resolved-side counters + (``mana_llm_alias_resolved_total``) tell you what they GOT. + """ + if "/" in model: + provider, _, name = model.partition("/") + return provider.lower(), name + return "ollama", model + + if __name__ == "__main__": import uvicorn diff --git a/services/mana-llm/src/providers/router.py b/services/mana-llm/src/providers/router.py index 0f0dfe00c..e6efaa6a0 100644 --- a/services/mana-llm/src/providers/router.py +++ b/services/mana-llm/src/providers/router.py @@ -41,6 +41,11 @@ from src.models import ( EmbeddingResponse, ModelInfo, ) +from src.utils.metrics import ( + record_alias_resolved, + record_fallback, + set_provider_healthy, +) from .base import LLMProvider from .errors import ( @@ -229,8 +234,10 @@ class ProviderRouter: chain = self._resolve_chain(model_or_alias) attempts: list[tuple[str, str]] = [] last_exc: Exception | None = None + is_alias = AliasRegistry.is_alias(model_or_alias) - for entry in chain: + for i, entry in enumerate(chain): + next_entry = chain[i + 1] if i + 1 < len(chain) else "" provider_name, model_name = self._parse_model(entry) if provider_name not in self.providers: logger.debug( @@ -239,11 +246,13 @@ class ProviderRouter: provider_name, ) attempts.append((entry, "unconfigured")) + record_fallback(entry, next_entry, "unconfigured") continue if not self.health_cache.is_healthy(provider_name): logger.debug("skip chain entry %s — cache says unhealthy", entry) attempts.append((entry, "cache-unhealthy")) + record_fallback(entry, next_entry, "cache-unhealthy") continue provider = self.providers[provider_name] @@ -253,10 +262,13 @@ class ProviderRouter: logger.info( "execute → %s (alias=%s)", entry, - model_or_alias if model_or_alias != entry else "", + model_or_alias if is_alias else "", ) result = await call(provider, model_name, request) self.health_cache.mark_healthy(provider_name) + set_provider_healthy(provider_name, True) + if is_alias: + record_alias_resolved(model_or_alias, entry) return result except Exception as e: if not self._is_retryable(e): @@ -266,7 +278,11 @@ class ProviderRouter: # being wrong. raise self.health_cache.mark_unhealthy(provider_name, self._exception_summary(e)) + set_provider_healthy( + provider_name, self.health_cache.is_healthy(provider_name) + ) attempts.append((entry, type(e).__name__)) + record_fallback(entry, next_entry, type(e).__name__) last_exc = e logger.warning( "execute %s failed (retryable, will try next): %s", @@ -310,14 +326,18 @@ class ProviderRouter: chain = self._resolve_chain(request.model) attempts: list[tuple[str, str]] = [] last_exc: Exception | None = None + is_alias = AliasRegistry.is_alias(request.model) - for entry in chain: + for i, entry in enumerate(chain): + next_entry = chain[i + 1] if i + 1 < len(chain) else "" provider_name, model_name = self._parse_model(entry) if provider_name not in self.providers: attempts.append((entry, "unconfigured")) + record_fallback(entry, next_entry, "unconfigured") continue if not self.health_cache.is_healthy(provider_name): attempts.append((entry, "cache-unhealthy")) + record_fallback(entry, next_entry, "cache-unhealthy") continue provider = self.providers[provider_name] @@ -330,13 +350,20 @@ class ProviderRouter: # Empty stream is a successful but content-free response. # Commit and exit cleanly. self.health_cache.mark_healthy(provider_name) + set_provider_healthy(provider_name, True) + if is_alias: + record_alias_resolved(request.model, entry) logger.info("stream %s yielded empty response", entry) return except Exception as e: if not self._is_retryable(e): raise self.health_cache.mark_unhealthy(provider_name, self._exception_summary(e)) + set_provider_healthy( + provider_name, self.health_cache.is_healthy(provider_name) + ) attempts.append((entry, type(e).__name__)) + record_fallback(entry, next_entry, type(e).__name__) last_exc = e logger.warning( "stream %s failed before first byte (retryable, trying next): %s", @@ -349,6 +376,9 @@ class ProviderRouter: # the rest of the stream. Any error from here on propagates; # it is NOT safe to splice another provider's output in. self.health_cache.mark_healthy(provider_name) + set_provider_healthy(provider_name, True) + if is_alias: + record_alias_resolved(request.model, entry) logger.info("stream → %s (committed after first chunk)", entry) yield first_chunk async for chunk in stream: diff --git a/services/mana-llm/src/utils/metrics.py b/services/mana-llm/src/utils/metrics.py index 87d4096a5..057cc1656 100644 --- a/services/mana-llm/src/utils/metrics.py +++ b/services/mana-llm/src/utils/metrics.py @@ -4,7 +4,7 @@ import time from collections.abc import Callable from fastapi import Request, Response -from prometheus_client import Counter, Histogram, generate_latest +from prometheus_client import Counter, Gauge, Histogram, generate_latest from starlette.middleware.base import BaseHTTPMiddleware # Request metrics @@ -47,6 +47,35 @@ LLM_ERRORS = Counter( ["provider", "model", "error_type"], ) +# --------------------------------------------------------------------------- +# Alias / fallback / health metrics — added in M4 of llm-fallback-aliases.md. +# --------------------------------------------------------------------------- + +ALIAS_RESOLVED = Counter( + "mana_llm_alias_resolved_total", + "How often an alias resolved to a concrete provider/model. The `target` " + "label is the chain entry that actually served the request — useful for " + "spotting cases where the primary always falls through to a cloud entry.", + ["alias", "target"], +) + +FALLBACK_TRIGGERED = Counter( + "mana_llm_fallback_total", + "Fallback transitions: a chain entry failed (or was skipped via cache) " + "and the router moved to the next entry. `reason` is the exception class " + "name or `cache-unhealthy` / `unconfigured`. `from_model` is the entry " + "that didn't serve, `to_model` is empty when no further entries existed.", + ["from_model", "to_model", "reason"], +) + +PROVIDER_HEALTHY = Gauge( + "mana_llm_provider_healthy", + "1 when the provider is currently considered healthy by the cache, " + "0 when in backoff. Refreshed on every probe tick and on every router " + "call-site state transition.", + ["provider"], +) + def get_metrics() -> bytes: """Generate Prometheus metrics output.""" @@ -107,3 +136,23 @@ def record_llm_request( def record_llm_error(provider: str, model: str, error_type: str) -> None: """Record LLM error metrics.""" LLM_ERRORS.labels(provider=provider, model=model, error_type=error_type).inc() + + +def record_alias_resolved(alias: str, target: str) -> None: + """Record which concrete model an alias resolved to for this request.""" + ALIAS_RESOLVED.labels(alias=alias, target=target).inc() + + +def record_fallback(from_model: str, to_model: str, reason: str) -> None: + """Record a fallback transition. ``to_model`` is empty when the chain + ran out (i.e. NoHealthyProviderError).""" + FALLBACK_TRIGGERED.labels( + from_model=from_model, + to_model=to_model, + reason=reason, + ).inc() + + +def set_provider_healthy(provider: str, healthy: bool) -> None: + """Mirror ``ProviderHealthCache`` state into a Prometheus gauge.""" + PROVIDER_HEALTHY.labels(provider=provider).set(1.0 if healthy else 0.0) diff --git a/services/mana-llm/tests/test_m4_observability.py b/services/mana-llm/tests/test_m4_observability.py new file mode 100644 index 000000000..d0fe0c84b --- /dev/null +++ b/services/mana-llm/tests/test_m4_observability.py @@ -0,0 +1,415 @@ +"""Tests for M4: observability + debug endpoints + reload.""" + +from __future__ import annotations + +import asyncio +import os +import signal +from pathlib import Path + +import httpx +import pytest +from fastapi.testclient import TestClient +from prometheus_client import REGISTRY + +from src.aliases import AliasRegistry +from src.health import ProviderHealthCache +from src.models import ( + ChatCompletionRequest, + ChatCompletionResponse, + Choice, + Message, + MessageResponse, +) +from src.providers import ProviderRouter +from src.utils.metrics import ( + record_alias_resolved, + record_fallback, + set_provider_healthy, +) + + +# --------------------------------------------------------------------------- +# Cache → listener → metric gauge +# --------------------------------------------------------------------------- + + +class TestHealthChangeListener: + def test_listener_fires_on_unhealthy_transition(self) -> None: + cache = ProviderHealthCache(failure_threshold=2) + events: list[tuple[str, bool]] = [] + cache.add_listener(lambda p, h: events.append((p, h))) + + # First failure: still healthy → no transition. + cache.mark_unhealthy("ollama", "blip") + assert events == [] + + # Second failure: transition healthy→unhealthy → fires. + cache.mark_unhealthy("ollama", "boom") + assert events == [("ollama", False)] + + def test_listener_fires_on_recovery(self) -> None: + cache = ProviderHealthCache(failure_threshold=1) + events: list[tuple[str, bool]] = [] + cache.add_listener(lambda p, h: events.append((p, h))) + + cache.mark_unhealthy("ollama", "boom") + assert events == [("ollama", False)] + + cache.mark_healthy("ollama") + assert events == [("ollama", False), ("ollama", True)] + + def test_steady_state_does_not_fire(self) -> None: + cache = ProviderHealthCache(failure_threshold=1) + events: list[tuple[str, bool]] = [] + cache.add_listener(lambda p, h: events.append((p, h))) + + # Three healthy ops in a row — no transitions, no events. + for _ in range(3): + cache.mark_healthy("ollama") + assert events == [] + + def test_listener_exception_does_not_break_cache(self) -> None: + cache = ProviderHealthCache(failure_threshold=1) + + def bad(_provider: str, _healthy: bool) -> None: + raise RuntimeError("listener boom") + + cache.add_listener(bad) + # Should NOT raise — the cache must keep working with a broken + # listener, otherwise one bad metric callback would brick the + # whole router. + cache.mark_unhealthy("ollama", "x") + assert cache.is_healthy("ollama") is False + + def test_multiple_listeners(self) -> None: + cache = ProviderHealthCache(failure_threshold=1) + a: list = [] + b: list = [] + cache.add_listener(lambda p, h: a.append((p, h))) + cache.add_listener(lambda p, h: b.append((p, h))) + + cache.mark_unhealthy("ollama", "x") + assert a == [("ollama", False)] + assert b == [("ollama", False)] + + +# --------------------------------------------------------------------------- +# Prometheus metrics — counters/gauges actually move +# --------------------------------------------------------------------------- + + +def _counter_value(name: str, labels: dict[str, str]) -> float: + """Helper: read the current value of a labeled Prometheus metric.""" + samples = REGISTRY.get_sample_value(name, labels=labels) + return samples or 0.0 + + +class TestMetricsRecording: + def test_record_alias_resolved_increments(self) -> None: + before = _counter_value( + "mana_llm_alias_resolved_total", + {"alias": "mana/test-class", "target": "ollama/x:1b"}, + ) + record_alias_resolved("mana/test-class", "ollama/x:1b") + after = _counter_value( + "mana_llm_alias_resolved_total", + {"alias": "mana/test-class", "target": "ollama/x:1b"}, + ) + assert after - before == pytest.approx(1.0) + + def test_record_fallback_increments(self) -> None: + before = _counter_value( + "mana_llm_fallback_total", + {"from_model": "ollama/x", "to_model": "groq/y", "reason": "ConnectError"}, + ) + record_fallback("ollama/x", "groq/y", "ConnectError") + after = _counter_value( + "mana_llm_fallback_total", + {"from_model": "ollama/x", "to_model": "groq/y", "reason": "ConnectError"}, + ) + assert after - before == pytest.approx(1.0) + + def test_set_provider_healthy_writes_gauge(self) -> None: + set_provider_healthy("test_provider_xyz", True) + v = REGISTRY.get_sample_value( + "mana_llm_provider_healthy", labels={"provider": "test_provider_xyz"} + ) + assert v == 1.0 + + set_provider_healthy("test_provider_xyz", False) + v = REGISTRY.get_sample_value( + "mana_llm_provider_healthy", labels={"provider": "test_provider_xyz"} + ) + assert v == 0.0 + + +# --------------------------------------------------------------------------- +# Router → metrics: end-to-end through a fallback +# --------------------------------------------------------------------------- + + +class _OkProvider: + """Minimal provider double — only what the router uses for chat.""" + + name = "ok-provider" + supports_tools = True + + def __init__(self, name: str, fail_with: BaseException | None = None) -> None: + self.name = name + self.fail_with = fail_with + self.calls = 0 + + def model_supports_tools(self, model: str) -> bool: + return True + + async def chat_completion(self, request, model): + self.calls += 1 + if self.fail_with is not None: + raise self.fail_with + return ChatCompletionResponse( + model=f"{self.name}/{model}", + choices=[Choice(message=MessageResponse(content="ok"))], + ) + + async def chat_completion_stream(self, request, model): # pragma: no cover + if False: # pragma: no cover + yield None + + async def list_models(self): + return [] + + async def embeddings(self, request, model): + raise NotImplementedError + + async def health_check(self): + return {"status": "healthy"} + + async def close(self): + pass + + +def _aliases(tmp_path: Path) -> AliasRegistry: + cfg = ( + "aliases:\n" + " mana/two-step:\n" + ' description: "x"\n' + " chain:\n" + " - alpha/m1\n" + " - beta/m2\n" + ) + p = tmp_path / "aliases.yaml" + p.write_text(cfg) + return AliasRegistry(p) + + +class TestRouterMetricsIntegration: + @pytest.mark.asyncio + async def test_alias_resolved_metric_records_target(self, tmp_path: Path) -> None: + aliases = _aliases(tmp_path) + cache = ProviderHealthCache() + router = ProviderRouter(aliases=aliases, health_cache=cache) + router.providers = {"alpha": _OkProvider("alpha")} # beta not configured + + before = _counter_value( + "mana_llm_alias_resolved_total", + {"alias": "mana/two-step", "target": "alpha/m1"}, + ) + await router.chat_completion( + ChatCompletionRequest( + model="mana/two-step", + messages=[Message(role="user", content="hi")], + ) + ) + after = _counter_value( + "mana_llm_alias_resolved_total", + {"alias": "mana/two-step", "target": "alpha/m1"}, + ) + assert after - before == pytest.approx(1.0) + + @pytest.mark.asyncio + async def test_fallback_metric_records_transition(self, tmp_path: Path) -> None: + aliases = _aliases(tmp_path) + cache = ProviderHealthCache() + router = ProviderRouter(aliases=aliases, health_cache=cache) + router.providers = { + "alpha": _OkProvider("alpha", fail_with=httpx.ConnectError("dead")), + "beta": _OkProvider("beta"), + } + + before = _counter_value( + "mana_llm_fallback_total", + {"from_model": "alpha/m1", "to_model": "beta/m2", "reason": "ConnectError"}, + ) + await router.chat_completion( + ChatCompletionRequest( + model="mana/two-step", + messages=[Message(role="user", content="hi")], + ) + ) + after = _counter_value( + "mana_llm_fallback_total", + {"from_model": "alpha/m1", "to_model": "beta/m2", "reason": "ConnectError"}, + ) + assert after - before == pytest.approx(1.0) + + @pytest.mark.asyncio + async def test_direct_model_does_not_record_alias_metric( + self, tmp_path: Path + ) -> None: + # Direct provider/model is not an alias — ALIAS_RESOLVED counter + # must stay flat for those calls. + aliases = _aliases(tmp_path) + cache = ProviderHealthCache() + router = ProviderRouter(aliases=aliases, health_cache=cache) + router.providers = {"alpha": _OkProvider("alpha")} + + before = _counter_value( + "mana_llm_alias_resolved_total", + {"alias": "alpha/anything", "target": "alpha/anything"}, + ) + await router.chat_completion( + ChatCompletionRequest( + model="alpha/anything", + messages=[Message(role="user", content="hi")], + ) + ) + after = _counter_value( + "mana_llm_alias_resolved_total", + {"alias": "alpha/anything", "target": "alpha/anything"}, + ) + # Counter must have NOT increased — direct calls aren't aliases. + assert after == before + + +# --------------------------------------------------------------------------- +# Debug endpoints: GET /v1/aliases, GET /v1/health +# --------------------------------------------------------------------------- + + +@pytest.fixture +def client(): + from src.main import app + + with TestClient(app) as c: + yield c + + +class TestDebugEndpoints: + def test_v1_aliases_returns_shipped_config(self, client: TestClient) -> None: + resp = client.get("/v1/aliases") + assert resp.status_code == 200 + data = resp.json() + names = [a["name"] for a in data["aliases"]] + # The five canonical classes must always be present. + for expected in ( + "mana/fast-text", + "mana/long-form", + "mana/structured", + "mana/reasoning", + "mana/vision", + ): + assert expected in names + # Default is set in the shipped config. + assert data["default"] == "mana/fast-text" + + def test_v1_aliases_chain_format(self, client: TestClient) -> None: + resp = client.get("/v1/aliases") + data = resp.json() + long_form = next(a for a in data["aliases"] if a["name"] == "mana/long-form") + # Each chain entry is a `provider/model` string. + assert all("/" in entry for entry in long_form["chain"]) + assert len(long_form["chain"]) >= 2 # plan requires at least one cloud fallback + + def test_v1_health_includes_all_providers(self, client: TestClient) -> None: + resp = client.get("/v1/health") + assert resp.status_code == 200 + data = resp.json() + assert "status" in data + assert "providers" in data + # ollama is always configured (provider list is non-empty). + assert "ollama" in data["providers"] + for name, info in data["providers"].items(): + assert "status" in info + assert "consecutive_failures" in info + + +# --------------------------------------------------------------------------- +# X-Mana-LLM-Resolved header on non-streaming responses +# --------------------------------------------------------------------------- + + +class TestResolvedHeader: + """The header is the consumer's hook for token-cost attribution. + + Tested at the router level — wiring through main.py would need a + real provider connection, which isn't available in unit tests. + """ + + @pytest.mark.asyncio + async def test_response_model_field_carries_resolved_target( + self, tmp_path: Path + ) -> None: + # The header value is `response.model`; verify that field reflects + # the actual chain entry that served, not the requested alias. + aliases = _aliases(tmp_path) + cache = ProviderHealthCache() + router = ProviderRouter(aliases=aliases, health_cache=cache) + # Force fallback to beta. + router.providers = { + "alpha": _OkProvider("alpha", fail_with=httpx.ConnectError("d")), + "beta": _OkProvider("beta"), + } + + resp = await router.chat_completion( + ChatCompletionRequest( + model="mana/two-step", + messages=[Message(role="user", content="hi")], + ) + ) + # Even though the caller asked for `mana/two-step`, the resolved + # field shows the entry that actually answered. + assert resp.model == "beta/m2" + + +# --------------------------------------------------------------------------- +# SIGHUP reload — only meaningful on Unix; tested by signalling the proc +# --------------------------------------------------------------------------- + + +class TestSighupReload: + """SIGHUP triggers ``alias_registry.reload()``; reload-error keeps state. + + The signal-handler wiring lives in main.py and only installs when + the loop is running in the main thread. We exercise the reload + semantics here directly on the registry instead — the signal-handler + code path itself is a 4-line wrapper around ``reload()``. + """ + + def test_reload_picks_up_yaml_edits(self, tmp_path: Path) -> None: + path = tmp_path / "aliases.yaml" + path.write_text( + "aliases:\n" + " mana/x:\n" + ' description: "x"\n' + " chain:\n" + " - ollama/foo:1b\n" + ) + reg = AliasRegistry(path) + assert reg.resolve_chain("mana/x") == ("ollama/foo:1b",) + + # Edit on disk, reload (this is exactly what the SIGHUP handler + # does — minus the signal plumbing). + path.write_text( + "aliases:\n" + " mana/x:\n" + ' description: "x"\n' + " chain:\n" + " - ollama/bar:1b\n" + " - groq/llama-3.1-8b-instant\n" + ) + reg.reload() + assert reg.resolve_chain("mana/x") == ( + "ollama/bar:1b", + "groq/llama-3.1-8b-instant", + )