mirror of
https://github.com/Memo-2023/mana-monorepo.git
synced 2026-05-19 22:01:26 +02:00
feat(mana-llm): add Google Gemini fallback provider with auto-routing
Add Google Gemini as a fallback provider that activates automatically when Ollama is overloaded or unavailable, ensuring LLM requests always succeed even under load. New provider (src/providers/google.py): - Full LLMProvider implementation using google-genai SDK - Chat completions (streaming + non-streaming) - Vision/multimodal support (base64 images) - Embeddings via text-embedding-004 - Model mapping: Ollama models → Gemini equivalents (gemma3:4b → gemini-2.0-flash, llava:7b → gemini-2.0-flash, etc.) Auto-fallback routing (src/providers/router.py): - Concurrent request tracking for Ollama (OLLAMA_MAX_CONCURRENT=3) - When Ollama concurrent > max: route to Google automatically - When Ollama fails: retry on Google with model mapping - Health check caching (5s TTL) to avoid hammering Ollama - Non-Ollama providers (openrouter, groq, together) are never fallback-routed - Fallback info included in /health endpoint response New config (src/config.py): - GOOGLE_API_KEY: enables Google provider - GOOGLE_DEFAULT_MODEL: default gemini-2.0-flash - AUTO_FALLBACK_ENABLED: toggle fallback (default: true) - OLLAMA_MAX_CONCURRENT: concurrent request threshold (default: 3) Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
parent
28286d126c
commit
45063b88be
5 changed files with 430 additions and 19 deletions
|
|
@ -1779,6 +1779,10 @@ services:
|
|||
OPENROUTER_API_KEY: ${OPENROUTER_API_KEY:-}
|
||||
GROQ_API_KEY: ${GROQ_API_KEY:-}
|
||||
TOGETHER_API_KEY: ${TOGETHER_API_KEY:-}
|
||||
GOOGLE_API_KEY: ${GOOGLE_API_KEY:-}
|
||||
GOOGLE_DEFAULT_MODEL: gemini-2.0-flash
|
||||
AUTO_FALLBACK_ENABLED: "true"
|
||||
OLLAMA_MAX_CONCURRENT: 3
|
||||
CORS_ORIGINS: https://playground.mana.how,https://mana.how,https://chat.mana.how
|
||||
extra_hosts:
|
||||
- "host.docker.internal:host-gateway"
|
||||
|
|
|
|||
|
|
@ -12,6 +12,7 @@ dependencies = [
|
|||
"sse-starlette>=2.2.0",
|
||||
"redis>=5.2.0",
|
||||
"prometheus-client>=0.21.0",
|
||||
"google-genai>=1.0.0",
|
||||
]
|
||||
|
||||
[project.optional-dependencies]
|
||||
|
|
|
|||
|
|
@ -28,6 +28,14 @@ class Settings(BaseSettings):
|
|||
together_api_key: str | None = None
|
||||
together_base_url: str = "https://api.together.xyz/v1"
|
||||
|
||||
# Google Gemini (Fallback provider)
|
||||
google_api_key: str | None = None
|
||||
google_default_model: str = "gemini-2.0-flash"
|
||||
|
||||
# Auto-fallback: Ollama → Google when Ollama is overloaded/down
|
||||
auto_fallback_enabled: bool = True
|
||||
ollama_max_concurrent: int = 3
|
||||
|
||||
# Caching (Optional)
|
||||
redis_url: str | None = None
|
||||
cache_ttl: int = 3600
|
||||
|
|
|
|||
270
services/mana-llm/src/providers/google.py
Normal file
270
services/mana-llm/src/providers/google.py
Normal file
|
|
@ -0,0 +1,270 @@
|
|||
"""Google Gemini provider for mana-llm (fallback when Ollama is unavailable)."""
|
||||
|
||||
import logging
|
||||
from collections.abc import AsyncIterator
|
||||
from typing import Any
|
||||
|
||||
from google import genai
|
||||
from google.genai import types
|
||||
|
||||
from src.config import settings
|
||||
from src.models import (
|
||||
ChatCompletionRequest,
|
||||
ChatCompletionResponse,
|
||||
ChatCompletionStreamResponse,
|
||||
Choice,
|
||||
DeltaContent,
|
||||
EmbeddingData,
|
||||
EmbeddingRequest,
|
||||
EmbeddingResponse,
|
||||
MessageResponse,
|
||||
ModelInfo,
|
||||
StreamChoice,
|
||||
Usage,
|
||||
)
|
||||
|
||||
from .base import LLMProvider
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Model mapping: Ollama model → Google Gemini equivalent
|
||||
OLLAMA_TO_GEMINI: dict[str, str] = {
|
||||
"gemma3:4b": "gemini-2.0-flash",
|
||||
"gemma3:12b": "gemini-2.0-flash",
|
||||
"gemma3:27b": "gemini-2.5-pro",
|
||||
"llava:7b": "gemini-2.0-flash", # Gemini has native vision
|
||||
"qwen3-vl:4b": "gemini-2.0-flash", # vision fallback
|
||||
"qwen2.5-coder:7b": "gemini-2.0-flash",
|
||||
"qwen2.5-coder:14b": "gemini-2.5-pro",
|
||||
"phi3.5:latest": "gemini-2.0-flash",
|
||||
"ministral-3:3b": "gemini-2.0-flash",
|
||||
"deepseek-ocr:latest": "gemini-2.0-flash",
|
||||
}
|
||||
|
||||
|
||||
class GoogleProvider(LLMProvider):
|
||||
"""Google Gemini API provider."""
|
||||
|
||||
name = "google"
|
||||
|
||||
def __init__(self, api_key: str, default_model: str = "gemini-2.0-flash"):
|
||||
self.api_key = api_key
|
||||
self.default_model = default_model
|
||||
self.client = genai.Client(api_key=api_key)
|
||||
|
||||
def map_model(self, ollama_model: str) -> str:
|
||||
"""Map an Ollama model name to a Google Gemini equivalent."""
|
||||
return OLLAMA_TO_GEMINI.get(ollama_model, self.default_model)
|
||||
|
||||
def _convert_messages(
|
||||
self, request: ChatCompletionRequest
|
||||
) -> tuple[str | None, list[types.Content]]:
|
||||
"""Convert OpenAI-format messages to Google Gemini format.
|
||||
|
||||
Returns (system_instruction, contents).
|
||||
"""
|
||||
system_instruction: str | None = None
|
||||
contents: list[types.Content] = []
|
||||
|
||||
for msg in request.messages:
|
||||
if msg.role == "system":
|
||||
# Gemini uses system_instruction separately
|
||||
if isinstance(msg.content, str):
|
||||
system_instruction = msg.content
|
||||
continue
|
||||
|
||||
role = "user" if msg.role == "user" else "model"
|
||||
|
||||
if isinstance(msg.content, str):
|
||||
contents.append(types.Content(role=role, parts=[types.Part.from_text(msg.content)]))
|
||||
else:
|
||||
# Multimodal content
|
||||
parts: list[types.Part] = []
|
||||
for part in msg.content:
|
||||
if part.type == "text":
|
||||
parts.append(types.Part.from_text(part.text))
|
||||
elif part.type == "image_url" and part.image_url:
|
||||
url = part.image_url.url
|
||||
if url.startswith("data:"):
|
||||
# Parse data URI: data:image/jpeg;base64,<data>
|
||||
header, b64_data = url.split(",", 1)
|
||||
mime_type = header.split(":")[1].split(";")[0]
|
||||
import base64
|
||||
|
||||
image_bytes = base64.b64decode(b64_data)
|
||||
parts.append(
|
||||
types.Part.from_bytes(data=image_bytes, mime_type=mime_type)
|
||||
)
|
||||
else:
|
||||
# URL-based image - use as URI
|
||||
parts.append(types.Part.from_uri(file_uri=url, mime_type="image/jpeg"))
|
||||
contents.append(types.Content(role=role, parts=parts))
|
||||
|
||||
return system_instruction, contents
|
||||
|
||||
async def chat_completion(
|
||||
self,
|
||||
request: ChatCompletionRequest,
|
||||
model: str,
|
||||
) -> ChatCompletionResponse:
|
||||
"""Generate a chat completion via Google Gemini."""
|
||||
gemini_model = self.map_model(model) if model in OLLAMA_TO_GEMINI else model
|
||||
system_instruction, contents = self._convert_messages(request)
|
||||
|
||||
config: dict[str, Any] = {}
|
||||
if request.temperature is not None:
|
||||
config["temperature"] = request.temperature
|
||||
if request.max_tokens is not None:
|
||||
config["max_output_tokens"] = request.max_tokens
|
||||
if request.top_p is not None:
|
||||
config["top_p"] = request.top_p
|
||||
if request.stop:
|
||||
stop_seqs = request.stop if isinstance(request.stop, list) else [request.stop]
|
||||
config["stop_sequences"] = stop_seqs
|
||||
|
||||
gen_config = types.GenerateContentConfig(
|
||||
system_instruction=system_instruction,
|
||||
**config,
|
||||
)
|
||||
|
||||
logger.debug(f"Google Gemini request: {gemini_model}, messages: {len(contents)}")
|
||||
|
||||
response = await self.client.aio.models.generate_content(
|
||||
model=gemini_model,
|
||||
contents=contents,
|
||||
config=gen_config,
|
||||
)
|
||||
|
||||
content = response.text or ""
|
||||
usage_meta = response.usage_metadata
|
||||
|
||||
return ChatCompletionResponse(
|
||||
model=f"google/{gemini_model}",
|
||||
choices=[
|
||||
Choice(
|
||||
index=0,
|
||||
message=MessageResponse(content=content),
|
||||
finish_reason="stop",
|
||||
)
|
||||
],
|
||||
usage=Usage(
|
||||
prompt_tokens=usage_meta.prompt_token_count if usage_meta else 0,
|
||||
completion_tokens=usage_meta.candidates_token_count if usage_meta else 0,
|
||||
total_tokens=usage_meta.total_token_count if usage_meta else 0,
|
||||
),
|
||||
)
|
||||
|
||||
async def chat_completion_stream(
|
||||
self,
|
||||
request: ChatCompletionRequest,
|
||||
model: str,
|
||||
) -> AsyncIterator[ChatCompletionStreamResponse]:
|
||||
"""Generate a streaming chat completion via Google Gemini."""
|
||||
gemini_model = self.map_model(model) if model in OLLAMA_TO_GEMINI else model
|
||||
system_instruction, contents = self._convert_messages(request)
|
||||
|
||||
config: dict[str, Any] = {}
|
||||
if request.temperature is not None:
|
||||
config["temperature"] = request.temperature
|
||||
if request.max_tokens is not None:
|
||||
config["max_output_tokens"] = request.max_tokens
|
||||
if request.top_p is not None:
|
||||
config["top_p"] = request.top_p
|
||||
|
||||
gen_config = types.GenerateContentConfig(
|
||||
system_instruction=system_instruction,
|
||||
**config,
|
||||
)
|
||||
|
||||
# First chunk with role
|
||||
yield ChatCompletionStreamResponse(
|
||||
model=f"google/{gemini_model}",
|
||||
choices=[
|
||||
StreamChoice(
|
||||
delta=DeltaContent(role="assistant"),
|
||||
finish_reason=None,
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
async for chunk in await self.client.aio.models.generate_content_stream(
|
||||
model=gemini_model,
|
||||
contents=contents,
|
||||
config=gen_config,
|
||||
):
|
||||
text = chunk.text
|
||||
if text:
|
||||
yield ChatCompletionStreamResponse(
|
||||
model=f"google/{gemini_model}",
|
||||
choices=[
|
||||
StreamChoice(
|
||||
delta=DeltaContent(content=text),
|
||||
finish_reason=None,
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
# Final chunk
|
||||
yield ChatCompletionStreamResponse(
|
||||
model=f"google/{gemini_model}",
|
||||
choices=[
|
||||
StreamChoice(
|
||||
delta=DeltaContent(),
|
||||
finish_reason="stop",
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
async def list_models(self) -> list[ModelInfo]:
|
||||
"""List available Google Gemini models."""
|
||||
# Return a static list of commonly used models
|
||||
return [
|
||||
ModelInfo(id="google/gemini-2.0-flash", owned_by="google"),
|
||||
ModelInfo(id="google/gemini-2.5-pro", owned_by="google"),
|
||||
ModelInfo(id="google/gemini-2.5-flash", owned_by="google"),
|
||||
]
|
||||
|
||||
async def embeddings(
|
||||
self,
|
||||
request: EmbeddingRequest,
|
||||
model: str,
|
||||
) -> EmbeddingResponse:
|
||||
"""Generate embeddings via Google Gemini."""
|
||||
inputs = request.input if isinstance(request.input, list) else [request.input]
|
||||
|
||||
result = await self.client.aio.models.embed_content(
|
||||
model="text-embedding-004",
|
||||
contents=inputs,
|
||||
)
|
||||
|
||||
return EmbeddingResponse(
|
||||
data=[
|
||||
EmbeddingData(index=i, embedding=emb.values)
|
||||
for i, emb in enumerate(result.embeddings)
|
||||
],
|
||||
model="google/text-embedding-004",
|
||||
usage=Usage(
|
||||
prompt_tokens=sum(len(t.split()) for t in inputs),
|
||||
total_tokens=sum(len(t.split()) for t in inputs),
|
||||
),
|
||||
)
|
||||
|
||||
async def health_check(self) -> dict[str, Any]:
|
||||
"""Check Google API health."""
|
||||
try:
|
||||
# Quick test: list models
|
||||
response = await self.client.aio.models.list(config={"page_size": 1})
|
||||
return {
|
||||
"status": "healthy",
|
||||
"provider": "google",
|
||||
}
|
||||
except Exception as e:
|
||||
return {
|
||||
"status": "unhealthy",
|
||||
"provider": "google",
|
||||
"error": str(e),
|
||||
}
|
||||
|
||||
async def close(self) -> None:
|
||||
"""No cleanup needed for Google client."""
|
||||
pass
|
||||
|
|
@ -1,6 +1,8 @@
|
|||
"""Provider routing logic for mana-llm."""
|
||||
"""Provider routing logic for mana-llm with auto-fallback support."""
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import time
|
||||
from collections.abc import AsyncIterator
|
||||
from typing import Any
|
||||
|
||||
|
|
@ -22,10 +24,19 @@ logger = logging.getLogger(__name__)
|
|||
|
||||
|
||||
class ProviderRouter:
|
||||
"""Routes requests to appropriate LLM providers based on model prefix."""
|
||||
"""Routes requests to appropriate LLM providers with auto-fallback.
|
||||
|
||||
When auto_fallback_enabled is True and a Google API key is configured:
|
||||
- Ollama requests that fail or exceed max_concurrent are automatically
|
||||
retried on Google Gemini with model mapping.
|
||||
- Explicit provider requests (e.g., openrouter/...) are never fallback-routed.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.providers: dict[str, LLMProvider] = {}
|
||||
self._ollama_concurrent: int = 0
|
||||
self._ollama_health_cache: tuple[dict[str, Any] | None, float] = (None, 0)
|
||||
self._health_cache_ttl: float = 5.0 # seconds
|
||||
self._initialize_providers()
|
||||
|
||||
def _initialize_providers(self) -> None:
|
||||
|
|
@ -34,6 +45,16 @@ class ProviderRouter:
|
|||
self.providers["ollama"] = OllamaProvider()
|
||||
logger.info(f"Initialized Ollama provider at {settings.ollama_url}")
|
||||
|
||||
# Google Gemini (fallback provider)
|
||||
if settings.google_api_key:
|
||||
from .google import GoogleProvider
|
||||
|
||||
self.providers["google"] = GoogleProvider(
|
||||
api_key=settings.google_api_key,
|
||||
default_model=settings.google_default_model,
|
||||
)
|
||||
logger.info("Initialized Google Gemini provider (fallback)")
|
||||
|
||||
# OpenRouter (if API key configured)
|
||||
if settings.openrouter_api_key:
|
||||
self.providers["openrouter"] = OpenAICompatProvider(
|
||||
|
|
@ -63,17 +84,12 @@ class ProviderRouter:
|
|||
logger.info("Initialized Together provider")
|
||||
|
||||
def _parse_model(self, model: str) -> tuple[str, str]:
|
||||
"""
|
||||
Parse model string into (provider, model_name).
|
||||
|
||||
Format: "provider/model" or just "model" (defaults to ollama)
|
||||
"""
|
||||
"""Parse model string into (provider, model_name)."""
|
||||
if "/" in model:
|
||||
parts = model.split("/", 1)
|
||||
provider = parts[0].lower()
|
||||
model_name = parts[1]
|
||||
else:
|
||||
# Default to Ollama
|
||||
provider = "ollama"
|
||||
model_name = model
|
||||
|
||||
|
|
@ -89,39 +105,141 @@ class ProviderRouter:
|
|||
)
|
||||
return self.providers[provider_name]
|
||||
|
||||
def _can_fallback_to_google(self, provider_name: str) -> bool:
|
||||
"""Check if a request can be fallback-routed to Google."""
|
||||
return (
|
||||
settings.auto_fallback_enabled
|
||||
and provider_name == "ollama"
|
||||
and "google" in self.providers
|
||||
)
|
||||
|
||||
def _should_use_ollama(self) -> bool:
|
||||
"""Determine if Ollama should handle the request based on load."""
|
||||
return self._ollama_concurrent < settings.ollama_max_concurrent
|
||||
|
||||
async def _get_ollama_health_cached(self) -> dict[str, Any]:
|
||||
"""Get Ollama health with caching (5s TTL)."""
|
||||
cached, cached_at = self._ollama_health_cache
|
||||
if cached is not None and (time.time() - cached_at) < self._health_cache_ttl:
|
||||
return cached
|
||||
|
||||
try:
|
||||
provider = self.providers.get("ollama")
|
||||
if provider:
|
||||
result = await provider.health_check()
|
||||
else:
|
||||
result = {"status": "unhealthy", "error": "no ollama provider"}
|
||||
except Exception as e:
|
||||
result = {"status": "unhealthy", "error": str(e)}
|
||||
|
||||
self._ollama_health_cache = (result, time.time())
|
||||
return result
|
||||
|
||||
async def _fallback_to_google(
|
||||
self,
|
||||
request: ChatCompletionRequest,
|
||||
model_name: str,
|
||||
original_error: Exception | None = None,
|
||||
) -> ChatCompletionResponse:
|
||||
"""Route a request to Google Gemini as fallback."""
|
||||
from .google import GoogleProvider
|
||||
|
||||
google = self.providers["google"]
|
||||
assert isinstance(google, GoogleProvider)
|
||||
|
||||
gemini_model = google.map_model(model_name)
|
||||
reason = f"error: {original_error}" if original_error else "overloaded"
|
||||
logger.warning(
|
||||
f"Falling back to Google Gemini ({gemini_model}) for ollama/{model_name} ({reason})"
|
||||
)
|
||||
return await google.chat_completion(request, gemini_model)
|
||||
|
||||
async def chat_completion(
|
||||
self,
|
||||
request: ChatCompletionRequest,
|
||||
) -> ChatCompletionResponse:
|
||||
"""Route chat completion request to appropriate provider."""
|
||||
"""Route chat completion request with auto-fallback."""
|
||||
provider_name, model_name = self._parse_model(request.model)
|
||||
provider = self._get_provider(provider_name)
|
||||
|
||||
logger.info(f"Routing chat completion to {provider_name}/{model_name}")
|
||||
# Non-Ollama providers: direct routing, no fallback
|
||||
if provider_name != "ollama":
|
||||
provider = self._get_provider(provider_name)
|
||||
logger.info(f"Routing chat completion to {provider_name}/{model_name}")
|
||||
return await provider.chat_completion(request, model_name)
|
||||
|
||||
# Ollama with fallback logic
|
||||
can_fallback = self._can_fallback_to_google(provider_name)
|
||||
|
||||
# Check if Ollama is overloaded
|
||||
if can_fallback and not self._should_use_ollama():
|
||||
return await self._fallback_to_google(request, model_name)
|
||||
|
||||
# Try Ollama first
|
||||
provider = self._get_provider("ollama")
|
||||
logger.info(f"Routing chat completion to ollama/{model_name}")
|
||||
self._ollama_concurrent += 1
|
||||
|
||||
try:
|
||||
return await provider.chat_completion(request, model_name)
|
||||
except Exception as e:
|
||||
logger.error(f"Chat completion failed on {provider_name}: {e}")
|
||||
# Could implement fallback logic here
|
||||
logger.error(f"Chat completion failed on ollama: {e}")
|
||||
if can_fallback:
|
||||
return await self._fallback_to_google(request, model_name, e)
|
||||
raise
|
||||
finally:
|
||||
self._ollama_concurrent -= 1
|
||||
|
||||
async def chat_completion_stream(
|
||||
self,
|
||||
request: ChatCompletionRequest,
|
||||
) -> AsyncIterator[ChatCompletionStreamResponse]:
|
||||
"""Route streaming chat completion request to appropriate provider."""
|
||||
"""Route streaming chat completion with auto-fallback."""
|
||||
provider_name, model_name = self._parse_model(request.model)
|
||||
provider = self._get_provider(provider_name)
|
||||
|
||||
logger.info(f"Routing streaming chat completion to {provider_name}/{model_name}")
|
||||
# Non-Ollama: direct
|
||||
if provider_name != "ollama":
|
||||
provider = self._get_provider(provider_name)
|
||||
logger.info(f"Routing streaming to {provider_name}/{model_name}")
|
||||
async for chunk in provider.chat_completion_stream(request, model_name):
|
||||
yield chunk
|
||||
return
|
||||
|
||||
# Ollama with fallback
|
||||
can_fallback = self._can_fallback_to_google(provider_name)
|
||||
|
||||
if can_fallback and not self._should_use_ollama():
|
||||
from .google import GoogleProvider
|
||||
|
||||
google = self.providers["google"]
|
||||
assert isinstance(google, GoogleProvider)
|
||||
gemini_model = google.map_model(model_name)
|
||||
logger.warning(f"Streaming fallback to Google Gemini ({gemini_model})")
|
||||
async for chunk in google.chat_completion_stream(request, gemini_model):
|
||||
yield chunk
|
||||
return
|
||||
|
||||
provider = self._get_provider("ollama")
|
||||
logger.info(f"Routing streaming to ollama/{model_name}")
|
||||
self._ollama_concurrent += 1
|
||||
|
||||
try:
|
||||
async for chunk in provider.chat_completion_stream(request, model_name):
|
||||
yield chunk
|
||||
except Exception as e:
|
||||
logger.error(f"Streaming chat completion failed on {provider_name}: {e}")
|
||||
raise
|
||||
logger.error(f"Streaming failed on ollama: {e}")
|
||||
if can_fallback:
|
||||
from .google import GoogleProvider
|
||||
|
||||
google = self.providers["google"]
|
||||
assert isinstance(google, GoogleProvider)
|
||||
gemini_model = google.map_model(model_name)
|
||||
logger.warning(f"Streaming fallback to Google Gemini ({gemini_model})")
|
||||
async for chunk in google.chat_completion_stream(request, gemini_model):
|
||||
yield chunk
|
||||
else:
|
||||
raise
|
||||
finally:
|
||||
self._ollama_concurrent -= 1
|
||||
|
||||
async def embeddings(
|
||||
self,
|
||||
|
|
@ -175,11 +293,21 @@ class ProviderRouter:
|
|||
all_healthy = all(r.get("status") == "healthy" for r in results.values())
|
||||
any_healthy = any(r.get("status") == "healthy" for r in results.values())
|
||||
|
||||
return {
|
||||
status_info: dict[str, Any] = {
|
||||
"status": "healthy" if all_healthy else ("degraded" if any_healthy else "unhealthy"),
|
||||
"providers": results,
|
||||
}
|
||||
|
||||
# Include fallback info
|
||||
if settings.auto_fallback_enabled and "google" in self.providers:
|
||||
status_info["fallback"] = {
|
||||
"enabled": True,
|
||||
"ollama_concurrent": self._ollama_concurrent,
|
||||
"ollama_max_concurrent": settings.ollama_max_concurrent,
|
||||
}
|
||||
|
||||
return status_info
|
||||
|
||||
async def close(self) -> None:
|
||||
"""Close all providers."""
|
||||
for provider in self.providers.values():
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue