mirror of
https://github.com/Memo-2023/mana-monorepo.git
synced 2026-05-14 22:41:09 +02:00
🔒️ feat(stt,tts): add API key authentication with rate limiting
Add auth.py module to both STT and TTS services with: - API key validation via X-API-Key header - Rate limiting with sliding window (requests per minute) - Internal API key option for unlimited access - Environment variable configuration All protected endpoints now require authentication. Public endpoints (/health, /docs) remain accessible.
This commit is contained in:
parent
4f9d992263
commit
aab304fc95
6 changed files with 527 additions and 8 deletions
|
|
@ -29,3 +29,24 @@ MISTRAL_API_KEY=
|
|||
|
||||
# CORS Origins (comma-separated)
|
||||
CORS_ORIGINS=https://mana.how,https://chat.mana.how,http://localhost:5173
|
||||
|
||||
# ===========================================
|
||||
# Authentication
|
||||
# ===========================================
|
||||
|
||||
# Enable API key authentication (default: true for production)
|
||||
REQUIRE_AUTH=true
|
||||
|
||||
# API Keys (comma-separated, format: key:name)
|
||||
# Example: sk-abc123:myapp,sk-def456:testuser
|
||||
API_KEYS=
|
||||
|
||||
# Internal API key (no rate limit, for internal services)
|
||||
# Generate with: openssl rand -hex 32
|
||||
INTERNAL_API_KEY=
|
||||
|
||||
# Rate Limiting
|
||||
# Requests per window per API key
|
||||
RATE_LIMIT_REQUESTS=60
|
||||
# Window size in seconds
|
||||
RATE_LIMIT_WINDOW=60
|
||||
|
|
|
|||
211
services/mana-stt/app/auth.py
Normal file
211
services/mana-stt/app/auth.py
Normal file
|
|
@ -0,0 +1,211 @@
|
|||
"""
|
||||
API Key Authentication for ManaCore STT Service
|
||||
|
||||
Simple API key authentication with rate limiting.
|
||||
Keys are configured via environment variables.
|
||||
|
||||
Usage:
|
||||
API_KEYS=sk-key1:name1,sk-key2:name2
|
||||
|
||||
Or for unlimited internal access:
|
||||
INTERNAL_API_KEY=sk-internal-xxx
|
||||
"""
|
||||
|
||||
import os
|
||||
import time
|
||||
import logging
|
||||
from typing import Optional
|
||||
from collections import defaultdict
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
from fastapi import HTTPException, Security, Request
|
||||
from fastapi.security import APIKeyHeader
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Configuration
|
||||
API_KEYS_ENV = os.getenv("API_KEYS", "") # Format: "sk-key1:name1,sk-key2:name2"
|
||||
INTERNAL_API_KEY = os.getenv("INTERNAL_API_KEY", "") # Unlimited internal key
|
||||
REQUIRE_AUTH = os.getenv("REQUIRE_AUTH", "true").lower() == "true"
|
||||
RATE_LIMIT_REQUESTS = int(os.getenv("RATE_LIMIT_REQUESTS", "60")) # Per minute
|
||||
RATE_LIMIT_WINDOW = int(os.getenv("RATE_LIMIT_WINDOW", "60")) # Seconds
|
||||
|
||||
|
||||
@dataclass
|
||||
class APIKey:
|
||||
"""API Key with metadata."""
|
||||
key: str
|
||||
name: str
|
||||
is_internal: bool = False
|
||||
rate_limit: int = RATE_LIMIT_REQUESTS # Requests per window
|
||||
|
||||
|
||||
@dataclass
|
||||
class RateLimitInfo:
|
||||
"""Rate limit tracking per key."""
|
||||
requests: list = field(default_factory=list)
|
||||
|
||||
def is_allowed(self, limit: int, window: int) -> bool:
|
||||
"""Check if request is allowed within rate limit."""
|
||||
now = time.time()
|
||||
# Remove old requests outside window
|
||||
self.requests = [t for t in self.requests if now - t < window]
|
||||
|
||||
if len(self.requests) >= limit:
|
||||
return False
|
||||
|
||||
self.requests.append(now)
|
||||
return True
|
||||
|
||||
def remaining(self, limit: int, window: int) -> int:
|
||||
"""Get remaining requests in current window."""
|
||||
now = time.time()
|
||||
self.requests = [t for t in self.requests if now - t < window]
|
||||
return max(0, limit - len(self.requests))
|
||||
|
||||
|
||||
# Parse API keys from environment
|
||||
def _parse_api_keys() -> dict[str, APIKey]:
|
||||
"""Parse API keys from environment variables."""
|
||||
keys = {}
|
||||
|
||||
# Parse comma-separated keys
|
||||
if API_KEYS_ENV:
|
||||
for entry in API_KEYS_ENV.split(","):
|
||||
entry = entry.strip()
|
||||
if ":" in entry:
|
||||
key, name = entry.split(":", 1)
|
||||
else:
|
||||
key, name = entry, "default"
|
||||
keys[key.strip()] = APIKey(key=key.strip(), name=name.strip())
|
||||
|
||||
# Add internal key with no rate limit
|
||||
if INTERNAL_API_KEY:
|
||||
keys[INTERNAL_API_KEY] = APIKey(
|
||||
key=INTERNAL_API_KEY,
|
||||
name="internal",
|
||||
is_internal=True,
|
||||
rate_limit=999999, # Effectively unlimited
|
||||
)
|
||||
|
||||
return keys
|
||||
|
||||
|
||||
# Global state
|
||||
_api_keys = _parse_api_keys()
|
||||
_rate_limits: dict[str, RateLimitInfo] = defaultdict(RateLimitInfo)
|
||||
|
||||
# Security scheme
|
||||
api_key_header = APIKeyHeader(name="X-API-Key", auto_error=False)
|
||||
|
||||
|
||||
@dataclass
|
||||
class AuthResult:
|
||||
"""Result of authentication check."""
|
||||
authenticated: bool
|
||||
key_name: Optional[str] = None
|
||||
is_internal: bool = False
|
||||
rate_limit_remaining: Optional[int] = None
|
||||
|
||||
|
||||
async def verify_api_key(
|
||||
request: Request,
|
||||
api_key: Optional[str] = Security(api_key_header),
|
||||
) -> AuthResult:
|
||||
"""
|
||||
Verify API key and check rate limits.
|
||||
|
||||
Returns AuthResult with authentication status.
|
||||
Raises HTTPException if auth fails or rate limited.
|
||||
"""
|
||||
# Skip auth for health and docs endpoints
|
||||
path = request.url.path
|
||||
if path in ["/health", "/docs", "/openapi.json", "/redoc"]:
|
||||
return AuthResult(authenticated=True, key_name="public")
|
||||
|
||||
# If auth not required, allow all
|
||||
if not REQUIRE_AUTH:
|
||||
return AuthResult(authenticated=True, key_name="anonymous")
|
||||
|
||||
# Check for API key
|
||||
if not api_key:
|
||||
logger.warning(f"Missing API key for {path} from {request.client.host if request.client else 'unknown'}")
|
||||
raise HTTPException(
|
||||
status_code=401,
|
||||
detail="Missing API key. Provide X-API-Key header.",
|
||||
headers={"WWW-Authenticate": "ApiKey"},
|
||||
)
|
||||
|
||||
# Validate key
|
||||
if api_key not in _api_keys:
|
||||
logger.warning(f"Invalid API key attempt for {path}")
|
||||
raise HTTPException(
|
||||
status_code=401,
|
||||
detail="Invalid API key.",
|
||||
headers={"WWW-Authenticate": "ApiKey"},
|
||||
)
|
||||
|
||||
key_info = _api_keys[api_key]
|
||||
|
||||
# Check rate limit (skip for internal keys)
|
||||
if not key_info.is_internal:
|
||||
rate_info = _rate_limits[api_key]
|
||||
if not rate_info.is_allowed(key_info.rate_limit, RATE_LIMIT_WINDOW):
|
||||
remaining = rate_info.remaining(key_info.rate_limit, RATE_LIMIT_WINDOW)
|
||||
logger.warning(f"Rate limit exceeded for key '{key_info.name}'")
|
||||
raise HTTPException(
|
||||
status_code=429,
|
||||
detail=f"Rate limit exceeded. Try again in {RATE_LIMIT_WINDOW} seconds.",
|
||||
headers={
|
||||
"X-RateLimit-Limit": str(key_info.rate_limit),
|
||||
"X-RateLimit-Remaining": str(remaining),
|
||||
"X-RateLimit-Reset": str(int(time.time()) + RATE_LIMIT_WINDOW),
|
||||
"Retry-After": str(RATE_LIMIT_WINDOW),
|
||||
},
|
||||
)
|
||||
remaining = rate_info.remaining(key_info.rate_limit, RATE_LIMIT_WINDOW)
|
||||
else:
|
||||
remaining = None
|
||||
|
||||
logger.debug(f"Authenticated request from '{key_info.name}' to {path}")
|
||||
|
||||
return AuthResult(
|
||||
authenticated=True,
|
||||
key_name=key_info.name,
|
||||
is_internal=key_info.is_internal,
|
||||
rate_limit_remaining=remaining,
|
||||
)
|
||||
|
||||
|
||||
def get_api_key_stats() -> dict:
|
||||
"""Get statistics about API keys (for admin endpoint)."""
|
||||
stats = {
|
||||
"total_keys": len(_api_keys),
|
||||
"auth_required": REQUIRE_AUTH,
|
||||
"rate_limit": {
|
||||
"requests_per_window": RATE_LIMIT_REQUESTS,
|
||||
"window_seconds": RATE_LIMIT_WINDOW,
|
||||
},
|
||||
"keys": [],
|
||||
}
|
||||
|
||||
for key, info in _api_keys.items():
|
||||
# Don't expose actual keys, just metadata
|
||||
masked_key = key[:8] + "..." if len(key) > 8 else "***"
|
||||
rate_info = _rate_limits.get(key, RateLimitInfo())
|
||||
stats["keys"].append({
|
||||
"name": info.name,
|
||||
"key_prefix": masked_key,
|
||||
"is_internal": info.is_internal,
|
||||
"requests_in_window": len(rate_info.requests),
|
||||
"remaining": rate_info.remaining(info.rate_limit, RATE_LIMIT_WINDOW),
|
||||
})
|
||||
|
||||
return stats
|
||||
|
||||
|
||||
def reload_api_keys():
|
||||
"""Reload API keys from environment (for runtime updates)."""
|
||||
global _api_keys
|
||||
_api_keys = _parse_api_keys()
|
||||
logger.info(f"Reloaded {len(_api_keys)} API keys")
|
||||
|
|
@ -11,11 +11,13 @@ import time
|
|||
from typing import Optional
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
from fastapi import FastAPI, File, UploadFile, Form, HTTPException
|
||||
from fastapi import FastAPI, File, UploadFile, Form, HTTPException, Depends, Response
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.responses import JSONResponse
|
||||
from pydantic import BaseModel
|
||||
|
||||
from app.auth import verify_api_key, AuthResult, get_api_key_stats, REQUIRE_AUTH
|
||||
|
||||
# Configure logging
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
|
|
@ -52,6 +54,7 @@ class HealthResponse(BaseModel):
|
|||
vllm_available: bool
|
||||
vllm_url: Optional[str] = None
|
||||
mistral_api_available: bool
|
||||
auth_required: bool
|
||||
models: dict
|
||||
|
||||
|
||||
|
|
@ -136,6 +139,7 @@ async def health_check():
|
|||
vllm_available=vllm_health.get("status") == "healthy",
|
||||
vllm_url=VLLM_URL if USE_VLLM else None,
|
||||
mistral_api_available=api_available(),
|
||||
auth_required=REQUIRE_AUTH,
|
||||
models={
|
||||
"default_whisper": DEFAULT_WHISPER_MODEL,
|
||||
},
|
||||
|
|
@ -143,7 +147,7 @@ async def health_check():
|
|||
|
||||
|
||||
@app.get("/models", response_model=ModelsResponse)
|
||||
async def list_models():
|
||||
async def list_models(auth: AuthResult = Depends(verify_api_key)):
|
||||
"""List available models."""
|
||||
from app.whisper_service import AVAILABLE_MODELS as whisper_models
|
||||
from app.vllm_service import get_models
|
||||
|
|
@ -159,9 +163,11 @@ async def list_models():
|
|||
|
||||
@app.post("/transcribe", response_model=TranscriptionResponse)
|
||||
async def transcribe_whisper(
|
||||
response: Response,
|
||||
file: UploadFile = File(..., description="Audio file to transcribe"),
|
||||
language: Optional[str] = Form(None, description="Language code (auto-detect if not provided)"),
|
||||
model: Optional[str] = Form(None, description="Whisper model to use"),
|
||||
auth: AuthResult = Depends(verify_api_key),
|
||||
):
|
||||
"""
|
||||
Transcribe audio using Whisper (Lightning MLX).
|
||||
|
|
@ -170,6 +176,10 @@ async def transcribe_whisper(
|
|||
Supported formats: mp3, wav, m4a, flac, ogg, webm
|
||||
Max file size: 100MB
|
||||
"""
|
||||
# Add rate limit headers
|
||||
if auth.rate_limit_remaining is not None:
|
||||
response.headers["X-RateLimit-Remaining"] = str(auth.rate_limit_remaining)
|
||||
|
||||
if not file.filename:
|
||||
raise HTTPException(status_code=400, detail="No file provided")
|
||||
|
||||
|
|
@ -216,9 +226,11 @@ async def transcribe_whisper(
|
|||
|
||||
@app.post("/transcribe/voxtral", response_model=TranscriptionResponse)
|
||||
async def transcribe_voxtral(
|
||||
response: Response,
|
||||
file: UploadFile = File(..., description="Audio file to transcribe"),
|
||||
language: str = Form("de", description="Language code"),
|
||||
use_realtime: bool = Form(False, description="Use Realtime 4B model for lower latency"),
|
||||
auth: AuthResult = Depends(verify_api_key),
|
||||
):
|
||||
"""
|
||||
Transcribe audio using Voxtral via vLLM server.
|
||||
|
|
@ -232,6 +244,10 @@ async def transcribe_voxtral(
|
|||
Supported formats: mp3, wav, m4a, flac, ogg, webm
|
||||
Max file size: 100MB
|
||||
"""
|
||||
# Add rate limit headers
|
||||
if auth.rate_limit_remaining is not None:
|
||||
response.headers["X-RateLimit-Remaining"] = str(auth.rate_limit_remaining)
|
||||
|
||||
if not file.filename:
|
||||
raise HTTPException(status_code=400, detail="No file provided")
|
||||
|
||||
|
|
@ -315,9 +331,11 @@ async def transcribe_voxtral(
|
|||
|
||||
@app.post("/transcribe/voxtral/api", response_model=TranscriptionResponse)
|
||||
async def transcribe_voxtral_api(
|
||||
response: Response,
|
||||
file: UploadFile = File(..., description="Audio file to transcribe"),
|
||||
language: Optional[str] = Form(None, description="Language code (auto-detect if not provided)"),
|
||||
diarization: bool = Form(False, description="Enable speaker diarization"),
|
||||
auth: AuthResult = Depends(verify_api_key),
|
||||
):
|
||||
"""
|
||||
Transcribe audio using Mistral's Voxtral API directly.
|
||||
|
|
@ -329,6 +347,10 @@ async def transcribe_voxtral_api(
|
|||
|
||||
Requires MISTRAL_API_KEY environment variable.
|
||||
"""
|
||||
# Add rate limit headers
|
||||
if auth.rate_limit_remaining is not None:
|
||||
response.headers["X-RateLimit-Remaining"] = str(auth.rate_limit_remaining)
|
||||
|
||||
from app.voxtral_api_service import is_available, transcribe_audio_bytes
|
||||
|
||||
if not is_available():
|
||||
|
|
@ -366,9 +388,11 @@ async def transcribe_voxtral_api(
|
|||
|
||||
@app.post("/transcribe/auto", response_model=TranscriptionResponse)
|
||||
async def transcribe_auto(
|
||||
response: Response,
|
||||
file: UploadFile = File(..., description="Audio file to transcribe"),
|
||||
language: Optional[str] = Form(None, description="Language hint"),
|
||||
prefer: str = Form("whisper", description="Preferred: 'whisper' or 'voxtral'"),
|
||||
auth: AuthResult = Depends(verify_api_key),
|
||||
):
|
||||
"""
|
||||
Transcribe with automatic model selection and fallback.
|
||||
|
|
@ -378,6 +402,10 @@ async def transcribe_auto(
|
|||
2. Alternative model
|
||||
3. Mistral API
|
||||
"""
|
||||
# Add rate limit headers
|
||||
if auth.rate_limit_remaining is not None:
|
||||
response.headers["X-RateLimit-Remaining"] = str(auth.rate_limit_remaining)
|
||||
|
||||
if prefer == "voxtral":
|
||||
try:
|
||||
return await transcribe_voxtral(file, language or "de", False)
|
||||
|
|
|
|||
36
services/mana-tts/.env.example
Normal file
36
services/mana-tts/.env.example
Normal file
|
|
@ -0,0 +1,36 @@
|
|||
# ManaCore TTS Service Configuration
|
||||
# Copy to .env and adjust values as needed
|
||||
|
||||
# Server
|
||||
PORT=3022
|
||||
|
||||
# Models
|
||||
# Set to true to preload models on startup (slower startup, faster first request)
|
||||
PRELOAD_MODELS=false
|
||||
|
||||
# Text Limits
|
||||
MAX_TEXT_LENGTH=1000
|
||||
|
||||
# CORS Origins (comma-separated)
|
||||
CORS_ORIGINS=https://mana.how,https://chat.mana.how,http://localhost:5173
|
||||
|
||||
# ===========================================
|
||||
# Authentication
|
||||
# ===========================================
|
||||
|
||||
# Enable API key authentication (default: true for production)
|
||||
REQUIRE_AUTH=true
|
||||
|
||||
# API Keys (comma-separated, format: key:name)
|
||||
# Example: sk-abc123:myapp,sk-def456:testuser
|
||||
API_KEYS=
|
||||
|
||||
# Internal API key (no rate limit, for internal services)
|
||||
# Generate with: openssl rand -hex 32
|
||||
INTERNAL_API_KEY=
|
||||
|
||||
# Rate Limiting
|
||||
# Requests per window per API key
|
||||
RATE_LIMIT_REQUESTS=60
|
||||
# Window size in seconds
|
||||
RATE_LIMIT_WINDOW=60
|
||||
211
services/mana-tts/app/auth.py
Normal file
211
services/mana-tts/app/auth.py
Normal file
|
|
@ -0,0 +1,211 @@
|
|||
"""
|
||||
API Key Authentication for ManaCore STT Service
|
||||
|
||||
Simple API key authentication with rate limiting.
|
||||
Keys are configured via environment variables.
|
||||
|
||||
Usage:
|
||||
API_KEYS=sk-key1:name1,sk-key2:name2
|
||||
|
||||
Or for unlimited internal access:
|
||||
INTERNAL_API_KEY=sk-internal-xxx
|
||||
"""
|
||||
|
||||
import os
|
||||
import time
|
||||
import logging
|
||||
from typing import Optional
|
||||
from collections import defaultdict
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
from fastapi import HTTPException, Security, Request
|
||||
from fastapi.security import APIKeyHeader
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Configuration
|
||||
API_KEYS_ENV = os.getenv("API_KEYS", "") # Format: "sk-key1:name1,sk-key2:name2"
|
||||
INTERNAL_API_KEY = os.getenv("INTERNAL_API_KEY", "") # Unlimited internal key
|
||||
REQUIRE_AUTH = os.getenv("REQUIRE_AUTH", "true").lower() == "true"
|
||||
RATE_LIMIT_REQUESTS = int(os.getenv("RATE_LIMIT_REQUESTS", "60")) # Per minute
|
||||
RATE_LIMIT_WINDOW = int(os.getenv("RATE_LIMIT_WINDOW", "60")) # Seconds
|
||||
|
||||
|
||||
@dataclass
|
||||
class APIKey:
|
||||
"""API Key with metadata."""
|
||||
key: str
|
||||
name: str
|
||||
is_internal: bool = False
|
||||
rate_limit: int = RATE_LIMIT_REQUESTS # Requests per window
|
||||
|
||||
|
||||
@dataclass
|
||||
class RateLimitInfo:
|
||||
"""Rate limit tracking per key."""
|
||||
requests: list = field(default_factory=list)
|
||||
|
||||
def is_allowed(self, limit: int, window: int) -> bool:
|
||||
"""Check if request is allowed within rate limit."""
|
||||
now = time.time()
|
||||
# Remove old requests outside window
|
||||
self.requests = [t for t in self.requests if now - t < window]
|
||||
|
||||
if len(self.requests) >= limit:
|
||||
return False
|
||||
|
||||
self.requests.append(now)
|
||||
return True
|
||||
|
||||
def remaining(self, limit: int, window: int) -> int:
|
||||
"""Get remaining requests in current window."""
|
||||
now = time.time()
|
||||
self.requests = [t for t in self.requests if now - t < window]
|
||||
return max(0, limit - len(self.requests))
|
||||
|
||||
|
||||
# Parse API keys from environment
|
||||
def _parse_api_keys() -> dict[str, APIKey]:
|
||||
"""Parse API keys from environment variables."""
|
||||
keys = {}
|
||||
|
||||
# Parse comma-separated keys
|
||||
if API_KEYS_ENV:
|
||||
for entry in API_KEYS_ENV.split(","):
|
||||
entry = entry.strip()
|
||||
if ":" in entry:
|
||||
key, name = entry.split(":", 1)
|
||||
else:
|
||||
key, name = entry, "default"
|
||||
keys[key.strip()] = APIKey(key=key.strip(), name=name.strip())
|
||||
|
||||
# Add internal key with no rate limit
|
||||
if INTERNAL_API_KEY:
|
||||
keys[INTERNAL_API_KEY] = APIKey(
|
||||
key=INTERNAL_API_KEY,
|
||||
name="internal",
|
||||
is_internal=True,
|
||||
rate_limit=999999, # Effectively unlimited
|
||||
)
|
||||
|
||||
return keys
|
||||
|
||||
|
||||
# Global state
|
||||
_api_keys = _parse_api_keys()
|
||||
_rate_limits: dict[str, RateLimitInfo] = defaultdict(RateLimitInfo)
|
||||
|
||||
# Security scheme
|
||||
api_key_header = APIKeyHeader(name="X-API-Key", auto_error=False)
|
||||
|
||||
|
||||
@dataclass
|
||||
class AuthResult:
|
||||
"""Result of authentication check."""
|
||||
authenticated: bool
|
||||
key_name: Optional[str] = None
|
||||
is_internal: bool = False
|
||||
rate_limit_remaining: Optional[int] = None
|
||||
|
||||
|
||||
async def verify_api_key(
|
||||
request: Request,
|
||||
api_key: Optional[str] = Security(api_key_header),
|
||||
) -> AuthResult:
|
||||
"""
|
||||
Verify API key and check rate limits.
|
||||
|
||||
Returns AuthResult with authentication status.
|
||||
Raises HTTPException if auth fails or rate limited.
|
||||
"""
|
||||
# Skip auth for health and docs endpoints
|
||||
path = request.url.path
|
||||
if path in ["/health", "/docs", "/openapi.json", "/redoc"]:
|
||||
return AuthResult(authenticated=True, key_name="public")
|
||||
|
||||
# If auth not required, allow all
|
||||
if not REQUIRE_AUTH:
|
||||
return AuthResult(authenticated=True, key_name="anonymous")
|
||||
|
||||
# Check for API key
|
||||
if not api_key:
|
||||
logger.warning(f"Missing API key for {path} from {request.client.host if request.client else 'unknown'}")
|
||||
raise HTTPException(
|
||||
status_code=401,
|
||||
detail="Missing API key. Provide X-API-Key header.",
|
||||
headers={"WWW-Authenticate": "ApiKey"},
|
||||
)
|
||||
|
||||
# Validate key
|
||||
if api_key not in _api_keys:
|
||||
logger.warning(f"Invalid API key attempt for {path}")
|
||||
raise HTTPException(
|
||||
status_code=401,
|
||||
detail="Invalid API key.",
|
||||
headers={"WWW-Authenticate": "ApiKey"},
|
||||
)
|
||||
|
||||
key_info = _api_keys[api_key]
|
||||
|
||||
# Check rate limit (skip for internal keys)
|
||||
if not key_info.is_internal:
|
||||
rate_info = _rate_limits[api_key]
|
||||
if not rate_info.is_allowed(key_info.rate_limit, RATE_LIMIT_WINDOW):
|
||||
remaining = rate_info.remaining(key_info.rate_limit, RATE_LIMIT_WINDOW)
|
||||
logger.warning(f"Rate limit exceeded for key '{key_info.name}'")
|
||||
raise HTTPException(
|
||||
status_code=429,
|
||||
detail=f"Rate limit exceeded. Try again in {RATE_LIMIT_WINDOW} seconds.",
|
||||
headers={
|
||||
"X-RateLimit-Limit": str(key_info.rate_limit),
|
||||
"X-RateLimit-Remaining": str(remaining),
|
||||
"X-RateLimit-Reset": str(int(time.time()) + RATE_LIMIT_WINDOW),
|
||||
"Retry-After": str(RATE_LIMIT_WINDOW),
|
||||
},
|
||||
)
|
||||
remaining = rate_info.remaining(key_info.rate_limit, RATE_LIMIT_WINDOW)
|
||||
else:
|
||||
remaining = None
|
||||
|
||||
logger.debug(f"Authenticated request from '{key_info.name}' to {path}")
|
||||
|
||||
return AuthResult(
|
||||
authenticated=True,
|
||||
key_name=key_info.name,
|
||||
is_internal=key_info.is_internal,
|
||||
rate_limit_remaining=remaining,
|
||||
)
|
||||
|
||||
|
||||
def get_api_key_stats() -> dict:
|
||||
"""Get statistics about API keys (for admin endpoint)."""
|
||||
stats = {
|
||||
"total_keys": len(_api_keys),
|
||||
"auth_required": REQUIRE_AUTH,
|
||||
"rate_limit": {
|
||||
"requests_per_window": RATE_LIMIT_REQUESTS,
|
||||
"window_seconds": RATE_LIMIT_WINDOW,
|
||||
},
|
||||
"keys": [],
|
||||
}
|
||||
|
||||
for key, info in _api_keys.items():
|
||||
# Don't expose actual keys, just metadata
|
||||
masked_key = key[:8] + "..." if len(key) > 8 else "***"
|
||||
rate_info = _rate_limits.get(key, RateLimitInfo())
|
||||
stats["keys"].append({
|
||||
"name": info.name,
|
||||
"key_prefix": masked_key,
|
||||
"is_internal": info.is_internal,
|
||||
"requests_in_window": len(rate_info.requests),
|
||||
"remaining": rate_info.remaining(info.rate_limit, RATE_LIMIT_WINDOW),
|
||||
})
|
||||
|
||||
return stats
|
||||
|
||||
|
||||
def reload_api_keys():
|
||||
"""Reload API keys from environment (for runtime updates)."""
|
||||
global _api_keys
|
||||
_api_keys = _parse_api_keys()
|
||||
logger.info(f"Reloaded {len(_api_keys)} API keys")
|
||||
|
|
@ -14,10 +14,12 @@ from contextlib import asynccontextmanager
|
|||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
from fastapi import FastAPI, HTTPException, UploadFile, File, Form, Response
|
||||
from fastapi import FastAPI, HTTPException, UploadFile, File, Form, Response, Depends
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from .auth import verify_api_key, AuthResult, REQUIRE_AUTH
|
||||
|
||||
from .audio_utils import convert_audio, SUPPORTED_FORMATS, cleanup_temp_file, save_temp_audio
|
||||
from .kokoro_service import (
|
||||
synthesize_kokoro,
|
||||
|
|
@ -142,6 +144,7 @@ class HealthResponse(BaseModel):
|
|||
status: str
|
||||
service: str
|
||||
models_loaded: dict
|
||||
auth_required: bool
|
||||
|
||||
|
||||
class ModelsResponse(BaseModel):
|
||||
|
|
@ -196,11 +199,12 @@ async def health_check():
|
|||
"kokoro": is_kokoro_loaded(),
|
||||
"f5": is_f5_loaded(),
|
||||
},
|
||||
auth_required=REQUIRE_AUTH,
|
||||
)
|
||||
|
||||
|
||||
@app.get("/models", response_model=ModelsResponse)
|
||||
async def get_models():
|
||||
async def get_models(auth: AuthResult = Depends(verify_api_key)):
|
||||
"""Get information about available models."""
|
||||
return ModelsResponse(
|
||||
kokoro={
|
||||
|
|
@ -226,7 +230,7 @@ async def get_models():
|
|||
|
||||
|
||||
@app.get("/voices", response_model=VoicesResponse)
|
||||
async def get_voices():
|
||||
async def get_voices(auth: AuthResult = Depends(verify_api_key)):
|
||||
"""Get all available voices."""
|
||||
# Kokoro preset voices
|
||||
kokoro_voices = [
|
||||
|
|
@ -264,6 +268,7 @@ async def register_voice(
|
|||
description: str = Form("", description="Voice description"),
|
||||
transcript: str = Form(..., description="Transcript of the reference audio"),
|
||||
reference_audio: UploadFile = File(..., description="Reference audio file"),
|
||||
auth: AuthResult = Depends(verify_api_key),
|
||||
):
|
||||
"""
|
||||
Register a new custom voice for F5-TTS voice cloning.
|
||||
|
|
@ -313,7 +318,7 @@ async def register_voice(
|
|||
|
||||
|
||||
@app.delete("/voices/{voice_id}", response_model=VoiceDeletedResponse)
|
||||
async def delete_voice(voice_id: str):
|
||||
async def delete_voice(voice_id: str, auth: AuthResult = Depends(verify_api_key)):
|
||||
"""Delete a registered custom voice."""
|
||||
voice_manager = get_voice_manager()
|
||||
|
||||
|
|
@ -332,7 +337,10 @@ async def delete_voice(voice_id: str):
|
|||
|
||||
|
||||
@app.post("/synthesize/kokoro")
|
||||
async def synthesize_with_kokoro(request: KokoroRequest):
|
||||
async def synthesize_with_kokoro(
|
||||
request: KokoroRequest,
|
||||
auth: AuthResult = Depends(verify_api_key),
|
||||
):
|
||||
"""
|
||||
Synthesize speech using Kokoro with preset voices.
|
||||
|
||||
|
|
@ -403,6 +411,7 @@ async def synthesize_with_f5(
|
|||
output_format: str = Form("wav", description="Output format (wav, mp3)"),
|
||||
speed: float = Form(1.0, ge=0.5, le=2.0, description="Speech speed"),
|
||||
steps: int = Form(32, ge=8, le=64, description="Diffusion steps"),
|
||||
auth: AuthResult = Depends(verify_api_key),
|
||||
):
|
||||
"""
|
||||
Synthesize speech using F5-TTS with voice cloning.
|
||||
|
|
@ -520,7 +529,10 @@ async def synthesize_with_f5(
|
|||
|
||||
|
||||
@app.post("/synthesize/auto")
|
||||
async def synthesize_auto(request: AutoRequest):
|
||||
async def synthesize_auto(
|
||||
request: AutoRequest,
|
||||
auth: AuthResult = Depends(verify_api_key),
|
||||
):
|
||||
"""
|
||||
Auto-select the best TTS model based on voice parameter.
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue