feat(mana-llm): add central LLM abstraction service

Python/FastAPI service providing unified OpenAI-compatible API for
Ollama and cloud LLM providers (OpenRouter, Groq, Together).

Features:
- Chat completions with streaming (SSE)
- Vision/multimodal support
- Embeddings generation
- Multi-provider routing (provider/model format)
- Prometheus metrics
- Optional Redis caching
This commit is contained in:
Till-JS 2026-01-29 22:01:00 +01:00
parent 4a3295d1d0
commit 1495dbe476
29 changed files with 2270 additions and 1 deletions

View file

@ -0,0 +1,25 @@
# Service
PORT=3025
LOG_LEVEL=info
# Ollama (Primary)
OLLAMA_URL=http://localhost:11434
OLLAMA_DEFAULT_MODEL=gemma3:4b
OLLAMA_TIMEOUT=120
# OpenRouter (Cloud Fallback)
OPENROUTER_API_KEY=sk-or-v1-xxx
OPENROUTER_DEFAULT_MODEL=meta-llama/llama-3.1-8b-instruct
# Groq (Optional)
GROQ_API_KEY=gsk_xxx
# Together (Optional)
TOGETHER_API_KEY=xxx
# Caching (Optional)
REDIS_URL=redis://localhost:6379
CACHE_TTL=3600
# CORS
CORS_ORIGINS=http://localhost:5173,https://mana.how

31
services/mana-llm/.gitignore vendored Normal file
View file

@ -0,0 +1,31 @@
# Python
__pycache__/
*.py[cod]
*$py.class
*.so
.Python
venv/
.venv/
ENV/
.env
# IDE
.idea/
.vscode/
*.swp
*.swo
# Testing
.pytest_cache/
.coverage
htmlcov/
.tox/
# Build
dist/
build/
*.egg-info/
# Local
.env.local
*.log

292
services/mana-llm/CLAUDE.md Normal file
View file

@ -0,0 +1,292 @@
# mana-llm
Central LLM abstraction service providing a unified OpenAI-compatible API for Ollama and cloud LLM providers.
## Overview
mana-llm acts as a central gateway for all LLM requests in the monorepo, providing:
- Unified OpenAI-compatible API
- Provider routing (Ollama, OpenRouter, Groq, Together)
- Streaming via Server-Sent Events (SSE)
- Vision/multimodal support
- Embeddings generation
- Prometheus metrics
## Architecture
```
┌─────────────────────────────────────────────────────────────────────┐
│ Consumer Apps │
│ matrix-ollama-bot │ telegram-ollama-bot │ chat-backend │ etc. │
└────────────────────────────────┬────────────────────────────────────┘
│ HTTP/SSE
┌─────────────────────────────────────────────────────────────────────┐
│ mana-llm (Port 3025) │
│ ┌─────────────┐ ┌─────────────┐ ┌─────────────┐ │
│ │ Router │ │ Cache │ │ Metrics │ │
│ │ (Provider) │ │ (Redis) │ │ (Prometheus)│ │
│ └──────┬──────┘ └─────────────┘ └─────────────┘ │
│ │ │
│ ┌──────┴──────────────────────────────────────────┐ │
│ │ Provider Adapters │ │
│ │ ┌──────────┐ ┌──────────┐ ┌──────────────┐ │ │
│ │ │ Ollama │ │ OpenAI │ │ OpenRouter │ │ │
│ │ │ Adapter │ │ Adapter │ │ Adapter │ │ │
│ │ └──────────┘ └──────────┘ └──────────────┘ │ │
│ └─────────────────────────────────────────────────┘ │
└─────────────────────────────────────────────────────────────────────┘
```
## Quick Start
### Prerequisites
- Python 3.11+
- Ollama running locally (http://localhost:11434)
- Redis (optional, for caching)
### Development
```bash
cd services/mana-llm
# Create virtual environment
python -m venv venv
source venv/bin/activate # or venv\Scripts\activate on Windows
# Install dependencies
pip install -r requirements.txt
# Copy environment file
cp .env.example .env
# Start Redis (optional)
docker-compose -f docker-compose.dev.yml up -d
# Run service
python -m uvicorn src.main:app --port 3025 --reload
```
### Docker
```bash
# Full stack (mana-llm + Redis)
docker-compose up -d
# View logs
docker-compose logs -f mana-llm
```
## API Endpoints
### Chat Completions
```bash
# Non-streaming
curl -X POST http://localhost:3025/v1/chat/completions \
-H "Content-Type: application/json" \
-d '{
"model": "ollama/gemma3:4b",
"messages": [{"role": "user", "content": "Hello!"}],
"stream": false
}'
# Streaming (SSE)
curl -X POST http://localhost:3025/v1/chat/completions \
-H "Content-Type: application/json" \
-d '{
"model": "ollama/gemma3:4b",
"messages": [{"role": "user", "content": "Hello!"}],
"stream": true
}'
```
### Vision/Multimodal
```bash
curl -X POST http://localhost:3025/v1/chat/completions \
-H "Content-Type: application/json" \
-d '{
"model": "ollama/llava:7b",
"messages": [{
"role": "user",
"content": [
{"type": "text", "text": "What is in this image?"},
{"type": "image_url", "image_url": {"url": "data:image/png;base64,..."}}
]
}]
}'
```
### Models
```bash
# List all models
curl http://localhost:3025/v1/models
# Get specific model
curl http://localhost:3025/v1/models/ollama/gemma3:4b
```
### Embeddings
```bash
curl -X POST http://localhost:3025/v1/embeddings \
-H "Content-Type: application/json" \
-d '{
"model": "ollama/nomic-embed-text",
"input": "Text to embed"
}'
```
### Health & Metrics
```bash
# Health check
curl http://localhost:3025/health
# Prometheus metrics
curl http://localhost:3025/metrics
```
## Provider Routing
Models use the format `provider/model`:
| Model | Provider | Target |
|-------|----------|--------|
| `ollama/gemma3:4b` | Ollama | localhost:11434 |
| `ollama/llava:7b` | Ollama | localhost:11434 |
| `openrouter/meta-llama/llama-3.1-8b-instruct` | OpenRouter | api.openrouter.ai |
| `groq/llama-3.1-8b-instant` | Groq | api.groq.com |
| `together/meta-llama/Meta-Llama-3.1-8B-Instruct-Turbo` | Together | api.together.xyz |
**Default:** If no provider prefix is given (e.g., `gemma3:4b`), Ollama is used.
## Configuration
Environment variables (see `.env.example`):
| Variable | Default | Description |
|----------|---------|-------------|
| `PORT` | 3025 | Service port |
| `LOG_LEVEL` | info | Logging level |
| `OLLAMA_URL` | http://localhost:11434 | Ollama server URL |
| `OLLAMA_DEFAULT_MODEL` | gemma3:4b | Default Ollama model |
| `OLLAMA_TIMEOUT` | 120 | Ollama request timeout (seconds) |
| `OPENROUTER_API_KEY` | - | OpenRouter API key |
| `GROQ_API_KEY` | - | Groq API key |
| `TOGETHER_API_KEY` | - | Together API key |
| `REDIS_URL` | - | Redis URL for caching |
| `CACHE_TTL` | 3600 | Cache TTL in seconds |
| `CORS_ORIGINS` | localhost | Allowed CORS origins |
## Project Structure
```
services/mana-llm/
├── src/
│ ├── main.py # FastAPI app entry point
│ ├── config.py # Settings via pydantic-settings
│ ├── providers/
│ │ ├── base.py # Abstract provider interface
│ │ ├── ollama.py # Ollama provider
│ │ ├── openai_compat.py # OpenAI-compatible provider
│ │ └── router.py # Provider routing logic
│ ├── models/
│ │ ├── requests.py # Request Pydantic models
│ │ └── responses.py # Response Pydantic models
│ ├── streaming/
│ │ └── sse.py # SSE response handling
│ └── utils/
│ ├── cache.py # Redis caching
│ └── metrics.py # Prometheus metrics
├── tests/
│ ├── test_api.py # API endpoint tests
│ ├── test_providers.py # Provider tests
│ └── test_streaming.py # Streaming tests
├── Dockerfile
├── docker-compose.yml
├── docker-compose.dev.yml
├── requirements.txt
├── pyproject.toml
└── .env.example
```
## Testing
```bash
# Run tests
pytest
# Run with coverage
pytest --cov=src
# Run specific test file
pytest tests/test_providers.py -v
```
## Integration Example
### TypeScript/Node.js Client
```typescript
// Using fetch
const response = await fetch('http://localhost:3025/v1/chat/completions', {
method: 'POST',
headers: { 'Content-Type': 'application/json' },
body: JSON.stringify({
model: 'ollama/gemma3:4b',
messages: [{ role: 'user', content: 'Hello!' }],
stream: false,
}),
});
const data = await response.json();
console.log(data.choices[0].message.content);
```
### Streaming with EventSource
```typescript
const response = await fetch('http://localhost:3025/v1/chat/completions', {
method: 'POST',
headers: { 'Content-Type': 'application/json' },
body: JSON.stringify({
model: 'ollama/gemma3:4b',
messages: [{ role: 'user', content: 'Hello!' }],
stream: true,
}),
});
const reader = response.body?.getReader();
const decoder = new TextDecoder();
while (true) {
const { done, value } = await reader!.read();
if (done) break;
const chunk = decoder.decode(value);
const lines = chunk.split('\n').filter(line => line.startsWith('data: '));
for (const line of lines) {
const data = line.slice(6);
if (data === '[DONE]') break;
const parsed = JSON.parse(data);
const content = parsed.choices[0]?.delta?.content;
if (content) process.stdout.write(content);
}
}
```
## Related Services
| Service | Port | Description |
|---------|------|-------------|
| mana-tts | 3022 | Text-to-speech service |
| mana-stt | 3023 | Speech-to-text service |
| mana-search | 3021 | Web search & extraction |
| matrix-ollama-bot | - | Matrix bot (consumer) |
| telegram-ollama-bot | - | Telegram bot (consumer) |

View file

@ -0,0 +1,45 @@
# Build stage
FROM python:3.12-slim AS builder
WORKDIR /app
# Install build dependencies
RUN apt-get update && apt-get install -y --no-install-recommends \
build-essential \
&& rm -rf /var/lib/apt/lists/*
# Copy requirements first for caching
COPY requirements.txt .
RUN pip install --no-cache-dir --user -r requirements.txt
# Production stage
FROM python:3.12-slim
WORKDIR /app
# Create non-root user
RUN useradd -m -u 1000 appuser
# Copy installed packages from builder
COPY --from=builder /root/.local /home/appuser/.local
# Copy application code
COPY --chown=appuser:appuser src/ ./src/
# Set environment
ENV PATH=/home/appuser/.local/bin:$PATH
ENV PYTHONUNBUFFERED=1
ENV PYTHONDONTWRITEBYTECODE=1
# Switch to non-root user
USER appuser
# Expose port
EXPOSE 3025
# Health check
HEALTHCHECK --interval=30s --timeout=10s --start-period=5s --retries=3 \
CMD python -c "import httpx; httpx.get('http://localhost:3025/health').raise_for_status()"
# Run application
CMD ["python", "-m", "uvicorn", "src.main:app", "--host", "0.0.0.0", "--port", "3025"]

View file

@ -0,0 +1,15 @@
# Development compose - only Redis (run Python locally)
version: "3.8"
services:
redis:
image: redis:7-alpine
container_name: mana-llm-redis-dev
ports:
- "6380:6379"
volumes:
- redis-data-dev:/data
restart: unless-stopped
volumes:
redis-data-dev:

View file

@ -0,0 +1,50 @@
version: "3.8"
services:
mana-llm:
build:
context: .
dockerfile: Dockerfile
container_name: mana-llm
ports:
- "3025:3025"
environment:
- PORT=3025
- LOG_LEVEL=info
- OLLAMA_URL=http://host.docker.internal:11434
- OLLAMA_DEFAULT_MODEL=gemma3:4b
- OLLAMA_TIMEOUT=120
- REDIS_URL=redis://redis:6379
# Add API keys via .env file
- OPENROUTER_API_KEY=${OPENROUTER_API_KEY:-}
- GROQ_API_KEY=${GROQ_API_KEY:-}
- TOGETHER_API_KEY=${TOGETHER_API_KEY:-}
- CORS_ORIGINS=http://localhost:5173,http://localhost:3000,https://mana.how
depends_on:
- redis
restart: unless-stopped
healthcheck:
test: ["CMD", "python", "-c", "import httpx; httpx.get('http://localhost:3025/health').raise_for_status()"]
interval: 30s
timeout: 10s
retries: 3
start_period: 10s
extra_hosts:
- "host.docker.internal:host-gateway"
redis:
image: redis:7-alpine
container_name: mana-llm-redis
ports:
- "6380:6379"
volumes:
- redis-data:/data
restart: unless-stopped
healthcheck:
test: ["CMD", "redis-cli", "ping"]
interval: 10s
timeout: 5s
retries: 5
volumes:
redis-data:

View file

@ -0,0 +1,38 @@
[project]
name = "mana-llm"
version = "0.1.0"
description = "Central LLM abstraction service for Ollama and OpenAI-compatible APIs"
requires-python = ">=3.11"
dependencies = [
"fastapi>=0.115.0",
"uvicorn[standard]>=0.32.0",
"pydantic>=2.10.0",
"pydantic-settings>=2.6.0",
"httpx>=0.28.0",
"sse-starlette>=2.2.0",
"redis>=5.2.0",
"prometheus-client>=0.21.0",
]
[project.optional-dependencies]
dev = [
"pytest>=8.3.0",
"pytest-asyncio>=0.24.0",
"pytest-httpx>=0.35.0",
"ruff>=0.8.0",
]
[build-system]
requires = ["hatchling"]
build-backend = "hatchling.build"
[tool.ruff]
line-length = 100
target-version = "py311"
[tool.ruff.lint]
select = ["E", "F", "I", "W"]
[tool.pytest.ini_options]
asyncio_mode = "auto"
testpaths = ["tests"]

View file

@ -0,0 +1,23 @@
# Core
fastapi>=0.115.0
uvicorn[standard]>=0.32.0
pydantic>=2.10.0
pydantic-settings>=2.6.0
# HTTP Client
httpx>=0.28.0
# Streaming
sse-starlette>=2.2.0
# Caching (optional)
redis>=5.2.0
# Metrics
prometheus-client>=0.21.0
# Dev
pytest>=8.3.0
pytest-asyncio>=0.24.0
pytest-httpx>=0.35.0
ruff>=0.8.0

View file

@ -0,0 +1,3 @@
"""mana-llm - Central LLM abstraction service."""
__version__ = "0.1.0"

View file

@ -0,0 +1,48 @@
"""Configuration settings for mana-llm service."""
from pydantic_settings import BaseSettings
class Settings(BaseSettings):
"""Application settings loaded from environment variables."""
# Service
port: int = 3025
log_level: str = "info"
# Ollama (Primary provider)
ollama_url: str = "http://localhost:11434"
ollama_default_model: str = "gemma3:4b"
ollama_timeout: int = 120
# OpenRouter (Cloud fallback)
openrouter_api_key: str | None = None
openrouter_base_url: str = "https://openrouter.ai/api/v1"
openrouter_default_model: str = "meta-llama/llama-3.1-8b-instruct"
# Groq (Optional)
groq_api_key: str | None = None
groq_base_url: str = "https://api.groq.com/openai/v1"
# Together (Optional)
together_api_key: str | None = None
together_base_url: str = "https://api.together.xyz/v1"
# Caching (Optional)
redis_url: str | None = None
cache_ttl: int = 3600
# CORS
cors_origins: str = "http://localhost:5173,https://mana.how"
@property
def cors_origins_list(self) -> list[str]:
"""Parse CORS origins from comma-separated string."""
return [origin.strip() for origin in self.cors_origins.split(",")]
class Config:
env_file = ".env"
env_file_encoding = "utf-8"
settings = Settings()

View file

@ -0,0 +1,235 @@
"""Main FastAPI application for mana-llm service."""
import logging
import time
from contextlib import asynccontextmanager
from typing import Any
from fastapi import FastAPI, HTTPException, Request
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import Response
from sse_starlette.sse import EventSourceResponse
from src.config import settings
from src.models import (
ChatCompletionRequest,
ChatCompletionResponse,
EmbeddingRequest,
EmbeddingResponse,
ModelInfo,
ModelsResponse,
)
from src.providers import ProviderRouter
from src.streaming import stream_chat_completion
from src.utils.cache import close_redis
from src.utils.metrics import get_metrics, record_llm_error, record_llm_request
# Configure logging
logging.basicConfig(
level=getattr(logging, settings.log_level.upper()),
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
)
logger = logging.getLogger(__name__)
# Global router instance
router: ProviderRouter | None = None
@asynccontextmanager
async def lifespan(app: FastAPI):
"""Application lifespan management."""
global router
# Startup
logger.info("Starting mana-llm service...")
router = ProviderRouter()
logger.info(f"Initialized providers: {list(router.providers.keys())}")
yield
# Shutdown
logger.info("Shutting down mana-llm service...")
if router:
await router.close()
await close_redis()
# Create FastAPI app
app = FastAPI(
title="mana-llm",
description="Central LLM abstraction service for Ollama and OpenAI-compatible APIs",
version="0.1.0",
lifespan=lifespan,
)
# Add CORS middleware
app.add_middleware(
CORSMiddleware,
allow_origins=settings.cors_origins_list,
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# Health endpoint
@app.get("/health")
async def health_check() -> dict[str, Any]:
"""Check service health and provider status."""
if router is None:
return {"status": "unhealthy", "error": "Router not initialized"}
provider_health = await router.health_check()
return {
"status": provider_health["status"],
"service": "mana-llm",
"version": "0.1.0",
"providers": provider_health["providers"],
}
# Metrics endpoint
@app.get("/metrics")
async def metrics() -> Response:
"""Prometheus metrics endpoint."""
return Response(content=get_metrics(), media_type="text/plain")
# Models endpoints
@app.get("/v1/models", response_model=ModelsResponse)
async def list_models() -> ModelsResponse:
"""List all available models from all providers."""
if router is None:
raise HTTPException(status_code=503, detail="Service not ready")
models = await router.list_models()
return ModelsResponse(data=models)
@app.get("/v1/models/{model_id:path}")
async def get_model(model_id: str) -> ModelInfo:
"""Get specific model information."""
if router is None:
raise HTTPException(status_code=503, detail="Service not ready")
model = await router.get_model(model_id)
if model is None:
raise HTTPException(status_code=404, detail=f"Model '{model_id}' not found")
return model
# Chat completions endpoint
@app.post("/v1/chat/completions")
async def chat_completions(
request: ChatCompletionRequest,
http_request: Request,
) -> ChatCompletionResponse | EventSourceResponse:
"""
Create a chat completion.
Supports both streaming (SSE) and non-streaming responses based on the
`stream` parameter in the request body.
"""
if router is None:
raise HTTPException(status_code=503, detail="Service not ready")
# Parse provider and model for metrics
model_parts = request.model.split("/", 1)
provider = model_parts[0] if len(model_parts) > 1 else "ollama"
model = model_parts[1] if len(model_parts) > 1 else request.model
start_time = time.time()
try:
if request.stream:
# Streaming response via SSE
logger.info(f"Streaming chat completion: {request.model}")
async def generate():
async for chunk in stream_chat_completion(router, request):
yield chunk
record_llm_request(provider, model, streaming=True)
return EventSourceResponse(
generate(),
media_type="text/event-stream",
)
else:
# Non-streaming response
logger.info(f"Chat completion: {request.model}")
response = await router.chat_completion(request)
# Record metrics
latency = time.time() - start_time
record_llm_request(
provider=provider,
model=model,
streaming=False,
prompt_tokens=response.usage.prompt_tokens,
completion_tokens=response.usage.completion_tokens,
latency=latency,
)
return response
except ValueError as e:
logger.error(f"Invalid request: {e}")
record_llm_error(provider, model, "invalid_request")
raise HTTPException(status_code=400, detail=str(e))
except Exception as e:
logger.error(f"Chat completion failed: {e}")
record_llm_error(provider, model, "server_error")
raise HTTPException(status_code=500, detail=str(e))
# Embeddings endpoint
@app.post("/v1/embeddings", response_model=EmbeddingResponse)
async def create_embeddings(request: EmbeddingRequest) -> EmbeddingResponse:
"""Create embeddings for the input text."""
if router is None:
raise HTTPException(status_code=503, detail="Service not ready")
# Parse provider and model for metrics
model_parts = request.model.split("/", 1)
provider = model_parts[0] if len(model_parts) > 1 else "ollama"
model = model_parts[1] if len(model_parts) > 1 else request.model
start_time = time.time()
try:
logger.info(f"Creating embeddings: {request.model}")
response = await router.embeddings(request)
latency = time.time() - start_time
record_llm_request(
provider=provider,
model=model,
streaming=False,
prompt_tokens=response.usage.prompt_tokens,
latency=latency,
)
return response
except ValueError as e:
logger.error(f"Invalid embedding request: {e}")
record_llm_error(provider, model, "invalid_request")
raise HTTPException(status_code=400, detail=str(e))
except Exception as e:
logger.error(f"Embeddings failed: {e}")
record_llm_error(provider, model, "server_error")
raise HTTPException(status_code=500, detail=str(e))
if __name__ == "__main__":
import uvicorn
uvicorn.run(
"src.main:app",
host="0.0.0.0",
port=settings.port,
reload=True,
log_level=settings.log_level,
)

View file

@ -0,0 +1,22 @@
"""Pydantic models for OpenAI-compatible API."""
from .requests import ChatCompletionRequest, EmbeddingRequest
from .responses import (
ChatCompletionResponse,
ChatCompletionStreamResponse,
EmbeddingResponse,
ModelInfo,
ModelsResponse,
Usage,
)
__all__ = [
"ChatCompletionRequest",
"ChatCompletionResponse",
"ChatCompletionStreamResponse",
"EmbeddingRequest",
"EmbeddingResponse",
"ModelInfo",
"ModelsResponse",
"Usage",
]

View file

@ -0,0 +1,57 @@
"""Request models for OpenAI-compatible API."""
from typing import Literal
from pydantic import BaseModel, Field
class TextContent(BaseModel):
"""Text content in a message."""
type: Literal["text"] = "text"
text: str
class ImageUrl(BaseModel):
"""Image URL reference."""
url: str # Can be http(s):// or data:image/...;base64,...
class ImageContent(BaseModel):
"""Image content in a message."""
type: Literal["image_url"] = "image_url"
image_url: ImageUrl
MessageContent = str | list[TextContent | ImageContent]
class Message(BaseModel):
"""A single message in the conversation."""
role: Literal["system", "user", "assistant"]
content: MessageContent
class ChatCompletionRequest(BaseModel):
"""Request body for chat completions endpoint."""
model: str = Field(..., description="Model identifier in format 'provider/model' or just 'model'")
messages: list[Message] = Field(..., min_length=1)
stream: bool = False
temperature: float | None = Field(default=None, ge=0.0, le=2.0)
max_tokens: int | None = Field(default=None, gt=0)
top_p: float | None = Field(default=None, ge=0.0, le=1.0)
frequency_penalty: float | None = Field(default=None, ge=-2.0, le=2.0)
presence_penalty: float | None = Field(default=None, ge=-2.0, le=2.0)
stop: str | list[str] | None = None
class EmbeddingRequest(BaseModel):
"""Request body for embeddings endpoint."""
model: str = Field(..., description="Model identifier")
input: str | list[str] = Field(..., description="Text(s) to embed")
encoding_format: Literal["float", "base64"] = "float"

View file

@ -0,0 +1,99 @@
"""Response models for OpenAI-compatible API."""
import time
import uuid
from typing import Literal
from pydantic import BaseModel, Field
class Usage(BaseModel):
"""Token usage information."""
prompt_tokens: int = 0
completion_tokens: int = 0
total_tokens: int = 0
class MessageResponse(BaseModel):
"""Response message from the model."""
role: Literal["assistant"] = "assistant"
content: str
class Choice(BaseModel):
"""A single completion choice."""
index: int = 0
message: MessageResponse
finish_reason: Literal["stop", "length", "content_filter"] | None = "stop"
class ChatCompletionResponse(BaseModel):
"""Response from chat completions endpoint (non-streaming)."""
id: str = Field(default_factory=lambda: f"chatcmpl-{uuid.uuid4().hex[:12]}")
object: Literal["chat.completion"] = "chat.completion"
created: int = Field(default_factory=lambda: int(time.time()))
model: str
choices: list[Choice]
usage: Usage = Field(default_factory=Usage)
class DeltaContent(BaseModel):
"""Delta content for streaming responses."""
role: Literal["assistant"] | None = None
content: str | None = None
class StreamChoice(BaseModel):
"""A single streaming choice."""
index: int = 0
delta: DeltaContent
finish_reason: Literal["stop", "length", "content_filter"] | None = None
class ChatCompletionStreamResponse(BaseModel):
"""Response chunk from chat completions endpoint (streaming)."""
id: str = Field(default_factory=lambda: f"chatcmpl-{uuid.uuid4().hex[:12]}")
object: Literal["chat.completion.chunk"] = "chat.completion.chunk"
created: int = Field(default_factory=lambda: int(time.time()))
model: str
choices: list[StreamChoice]
class ModelInfo(BaseModel):
"""Information about a model."""
id: str
object: Literal["model"] = "model"
created: int = Field(default_factory=lambda: int(time.time()))
owned_by: str = "mana-llm"
class ModelsResponse(BaseModel):
"""Response from models endpoint."""
object: Literal["list"] = "list"
data: list[ModelInfo]
class EmbeddingData(BaseModel):
"""A single embedding result."""
object: Literal["embedding"] = "embedding"
index: int = 0
embedding: list[float]
class EmbeddingResponse(BaseModel):
"""Response from embeddings endpoint."""
object: Literal["list"] = "list"
data: list[EmbeddingData]
model: str
usage: Usage = Field(default_factory=Usage)

View file

@ -0,0 +1,13 @@
"""LLM Provider implementations."""
from .base import LLMProvider
from .ollama import OllamaProvider
from .openai_compat import OpenAICompatProvider
from .router import ProviderRouter
__all__ = [
"LLMProvider",
"OllamaProvider",
"OpenAICompatProvider",
"ProviderRouter",
]

View file

@ -0,0 +1,61 @@
"""Abstract base class for LLM providers."""
from abc import ABC, abstractmethod
from collections.abc import AsyncIterator
from typing import Any
from src.models import (
ChatCompletionRequest,
ChatCompletionResponse,
ChatCompletionStreamResponse,
EmbeddingRequest,
EmbeddingResponse,
ModelInfo,
)
class LLMProvider(ABC):
"""Abstract base class for LLM providers."""
name: str = "base"
@abstractmethod
async def chat_completion(
self,
request: ChatCompletionRequest,
model: str,
) -> ChatCompletionResponse:
"""Generate a chat completion (non-streaming)."""
...
@abstractmethod
async def chat_completion_stream(
self,
request: ChatCompletionRequest,
model: str,
) -> AsyncIterator[ChatCompletionStreamResponse]:
"""Generate a chat completion (streaming)."""
...
@abstractmethod
async def list_models(self) -> list[ModelInfo]:
"""List available models."""
...
@abstractmethod
async def embeddings(
self,
request: EmbeddingRequest,
model: str,
) -> EmbeddingResponse:
"""Generate embeddings for input text."""
...
@abstractmethod
async def health_check(self) -> dict[str, Any]:
"""Check provider health status."""
...
async def close(self) -> None:
"""Clean up resources."""
pass

View file

@ -0,0 +1,289 @@
"""Ollama provider implementation."""
import json
import logging
from collections.abc import AsyncIterator
from typing import Any
import httpx
from src.config import settings
from src.models import (
ChatCompletionRequest,
ChatCompletionResponse,
ChatCompletionStreamResponse,
Choice,
DeltaContent,
EmbeddingData,
EmbeddingRequest,
EmbeddingResponse,
MessageResponse,
ModelInfo,
StreamChoice,
Usage,
)
from .base import LLMProvider
logger = logging.getLogger(__name__)
class OllamaProvider(LLMProvider):
"""Ollama LLM provider."""
name = "ollama"
def __init__(self, base_url: str | None = None, timeout: int | None = None):
self.base_url = (base_url or settings.ollama_url).rstrip("/")
self.timeout = timeout or settings.ollama_timeout
self._client: httpx.AsyncClient | None = None
@property
def client(self) -> httpx.AsyncClient:
"""Get or create HTTP client."""
if self._client is None or self._client.is_closed:
self._client = httpx.AsyncClient(
base_url=self.base_url,
timeout=httpx.Timeout(self.timeout),
)
return self._client
def _convert_messages(self, request: ChatCompletionRequest) -> list[dict[str, Any]]:
"""Convert OpenAI message format to Ollama format."""
messages = []
for msg in request.messages:
if isinstance(msg.content, str):
messages.append({"role": msg.role, "content": msg.content})
else:
# Handle multimodal content (vision)
text_parts = []
images = []
for part in msg.content:
if part.type == "text":
text_parts.append(part.text)
elif part.type == "image_url":
url = part.image_url.url
# Extract base64 data from data URL
if url.startswith("data:"):
# Format: data:image/png;base64,<base64_data>
base64_data = url.split(",", 1)[1] if "," in url else url
images.append(base64_data)
else:
# HTTP URL - Ollama expects base64, so we'd need to fetch
# For now, log warning and skip
logger.warning(f"HTTP image URLs not supported, skipping: {url[:50]}...")
message_data: dict[str, Any] = {
"role": msg.role,
"content": " ".join(text_parts),
}
if images:
message_data["images"] = images
messages.append(message_data)
return messages
async def chat_completion(
self,
request: ChatCompletionRequest,
model: str,
) -> ChatCompletionResponse:
"""Generate a chat completion (non-streaming)."""
payload: dict[str, Any] = {
"model": model,
"messages": self._convert_messages(request),
"stream": False,
}
# Add optional parameters
options: dict[str, Any] = {}
if request.temperature is not None:
options["temperature"] = request.temperature
if request.top_p is not None:
options["top_p"] = request.top_p
if request.max_tokens is not None:
options["num_predict"] = request.max_tokens
if request.stop:
options["stop"] = request.stop if isinstance(request.stop, list) else [request.stop]
if options:
payload["options"] = options
logger.debug(f"Ollama request: {model}, messages: {len(request.messages)}")
response = await self.client.post("/api/chat", json=payload)
response.raise_for_status()
data = response.json()
return ChatCompletionResponse(
model=f"ollama/{model}",
choices=[
Choice(
message=MessageResponse(content=data["message"]["content"]),
finish_reason="stop" if data.get("done") else None,
)
],
usage=Usage(
prompt_tokens=data.get("prompt_eval_count", 0),
completion_tokens=data.get("eval_count", 0),
total_tokens=data.get("prompt_eval_count", 0) + data.get("eval_count", 0),
),
)
async def chat_completion_stream(
self,
request: ChatCompletionRequest,
model: str,
) -> AsyncIterator[ChatCompletionStreamResponse]:
"""Generate a chat completion (streaming)."""
payload: dict[str, Any] = {
"model": model,
"messages": self._convert_messages(request),
"stream": True,
}
# Add optional parameters
options: dict[str, Any] = {}
if request.temperature is not None:
options["temperature"] = request.temperature
if request.top_p is not None:
options["top_p"] = request.top_p
if request.max_tokens is not None:
options["num_predict"] = request.max_tokens
if request.stop:
options["stop"] = request.stop if isinstance(request.stop, list) else [request.stop]
if options:
payload["options"] = options
logger.debug(f"Ollama streaming request: {model}")
response_id = f"chatcmpl-{model[:8]}"
first_chunk = True
async with self.client.stream("POST", "/api/chat", json=payload) as response:
response.raise_for_status()
async for line in response.aiter_lines():
if not line:
continue
try:
data = json.loads(line)
except json.JSONDecodeError:
logger.warning(f"Failed to parse Ollama response line: {line}")
continue
# First chunk includes role
if first_chunk:
yield ChatCompletionStreamResponse(
id=response_id,
model=f"ollama/{model}",
choices=[
StreamChoice(
delta=DeltaContent(role="assistant"),
)
],
)
first_chunk = False
# Content chunks
content = data.get("message", {}).get("content", "")
if content:
yield ChatCompletionStreamResponse(
id=response_id,
model=f"ollama/{model}",
choices=[
StreamChoice(
delta=DeltaContent(content=content),
)
],
)
# Final chunk with finish_reason
if data.get("done"):
yield ChatCompletionStreamResponse(
id=response_id,
model=f"ollama/{model}",
choices=[
StreamChoice(
delta=DeltaContent(),
finish_reason="stop",
)
],
)
async def list_models(self) -> list[ModelInfo]:
"""List available Ollama models."""
response = await self.client.get("/api/tags")
response.raise_for_status()
data = response.json()
models = []
for model_data in data.get("models", []):
name = model_data.get("name", "")
models.append(
ModelInfo(
id=f"ollama/{name}",
owned_by="ollama",
created=int(model_data.get("modified_at", 0)) or None,
)
)
return models
async def embeddings(
self,
request: EmbeddingRequest,
model: str,
) -> EmbeddingResponse:
"""Generate embeddings for input text."""
inputs = request.input if isinstance(request.input, list) else [request.input]
embeddings_data = []
for i, text in enumerate(inputs):
response = await self.client.post(
"/api/embeddings",
json={"model": model, "prompt": text},
)
response.raise_for_status()
data = response.json()
embeddings_data.append(
EmbeddingData(
index=i,
embedding=data.get("embedding", []),
)
)
return EmbeddingResponse(
data=embeddings_data,
model=f"ollama/{model}",
usage=Usage(
prompt_tokens=sum(len(text.split()) for text in inputs), # Approximate
total_tokens=sum(len(text.split()) for text in inputs),
),
)
async def health_check(self) -> dict[str, Any]:
"""Check Ollama health status."""
try:
response = await self.client.get("/api/tags")
response.raise_for_status()
data = response.json()
model_count = len(data.get("models", []))
return {
"status": "healthy",
"provider": self.name,
"url": self.base_url,
"models_available": model_count,
}
except Exception as e:
return {
"status": "unhealthy",
"provider": self.name,
"url": self.base_url,
"error": str(e),
}
async def close(self) -> None:
"""Close HTTP client."""
if self._client and not self._client.is_closed:
await self._client.aclose()

View file

@ -0,0 +1,274 @@
"""OpenAI-compatible provider for OpenRouter, Groq, Together, etc."""
import json
import logging
from collections.abc import AsyncIterator
from typing import Any
import httpx
from src.models import (
ChatCompletionRequest,
ChatCompletionResponse,
ChatCompletionStreamResponse,
Choice,
DeltaContent,
EmbeddingData,
EmbeddingRequest,
EmbeddingResponse,
MessageResponse,
ModelInfo,
StreamChoice,
Usage,
)
from .base import LLMProvider
logger = logging.getLogger(__name__)
class OpenAICompatProvider(LLMProvider):
"""OpenAI-compatible API provider (OpenRouter, Groq, Together, etc.)."""
def __init__(
self,
name: str,
base_url: str,
api_key: str,
default_model: str | None = None,
timeout: int = 120,
):
self.name = name
self.base_url = base_url.rstrip("/")
self.api_key = api_key
self.default_model = default_model
self.timeout = timeout
self._client: httpx.AsyncClient | None = None
@property
def client(self) -> httpx.AsyncClient:
"""Get or create HTTP client."""
if self._client is None or self._client.is_closed:
self._client = httpx.AsyncClient(
base_url=self.base_url,
timeout=httpx.Timeout(self.timeout),
headers={
"Authorization": f"Bearer {self.api_key}",
"Content-Type": "application/json",
},
)
return self._client
def _convert_messages(self, request: ChatCompletionRequest) -> list[dict[str, Any]]:
"""Convert internal message format to OpenAI format."""
messages = []
for msg in request.messages:
if isinstance(msg.content, str):
messages.append({"role": msg.role, "content": msg.content})
else:
# Handle multimodal content
content_parts = []
for part in msg.content:
if part.type == "text":
content_parts.append({"type": "text", "text": part.text})
elif part.type == "image_url":
content_parts.append({
"type": "image_url",
"image_url": {"url": part.image_url.url},
})
messages.append({"role": msg.role, "content": content_parts})
return messages
async def chat_completion(
self,
request: ChatCompletionRequest,
model: str,
) -> ChatCompletionResponse:
"""Generate a chat completion (non-streaming)."""
payload: dict[str, Any] = {
"model": model,
"messages": self._convert_messages(request),
"stream": False,
}
# Add optional parameters
if request.temperature is not None:
payload["temperature"] = request.temperature
if request.max_tokens is not None:
payload["max_tokens"] = request.max_tokens
if request.top_p is not None:
payload["top_p"] = request.top_p
if request.frequency_penalty is not None:
payload["frequency_penalty"] = request.frequency_penalty
if request.presence_penalty is not None:
payload["presence_penalty"] = request.presence_penalty
if request.stop:
payload["stop"] = request.stop
logger.debug(f"{self.name} request: {model}, messages: {len(request.messages)}")
response = await self.client.post("/chat/completions", json=payload)
response.raise_for_status()
data = response.json()
return ChatCompletionResponse(
id=data.get("id", ""),
model=f"{self.name}/{model}",
choices=[
Choice(
index=choice.get("index", 0),
message=MessageResponse(content=choice["message"]["content"]),
finish_reason=choice.get("finish_reason", "stop"),
)
for choice in data.get("choices", [])
],
usage=Usage(
prompt_tokens=data.get("usage", {}).get("prompt_tokens", 0),
completion_tokens=data.get("usage", {}).get("completion_tokens", 0),
total_tokens=data.get("usage", {}).get("total_tokens", 0),
),
)
async def chat_completion_stream(
self,
request: ChatCompletionRequest,
model: str,
) -> AsyncIterator[ChatCompletionStreamResponse]:
"""Generate a chat completion (streaming)."""
payload: dict[str, Any] = {
"model": model,
"messages": self._convert_messages(request),
"stream": True,
}
# Add optional parameters
if request.temperature is not None:
payload["temperature"] = request.temperature
if request.max_tokens is not None:
payload["max_tokens"] = request.max_tokens
if request.top_p is not None:
payload["top_p"] = request.top_p
if request.frequency_penalty is not None:
payload["frequency_penalty"] = request.frequency_penalty
if request.presence_penalty is not None:
payload["presence_penalty"] = request.presence_penalty
if request.stop:
payload["stop"] = request.stop
logger.debug(f"{self.name} streaming request: {model}")
async with self.client.stream("POST", "/chat/completions", json=payload) as response:
response.raise_for_status()
async for line in response.aiter_lines():
if not line or not line.startswith("data: "):
continue
data_str = line[6:] # Remove "data: " prefix
if data_str == "[DONE]":
break
try:
data = json.loads(data_str)
except json.JSONDecodeError:
logger.warning(f"Failed to parse stream line: {data_str}")
continue
choices = data.get("choices", [])
if not choices:
continue
choice = choices[0]
delta = choice.get("delta", {})
yield ChatCompletionStreamResponse(
id=data.get("id", ""),
model=f"{self.name}/{model}",
choices=[
StreamChoice(
index=choice.get("index", 0),
delta=DeltaContent(
role=delta.get("role"),
content=delta.get("content"),
),
finish_reason=choice.get("finish_reason"),
)
],
)
async def list_models(self) -> list[ModelInfo]:
"""List available models."""
try:
response = await self.client.get("/models")
response.raise_for_status()
data = response.json()
models = []
for model_data in data.get("data", []):
model_id = model_data.get("id", "")
models.append(
ModelInfo(
id=f"{self.name}/{model_id}",
owned_by=model_data.get("owned_by", self.name),
)
)
return models
except httpx.HTTPError as e:
logger.warning(f"Failed to list models from {self.name}: {e}")
return []
async def embeddings(
self,
request: EmbeddingRequest,
model: str,
) -> EmbeddingResponse:
"""Generate embeddings for input text."""
payload = {
"model": model,
"input": request.input,
}
response = await self.client.post("/embeddings", json=payload)
response.raise_for_status()
data = response.json()
return EmbeddingResponse(
data=[
EmbeddingData(
index=item.get("index", i),
embedding=item.get("embedding", []),
)
for i, item in enumerate(data.get("data", []))
],
model=f"{self.name}/{model}",
usage=Usage(
prompt_tokens=data.get("usage", {}).get("prompt_tokens", 0),
total_tokens=data.get("usage", {}).get("total_tokens", 0),
),
)
async def health_check(self) -> dict[str, Any]:
"""Check provider health status."""
try:
response = await self.client.get("/models")
response.raise_for_status()
data = response.json()
model_count = len(data.get("data", []))
return {
"status": "healthy",
"provider": self.name,
"url": self.base_url,
"models_available": model_count,
}
except Exception as e:
return {
"status": "unhealthy",
"provider": self.name,
"url": self.base_url,
"error": str(e),
}
async def close(self) -> None:
"""Close HTTP client."""
if self._client and not self._client.is_closed:
await self._client.aclose()

View file

@ -0,0 +1,186 @@
"""Provider routing logic for mana-llm."""
import logging
from collections.abc import AsyncIterator
from typing import Any
from src.config import settings
from src.models import (
ChatCompletionRequest,
ChatCompletionResponse,
ChatCompletionStreamResponse,
EmbeddingRequest,
EmbeddingResponse,
ModelInfo,
)
from .base import LLMProvider
from .ollama import OllamaProvider
from .openai_compat import OpenAICompatProvider
logger = logging.getLogger(__name__)
class ProviderRouter:
"""Routes requests to appropriate LLM providers based on model prefix."""
def __init__(self):
self.providers: dict[str, LLMProvider] = {}
self._initialize_providers()
def _initialize_providers(self) -> None:
"""Initialize available providers based on configuration."""
# Ollama is always available (local)
self.providers["ollama"] = OllamaProvider()
logger.info(f"Initialized Ollama provider at {settings.ollama_url}")
# OpenRouter (if API key configured)
if settings.openrouter_api_key:
self.providers["openrouter"] = OpenAICompatProvider(
name="openrouter",
base_url=settings.openrouter_base_url,
api_key=settings.openrouter_api_key,
default_model=settings.openrouter_default_model,
)
logger.info("Initialized OpenRouter provider")
# Groq (if API key configured)
if settings.groq_api_key:
self.providers["groq"] = OpenAICompatProvider(
name="groq",
base_url=settings.groq_base_url,
api_key=settings.groq_api_key,
)
logger.info("Initialized Groq provider")
# Together (if API key configured)
if settings.together_api_key:
self.providers["together"] = OpenAICompatProvider(
name="together",
base_url=settings.together_base_url,
api_key=settings.together_api_key,
)
logger.info("Initialized Together provider")
def _parse_model(self, model: str) -> tuple[str, str]:
"""
Parse model string into (provider, model_name).
Format: "provider/model" or just "model" (defaults to ollama)
"""
if "/" in model:
parts = model.split("/", 1)
provider = parts[0].lower()
model_name = parts[1]
else:
# Default to Ollama
provider = "ollama"
model_name = model
return provider, model_name
def _get_provider(self, provider_name: str) -> LLMProvider:
"""Get provider by name, raise if not available."""
if provider_name not in self.providers:
available = list(self.providers.keys())
raise ValueError(
f"Provider '{provider_name}' not available. "
f"Available providers: {available}"
)
return self.providers[provider_name]
async def chat_completion(
self,
request: ChatCompletionRequest,
) -> ChatCompletionResponse:
"""Route chat completion request to appropriate provider."""
provider_name, model_name = self._parse_model(request.model)
provider = self._get_provider(provider_name)
logger.info(f"Routing chat completion to {provider_name}/{model_name}")
try:
return await provider.chat_completion(request, model_name)
except Exception as e:
logger.error(f"Chat completion failed on {provider_name}: {e}")
# Could implement fallback logic here
raise
async def chat_completion_stream(
self,
request: ChatCompletionRequest,
) -> AsyncIterator[ChatCompletionStreamResponse]:
"""Route streaming chat completion request to appropriate provider."""
provider_name, model_name = self._parse_model(request.model)
provider = self._get_provider(provider_name)
logger.info(f"Routing streaming chat completion to {provider_name}/{model_name}")
try:
async for chunk in provider.chat_completion_stream(request, model_name):
yield chunk
except Exception as e:
logger.error(f"Streaming chat completion failed on {provider_name}: {e}")
raise
async def embeddings(
self,
request: EmbeddingRequest,
) -> EmbeddingResponse:
"""Route embeddings request to appropriate provider."""
provider_name, model_name = self._parse_model(request.model)
provider = self._get_provider(provider_name)
logger.info(f"Routing embeddings to {provider_name}/{model_name}")
return await provider.embeddings(request, model_name)
async def list_models(self) -> list[ModelInfo]:
"""List all available models from all providers."""
all_models: list[ModelInfo] = []
for provider in self.providers.values():
try:
models = await provider.list_models()
all_models.extend(models)
except Exception as e:
logger.warning(f"Failed to list models from {provider.name}: {e}")
return all_models
async def get_model(self, model_id: str) -> ModelInfo | None:
"""Get specific model info."""
provider_name, model_name = self._parse_model(model_id)
if provider_name not in self.providers:
return None
provider = self.providers[provider_name]
models = await provider.list_models()
for model in models:
if model.id == model_id or model.id.endswith(f"/{model_name}"):
return model
return None
async def health_check(self) -> dict[str, Any]:
"""Check health of all providers."""
results: dict[str, Any] = {}
for name, provider in self.providers.items():
results[name] = await provider.health_check()
# Overall status
all_healthy = all(r.get("status") == "healthy" for r in results.values())
any_healthy = any(r.get("status") == "healthy" for r in results.values())
return {
"status": "healthy" if all_healthy else ("degraded" if any_healthy else "unhealthy"),
"providers": results,
}
async def close(self) -> None:
"""Close all providers."""
for provider in self.providers.values():
await provider.close()

View file

@ -0,0 +1,5 @@
"""Streaming utilities for SSE responses."""
from .sse import stream_chat_completion
__all__ = ["stream_chat_completion"]

View file

@ -0,0 +1,43 @@
"""Server-Sent Events (SSE) response handling."""
import json
import logging
from collections.abc import AsyncIterator
from src.models import ChatCompletionRequest, ChatCompletionStreamResponse
from src.providers import ProviderRouter
logger = logging.getLogger(__name__)
async def stream_chat_completion(
router: ProviderRouter,
request: ChatCompletionRequest,
) -> AsyncIterator[str]:
"""
Stream chat completion responses as SSE data lines.
Yields strings in SSE format:
data: {"choices":[{"delta":{"content":"Hello"}}]}
data: [DONE]
"""
try:
async for chunk in router.chat_completion_stream(request):
# Convert to OpenAI-compatible SSE format
data = chunk.model_dump(exclude_none=True)
yield f"data: {json.dumps(data)}\n\n"
# Send final [DONE] marker
yield "data: [DONE]\n\n"
except Exception as e:
logger.error(f"Streaming error: {e}")
# Send error as SSE event
error_data = {
"error": {
"message": str(e),
"type": "server_error",
}
}
yield f"data: {json.dumps(error_data)}\n\n"
yield "data: [DONE]\n\n"

View file

@ -0,0 +1,5 @@
"""Utility modules."""
from .metrics import get_metrics, metrics_middleware
__all__ = ["get_metrics", "metrics_middleware"]

View file

@ -0,0 +1,85 @@
"""Redis caching utilities (optional)."""
import hashlib
import json
import logging
from typing import Any
from src.config import settings
logger = logging.getLogger(__name__)
# Redis client (lazy initialized)
_redis_client = None
async def get_redis_client():
"""Get or create Redis client."""
global _redis_client
if _redis_client is not None:
return _redis_client
if not settings.redis_url:
return None
try:
import redis.asyncio as redis
_redis_client = redis.from_url(settings.redis_url)
# Test connection
await _redis_client.ping()
logger.info(f"Connected to Redis at {settings.redis_url}")
return _redis_client
except Exception as e:
logger.warning(f"Failed to connect to Redis: {e}")
return None
def generate_cache_key(prefix: str, data: dict[str, Any]) -> str:
"""Generate a cache key from request data."""
# Serialize and hash the data for consistent key
serialized = json.dumps(data, sort_keys=True)
hash_value = hashlib.sha256(serialized.encode()).hexdigest()[:16]
return f"mana-llm:{prefix}:{hash_value}"
async def get_cached(key: str) -> dict[str, Any] | None:
"""Get cached value by key."""
client = await get_redis_client()
if client is None:
return None
try:
value = await client.get(key)
if value:
return json.loads(value)
except Exception as e:
logger.warning(f"Cache get failed: {e}")
return None
async def set_cached(key: str, value: dict[str, Any], ttl: int | None = None) -> bool:
"""Set cached value with optional TTL."""
client = await get_redis_client()
if client is None:
return False
try:
ttl = ttl or settings.cache_ttl
serialized = json.dumps(value)
await client.setex(key, ttl, serialized)
return True
except Exception as e:
logger.warning(f"Cache set failed: {e}")
return False
async def close_redis() -> None:
"""Close Redis connection."""
global _redis_client
if _redis_client is not None:
await _redis_client.aclose()
_redis_client = None

View file

@ -0,0 +1,109 @@
"""Prometheus metrics for mana-llm."""
import time
from collections.abc import Callable
from fastapi import Request, Response
from prometheus_client import Counter, Histogram, generate_latest
from starlette.middleware.base import BaseHTTPMiddleware
# Request metrics
REQUEST_COUNT = Counter(
"mana_llm_requests_total",
"Total number of requests",
["method", "endpoint", "status"],
)
REQUEST_LATENCY = Histogram(
"mana_llm_request_latency_seconds",
"Request latency in seconds",
["method", "endpoint"],
buckets=[0.1, 0.5, 1.0, 2.0, 5.0, 10.0, 30.0, 60.0, 120.0],
)
# LLM-specific metrics
LLM_REQUEST_COUNT = Counter(
"mana_llm_llm_requests_total",
"Total number of LLM requests",
["provider", "model", "streaming"],
)
LLM_TOKEN_COUNT = Counter(
"mana_llm_tokens_total",
"Total tokens processed",
["provider", "model", "type"], # type: prompt, completion
)
LLM_LATENCY = Histogram(
"mana_llm_llm_latency_seconds",
"LLM request latency in seconds",
["provider", "model"],
buckets=[0.5, 1.0, 2.0, 5.0, 10.0, 30.0, 60.0, 120.0],
)
LLM_ERRORS = Counter(
"mana_llm_llm_errors_total",
"Total LLM errors",
["provider", "model", "error_type"],
)
def get_metrics() -> bytes:
"""Generate Prometheus metrics output."""
return generate_latest()
class MetricsMiddleware(BaseHTTPMiddleware):
"""Middleware for collecting HTTP metrics."""
async def dispatch(self, request: Request, call_next: Callable) -> Response:
start_time = time.time()
response = await call_next(request)
# Record metrics
duration = time.time() - start_time
endpoint = request.url.path
method = request.method
status = str(response.status_code)
REQUEST_COUNT.labels(method=method, endpoint=endpoint, status=status).inc()
REQUEST_LATENCY.labels(method=method, endpoint=endpoint).observe(duration)
return response
# Export middleware instance
metrics_middleware = MetricsMiddleware
def record_llm_request(
provider: str,
model: str,
streaming: bool,
prompt_tokens: int = 0,
completion_tokens: int = 0,
latency: float | None = None,
) -> None:
"""Record LLM request metrics."""
LLM_REQUEST_COUNT.labels(
provider=provider,
model=model,
streaming=str(streaming).lower(),
).inc()
if prompt_tokens > 0:
LLM_TOKEN_COUNT.labels(provider=provider, model=model, type="prompt").inc(prompt_tokens)
if completion_tokens > 0:
LLM_TOKEN_COUNT.labels(provider=provider, model=model, type="completion").inc(
completion_tokens
)
if latency is not None:
LLM_LATENCY.labels(provider=provider, model=model).observe(latency)
def record_llm_error(provider: str, model: str, error_type: str) -> None:
"""Record LLM error metrics."""
LLM_ERRORS.labels(provider=provider, model=model, error_type=error_type).inc()

View file

@ -0,0 +1 @@
"""Tests for mana-llm service."""

View file

@ -0,0 +1,41 @@
"""API endpoint tests."""
import pytest
from fastapi.testclient import TestClient
@pytest.fixture
def client():
"""Create test client."""
from src.main import app
with TestClient(app) as c:
yield c
def test_health_endpoint(client):
"""Test health check endpoint."""
response = client.get("/health")
assert response.status_code == 200
data = response.json()
assert "status" in data
assert "service" in data
assert data["service"] == "mana-llm"
def test_metrics_endpoint(client):
"""Test metrics endpoint."""
response = client.get("/metrics")
assert response.status_code == 200
assert "mana_llm" in response.text
def test_list_models_endpoint(client):
"""Test list models endpoint."""
response = client.get("/v1/models")
# May fail if Ollama is not running, but should return valid response structure
if response.status_code == 200:
data = response.json()
assert "data" in data
assert "object" in data
assert data["object"] == "list"

View file

@ -0,0 +1,113 @@
"""Provider tests."""
import pytest
from src.models import ChatCompletionRequest, Message
from src.providers import OllamaProvider, OpenAICompatProvider, ProviderRouter
class TestProviderRouter:
"""Test provider routing logic."""
def test_parse_model_with_provider(self):
"""Test model parsing with provider prefix."""
router = ProviderRouter()
provider, model = router._parse_model("ollama/gemma3:4b")
assert provider == "ollama"
assert model == "gemma3:4b"
def test_parse_model_without_provider(self):
"""Test model parsing without provider prefix (defaults to ollama)."""
router = ProviderRouter()
provider, model = router._parse_model("gemma3:4b")
assert provider == "ollama"
assert model == "gemma3:4b"
def test_parse_model_openrouter(self):
"""Test model parsing for OpenRouter."""
router = ProviderRouter()
provider, model = router._parse_model("openrouter/meta-llama/llama-3.1-8b-instruct")
assert provider == "openrouter"
assert model == "meta-llama/llama-3.1-8b-instruct"
def test_get_invalid_provider(self):
"""Test getting invalid provider raises error."""
router = ProviderRouter()
with pytest.raises(ValueError, match="not available"):
router._get_provider("invalid_provider")
class TestOllamaProvider:
"""Test Ollama provider."""
def test_convert_simple_messages(self):
"""Test converting simple text messages."""
provider = OllamaProvider()
request = ChatCompletionRequest(
model="gemma3:4b",
messages=[
Message(role="user", content="Hello"),
],
)
messages = provider._convert_messages(request)
assert len(messages) == 1
assert messages[0]["role"] == "user"
assert messages[0]["content"] == "Hello"
def test_convert_multimodal_messages(self):
"""Test converting multimodal messages."""
provider = OllamaProvider()
request = ChatCompletionRequest(
model="llava:7b",
messages=[
Message(
role="user",
content=[
{"type": "text", "text": "What's in this image?"},
{
"type": "image_url",
"image_url": {"url": "data:image/png;base64,iVBORw0KGgo="},
},
],
),
],
)
messages = provider._convert_messages(request)
assert len(messages) == 1
assert messages[0]["role"] == "user"
assert messages[0]["content"] == "What's in this image?"
assert "images" in messages[0]
assert len(messages[0]["images"]) == 1
class TestOpenAICompatProvider:
"""Test OpenAI-compatible provider."""
def test_convert_simple_messages(self):
"""Test converting simple text messages."""
provider = OpenAICompatProvider(
name="test",
base_url="http://localhost",
api_key="test-key",
)
request = ChatCompletionRequest(
model="test-model",
messages=[
Message(role="system", content="You are helpful."),
Message(role="user", content="Hello"),
],
)
messages = provider._convert_messages(request)
assert len(messages) == 2
assert messages[0]["role"] == "system"
assert messages[0]["content"] == "You are helpful."
assert messages[1]["role"] == "user"
assert messages[1]["content"] == "Hello"

View file

@ -0,0 +1,57 @@
"""Streaming tests."""
import pytest
from src.models import ChatCompletionStreamResponse, DeltaContent, StreamChoice
class TestStreamingModels:
"""Test streaming response models."""
def test_stream_response_serialization(self):
"""Test streaming response serializes correctly."""
response = ChatCompletionStreamResponse(
id="test-id",
model="ollama/gemma3:4b",
choices=[
StreamChoice(
delta=DeltaContent(content="Hello"),
)
],
)
data = response.model_dump(exclude_none=True)
assert data["id"] == "test-id"
assert data["model"] == "ollama/gemma3:4b"
assert data["choices"][0]["delta"]["content"] == "Hello"
def test_stream_response_with_role(self):
"""Test first chunk with role."""
response = ChatCompletionStreamResponse(
id="test-id",
model="ollama/gemma3:4b",
choices=[
StreamChoice(
delta=DeltaContent(role="assistant"),
)
],
)
data = response.model_dump(exclude_none=True)
assert data["choices"][0]["delta"]["role"] == "assistant"
def test_stream_response_with_finish_reason(self):
"""Test final chunk with finish_reason."""
response = ChatCompletionStreamResponse(
id="test-id",
model="ollama/gemma3:4b",
choices=[
StreamChoice(
delta=DeltaContent(),
finish_reason="stop",
)
],
)
data = response.model_dump(exclude_none=True)
assert data["choices"][0]["finish_reason"] == "stop"