mirror of
https://github.com/Memo-2023/mana-monorepo.git
synced 2026-05-14 19:01:08 +02:00
feat(mana-llm): M4 — observability, debug endpoints, SIGHUP reload
- `X-Mana-LLM-Resolved: <provider>/<model>` header on non-streaming
responses. Streaming clients read the same info from each chunk's
`model` field (SSE headers go out before the chain is walked).
- Three new Prometheus metrics: `mana_llm_alias_resolved_total{alias,
target}` (which concrete model an alias resolved to per request),
`mana_llm_fallback_total{from_model, to_model, reason}` (each
fallback transition), `mana_llm_provider_healthy{provider}` (gauge,
mirrors the circuit-breaker).
- New debug endpoints: `GET /v1/aliases` (registry inspection — chain
+ description per alias, useful for confirming SIGHUP reloads),
`GET /v1/health` (full per-provider liveness snapshot — failure
counter, last error, unhealthy-until backoff).
- `kill -HUP <pid>` reloads `aliases.yaml`. Parse errors leave the
previous good state in memory and log the rejection.
- `ProviderHealthCache.add_listener()` for cache→metrics decoupling:
the gauge is updated via a transition-only listener wired in main.py
rather than the cache importing prometheus_client itself.
- Request-side metrics now use the requested model string, success-side
uses the resolved one. So `mana_llm_llm_requests_total{provider="ollama",
model="gemma3:12b"}` reflects actual upstream load even when callers
used `mana/long-form` aliases.
16 new observability tests (test_m4_observability.py): listener
fire-on-transition semantics, exception-isolation, multi-listener,
counter increments, gauge writes, end-to-end alias→metric flow,
v1/aliases + v1/health endpoint shape, response.model carries the
resolved target after fallback. Total suite: 115/115 in 1.6s.
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
parent
3046da3b19
commit
8a49e3ffd5
6 changed files with 749 additions and 29 deletions
|
|
@ -143,13 +143,99 @@ curl -X POST http://localhost:3025/v1/embeddings \
|
|||
### Health & Metrics
|
||||
|
||||
```bash
|
||||
# Health check
|
||||
# Liveness summary (legacy, terse shape — only status + per-provider status string)
|
||||
curl http://localhost:3025/health
|
||||
|
||||
# Detailed per-provider liveness snapshot (M4)
|
||||
curl http://localhost:3025/v1/health
|
||||
|
||||
# Prometheus metrics
|
||||
curl http://localhost:3025/metrics
|
||||
```
|
||||
|
||||
### Aliases (M4 — see `aliases.yaml` and the fallback section below)
|
||||
|
||||
```bash
|
||||
# What does each `mana/<class>` resolve to?
|
||||
curl http://localhost:3025/v1/aliases
|
||||
```
|
||||
|
||||
## Aliases & Fallback
|
||||
|
||||
> Background: [`docs/plans/llm-fallback-aliases.md`](../../docs/plans/llm-fallback-aliases.md)
|
||||
|
||||
### What callers send
|
||||
|
||||
Two acceptable shapes for the `model` field of `/v1/chat/completions`:
|
||||
|
||||
1. **Aliases** in the reserved `mana/` namespace — recommended for product code.
|
||||
The router resolves them via `aliases.yaml` to a chain of concrete
|
||||
`provider/model` strings and tries them in order.
|
||||
2. **Direct `provider/model`** — bypasses the alias layer, no fallback.
|
||||
Useful for tests, debugging, and one-off integrations.
|
||||
|
||||
| Alias | Class |
|
||||
|---|---|
|
||||
| `mana/fast-text` | Short answers, classification, single-shot Q&A |
|
||||
| `mana/long-form` | Writing, essays, stories, longer prose |
|
||||
| `mana/structured` | JSON output (comic storyboards, research subqueries, tag suggestions) |
|
||||
| `mana/reasoning` | Agent missions, tool calls, multi-step plans |
|
||||
| `mana/vision` | Multimodal (image + text) |
|
||||
|
||||
The chain for each alias lives in `services/mana-llm/aliases.yaml`. Edit
|
||||
the file and `kill -HUP <pid>` to reload — no restart needed. Reload
|
||||
errors keep the previous good state; check the service logs.
|
||||
|
||||
### Fallback semantics
|
||||
|
||||
Every chain is tried in order. The router skips an entry if the provider
|
||||
isn't configured at this deployment (no API key) or is currently marked
|
||||
unhealthy by the health-cache. For each remaining entry the request is
|
||||
attempted; on a **retryable** error (connection failure, timeout, 5xx,
|
||||
rate-limit, RemoteProtocolError) the provider is marked unhealthy and
|
||||
the next entry is tried. **Non-retryable** errors (auth, capability,
|
||||
content-blocked, 4xx, unknown exception types) propagate immediately —
|
||||
no fallback, the cache is not poisoned.
|
||||
|
||||
Streaming follows the same logic up to the **first byte**. Once a chunk
|
||||
has been yielded the provider is committed; mid-stream errors surface
|
||||
as-is so we never splice two providers' voices into one output.
|
||||
|
||||
If every entry was skipped or failed, the response is `503` carrying a
|
||||
structured `attempts: list[(model, reason)]` log so the cause is
|
||||
visible to the caller, not only in service logs.
|
||||
|
||||
### Resolved-model header
|
||||
|
||||
Non-streaming responses carry `X-Mana-LLM-Resolved: <provider>/<model>`
|
||||
(e.g. `groq/llama-3.3-70b-versatile`) — the concrete model that
|
||||
actually answered. Use this for token-cost attribution when the request
|
||||
used an alias. For streaming, each chunk's `model` field carries the
|
||||
same info (headers go out before the chain is walked).
|
||||
|
||||
### Health-cache + probe
|
||||
|
||||
`ProviderHealthCache` keeps a per-provider circuit-breaker:
|
||||
|
||||
* 1 failure: still healthy (transient blip, don't bounce).
|
||||
* 2 consecutive failures: `is_healthy → False` for 60 s; the router
|
||||
fail-fasts straight to the next chain entry.
|
||||
* After 60 s: half-open. Next call exercises the provider; success
|
||||
fully resets, failure re-arms the backoff.
|
||||
|
||||
A background `HealthProbe` task runs every 30 s with a 3 s timeout per
|
||||
provider, calling cheap endpoints (`/api/tags` for Ollama, `/v1/models`
|
||||
for OpenAI-compat). One bad probe can't sink the loop; results feed
|
||||
into the same cache as the call-site fallback.
|
||||
|
||||
### Prometheus metrics added in M4
|
||||
|
||||
| Metric | Labels | Purpose |
|
||||
|---|---|---|
|
||||
| `mana_llm_alias_resolved_total` | `alias`, `target` | How often an alias resolved to which concrete model — useful for spotting cases where the primary always falls through. |
|
||||
| `mana_llm_fallback_total` | `from_model`, `to_model`, `reason` | Each fallback transition. `reason` is the exception class name or `cache-unhealthy` / `unconfigured`. |
|
||||
| `mana_llm_provider_healthy` | `provider` | Gauge: 1 healthy, 0 in backoff. Mirrors the circuit-breaker. |
|
||||
|
||||
## Provider Routing
|
||||
|
||||
Models use the format `provider/model`:
|
||||
|
|
|
|||
|
|
@ -31,10 +31,15 @@ import logging
|
|||
import threading
|
||||
import time
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Iterable
|
||||
from typing import Callable, Iterable
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
#: Notification fired whenever a provider transitions between healthy and
|
||||
#: unhealthy. ``main.py`` wires this to the Prometheus gauge — but the
|
||||
#: cache itself stays metrics-agnostic so tests don't need to mock it.
|
||||
HealthChangeListener = Callable[[str, bool], None]
|
||||
|
||||
DEFAULT_FAILURE_THRESHOLD = 2
|
||||
DEFAULT_UNHEALTHY_BACKOFF_SEC = 60.0
|
||||
|
||||
|
|
@ -77,6 +82,22 @@ class ProviderHealthCache:
|
|||
self._clock = clock
|
||||
self._lock = threading.Lock()
|
||||
self._states: dict[str, ProviderState] = {}
|
||||
self._listeners: list[HealthChangeListener] = []
|
||||
|
||||
def add_listener(self, listener: HealthChangeListener) -> None:
|
||||
"""Register a callback fired with ``(provider_id, healthy: bool)``
|
||||
whenever a provider's healthy-flag transitions. Listeners run
|
||||
outside the cache's lock; exceptions are swallowed and logged so a
|
||||
bad listener can't break the underlying state machine.
|
||||
"""
|
||||
self._listeners.append(listener)
|
||||
|
||||
def _notify(self, provider_id: str, healthy: bool) -> None:
|
||||
for listener in self._listeners:
|
||||
try:
|
||||
listener(provider_id, healthy)
|
||||
except Exception as e: # noqa: BLE001
|
||||
logger.error("health-change listener raised: %s", e)
|
||||
|
||||
@property
|
||||
def failure_threshold(self) -> int:
|
||||
|
|
@ -129,19 +150,22 @@ class ProviderHealthCache:
|
|||
|
||||
def mark_healthy(self, provider_id: str) -> None:
|
||||
"""Provider answered correctly — clear any failure state."""
|
||||
transitioned = False
|
||||
with self._lock:
|
||||
state = self._states.setdefault(provider_id, ProviderState())
|
||||
previously_unhealthy = not state.healthy
|
||||
transitioned = not state.healthy
|
||||
state.healthy = True
|
||||
state.consecutive_failures = 0
|
||||
state.last_check = self._clock()
|
||||
state.last_error = None
|
||||
state.unhealthy_until = 0.0
|
||||
if previously_unhealthy:
|
||||
logger.info("provider %s recovered", provider_id)
|
||||
if transitioned:
|
||||
logger.info("provider %s recovered", provider_id)
|
||||
self._notify(provider_id, True)
|
||||
|
||||
def mark_unhealthy(self, provider_id: str, reason: str) -> None:
|
||||
"""Record a failure. Trips the breaker after the threshold."""
|
||||
transitioned = False
|
||||
with self._lock:
|
||||
state = self._states.setdefault(provider_id, ProviderState())
|
||||
state.consecutive_failures += 1
|
||||
|
|
@ -151,6 +175,7 @@ class ProviderHealthCache:
|
|||
if tripped and state.healthy:
|
||||
state.healthy = False
|
||||
state.unhealthy_until = self._clock() + self._unhealthy_backoff
|
||||
transitioned = True
|
||||
logger.warning(
|
||||
"provider %s marked unhealthy after %d consecutive failures (%s); "
|
||||
"backoff %.0fs",
|
||||
|
|
@ -163,6 +188,8 @@ class ProviderHealthCache:
|
|||
# Still in unhealthy window; refresh the backoff so a flapping
|
||||
# provider doesn't get re-tried every probe tick.
|
||||
state.unhealthy_until = self._clock() + self._unhealthy_backoff
|
||||
if transitioned:
|
||||
self._notify(provider_id, False)
|
||||
|
||||
|
||||
def _copy(state: ProviderState) -> ProviderState:
|
||||
|
|
|
|||
|
|
@ -1,6 +1,8 @@
|
|||
"""Main FastAPI application for mana-llm service."""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import signal
|
||||
import time
|
||||
from contextlib import asynccontextmanager
|
||||
from pathlib import Path
|
||||
|
|
@ -8,10 +10,10 @@ from typing import Any
|
|||
|
||||
from fastapi import FastAPI, HTTPException, Request
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.responses import Response
|
||||
from fastapi.responses import JSONResponse, Response
|
||||
from sse_starlette.sse import EventSourceResponse
|
||||
|
||||
from src.aliases import AliasRegistry
|
||||
from src.aliases import AliasConfigError, AliasRegistry
|
||||
from src.api_auth import ApiKeyMiddleware
|
||||
from src.config import settings
|
||||
from src.health import ProviderHealthCache
|
||||
|
|
@ -28,7 +30,18 @@ from src.providers import ProviderRouter
|
|||
from src.providers.errors import ProviderError
|
||||
from src.streaming import stream_chat_completion
|
||||
from src.utils.cache import close_redis
|
||||
from src.utils.metrics import get_metrics, record_llm_error, record_llm_request
|
||||
from src.utils.metrics import (
|
||||
get_metrics,
|
||||
record_llm_error,
|
||||
record_llm_request,
|
||||
set_provider_healthy,
|
||||
)
|
||||
|
||||
#: Header carrying the concrete provider/model that actually served a
|
||||
#: non-streaming response — useful for token-cost accounting on the
|
||||
#: caller side, since `mana/long-form` could resolve to ollama, groq,
|
||||
#: or claude depending on which providers were healthy at request time.
|
||||
RESOLVED_MODEL_HEADER = "X-Mana-LLM-Resolved"
|
||||
|
||||
# Configure logging
|
||||
logging.basicConfig(
|
||||
|
|
@ -73,6 +86,35 @@ def _build_provider_probes(
|
|||
return probes
|
||||
|
||||
|
||||
def _on_health_change(provider: str, healthy: bool) -> None:
|
||||
"""Mirror health-cache transitions into the Prometheus gauge."""
|
||||
set_provider_healthy(provider, healthy)
|
||||
|
||||
|
||||
def _install_sighup_reload(loop: asyncio.AbstractEventLoop) -> None:
|
||||
"""Reload ``aliases.yaml`` when the process receives SIGHUP.
|
||||
|
||||
Reload errors keep the previous good state in memory (see
|
||||
AliasRegistry.reload). SIGHUP isn't available on Windows; we just
|
||||
log and skip in that case.
|
||||
"""
|
||||
|
||||
def handler() -> None:
|
||||
if alias_registry is None:
|
||||
return
|
||||
try:
|
||||
alias_registry.reload()
|
||||
except AliasConfigError as e:
|
||||
logger.error("alias reload rejected, keeping previous state: %s", e)
|
||||
|
||||
try:
|
||||
loop.add_signal_handler(signal.SIGHUP, handler)
|
||||
except (NotImplementedError, AttributeError, RuntimeError):
|
||||
# NotImplementedError on Windows; RuntimeError when the loop is
|
||||
# not running in the main thread (TestClient does this).
|
||||
logger.info("SIGHUP reload not available in this context — skipping")
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
"""Application lifespan: load aliases, spin up router + health probe."""
|
||||
|
|
@ -85,12 +127,20 @@ async def lifespan(app: FastAPI):
|
|||
logger.info("Loaded %d aliases from %s", len(alias_registry.list_aliases()), aliases_path)
|
||||
|
||||
health_cache = ProviderHealthCache()
|
||||
health_cache.add_listener(_on_health_change)
|
||||
router = ProviderRouter(aliases=alias_registry, health_cache=health_cache)
|
||||
logger.info("Initialized providers: %s", list(router.providers))
|
||||
|
||||
# Initial gauge values so dashboards render before the first probe
|
||||
# transition fires the listener.
|
||||
for provider_name in router.providers:
|
||||
set_provider_healthy(provider_name, True)
|
||||
|
||||
health_probe = HealthProbe(health_cache, _build_provider_probes(router.providers))
|
||||
await health_probe.start()
|
||||
|
||||
_install_sighup_reload(asyncio.get_running_loop())
|
||||
|
||||
yield
|
||||
|
||||
logger.info("Shutting down mana-llm service...")
|
||||
|
|
@ -143,6 +193,43 @@ async def metrics() -> Response:
|
|||
return Response(content=get_metrics(), media_type="text/plain")
|
||||
|
||||
|
||||
# ----- Alias / health debug endpoints (M4) ---------------------------------
|
||||
|
||||
|
||||
@app.get("/v1/aliases")
|
||||
async def list_aliases() -> dict[str, Any]:
|
||||
"""Inspect the alias registry — what each ``mana/<class>`` resolves to.
|
||||
|
||||
Useful for debugging "which model actually answered my request" and
|
||||
for confirming SIGHUP reloads picked up edits to ``aliases.yaml``.
|
||||
"""
|
||||
if alias_registry is None:
|
||||
raise HTTPException(status_code=503, detail="Service not ready")
|
||||
return {
|
||||
"default": alias_registry.default_alias,
|
||||
"aliases": [
|
||||
{
|
||||
"name": a.name,
|
||||
"description": a.description,
|
||||
"chain": list(a.chain),
|
||||
}
|
||||
for a in alias_registry.list_aliases()
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
@app.get("/v1/health")
|
||||
async def detailed_health() -> dict[str, Any]:
|
||||
"""Full per-provider liveness snapshot.
|
||||
|
||||
Includes the failure counter, last error, and the unhealthy-until
|
||||
backoff timestamp — info the original ``/health`` endpoint hides.
|
||||
"""
|
||||
if router is None:
|
||||
raise HTTPException(status_code=503, detail="Service not ready")
|
||||
return await router.health_check()
|
||||
|
||||
|
||||
# Models endpoints
|
||||
@app.get("/v1/models", response_model=ModelsResponse)
|
||||
async def list_models() -> ModelsResponse:
|
||||
|
|
@ -182,10 +269,12 @@ async def chat_completions(
|
|||
if router is None:
|
||||
raise HTTPException(status_code=503, detail="Service not ready")
|
||||
|
||||
# Parse provider and model for metrics
|
||||
model_parts = request.model.split("/", 1)
|
||||
provider = model_parts[0] if len(model_parts) > 1 else "ollama"
|
||||
model = model_parts[1] if len(model_parts) > 1 else request.model
|
||||
# The request's `model` field is what the caller asked for — could be
|
||||
# `mana/long-form`, `ollama/gemma3:4b`, or even bare `gemma3:4b`. For
|
||||
# error-path metrics we use that value (it's what the caller will
|
||||
# search for); for success-path metrics we use the resolved provider
|
||||
# so token-cost / latency attribute to the model that actually ran.
|
||||
requested_provider, requested_model = _split_model(request.model)
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
|
|
@ -198,7 +287,11 @@ async def chat_completions(
|
|||
async for chunk in stream_chat_completion(router, request):
|
||||
yield chunk
|
||||
|
||||
record_llm_request(provider, model, streaming=True)
|
||||
# Streaming metrics: we don't yet know which provider answered
|
||||
# at request-record time. Each chunk's `model` field carries
|
||||
# the resolved name; per-token latency is harder to attribute
|
||||
# cleanly so we skip it for streams.
|
||||
record_llm_request(requested_provider, requested_model, streaming=True)
|
||||
|
||||
return EventSourceResponse(
|
||||
generate(),
|
||||
|
|
@ -209,35 +302,43 @@ async def chat_completions(
|
|||
logger.info(f"Chat completion: {request.model}")
|
||||
response = await router.chat_completion(request)
|
||||
|
||||
# Record metrics
|
||||
resolved_provider, resolved_model = _split_model(response.model)
|
||||
latency = time.time() - start_time
|
||||
record_llm_request(
|
||||
provider=provider,
|
||||
model=model,
|
||||
provider=resolved_provider,
|
||||
model=resolved_model,
|
||||
streaming=False,
|
||||
prompt_tokens=response.usage.prompt_tokens,
|
||||
completion_tokens=response.usage.completion_tokens,
|
||||
latency=latency,
|
||||
)
|
||||
|
||||
return response
|
||||
# `response.model` is the concrete provider/model the chain
|
||||
# actually resolved to. Surface it via header so the caller
|
||||
# can attribute token cost to the right model even when the
|
||||
# request used an alias.
|
||||
return JSONResponse(
|
||||
content=response.model_dump(),
|
||||
headers={RESOLVED_MODEL_HEADER: response.model},
|
||||
)
|
||||
|
||||
except ValueError as e:
|
||||
logger.error(f"Invalid request: {e}")
|
||||
record_llm_error(provider, model, "invalid_request")
|
||||
record_llm_error(requested_provider, requested_model, "invalid_request")
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
except ProviderError as e:
|
||||
logger.warning(
|
||||
f"Provider error on {provider}/{model}: kind={e.kind} detail={e}"
|
||||
f"Provider error on {requested_provider}/{requested_model}: "
|
||||
f"kind={e.kind} detail={e}"
|
||||
)
|
||||
record_llm_error(provider, model, e.kind)
|
||||
record_llm_error(requested_provider, requested_model, e.kind)
|
||||
raise HTTPException(
|
||||
status_code=e.http_status,
|
||||
detail={"kind": e.kind, "message": str(e)},
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Chat completion failed: {e}")
|
||||
record_llm_error(provider, model, "server_error")
|
||||
record_llm_error(requested_provider, requested_model, "server_error")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
|
|
@ -248,10 +349,7 @@ async def create_embeddings(request: EmbeddingRequest) -> EmbeddingResponse:
|
|||
if router is None:
|
||||
raise HTTPException(status_code=503, detail="Service not ready")
|
||||
|
||||
# Parse provider and model for metrics
|
||||
model_parts = request.model.split("/", 1)
|
||||
provider = model_parts[0] if len(model_parts) > 1 else "ollama"
|
||||
model = model_parts[1] if len(model_parts) > 1 else request.model
|
||||
provider, model = _split_model(request.model)
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
|
|
@ -280,6 +378,21 @@ async def create_embeddings(request: EmbeddingRequest) -> EmbeddingResponse:
|
|||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
def _split_model(model: str) -> tuple[str, str]:
|
||||
"""Split a ``provider/model`` string for metric labelling.
|
||||
|
||||
Bare names with no slash default to ``ollama`` to match the legacy
|
||||
OpenAI-style behaviour. Aliases (``mana/...``) keep their namespace
|
||||
in the metrics — that's intentional, so request-side counters tell
|
||||
you what callers ASKED for, while the resolved-side counters
|
||||
(``mana_llm_alias_resolved_total``) tell you what they GOT.
|
||||
"""
|
||||
if "/" in model:
|
||||
provider, _, name = model.partition("/")
|
||||
return provider.lower(), name
|
||||
return "ollama", model
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import uvicorn
|
||||
|
||||
|
|
|
|||
|
|
@ -41,6 +41,11 @@ from src.models import (
|
|||
EmbeddingResponse,
|
||||
ModelInfo,
|
||||
)
|
||||
from src.utils.metrics import (
|
||||
record_alias_resolved,
|
||||
record_fallback,
|
||||
set_provider_healthy,
|
||||
)
|
||||
|
||||
from .base import LLMProvider
|
||||
from .errors import (
|
||||
|
|
@ -229,8 +234,10 @@ class ProviderRouter:
|
|||
chain = self._resolve_chain(model_or_alias)
|
||||
attempts: list[tuple[str, str]] = []
|
||||
last_exc: Exception | None = None
|
||||
is_alias = AliasRegistry.is_alias(model_or_alias)
|
||||
|
||||
for entry in chain:
|
||||
for i, entry in enumerate(chain):
|
||||
next_entry = chain[i + 1] if i + 1 < len(chain) else ""
|
||||
provider_name, model_name = self._parse_model(entry)
|
||||
if provider_name not in self.providers:
|
||||
logger.debug(
|
||||
|
|
@ -239,11 +246,13 @@ class ProviderRouter:
|
|||
provider_name,
|
||||
)
|
||||
attempts.append((entry, "unconfigured"))
|
||||
record_fallback(entry, next_entry, "unconfigured")
|
||||
continue
|
||||
|
||||
if not self.health_cache.is_healthy(provider_name):
|
||||
logger.debug("skip chain entry %s — cache says unhealthy", entry)
|
||||
attempts.append((entry, "cache-unhealthy"))
|
||||
record_fallback(entry, next_entry, "cache-unhealthy")
|
||||
continue
|
||||
|
||||
provider = self.providers[provider_name]
|
||||
|
|
@ -253,10 +262,13 @@ class ProviderRouter:
|
|||
logger.info(
|
||||
"execute → %s (alias=%s)",
|
||||
entry,
|
||||
model_or_alias if model_or_alias != entry else "<direct>",
|
||||
model_or_alias if is_alias else "<direct>",
|
||||
)
|
||||
result = await call(provider, model_name, request)
|
||||
self.health_cache.mark_healthy(provider_name)
|
||||
set_provider_healthy(provider_name, True)
|
||||
if is_alias:
|
||||
record_alias_resolved(model_or_alias, entry)
|
||||
return result
|
||||
except Exception as e:
|
||||
if not self._is_retryable(e):
|
||||
|
|
@ -266,7 +278,11 @@ class ProviderRouter:
|
|||
# being wrong.
|
||||
raise
|
||||
self.health_cache.mark_unhealthy(provider_name, self._exception_summary(e))
|
||||
set_provider_healthy(
|
||||
provider_name, self.health_cache.is_healthy(provider_name)
|
||||
)
|
||||
attempts.append((entry, type(e).__name__))
|
||||
record_fallback(entry, next_entry, type(e).__name__)
|
||||
last_exc = e
|
||||
logger.warning(
|
||||
"execute %s failed (retryable, will try next): %s",
|
||||
|
|
@ -310,14 +326,18 @@ class ProviderRouter:
|
|||
chain = self._resolve_chain(request.model)
|
||||
attempts: list[tuple[str, str]] = []
|
||||
last_exc: Exception | None = None
|
||||
is_alias = AliasRegistry.is_alias(request.model)
|
||||
|
||||
for entry in chain:
|
||||
for i, entry in enumerate(chain):
|
||||
next_entry = chain[i + 1] if i + 1 < len(chain) else ""
|
||||
provider_name, model_name = self._parse_model(entry)
|
||||
if provider_name not in self.providers:
|
||||
attempts.append((entry, "unconfigured"))
|
||||
record_fallback(entry, next_entry, "unconfigured")
|
||||
continue
|
||||
if not self.health_cache.is_healthy(provider_name):
|
||||
attempts.append((entry, "cache-unhealthy"))
|
||||
record_fallback(entry, next_entry, "cache-unhealthy")
|
||||
continue
|
||||
|
||||
provider = self.providers[provider_name]
|
||||
|
|
@ -330,13 +350,20 @@ class ProviderRouter:
|
|||
# Empty stream is a successful but content-free response.
|
||||
# Commit and exit cleanly.
|
||||
self.health_cache.mark_healthy(provider_name)
|
||||
set_provider_healthy(provider_name, True)
|
||||
if is_alias:
|
||||
record_alias_resolved(request.model, entry)
|
||||
logger.info("stream %s yielded empty response", entry)
|
||||
return
|
||||
except Exception as e:
|
||||
if not self._is_retryable(e):
|
||||
raise
|
||||
self.health_cache.mark_unhealthy(provider_name, self._exception_summary(e))
|
||||
set_provider_healthy(
|
||||
provider_name, self.health_cache.is_healthy(provider_name)
|
||||
)
|
||||
attempts.append((entry, type(e).__name__))
|
||||
record_fallback(entry, next_entry, type(e).__name__)
|
||||
last_exc = e
|
||||
logger.warning(
|
||||
"stream %s failed before first byte (retryable, trying next): %s",
|
||||
|
|
@ -349,6 +376,9 @@ class ProviderRouter:
|
|||
# the rest of the stream. Any error from here on propagates;
|
||||
# it is NOT safe to splice another provider's output in.
|
||||
self.health_cache.mark_healthy(provider_name)
|
||||
set_provider_healthy(provider_name, True)
|
||||
if is_alias:
|
||||
record_alias_resolved(request.model, entry)
|
||||
logger.info("stream → %s (committed after first chunk)", entry)
|
||||
yield first_chunk
|
||||
async for chunk in stream:
|
||||
|
|
|
|||
|
|
@ -4,7 +4,7 @@ import time
|
|||
from collections.abc import Callable
|
||||
|
||||
from fastapi import Request, Response
|
||||
from prometheus_client import Counter, Histogram, generate_latest
|
||||
from prometheus_client import Counter, Gauge, Histogram, generate_latest
|
||||
from starlette.middleware.base import BaseHTTPMiddleware
|
||||
|
||||
# Request metrics
|
||||
|
|
@ -47,6 +47,35 @@ LLM_ERRORS = Counter(
|
|||
["provider", "model", "error_type"],
|
||||
)
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Alias / fallback / health metrics — added in M4 of llm-fallback-aliases.md.
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
ALIAS_RESOLVED = Counter(
|
||||
"mana_llm_alias_resolved_total",
|
||||
"How often an alias resolved to a concrete provider/model. The `target` "
|
||||
"label is the chain entry that actually served the request — useful for "
|
||||
"spotting cases where the primary always falls through to a cloud entry.",
|
||||
["alias", "target"],
|
||||
)
|
||||
|
||||
FALLBACK_TRIGGERED = Counter(
|
||||
"mana_llm_fallback_total",
|
||||
"Fallback transitions: a chain entry failed (or was skipped via cache) "
|
||||
"and the router moved to the next entry. `reason` is the exception class "
|
||||
"name or `cache-unhealthy` / `unconfigured`. `from_model` is the entry "
|
||||
"that didn't serve, `to_model` is empty when no further entries existed.",
|
||||
["from_model", "to_model", "reason"],
|
||||
)
|
||||
|
||||
PROVIDER_HEALTHY = Gauge(
|
||||
"mana_llm_provider_healthy",
|
||||
"1 when the provider is currently considered healthy by the cache, "
|
||||
"0 when in backoff. Refreshed on every probe tick and on every router "
|
||||
"call-site state transition.",
|
||||
["provider"],
|
||||
)
|
||||
|
||||
|
||||
def get_metrics() -> bytes:
|
||||
"""Generate Prometheus metrics output."""
|
||||
|
|
@ -107,3 +136,23 @@ def record_llm_request(
|
|||
def record_llm_error(provider: str, model: str, error_type: str) -> None:
|
||||
"""Record LLM error metrics."""
|
||||
LLM_ERRORS.labels(provider=provider, model=model, error_type=error_type).inc()
|
||||
|
||||
|
||||
def record_alias_resolved(alias: str, target: str) -> None:
|
||||
"""Record which concrete model an alias resolved to for this request."""
|
||||
ALIAS_RESOLVED.labels(alias=alias, target=target).inc()
|
||||
|
||||
|
||||
def record_fallback(from_model: str, to_model: str, reason: str) -> None:
|
||||
"""Record a fallback transition. ``to_model`` is empty when the chain
|
||||
ran out (i.e. NoHealthyProviderError)."""
|
||||
FALLBACK_TRIGGERED.labels(
|
||||
from_model=from_model,
|
||||
to_model=to_model,
|
||||
reason=reason,
|
||||
).inc()
|
||||
|
||||
|
||||
def set_provider_healthy(provider: str, healthy: bool) -> None:
|
||||
"""Mirror ``ProviderHealthCache`` state into a Prometheus gauge."""
|
||||
PROVIDER_HEALTHY.labels(provider=provider).set(1.0 if healthy else 0.0)
|
||||
|
|
|
|||
415
services/mana-llm/tests/test_m4_observability.py
Normal file
415
services/mana-llm/tests/test_m4_observability.py
Normal file
|
|
@ -0,0 +1,415 @@
|
|||
"""Tests for M4: observability + debug endpoints + reload."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
import signal
|
||||
from pathlib import Path
|
||||
|
||||
import httpx
|
||||
import pytest
|
||||
from fastapi.testclient import TestClient
|
||||
from prometheus_client import REGISTRY
|
||||
|
||||
from src.aliases import AliasRegistry
|
||||
from src.health import ProviderHealthCache
|
||||
from src.models import (
|
||||
ChatCompletionRequest,
|
||||
ChatCompletionResponse,
|
||||
Choice,
|
||||
Message,
|
||||
MessageResponse,
|
||||
)
|
||||
from src.providers import ProviderRouter
|
||||
from src.utils.metrics import (
|
||||
record_alias_resolved,
|
||||
record_fallback,
|
||||
set_provider_healthy,
|
||||
)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Cache → listener → metric gauge
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestHealthChangeListener:
|
||||
def test_listener_fires_on_unhealthy_transition(self) -> None:
|
||||
cache = ProviderHealthCache(failure_threshold=2)
|
||||
events: list[tuple[str, bool]] = []
|
||||
cache.add_listener(lambda p, h: events.append((p, h)))
|
||||
|
||||
# First failure: still healthy → no transition.
|
||||
cache.mark_unhealthy("ollama", "blip")
|
||||
assert events == []
|
||||
|
||||
# Second failure: transition healthy→unhealthy → fires.
|
||||
cache.mark_unhealthy("ollama", "boom")
|
||||
assert events == [("ollama", False)]
|
||||
|
||||
def test_listener_fires_on_recovery(self) -> None:
|
||||
cache = ProviderHealthCache(failure_threshold=1)
|
||||
events: list[tuple[str, bool]] = []
|
||||
cache.add_listener(lambda p, h: events.append((p, h)))
|
||||
|
||||
cache.mark_unhealthy("ollama", "boom")
|
||||
assert events == [("ollama", False)]
|
||||
|
||||
cache.mark_healthy("ollama")
|
||||
assert events == [("ollama", False), ("ollama", True)]
|
||||
|
||||
def test_steady_state_does_not_fire(self) -> None:
|
||||
cache = ProviderHealthCache(failure_threshold=1)
|
||||
events: list[tuple[str, bool]] = []
|
||||
cache.add_listener(lambda p, h: events.append((p, h)))
|
||||
|
||||
# Three healthy ops in a row — no transitions, no events.
|
||||
for _ in range(3):
|
||||
cache.mark_healthy("ollama")
|
||||
assert events == []
|
||||
|
||||
def test_listener_exception_does_not_break_cache(self) -> None:
|
||||
cache = ProviderHealthCache(failure_threshold=1)
|
||||
|
||||
def bad(_provider: str, _healthy: bool) -> None:
|
||||
raise RuntimeError("listener boom")
|
||||
|
||||
cache.add_listener(bad)
|
||||
# Should NOT raise — the cache must keep working with a broken
|
||||
# listener, otherwise one bad metric callback would brick the
|
||||
# whole router.
|
||||
cache.mark_unhealthy("ollama", "x")
|
||||
assert cache.is_healthy("ollama") is False
|
||||
|
||||
def test_multiple_listeners(self) -> None:
|
||||
cache = ProviderHealthCache(failure_threshold=1)
|
||||
a: list = []
|
||||
b: list = []
|
||||
cache.add_listener(lambda p, h: a.append((p, h)))
|
||||
cache.add_listener(lambda p, h: b.append((p, h)))
|
||||
|
||||
cache.mark_unhealthy("ollama", "x")
|
||||
assert a == [("ollama", False)]
|
||||
assert b == [("ollama", False)]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Prometheus metrics — counters/gauges actually move
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _counter_value(name: str, labels: dict[str, str]) -> float:
|
||||
"""Helper: read the current value of a labeled Prometheus metric."""
|
||||
samples = REGISTRY.get_sample_value(name, labels=labels)
|
||||
return samples or 0.0
|
||||
|
||||
|
||||
class TestMetricsRecording:
|
||||
def test_record_alias_resolved_increments(self) -> None:
|
||||
before = _counter_value(
|
||||
"mana_llm_alias_resolved_total",
|
||||
{"alias": "mana/test-class", "target": "ollama/x:1b"},
|
||||
)
|
||||
record_alias_resolved("mana/test-class", "ollama/x:1b")
|
||||
after = _counter_value(
|
||||
"mana_llm_alias_resolved_total",
|
||||
{"alias": "mana/test-class", "target": "ollama/x:1b"},
|
||||
)
|
||||
assert after - before == pytest.approx(1.0)
|
||||
|
||||
def test_record_fallback_increments(self) -> None:
|
||||
before = _counter_value(
|
||||
"mana_llm_fallback_total",
|
||||
{"from_model": "ollama/x", "to_model": "groq/y", "reason": "ConnectError"},
|
||||
)
|
||||
record_fallback("ollama/x", "groq/y", "ConnectError")
|
||||
after = _counter_value(
|
||||
"mana_llm_fallback_total",
|
||||
{"from_model": "ollama/x", "to_model": "groq/y", "reason": "ConnectError"},
|
||||
)
|
||||
assert after - before == pytest.approx(1.0)
|
||||
|
||||
def test_set_provider_healthy_writes_gauge(self) -> None:
|
||||
set_provider_healthy("test_provider_xyz", True)
|
||||
v = REGISTRY.get_sample_value(
|
||||
"mana_llm_provider_healthy", labels={"provider": "test_provider_xyz"}
|
||||
)
|
||||
assert v == 1.0
|
||||
|
||||
set_provider_healthy("test_provider_xyz", False)
|
||||
v = REGISTRY.get_sample_value(
|
||||
"mana_llm_provider_healthy", labels={"provider": "test_provider_xyz"}
|
||||
)
|
||||
assert v == 0.0
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Router → metrics: end-to-end through a fallback
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class _OkProvider:
|
||||
"""Minimal provider double — only what the router uses for chat."""
|
||||
|
||||
name = "ok-provider"
|
||||
supports_tools = True
|
||||
|
||||
def __init__(self, name: str, fail_with: BaseException | None = None) -> None:
|
||||
self.name = name
|
||||
self.fail_with = fail_with
|
||||
self.calls = 0
|
||||
|
||||
def model_supports_tools(self, model: str) -> bool:
|
||||
return True
|
||||
|
||||
async def chat_completion(self, request, model):
|
||||
self.calls += 1
|
||||
if self.fail_with is not None:
|
||||
raise self.fail_with
|
||||
return ChatCompletionResponse(
|
||||
model=f"{self.name}/{model}",
|
||||
choices=[Choice(message=MessageResponse(content="ok"))],
|
||||
)
|
||||
|
||||
async def chat_completion_stream(self, request, model): # pragma: no cover
|
||||
if False: # pragma: no cover
|
||||
yield None
|
||||
|
||||
async def list_models(self):
|
||||
return []
|
||||
|
||||
async def embeddings(self, request, model):
|
||||
raise NotImplementedError
|
||||
|
||||
async def health_check(self):
|
||||
return {"status": "healthy"}
|
||||
|
||||
async def close(self):
|
||||
pass
|
||||
|
||||
|
||||
def _aliases(tmp_path: Path) -> AliasRegistry:
|
||||
cfg = (
|
||||
"aliases:\n"
|
||||
" mana/two-step:\n"
|
||||
' description: "x"\n'
|
||||
" chain:\n"
|
||||
" - alpha/m1\n"
|
||||
" - beta/m2\n"
|
||||
)
|
||||
p = tmp_path / "aliases.yaml"
|
||||
p.write_text(cfg)
|
||||
return AliasRegistry(p)
|
||||
|
||||
|
||||
class TestRouterMetricsIntegration:
|
||||
@pytest.mark.asyncio
|
||||
async def test_alias_resolved_metric_records_target(self, tmp_path: Path) -> None:
|
||||
aliases = _aliases(tmp_path)
|
||||
cache = ProviderHealthCache()
|
||||
router = ProviderRouter(aliases=aliases, health_cache=cache)
|
||||
router.providers = {"alpha": _OkProvider("alpha")} # beta not configured
|
||||
|
||||
before = _counter_value(
|
||||
"mana_llm_alias_resolved_total",
|
||||
{"alias": "mana/two-step", "target": "alpha/m1"},
|
||||
)
|
||||
await router.chat_completion(
|
||||
ChatCompletionRequest(
|
||||
model="mana/two-step",
|
||||
messages=[Message(role="user", content="hi")],
|
||||
)
|
||||
)
|
||||
after = _counter_value(
|
||||
"mana_llm_alias_resolved_total",
|
||||
{"alias": "mana/two-step", "target": "alpha/m1"},
|
||||
)
|
||||
assert after - before == pytest.approx(1.0)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_fallback_metric_records_transition(self, tmp_path: Path) -> None:
|
||||
aliases = _aliases(tmp_path)
|
||||
cache = ProviderHealthCache()
|
||||
router = ProviderRouter(aliases=aliases, health_cache=cache)
|
||||
router.providers = {
|
||||
"alpha": _OkProvider("alpha", fail_with=httpx.ConnectError("dead")),
|
||||
"beta": _OkProvider("beta"),
|
||||
}
|
||||
|
||||
before = _counter_value(
|
||||
"mana_llm_fallback_total",
|
||||
{"from_model": "alpha/m1", "to_model": "beta/m2", "reason": "ConnectError"},
|
||||
)
|
||||
await router.chat_completion(
|
||||
ChatCompletionRequest(
|
||||
model="mana/two-step",
|
||||
messages=[Message(role="user", content="hi")],
|
||||
)
|
||||
)
|
||||
after = _counter_value(
|
||||
"mana_llm_fallback_total",
|
||||
{"from_model": "alpha/m1", "to_model": "beta/m2", "reason": "ConnectError"},
|
||||
)
|
||||
assert after - before == pytest.approx(1.0)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_direct_model_does_not_record_alias_metric(
|
||||
self, tmp_path: Path
|
||||
) -> None:
|
||||
# Direct provider/model is not an alias — ALIAS_RESOLVED counter
|
||||
# must stay flat for those calls.
|
||||
aliases = _aliases(tmp_path)
|
||||
cache = ProviderHealthCache()
|
||||
router = ProviderRouter(aliases=aliases, health_cache=cache)
|
||||
router.providers = {"alpha": _OkProvider("alpha")}
|
||||
|
||||
before = _counter_value(
|
||||
"mana_llm_alias_resolved_total",
|
||||
{"alias": "alpha/anything", "target": "alpha/anything"},
|
||||
)
|
||||
await router.chat_completion(
|
||||
ChatCompletionRequest(
|
||||
model="alpha/anything",
|
||||
messages=[Message(role="user", content="hi")],
|
||||
)
|
||||
)
|
||||
after = _counter_value(
|
||||
"mana_llm_alias_resolved_total",
|
||||
{"alias": "alpha/anything", "target": "alpha/anything"},
|
||||
)
|
||||
# Counter must have NOT increased — direct calls aren't aliases.
|
||||
assert after == before
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Debug endpoints: GET /v1/aliases, GET /v1/health
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def client():
|
||||
from src.main import app
|
||||
|
||||
with TestClient(app) as c:
|
||||
yield c
|
||||
|
||||
|
||||
class TestDebugEndpoints:
|
||||
def test_v1_aliases_returns_shipped_config(self, client: TestClient) -> None:
|
||||
resp = client.get("/v1/aliases")
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
names = [a["name"] for a in data["aliases"]]
|
||||
# The five canonical classes must always be present.
|
||||
for expected in (
|
||||
"mana/fast-text",
|
||||
"mana/long-form",
|
||||
"mana/structured",
|
||||
"mana/reasoning",
|
||||
"mana/vision",
|
||||
):
|
||||
assert expected in names
|
||||
# Default is set in the shipped config.
|
||||
assert data["default"] == "mana/fast-text"
|
||||
|
||||
def test_v1_aliases_chain_format(self, client: TestClient) -> None:
|
||||
resp = client.get("/v1/aliases")
|
||||
data = resp.json()
|
||||
long_form = next(a for a in data["aliases"] if a["name"] == "mana/long-form")
|
||||
# Each chain entry is a `provider/model` string.
|
||||
assert all("/" in entry for entry in long_form["chain"])
|
||||
assert len(long_form["chain"]) >= 2 # plan requires at least one cloud fallback
|
||||
|
||||
def test_v1_health_includes_all_providers(self, client: TestClient) -> None:
|
||||
resp = client.get("/v1/health")
|
||||
assert resp.status_code == 200
|
||||
data = resp.json()
|
||||
assert "status" in data
|
||||
assert "providers" in data
|
||||
# ollama is always configured (provider list is non-empty).
|
||||
assert "ollama" in data["providers"]
|
||||
for name, info in data["providers"].items():
|
||||
assert "status" in info
|
||||
assert "consecutive_failures" in info
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# X-Mana-LLM-Resolved header on non-streaming responses
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestResolvedHeader:
|
||||
"""The header is the consumer's hook for token-cost attribution.
|
||||
|
||||
Tested at the router level — wiring through main.py would need a
|
||||
real provider connection, which isn't available in unit tests.
|
||||
"""
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_response_model_field_carries_resolved_target(
|
||||
self, tmp_path: Path
|
||||
) -> None:
|
||||
# The header value is `response.model`; verify that field reflects
|
||||
# the actual chain entry that served, not the requested alias.
|
||||
aliases = _aliases(tmp_path)
|
||||
cache = ProviderHealthCache()
|
||||
router = ProviderRouter(aliases=aliases, health_cache=cache)
|
||||
# Force fallback to beta.
|
||||
router.providers = {
|
||||
"alpha": _OkProvider("alpha", fail_with=httpx.ConnectError("d")),
|
||||
"beta": _OkProvider("beta"),
|
||||
}
|
||||
|
||||
resp = await router.chat_completion(
|
||||
ChatCompletionRequest(
|
||||
model="mana/two-step",
|
||||
messages=[Message(role="user", content="hi")],
|
||||
)
|
||||
)
|
||||
# Even though the caller asked for `mana/two-step`, the resolved
|
||||
# field shows the entry that actually answered.
|
||||
assert resp.model == "beta/m2"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# SIGHUP reload — only meaningful on Unix; tested by signalling the proc
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class TestSighupReload:
|
||||
"""SIGHUP triggers ``alias_registry.reload()``; reload-error keeps state.
|
||||
|
||||
The signal-handler wiring lives in main.py and only installs when
|
||||
the loop is running in the main thread. We exercise the reload
|
||||
semantics here directly on the registry instead — the signal-handler
|
||||
code path itself is a 4-line wrapper around ``reload()``.
|
||||
"""
|
||||
|
||||
def test_reload_picks_up_yaml_edits(self, tmp_path: Path) -> None:
|
||||
path = tmp_path / "aliases.yaml"
|
||||
path.write_text(
|
||||
"aliases:\n"
|
||||
" mana/x:\n"
|
||||
' description: "x"\n'
|
||||
" chain:\n"
|
||||
" - ollama/foo:1b\n"
|
||||
)
|
||||
reg = AliasRegistry(path)
|
||||
assert reg.resolve_chain("mana/x") == ("ollama/foo:1b",)
|
||||
|
||||
# Edit on disk, reload (this is exactly what the SIGHUP handler
|
||||
# does — minus the signal plumbing).
|
||||
path.write_text(
|
||||
"aliases:\n"
|
||||
" mana/x:\n"
|
||||
' description: "x"\n'
|
||||
" chain:\n"
|
||||
" - ollama/bar:1b\n"
|
||||
" - groq/llama-3.1-8b-instant\n"
|
||||
)
|
||||
reg.reload()
|
||||
assert reg.resolve_chain("mana/x") == (
|
||||
"ollama/bar:1b",
|
||||
"groq/llama-3.1-8b-instant",
|
||||
)
|
||||
Loading…
Add table
Add a link
Reference in a new issue