chore(ai-services): adopt Windows GPU as source of truth for llm/stt/tts

The Windows GPU server has been the actual production home for these
services for some time, and the running code there has drifted ahead of
the repo. This sync pulls the live versions back into the repo so the
Windows box is no longer the only place those changes exist.

Pulled from C:\mana\services\* on mana-server-gpu (192.168.178.11):

mana-llm:
- src/main.py, src/config.py — small fixes (auth wiring, config tweaks)
- src/api_auth.py — NEW (cross-service GPU_API_KEY validator)
- service.pyw — Windows runner used by the ManaLLM scheduled task
  (sets up logging redirect, loads .env, calls uvicorn)

mana-stt:
- app/main.py — substantial cleanup (684→392 lines), drops the
  whisperx-as-separate-backend branching now that whisper_service.py
  rolls whisperx in directly
- app/whisper_service.py — full CUDA + whisperx rewrite (158→358 lines)
- app/auth.py + external_auth.py — significantly expanded auth
- app/vram_manager.py — NEW (shared VRAM accounting helper)
- service.pyw — Windows runner with CUDA pre-init, FFmpeg PATH
  injection, .env loading
- removed: app/whisper_service_cuda.py (folded into whisper_service.py)
- removed: app/whisperx_service.py (folded into whisper_service.py)

mana-tts:
- app/auth.py, external_auth.py — same auth expansion as stt
- app/f5_service.py, kokoro_service.py — Windows tweaks
- app/vram_manager.py — NEW (same shared helper as stt)
- service.pyw — Windows runner

mana-video-gen:
- service.pyw — Windows runner (no other changes; the .py code on the
  GPU box is byte-identical to what's already in the repo)

The service.pyw files contain absolute Windows paths
(C:\mana\services\<svc>) and a hardcoded FFmpeg PATH for the tills user
profile. Kept as-is intentionally — they exist to be deployed to that
one machine and any abstraction layer would just hide what's actually
happening. Anyone redeploying to a different layout will need to edit
the path strings, which is a known and obvious change.

Mac-Mini infrastructure for these services (launchd plists, install
scripts, scripts/mac-mini/setup-{stt,tts}.sh, the Mac-flux2c image-gen
implementation) is still on disk and will be removed in a follow-up
commit, along with replacing mana-image-gen with the Windows
diffusers+CUDA implementation. This commit is just the live-code sync.
This commit is contained in:
Till JS 2026-04-08 12:46:03 +02:00
parent abe0a21966
commit b8e18b7f82
20 changed files with 1623 additions and 1261 deletions

View file

@ -23,3 +23,6 @@ CACHE_TTL=3600
# CORS
CORS_ORIGINS=http://localhost:5173,https://mana.how
# API key for cross-service auth (validated in src/api_auth.py)
GPU_API_KEY=

View file

@ -0,0 +1,17 @@
"""mana-llm service runner (run with pythonw.exe to run headless)."""
import os
import sys
os.chdir(r"C:\mana\services\mana-llm")
sys.path.insert(0, r"C:\mana\services\mana-llm")
# Load .env file
from dotenv import load_dotenv
load_dotenv(r"C:\mana\services\mana-llm\.env")
# Redirect stdout/stderr to log file
log = open(r"C:\mana\services\mana-llm\service.log", "w", buffering=1)
sys.stdout = log
sys.stderr = log
import uvicorn
uvicorn.run("src.main:app", host="0.0.0.0", port=3025, log_level="info")

View file

@ -0,0 +1,53 @@
"""
Simple API Key Authentication Middleware for GPU Services.
Checks X-API-Key header or ?api_key query parameter.
Skips auth for /health, /docs, /openapi.json, /redoc endpoints.
Environment variables:
GPU_API_KEY: Required API key (if empty, auth is disabled)
GPU_REQUIRE_AUTH: Enable/disable auth (default: true if GPU_API_KEY is set)
"""
import os
import logging
from fastapi import Request
from fastapi.responses import JSONResponse
from starlette.middleware.base import BaseHTTPMiddleware
logger = logging.getLogger(__name__)
GPU_API_KEY = os.getenv("GPU_API_KEY", "")
GPU_REQUIRE_AUTH = os.getenv("GPU_REQUIRE_AUTH", "true" if GPU_API_KEY else "false").lower() == "true"
# Endpoints that don't require auth
PUBLIC_PATHS = {"/health", "/docs", "/openapi.json", "/redoc", "/metrics"}
class ApiKeyMiddleware(BaseHTTPMiddleware):
async def dispatch(self, request: Request, call_next):
# Skip auth if disabled
if not GPU_REQUIRE_AUTH or not GPU_API_KEY:
return await call_next(request)
# Skip auth for public endpoints
if request.url.path in PUBLIC_PATHS:
return await call_next(request)
# Check API key from header or query param
api_key = request.headers.get("X-API-Key") or request.query_params.get("api_key")
if not api_key:
return JSONResponse(
status_code=401,
content={"detail": "Missing API key. Provide X-API-Key header."},
)
if api_key != GPU_API_KEY:
logger.warning(f"Invalid API key attempt from {request.client.host if request.client else 'unknown'}")
return JSONResponse(
status_code=401,
content={"detail": "Invalid API key."},
)
return await call_next(request)

View file

@ -51,6 +51,7 @@ class Settings(BaseSettings):
class Config:
env_file = ".env"
env_file_encoding = "utf-8"
extra = "ignore"
settings = Settings()

View file

@ -10,6 +10,7 @@ from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import Response
from sse_starlette.sse import EventSourceResponse
from src.api_auth import ApiKeyMiddleware
from src.config import settings
from src.models import (
ChatCompletionRequest,
@ -70,6 +71,7 @@ app.add_middleware(
allow_methods=["*"],
allow_headers=["*"],
)
app.add_middleware(ApiKeyMiddleware)
# Health endpoint

View file

@ -1,41 +1,271 @@
"""
API Key Authentication for Mana STT Service.
Delegates to shared mana_auth package.
API Key Authentication for ManaCore STT Service
Supports two authentication modes:
1. Local API keys: Configured via environment variables
2. External API keys: Validated via mana-core-auth service (when EXTERNAL_AUTH_ENABLED=true)
Usage:
# Local keys
API_KEYS=sk-key1:name1,sk-key2:name2
INTERNAL_API_KEY=sk-internal-xxx
# External auth (for user-created keys via mana.how)
EXTERNAL_AUTH_ENABLED=true
MANA_CORE_AUTH_URL=http://localhost:3001
"""
# Re-export everything from shared auth for backward compatibility
import sys
import os
import time
import logging
from typing import Optional
from collections import defaultdict
from dataclasses import dataclass, field
# Add shared-python to path
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "..", "..", "packages", "shared-python"))
from fastapi import HTTPException, Security, Request
from fastapi.security import APIKeyHeader
from mana_auth import (
APIKey,
AuthResult,
RateLimitInfo,
verify_api_key as _verify_api_key,
get_api_key_stats,
reload_api_keys,
api_key_header,
create_auth_dependency,
)
from mana_auth.external_auth import (
ExternalValidationResult,
from .external_auth import (
is_external_auth_enabled,
validate_api_key_external,
ExternalValidationResult,
)
from typing import Optional
from fastapi import Security, Request
logger = logging.getLogger(__name__)
# STT-specific auth dependency
verify_stt_key = create_auth_dependency("stt")
# Configuration
API_KEYS_ENV = os.getenv("API_KEYS", "") # Format: "sk-key1:name1,sk-key2:name2"
INTERNAL_API_KEY = os.getenv("INTERNAL_API_KEY", "") # Unlimited internal key
REQUIRE_AUTH = os.getenv("REQUIRE_AUTH", "true").lower() == "true"
RATE_LIMIT_REQUESTS = int(os.getenv("RATE_LIMIT_REQUESTS", "60")) # Per minute
RATE_LIMIT_WINDOW = int(os.getenv("RATE_LIMIT_WINDOW", "60")) # Seconds
@dataclass
class APIKey:
"""API Key with metadata."""
key: str
name: str
is_internal: bool = False
rate_limit: int = RATE_LIMIT_REQUESTS # Requests per window
@dataclass
class RateLimitInfo:
"""Rate limit tracking per key."""
requests: list = field(default_factory=list)
def is_allowed(self, limit: int, window: int) -> bool:
"""Check if request is allowed within rate limit."""
now = time.time()
# Remove old requests outside window
self.requests = [t for t in self.requests if now - t < window]
if len(self.requests) >= limit:
return False
self.requests.append(now)
return True
def remaining(self, limit: int, window: int) -> int:
"""Get remaining requests in current window."""
now = time.time()
self.requests = [t for t in self.requests if now - t < window]
return max(0, limit - len(self.requests))
# Parse API keys from environment
def _parse_api_keys() -> dict[str, APIKey]:
"""Parse API keys from environment variables."""
keys = {}
# Parse comma-separated keys
if API_KEYS_ENV:
for entry in API_KEYS_ENV.split(","):
entry = entry.strip()
if ":" in entry:
key, name = entry.split(":", 1)
else:
key, name = entry, "default"
keys[key.strip()] = APIKey(key=key.strip(), name=name.strip())
# Add internal key with no rate limit
if INTERNAL_API_KEY:
keys[INTERNAL_API_KEY] = APIKey(
key=INTERNAL_API_KEY,
name="internal",
is_internal=True,
rate_limit=999999, # Effectively unlimited
)
return keys
# Global state
_api_keys = _parse_api_keys()
_rate_limits: dict[str, RateLimitInfo] = defaultdict(RateLimitInfo)
# Security scheme
api_key_header = APIKeyHeader(name="X-API-Key", auto_error=False)
@dataclass
class AuthResult:
"""Result of authentication check."""
authenticated: bool
key_name: Optional[str] = None
is_internal: bool = False
rate_limit_remaining: Optional[int] = None
user_id: Optional[str] = None # Set when using external auth
async def verify_api_key(
request: Request,
api_key: Optional[str] = Security(api_key_header),
) -> AuthResult:
"""Verify API key with STT scope."""
return await _verify_api_key(request, scope="stt", api_key=api_key)
"""
Verify API key and check rate limits.
Supports two authentication modes:
1. External auth via mana-core-auth (for sk_live_ keys)
2. Local auth via environment variables
Returns AuthResult with authentication status.
Raises HTTPException if auth fails or rate limited.
"""
# Skip auth for health and docs endpoints
path = request.url.path
if path in ["/health", "/docs", "/openapi.json", "/redoc"]:
return AuthResult(authenticated=True, key_name="public")
# If auth not required, allow all
if not REQUIRE_AUTH:
return AuthResult(authenticated=True, key_name="anonymous")
# Check for API key
if not api_key:
logger.warning(f"Missing API key for {path} from {request.client.host if request.client else 'unknown'}")
raise HTTPException(
status_code=401,
detail="Missing API key. Provide X-API-Key header.",
headers={"WWW-Authenticate": "ApiKey"},
)
# Try external auth first for sk_live_ keys (user-created keys via mana.how)
if api_key.startswith("sk_live_") and is_external_auth_enabled():
external_result = await validate_api_key_external(api_key, "stt")
if external_result is not None:
if external_result.valid:
# Use rate limits from external auth
rate_info = _rate_limits[api_key]
limit = external_result.rate_limit_requests
window = external_result.rate_limit_window
if not rate_info.is_allowed(limit, window):
remaining = rate_info.remaining(limit, window)
logger.warning(f"Rate limit exceeded for external key")
raise HTTPException(
status_code=429,
detail=f"Rate limit exceeded. Try again in {window} seconds.",
headers={
"X-RateLimit-Limit": str(limit),
"X-RateLimit-Remaining": str(remaining),
"X-RateLimit-Reset": str(int(time.time()) + window),
"Retry-After": str(window),
},
)
remaining = rate_info.remaining(limit, window)
logger.debug(f"Authenticated external request from user {external_result.user_id} to {path}")
return AuthResult(
authenticated=True,
key_name="external",
is_internal=False,
rate_limit_remaining=remaining,
user_id=external_result.user_id,
)
else:
# External auth returned invalid
logger.warning(f"External auth failed: {external_result.error}")
raise HTTPException(
status_code=401,
detail=external_result.error or "Invalid API key.",
headers={"WWW-Authenticate": "ApiKey"},
)
# If external_result is None, fall through to local auth
# Local auth: Validate key against environment variables
if api_key not in _api_keys:
logger.warning(f"Invalid API key attempt for {path}")
raise HTTPException(
status_code=401,
detail="Invalid API key.",
headers={"WWW-Authenticate": "ApiKey"},
)
key_info = _api_keys[api_key]
# Check rate limit (skip for internal keys)
if not key_info.is_internal:
rate_info = _rate_limits[api_key]
if not rate_info.is_allowed(key_info.rate_limit, RATE_LIMIT_WINDOW):
remaining = rate_info.remaining(key_info.rate_limit, RATE_LIMIT_WINDOW)
logger.warning(f"Rate limit exceeded for key '{key_info.name}'")
raise HTTPException(
status_code=429,
detail=f"Rate limit exceeded. Try again in {RATE_LIMIT_WINDOW} seconds.",
headers={
"X-RateLimit-Limit": str(key_info.rate_limit),
"X-RateLimit-Remaining": str(remaining),
"X-RateLimit-Reset": str(int(time.time()) + RATE_LIMIT_WINDOW),
"Retry-After": str(RATE_LIMIT_WINDOW),
},
)
remaining = rate_info.remaining(key_info.rate_limit, RATE_LIMIT_WINDOW)
else:
remaining = None
logger.debug(f"Authenticated request from '{key_info.name}' to {path}")
return AuthResult(
authenticated=True,
key_name=key_info.name,
is_internal=key_info.is_internal,
rate_limit_remaining=remaining,
)
def get_api_key_stats() -> dict:
"""Get statistics about API keys (for admin endpoint)."""
stats = {
"total_keys": len(_api_keys),
"auth_required": REQUIRE_AUTH,
"rate_limit": {
"requests_per_window": RATE_LIMIT_REQUESTS,
"window_seconds": RATE_LIMIT_WINDOW,
},
"keys": [],
}
for key, info in _api_keys.items():
# Don't expose actual keys, just metadata
masked_key = key[:8] + "..." if len(key) > 8 else "***"
rate_info = _rate_limits.get(key, RateLimitInfo())
stats["keys"].append({
"name": info.name,
"key_prefix": masked_key,
"is_internal": info.is_internal,
"requests_in_window": len(rate_info.requests),
"remaining": rate_info.remaining(info.rate_limit, RATE_LIMIT_WINDOW),
})
return stats
def reload_api_keys():
"""Reload API keys from environment (for runtime updates)."""
global _api_keys
_api_keys = _parse_api_keys()
logger.info(f"Reloaded {len(_api_keys)} API keys")

View file

@ -1,22 +1,145 @@
"""
External API Key Validation delegates to shared mana_auth package.
External API Key Validation via mana-core-auth
When EXTERNAL_AUTH_ENABLED=true, API keys are validated against the
central mana-core-auth service. This allows users to create and manage
API keys from the mana.how web interface.
Results are cached for 5 minutes to reduce load on the auth service.
"""
import sys
import os
import time
import logging
import httpx
from typing import Optional
from dataclasses import dataclass
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "..", "..", "packages", "shared-python"))
logger = logging.getLogger(__name__)
from mana_auth.external_auth import (
ExternalValidationResult,
is_external_auth_enabled,
validate_api_key_external,
clear_cache,
)
# Configuration
EXTERNAL_AUTH_ENABLED = os.getenv("EXTERNAL_AUTH_ENABLED", "false").lower() == "true"
MANA_CORE_AUTH_URL = os.getenv("MANA_CORE_AUTH_URL", "http://localhost:3001")
API_KEY_CACHE_TTL = int(os.getenv("API_KEY_CACHE_TTL", "300")) # 5 minutes
EXTERNAL_AUTH_TIMEOUT = float(os.getenv("EXTERNAL_AUTH_TIMEOUT", "5.0")) # seconds
__all__ = [
"ExternalValidationResult",
"is_external_auth_enabled",
"validate_api_key_external",
"clear_cache",
]
@dataclass
class ExternalValidationResult:
"""Result from external API key validation."""
valid: bool
user_id: Optional[str] = None
scopes: Optional[list] = None
rate_limit_requests: int = 60
rate_limit_window: int = 60
error: Optional[str] = None
cached_at: float = 0.0
# In-memory cache for validation results
# Key: API key, Value: ExternalValidationResult
_validation_cache: dict[str, ExternalValidationResult] = {}
def is_external_auth_enabled() -> bool:
"""Check if external authentication is enabled."""
return EXTERNAL_AUTH_ENABLED
def _get_cached_result(api_key: str) -> Optional[ExternalValidationResult]:
"""Get cached validation result if still valid."""
result = _validation_cache.get(api_key)
if result and (time.time() - result.cached_at) < API_KEY_CACHE_TTL:
return result
return None
def _cache_result(api_key: str, result: ExternalValidationResult):
"""Cache a validation result."""
result.cached_at = time.time()
_validation_cache[api_key] = result
# Clean up old entries periodically (keep cache size manageable)
if len(_validation_cache) > 1000:
now = time.time()
expired_keys = [
k for k, v in _validation_cache.items()
if (now - v.cached_at) >= API_KEY_CACHE_TTL
]
for k in expired_keys:
del _validation_cache[k]
async def validate_api_key_external(api_key: str, scope: str) -> Optional[ExternalValidationResult]:
"""
Validate an API key against mana-core-auth service.
Args:
api_key: The API key to validate (e.g., "sk_live_...")
scope: The required scope (e.g., "stt" or "tts")
Returns:
ExternalValidationResult if external auth is enabled and the key was validated.
None if external auth is disabled or the service is unavailable (fallback to local).
"""
if not EXTERNAL_AUTH_ENABLED:
return None
# Check cache first
cached = _get_cached_result(api_key)
if cached:
logger.debug(f"Using cached validation result for key prefix: {api_key[:12]}...")
# Check scope against cached result
if cached.valid and cached.scopes and scope not in cached.scopes:
return ExternalValidationResult(
valid=False,
error=f"API key does not have scope: {scope}",
)
return cached
# Call mana-core-auth validation endpoint
try:
async with httpx.AsyncClient(timeout=EXTERNAL_AUTH_TIMEOUT) as client:
response = await client.post(
f"{MANA_CORE_AUTH_URL}/api/v1/api-keys/validate",
json={"apiKey": api_key, "scope": scope},
)
if response.status_code == 200:
data = response.json()
result = ExternalValidationResult(
valid=data.get("valid", False),
user_id=data.get("userId"),
scopes=data.get("scopes", []),
rate_limit_requests=data.get("rateLimit", {}).get("requests", 60),
rate_limit_window=data.get("rateLimit", {}).get("window", 60),
error=data.get("error"),
)
_cache_result(api_key, result)
return result
else:
logger.warning(
f"External auth returned status {response.status_code}: {response.text}"
)
# Don't cache errors - allow retry
return ExternalValidationResult(
valid=False,
error=f"Auth service returned {response.status_code}",
)
except httpx.TimeoutException:
logger.warning("External auth service timeout - falling back to local auth")
return None
except httpx.ConnectError:
logger.warning("Cannot connect to external auth service - falling back to local auth")
return None
except Exception as e:
logger.error(f"External auth error: {e}")
return None
def clear_cache():
"""Clear the validation cache (for testing or runtime updates)."""
global _validation_cache
_validation_cache.clear()
logger.info("External auth cache cleared")

View file

@ -1,6 +1,6 @@
"""
Mana STT API Service
Speech-to-Text with Whisper (MLX), WhisperX (CUDA), Voxtral (vLLM), and Mistral API (fallback)
ManaCore STT API Service (WhisperX Edition)
Speech-to-Text with WhisperX: transcription, word timestamps, speaker diarization.
Run with: uvicorn app.main:app --host 0.0.0.0 --port 3020
"""
@ -34,66 +34,42 @@ CORS_ORIGINS = os.getenv(
"https://mana.how,https://chat.mana.how,http://localhost:5173"
).split(",")
# vLLM configuration (disabled by default - has issues on macOS CPU)
# vLLM configuration
VLLM_URL = os.getenv("VLLM_URL", "http://localhost:8100")
USE_VLLM = os.getenv("USE_VLLM", "false").lower() == "true"
# WhisperX configuration (CUDA GPU server)
USE_WHISPERX = os.getenv("USE_WHISPERX", "false").lower() == "true"
# Response models
class WordInfo(BaseModel):
word: str
start: float
end: float
score: Optional[float] = None
speaker: Optional[str] = None
class SegmentInfo(BaseModel):
start: float
end: float
text: str
speaker: Optional[str] = None
class TranscriptionResponse(BaseModel):
text: str
language: Optional[str] = None
model: str
latency_ms: Optional[float] = None
duration_seconds: Optional[float] = None
class WordTimestampResponse(BaseModel):
word: str
start: float
end: float
score: float = 0.0
speaker: Optional[str] = None
class SegmentResponse(BaseModel):
start: float
end: float
text: str
speaker: Optional[str] = None
words: list[WordTimestampResponse] = []
class UtteranceResponse(BaseModel):
speaker: int
text: str
offset: int # milliseconds
duration: int # milliseconds
class RichTranscriptionResponse(BaseModel):
"""Extended response with segments, utterances, and speaker diarization."""
text: str
language: Optional[str] = None
model: str
latency_ms: Optional[float] = None
duration_seconds: Optional[float] = None
segments: list[SegmentResponse] = []
utterances: list[UtteranceResponse] = []
speakers: dict[str, str] = {}
speaker_map: dict[str, int] = {}
languages: list[str] = []
primary_language: Optional[str] = None
words: list[WordTimestampResponse] = []
words: Optional[list[WordInfo]] = None
segments: Optional[list[SegmentInfo]] = None
speakers: Optional[list[str]] = None
class HealthResponse(BaseModel):
status: str
whisper_loaded: bool
whisperx_available: bool
whisperx: bool
vllm_available: bool
vllm_url: Optional[str] = None
mistral_api_available: bool
@ -103,7 +79,6 @@ class HealthResponse(BaseModel):
class ModelsResponse(BaseModel):
whisper: list
whisperx: list
voxtral_vllm: list
default_whisper: str
@ -111,7 +86,6 @@ class ModelsResponse(BaseModel):
# Track loaded models
models_status = {
"whisper_loaded": False,
"whisperx_available": False,
"vllm_available": False,
}
@ -119,45 +93,28 @@ models_status = {
@asynccontextmanager
async def lifespan(app: FastAPI):
"""Startup and shutdown events."""
logger.info("Starting Mana STT Service...")
logger.info("Starting ManaCore STT Service (WhisperX Edition)...")
# Check vLLM availability
if USE_VLLM:
from app.vllm_service import check_health
health = await check_health()
models_status["vllm_available"] = health.get("status") == "healthy"
if models_status["vllm_available"]:
logger.info(f"vLLM server available at {VLLM_URL}")
else:
logger.warning(f"vLLM server not available: {health}")
# Check WhisperX availability
if USE_WHISPERX:
try:
from app.whisperx_service import is_available as whisperx_available
models_status["whisperx_available"] = whisperx_available()
if models_status["whisperx_available"]:
logger.info("WhisperX (CUDA) available")
else:
logger.warning("WhisperX not available (whisperx package not installed)")
except Exception as e:
logger.warning(f"WhisperX check failed: {e}")
# Check Mistral API
from app.voxtral_api_service import is_available as api_available
if api_available():
logger.info("Mistral API fallback configured")
# Optionally preload Whisper
if PRELOAD_MODELS:
logger.info("Preloading Whisper model...")
try:
from app.whisper_service import get_whisper_model
get_whisper_model(DEFAULT_WHISPER_MODEL)
models_status["whisper_loaded"] = True
logger.info("Whisper model preloaded")
except Exception as e:
logger.warning(f"Failed to preload Whisper: {e}")
# Always preload WhisperX model at startup (avoids timeout on first request)
logger.info("Preloading WhisperX model...")
try:
from app.whisper_service import get_whisper_model
get_whisper_model(DEFAULT_WHISPER_MODEL)
models_status["whisper_loaded"] = True
logger.info("WhisperX model preloaded successfully")
except Exception as e:
logger.warning(f"Failed to preload WhisperX: {e}")
logger.info(f"STT Service ready on port {PORT}")
yield
@ -166,9 +123,9 @@ async def lifespan(app: FastAPI):
# Create FastAPI app
app = FastAPI(
title="Mana STT Service",
description="Speech-to-Text API with Whisper (MLX), Voxtral (vLLM), and Mistral API",
version="2.0.0",
title="ManaCore STT Service",
description="Speech-to-Text API with WhisperX (word timestamps + speaker diarization)",
version="3.0.0",
lifespan=lifespan,
)
@ -193,13 +150,15 @@ async def health_check():
return HealthResponse(
status="healthy",
whisper_loaded=models_status["whisper_loaded"],
whisperx_available=models_status["whisperx_available"],
whisperx=True,
vllm_available=vllm_health.get("status") == "healthy",
vllm_url=VLLM_URL if USE_VLLM else None,
mistral_api_available=api_available(),
auth_required=REQUIRE_AUTH,
models={
"default_whisper": DEFAULT_WHISPER_MODEL,
"engine": "whisperx",
"features": ["transcription", "word_timestamps", "speaker_diarization"],
},
)
@ -212,17 +171,8 @@ async def list_models(auth: AuthResult = Depends(verify_api_key)):
vllm_models = await get_models()
whisperx_models = []
if USE_WHISPERX:
try:
from app.whisperx_service import AVAILABLE_MODELS as wx_models
whisperx_models = wx_models
except ImportError:
pass
return ModelsResponse(
whisper=whisper_models,
whisperx=whisperx_models,
voxtral_vllm=vllm_models,
default_whisper=DEFAULT_WHISPER_MODEL,
)
@ -234,16 +184,22 @@ async def transcribe_whisper(
file: UploadFile = File(..., description="Audio file to transcribe"),
language: Optional[str] = Form(None, description="Language code (auto-detect if not provided)"),
model: Optional[str] = Form(None, description="Whisper model to use"),
align: bool = Form(True, description="Enable word-level timestamp alignment"),
diarize: bool = Form(False, description="Enable speaker diarization"),
min_speakers: Optional[int] = Form(None, description="Min expected speakers (helps diarization)"),
max_speakers: Optional[int] = Form(None, description="Max expected speakers"),
auth: AuthResult = Depends(verify_api_key),
):
"""
Transcribe audio using Whisper (Lightning MLX).
Transcribe audio using WhisperX.
Best for: General transcription, many languages
Supported formats: mp3, wav, m4a, flac, ogg, webm
Features:
- Word-level timestamps (align=true, default)
- Speaker diarization (diarize=true, opt-in)
Supported formats: mp3, wav, m4a, flac, ogg, webm, mp4
Max file size: 100MB
"""
# Add rate limit headers
if auth.rate_limit_remaining is not None:
response.headers["X-RateLimit-Remaining"] = str(auth.rate_limit_remaining)
@ -274,20 +230,51 @@ async def transcribe_whisper(
filename=file.filename,
language=language,
model_name=model_name,
align=align,
diarize=diarize,
min_speakers=min_speakers,
max_speakers=max_speakers,
)
models_status["whisper_loaded"] = True
latency_ms = (time.time() - start_time) * 1000
return TranscriptionResponse(
# Build response
resp = TranscriptionResponse(
text=result.text,
language=result.language,
model=f"whisper-{model_name}",
model=f"whisperx-{model_name}",
latency_ms=latency_ms,
duration_seconds=result.duration,
)
# Add word timestamps if available
if result.words:
resp.words = [
WordInfo(
word=w.word,
start=w.start,
end=w.end,
score=w.score,
speaker=w.speaker,
)
for w in result.words
]
# Add segments
if result.segments:
resp.segments = [
SegmentInfo(**s) for s in result.segments
]
# Add speakers
if result.speakers:
resp.speakers = result.speakers
return resp
except Exception as e:
logger.error(f"Whisper transcription error: {e}")
logger.error(f"WhisperX transcription error: {e}")
raise HTTPException(status_code=500, detail=str(e))
@ -296,22 +283,10 @@ async def transcribe_voxtral(
response: Response,
file: UploadFile = File(..., description="Audio file to transcribe"),
language: str = Form("de", description="Language code"),
use_realtime: bool = Form(False, description="Use Realtime 4B model for lower latency"),
use_realtime: bool = Form(False, description="Use Realtime 4B model"),
auth: AuthResult = Depends(verify_api_key),
):
"""
Transcribe audio using Voxtral via vLLM server.
Models:
- Voxtral Mini 3B (default): Best quality
- Voxtral Realtime 4B: Lower latency (<500ms)
Falls back to Mistral API if vLLM is unavailable.
Supported formats: mp3, wav, m4a, flac, ogg, webm
Max file size: 100MB
"""
# Add rate limit headers
"""Transcribe audio using Voxtral via vLLM or Mistral API."""
if auth.rate_limit_remaining is not None:
response.headers["X-RateLimit-Remaining"] = str(auth.rate_limit_remaining)
@ -345,49 +320,30 @@ async def transcribe_voxtral(
if USE_VLLM:
health = await check_health()
if health.get("status") == "healthy":
logger.info("Using vLLM for Voxtral transcription")
if use_realtime:
result = await transcribe_with_realtime(
audio_bytes=audio_bytes,
filename=file.filename,
language=language,
audio_bytes=audio_bytes, filename=file.filename, language=language,
)
else:
result = await vllm_transcribe(
audio_bytes=audio_bytes,
filename=file.filename,
language=language,
audio_bytes=audio_bytes, filename=file.filename, language=language,
)
return TranscriptionResponse(
text=result.text,
language=result.language,
model=result.model,
latency_ms=result.latency_ms,
duration_seconds=result.duration_seconds,
text=result.text, language=result.language, model=result.model,
latency_ms=result.latency_ms, duration_seconds=result.duration_seconds,
)
# Fallback to Mistral API
if api_available():
logger.info("Falling back to Mistral API")
result = await api_transcribe(
audio_bytes=audio_bytes,
filename=file.filename,
language=language,
audio_bytes=audio_bytes, filename=file.filename, language=language,
)
return TranscriptionResponse(
text=result.text,
language=result.language,
model=result.model,
latency_ms=None,
text=result.text, language=result.language, model=result.model,
duration_seconds=result.duration_seconds,
)
raise HTTPException(
status_code=503,
detail="Voxtral not available. Start vLLM server or configure MISTRAL_API_KEY."
)
raise HTTPException(status_code=503, detail="Voxtral not available.")
except HTTPException:
raise
@ -396,273 +352,30 @@ async def transcribe_voxtral(
raise HTTPException(status_code=500, detail=str(e))
@app.post("/transcribe/voxtral/api", response_model=TranscriptionResponse)
async def transcribe_voxtral_api(
response: Response,
file: UploadFile = File(..., description="Audio file to transcribe"),
language: Optional[str] = Form(None, description="Language code (auto-detect if not provided)"),
diarization: bool = Form(False, description="Enable speaker diarization"),
auth: AuthResult = Depends(verify_api_key),
):
"""
Transcribe audio using Mistral's Voxtral API directly.
Features:
- Speaker diarization
- Auto language detection
- High quality (~4% WER)
Requires MISTRAL_API_KEY environment variable.
"""
# Add rate limit headers
if auth.rate_limit_remaining is not None:
response.headers["X-RateLimit-Remaining"] = str(auth.rate_limit_remaining)
from app.voxtral_api_service import is_available, transcribe_audio_bytes
if not is_available():
raise HTTPException(
status_code=503,
detail="Mistral API not configured. Set MISTRAL_API_KEY environment variable."
)
if not file.filename:
raise HTTPException(status_code=400, detail="No file provided")
try:
audio_bytes = await file.read()
if len(audio_bytes) > 100 * 1024 * 1024:
raise HTTPException(status_code=400, detail="File too large (max 100MB)")
result = await transcribe_audio_bytes(
audio_bytes=audio_bytes,
filename=file.filename,
language=language,
diarization=diarization,
)
return TranscriptionResponse(
text=result.text,
language=result.language,
model=result.model,
duration_seconds=result.duration_seconds,
)
except Exception as e:
logger.error(f"Mistral API error: {e}")
raise HTTPException(status_code=500, detail=str(e))
@app.post("/transcribe/whisperx", response_model=RichTranscriptionResponse)
async def transcribe_whisperx(
response: Response,
file: UploadFile = File(..., description="Audio file to transcribe"),
language: Optional[str] = Form(None, description="Language code (auto-detect if not provided)"),
model: Optional[str] = Form(None, description="Whisper model to use"),
diarization: bool = Form(True, description="Enable speaker diarization"),
alignment: bool = Form(True, description="Enable word-level alignment"),
min_speakers: Optional[int] = Form(None, description="Minimum expected speakers"),
max_speakers: Optional[int] = Form(None, description="Maximum expected speakers"),
auth: AuthResult = Depends(verify_api_key),
):
"""
Transcribe audio using WhisperX (CUDA GPU).
Returns rich transcription with:
- Word-level timestamps (via forced alignment)
- Speaker diarization (via pyannote.audio)
- Memoro-compatible utterances with speaker IDs
Requires NVIDIA GPU with CUDA and USE_WHISPERX=true.
Diarization requires HF_TOKEN with pyannote model access.
Supported formats: mp3, wav, m4a, flac, ogg, webm, mp4
Max file size: 100MB
"""
if auth.rate_limit_remaining is not None:
response.headers["X-RateLimit-Remaining"] = str(auth.rate_limit_remaining)
if not USE_WHISPERX:
raise HTTPException(
status_code=503,
detail="WhisperX not enabled. Set USE_WHISPERX=true on a CUDA-capable server."
)
if not file.filename:
raise HTTPException(status_code=400, detail="No file provided")
allowed_extensions = {".mp3", ".wav", ".m4a", ".flac", ".ogg", ".webm", ".mp4"}
ext = os.path.splitext(file.filename)[1].lower()
if ext not in allowed_extensions:
raise HTTPException(
status_code=400,
detail=f"Unsupported file type: {ext}. Allowed: {allowed_extensions}"
)
start_time = time.time()
try:
from app.whisperx_service import transcribe_audio_bytes
audio_bytes = await file.read()
if len(audio_bytes) > 100 * 1024 * 1024:
raise HTTPException(status_code=400, detail="File too large (max 100MB)")
model_name = model or DEFAULT_WHISPER_MODEL
result = await transcribe_audio_bytes(
audio_bytes=audio_bytes,
filename=file.filename,
language=language,
model_name=model_name,
enable_diarization=diarization,
enable_alignment=alignment,
min_speakers=min_speakers,
max_speakers=max_speakers,
)
latency_ms = (time.time() - start_time) * 1000
return RichTranscriptionResponse(
text=result.text,
language=result.language,
model=f"whisperx-{model_name}",
latency_ms=latency_ms,
duration_seconds=result.duration_seconds,
segments=[
SegmentResponse(
start=s.start,
end=s.end,
text=s.text,
speaker=s.speaker,
words=[
WordTimestampResponse(
word=w.word,
start=w.start,
end=w.end,
score=w.score,
speaker=w.speaker,
)
for w in s.words
],
)
for s in result.segments
],
utterances=[
UtteranceResponse(
speaker=u.speaker,
text=u.text,
offset=u.offset,
duration=u.duration,
)
for u in result.utterances
],
speakers=result.speakers,
speaker_map={k: v for k, v in result.speaker_map.items()},
languages=result.languages,
primary_language=result.primary_language,
words=[
WordTimestampResponse(
word=w.word,
start=w.start,
end=w.end,
score=w.score,
speaker=w.speaker,
)
for w in result.words
],
)
except HTTPException:
raise
except ImportError:
raise HTTPException(
status_code=503,
detail="WhisperX not installed. Install with: pip install -r requirements-cuda.txt"
)
except Exception as e:
logger.error(f"WhisperX transcription error: {e}")
raise HTTPException(status_code=500, detail=str(e))
@app.post("/transcribe/auto", response_model=TranscriptionResponse)
async def transcribe_auto(
response: Response,
file: UploadFile = File(..., description="Audio file to transcribe"),
language: Optional[str] = Form(None, description="Language hint"),
prefer: str = Form("whisper", description="Preferred: 'whisper', 'whisperx', or 'voxtral'"),
prefer: str = Form("whisper", description="Preferred: 'whisper' or 'voxtral'"),
auth: AuthResult = Depends(verify_api_key),
):
"""
Transcribe with automatic model selection and fallback.
Fallback chain:
- whisper: Whisper WhisperX Voxtral Mistral API
- whisperx: WhisperX Whisper Voxtral Mistral API
- voxtral: Voxtral WhisperX Whisper Mistral API
"""
# Add rate limit headers
"""Auto-select best model with fallback chain."""
if auth.rate_limit_remaining is not None:
response.headers["X-RateLimit-Remaining"] = str(auth.rate_limit_remaining)
async def try_whisperx_simple():
"""Try WhisperX and return as simple TranscriptionResponse."""
if not USE_WHISPERX:
raise RuntimeError("WhisperX not enabled")
from app.whisperx_service import transcribe_audio_bytes as wx_transcribe
audio_bytes = await file.read()
result = await wx_transcribe(
audio_bytes=audio_bytes,
filename=file.filename or "audio.wav",
language=language,
enable_diarization=False,
enable_alignment=False,
)
return TranscriptionResponse(
text=result.text,
language=result.language,
model=f"whisperx-{DEFAULT_WHISPER_MODEL}",
latency_ms=None,
duration_seconds=result.duration_seconds,
)
# Build fallback chain based on preference
if prefer == "whisperx":
chain = [
("WhisperX", try_whisperx_simple),
("Whisper", lambda: transcribe_whisper(response, file, language, None, auth)),
("Voxtral", lambda: transcribe_voxtral(response, file, language or "de", False, auth)),
("Mistral API", lambda: transcribe_voxtral_api(response, file, language, False, auth)),
]
elif prefer == "voxtral":
chain = [
("Voxtral", lambda: transcribe_voxtral(response, file, language or "de", False, auth)),
("WhisperX", try_whisperx_simple),
("Whisper", lambda: transcribe_whisper(response, file, language, None, auth)),
("Mistral API", lambda: transcribe_voxtral_api(response, file, language, False, auth)),
]
else:
chain = [
("Whisper", lambda: transcribe_whisper(response, file, language, None, auth)),
("WhisperX", try_whisperx_simple),
("Voxtral", lambda: transcribe_voxtral(response, file, language or "de", False, auth)),
("Mistral API", lambda: transcribe_voxtral_api(response, file, language, False, auth)),
]
last_error = None
for name, fn in chain:
if prefer == "voxtral":
try:
result = await fn()
return result
except Exception as e:
last_error = e
logger.warning(f"{name} failed: {e}")
return await transcribe_voxtral(response, file, language or "de", False, auth)
except Exception:
await file.seek(0)
raise HTTPException(
status_code=503,
detail=f"All transcription backends failed. Last error: {last_error}"
)
return await transcribe_whisper(response, file, language, None, True, False, None, None, auth)
else:
try:
return await transcribe_whisper(response, file, language, None, True, False, None, None, auth)
except Exception:
await file.seek(0)
return await transcribe_voxtral(response, file, language or "de", False, auth)
@app.exception_handler(Exception)
@ -676,9 +389,4 @@ async def global_exception_handler(request, exc):
if __name__ == "__main__":
import uvicorn
uvicorn.run(
"app.main:app",
host="0.0.0.0",
port=PORT,
reload=False,
)
uvicorn.run("app.main:app", host="0.0.0.0", port=PORT, reload=False)

View file

@ -0,0 +1,114 @@
"""
VRAM Manager Automatic model unloading after idle timeout.
Tracks last usage time per model and unloads after configurable timeout.
Designed for shared GPU environments (multiple services on one RTX 3090).
Usage in a service:
from vram_manager import VramManager
vram = VramManager(idle_timeout=300) # 5 min
# Before using a model
vram.touch()
# Call periodically (e.g., from health check or background task)
vram.check_idle(unload_fn=my_unload_function)
"""
import os
import time
import logging
import threading
from typing import Optional, Callable
logger = logging.getLogger(__name__)
DEFAULT_IDLE_TIMEOUT = int(os.getenv("VRAM_IDLE_TIMEOUT", "300")) # 5 minutes
class VramManager:
def __init__(self, idle_timeout: int = DEFAULT_IDLE_TIMEOUT, service_name: str = "unknown"):
self.idle_timeout = idle_timeout
self.service_name = service_name
self.last_used: float = 0.0
self.model_loaded: bool = False
self._lock = threading.Lock()
self._timer: Optional[threading.Timer] = None
def touch(self):
"""Mark the model as recently used. Call before/after each inference."""
with self._lock:
self.last_used = time.time()
self.model_loaded = True
self._schedule_check()
def mark_loaded(self):
"""Mark that a model has been loaded into VRAM."""
with self._lock:
self.model_loaded = True
self.last_used = time.time()
self._schedule_check()
logger.info(f"[{self.service_name}] Model loaded, idle timeout: {self.idle_timeout}s")
def mark_unloaded(self):
"""Mark that a model has been unloaded from VRAM."""
with self._lock:
self.model_loaded = False
if self._timer:
self._timer.cancel()
self._timer = None
logger.info(f"[{self.service_name}] Model unloaded, VRAM freed")
def is_idle(self) -> bool:
"""Check if the model has been idle longer than the timeout."""
if not self.model_loaded:
return False
return (time.time() - self.last_used) > self.idle_timeout
def seconds_until_unload(self) -> Optional[float]:
"""Seconds until the model will be unloaded, or None if not loaded."""
if not self.model_loaded:
return None
remaining = self.idle_timeout - (time.time() - self.last_used)
return max(0, remaining)
def check_and_unload(self, unload_fn: Callable[[], None]) -> bool:
"""Check if idle and unload if so. Returns True if unloaded."""
if self.is_idle():
logger.info(f"[{self.service_name}] Idle for >{self.idle_timeout}s, unloading model...")
try:
unload_fn()
self.mark_unloaded()
return True
except Exception as e:
logger.error(f"[{self.service_name}] Failed to unload: {e}")
return False
def _schedule_check(self):
"""Schedule an idle check after the timeout period."""
if self._timer:
self._timer.cancel()
self._timer = threading.Timer(
self.idle_timeout + 5, # Small buffer
self._auto_check,
)
self._timer.daemon = True
self._timer.start()
def _auto_check(self):
"""Auto-triggered idle check (called by timer)."""
# This is just a log — actual unloading needs the unload_fn
# which depends on the service. The service should call check_and_unload.
if self.is_idle():
logger.info(f"[{self.service_name}] Model idle for >{self.idle_timeout}s — ready to unload")
def status(self) -> dict:
"""Get current VRAM manager status."""
return {
"model_loaded": self.model_loaded,
"idle_seconds": round(time.time() - self.last_used, 1) if self.model_loaded else None,
"idle_timeout": self.idle_timeout,
"seconds_until_unload": round(self.seconds_until_unload(), 1) if self.model_loaded else None,
}

View file

@ -1,6 +1,11 @@
"""
Whisper STT Service using Lightning Whisper MLX
Optimized for Apple Silicon (M1/M2/M3/M4)
Whisper STT Service using WhisperX (CUDA)
Provides: transcription, word-level timestamps, speaker diarization.
WhisperX pipeline:
1. faster-whisper for transcription
2. wav2vec2 for forced alignment (precise word timestamps)
3. pyannote-audio for speaker diarization
"""
import os
@ -8,12 +13,55 @@ import tempfile
import logging
from pathlib import Path
from typing import Optional
from dataclasses import dataclass
from dataclasses import dataclass, field
logger = logging.getLogger(__name__)
# Lazy load to avoid import errors if not installed
_whisper_model = None
# Lazy load
_whisperx_model = None
_align_model = None
_align_metadata = None
_diarize_pipeline = None
# Config
HF_TOKEN = os.getenv("HF_TOKEN", "")
# VRAM management — unload after 10 min idle (STT uses ~3GB)
from app.vram_manager import VramManager
_vram = VramManager(
idle_timeout=int(os.getenv("VRAM_IDLE_TIMEOUT", "600")),
service_name="mana-stt",
)
def unload_models():
"""Unload all WhisperX models from GPU to free VRAM."""
global _whisperx_model, _align_model, _align_metadata, _diarize_pipeline
import torch
if _whisperx_model is not None:
del _whisperx_model
_whisperx_model = None
if _align_model is not None:
del _align_model
_align_model = None
_align_metadata = None
if _diarize_pipeline is not None:
del _diarize_pipeline
_diarize_pipeline = None
torch.cuda.empty_cache()
_vram.mark_unloaded()
logger.info("WhisperX models unloaded, VRAM freed")
@dataclass
class WordSegment:
word: str
start: float
end: float
score: Optional[float] = None
speaker: Optional[str] = None
@dataclass
@ -22,83 +70,231 @@ class TranscriptionResult:
language: Optional[str] = None
duration: Optional[float] = None
segments: Optional[list] = None
words: Optional[list[WordSegment]] = field(default_factory=list)
speakers: Optional[list[str]] = field(default_factory=list)
def get_whisper_model(model_name: str = "large-v3", batch_size: int = 12):
"""Get or create Whisper model instance (singleton pattern)."""
global _whisper_model
def get_whisper_model(model_name: str = "large-v3", **kwargs):
"""Get or create WhisperX model instance (singleton)."""
global _whisperx_model
if _whisper_model is None:
logger.info(f"Loading Whisper model: {model_name}")
try:
from lightning_whisper_mlx import LightningWhisperMLX
_whisper_model = LightningWhisperMLX(
model=model_name,
batch_size=batch_size,
quant=None # Use full precision for best quality
)
logger.info(f"Whisper model loaded successfully: {model_name}")
except ImportError as e:
logger.error(f"Failed to import lightning_whisper_mlx: {e}")
raise RuntimeError(
"lightning-whisper-mlx not installed. "
"Run: pip install lightning-whisper-mlx"
)
except Exception as e:
logger.error(f"Failed to load Whisper model: {e}")
raise
if _whisperx_model is not None:
return _whisperx_model
return _whisper_model
logger.info(f"Loading WhisperX model: {model_name}")
try:
import whisperx
device = os.getenv("WHISPER_DEVICE", "cuda")
compute_type = os.getenv("WHISPER_COMPUTE_TYPE", "float16")
default_language = os.getenv("WHISPER_DEFAULT_LANGUAGE", "de")
_whisperx_model = whisperx.load_model(
model_name,
device=device,
compute_type=compute_type,
language=default_language,
)
logger.info(f"WhisperX model loaded: {model_name} on {device} ({compute_type})")
_vram.mark_loaded()
except ImportError as e:
logger.error(f"Failed to import whisperx: {e}")
raise RuntimeError("whisperx not installed. Run: pip install whisperx")
except Exception as e:
logger.error(f"Failed to load WhisperX model: {e}")
raise
return _whisperx_model
def _get_align_model(language: str, device: str = "cuda"):
"""Get or create alignment model for a language."""
global _align_model, _align_metadata
import whisperx
# Reload if language changed (alignment models are language-specific)
if _align_model is None:
logger.info(f"Loading alignment model for language: {language}")
_align_model, _align_metadata = whisperx.load_align_model(
language_code=language,
device=device,
)
logger.info("Alignment model loaded")
return _align_model, _align_metadata
def _get_diarize_pipeline(device: str = "cuda"):
"""Get or create speaker diarization pipeline."""
global _diarize_pipeline
if _diarize_pipeline is not None:
return _diarize_pipeline
import torch
from pyannote.audio import Pipeline
token = HF_TOKEN or os.getenv("HUGGING_FACE_HUB_TOKEN", "")
if not token:
logger.warning("No HF_TOKEN set — speaker diarization may fail for gated models")
logger.info("Loading speaker diarization pipeline (pyannote)...")
_diarize_pipeline = Pipeline.from_pretrained(
"pyannote/speaker-diarization-3.1",
token=token,
)
_diarize_pipeline.to(torch.device(device))
logger.info("Diarization pipeline loaded")
return _diarize_pipeline
def transcribe_audio(
audio_path: str,
language: Optional[str] = None,
model_name: str = "large-v3",
align: bool = True,
diarize: bool = False,
min_speakers: Optional[int] = None,
max_speakers: Optional[int] = None,
) -> TranscriptionResult:
"""
Transcribe audio file using Lightning Whisper MLX.
Transcribe audio using WhisperX with optional alignment and diarization.
Args:
audio_path: Path to audio file (mp3, wav, m4a, etc.)
language: Optional language code (e.g., 'de', 'en'). Auto-detect if None.
audio_path: Path to audio file
language: Language code (auto-detect if None)
model_name: Whisper model to use
align: Enable word-level timestamp alignment
diarize: Enable speaker diarization
min_speakers: Minimum expected speakers (helps diarization)
max_speakers: Maximum expected speakers
Returns:
TranscriptionResult with text and metadata
TranscriptionResult with text, word timestamps, and speaker info
"""
import whisperx
device = os.getenv("WHISPER_DEVICE", "cuda")
model = get_whisper_model(model_name)
logger.info(f"Transcribing: {audio_path}")
logger.info(f"Transcribing: {audio_path} (align={align}, diarize={diarize})")
try:
# Lightning Whisper MLX returns dict with 'text' key
result = model.transcribe(
audio_path=audio_path,
language=language,
)
# Check and unload if idle, then reload
_vram.check_and_unload(unload_models)
_vram.touch()
# Handle different return formats
if isinstance(result, dict):
text = result.get("text", "")
segments = result.get("segments", [])
detected_language = result.get("language", language)
else:
text = str(result)
segments = []
detected_language = language
# Step 1: Load audio
audio = whisperx.load_audio(audio_path)
logger.info(f"Transcription complete: {len(text)} characters")
# Step 2: Transcribe with faster-whisper
transcribe_kwargs = {"batch_size": 16}
if language:
transcribe_kwargs["language"] = language
result = model.transcribe(audio, **transcribe_kwargs)
detected_language = result.get("language", language or "en")
return TranscriptionResult(
text=text.strip(),
language=detected_language,
segments=segments,
)
# Step 3: Align (word-level timestamps)
if align and result["segments"]:
try:
align_model, metadata = _get_align_model(detected_language, device)
result = whisperx.align(
result["segments"],
align_model,
metadata,
audio,
device,
return_char_alignments=False,
)
logger.info("Word alignment complete")
except Exception as e:
logger.warning(f"Alignment failed (continuing without): {e}")
except Exception as e:
logger.error(f"Transcription failed: {e}")
raise
# Step 4: Diarize (speaker identification)
if diarize:
try:
import torch
import torchaudio
diarize_pipe = _get_diarize_pipeline(device)
# pyannote needs waveform as tensor, not the whisperx audio array
waveform = torch.from_numpy(audio).unsqueeze(0).float()
diarize_input = {"waveform": waveform, "sample_rate": 16000}
diarize_kwargs = {}
if min_speakers is not None:
diarize_kwargs["min_speakers"] = min_speakers
if max_speakers is not None:
diarize_kwargs["max_speakers"] = max_speakers
diarize_output = diarize_pipe(diarize_input, **diarize_kwargs)
# pyannote 4.x returns DiarizeOutput, extract the Annotation
if hasattr(diarize_output, "speaker_diarization"):
diarize_annotation = diarize_output.speaker_diarization
else:
diarize_annotation = diarize_output
# Convert pyannote output to DataFrame for whisperx
import pandas as pd
diarize_rows = []
for turn, _, speaker in diarize_annotation.itertracks(yield_label=True):
diarize_rows.append({
"start": turn.start,
"end": turn.end,
"speaker": speaker,
})
diarize_df = pd.DataFrame(diarize_rows)
result = whisperx.assign_word_speakers(diarize_df, result)
logger.info("Speaker diarization complete")
except Exception as e:
logger.warning(f"Diarization failed (continuing without): {e}")
import traceback
traceback.print_exc()
# Build response
segments = result.get("segments", [])
full_text_parts = []
all_words = []
speaker_set = set()
for seg in segments:
full_text_parts.append(seg.get("text", ""))
speaker = seg.get("speaker")
if speaker:
speaker_set.add(speaker)
for word_info in seg.get("words", []):
all_words.append(WordSegment(
word=word_info.get("word", ""),
start=word_info.get("start", 0.0),
end=word_info.get("end", 0.0),
score=word_info.get("score"),
speaker=word_info.get("speaker", speaker),
))
text = " ".join(full_text_parts)
_vram.touch()
logger.info(
f"Transcription complete: {len(text)} chars, "
f"{len(all_words)} words, {len(speaker_set)} speakers"
)
return TranscriptionResult(
text=text.strip(),
language=detected_language,
segments=[{
"start": s.get("start", 0),
"end": s.get("end", 0),
"text": s.get("text", ""),
"speaker": s.get("speaker"),
} for s in segments],
words=all_words,
speakers=sorted(speaker_set),
)
async def transcribe_audio_bytes(
@ -106,53 +302,57 @@ async def transcribe_audio_bytes(
filename: str,
language: Optional[str] = None,
model_name: str = "large-v3",
align: bool = True,
diarize: bool = False,
min_speakers: Optional[int] = None,
max_speakers: Optional[int] = None,
) -> TranscriptionResult:
"""
Transcribe audio from bytes (for API uploads).
"""Transcribe audio from bytes (for API uploads)."""
import asyncio
Args:
audio_bytes: Raw audio file bytes
filename: Original filename (for extension detection)
language: Optional language code
model_name: Whisper model to use
Returns:
TranscriptionResult
"""
# Get file extension
ext = Path(filename).suffix or ".wav"
# Write to temp file
with tempfile.NamedTemporaryFile(suffix=ext, delete=False) as tmp:
tmp.write(audio_bytes)
tmp_path = tmp.name
try:
result = transcribe_audio(
audio_path=tmp_path,
language=language,
model_name=model_name,
# Run in thread pool to avoid blocking the event loop
loop = asyncio.get_event_loop()
result = await loop.run_in_executor(
None,
lambda: transcribe_audio(
audio_path=tmp_path,
language=language,
model_name=model_name,
align=align,
diarize=diarize,
min_speakers=min_speakers,
max_speakers=max_speakers,
),
)
return result
finally:
# Clean up temp file
try:
os.unlink(tmp_path)
except Exception:
pass
# Available models for Lightning Whisper MLX
# Available models
AVAILABLE_MODELS = [
"tiny",
"small",
"tiny.en",
"base",
"base.en",
"small",
"small.en",
"medium",
"large",
"medium.en",
"large-v1",
"large-v2",
"large-v3", # Recommended for Mac Mini
"distil-small.en",
"distil-medium.en",
"large-v3",
"large-v3-turbo",
"distil-large-v2",
"distil-large-v3",
]

View file

@ -1,175 +0,0 @@
"""
Whisper STT Service using faster-whisper (CUDA)
Optimized for NVIDIA GPUs (RTX 3090 etc.)
Drop-in replacement for whisper_service.py (MLX version).
Uses faster-whisper with CTranslate2 for GPU-accelerated inference.
"""
import os
import tempfile
import logging
from pathlib import Path
from typing import Optional
from dataclasses import dataclass
logger = logging.getLogger(__name__)
# Lazy load to avoid import errors if not installed
_whisper_model = None
@dataclass
class TranscriptionResult:
text: str
language: Optional[str] = None
duration: Optional[float] = None
segments: Optional[list] = None
def get_whisper_model(model_name: str = "large-v3", **kwargs):
"""Get or create Whisper model instance (singleton pattern)."""
global _whisper_model
if _whisper_model is None:
logger.info(f"Loading Whisper model: {model_name}")
try:
from faster_whisper import WhisperModel
# Use CUDA with float16 for RTX 3090
compute_type = os.getenv("WHISPER_COMPUTE_TYPE", "float16")
device = os.getenv("WHISPER_DEVICE", "cuda")
_whisper_model = WhisperModel(
model_name,
device=device,
compute_type=compute_type,
)
logger.info(f"Whisper model loaded: {model_name} on {device} ({compute_type})")
except ImportError as e:
logger.error(f"Failed to import faster_whisper: {e}")
raise RuntimeError(
"faster-whisper not installed. "
"Run: pip install faster-whisper"
)
except Exception as e:
logger.error(f"Failed to load Whisper model: {e}")
raise
return _whisper_model
def transcribe_audio(
audio_path: str,
language: Optional[str] = None,
model_name: str = "large-v3",
) -> TranscriptionResult:
"""
Transcribe audio file using faster-whisper (CUDA).
Args:
audio_path: Path to audio file (mp3, wav, m4a, etc.)
language: Optional language code (e.g., 'de', 'en'). Auto-detect if None.
model_name: Whisper model to use
Returns:
TranscriptionResult with text and metadata
"""
model = get_whisper_model(model_name)
logger.info(f"Transcribing: {audio_path}")
try:
segments, info = model.transcribe(
audio_path,
language=language,
beam_size=5,
vad_filter=True, # Filter out silence
)
# Collect all segments
all_segments = []
full_text_parts = []
for segment in segments:
full_text_parts.append(segment.text)
all_segments.append({
"start": segment.start,
"end": segment.end,
"text": segment.text,
})
text = " ".join(full_text_parts)
detected_language = info.language if info else language
logger.info(f"Transcription complete: {len(text)} characters, language={detected_language}")
return TranscriptionResult(
text=text.strip(),
language=detected_language,
duration=info.duration if info else None,
segments=all_segments,
)
except Exception as e:
logger.error(f"Transcription failed: {e}")
raise
async def transcribe_audio_bytes(
audio_bytes: bytes,
filename: str,
language: Optional[str] = None,
model_name: str = "large-v3",
) -> TranscriptionResult:
"""
Transcribe audio from bytes (for API uploads).
Args:
audio_bytes: Raw audio file bytes
filename: Original filename (for extension detection)
language: Optional language code
model_name: Whisper model to use
Returns:
TranscriptionResult
"""
# Get file extension
ext = Path(filename).suffix or ".wav"
# Write to temp file
with tempfile.NamedTemporaryFile(suffix=ext, delete=False) as tmp:
tmp.write(audio_bytes)
tmp_path = tmp.name
try:
result = transcribe_audio(
audio_path=tmp_path,
language=language,
model_name=model_name,
)
return result
finally:
# Clean up temp file
try:
os.unlink(tmp_path)
except Exception:
pass
# Available models for faster-whisper
AVAILABLE_MODELS = [
"tiny",
"tiny.en",
"base",
"base.en",
"small",
"small.en",
"medium",
"medium.en",
"large-v1",
"large-v2",
"large-v3",
"large-v3-turbo",
"distil-large-v2",
"distil-large-v3",
]

View file

@ -1,419 +0,0 @@
"""
WhisperX STT Service using faster-whisper + pyannote (CUDA)
Optimized for NVIDIA GPUs (RTX 3090 etc.)
Features:
- Word-level timestamps via forced alignment
- Speaker diarization via pyannote.audio
- Segment-level timestamps with speaker labels
- VAD filtering for silence removal
Requires HuggingFace token for pyannote models:
export HF_TOKEN=hf_xxx
# Accept terms at: https://huggingface.co/pyannote/speaker-diarization-3.1
"""
import os
import tempfile
import logging
import time
from pathlib import Path
from typing import Optional
from dataclasses import dataclass, field
logger = logging.getLogger(__name__)
# Lazy-loaded singletons
_whisper_model = None
_align_model = None
_align_metadata = None
_diarize_pipeline = None
HF_TOKEN = os.getenv("HF_TOKEN")
DEFAULT_MODEL = os.getenv("WHISPER_MODEL", "large-v3")
DEVICE = os.getenv("WHISPER_DEVICE", "cuda")
COMPUTE_TYPE = os.getenv("WHISPER_COMPUTE_TYPE", "float16")
BATCH_SIZE = int(os.getenv("WHISPERX_BATCH_SIZE", "16"))
@dataclass
class WordInfo:
"""Word with timestamp."""
word: str
start: float
end: float
score: float = 0.0
speaker: Optional[str] = None
@dataclass
class SegmentInfo:
"""Segment with speaker and word-level detail."""
start: float
end: float
text: str
speaker: Optional[str] = None
words: list[WordInfo] = field(default_factory=list)
@dataclass
class Utterance:
"""Speaker utterance in Memoro-compatible format."""
speaker: int
text: str
offset: int # milliseconds
duration: int # milliseconds
@dataclass
class WhisperXResult:
"""Rich transcription result with alignment and diarization."""
text: str
language: Optional[str] = None
duration_seconds: Optional[float] = None
segments: list[SegmentInfo] = field(default_factory=list)
utterances: list[Utterance] = field(default_factory=list)
speakers: dict[str, str] = field(default_factory=dict)
speaker_map: dict[str, int] = field(default_factory=dict)
languages: list[str] = field(default_factory=list)
primary_language: Optional[str] = None
words: list[WordInfo] = field(default_factory=list)
def get_whisper_model(model_name: str = None):
"""Load faster-whisper model (singleton)."""
global _whisper_model
model_name = model_name or DEFAULT_MODEL
if _whisper_model is None:
logger.info(f"Loading WhisperX model: {model_name} on {DEVICE} ({COMPUTE_TYPE})")
try:
import whisperx
_whisper_model = whisperx.load_model(
model_name,
device=DEVICE,
compute_type=COMPUTE_TYPE,
)
logger.info(f"WhisperX model loaded: {model_name}")
except ImportError:
raise RuntimeError(
"whisperx not installed. "
"Run: pip install whisperx"
)
return _whisper_model
def get_align_model(language_code: str):
"""Load alignment model for a specific language (cached per language)."""
global _align_model, _align_metadata
try:
import whisperx
_align_model, _align_metadata = whisperx.load_align_model(
language_code=language_code,
device=DEVICE,
)
logger.info(f"Alignment model loaded for language: {language_code}")
return _align_model, _align_metadata
except Exception as e:
logger.warning(f"Failed to load alignment model for {language_code}: {e}")
return None, None
def get_diarize_pipeline():
"""Load pyannote speaker diarization pipeline (singleton)."""
global _diarize_pipeline
if _diarize_pipeline is None:
if not HF_TOKEN:
logger.warning("HF_TOKEN not set — diarization disabled")
return None
try:
import whisperx
_diarize_pipeline = whisperx.DiarizationPipeline(
use_auth_token=HF_TOKEN,
device=DEVICE,
)
logger.info("Diarization pipeline loaded")
except Exception as e:
logger.warning(f"Failed to load diarization pipeline: {e}")
return None
return _diarize_pipeline
def _build_utterances(segments: list[SegmentInfo]) -> tuple[list[Utterance], dict[str, str], dict[str, int]]:
"""
Build Memoro-compatible utterances from diarized segments.
Groups consecutive segments by the same speaker.
"""
if not segments:
return [], {}, {}
# Collect unique speakers
speaker_labels = sorted(set(
s.speaker for s in segments if s.speaker is not None
))
speaker_map: dict[str, int] = {}
speakers: dict[str, str] = {}
for idx, label in enumerate(speaker_labels):
speaker_map[label] = idx
speakers[str(idx)] = label
# Merge consecutive segments with the same speaker
utterances: list[Utterance] = []
current_speaker = None
current_text_parts: list[str] = []
current_start = 0.0
current_end = 0.0
for seg in segments:
sp = seg.speaker or "SPEAKER_00"
if sp != current_speaker:
# Flush previous
if current_speaker is not None and current_text_parts:
utterances.append(Utterance(
speaker=speaker_map.get(current_speaker, 0),
text=" ".join(current_text_parts).strip(),
offset=int(current_start * 1000),
duration=int((current_end - current_start) * 1000),
))
current_speaker = sp
current_text_parts = [seg.text]
current_start = seg.start
current_end = seg.end
else:
current_text_parts.append(seg.text)
current_end = seg.end
# Flush last
if current_speaker is not None and current_text_parts:
utterances.append(Utterance(
speaker=speaker_map.get(current_speaker, 0),
text=" ".join(current_text_parts).strip(),
offset=int(current_start * 1000),
duration=int((current_end - current_start) * 1000),
))
return utterances, speakers, speaker_map
def transcribe_audio(
audio_path: str,
language: Optional[str] = None,
model_name: Optional[str] = None,
enable_diarization: bool = True,
enable_alignment: bool = True,
min_speakers: Optional[int] = None,
max_speakers: Optional[int] = None,
) -> WhisperXResult:
"""
Transcribe audio with WhisperX: alignment + diarization.
Args:
audio_path: Path to audio file
language: Language code (e.g., 'de', 'en'). Auto-detect if None.
model_name: Whisper model to use
enable_diarization: Run speaker diarization
enable_alignment: Run forced word alignment
min_speakers: Minimum expected speakers (hint for diarization)
max_speakers: Maximum expected speakers (hint for diarization)
Returns:
WhisperXResult with full transcription, segments, utterances, speakers
"""
import whisperx
start_time = time.time()
# 1. Load audio
audio = whisperx.load_audio(audio_path)
audio_duration = len(audio) / 16000 # whisperx resamples to 16kHz
# 2. Transcribe with faster-whisper
model = get_whisper_model(model_name)
transcribe_result = model.transcribe(
audio,
batch_size=BATCH_SIZE,
language=language,
)
detected_language = transcribe_result.get("language", language or "en")
raw_segments = transcribe_result.get("segments", [])
logger.info(
f"Transcription: {len(raw_segments)} segments, "
f"language={detected_language}, "
f"duration={audio_duration:.1f}s"
)
# 3. Forced alignment (word-level timestamps)
if enable_alignment and raw_segments:
align_model, align_metadata = get_align_model(detected_language)
if align_model is not None:
try:
transcribe_result = whisperx.align(
raw_segments,
align_model,
align_metadata,
audio,
DEVICE,
return_char_alignments=False,
)
raw_segments = transcribe_result.get("segments", raw_segments)
logger.info("Word alignment complete")
except Exception as e:
logger.warning(f"Alignment failed, using segment-level timestamps: {e}")
# 4. Speaker diarization
if enable_diarization:
diarize_pipeline = get_diarize_pipeline()
if diarize_pipeline is not None:
try:
diarize_kwargs = {}
if min_speakers is not None:
diarize_kwargs["min_speakers"] = min_speakers
if max_speakers is not None:
diarize_kwargs["max_speakers"] = max_speakers
diarize_segments = diarize_pipeline(
audio_path,
**diarize_kwargs,
)
transcribe_result = whisperx.assign_word_speakers(
diarize_segments, transcribe_result
)
raw_segments = transcribe_result.get("segments", raw_segments)
logger.info("Diarization complete")
except Exception as e:
logger.warning(f"Diarization failed: {e}")
# 5. Build structured result
segments: list[SegmentInfo] = []
all_words: list[WordInfo] = []
full_text_parts: list[str] = []
for seg in raw_segments:
seg_words: list[WordInfo] = []
for w in seg.get("words", []):
wi = WordInfo(
word=w.get("word", ""),
start=w.get("start", 0.0),
end=w.get("end", 0.0),
score=w.get("score", 0.0),
speaker=w.get("speaker"),
)
seg_words.append(wi)
all_words.append(wi)
segment = SegmentInfo(
start=seg.get("start", 0.0),
end=seg.get("end", 0.0),
text=seg.get("text", "").strip(),
speaker=seg.get("speaker"),
words=seg_words,
)
segments.append(segment)
full_text_parts.append(segment.text)
full_text = " ".join(full_text_parts)
# 6. Build utterances (Memoro-compatible)
utterances, speakers, speaker_map = _build_utterances(segments)
latency = time.time() - start_time
logger.info(f"WhisperX complete in {latency:.1f}s: {len(full_text)} chars, {len(speakers)} speakers")
return WhisperXResult(
text=full_text,
language=detected_language,
duration_seconds=audio_duration,
segments=segments,
utterances=utterances,
speakers=speakers,
speaker_map=speaker_map,
languages=[detected_language] if detected_language else [],
primary_language=detected_language,
words=all_words,
)
async def transcribe_audio_bytes(
audio_bytes: bytes,
filename: str,
language: Optional[str] = None,
model_name: Optional[str] = None,
enable_diarization: bool = True,
enable_alignment: bool = True,
min_speakers: Optional[int] = None,
max_speakers: Optional[int] = None,
) -> WhisperXResult:
"""
Transcribe audio from bytes (for API uploads).
Args:
audio_bytes: Raw audio file bytes
filename: Original filename (for extension detection)
language: Optional language code
model_name: Whisper model to use
enable_diarization: Run speaker diarization
enable_alignment: Run forced word alignment
min_speakers: Min expected speakers
max_speakers: Max expected speakers
Returns:
WhisperXResult with full transcription data
"""
ext = Path(filename).suffix or ".wav"
with tempfile.NamedTemporaryFile(suffix=ext, delete=False) as tmp:
tmp.write(audio_bytes)
tmp_path = tmp.name
try:
return transcribe_audio(
audio_path=tmp_path,
language=language,
model_name=model_name,
enable_diarization=enable_diarization,
enable_alignment=enable_alignment,
min_speakers=min_speakers,
max_speakers=max_speakers,
)
finally:
try:
os.unlink(tmp_path)
except Exception:
pass
def is_available() -> bool:
"""Check if WhisperX dependencies are installed."""
try:
import whisperx
return True
except ImportError:
return False
AVAILABLE_MODELS = [
"tiny",
"tiny.en",
"base",
"base.en",
"small",
"small.en",
"medium",
"medium.en",
"large-v1",
"large-v2",
"large-v3",
"large-v3-turbo",
"distil-large-v2",
"distil-large-v3",
]

View file

@ -0,0 +1,34 @@
"""mana-stt service runner."""
import os
import sys
os.chdir(r"C:\mana\services\mana-stt")
sys.path.insert(0, r"C:\mana\services\mana-stt")
# Redirect stdout/stderr to log file FIRST (before any imports that warn)
log = open(r"C:\mana\services\mana-stt\service.log", "w", buffering=1)
sys.stdout = log
sys.stderr = log
# Load .env file
from dotenv import load_dotenv
load_dotenv(r"C:\mana\services\mana-stt\.env")
# Ensure FFmpeg is in PATH
ffmpeg_dir = r"C:\Users\tills\AppData\Local\Microsoft\WinGet\Links"
if ffmpeg_dir not in os.environ.get("PATH", ""):
os.environ["PATH"] = ffmpeg_dir + os.pathsep + os.environ.get("PATH", "")
# Set HF token
hf_token = os.environ.get("HF_TOKEN", "")
if hf_token:
os.environ["HUGGING_FACE_HUB_TOKEN"] = hf_token
# Pre-initialize CUDA before importing whisperx (avoids hangs)
import torch
if torch.cuda.is_available():
torch.cuda.init()
print(f"CUDA initialized: {torch.cuda.get_device_name(0)}", flush=True)
import uvicorn
uvicorn.run("app.main:app", host="0.0.0.0", port=3020, log_level="info")

View file

@ -1,39 +1,271 @@
"""
API Key Authentication for Mana TTS Service.
Delegates to shared mana_auth package.
API Key Authentication for ManaCore STT Service
Supports two authentication modes:
1. Local API keys: Configured via environment variables
2. External API keys: Validated via mana-core-auth service (when EXTERNAL_AUTH_ENABLED=true)
Usage:
# Local keys
API_KEYS=sk-key1:name1,sk-key2:name2
INTERNAL_API_KEY=sk-internal-xxx
# External auth (for user-created keys via mana.how)
EXTERNAL_AUTH_ENABLED=true
MANA_CORE_AUTH_URL=http://localhost:3001
"""
import sys
import os
import time
import logging
from typing import Optional
from collections import defaultdict
from dataclasses import dataclass, field
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "..", "..", "packages", "shared-python"))
from fastapi import HTTPException, Security, Request
from fastapi.security import APIKeyHeader
from mana_auth import (
APIKey,
AuthResult,
RateLimitInfo,
verify_api_key as _verify_api_key,
get_api_key_stats,
reload_api_keys,
api_key_header,
create_auth_dependency,
)
from mana_auth.external_auth import (
ExternalValidationResult,
from .external_auth import (
is_external_auth_enabled,
validate_api_key_external,
ExternalValidationResult,
)
from typing import Optional
from fastapi import Security, Request
logger = logging.getLogger(__name__)
# TTS-specific auth dependency
verify_tts_key = create_auth_dependency("tts")
# Configuration
API_KEYS_ENV = os.getenv("API_KEYS", "") # Format: "sk-key1:name1,sk-key2:name2"
INTERNAL_API_KEY = os.getenv("INTERNAL_API_KEY", "") # Unlimited internal key
REQUIRE_AUTH = os.getenv("REQUIRE_AUTH", "true").lower() == "true"
RATE_LIMIT_REQUESTS = int(os.getenv("RATE_LIMIT_REQUESTS", "60")) # Per minute
RATE_LIMIT_WINDOW = int(os.getenv("RATE_LIMIT_WINDOW", "60")) # Seconds
@dataclass
class APIKey:
"""API Key with metadata."""
key: str
name: str
is_internal: bool = False
rate_limit: int = RATE_LIMIT_REQUESTS # Requests per window
@dataclass
class RateLimitInfo:
"""Rate limit tracking per key."""
requests: list = field(default_factory=list)
def is_allowed(self, limit: int, window: int) -> bool:
"""Check if request is allowed within rate limit."""
now = time.time()
# Remove old requests outside window
self.requests = [t for t in self.requests if now - t < window]
if len(self.requests) >= limit:
return False
self.requests.append(now)
return True
def remaining(self, limit: int, window: int) -> int:
"""Get remaining requests in current window."""
now = time.time()
self.requests = [t for t in self.requests if now - t < window]
return max(0, limit - len(self.requests))
# Parse API keys from environment
def _parse_api_keys() -> dict[str, APIKey]:
"""Parse API keys from environment variables."""
keys = {}
# Parse comma-separated keys
if API_KEYS_ENV:
for entry in API_KEYS_ENV.split(","):
entry = entry.strip()
if ":" in entry:
key, name = entry.split(":", 1)
else:
key, name = entry, "default"
keys[key.strip()] = APIKey(key=key.strip(), name=name.strip())
# Add internal key with no rate limit
if INTERNAL_API_KEY:
keys[INTERNAL_API_KEY] = APIKey(
key=INTERNAL_API_KEY,
name="internal",
is_internal=True,
rate_limit=999999, # Effectively unlimited
)
return keys
# Global state
_api_keys = _parse_api_keys()
_rate_limits: dict[str, RateLimitInfo] = defaultdict(RateLimitInfo)
# Security scheme
api_key_header = APIKeyHeader(name="X-API-Key", auto_error=False)
@dataclass
class AuthResult:
"""Result of authentication check."""
authenticated: bool
key_name: Optional[str] = None
is_internal: bool = False
rate_limit_remaining: Optional[int] = None
user_id: Optional[str] = None # Set when using external auth
async def verify_api_key(
request: Request,
api_key: Optional[str] = Security(api_key_header),
) -> AuthResult:
"""Verify API key with TTS scope."""
return await _verify_api_key(request, scope="tts", api_key=api_key)
"""
Verify API key and check rate limits.
Supports two authentication modes:
1. External auth via mana-core-auth (for sk_live_ keys)
2. Local auth via environment variables
Returns AuthResult with authentication status.
Raises HTTPException if auth fails or rate limited.
"""
# Skip auth for health and docs endpoints
path = request.url.path
if path in ["/health", "/docs", "/openapi.json", "/redoc"]:
return AuthResult(authenticated=True, key_name="public")
# If auth not required, allow all
if not REQUIRE_AUTH:
return AuthResult(authenticated=True, key_name="anonymous")
# Check for API key
if not api_key:
logger.warning(f"Missing API key for {path} from {request.client.host if request.client else 'unknown'}")
raise HTTPException(
status_code=401,
detail="Missing API key. Provide X-API-Key header.",
headers={"WWW-Authenticate": "ApiKey"},
)
# Try external auth first for sk_live_ keys (user-created keys via mana.how)
if api_key.startswith("sk_live_") and is_external_auth_enabled():
external_result = await validate_api_key_external(api_key, "stt")
if external_result is not None:
if external_result.valid:
# Use rate limits from external auth
rate_info = _rate_limits[api_key]
limit = external_result.rate_limit_requests
window = external_result.rate_limit_window
if not rate_info.is_allowed(limit, window):
remaining = rate_info.remaining(limit, window)
logger.warning(f"Rate limit exceeded for external key")
raise HTTPException(
status_code=429,
detail=f"Rate limit exceeded. Try again in {window} seconds.",
headers={
"X-RateLimit-Limit": str(limit),
"X-RateLimit-Remaining": str(remaining),
"X-RateLimit-Reset": str(int(time.time()) + window),
"Retry-After": str(window),
},
)
remaining = rate_info.remaining(limit, window)
logger.debug(f"Authenticated external request from user {external_result.user_id} to {path}")
return AuthResult(
authenticated=True,
key_name="external",
is_internal=False,
rate_limit_remaining=remaining,
user_id=external_result.user_id,
)
else:
# External auth returned invalid
logger.warning(f"External auth failed: {external_result.error}")
raise HTTPException(
status_code=401,
detail=external_result.error or "Invalid API key.",
headers={"WWW-Authenticate": "ApiKey"},
)
# If external_result is None, fall through to local auth
# Local auth: Validate key against environment variables
if api_key not in _api_keys:
logger.warning(f"Invalid API key attempt for {path}")
raise HTTPException(
status_code=401,
detail="Invalid API key.",
headers={"WWW-Authenticate": "ApiKey"},
)
key_info = _api_keys[api_key]
# Check rate limit (skip for internal keys)
if not key_info.is_internal:
rate_info = _rate_limits[api_key]
if not rate_info.is_allowed(key_info.rate_limit, RATE_LIMIT_WINDOW):
remaining = rate_info.remaining(key_info.rate_limit, RATE_LIMIT_WINDOW)
logger.warning(f"Rate limit exceeded for key '{key_info.name}'")
raise HTTPException(
status_code=429,
detail=f"Rate limit exceeded. Try again in {RATE_LIMIT_WINDOW} seconds.",
headers={
"X-RateLimit-Limit": str(key_info.rate_limit),
"X-RateLimit-Remaining": str(remaining),
"X-RateLimit-Reset": str(int(time.time()) + RATE_LIMIT_WINDOW),
"Retry-After": str(RATE_LIMIT_WINDOW),
},
)
remaining = rate_info.remaining(key_info.rate_limit, RATE_LIMIT_WINDOW)
else:
remaining = None
logger.debug(f"Authenticated request from '{key_info.name}' to {path}")
return AuthResult(
authenticated=True,
key_name=key_info.name,
is_internal=key_info.is_internal,
rate_limit_remaining=remaining,
)
def get_api_key_stats() -> dict:
"""Get statistics about API keys (for admin endpoint)."""
stats = {
"total_keys": len(_api_keys),
"auth_required": REQUIRE_AUTH,
"rate_limit": {
"requests_per_window": RATE_LIMIT_REQUESTS,
"window_seconds": RATE_LIMIT_WINDOW,
},
"keys": [],
}
for key, info in _api_keys.items():
# Don't expose actual keys, just metadata
masked_key = key[:8] + "..." if len(key) > 8 else "***"
rate_info = _rate_limits.get(key, RateLimitInfo())
stats["keys"].append({
"name": info.name,
"key_prefix": masked_key,
"is_internal": info.is_internal,
"requests_in_window": len(rate_info.requests),
"remaining": rate_info.remaining(info.rate_limit, RATE_LIMIT_WINDOW),
})
return stats
def reload_api_keys():
"""Reload API keys from environment (for runtime updates)."""
global _api_keys
_api_keys = _parse_api_keys()
logger.info(f"Reloaded {len(_api_keys)} API keys")

View file

@ -1,22 +1,145 @@
"""
External API Key Validation delegates to shared mana_auth package.
External API Key Validation via mana-core-auth
When EXTERNAL_AUTH_ENABLED=true, API keys are validated against the
central mana-core-auth service. This allows users to create and manage
API keys from the mana.how web interface.
Results are cached for 5 minutes to reduce load on the auth service.
"""
import sys
import os
import time
import logging
import httpx
from typing import Optional
from dataclasses import dataclass
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "..", "..", "packages", "shared-python"))
logger = logging.getLogger(__name__)
from mana_auth.external_auth import (
ExternalValidationResult,
is_external_auth_enabled,
validate_api_key_external,
clear_cache,
)
# Configuration
EXTERNAL_AUTH_ENABLED = os.getenv("EXTERNAL_AUTH_ENABLED", "false").lower() == "true"
MANA_CORE_AUTH_URL = os.getenv("MANA_CORE_AUTH_URL", "http://localhost:3001")
API_KEY_CACHE_TTL = int(os.getenv("API_KEY_CACHE_TTL", "300")) # 5 minutes
EXTERNAL_AUTH_TIMEOUT = float(os.getenv("EXTERNAL_AUTH_TIMEOUT", "5.0")) # seconds
__all__ = [
"ExternalValidationResult",
"is_external_auth_enabled",
"validate_api_key_external",
"clear_cache",
]
@dataclass
class ExternalValidationResult:
"""Result from external API key validation."""
valid: bool
user_id: Optional[str] = None
scopes: Optional[list] = None
rate_limit_requests: int = 60
rate_limit_window: int = 60
error: Optional[str] = None
cached_at: float = 0.0
# In-memory cache for validation results
# Key: API key, Value: ExternalValidationResult
_validation_cache: dict[str, ExternalValidationResult] = {}
def is_external_auth_enabled() -> bool:
"""Check if external authentication is enabled."""
return EXTERNAL_AUTH_ENABLED
def _get_cached_result(api_key: str) -> Optional[ExternalValidationResult]:
"""Get cached validation result if still valid."""
result = _validation_cache.get(api_key)
if result and (time.time() - result.cached_at) < API_KEY_CACHE_TTL:
return result
return None
def _cache_result(api_key: str, result: ExternalValidationResult):
"""Cache a validation result."""
result.cached_at = time.time()
_validation_cache[api_key] = result
# Clean up old entries periodically (keep cache size manageable)
if len(_validation_cache) > 1000:
now = time.time()
expired_keys = [
k for k, v in _validation_cache.items()
if (now - v.cached_at) >= API_KEY_CACHE_TTL
]
for k in expired_keys:
del _validation_cache[k]
async def validate_api_key_external(api_key: str, scope: str) -> Optional[ExternalValidationResult]:
"""
Validate an API key against mana-core-auth service.
Args:
api_key: The API key to validate (e.g., "sk_live_...")
scope: The required scope (e.g., "stt" or "tts")
Returns:
ExternalValidationResult if external auth is enabled and the key was validated.
None if external auth is disabled or the service is unavailable (fallback to local).
"""
if not EXTERNAL_AUTH_ENABLED:
return None
# Check cache first
cached = _get_cached_result(api_key)
if cached:
logger.debug(f"Using cached validation result for key prefix: {api_key[:12]}...")
# Check scope against cached result
if cached.valid and cached.scopes and scope not in cached.scopes:
return ExternalValidationResult(
valid=False,
error=f"API key does not have scope: {scope}",
)
return cached
# Call mana-core-auth validation endpoint
try:
async with httpx.AsyncClient(timeout=EXTERNAL_AUTH_TIMEOUT) as client:
response = await client.post(
f"{MANA_CORE_AUTH_URL}/api/v1/api-keys/validate",
json={"apiKey": api_key, "scope": scope},
)
if response.status_code == 200:
data = response.json()
result = ExternalValidationResult(
valid=data.get("valid", False),
user_id=data.get("userId"),
scopes=data.get("scopes", []),
rate_limit_requests=data.get("rateLimit", {}).get("requests", 60),
rate_limit_window=data.get("rateLimit", {}).get("window", 60),
error=data.get("error"),
)
_cache_result(api_key, result)
return result
else:
logger.warning(
f"External auth returned status {response.status_code}: {response.text}"
)
# Don't cache errors - allow retry
return ExternalValidationResult(
valid=False,
error=f"Auth service returned {response.status_code}",
)
except httpx.TimeoutException:
logger.warning("External auth service timeout - falling back to local auth")
return None
except httpx.ConnectError:
logger.warning("Cannot connect to external auth service - falling back to local auth")
return None
except Exception as e:
logger.error(f"External auth error: {e}")
return None
def clear_cache():
"""Clear the validation cache (for testing or runtime updates)."""
global _validation_cache
_validation_cache.clear()
logger.info("External auth cache cleared")

View file

@ -1,6 +1,6 @@
"""
F5-TTS Service for voice cloning synthesis.
Uses f5-tts-mlx optimized for Apple Silicon.
CUDA version using f5-tts PyTorch package.
"""
import logging
@ -15,14 +15,12 @@ import numpy as np
logger = logging.getLogger(__name__)
# Global singleton for lazy initialization
_f5_model = None
_f5_model_name = None
_f5_api = None
# Default model
DEFAULT_F5_MODEL = os.getenv("F5_MODEL", "lucasnewman/f5-tts-mlx")
DEFAULT_F5_MODEL = os.getenv("F5_MODEL", "F5-TTS")
# Default generation parameters
DEFAULT_DURATION = 10.0 # seconds
DEFAULT_STEPS = 32
DEFAULT_CFG_STRENGTH = 2.0
DEFAULT_SWAY_COEF = -1.0
@ -40,35 +38,25 @@ class F5Result:
def get_f5_model(model_name: str = DEFAULT_F5_MODEL):
"""
Get or create F5-TTS model instance (singleton pattern).
"""Get or create F5-TTS API instance (singleton pattern)."""
global _f5_api
Args:
model_name: HuggingFace model identifier
Returns:
F5TTS model instance
"""
global _f5_model, _f5_model_name
# Return existing model if same model name
if _f5_model is not None and _f5_model_name == model_name:
return _f5_model
if _f5_api is not None:
return _f5_api
logger.info(f"Loading F5-TTS model: {model_name}")
try:
from f5_tts_mlx import F5TTS
from f5_tts.api import F5TTS
_f5_model = F5TTS(model_name=model_name)
_f5_model_name = model_name
logger.info("F5-TTS model loaded successfully")
return _f5_model
_f5_api = F5TTS(model_type="F5-TTS")
logger.info("F5-TTS model loaded successfully (CUDA)")
return _f5_api
except ImportError as e:
logger.error(f"Failed to import f5_tts_mlx: {e}")
logger.error(f"Failed to import f5_tts: {e}")
raise RuntimeError(
"f5-tts-mlx not installed. Run: pip install f5-tts-mlx"
"f5-tts not installed. Run: pip install f5-tts"
)
except Exception as e:
logger.error(f"Failed to load F5-TTS model: {e}")
@ -77,7 +65,7 @@ def get_f5_model(model_name: str = DEFAULT_F5_MODEL):
def is_f5_loaded() -> bool:
"""Check if F5-TTS model is currently loaded."""
return _f5_model is not None
return _f5_api is not None
async def synthesize_f5(
@ -103,13 +91,14 @@ async def synthesize_f5(
cfg_strength: Classifier-free guidance strength
sway_coef: Sway sampling coefficient
speed: Speech speed multiplier
model_name: HuggingFace model identifier
model_name: Model identifier
Returns:
F5Result with audio data
"""
# Get model
model = get_f5_model(model_name)
import asyncio
api = get_f5_model(model_name)
logger.info(
f"Synthesizing with F5-TTS: text_length={len(text)}, "
@ -117,17 +106,26 @@ async def synthesize_f5(
)
try:
# Generate audio
audio, sample_rate = model.generate(
text=text,
ref_audio_path=reference_audio_path,
ref_audio_text=reference_text,
duration=duration,
steps=steps,
cfg_strength=cfg_strength,
sway_coef=sway_coef,
speed=speed,
)
# F5-TTS API infer method (runs synchronously, offload to thread)
loop = asyncio.get_event_loop()
def _generate():
wav, sr, _ = api.infer(
ref_file=reference_audio_path,
ref_text=reference_text,
gen_text=text,
nfe_step=steps,
cfg_strength=cfg_strength,
sway_sampling_coeff=sway_coef,
speed=speed,
)
return wav, sr
audio, sample_rate = await loop.run_in_executor(None, _generate)
# Convert to numpy if needed
if not isinstance(audio, np.ndarray):
audio = np.array(audio, dtype=np.float32)
# Calculate duration
audio_duration = len(audio) / sample_rate
@ -152,24 +150,8 @@ async def synthesize_f5_from_bytes(
audio_extension: str = ".wav",
**kwargs,
) -> F5Result:
"""
Synthesize speech using F5-TTS with reference audio as bytes.
Args:
text: Text to synthesize
reference_audio_bytes: Reference audio as bytes
reference_text: Transcript of the reference audio
audio_extension: File extension for temp file
**kwargs: Additional arguments passed to synthesize_f5
Returns:
F5Result with audio data
"""
# Save reference audio to temp file
with tempfile.NamedTemporaryFile(
suffix=audio_extension,
delete=False,
) as tmp:
"""Synthesize speech using F5-TTS with reference audio as bytes."""
with tempfile.NamedTemporaryFile(suffix=audio_extension, delete=False) as tmp:
tmp.write(reference_audio_bytes)
tmp_path = tmp.name
@ -182,7 +164,6 @@ async def synthesize_f5_from_bytes(
)
return result
finally:
# Clean up temp file
try:
Path(tmp_path).unlink()
except Exception:
@ -190,18 +171,7 @@ async def synthesize_f5_from_bytes(
def estimate_duration(text: str, speed: float = 1.0) -> float:
"""
Estimate audio duration from text.
Args:
text: Text to synthesize
speed: Speech speed multiplier
Returns:
Estimated duration in seconds
"""
# Rough estimate: ~150 words per minute at normal speed
# Average word length: ~5 characters
"""Estimate audio duration from text."""
words = len(text) / 5
minutes = words / 150
seconds = minutes * 60

View file

@ -1,6 +1,6 @@
"""
Kokoro TTS Service for fast preset voice synthesis.
Uses mlx-audio's Kokoro implementation optimized for Apple Silicon.
CUDA version using kokoro PyTorch package.
"""
import logging
@ -12,11 +12,10 @@ import numpy as np
logger = logging.getLogger(__name__)
# Global singleton for lazy initialization
_kokoro_model = None
_kokoro_model_name = None
_kokoro_pipeline = None
# Default model
DEFAULT_KOKORO_MODEL = "mlx-community/Kokoro-82M-bf16"
DEFAULT_KOKORO_MODEL = "hexgrad/Kokoro-82M"
# Available Kokoro voices (American Female/Male, British Female/Male)
KOKORO_VOICES = {
@ -67,35 +66,25 @@ class KokoroResult:
def get_kokoro_model(model_name: str = DEFAULT_KOKORO_MODEL):
"""
Get or create Kokoro model instance (singleton pattern).
"""Get or create Kokoro pipeline instance (singleton pattern)."""
global _kokoro_pipeline
Args:
model_name: HuggingFace model identifier
Returns:
Kokoro model instance
"""
global _kokoro_model, _kokoro_model_name
# Return existing model if same model name
if _kokoro_model is not None and _kokoro_model_name == model_name:
return _kokoro_model
if _kokoro_pipeline is not None:
return _kokoro_pipeline
logger.info(f"Loading Kokoro model: {model_name}")
try:
from mlx_audio.tts import load
from kokoro import KPipeline
_kokoro_model = load(model_name)
_kokoro_model_name = model_name
logger.info("Kokoro model loaded successfully")
return _kokoro_model
_kokoro_pipeline = KPipeline(lang_code="a") # 'a' for American English
logger.info("Kokoro pipeline loaded successfully")
return _kokoro_pipeline
except ImportError as e:
logger.error(f"Failed to import mlx_audio: {e}")
logger.error(f"Failed to import kokoro: {e}")
raise RuntimeError(
"mlx-audio not installed. Run: pip install mlx-audio"
"kokoro not installed. Run: pip install kokoro"
)
except Exception as e:
logger.error(f"Failed to load Kokoro model: {e}")
@ -104,7 +93,7 @@ def get_kokoro_model(model_name: str = DEFAULT_KOKORO_MODEL):
def is_kokoro_loaded() -> bool:
"""Check if Kokoro model is currently loaded."""
return _kokoro_model is not None
return _kokoro_pipeline is not None
def get_available_voices() -> dict[str, str]:
@ -125,7 +114,7 @@ async def synthesize_kokoro(
text: Text to synthesize
voice: Voice ID from KOKORO_VOICES
speed: Speech speed multiplier (0.5-2.0)
model_name: HuggingFace model identifier
model_name: Model identifier
Returns:
KokoroResult with audio data
@ -139,29 +128,18 @@ async def synthesize_kokoro(
speed = max(0.5, min(2.0, speed))
# Get model
model = get_kokoro_model(model_name)
pipeline = get_kokoro_model(model_name)
logger.info(f"Synthesizing with Kokoro: voice={voice}, speed={speed}, text_length={len(text)}")
try:
# Generate audio using mlx-audio's generate method
# Returns a generator of GenerationResult objects
result_gen = model.generate(
text=text,
voice=voice,
speed=speed,
)
# Collect all audio chunks from the generator
# Generate audio using kokoro pipeline
audio_chunks = []
sample_rate = 24000 # Default, will be updated from result
sample_rate = 24000 # Kokoro default
for result in result_gen:
# Each result has audio, sample_rate, audio_duration (string)
sample_rate = result.sample_rate
# Convert MLX array to numpy
audio_np = np.array(result.audio, dtype=np.float32)
for result in pipeline(text, voice=voice, speed=speed):
# result is a KPipelineResult with .audio (tensor) and .graphemes/.phonemes
audio_np = result.audio.numpy()
audio_chunks.append(audio_np)
# Concatenate all chunks

View file

@ -0,0 +1,114 @@
"""
VRAM Manager Automatic model unloading after idle timeout.
Tracks last usage time per model and unloads after configurable timeout.
Designed for shared GPU environments (multiple services on one RTX 3090).
Usage in a service:
from vram_manager import VramManager
vram = VramManager(idle_timeout=300) # 5 min
# Before using a model
vram.touch()
# Call periodically (e.g., from health check or background task)
vram.check_idle(unload_fn=my_unload_function)
"""
import os
import time
import logging
import threading
from typing import Optional, Callable
logger = logging.getLogger(__name__)
DEFAULT_IDLE_TIMEOUT = int(os.getenv("VRAM_IDLE_TIMEOUT", "300")) # 5 minutes
class VramManager:
def __init__(self, idle_timeout: int = DEFAULT_IDLE_TIMEOUT, service_name: str = "unknown"):
self.idle_timeout = idle_timeout
self.service_name = service_name
self.last_used: float = 0.0
self.model_loaded: bool = False
self._lock = threading.Lock()
self._timer: Optional[threading.Timer] = None
def touch(self):
"""Mark the model as recently used. Call before/after each inference."""
with self._lock:
self.last_used = time.time()
self.model_loaded = True
self._schedule_check()
def mark_loaded(self):
"""Mark that a model has been loaded into VRAM."""
with self._lock:
self.model_loaded = True
self.last_used = time.time()
self._schedule_check()
logger.info(f"[{self.service_name}] Model loaded, idle timeout: {self.idle_timeout}s")
def mark_unloaded(self):
"""Mark that a model has been unloaded from VRAM."""
with self._lock:
self.model_loaded = False
if self._timer:
self._timer.cancel()
self._timer = None
logger.info(f"[{self.service_name}] Model unloaded, VRAM freed")
def is_idle(self) -> bool:
"""Check if the model has been idle longer than the timeout."""
if not self.model_loaded:
return False
return (time.time() - self.last_used) > self.idle_timeout
def seconds_until_unload(self) -> Optional[float]:
"""Seconds until the model will be unloaded, or None if not loaded."""
if not self.model_loaded:
return None
remaining = self.idle_timeout - (time.time() - self.last_used)
return max(0, remaining)
def check_and_unload(self, unload_fn: Callable[[], None]) -> bool:
"""Check if idle and unload if so. Returns True if unloaded."""
if self.is_idle():
logger.info(f"[{self.service_name}] Idle for >{self.idle_timeout}s, unloading model...")
try:
unload_fn()
self.mark_unloaded()
return True
except Exception as e:
logger.error(f"[{self.service_name}] Failed to unload: {e}")
return False
def _schedule_check(self):
"""Schedule an idle check after the timeout period."""
if self._timer:
self._timer.cancel()
self._timer = threading.Timer(
self.idle_timeout + 5, # Small buffer
self._auto_check,
)
self._timer.daemon = True
self._timer.start()
def _auto_check(self):
"""Auto-triggered idle check (called by timer)."""
# This is just a log — actual unloading needs the unload_fn
# which depends on the service. The service should call check_and_unload.
if self.is_idle():
logger.info(f"[{self.service_name}] Model idle for >{self.idle_timeout}s — ready to unload")
def status(self) -> dict:
"""Get current VRAM manager status."""
return {
"model_loaded": self.model_loaded,
"idle_seconds": round(time.time() - self.last_used, 1) if self.model_loaded else None,
"idle_timeout": self.idle_timeout,
"seconds_until_unload": round(self.seconds_until_unload(), 1) if self.model_loaded else None,
}

View file

@ -0,0 +1,17 @@
"""mana-tts service runner."""
import os
import sys
os.chdir(r"C:\mana\services\mana-tts")
sys.path.insert(0, r"C:\mana\services\mana-tts")
# Load .env file
from dotenv import load_dotenv
load_dotenv(r"C:\mana\services\mana-tts\.env")
# Redirect stdout/stderr to log file
log = open(r"C:\mana\services\mana-tts\service.log", "w", buffering=1)
sys.stdout = log
sys.stderr = log
import uvicorn
uvicorn.run("app.main:app", host="0.0.0.0", port=3022, log_level="info")

View file

@ -0,0 +1,37 @@
"""mana-video-gen service runner."""
import os
import sys
os.chdir(r"C:\mana\services\mana-video-gen")
sys.path.insert(0, r"C:\mana\services\mana-video-gen")
# Redirect stdout/stderr to log file FIRST
log = open(r"C:\mana\services\mana-video-gen\service.log", "w", buffering=1)
sys.stdout = log
sys.stderr = log
# Load .env file
from dotenv import load_dotenv
load_dotenv(r"C:\mana\services\mana-video-gen\.env")
# Ensure FFmpeg is in PATH (LTX needs it for video encoding)
ffmpeg_dir = r"C:\Users\tills\AppData\Local\Microsoft\WinGet\Links"
if ffmpeg_dir not in os.environ.get("PATH", ""):
os.environ["PATH"] = ffmpeg_dir + os.pathsep + os.environ.get("PATH", "")
# HF token for model downloads
hf_token = os.environ.get("HF_TOKEN", "")
if hf_token:
os.environ["HUGGING_FACE_HUB_TOKEN"] = hf_token
# Pre-initialize CUDA before importing diffusers
import torch
if torch.cuda.is_available():
torch.cuda.init()
print(f"CUDA initialized: {torch.cuda.get_device_name(0)}", flush=True)
print(f"VRAM: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB", flush=True)
else:
print("WARNING: CUDA not available, LTX will be unusably slow on CPU", flush=True)
import uvicorn
uvicorn.run("app.main:app", host="0.0.0.0", port=3026, log_level="info")