🔒️ feat(stt,tts): add API key authentication with rate limiting

Add auth.py module to both STT and TTS services with:
- API key validation via X-API-Key header
- Rate limiting with sliding window (requests per minute)
- Internal API key option for unlimited access
- Environment variable configuration

All protected endpoints now require authentication.
Public endpoints (/health, /docs) remain accessible.
This commit is contained in:
Till-JS 2026-02-11 18:04:22 +01:00
parent 4f9d992263
commit aab304fc95
6 changed files with 527 additions and 8 deletions

View file

@ -29,3 +29,24 @@ MISTRAL_API_KEY=
# CORS Origins (comma-separated)
CORS_ORIGINS=https://mana.how,https://chat.mana.how,http://localhost:5173
# ===========================================
# Authentication
# ===========================================
# Enable API key authentication (default: true for production)
REQUIRE_AUTH=true
# API Keys (comma-separated, format: key:name)
# Example: sk-abc123:myapp,sk-def456:testuser
API_KEYS=
# Internal API key (no rate limit, for internal services)
# Generate with: openssl rand -hex 32
INTERNAL_API_KEY=
# Rate Limiting
# Requests per window per API key
RATE_LIMIT_REQUESTS=60
# Window size in seconds
RATE_LIMIT_WINDOW=60

View file

@ -0,0 +1,211 @@
"""
API Key Authentication for ManaCore STT Service
Simple API key authentication with rate limiting.
Keys are configured via environment variables.
Usage:
API_KEYS=sk-key1:name1,sk-key2:name2
Or for unlimited internal access:
INTERNAL_API_KEY=sk-internal-xxx
"""
import os
import time
import logging
from typing import Optional
from collections import defaultdict
from dataclasses import dataclass, field
from fastapi import HTTPException, Security, Request
from fastapi.security import APIKeyHeader
logger = logging.getLogger(__name__)
# Configuration
API_KEYS_ENV = os.getenv("API_KEYS", "") # Format: "sk-key1:name1,sk-key2:name2"
INTERNAL_API_KEY = os.getenv("INTERNAL_API_KEY", "") # Unlimited internal key
REQUIRE_AUTH = os.getenv("REQUIRE_AUTH", "true").lower() == "true"
RATE_LIMIT_REQUESTS = int(os.getenv("RATE_LIMIT_REQUESTS", "60")) # Per minute
RATE_LIMIT_WINDOW = int(os.getenv("RATE_LIMIT_WINDOW", "60")) # Seconds
@dataclass
class APIKey:
"""API Key with metadata."""
key: str
name: str
is_internal: bool = False
rate_limit: int = RATE_LIMIT_REQUESTS # Requests per window
@dataclass
class RateLimitInfo:
"""Rate limit tracking per key."""
requests: list = field(default_factory=list)
def is_allowed(self, limit: int, window: int) -> bool:
"""Check if request is allowed within rate limit."""
now = time.time()
# Remove old requests outside window
self.requests = [t for t in self.requests if now - t < window]
if len(self.requests) >= limit:
return False
self.requests.append(now)
return True
def remaining(self, limit: int, window: int) -> int:
"""Get remaining requests in current window."""
now = time.time()
self.requests = [t for t in self.requests if now - t < window]
return max(0, limit - len(self.requests))
# Parse API keys from environment
def _parse_api_keys() -> dict[str, APIKey]:
"""Parse API keys from environment variables."""
keys = {}
# Parse comma-separated keys
if API_KEYS_ENV:
for entry in API_KEYS_ENV.split(","):
entry = entry.strip()
if ":" in entry:
key, name = entry.split(":", 1)
else:
key, name = entry, "default"
keys[key.strip()] = APIKey(key=key.strip(), name=name.strip())
# Add internal key with no rate limit
if INTERNAL_API_KEY:
keys[INTERNAL_API_KEY] = APIKey(
key=INTERNAL_API_KEY,
name="internal",
is_internal=True,
rate_limit=999999, # Effectively unlimited
)
return keys
# Global state
_api_keys = _parse_api_keys()
_rate_limits: dict[str, RateLimitInfo] = defaultdict(RateLimitInfo)
# Security scheme
api_key_header = APIKeyHeader(name="X-API-Key", auto_error=False)
@dataclass
class AuthResult:
"""Result of authentication check."""
authenticated: bool
key_name: Optional[str] = None
is_internal: bool = False
rate_limit_remaining: Optional[int] = None
async def verify_api_key(
request: Request,
api_key: Optional[str] = Security(api_key_header),
) -> AuthResult:
"""
Verify API key and check rate limits.
Returns AuthResult with authentication status.
Raises HTTPException if auth fails or rate limited.
"""
# Skip auth for health and docs endpoints
path = request.url.path
if path in ["/health", "/docs", "/openapi.json", "/redoc"]:
return AuthResult(authenticated=True, key_name="public")
# If auth not required, allow all
if not REQUIRE_AUTH:
return AuthResult(authenticated=True, key_name="anonymous")
# Check for API key
if not api_key:
logger.warning(f"Missing API key for {path} from {request.client.host if request.client else 'unknown'}")
raise HTTPException(
status_code=401,
detail="Missing API key. Provide X-API-Key header.",
headers={"WWW-Authenticate": "ApiKey"},
)
# Validate key
if api_key not in _api_keys:
logger.warning(f"Invalid API key attempt for {path}")
raise HTTPException(
status_code=401,
detail="Invalid API key.",
headers={"WWW-Authenticate": "ApiKey"},
)
key_info = _api_keys[api_key]
# Check rate limit (skip for internal keys)
if not key_info.is_internal:
rate_info = _rate_limits[api_key]
if not rate_info.is_allowed(key_info.rate_limit, RATE_LIMIT_WINDOW):
remaining = rate_info.remaining(key_info.rate_limit, RATE_LIMIT_WINDOW)
logger.warning(f"Rate limit exceeded for key '{key_info.name}'")
raise HTTPException(
status_code=429,
detail=f"Rate limit exceeded. Try again in {RATE_LIMIT_WINDOW} seconds.",
headers={
"X-RateLimit-Limit": str(key_info.rate_limit),
"X-RateLimit-Remaining": str(remaining),
"X-RateLimit-Reset": str(int(time.time()) + RATE_LIMIT_WINDOW),
"Retry-After": str(RATE_LIMIT_WINDOW),
},
)
remaining = rate_info.remaining(key_info.rate_limit, RATE_LIMIT_WINDOW)
else:
remaining = None
logger.debug(f"Authenticated request from '{key_info.name}' to {path}")
return AuthResult(
authenticated=True,
key_name=key_info.name,
is_internal=key_info.is_internal,
rate_limit_remaining=remaining,
)
def get_api_key_stats() -> dict:
"""Get statistics about API keys (for admin endpoint)."""
stats = {
"total_keys": len(_api_keys),
"auth_required": REQUIRE_AUTH,
"rate_limit": {
"requests_per_window": RATE_LIMIT_REQUESTS,
"window_seconds": RATE_LIMIT_WINDOW,
},
"keys": [],
}
for key, info in _api_keys.items():
# Don't expose actual keys, just metadata
masked_key = key[:8] + "..." if len(key) > 8 else "***"
rate_info = _rate_limits.get(key, RateLimitInfo())
stats["keys"].append({
"name": info.name,
"key_prefix": masked_key,
"is_internal": info.is_internal,
"requests_in_window": len(rate_info.requests),
"remaining": rate_info.remaining(info.rate_limit, RATE_LIMIT_WINDOW),
})
return stats
def reload_api_keys():
"""Reload API keys from environment (for runtime updates)."""
global _api_keys
_api_keys = _parse_api_keys()
logger.info(f"Reloaded {len(_api_keys)} API keys")

View file

@ -11,11 +11,13 @@ import time
from typing import Optional
from contextlib import asynccontextmanager
from fastapi import FastAPI, File, UploadFile, Form, HTTPException
from fastapi import FastAPI, File, UploadFile, Form, HTTPException, Depends, Response
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse
from pydantic import BaseModel
from app.auth import verify_api_key, AuthResult, get_api_key_stats, REQUIRE_AUTH
# Configure logging
logging.basicConfig(
level=logging.INFO,
@ -52,6 +54,7 @@ class HealthResponse(BaseModel):
vllm_available: bool
vllm_url: Optional[str] = None
mistral_api_available: bool
auth_required: bool
models: dict
@ -136,6 +139,7 @@ async def health_check():
vllm_available=vllm_health.get("status") == "healthy",
vllm_url=VLLM_URL if USE_VLLM else None,
mistral_api_available=api_available(),
auth_required=REQUIRE_AUTH,
models={
"default_whisper": DEFAULT_WHISPER_MODEL,
},
@ -143,7 +147,7 @@ async def health_check():
@app.get("/models", response_model=ModelsResponse)
async def list_models():
async def list_models(auth: AuthResult = Depends(verify_api_key)):
"""List available models."""
from app.whisper_service import AVAILABLE_MODELS as whisper_models
from app.vllm_service import get_models
@ -159,9 +163,11 @@ async def list_models():
@app.post("/transcribe", response_model=TranscriptionResponse)
async def transcribe_whisper(
response: Response,
file: UploadFile = File(..., description="Audio file to transcribe"),
language: Optional[str] = Form(None, description="Language code (auto-detect if not provided)"),
model: Optional[str] = Form(None, description="Whisper model to use"),
auth: AuthResult = Depends(verify_api_key),
):
"""
Transcribe audio using Whisper (Lightning MLX).
@ -170,6 +176,10 @@ async def transcribe_whisper(
Supported formats: mp3, wav, m4a, flac, ogg, webm
Max file size: 100MB
"""
# Add rate limit headers
if auth.rate_limit_remaining is not None:
response.headers["X-RateLimit-Remaining"] = str(auth.rate_limit_remaining)
if not file.filename:
raise HTTPException(status_code=400, detail="No file provided")
@ -216,9 +226,11 @@ async def transcribe_whisper(
@app.post("/transcribe/voxtral", response_model=TranscriptionResponse)
async def transcribe_voxtral(
response: Response,
file: UploadFile = File(..., description="Audio file to transcribe"),
language: str = Form("de", description="Language code"),
use_realtime: bool = Form(False, description="Use Realtime 4B model for lower latency"),
auth: AuthResult = Depends(verify_api_key),
):
"""
Transcribe audio using Voxtral via vLLM server.
@ -232,6 +244,10 @@ async def transcribe_voxtral(
Supported formats: mp3, wav, m4a, flac, ogg, webm
Max file size: 100MB
"""
# Add rate limit headers
if auth.rate_limit_remaining is not None:
response.headers["X-RateLimit-Remaining"] = str(auth.rate_limit_remaining)
if not file.filename:
raise HTTPException(status_code=400, detail="No file provided")
@ -315,9 +331,11 @@ async def transcribe_voxtral(
@app.post("/transcribe/voxtral/api", response_model=TranscriptionResponse)
async def transcribe_voxtral_api(
response: Response,
file: UploadFile = File(..., description="Audio file to transcribe"),
language: Optional[str] = Form(None, description="Language code (auto-detect if not provided)"),
diarization: bool = Form(False, description="Enable speaker diarization"),
auth: AuthResult = Depends(verify_api_key),
):
"""
Transcribe audio using Mistral's Voxtral API directly.
@ -329,6 +347,10 @@ async def transcribe_voxtral_api(
Requires MISTRAL_API_KEY environment variable.
"""
# Add rate limit headers
if auth.rate_limit_remaining is not None:
response.headers["X-RateLimit-Remaining"] = str(auth.rate_limit_remaining)
from app.voxtral_api_service import is_available, transcribe_audio_bytes
if not is_available():
@ -366,9 +388,11 @@ async def transcribe_voxtral_api(
@app.post("/transcribe/auto", response_model=TranscriptionResponse)
async def transcribe_auto(
response: Response,
file: UploadFile = File(..., description="Audio file to transcribe"),
language: Optional[str] = Form(None, description="Language hint"),
prefer: str = Form("whisper", description="Preferred: 'whisper' or 'voxtral'"),
auth: AuthResult = Depends(verify_api_key),
):
"""
Transcribe with automatic model selection and fallback.
@ -378,6 +402,10 @@ async def transcribe_auto(
2. Alternative model
3. Mistral API
"""
# Add rate limit headers
if auth.rate_limit_remaining is not None:
response.headers["X-RateLimit-Remaining"] = str(auth.rate_limit_remaining)
if prefer == "voxtral":
try:
return await transcribe_voxtral(file, language or "de", False)

View file

@ -0,0 +1,36 @@
# ManaCore TTS Service Configuration
# Copy to .env and adjust values as needed
# Server
PORT=3022
# Models
# Set to true to preload models on startup (slower startup, faster first request)
PRELOAD_MODELS=false
# Text Limits
MAX_TEXT_LENGTH=1000
# CORS Origins (comma-separated)
CORS_ORIGINS=https://mana.how,https://chat.mana.how,http://localhost:5173
# ===========================================
# Authentication
# ===========================================
# Enable API key authentication (default: true for production)
REQUIRE_AUTH=true
# API Keys (comma-separated, format: key:name)
# Example: sk-abc123:myapp,sk-def456:testuser
API_KEYS=
# Internal API key (no rate limit, for internal services)
# Generate with: openssl rand -hex 32
INTERNAL_API_KEY=
# Rate Limiting
# Requests per window per API key
RATE_LIMIT_REQUESTS=60
# Window size in seconds
RATE_LIMIT_WINDOW=60

View file

@ -0,0 +1,211 @@
"""
API Key Authentication for ManaCore STT Service
Simple API key authentication with rate limiting.
Keys are configured via environment variables.
Usage:
API_KEYS=sk-key1:name1,sk-key2:name2
Or for unlimited internal access:
INTERNAL_API_KEY=sk-internal-xxx
"""
import os
import time
import logging
from typing import Optional
from collections import defaultdict
from dataclasses import dataclass, field
from fastapi import HTTPException, Security, Request
from fastapi.security import APIKeyHeader
logger = logging.getLogger(__name__)
# Configuration
API_KEYS_ENV = os.getenv("API_KEYS", "") # Format: "sk-key1:name1,sk-key2:name2"
INTERNAL_API_KEY = os.getenv("INTERNAL_API_KEY", "") # Unlimited internal key
REQUIRE_AUTH = os.getenv("REQUIRE_AUTH", "true").lower() == "true"
RATE_LIMIT_REQUESTS = int(os.getenv("RATE_LIMIT_REQUESTS", "60")) # Per minute
RATE_LIMIT_WINDOW = int(os.getenv("RATE_LIMIT_WINDOW", "60")) # Seconds
@dataclass
class APIKey:
"""API Key with metadata."""
key: str
name: str
is_internal: bool = False
rate_limit: int = RATE_LIMIT_REQUESTS # Requests per window
@dataclass
class RateLimitInfo:
"""Rate limit tracking per key."""
requests: list = field(default_factory=list)
def is_allowed(self, limit: int, window: int) -> bool:
"""Check if request is allowed within rate limit."""
now = time.time()
# Remove old requests outside window
self.requests = [t for t in self.requests if now - t < window]
if len(self.requests) >= limit:
return False
self.requests.append(now)
return True
def remaining(self, limit: int, window: int) -> int:
"""Get remaining requests in current window."""
now = time.time()
self.requests = [t for t in self.requests if now - t < window]
return max(0, limit - len(self.requests))
# Parse API keys from environment
def _parse_api_keys() -> dict[str, APIKey]:
"""Parse API keys from environment variables."""
keys = {}
# Parse comma-separated keys
if API_KEYS_ENV:
for entry in API_KEYS_ENV.split(","):
entry = entry.strip()
if ":" in entry:
key, name = entry.split(":", 1)
else:
key, name = entry, "default"
keys[key.strip()] = APIKey(key=key.strip(), name=name.strip())
# Add internal key with no rate limit
if INTERNAL_API_KEY:
keys[INTERNAL_API_KEY] = APIKey(
key=INTERNAL_API_KEY,
name="internal",
is_internal=True,
rate_limit=999999, # Effectively unlimited
)
return keys
# Global state
_api_keys = _parse_api_keys()
_rate_limits: dict[str, RateLimitInfo] = defaultdict(RateLimitInfo)
# Security scheme
api_key_header = APIKeyHeader(name="X-API-Key", auto_error=False)
@dataclass
class AuthResult:
"""Result of authentication check."""
authenticated: bool
key_name: Optional[str] = None
is_internal: bool = False
rate_limit_remaining: Optional[int] = None
async def verify_api_key(
request: Request,
api_key: Optional[str] = Security(api_key_header),
) -> AuthResult:
"""
Verify API key and check rate limits.
Returns AuthResult with authentication status.
Raises HTTPException if auth fails or rate limited.
"""
# Skip auth for health and docs endpoints
path = request.url.path
if path in ["/health", "/docs", "/openapi.json", "/redoc"]:
return AuthResult(authenticated=True, key_name="public")
# If auth not required, allow all
if not REQUIRE_AUTH:
return AuthResult(authenticated=True, key_name="anonymous")
# Check for API key
if not api_key:
logger.warning(f"Missing API key for {path} from {request.client.host if request.client else 'unknown'}")
raise HTTPException(
status_code=401,
detail="Missing API key. Provide X-API-Key header.",
headers={"WWW-Authenticate": "ApiKey"},
)
# Validate key
if api_key not in _api_keys:
logger.warning(f"Invalid API key attempt for {path}")
raise HTTPException(
status_code=401,
detail="Invalid API key.",
headers={"WWW-Authenticate": "ApiKey"},
)
key_info = _api_keys[api_key]
# Check rate limit (skip for internal keys)
if not key_info.is_internal:
rate_info = _rate_limits[api_key]
if not rate_info.is_allowed(key_info.rate_limit, RATE_LIMIT_WINDOW):
remaining = rate_info.remaining(key_info.rate_limit, RATE_LIMIT_WINDOW)
logger.warning(f"Rate limit exceeded for key '{key_info.name}'")
raise HTTPException(
status_code=429,
detail=f"Rate limit exceeded. Try again in {RATE_LIMIT_WINDOW} seconds.",
headers={
"X-RateLimit-Limit": str(key_info.rate_limit),
"X-RateLimit-Remaining": str(remaining),
"X-RateLimit-Reset": str(int(time.time()) + RATE_LIMIT_WINDOW),
"Retry-After": str(RATE_LIMIT_WINDOW),
},
)
remaining = rate_info.remaining(key_info.rate_limit, RATE_LIMIT_WINDOW)
else:
remaining = None
logger.debug(f"Authenticated request from '{key_info.name}' to {path}")
return AuthResult(
authenticated=True,
key_name=key_info.name,
is_internal=key_info.is_internal,
rate_limit_remaining=remaining,
)
def get_api_key_stats() -> dict:
"""Get statistics about API keys (for admin endpoint)."""
stats = {
"total_keys": len(_api_keys),
"auth_required": REQUIRE_AUTH,
"rate_limit": {
"requests_per_window": RATE_LIMIT_REQUESTS,
"window_seconds": RATE_LIMIT_WINDOW,
},
"keys": [],
}
for key, info in _api_keys.items():
# Don't expose actual keys, just metadata
masked_key = key[:8] + "..." if len(key) > 8 else "***"
rate_info = _rate_limits.get(key, RateLimitInfo())
stats["keys"].append({
"name": info.name,
"key_prefix": masked_key,
"is_internal": info.is_internal,
"requests_in_window": len(rate_info.requests),
"remaining": rate_info.remaining(info.rate_limit, RATE_LIMIT_WINDOW),
})
return stats
def reload_api_keys():
"""Reload API keys from environment (for runtime updates)."""
global _api_keys
_api_keys = _parse_api_keys()
logger.info(f"Reloaded {len(_api_keys)} API keys")

View file

@ -14,10 +14,12 @@ from contextlib import asynccontextmanager
from pathlib import Path
from typing import Optional
from fastapi import FastAPI, HTTPException, UploadFile, File, Form, Response
from fastapi import FastAPI, HTTPException, UploadFile, File, Form, Response, Depends
from fastapi.middleware.cors import CORSMiddleware
from pydantic import BaseModel, Field
from .auth import verify_api_key, AuthResult, REQUIRE_AUTH
from .audio_utils import convert_audio, SUPPORTED_FORMATS, cleanup_temp_file, save_temp_audio
from .kokoro_service import (
synthesize_kokoro,
@ -142,6 +144,7 @@ class HealthResponse(BaseModel):
status: str
service: str
models_loaded: dict
auth_required: bool
class ModelsResponse(BaseModel):
@ -196,11 +199,12 @@ async def health_check():
"kokoro": is_kokoro_loaded(),
"f5": is_f5_loaded(),
},
auth_required=REQUIRE_AUTH,
)
@app.get("/models", response_model=ModelsResponse)
async def get_models():
async def get_models(auth: AuthResult = Depends(verify_api_key)):
"""Get information about available models."""
return ModelsResponse(
kokoro={
@ -226,7 +230,7 @@ async def get_models():
@app.get("/voices", response_model=VoicesResponse)
async def get_voices():
async def get_voices(auth: AuthResult = Depends(verify_api_key)):
"""Get all available voices."""
# Kokoro preset voices
kokoro_voices = [
@ -264,6 +268,7 @@ async def register_voice(
description: str = Form("", description="Voice description"),
transcript: str = Form(..., description="Transcript of the reference audio"),
reference_audio: UploadFile = File(..., description="Reference audio file"),
auth: AuthResult = Depends(verify_api_key),
):
"""
Register a new custom voice for F5-TTS voice cloning.
@ -313,7 +318,7 @@ async def register_voice(
@app.delete("/voices/{voice_id}", response_model=VoiceDeletedResponse)
async def delete_voice(voice_id: str):
async def delete_voice(voice_id: str, auth: AuthResult = Depends(verify_api_key)):
"""Delete a registered custom voice."""
voice_manager = get_voice_manager()
@ -332,7 +337,10 @@ async def delete_voice(voice_id: str):
@app.post("/synthesize/kokoro")
async def synthesize_with_kokoro(request: KokoroRequest):
async def synthesize_with_kokoro(
request: KokoroRequest,
auth: AuthResult = Depends(verify_api_key),
):
"""
Synthesize speech using Kokoro with preset voices.
@ -403,6 +411,7 @@ async def synthesize_with_f5(
output_format: str = Form("wav", description="Output format (wav, mp3)"),
speed: float = Form(1.0, ge=0.5, le=2.0, description="Speech speed"),
steps: int = Form(32, ge=8, le=64, description="Diffusion steps"),
auth: AuthResult = Depends(verify_api_key),
):
"""
Synthesize speech using F5-TTS with voice cloning.
@ -520,7 +529,10 @@ async def synthesize_with_f5(
@app.post("/synthesize/auto")
async def synthesize_auto(request: AutoRequest):
async def synthesize_auto(
request: AutoRequest,
auth: AuthResult = Depends(verify_api_key),
):
"""
Auto-select the best TTS model based on voice parameter.