managarten/services/mana-llm/src/providers/router.py
Till JS e757470cb0 feat(mana-llm): add OpenAI-style tools + tool_calls passthrough
Extends the chat-completions surface so callers can ask any provider
to call named functions and get structured tool_calls back. Wired
through all three provider adapters so the planner and companion can
switch off the fragile JSON-parsing pathway.

- Request: tools[], tool_choice, assistant tool_calls, tool-role
  messages with tool_call_id.
- Response: MessageResponse.tool_calls, Choice.finish_reason adds
  "tool_calls", DeltaContent streams tool_calls.
- Google provider: Tool(function_declarations=...) build, result
  normalised (args dict → JSON string), function_response parts on
  a user turn for tool-role messages.
- OpenAI-compat: 1:1 passthrough of the OpenAI spec.
- Ollama: /api/chat passthrough; model-level capability check via a
  TOOL_CAPABLE_OLLAMA_PATTERNS whitelist (llama3.1+, qwen2.5+,
  mistral, command-r, …) — unsupported models rejected rather than
  silently falling back to prose.
- Router: model_supports_tools() check upfront for both streaming
  and non-streaming paths; ProviderCapabilityError bubbles as 400.

No silent downgrade. Missing tool support = explicit error.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-04-20 15:22:48 +02:00

336 lines
12 KiB
Python

"""Provider routing logic for mana-llm with auto-fallback support."""
import asyncio
import logging
import time
from collections.abc import AsyncIterator
from typing import Any
from src.config import settings
from src.models import (
ChatCompletionRequest,
ChatCompletionResponse,
ChatCompletionStreamResponse,
EmbeddingRequest,
EmbeddingResponse,
ModelInfo,
)
from .base import LLMProvider
from .errors import ProviderCapabilityError
from .ollama import OllamaProvider
from .openai_compat import OpenAICompatProvider
logger = logging.getLogger(__name__)
class ProviderRouter:
"""Routes requests to appropriate LLM providers with auto-fallback.
When auto_fallback_enabled is True and a Google API key is configured:
- Ollama requests that fail or exceed max_concurrent are automatically
retried on Google Gemini with model mapping.
- Explicit provider requests (e.g., openrouter/...) are never fallback-routed.
"""
def __init__(self):
self.providers: dict[str, LLMProvider] = {}
self._ollama_concurrent: int = 0
self._ollama_health_cache: tuple[dict[str, Any] | None, float] = (None, 0)
self._health_cache_ttl: float = 5.0 # seconds
self._initialize_providers()
def _initialize_providers(self) -> None:
"""Initialize available providers based on configuration."""
# Ollama is always available (local)
self.providers["ollama"] = OllamaProvider()
logger.info(f"Initialized Ollama provider at {settings.ollama_url}")
# Google Gemini (fallback provider)
if settings.google_api_key:
from .google import GoogleProvider
self.providers["google"] = GoogleProvider(
api_key=settings.google_api_key,
default_model=settings.google_default_model,
)
logger.info("Initialized Google Gemini provider (fallback)")
# OpenRouter (if API key configured)
if settings.openrouter_api_key:
self.providers["openrouter"] = OpenAICompatProvider(
name="openrouter",
base_url=settings.openrouter_base_url,
api_key=settings.openrouter_api_key,
default_model=settings.openrouter_default_model,
)
logger.info("Initialized OpenRouter provider")
# Groq (if API key configured)
if settings.groq_api_key:
self.providers["groq"] = OpenAICompatProvider(
name="groq",
base_url=settings.groq_base_url,
api_key=settings.groq_api_key,
)
logger.info("Initialized Groq provider")
# Together (if API key configured)
if settings.together_api_key:
self.providers["together"] = OpenAICompatProvider(
name="together",
base_url=settings.together_base_url,
api_key=settings.together_api_key,
)
logger.info("Initialized Together provider")
def _parse_model(self, model: str) -> tuple[str, str]:
"""Parse model string into (provider, model_name)."""
if "/" in model:
parts = model.split("/", 1)
provider = parts[0].lower()
model_name = parts[1]
else:
provider = "ollama"
model_name = model
return provider, model_name
def _get_provider(self, provider_name: str) -> LLMProvider:
"""Get provider by name, raise if not available."""
if provider_name not in self.providers:
available = list(self.providers.keys())
raise ValueError(
f"Provider '{provider_name}' not available. "
f"Available providers: {available}"
)
return self.providers[provider_name]
def _can_fallback_to_google(self, provider_name: str) -> bool:
"""Check if a request can be fallback-routed to Google."""
return (
settings.auto_fallback_enabled
and provider_name == "ollama"
and "google" in self.providers
)
def _should_use_ollama(self) -> bool:
"""Determine if Ollama should handle the request based on load."""
return self._ollama_concurrent < settings.ollama_max_concurrent
async def _get_ollama_health_cached(self) -> dict[str, Any]:
"""Get Ollama health with caching (5s TTL)."""
cached, cached_at = self._ollama_health_cache
if cached is not None and (time.time() - cached_at) < self._health_cache_ttl:
return cached
try:
provider = self.providers.get("ollama")
if provider:
result = await provider.health_check()
else:
result = {"status": "unhealthy", "error": "no ollama provider"}
except Exception as e:
result = {"status": "unhealthy", "error": str(e)}
self._ollama_health_cache = (result, time.time())
return result
async def _fallback_to_google(
self,
request: ChatCompletionRequest,
model_name: str,
original_error: Exception | None = None,
) -> ChatCompletionResponse:
"""Route a request to Google Gemini as fallback."""
from .google import GoogleProvider
google = self.providers["google"]
assert isinstance(google, GoogleProvider)
gemini_model = google.map_model(model_name)
reason = f"error: {original_error}" if original_error else "overloaded"
logger.warning(
f"Falling back to Google Gemini ({gemini_model}) for ollama/{model_name} ({reason})"
)
return await google.chat_completion(request, gemini_model)
def _check_tool_capability(
self, provider: LLMProvider, model_name: str, request: ChatCompletionRequest
) -> None:
"""Refuse tool-bearing requests for providers/models without tool support.
Silent downgrade (dropping the `tools` payload) is more dangerous
than an explicit error — the caller would get plain text back and
have no way to tell the tools never reached the model.
"""
if not request.tools:
return
if not provider.model_supports_tools(model_name):
raise ProviderCapabilityError(
f"{provider.name}/{model_name} does not support tool calling. "
"Choose a tool-capable model (e.g. gemini-2.5-flash, llama3.1:*)"
)
async def chat_completion(
self,
request: ChatCompletionRequest,
) -> ChatCompletionResponse:
"""Route chat completion request with auto-fallback."""
provider_name, model_name = self._parse_model(request.model)
# Non-Ollama providers: direct routing, no fallback
if provider_name != "ollama":
provider = self._get_provider(provider_name)
self._check_tool_capability(provider, model_name, request)
logger.info(f"Routing chat completion to {provider_name}/{model_name}")
return await provider.chat_completion(request, model_name)
# Ollama with fallback logic
can_fallback = self._can_fallback_to_google(provider_name)
# Check if Ollama is overloaded
if can_fallback and not self._should_use_ollama():
return await self._fallback_to_google(request, model_name)
# Try Ollama first
provider = self._get_provider("ollama")
self._check_tool_capability(provider, model_name, request)
logger.info(f"Routing chat completion to ollama/{model_name}")
self._ollama_concurrent += 1
try:
return await provider.chat_completion(request, model_name)
except Exception as e:
logger.error(f"Chat completion failed on ollama: {e}")
if can_fallback:
return await self._fallback_to_google(request, model_name, e)
raise
finally:
self._ollama_concurrent -= 1
async def chat_completion_stream(
self,
request: ChatCompletionRequest,
) -> AsyncIterator[ChatCompletionStreamResponse]:
"""Route streaming chat completion with auto-fallback."""
provider_name, model_name = self._parse_model(request.model)
# Non-Ollama: direct
if provider_name != "ollama":
provider = self._get_provider(provider_name)
self._check_tool_capability(provider, model_name, request)
logger.info(f"Routing streaming to {provider_name}/{model_name}")
async for chunk in provider.chat_completion_stream(request, model_name):
yield chunk
return
# Ollama with fallback
can_fallback = self._can_fallback_to_google(provider_name)
if can_fallback and not self._should_use_ollama():
from .google import GoogleProvider
google = self.providers["google"]
assert isinstance(google, GoogleProvider)
gemini_model = google.map_model(model_name)
logger.warning(f"Streaming fallback to Google Gemini ({gemini_model})")
async for chunk in google.chat_completion_stream(request, gemini_model):
yield chunk
return
provider = self._get_provider("ollama")
self._check_tool_capability(provider, model_name, request)
logger.info(f"Routing streaming to ollama/{model_name}")
self._ollama_concurrent += 1
try:
async for chunk in provider.chat_completion_stream(request, model_name):
yield chunk
except Exception as e:
logger.error(f"Streaming failed on ollama: {e}")
if can_fallback:
from .google import GoogleProvider
google = self.providers["google"]
assert isinstance(google, GoogleProvider)
gemini_model = google.map_model(model_name)
logger.warning(f"Streaming fallback to Google Gemini ({gemini_model})")
async for chunk in google.chat_completion_stream(request, gemini_model):
yield chunk
else:
raise
finally:
self._ollama_concurrent -= 1
async def embeddings(
self,
request: EmbeddingRequest,
) -> EmbeddingResponse:
"""Route embeddings request to appropriate provider."""
provider_name, model_name = self._parse_model(request.model)
provider = self._get_provider(provider_name)
logger.info(f"Routing embeddings to {provider_name}/{model_name}")
return await provider.embeddings(request, model_name)
async def list_models(self) -> list[ModelInfo]:
"""List all available models from all providers."""
all_models: list[ModelInfo] = []
for provider in self.providers.values():
try:
models = await provider.list_models()
all_models.extend(models)
except Exception as e:
logger.warning(f"Failed to list models from {provider.name}: {e}")
return all_models
async def get_model(self, model_id: str) -> ModelInfo | None:
"""Get specific model info."""
provider_name, model_name = self._parse_model(model_id)
if provider_name not in self.providers:
return None
provider = self.providers[provider_name]
models = await provider.list_models()
for model in models:
if model.id == model_id or model.id.endswith(f"/{model_name}"):
return model
return None
async def health_check(self) -> dict[str, Any]:
"""Check health of all providers."""
results: dict[str, Any] = {}
for name, provider in self.providers.items():
results[name] = await provider.health_check()
# Overall status
all_healthy = all(r.get("status") == "healthy" for r in results.values())
any_healthy = any(r.get("status") == "healthy" for r in results.values())
status_info: dict[str, Any] = {
"status": "healthy" if all_healthy else ("degraded" if any_healthy else "unhealthy"),
"providers": results,
}
# Include fallback info
if settings.auto_fallback_enabled and "google" in self.providers:
status_info["fallback"] = {
"enabled": True,
"ollama_concurrent": self._ollama_concurrent,
"ollama_max_concurrent": settings.ollama_max_concurrent,
}
return status_info
async def close(self) -> None:
"""Close all providers."""
for provider in self.providers.values():
await provider.close()