diff --git a/services/mana-llm/.env.example b/services/mana-llm/.env.example index f304f55f9..98d47d8ba 100644 --- a/services/mana-llm/.env.example +++ b/services/mana-llm/.env.example @@ -23,3 +23,6 @@ CACHE_TTL=3600 # CORS CORS_ORIGINS=http://localhost:5173,https://mana.how + +# API key for cross-service auth (validated in src/api_auth.py) +GPU_API_KEY= diff --git a/services/mana-llm/service.pyw b/services/mana-llm/service.pyw new file mode 100644 index 000000000..5a21e29de --- /dev/null +++ b/services/mana-llm/service.pyw @@ -0,0 +1,17 @@ +"""mana-llm service runner (run with pythonw.exe to run headless).""" +import os +import sys +os.chdir(r"C:\mana\services\mana-llm") +sys.path.insert(0, r"C:\mana\services\mana-llm") + +# Load .env file +from dotenv import load_dotenv +load_dotenv(r"C:\mana\services\mana-llm\.env") + +# Redirect stdout/stderr to log file +log = open(r"C:\mana\services\mana-llm\service.log", "w", buffering=1) +sys.stdout = log +sys.stderr = log + +import uvicorn +uvicorn.run("src.main:app", host="0.0.0.0", port=3025, log_level="info") diff --git a/services/mana-llm/src/api_auth.py b/services/mana-llm/src/api_auth.py new file mode 100644 index 000000000..0f5813735 --- /dev/null +++ b/services/mana-llm/src/api_auth.py @@ -0,0 +1,53 @@ +""" +Simple API Key Authentication Middleware for GPU Services. + +Checks X-API-Key header or ?api_key query parameter. +Skips auth for /health, /docs, /openapi.json, /redoc endpoints. + +Environment variables: + GPU_API_KEY: Required API key (if empty, auth is disabled) + GPU_REQUIRE_AUTH: Enable/disable auth (default: true if GPU_API_KEY is set) +""" + +import os +import logging +from fastapi import Request +from fastapi.responses import JSONResponse +from starlette.middleware.base import BaseHTTPMiddleware + +logger = logging.getLogger(__name__) + +GPU_API_KEY = os.getenv("GPU_API_KEY", "") +GPU_REQUIRE_AUTH = os.getenv("GPU_REQUIRE_AUTH", "true" if GPU_API_KEY else "false").lower() == "true" + +# Endpoints that don't require auth +PUBLIC_PATHS = {"/health", "/docs", "/openapi.json", "/redoc", "/metrics"} + + +class ApiKeyMiddleware(BaseHTTPMiddleware): + async def dispatch(self, request: Request, call_next): + # Skip auth if disabled + if not GPU_REQUIRE_AUTH or not GPU_API_KEY: + return await call_next(request) + + # Skip auth for public endpoints + if request.url.path in PUBLIC_PATHS: + return await call_next(request) + + # Check API key from header or query param + api_key = request.headers.get("X-API-Key") or request.query_params.get("api_key") + + if not api_key: + return JSONResponse( + status_code=401, + content={"detail": "Missing API key. Provide X-API-Key header."}, + ) + + if api_key != GPU_API_KEY: + logger.warning(f"Invalid API key attempt from {request.client.host if request.client else 'unknown'}") + return JSONResponse( + status_code=401, + content={"detail": "Invalid API key."}, + ) + + return await call_next(request) diff --git a/services/mana-llm/src/config.py b/services/mana-llm/src/config.py index f35dfd135..b56fd5646 100644 --- a/services/mana-llm/src/config.py +++ b/services/mana-llm/src/config.py @@ -51,6 +51,7 @@ class Settings(BaseSettings): class Config: env_file = ".env" env_file_encoding = "utf-8" + extra = "ignore" settings = Settings() diff --git a/services/mana-llm/src/main.py b/services/mana-llm/src/main.py index d826ee4bb..b13519241 100644 --- a/services/mana-llm/src/main.py +++ b/services/mana-llm/src/main.py @@ -10,6 +10,7 @@ from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import Response from sse_starlette.sse import EventSourceResponse +from src.api_auth import ApiKeyMiddleware from src.config import settings from src.models import ( ChatCompletionRequest, @@ -70,6 +71,7 @@ app.add_middleware( allow_methods=["*"], allow_headers=["*"], ) +app.add_middleware(ApiKeyMiddleware) # Health endpoint diff --git a/services/mana-stt/app/auth.py b/services/mana-stt/app/auth.py index d20087c58..40258c730 100644 --- a/services/mana-stt/app/auth.py +++ b/services/mana-stt/app/auth.py @@ -1,41 +1,271 @@ """ -API Key Authentication for Mana STT Service. -Delegates to shared mana_auth package. +API Key Authentication for ManaCore STT Service + +Supports two authentication modes: +1. Local API keys: Configured via environment variables +2. External API keys: Validated via mana-core-auth service (when EXTERNAL_AUTH_ENABLED=true) + +Usage: + # Local keys + API_KEYS=sk-key1:name1,sk-key2:name2 + INTERNAL_API_KEY=sk-internal-xxx + + # External auth (for user-created keys via mana.how) + EXTERNAL_AUTH_ENABLED=true + MANA_CORE_AUTH_URL=http://localhost:3001 """ -# Re-export everything from shared auth for backward compatibility -import sys import os +import time +import logging +from typing import Optional +from collections import defaultdict +from dataclasses import dataclass, field -# Add shared-python to path -sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "..", "..", "packages", "shared-python")) +from fastapi import HTTPException, Security, Request +from fastapi.security import APIKeyHeader -from mana_auth import ( - APIKey, - AuthResult, - RateLimitInfo, - verify_api_key as _verify_api_key, - get_api_key_stats, - reload_api_keys, - api_key_header, - create_auth_dependency, -) -from mana_auth.external_auth import ( - ExternalValidationResult, +from .external_auth import ( is_external_auth_enabled, validate_api_key_external, + ExternalValidationResult, ) -from typing import Optional -from fastapi import Security, Request +logger = logging.getLogger(__name__) -# STT-specific auth dependency -verify_stt_key = create_auth_dependency("stt") +# Configuration +API_KEYS_ENV = os.getenv("API_KEYS", "") # Format: "sk-key1:name1,sk-key2:name2" +INTERNAL_API_KEY = os.getenv("INTERNAL_API_KEY", "") # Unlimited internal key +REQUIRE_AUTH = os.getenv("REQUIRE_AUTH", "true").lower() == "true" +RATE_LIMIT_REQUESTS = int(os.getenv("RATE_LIMIT_REQUESTS", "60")) # Per minute +RATE_LIMIT_WINDOW = int(os.getenv("RATE_LIMIT_WINDOW", "60")) # Seconds + + +@dataclass +class APIKey: + """API Key with metadata.""" + key: str + name: str + is_internal: bool = False + rate_limit: int = RATE_LIMIT_REQUESTS # Requests per window + + +@dataclass +class RateLimitInfo: + """Rate limit tracking per key.""" + requests: list = field(default_factory=list) + + def is_allowed(self, limit: int, window: int) -> bool: + """Check if request is allowed within rate limit.""" + now = time.time() + # Remove old requests outside window + self.requests = [t for t in self.requests if now - t < window] + + if len(self.requests) >= limit: + return False + + self.requests.append(now) + return True + + def remaining(self, limit: int, window: int) -> int: + """Get remaining requests in current window.""" + now = time.time() + self.requests = [t for t in self.requests if now - t < window] + return max(0, limit - len(self.requests)) + + +# Parse API keys from environment +def _parse_api_keys() -> dict[str, APIKey]: + """Parse API keys from environment variables.""" + keys = {} + + # Parse comma-separated keys + if API_KEYS_ENV: + for entry in API_KEYS_ENV.split(","): + entry = entry.strip() + if ":" in entry: + key, name = entry.split(":", 1) + else: + key, name = entry, "default" + keys[key.strip()] = APIKey(key=key.strip(), name=name.strip()) + + # Add internal key with no rate limit + if INTERNAL_API_KEY: + keys[INTERNAL_API_KEY] = APIKey( + key=INTERNAL_API_KEY, + name="internal", + is_internal=True, + rate_limit=999999, # Effectively unlimited + ) + + return keys + + +# Global state +_api_keys = _parse_api_keys() +_rate_limits: dict[str, RateLimitInfo] = defaultdict(RateLimitInfo) + +# Security scheme +api_key_header = APIKeyHeader(name="X-API-Key", auto_error=False) + + +@dataclass +class AuthResult: + """Result of authentication check.""" + authenticated: bool + key_name: Optional[str] = None + is_internal: bool = False + rate_limit_remaining: Optional[int] = None + user_id: Optional[str] = None # Set when using external auth async def verify_api_key( request: Request, api_key: Optional[str] = Security(api_key_header), ) -> AuthResult: - """Verify API key with STT scope.""" - return await _verify_api_key(request, scope="stt", api_key=api_key) + """ + Verify API key and check rate limits. + + Supports two authentication modes: + 1. External auth via mana-core-auth (for sk_live_ keys) + 2. Local auth via environment variables + + Returns AuthResult with authentication status. + Raises HTTPException if auth fails or rate limited. + """ + # Skip auth for health and docs endpoints + path = request.url.path + if path in ["/health", "/docs", "/openapi.json", "/redoc"]: + return AuthResult(authenticated=True, key_name="public") + + # If auth not required, allow all + if not REQUIRE_AUTH: + return AuthResult(authenticated=True, key_name="anonymous") + + # Check for API key + if not api_key: + logger.warning(f"Missing API key for {path} from {request.client.host if request.client else 'unknown'}") + raise HTTPException( + status_code=401, + detail="Missing API key. Provide X-API-Key header.", + headers={"WWW-Authenticate": "ApiKey"}, + ) + + # Try external auth first for sk_live_ keys (user-created keys via mana.how) + if api_key.startswith("sk_live_") and is_external_auth_enabled(): + external_result = await validate_api_key_external(api_key, "stt") + + if external_result is not None: + if external_result.valid: + # Use rate limits from external auth + rate_info = _rate_limits[api_key] + limit = external_result.rate_limit_requests + window = external_result.rate_limit_window + + if not rate_info.is_allowed(limit, window): + remaining = rate_info.remaining(limit, window) + logger.warning(f"Rate limit exceeded for external key") + raise HTTPException( + status_code=429, + detail=f"Rate limit exceeded. Try again in {window} seconds.", + headers={ + "X-RateLimit-Limit": str(limit), + "X-RateLimit-Remaining": str(remaining), + "X-RateLimit-Reset": str(int(time.time()) + window), + "Retry-After": str(window), + }, + ) + + remaining = rate_info.remaining(limit, window) + logger.debug(f"Authenticated external request from user {external_result.user_id} to {path}") + + return AuthResult( + authenticated=True, + key_name="external", + is_internal=False, + rate_limit_remaining=remaining, + user_id=external_result.user_id, + ) + else: + # External auth returned invalid + logger.warning(f"External auth failed: {external_result.error}") + raise HTTPException( + status_code=401, + detail=external_result.error or "Invalid API key.", + headers={"WWW-Authenticate": "ApiKey"}, + ) + # If external_result is None, fall through to local auth + + # Local auth: Validate key against environment variables + if api_key not in _api_keys: + logger.warning(f"Invalid API key attempt for {path}") + raise HTTPException( + status_code=401, + detail="Invalid API key.", + headers={"WWW-Authenticate": "ApiKey"}, + ) + + key_info = _api_keys[api_key] + + # Check rate limit (skip for internal keys) + if not key_info.is_internal: + rate_info = _rate_limits[api_key] + if not rate_info.is_allowed(key_info.rate_limit, RATE_LIMIT_WINDOW): + remaining = rate_info.remaining(key_info.rate_limit, RATE_LIMIT_WINDOW) + logger.warning(f"Rate limit exceeded for key '{key_info.name}'") + raise HTTPException( + status_code=429, + detail=f"Rate limit exceeded. Try again in {RATE_LIMIT_WINDOW} seconds.", + headers={ + "X-RateLimit-Limit": str(key_info.rate_limit), + "X-RateLimit-Remaining": str(remaining), + "X-RateLimit-Reset": str(int(time.time()) + RATE_LIMIT_WINDOW), + "Retry-After": str(RATE_LIMIT_WINDOW), + }, + ) + remaining = rate_info.remaining(key_info.rate_limit, RATE_LIMIT_WINDOW) + else: + remaining = None + + logger.debug(f"Authenticated request from '{key_info.name}' to {path}") + + return AuthResult( + authenticated=True, + key_name=key_info.name, + is_internal=key_info.is_internal, + rate_limit_remaining=remaining, + ) + + +def get_api_key_stats() -> dict: + """Get statistics about API keys (for admin endpoint).""" + stats = { + "total_keys": len(_api_keys), + "auth_required": REQUIRE_AUTH, + "rate_limit": { + "requests_per_window": RATE_LIMIT_REQUESTS, + "window_seconds": RATE_LIMIT_WINDOW, + }, + "keys": [], + } + + for key, info in _api_keys.items(): + # Don't expose actual keys, just metadata + masked_key = key[:8] + "..." if len(key) > 8 else "***" + rate_info = _rate_limits.get(key, RateLimitInfo()) + stats["keys"].append({ + "name": info.name, + "key_prefix": masked_key, + "is_internal": info.is_internal, + "requests_in_window": len(rate_info.requests), + "remaining": rate_info.remaining(info.rate_limit, RATE_LIMIT_WINDOW), + }) + + return stats + + +def reload_api_keys(): + """Reload API keys from environment (for runtime updates).""" + global _api_keys + _api_keys = _parse_api_keys() + logger.info(f"Reloaded {len(_api_keys)} API keys") diff --git a/services/mana-stt/app/external_auth.py b/services/mana-stt/app/external_auth.py index 18a1cfe68..6f64bd315 100644 --- a/services/mana-stt/app/external_auth.py +++ b/services/mana-stt/app/external_auth.py @@ -1,22 +1,145 @@ """ -External API Key Validation — delegates to shared mana_auth package. +External API Key Validation via mana-core-auth + +When EXTERNAL_AUTH_ENABLED=true, API keys are validated against the +central mana-core-auth service. This allows users to create and manage +API keys from the mana.how web interface. + +Results are cached for 5 minutes to reduce load on the auth service. """ -import sys import os +import time +import logging +import httpx +from typing import Optional +from dataclasses import dataclass -sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "..", "..", "packages", "shared-python")) +logger = logging.getLogger(__name__) -from mana_auth.external_auth import ( - ExternalValidationResult, - is_external_auth_enabled, - validate_api_key_external, - clear_cache, -) +# Configuration +EXTERNAL_AUTH_ENABLED = os.getenv("EXTERNAL_AUTH_ENABLED", "false").lower() == "true" +MANA_CORE_AUTH_URL = os.getenv("MANA_CORE_AUTH_URL", "http://localhost:3001") +API_KEY_CACHE_TTL = int(os.getenv("API_KEY_CACHE_TTL", "300")) # 5 minutes +EXTERNAL_AUTH_TIMEOUT = float(os.getenv("EXTERNAL_AUTH_TIMEOUT", "5.0")) # seconds -__all__ = [ - "ExternalValidationResult", - "is_external_auth_enabled", - "validate_api_key_external", - "clear_cache", -] + +@dataclass +class ExternalValidationResult: + """Result from external API key validation.""" + valid: bool + user_id: Optional[str] = None + scopes: Optional[list] = None + rate_limit_requests: int = 60 + rate_limit_window: int = 60 + error: Optional[str] = None + cached_at: float = 0.0 + + +# In-memory cache for validation results +# Key: API key, Value: ExternalValidationResult +_validation_cache: dict[str, ExternalValidationResult] = {} + + +def is_external_auth_enabled() -> bool: + """Check if external authentication is enabled.""" + return EXTERNAL_AUTH_ENABLED + + +def _get_cached_result(api_key: str) -> Optional[ExternalValidationResult]: + """Get cached validation result if still valid.""" + result = _validation_cache.get(api_key) + if result and (time.time() - result.cached_at) < API_KEY_CACHE_TTL: + return result + return None + + +def _cache_result(api_key: str, result: ExternalValidationResult): + """Cache a validation result.""" + result.cached_at = time.time() + _validation_cache[api_key] = result + + # Clean up old entries periodically (keep cache size manageable) + if len(_validation_cache) > 1000: + now = time.time() + expired_keys = [ + k for k, v in _validation_cache.items() + if (now - v.cached_at) >= API_KEY_CACHE_TTL + ] + for k in expired_keys: + del _validation_cache[k] + + +async def validate_api_key_external(api_key: str, scope: str) -> Optional[ExternalValidationResult]: + """ + Validate an API key against mana-core-auth service. + + Args: + api_key: The API key to validate (e.g., "sk_live_...") + scope: The required scope (e.g., "stt" or "tts") + + Returns: + ExternalValidationResult if external auth is enabled and the key was validated. + None if external auth is disabled or the service is unavailable (fallback to local). + """ + if not EXTERNAL_AUTH_ENABLED: + return None + + # Check cache first + cached = _get_cached_result(api_key) + if cached: + logger.debug(f"Using cached validation result for key prefix: {api_key[:12]}...") + # Check scope against cached result + if cached.valid and cached.scopes and scope not in cached.scopes: + return ExternalValidationResult( + valid=False, + error=f"API key does not have scope: {scope}", + ) + return cached + + # Call mana-core-auth validation endpoint + try: + async with httpx.AsyncClient(timeout=EXTERNAL_AUTH_TIMEOUT) as client: + response = await client.post( + f"{MANA_CORE_AUTH_URL}/api/v1/api-keys/validate", + json={"apiKey": api_key, "scope": scope}, + ) + + if response.status_code == 200: + data = response.json() + result = ExternalValidationResult( + valid=data.get("valid", False), + user_id=data.get("userId"), + scopes=data.get("scopes", []), + rate_limit_requests=data.get("rateLimit", {}).get("requests", 60), + rate_limit_window=data.get("rateLimit", {}).get("window", 60), + error=data.get("error"), + ) + _cache_result(api_key, result) + return result + else: + logger.warning( + f"External auth returned status {response.status_code}: {response.text}" + ) + # Don't cache errors - allow retry + return ExternalValidationResult( + valid=False, + error=f"Auth service returned {response.status_code}", + ) + + except httpx.TimeoutException: + logger.warning("External auth service timeout - falling back to local auth") + return None + except httpx.ConnectError: + logger.warning("Cannot connect to external auth service - falling back to local auth") + return None + except Exception as e: + logger.error(f"External auth error: {e}") + return None + + +def clear_cache(): + """Clear the validation cache (for testing or runtime updates).""" + global _validation_cache + _validation_cache.clear() + logger.info("External auth cache cleared") diff --git a/services/mana-stt/app/main.py b/services/mana-stt/app/main.py index 447aca438..f07e33a0e 100644 --- a/services/mana-stt/app/main.py +++ b/services/mana-stt/app/main.py @@ -1,6 +1,6 @@ """ -Mana STT API Service -Speech-to-Text with Whisper (MLX), WhisperX (CUDA), Voxtral (vLLM), and Mistral API (fallback) +ManaCore STT API Service (WhisperX Edition) +Speech-to-Text with WhisperX: transcription, word timestamps, speaker diarization. Run with: uvicorn app.main:app --host 0.0.0.0 --port 3020 """ @@ -34,66 +34,42 @@ CORS_ORIGINS = os.getenv( "https://mana.how,https://chat.mana.how,http://localhost:5173" ).split(",") -# vLLM configuration (disabled by default - has issues on macOS CPU) +# vLLM configuration VLLM_URL = os.getenv("VLLM_URL", "http://localhost:8100") USE_VLLM = os.getenv("USE_VLLM", "false").lower() == "true" -# WhisperX configuration (CUDA GPU server) -USE_WHISPERX = os.getenv("USE_WHISPERX", "false").lower() == "true" - # Response models +class WordInfo(BaseModel): + word: str + start: float + end: float + score: Optional[float] = None + speaker: Optional[str] = None + + +class SegmentInfo(BaseModel): + start: float + end: float + text: str + speaker: Optional[str] = None + + class TranscriptionResponse(BaseModel): text: str language: Optional[str] = None model: str latency_ms: Optional[float] = None duration_seconds: Optional[float] = None - - -class WordTimestampResponse(BaseModel): - word: str - start: float - end: float - score: float = 0.0 - speaker: Optional[str] = None - - -class SegmentResponse(BaseModel): - start: float - end: float - text: str - speaker: Optional[str] = None - words: list[WordTimestampResponse] = [] - - -class UtteranceResponse(BaseModel): - speaker: int - text: str - offset: int # milliseconds - duration: int # milliseconds - - -class RichTranscriptionResponse(BaseModel): - """Extended response with segments, utterances, and speaker diarization.""" - text: str - language: Optional[str] = None - model: str - latency_ms: Optional[float] = None - duration_seconds: Optional[float] = None - segments: list[SegmentResponse] = [] - utterances: list[UtteranceResponse] = [] - speakers: dict[str, str] = {} - speaker_map: dict[str, int] = {} - languages: list[str] = [] - primary_language: Optional[str] = None - words: list[WordTimestampResponse] = [] + words: Optional[list[WordInfo]] = None + segments: Optional[list[SegmentInfo]] = None + speakers: Optional[list[str]] = None class HealthResponse(BaseModel): status: str whisper_loaded: bool - whisperx_available: bool + whisperx: bool vllm_available: bool vllm_url: Optional[str] = None mistral_api_available: bool @@ -103,7 +79,6 @@ class HealthResponse(BaseModel): class ModelsResponse(BaseModel): whisper: list - whisperx: list voxtral_vllm: list default_whisper: str @@ -111,7 +86,6 @@ class ModelsResponse(BaseModel): # Track loaded models models_status = { "whisper_loaded": False, - "whisperx_available": False, "vllm_available": False, } @@ -119,45 +93,28 @@ models_status = { @asynccontextmanager async def lifespan(app: FastAPI): """Startup and shutdown events.""" - logger.info("Starting Mana STT Service...") + logger.info("Starting ManaCore STT Service (WhisperX Edition)...") # Check vLLM availability if USE_VLLM: from app.vllm_service import check_health health = await check_health() models_status["vllm_available"] = health.get("status") == "healthy" - if models_status["vllm_available"]: - logger.info(f"vLLM server available at {VLLM_URL}") - else: - logger.warning(f"vLLM server not available: {health}") - - # Check WhisperX availability - if USE_WHISPERX: - try: - from app.whisperx_service import is_available as whisperx_available - models_status["whisperx_available"] = whisperx_available() - if models_status["whisperx_available"]: - logger.info("WhisperX (CUDA) available") - else: - logger.warning("WhisperX not available (whisperx package not installed)") - except Exception as e: - logger.warning(f"WhisperX check failed: {e}") # Check Mistral API from app.voxtral_api_service import is_available as api_available if api_available(): logger.info("Mistral API fallback configured") - # Optionally preload Whisper - if PRELOAD_MODELS: - logger.info("Preloading Whisper model...") - try: - from app.whisper_service import get_whisper_model - get_whisper_model(DEFAULT_WHISPER_MODEL) - models_status["whisper_loaded"] = True - logger.info("Whisper model preloaded") - except Exception as e: - logger.warning(f"Failed to preload Whisper: {e}") + # Always preload WhisperX model at startup (avoids timeout on first request) + logger.info("Preloading WhisperX model...") + try: + from app.whisper_service import get_whisper_model + get_whisper_model(DEFAULT_WHISPER_MODEL) + models_status["whisper_loaded"] = True + logger.info("WhisperX model preloaded successfully") + except Exception as e: + logger.warning(f"Failed to preload WhisperX: {e}") logger.info(f"STT Service ready on port {PORT}") yield @@ -166,9 +123,9 @@ async def lifespan(app: FastAPI): # Create FastAPI app app = FastAPI( - title="Mana STT Service", - description="Speech-to-Text API with Whisper (MLX), Voxtral (vLLM), and Mistral API", - version="2.0.0", + title="ManaCore STT Service", + description="Speech-to-Text API with WhisperX (word timestamps + speaker diarization)", + version="3.0.0", lifespan=lifespan, ) @@ -193,13 +150,15 @@ async def health_check(): return HealthResponse( status="healthy", whisper_loaded=models_status["whisper_loaded"], - whisperx_available=models_status["whisperx_available"], + whisperx=True, vllm_available=vllm_health.get("status") == "healthy", vllm_url=VLLM_URL if USE_VLLM else None, mistral_api_available=api_available(), auth_required=REQUIRE_AUTH, models={ "default_whisper": DEFAULT_WHISPER_MODEL, + "engine": "whisperx", + "features": ["transcription", "word_timestamps", "speaker_diarization"], }, ) @@ -212,17 +171,8 @@ async def list_models(auth: AuthResult = Depends(verify_api_key)): vllm_models = await get_models() - whisperx_models = [] - if USE_WHISPERX: - try: - from app.whisperx_service import AVAILABLE_MODELS as wx_models - whisperx_models = wx_models - except ImportError: - pass - return ModelsResponse( whisper=whisper_models, - whisperx=whisperx_models, voxtral_vllm=vllm_models, default_whisper=DEFAULT_WHISPER_MODEL, ) @@ -234,16 +184,22 @@ async def transcribe_whisper( file: UploadFile = File(..., description="Audio file to transcribe"), language: Optional[str] = Form(None, description="Language code (auto-detect if not provided)"), model: Optional[str] = Form(None, description="Whisper model to use"), + align: bool = Form(True, description="Enable word-level timestamp alignment"), + diarize: bool = Form(False, description="Enable speaker diarization"), + min_speakers: Optional[int] = Form(None, description="Min expected speakers (helps diarization)"), + max_speakers: Optional[int] = Form(None, description="Max expected speakers"), auth: AuthResult = Depends(verify_api_key), ): """ - Transcribe audio using Whisper (Lightning MLX). + Transcribe audio using WhisperX. - Best for: General transcription, many languages - Supported formats: mp3, wav, m4a, flac, ogg, webm + Features: + - Word-level timestamps (align=true, default) + - Speaker diarization (diarize=true, opt-in) + + Supported formats: mp3, wav, m4a, flac, ogg, webm, mp4 Max file size: 100MB """ - # Add rate limit headers if auth.rate_limit_remaining is not None: response.headers["X-RateLimit-Remaining"] = str(auth.rate_limit_remaining) @@ -274,20 +230,51 @@ async def transcribe_whisper( filename=file.filename, language=language, model_name=model_name, + align=align, + diarize=diarize, + min_speakers=min_speakers, + max_speakers=max_speakers, ) models_status["whisper_loaded"] = True latency_ms = (time.time() - start_time) * 1000 - return TranscriptionResponse( + # Build response + resp = TranscriptionResponse( text=result.text, language=result.language, - model=f"whisper-{model_name}", + model=f"whisperx-{model_name}", latency_ms=latency_ms, + duration_seconds=result.duration, ) + # Add word timestamps if available + if result.words: + resp.words = [ + WordInfo( + word=w.word, + start=w.start, + end=w.end, + score=w.score, + speaker=w.speaker, + ) + for w in result.words + ] + + # Add segments + if result.segments: + resp.segments = [ + SegmentInfo(**s) for s in result.segments + ] + + # Add speakers + if result.speakers: + resp.speakers = result.speakers + + return resp + except Exception as e: - logger.error(f"Whisper transcription error: {e}") + logger.error(f"WhisperX transcription error: {e}") raise HTTPException(status_code=500, detail=str(e)) @@ -296,22 +283,10 @@ async def transcribe_voxtral( response: Response, file: UploadFile = File(..., description="Audio file to transcribe"), language: str = Form("de", description="Language code"), - use_realtime: bool = Form(False, description="Use Realtime 4B model for lower latency"), + use_realtime: bool = Form(False, description="Use Realtime 4B model"), auth: AuthResult = Depends(verify_api_key), ): - """ - Transcribe audio using Voxtral via vLLM server. - - Models: - - Voxtral Mini 3B (default): Best quality - - Voxtral Realtime 4B: Lower latency (<500ms) - - Falls back to Mistral API if vLLM is unavailable. - - Supported formats: mp3, wav, m4a, flac, ogg, webm - Max file size: 100MB - """ - # Add rate limit headers + """Transcribe audio using Voxtral via vLLM or Mistral API.""" if auth.rate_limit_remaining is not None: response.headers["X-RateLimit-Remaining"] = str(auth.rate_limit_remaining) @@ -345,49 +320,30 @@ async def transcribe_voxtral( if USE_VLLM: health = await check_health() if health.get("status") == "healthy": - logger.info("Using vLLM for Voxtral transcription") if use_realtime: result = await transcribe_with_realtime( - audio_bytes=audio_bytes, - filename=file.filename, - language=language, + audio_bytes=audio_bytes, filename=file.filename, language=language, ) else: result = await vllm_transcribe( - audio_bytes=audio_bytes, - filename=file.filename, - language=language, + audio_bytes=audio_bytes, filename=file.filename, language=language, ) - return TranscriptionResponse( - text=result.text, - language=result.language, - model=result.model, - latency_ms=result.latency_ms, - duration_seconds=result.duration_seconds, + text=result.text, language=result.language, model=result.model, + latency_ms=result.latency_ms, duration_seconds=result.duration_seconds, ) # Fallback to Mistral API if api_available(): - logger.info("Falling back to Mistral API") result = await api_transcribe( - audio_bytes=audio_bytes, - filename=file.filename, - language=language, + audio_bytes=audio_bytes, filename=file.filename, language=language, ) - return TranscriptionResponse( - text=result.text, - language=result.language, - model=result.model, - latency_ms=None, + text=result.text, language=result.language, model=result.model, duration_seconds=result.duration_seconds, ) - raise HTTPException( - status_code=503, - detail="Voxtral not available. Start vLLM server or configure MISTRAL_API_KEY." - ) + raise HTTPException(status_code=503, detail="Voxtral not available.") except HTTPException: raise @@ -396,273 +352,30 @@ async def transcribe_voxtral( raise HTTPException(status_code=500, detail=str(e)) -@app.post("/transcribe/voxtral/api", response_model=TranscriptionResponse) -async def transcribe_voxtral_api( - response: Response, - file: UploadFile = File(..., description="Audio file to transcribe"), - language: Optional[str] = Form(None, description="Language code (auto-detect if not provided)"), - diarization: bool = Form(False, description="Enable speaker diarization"), - auth: AuthResult = Depends(verify_api_key), -): - """ - Transcribe audio using Mistral's Voxtral API directly. - - Features: - - Speaker diarization - - Auto language detection - - High quality (~4% WER) - - Requires MISTRAL_API_KEY environment variable. - """ - # Add rate limit headers - if auth.rate_limit_remaining is not None: - response.headers["X-RateLimit-Remaining"] = str(auth.rate_limit_remaining) - - from app.voxtral_api_service import is_available, transcribe_audio_bytes - - if not is_available(): - raise HTTPException( - status_code=503, - detail="Mistral API not configured. Set MISTRAL_API_KEY environment variable." - ) - - if not file.filename: - raise HTTPException(status_code=400, detail="No file provided") - - try: - audio_bytes = await file.read() - if len(audio_bytes) > 100 * 1024 * 1024: - raise HTTPException(status_code=400, detail="File too large (max 100MB)") - - result = await transcribe_audio_bytes( - audio_bytes=audio_bytes, - filename=file.filename, - language=language, - diarization=diarization, - ) - - return TranscriptionResponse( - text=result.text, - language=result.language, - model=result.model, - duration_seconds=result.duration_seconds, - ) - - except Exception as e: - logger.error(f"Mistral API error: {e}") - raise HTTPException(status_code=500, detail=str(e)) - - -@app.post("/transcribe/whisperx", response_model=RichTranscriptionResponse) -async def transcribe_whisperx( - response: Response, - file: UploadFile = File(..., description="Audio file to transcribe"), - language: Optional[str] = Form(None, description="Language code (auto-detect if not provided)"), - model: Optional[str] = Form(None, description="Whisper model to use"), - diarization: bool = Form(True, description="Enable speaker diarization"), - alignment: bool = Form(True, description="Enable word-level alignment"), - min_speakers: Optional[int] = Form(None, description="Minimum expected speakers"), - max_speakers: Optional[int] = Form(None, description="Maximum expected speakers"), - auth: AuthResult = Depends(verify_api_key), -): - """ - Transcribe audio using WhisperX (CUDA GPU). - - Returns rich transcription with: - - Word-level timestamps (via forced alignment) - - Speaker diarization (via pyannote.audio) - - Memoro-compatible utterances with speaker IDs - - Requires NVIDIA GPU with CUDA and USE_WHISPERX=true. - Diarization requires HF_TOKEN with pyannote model access. - - Supported formats: mp3, wav, m4a, flac, ogg, webm, mp4 - Max file size: 100MB - """ - if auth.rate_limit_remaining is not None: - response.headers["X-RateLimit-Remaining"] = str(auth.rate_limit_remaining) - - if not USE_WHISPERX: - raise HTTPException( - status_code=503, - detail="WhisperX not enabled. Set USE_WHISPERX=true on a CUDA-capable server." - ) - - if not file.filename: - raise HTTPException(status_code=400, detail="No file provided") - - allowed_extensions = {".mp3", ".wav", ".m4a", ".flac", ".ogg", ".webm", ".mp4"} - ext = os.path.splitext(file.filename)[1].lower() - if ext not in allowed_extensions: - raise HTTPException( - status_code=400, - detail=f"Unsupported file type: {ext}. Allowed: {allowed_extensions}" - ) - - start_time = time.time() - - try: - from app.whisperx_service import transcribe_audio_bytes - - audio_bytes = await file.read() - if len(audio_bytes) > 100 * 1024 * 1024: - raise HTTPException(status_code=400, detail="File too large (max 100MB)") - - model_name = model or DEFAULT_WHISPER_MODEL - - result = await transcribe_audio_bytes( - audio_bytes=audio_bytes, - filename=file.filename, - language=language, - model_name=model_name, - enable_diarization=diarization, - enable_alignment=alignment, - min_speakers=min_speakers, - max_speakers=max_speakers, - ) - - latency_ms = (time.time() - start_time) * 1000 - - return RichTranscriptionResponse( - text=result.text, - language=result.language, - model=f"whisperx-{model_name}", - latency_ms=latency_ms, - duration_seconds=result.duration_seconds, - segments=[ - SegmentResponse( - start=s.start, - end=s.end, - text=s.text, - speaker=s.speaker, - words=[ - WordTimestampResponse( - word=w.word, - start=w.start, - end=w.end, - score=w.score, - speaker=w.speaker, - ) - for w in s.words - ], - ) - for s in result.segments - ], - utterances=[ - UtteranceResponse( - speaker=u.speaker, - text=u.text, - offset=u.offset, - duration=u.duration, - ) - for u in result.utterances - ], - speakers=result.speakers, - speaker_map={k: v for k, v in result.speaker_map.items()}, - languages=result.languages, - primary_language=result.primary_language, - words=[ - WordTimestampResponse( - word=w.word, - start=w.start, - end=w.end, - score=w.score, - speaker=w.speaker, - ) - for w in result.words - ], - ) - - except HTTPException: - raise - except ImportError: - raise HTTPException( - status_code=503, - detail="WhisperX not installed. Install with: pip install -r requirements-cuda.txt" - ) - except Exception as e: - logger.error(f"WhisperX transcription error: {e}") - raise HTTPException(status_code=500, detail=str(e)) - - @app.post("/transcribe/auto", response_model=TranscriptionResponse) async def transcribe_auto( response: Response, file: UploadFile = File(..., description="Audio file to transcribe"), language: Optional[str] = Form(None, description="Language hint"), - prefer: str = Form("whisper", description="Preferred: 'whisper', 'whisperx', or 'voxtral'"), + prefer: str = Form("whisper", description="Preferred: 'whisper' or 'voxtral'"), auth: AuthResult = Depends(verify_api_key), ): - """ - Transcribe with automatic model selection and fallback. - - Fallback chain: - - whisper: Whisper → WhisperX → Voxtral → Mistral API - - whisperx: WhisperX → Whisper → Voxtral → Mistral API - - voxtral: Voxtral → WhisperX → Whisper → Mistral API - """ - # Add rate limit headers + """Auto-select best model with fallback chain.""" if auth.rate_limit_remaining is not None: response.headers["X-RateLimit-Remaining"] = str(auth.rate_limit_remaining) - async def try_whisperx_simple(): - """Try WhisperX and return as simple TranscriptionResponse.""" - if not USE_WHISPERX: - raise RuntimeError("WhisperX not enabled") - from app.whisperx_service import transcribe_audio_bytes as wx_transcribe - audio_bytes = await file.read() - result = await wx_transcribe( - audio_bytes=audio_bytes, - filename=file.filename or "audio.wav", - language=language, - enable_diarization=False, - enable_alignment=False, - ) - return TranscriptionResponse( - text=result.text, - language=result.language, - model=f"whisperx-{DEFAULT_WHISPER_MODEL}", - latency_ms=None, - duration_seconds=result.duration_seconds, - ) - - # Build fallback chain based on preference - if prefer == "whisperx": - chain = [ - ("WhisperX", try_whisperx_simple), - ("Whisper", lambda: transcribe_whisper(response, file, language, None, auth)), - ("Voxtral", lambda: transcribe_voxtral(response, file, language or "de", False, auth)), - ("Mistral API", lambda: transcribe_voxtral_api(response, file, language, False, auth)), - ] - elif prefer == "voxtral": - chain = [ - ("Voxtral", lambda: transcribe_voxtral(response, file, language or "de", False, auth)), - ("WhisperX", try_whisperx_simple), - ("Whisper", lambda: transcribe_whisper(response, file, language, None, auth)), - ("Mistral API", lambda: transcribe_voxtral_api(response, file, language, False, auth)), - ] - else: - chain = [ - ("Whisper", lambda: transcribe_whisper(response, file, language, None, auth)), - ("WhisperX", try_whisperx_simple), - ("Voxtral", lambda: transcribe_voxtral(response, file, language or "de", False, auth)), - ("Mistral API", lambda: transcribe_voxtral_api(response, file, language, False, auth)), - ] - - last_error = None - for name, fn in chain: + if prefer == "voxtral": try: - result = await fn() - return result - except Exception as e: - last_error = e - logger.warning(f"{name} failed: {e}") + return await transcribe_voxtral(response, file, language or "de", False, auth) + except Exception: await file.seek(0) - - raise HTTPException( - status_code=503, - detail=f"All transcription backends failed. Last error: {last_error}" - ) + return await transcribe_whisper(response, file, language, None, True, False, None, None, auth) + else: + try: + return await transcribe_whisper(response, file, language, None, True, False, None, None, auth) + except Exception: + await file.seek(0) + return await transcribe_voxtral(response, file, language or "de", False, auth) @app.exception_handler(Exception) @@ -676,9 +389,4 @@ async def global_exception_handler(request, exc): if __name__ == "__main__": import uvicorn - uvicorn.run( - "app.main:app", - host="0.0.0.0", - port=PORT, - reload=False, - ) + uvicorn.run("app.main:app", host="0.0.0.0", port=PORT, reload=False) diff --git a/services/mana-stt/app/vram_manager.py b/services/mana-stt/app/vram_manager.py new file mode 100644 index 000000000..89b5656ae --- /dev/null +++ b/services/mana-stt/app/vram_manager.py @@ -0,0 +1,114 @@ +""" +VRAM Manager — Automatic model unloading after idle timeout. + +Tracks last usage time per model and unloads after configurable timeout. +Designed for shared GPU environments (multiple services on one RTX 3090). + +Usage in a service: + from vram_manager import VramManager + + vram = VramManager(idle_timeout=300) # 5 min + + # Before using a model + vram.touch() + + # Call periodically (e.g., from health check or background task) + vram.check_idle(unload_fn=my_unload_function) +""" + +import os +import time +import logging +import threading +from typing import Optional, Callable + +logger = logging.getLogger(__name__) + +DEFAULT_IDLE_TIMEOUT = int(os.getenv("VRAM_IDLE_TIMEOUT", "300")) # 5 minutes + + +class VramManager: + def __init__(self, idle_timeout: int = DEFAULT_IDLE_TIMEOUT, service_name: str = "unknown"): + self.idle_timeout = idle_timeout + self.service_name = service_name + self.last_used: float = 0.0 + self.model_loaded: bool = False + self._lock = threading.Lock() + self._timer: Optional[threading.Timer] = None + + def touch(self): + """Mark the model as recently used. Call before/after each inference.""" + with self._lock: + self.last_used = time.time() + self.model_loaded = True + self._schedule_check() + + def mark_loaded(self): + """Mark that a model has been loaded into VRAM.""" + with self._lock: + self.model_loaded = True + self.last_used = time.time() + self._schedule_check() + logger.info(f"[{self.service_name}] Model loaded, idle timeout: {self.idle_timeout}s") + + def mark_unloaded(self): + """Mark that a model has been unloaded from VRAM.""" + with self._lock: + self.model_loaded = False + if self._timer: + self._timer.cancel() + self._timer = None + logger.info(f"[{self.service_name}] Model unloaded, VRAM freed") + + def is_idle(self) -> bool: + """Check if the model has been idle longer than the timeout.""" + if not self.model_loaded: + return False + return (time.time() - self.last_used) > self.idle_timeout + + def seconds_until_unload(self) -> Optional[float]: + """Seconds until the model will be unloaded, or None if not loaded.""" + if not self.model_loaded: + return None + remaining = self.idle_timeout - (time.time() - self.last_used) + return max(0, remaining) + + def check_and_unload(self, unload_fn: Callable[[], None]) -> bool: + """Check if idle and unload if so. Returns True if unloaded.""" + if self.is_idle(): + logger.info(f"[{self.service_name}] Idle for >{self.idle_timeout}s, unloading model...") + try: + unload_fn() + self.mark_unloaded() + return True + except Exception as e: + logger.error(f"[{self.service_name}] Failed to unload: {e}") + return False + + def _schedule_check(self): + """Schedule an idle check after the timeout period.""" + if self._timer: + self._timer.cancel() + + self._timer = threading.Timer( + self.idle_timeout + 5, # Small buffer + self._auto_check, + ) + self._timer.daemon = True + self._timer.start() + + def _auto_check(self): + """Auto-triggered idle check (called by timer).""" + # This is just a log — actual unloading needs the unload_fn + # which depends on the service. The service should call check_and_unload. + if self.is_idle(): + logger.info(f"[{self.service_name}] Model idle for >{self.idle_timeout}s — ready to unload") + + def status(self) -> dict: + """Get current VRAM manager status.""" + return { + "model_loaded": self.model_loaded, + "idle_seconds": round(time.time() - self.last_used, 1) if self.model_loaded else None, + "idle_timeout": self.idle_timeout, + "seconds_until_unload": round(self.seconds_until_unload(), 1) if self.model_loaded else None, + } diff --git a/services/mana-stt/app/whisper_service.py b/services/mana-stt/app/whisper_service.py index db5e17ca9..821e22d9b 100644 --- a/services/mana-stt/app/whisper_service.py +++ b/services/mana-stt/app/whisper_service.py @@ -1,6 +1,11 @@ """ -Whisper STT Service using Lightning Whisper MLX -Optimized for Apple Silicon (M1/M2/M3/M4) +Whisper STT Service using WhisperX (CUDA) +Provides: transcription, word-level timestamps, speaker diarization. + +WhisperX pipeline: +1. faster-whisper for transcription +2. wav2vec2 for forced alignment (precise word timestamps) +3. pyannote-audio for speaker diarization """ import os @@ -8,12 +13,55 @@ import tempfile import logging from pathlib import Path from typing import Optional -from dataclasses import dataclass +from dataclasses import dataclass, field logger = logging.getLogger(__name__) -# Lazy load to avoid import errors if not installed -_whisper_model = None +# Lazy load +_whisperx_model = None +_align_model = None +_align_metadata = None +_diarize_pipeline = None + +# Config +HF_TOKEN = os.getenv("HF_TOKEN", "") + +# VRAM management — unload after 10 min idle (STT uses ~3GB) +from app.vram_manager import VramManager +_vram = VramManager( + idle_timeout=int(os.getenv("VRAM_IDLE_TIMEOUT", "600")), + service_name="mana-stt", +) + + +def unload_models(): + """Unload all WhisperX models from GPU to free VRAM.""" + global _whisperx_model, _align_model, _align_metadata, _diarize_pipeline + import torch + + if _whisperx_model is not None: + del _whisperx_model + _whisperx_model = None + if _align_model is not None: + del _align_model + _align_model = None + _align_metadata = None + if _diarize_pipeline is not None: + del _diarize_pipeline + _diarize_pipeline = None + + torch.cuda.empty_cache() + _vram.mark_unloaded() + logger.info("WhisperX models unloaded, VRAM freed") + + +@dataclass +class WordSegment: + word: str + start: float + end: float + score: Optional[float] = None + speaker: Optional[str] = None @dataclass @@ -22,83 +70,231 @@ class TranscriptionResult: language: Optional[str] = None duration: Optional[float] = None segments: Optional[list] = None + words: Optional[list[WordSegment]] = field(default_factory=list) + speakers: Optional[list[str]] = field(default_factory=list) -def get_whisper_model(model_name: str = "large-v3", batch_size: int = 12): - """Get or create Whisper model instance (singleton pattern).""" - global _whisper_model +def get_whisper_model(model_name: str = "large-v3", **kwargs): + """Get or create WhisperX model instance (singleton).""" + global _whisperx_model - if _whisper_model is None: - logger.info(f"Loading Whisper model: {model_name}") - try: - from lightning_whisper_mlx import LightningWhisperMLX - _whisper_model = LightningWhisperMLX( - model=model_name, - batch_size=batch_size, - quant=None # Use full precision for best quality - ) - logger.info(f"Whisper model loaded successfully: {model_name}") - except ImportError as e: - logger.error(f"Failed to import lightning_whisper_mlx: {e}") - raise RuntimeError( - "lightning-whisper-mlx not installed. " - "Run: pip install lightning-whisper-mlx" - ) - except Exception as e: - logger.error(f"Failed to load Whisper model: {e}") - raise + if _whisperx_model is not None: + return _whisperx_model - return _whisper_model + logger.info(f"Loading WhisperX model: {model_name}") + try: + import whisperx + + device = os.getenv("WHISPER_DEVICE", "cuda") + compute_type = os.getenv("WHISPER_COMPUTE_TYPE", "float16") + + default_language = os.getenv("WHISPER_DEFAULT_LANGUAGE", "de") + _whisperx_model = whisperx.load_model( + model_name, + device=device, + compute_type=compute_type, + language=default_language, + ) + logger.info(f"WhisperX model loaded: {model_name} on {device} ({compute_type})") + _vram.mark_loaded() + except ImportError as e: + logger.error(f"Failed to import whisperx: {e}") + raise RuntimeError("whisperx not installed. Run: pip install whisperx") + except Exception as e: + logger.error(f"Failed to load WhisperX model: {e}") + raise + + return _whisperx_model + + +def _get_align_model(language: str, device: str = "cuda"): + """Get or create alignment model for a language.""" + global _align_model, _align_metadata + + import whisperx + + # Reload if language changed (alignment models are language-specific) + if _align_model is None: + logger.info(f"Loading alignment model for language: {language}") + _align_model, _align_metadata = whisperx.load_align_model( + language_code=language, + device=device, + ) + logger.info("Alignment model loaded") + + return _align_model, _align_metadata + + +def _get_diarize_pipeline(device: str = "cuda"): + """Get or create speaker diarization pipeline.""" + global _diarize_pipeline + + if _diarize_pipeline is not None: + return _diarize_pipeline + + import torch + from pyannote.audio import Pipeline + + token = HF_TOKEN or os.getenv("HUGGING_FACE_HUB_TOKEN", "") + if not token: + logger.warning("No HF_TOKEN set — speaker diarization may fail for gated models") + + logger.info("Loading speaker diarization pipeline (pyannote)...") + _diarize_pipeline = Pipeline.from_pretrained( + "pyannote/speaker-diarization-3.1", + token=token, + ) + _diarize_pipeline.to(torch.device(device)) + logger.info("Diarization pipeline loaded") + return _diarize_pipeline def transcribe_audio( audio_path: str, language: Optional[str] = None, model_name: str = "large-v3", + align: bool = True, + diarize: bool = False, + min_speakers: Optional[int] = None, + max_speakers: Optional[int] = None, ) -> TranscriptionResult: """ - Transcribe audio file using Lightning Whisper MLX. + Transcribe audio using WhisperX with optional alignment and diarization. Args: - audio_path: Path to audio file (mp3, wav, m4a, etc.) - language: Optional language code (e.g., 'de', 'en'). Auto-detect if None. + audio_path: Path to audio file + language: Language code (auto-detect if None) model_name: Whisper model to use + align: Enable word-level timestamp alignment + diarize: Enable speaker diarization + min_speakers: Minimum expected speakers (helps diarization) + max_speakers: Maximum expected speakers Returns: - TranscriptionResult with text and metadata + TranscriptionResult with text, word timestamps, and speaker info """ + import whisperx + + device = os.getenv("WHISPER_DEVICE", "cuda") model = get_whisper_model(model_name) - logger.info(f"Transcribing: {audio_path}") + logger.info(f"Transcribing: {audio_path} (align={align}, diarize={diarize})") - try: - # Lightning Whisper MLX returns dict with 'text' key - result = model.transcribe( - audio_path=audio_path, - language=language, - ) + # Check and unload if idle, then reload + _vram.check_and_unload(unload_models) + _vram.touch() - # Handle different return formats - if isinstance(result, dict): - text = result.get("text", "") - segments = result.get("segments", []) - detected_language = result.get("language", language) - else: - text = str(result) - segments = [] - detected_language = language + # Step 1: Load audio + audio = whisperx.load_audio(audio_path) - logger.info(f"Transcription complete: {len(text)} characters") + # Step 2: Transcribe with faster-whisper + transcribe_kwargs = {"batch_size": 16} + if language: + transcribe_kwargs["language"] = language + result = model.transcribe(audio, **transcribe_kwargs) + detected_language = result.get("language", language or "en") - return TranscriptionResult( - text=text.strip(), - language=detected_language, - segments=segments, - ) + # Step 3: Align (word-level timestamps) + if align and result["segments"]: + try: + align_model, metadata = _get_align_model(detected_language, device) + result = whisperx.align( + result["segments"], + align_model, + metadata, + audio, + device, + return_char_alignments=False, + ) + logger.info("Word alignment complete") + except Exception as e: + logger.warning(f"Alignment failed (continuing without): {e}") - except Exception as e: - logger.error(f"Transcription failed: {e}") - raise + # Step 4: Diarize (speaker identification) + if diarize: + try: + import torch + import torchaudio + + diarize_pipe = _get_diarize_pipeline(device) + + # pyannote needs waveform as tensor, not the whisperx audio array + waveform = torch.from_numpy(audio).unsqueeze(0).float() + diarize_input = {"waveform": waveform, "sample_rate": 16000} + + diarize_kwargs = {} + if min_speakers is not None: + diarize_kwargs["min_speakers"] = min_speakers + if max_speakers is not None: + diarize_kwargs["max_speakers"] = max_speakers + + diarize_output = diarize_pipe(diarize_input, **diarize_kwargs) + + # pyannote 4.x returns DiarizeOutput, extract the Annotation + if hasattr(diarize_output, "speaker_diarization"): + diarize_annotation = diarize_output.speaker_diarization + else: + diarize_annotation = diarize_output + + # Convert pyannote output to DataFrame for whisperx + import pandas as pd + diarize_rows = [] + for turn, _, speaker in diarize_annotation.itertracks(yield_label=True): + diarize_rows.append({ + "start": turn.start, + "end": turn.end, + "speaker": speaker, + }) + + diarize_df = pd.DataFrame(diarize_rows) + result = whisperx.assign_word_speakers(diarize_df, result) + logger.info("Speaker diarization complete") + except Exception as e: + logger.warning(f"Diarization failed (continuing without): {e}") + import traceback + traceback.print_exc() + + # Build response + segments = result.get("segments", []) + full_text_parts = [] + all_words = [] + speaker_set = set() + + for seg in segments: + full_text_parts.append(seg.get("text", "")) + speaker = seg.get("speaker") + if speaker: + speaker_set.add(speaker) + + for word_info in seg.get("words", []): + all_words.append(WordSegment( + word=word_info.get("word", ""), + start=word_info.get("start", 0.0), + end=word_info.get("end", 0.0), + score=word_info.get("score"), + speaker=word_info.get("speaker", speaker), + )) + + text = " ".join(full_text_parts) + + _vram.touch() + logger.info( + f"Transcription complete: {len(text)} chars, " + f"{len(all_words)} words, {len(speaker_set)} speakers" + ) + + return TranscriptionResult( + text=text.strip(), + language=detected_language, + segments=[{ + "start": s.get("start", 0), + "end": s.get("end", 0), + "text": s.get("text", ""), + "speaker": s.get("speaker"), + } for s in segments], + words=all_words, + speakers=sorted(speaker_set), + ) async def transcribe_audio_bytes( @@ -106,53 +302,57 @@ async def transcribe_audio_bytes( filename: str, language: Optional[str] = None, model_name: str = "large-v3", + align: bool = True, + diarize: bool = False, + min_speakers: Optional[int] = None, + max_speakers: Optional[int] = None, ) -> TranscriptionResult: - """ - Transcribe audio from bytes (for API uploads). + """Transcribe audio from bytes (for API uploads).""" + import asyncio - Args: - audio_bytes: Raw audio file bytes - filename: Original filename (for extension detection) - language: Optional language code - model_name: Whisper model to use - - Returns: - TranscriptionResult - """ - # Get file extension ext = Path(filename).suffix or ".wav" - # Write to temp file with tempfile.NamedTemporaryFile(suffix=ext, delete=False) as tmp: tmp.write(audio_bytes) tmp_path = tmp.name try: - result = transcribe_audio( - audio_path=tmp_path, - language=language, - model_name=model_name, + # Run in thread pool to avoid blocking the event loop + loop = asyncio.get_event_loop() + result = await loop.run_in_executor( + None, + lambda: transcribe_audio( + audio_path=tmp_path, + language=language, + model_name=model_name, + align=align, + diarize=diarize, + min_speakers=min_speakers, + max_speakers=max_speakers, + ), ) return result finally: - # Clean up temp file try: os.unlink(tmp_path) except Exception: pass -# Available models for Lightning Whisper MLX +# Available models AVAILABLE_MODELS = [ "tiny", - "small", + "tiny.en", "base", + "base.en", + "small", + "small.en", "medium", - "large", + "medium.en", + "large-v1", "large-v2", - "large-v3", # Recommended for Mac Mini - "distil-small.en", - "distil-medium.en", + "large-v3", + "large-v3-turbo", "distil-large-v2", "distil-large-v3", ] diff --git a/services/mana-stt/app/whisper_service_cuda.py b/services/mana-stt/app/whisper_service_cuda.py deleted file mode 100644 index 90460a454..000000000 --- a/services/mana-stt/app/whisper_service_cuda.py +++ /dev/null @@ -1,175 +0,0 @@ -""" -Whisper STT Service using faster-whisper (CUDA) -Optimized for NVIDIA GPUs (RTX 3090 etc.) - -Drop-in replacement for whisper_service.py (MLX version). -Uses faster-whisper with CTranslate2 for GPU-accelerated inference. -""" - -import os -import tempfile -import logging -from pathlib import Path -from typing import Optional -from dataclasses import dataclass - -logger = logging.getLogger(__name__) - -# Lazy load to avoid import errors if not installed -_whisper_model = None - - -@dataclass -class TranscriptionResult: - text: str - language: Optional[str] = None - duration: Optional[float] = None - segments: Optional[list] = None - - -def get_whisper_model(model_name: str = "large-v3", **kwargs): - """Get or create Whisper model instance (singleton pattern).""" - global _whisper_model - - if _whisper_model is None: - logger.info(f"Loading Whisper model: {model_name}") - try: - from faster_whisper import WhisperModel - - # Use CUDA with float16 for RTX 3090 - compute_type = os.getenv("WHISPER_COMPUTE_TYPE", "float16") - device = os.getenv("WHISPER_DEVICE", "cuda") - - _whisper_model = WhisperModel( - model_name, - device=device, - compute_type=compute_type, - ) - logger.info(f"Whisper model loaded: {model_name} on {device} ({compute_type})") - except ImportError as e: - logger.error(f"Failed to import faster_whisper: {e}") - raise RuntimeError( - "faster-whisper not installed. " - "Run: pip install faster-whisper" - ) - except Exception as e: - logger.error(f"Failed to load Whisper model: {e}") - raise - - return _whisper_model - - -def transcribe_audio( - audio_path: str, - language: Optional[str] = None, - model_name: str = "large-v3", -) -> TranscriptionResult: - """ - Transcribe audio file using faster-whisper (CUDA). - - Args: - audio_path: Path to audio file (mp3, wav, m4a, etc.) - language: Optional language code (e.g., 'de', 'en'). Auto-detect if None. - model_name: Whisper model to use - - Returns: - TranscriptionResult with text and metadata - """ - model = get_whisper_model(model_name) - - logger.info(f"Transcribing: {audio_path}") - - try: - segments, info = model.transcribe( - audio_path, - language=language, - beam_size=5, - vad_filter=True, # Filter out silence - ) - - # Collect all segments - all_segments = [] - full_text_parts = [] - for segment in segments: - full_text_parts.append(segment.text) - all_segments.append({ - "start": segment.start, - "end": segment.end, - "text": segment.text, - }) - - text = " ".join(full_text_parts) - detected_language = info.language if info else language - - logger.info(f"Transcription complete: {len(text)} characters, language={detected_language}") - - return TranscriptionResult( - text=text.strip(), - language=detected_language, - duration=info.duration if info else None, - segments=all_segments, - ) - - except Exception as e: - logger.error(f"Transcription failed: {e}") - raise - - -async def transcribe_audio_bytes( - audio_bytes: bytes, - filename: str, - language: Optional[str] = None, - model_name: str = "large-v3", -) -> TranscriptionResult: - """ - Transcribe audio from bytes (for API uploads). - - Args: - audio_bytes: Raw audio file bytes - filename: Original filename (for extension detection) - language: Optional language code - model_name: Whisper model to use - - Returns: - TranscriptionResult - """ - # Get file extension - ext = Path(filename).suffix or ".wav" - - # Write to temp file - with tempfile.NamedTemporaryFile(suffix=ext, delete=False) as tmp: - tmp.write(audio_bytes) - tmp_path = tmp.name - - try: - result = transcribe_audio( - audio_path=tmp_path, - language=language, - model_name=model_name, - ) - return result - finally: - # Clean up temp file - try: - os.unlink(tmp_path) - except Exception: - pass - - -# Available models for faster-whisper -AVAILABLE_MODELS = [ - "tiny", - "tiny.en", - "base", - "base.en", - "small", - "small.en", - "medium", - "medium.en", - "large-v1", - "large-v2", - "large-v3", - "large-v3-turbo", - "distil-large-v2", - "distil-large-v3", -] diff --git a/services/mana-stt/app/whisperx_service.py b/services/mana-stt/app/whisperx_service.py deleted file mode 100644 index 156c1581a..000000000 --- a/services/mana-stt/app/whisperx_service.py +++ /dev/null @@ -1,419 +0,0 @@ -""" -WhisperX STT Service using faster-whisper + pyannote (CUDA) -Optimized for NVIDIA GPUs (RTX 3090 etc.) - -Features: -- Word-level timestamps via forced alignment -- Speaker diarization via pyannote.audio -- Segment-level timestamps with speaker labels -- VAD filtering for silence removal - -Requires HuggingFace token for pyannote models: - export HF_TOKEN=hf_xxx - # Accept terms at: https://huggingface.co/pyannote/speaker-diarization-3.1 -""" - -import os -import tempfile -import logging -import time -from pathlib import Path -from typing import Optional -from dataclasses import dataclass, field - -logger = logging.getLogger(__name__) - -# Lazy-loaded singletons -_whisper_model = None -_align_model = None -_align_metadata = None -_diarize_pipeline = None - -HF_TOKEN = os.getenv("HF_TOKEN") -DEFAULT_MODEL = os.getenv("WHISPER_MODEL", "large-v3") -DEVICE = os.getenv("WHISPER_DEVICE", "cuda") -COMPUTE_TYPE = os.getenv("WHISPER_COMPUTE_TYPE", "float16") -BATCH_SIZE = int(os.getenv("WHISPERX_BATCH_SIZE", "16")) - - -@dataclass -class WordInfo: - """Word with timestamp.""" - word: str - start: float - end: float - score: float = 0.0 - speaker: Optional[str] = None - - -@dataclass -class SegmentInfo: - """Segment with speaker and word-level detail.""" - start: float - end: float - text: str - speaker: Optional[str] = None - words: list[WordInfo] = field(default_factory=list) - - -@dataclass -class Utterance: - """Speaker utterance in Memoro-compatible format.""" - speaker: int - text: str - offset: int # milliseconds - duration: int # milliseconds - - -@dataclass -class WhisperXResult: - """Rich transcription result with alignment and diarization.""" - text: str - language: Optional[str] = None - duration_seconds: Optional[float] = None - segments: list[SegmentInfo] = field(default_factory=list) - utterances: list[Utterance] = field(default_factory=list) - speakers: dict[str, str] = field(default_factory=dict) - speaker_map: dict[str, int] = field(default_factory=dict) - languages: list[str] = field(default_factory=list) - primary_language: Optional[str] = None - words: list[WordInfo] = field(default_factory=list) - - -def get_whisper_model(model_name: str = None): - """Load faster-whisper model (singleton).""" - global _whisper_model - model_name = model_name or DEFAULT_MODEL - - if _whisper_model is None: - logger.info(f"Loading WhisperX model: {model_name} on {DEVICE} ({COMPUTE_TYPE})") - try: - import whisperx - - _whisper_model = whisperx.load_model( - model_name, - device=DEVICE, - compute_type=COMPUTE_TYPE, - ) - logger.info(f"WhisperX model loaded: {model_name}") - except ImportError: - raise RuntimeError( - "whisperx not installed. " - "Run: pip install whisperx" - ) - - return _whisper_model - - -def get_align_model(language_code: str): - """Load alignment model for a specific language (cached per language).""" - global _align_model, _align_metadata - - try: - import whisperx - - _align_model, _align_metadata = whisperx.load_align_model( - language_code=language_code, - device=DEVICE, - ) - logger.info(f"Alignment model loaded for language: {language_code}") - return _align_model, _align_metadata - except Exception as e: - logger.warning(f"Failed to load alignment model for {language_code}: {e}") - return None, None - - -def get_diarize_pipeline(): - """Load pyannote speaker diarization pipeline (singleton).""" - global _diarize_pipeline - - if _diarize_pipeline is None: - if not HF_TOKEN: - logger.warning("HF_TOKEN not set — diarization disabled") - return None - - try: - import whisperx - - _diarize_pipeline = whisperx.DiarizationPipeline( - use_auth_token=HF_TOKEN, - device=DEVICE, - ) - logger.info("Diarization pipeline loaded") - except Exception as e: - logger.warning(f"Failed to load diarization pipeline: {e}") - return None - - return _diarize_pipeline - - -def _build_utterances(segments: list[SegmentInfo]) -> tuple[list[Utterance], dict[str, str], dict[str, int]]: - """ - Build Memoro-compatible utterances from diarized segments. - Groups consecutive segments by the same speaker. - """ - if not segments: - return [], {}, {} - - # Collect unique speakers - speaker_labels = sorted(set( - s.speaker for s in segments if s.speaker is not None - )) - speaker_map: dict[str, int] = {} - speakers: dict[str, str] = {} - for idx, label in enumerate(speaker_labels): - speaker_map[label] = idx - speakers[str(idx)] = label - - # Merge consecutive segments with the same speaker - utterances: list[Utterance] = [] - current_speaker = None - current_text_parts: list[str] = [] - current_start = 0.0 - current_end = 0.0 - - for seg in segments: - sp = seg.speaker or "SPEAKER_00" - if sp != current_speaker: - # Flush previous - if current_speaker is not None and current_text_parts: - utterances.append(Utterance( - speaker=speaker_map.get(current_speaker, 0), - text=" ".join(current_text_parts).strip(), - offset=int(current_start * 1000), - duration=int((current_end - current_start) * 1000), - )) - current_speaker = sp - current_text_parts = [seg.text] - current_start = seg.start - current_end = seg.end - else: - current_text_parts.append(seg.text) - current_end = seg.end - - # Flush last - if current_speaker is not None and current_text_parts: - utterances.append(Utterance( - speaker=speaker_map.get(current_speaker, 0), - text=" ".join(current_text_parts).strip(), - offset=int(current_start * 1000), - duration=int((current_end - current_start) * 1000), - )) - - return utterances, speakers, speaker_map - - -def transcribe_audio( - audio_path: str, - language: Optional[str] = None, - model_name: Optional[str] = None, - enable_diarization: bool = True, - enable_alignment: bool = True, - min_speakers: Optional[int] = None, - max_speakers: Optional[int] = None, -) -> WhisperXResult: - """ - Transcribe audio with WhisperX: alignment + diarization. - - Args: - audio_path: Path to audio file - language: Language code (e.g., 'de', 'en'). Auto-detect if None. - model_name: Whisper model to use - enable_diarization: Run speaker diarization - enable_alignment: Run forced word alignment - min_speakers: Minimum expected speakers (hint for diarization) - max_speakers: Maximum expected speakers (hint for diarization) - - Returns: - WhisperXResult with full transcription, segments, utterances, speakers - """ - import whisperx - - start_time = time.time() - - # 1. Load audio - audio = whisperx.load_audio(audio_path) - audio_duration = len(audio) / 16000 # whisperx resamples to 16kHz - - # 2. Transcribe with faster-whisper - model = get_whisper_model(model_name) - transcribe_result = model.transcribe( - audio, - batch_size=BATCH_SIZE, - language=language, - ) - - detected_language = transcribe_result.get("language", language or "en") - raw_segments = transcribe_result.get("segments", []) - - logger.info( - f"Transcription: {len(raw_segments)} segments, " - f"language={detected_language}, " - f"duration={audio_duration:.1f}s" - ) - - # 3. Forced alignment (word-level timestamps) - if enable_alignment and raw_segments: - align_model, align_metadata = get_align_model(detected_language) - if align_model is not None: - try: - transcribe_result = whisperx.align( - raw_segments, - align_model, - align_metadata, - audio, - DEVICE, - return_char_alignments=False, - ) - raw_segments = transcribe_result.get("segments", raw_segments) - logger.info("Word alignment complete") - except Exception as e: - logger.warning(f"Alignment failed, using segment-level timestamps: {e}") - - # 4. Speaker diarization - if enable_diarization: - diarize_pipeline = get_diarize_pipeline() - if diarize_pipeline is not None: - try: - diarize_kwargs = {} - if min_speakers is not None: - diarize_kwargs["min_speakers"] = min_speakers - if max_speakers is not None: - diarize_kwargs["max_speakers"] = max_speakers - - diarize_segments = diarize_pipeline( - audio_path, - **diarize_kwargs, - ) - transcribe_result = whisperx.assign_word_speakers( - diarize_segments, transcribe_result - ) - raw_segments = transcribe_result.get("segments", raw_segments) - logger.info("Diarization complete") - except Exception as e: - logger.warning(f"Diarization failed: {e}") - - # 5. Build structured result - segments: list[SegmentInfo] = [] - all_words: list[WordInfo] = [] - full_text_parts: list[str] = [] - - for seg in raw_segments: - seg_words: list[WordInfo] = [] - for w in seg.get("words", []): - wi = WordInfo( - word=w.get("word", ""), - start=w.get("start", 0.0), - end=w.get("end", 0.0), - score=w.get("score", 0.0), - speaker=w.get("speaker"), - ) - seg_words.append(wi) - all_words.append(wi) - - segment = SegmentInfo( - start=seg.get("start", 0.0), - end=seg.get("end", 0.0), - text=seg.get("text", "").strip(), - speaker=seg.get("speaker"), - words=seg_words, - ) - segments.append(segment) - full_text_parts.append(segment.text) - - full_text = " ".join(full_text_parts) - - # 6. Build utterances (Memoro-compatible) - utterances, speakers, speaker_map = _build_utterances(segments) - - latency = time.time() - start_time - logger.info(f"WhisperX complete in {latency:.1f}s: {len(full_text)} chars, {len(speakers)} speakers") - - return WhisperXResult( - text=full_text, - language=detected_language, - duration_seconds=audio_duration, - segments=segments, - utterances=utterances, - speakers=speakers, - speaker_map=speaker_map, - languages=[detected_language] if detected_language else [], - primary_language=detected_language, - words=all_words, - ) - - -async def transcribe_audio_bytes( - audio_bytes: bytes, - filename: str, - language: Optional[str] = None, - model_name: Optional[str] = None, - enable_diarization: bool = True, - enable_alignment: bool = True, - min_speakers: Optional[int] = None, - max_speakers: Optional[int] = None, -) -> WhisperXResult: - """ - Transcribe audio from bytes (for API uploads). - - Args: - audio_bytes: Raw audio file bytes - filename: Original filename (for extension detection) - language: Optional language code - model_name: Whisper model to use - enable_diarization: Run speaker diarization - enable_alignment: Run forced word alignment - min_speakers: Min expected speakers - max_speakers: Max expected speakers - - Returns: - WhisperXResult with full transcription data - """ - ext = Path(filename).suffix or ".wav" - - with tempfile.NamedTemporaryFile(suffix=ext, delete=False) as tmp: - tmp.write(audio_bytes) - tmp_path = tmp.name - - try: - return transcribe_audio( - audio_path=tmp_path, - language=language, - model_name=model_name, - enable_diarization=enable_diarization, - enable_alignment=enable_alignment, - min_speakers=min_speakers, - max_speakers=max_speakers, - ) - finally: - try: - os.unlink(tmp_path) - except Exception: - pass - - -def is_available() -> bool: - """Check if WhisperX dependencies are installed.""" - try: - import whisperx - return True - except ImportError: - return False - - -AVAILABLE_MODELS = [ - "tiny", - "tiny.en", - "base", - "base.en", - "small", - "small.en", - "medium", - "medium.en", - "large-v1", - "large-v2", - "large-v3", - "large-v3-turbo", - "distil-large-v2", - "distil-large-v3", -] diff --git a/services/mana-stt/service.pyw b/services/mana-stt/service.pyw new file mode 100644 index 000000000..056059e98 --- /dev/null +++ b/services/mana-stt/service.pyw @@ -0,0 +1,34 @@ +"""mana-stt service runner.""" +import os +import sys + +os.chdir(r"C:\mana\services\mana-stt") +sys.path.insert(0, r"C:\mana\services\mana-stt") + +# Redirect stdout/stderr to log file FIRST (before any imports that warn) +log = open(r"C:\mana\services\mana-stt\service.log", "w", buffering=1) +sys.stdout = log +sys.stderr = log + +# Load .env file +from dotenv import load_dotenv +load_dotenv(r"C:\mana\services\mana-stt\.env") + +# Ensure FFmpeg is in PATH +ffmpeg_dir = r"C:\Users\tills\AppData\Local\Microsoft\WinGet\Links" +if ffmpeg_dir not in os.environ.get("PATH", ""): + os.environ["PATH"] = ffmpeg_dir + os.pathsep + os.environ.get("PATH", "") + +# Set HF token +hf_token = os.environ.get("HF_TOKEN", "") +if hf_token: + os.environ["HUGGING_FACE_HUB_TOKEN"] = hf_token + +# Pre-initialize CUDA before importing whisperx (avoids hangs) +import torch +if torch.cuda.is_available(): + torch.cuda.init() + print(f"CUDA initialized: {torch.cuda.get_device_name(0)}", flush=True) + +import uvicorn +uvicorn.run("app.main:app", host="0.0.0.0", port=3020, log_level="info") diff --git a/services/mana-tts/app/auth.py b/services/mana-tts/app/auth.py index f6d3315cf..40258c730 100644 --- a/services/mana-tts/app/auth.py +++ b/services/mana-tts/app/auth.py @@ -1,39 +1,271 @@ """ -API Key Authentication for Mana TTS Service. -Delegates to shared mana_auth package. +API Key Authentication for ManaCore STT Service + +Supports two authentication modes: +1. Local API keys: Configured via environment variables +2. External API keys: Validated via mana-core-auth service (when EXTERNAL_AUTH_ENABLED=true) + +Usage: + # Local keys + API_KEYS=sk-key1:name1,sk-key2:name2 + INTERNAL_API_KEY=sk-internal-xxx + + # External auth (for user-created keys via mana.how) + EXTERNAL_AUTH_ENABLED=true + MANA_CORE_AUTH_URL=http://localhost:3001 """ -import sys import os +import time +import logging +from typing import Optional +from collections import defaultdict +from dataclasses import dataclass, field -sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "..", "..", "packages", "shared-python")) +from fastapi import HTTPException, Security, Request +from fastapi.security import APIKeyHeader -from mana_auth import ( - APIKey, - AuthResult, - RateLimitInfo, - verify_api_key as _verify_api_key, - get_api_key_stats, - reload_api_keys, - api_key_header, - create_auth_dependency, -) -from mana_auth.external_auth import ( - ExternalValidationResult, +from .external_auth import ( is_external_auth_enabled, validate_api_key_external, + ExternalValidationResult, ) -from typing import Optional -from fastapi import Security, Request +logger = logging.getLogger(__name__) -# TTS-specific auth dependency -verify_tts_key = create_auth_dependency("tts") +# Configuration +API_KEYS_ENV = os.getenv("API_KEYS", "") # Format: "sk-key1:name1,sk-key2:name2" +INTERNAL_API_KEY = os.getenv("INTERNAL_API_KEY", "") # Unlimited internal key +REQUIRE_AUTH = os.getenv("REQUIRE_AUTH", "true").lower() == "true" +RATE_LIMIT_REQUESTS = int(os.getenv("RATE_LIMIT_REQUESTS", "60")) # Per minute +RATE_LIMIT_WINDOW = int(os.getenv("RATE_LIMIT_WINDOW", "60")) # Seconds + + +@dataclass +class APIKey: + """API Key with metadata.""" + key: str + name: str + is_internal: bool = False + rate_limit: int = RATE_LIMIT_REQUESTS # Requests per window + + +@dataclass +class RateLimitInfo: + """Rate limit tracking per key.""" + requests: list = field(default_factory=list) + + def is_allowed(self, limit: int, window: int) -> bool: + """Check if request is allowed within rate limit.""" + now = time.time() + # Remove old requests outside window + self.requests = [t for t in self.requests if now - t < window] + + if len(self.requests) >= limit: + return False + + self.requests.append(now) + return True + + def remaining(self, limit: int, window: int) -> int: + """Get remaining requests in current window.""" + now = time.time() + self.requests = [t for t in self.requests if now - t < window] + return max(0, limit - len(self.requests)) + + +# Parse API keys from environment +def _parse_api_keys() -> dict[str, APIKey]: + """Parse API keys from environment variables.""" + keys = {} + + # Parse comma-separated keys + if API_KEYS_ENV: + for entry in API_KEYS_ENV.split(","): + entry = entry.strip() + if ":" in entry: + key, name = entry.split(":", 1) + else: + key, name = entry, "default" + keys[key.strip()] = APIKey(key=key.strip(), name=name.strip()) + + # Add internal key with no rate limit + if INTERNAL_API_KEY: + keys[INTERNAL_API_KEY] = APIKey( + key=INTERNAL_API_KEY, + name="internal", + is_internal=True, + rate_limit=999999, # Effectively unlimited + ) + + return keys + + +# Global state +_api_keys = _parse_api_keys() +_rate_limits: dict[str, RateLimitInfo] = defaultdict(RateLimitInfo) + +# Security scheme +api_key_header = APIKeyHeader(name="X-API-Key", auto_error=False) + + +@dataclass +class AuthResult: + """Result of authentication check.""" + authenticated: bool + key_name: Optional[str] = None + is_internal: bool = False + rate_limit_remaining: Optional[int] = None + user_id: Optional[str] = None # Set when using external auth async def verify_api_key( request: Request, api_key: Optional[str] = Security(api_key_header), ) -> AuthResult: - """Verify API key with TTS scope.""" - return await _verify_api_key(request, scope="tts", api_key=api_key) + """ + Verify API key and check rate limits. + + Supports two authentication modes: + 1. External auth via mana-core-auth (for sk_live_ keys) + 2. Local auth via environment variables + + Returns AuthResult with authentication status. + Raises HTTPException if auth fails or rate limited. + """ + # Skip auth for health and docs endpoints + path = request.url.path + if path in ["/health", "/docs", "/openapi.json", "/redoc"]: + return AuthResult(authenticated=True, key_name="public") + + # If auth not required, allow all + if not REQUIRE_AUTH: + return AuthResult(authenticated=True, key_name="anonymous") + + # Check for API key + if not api_key: + logger.warning(f"Missing API key for {path} from {request.client.host if request.client else 'unknown'}") + raise HTTPException( + status_code=401, + detail="Missing API key. Provide X-API-Key header.", + headers={"WWW-Authenticate": "ApiKey"}, + ) + + # Try external auth first for sk_live_ keys (user-created keys via mana.how) + if api_key.startswith("sk_live_") and is_external_auth_enabled(): + external_result = await validate_api_key_external(api_key, "stt") + + if external_result is not None: + if external_result.valid: + # Use rate limits from external auth + rate_info = _rate_limits[api_key] + limit = external_result.rate_limit_requests + window = external_result.rate_limit_window + + if not rate_info.is_allowed(limit, window): + remaining = rate_info.remaining(limit, window) + logger.warning(f"Rate limit exceeded for external key") + raise HTTPException( + status_code=429, + detail=f"Rate limit exceeded. Try again in {window} seconds.", + headers={ + "X-RateLimit-Limit": str(limit), + "X-RateLimit-Remaining": str(remaining), + "X-RateLimit-Reset": str(int(time.time()) + window), + "Retry-After": str(window), + }, + ) + + remaining = rate_info.remaining(limit, window) + logger.debug(f"Authenticated external request from user {external_result.user_id} to {path}") + + return AuthResult( + authenticated=True, + key_name="external", + is_internal=False, + rate_limit_remaining=remaining, + user_id=external_result.user_id, + ) + else: + # External auth returned invalid + logger.warning(f"External auth failed: {external_result.error}") + raise HTTPException( + status_code=401, + detail=external_result.error or "Invalid API key.", + headers={"WWW-Authenticate": "ApiKey"}, + ) + # If external_result is None, fall through to local auth + + # Local auth: Validate key against environment variables + if api_key not in _api_keys: + logger.warning(f"Invalid API key attempt for {path}") + raise HTTPException( + status_code=401, + detail="Invalid API key.", + headers={"WWW-Authenticate": "ApiKey"}, + ) + + key_info = _api_keys[api_key] + + # Check rate limit (skip for internal keys) + if not key_info.is_internal: + rate_info = _rate_limits[api_key] + if not rate_info.is_allowed(key_info.rate_limit, RATE_LIMIT_WINDOW): + remaining = rate_info.remaining(key_info.rate_limit, RATE_LIMIT_WINDOW) + logger.warning(f"Rate limit exceeded for key '{key_info.name}'") + raise HTTPException( + status_code=429, + detail=f"Rate limit exceeded. Try again in {RATE_LIMIT_WINDOW} seconds.", + headers={ + "X-RateLimit-Limit": str(key_info.rate_limit), + "X-RateLimit-Remaining": str(remaining), + "X-RateLimit-Reset": str(int(time.time()) + RATE_LIMIT_WINDOW), + "Retry-After": str(RATE_LIMIT_WINDOW), + }, + ) + remaining = rate_info.remaining(key_info.rate_limit, RATE_LIMIT_WINDOW) + else: + remaining = None + + logger.debug(f"Authenticated request from '{key_info.name}' to {path}") + + return AuthResult( + authenticated=True, + key_name=key_info.name, + is_internal=key_info.is_internal, + rate_limit_remaining=remaining, + ) + + +def get_api_key_stats() -> dict: + """Get statistics about API keys (for admin endpoint).""" + stats = { + "total_keys": len(_api_keys), + "auth_required": REQUIRE_AUTH, + "rate_limit": { + "requests_per_window": RATE_LIMIT_REQUESTS, + "window_seconds": RATE_LIMIT_WINDOW, + }, + "keys": [], + } + + for key, info in _api_keys.items(): + # Don't expose actual keys, just metadata + masked_key = key[:8] + "..." if len(key) > 8 else "***" + rate_info = _rate_limits.get(key, RateLimitInfo()) + stats["keys"].append({ + "name": info.name, + "key_prefix": masked_key, + "is_internal": info.is_internal, + "requests_in_window": len(rate_info.requests), + "remaining": rate_info.remaining(info.rate_limit, RATE_LIMIT_WINDOW), + }) + + return stats + + +def reload_api_keys(): + """Reload API keys from environment (for runtime updates).""" + global _api_keys + _api_keys = _parse_api_keys() + logger.info(f"Reloaded {len(_api_keys)} API keys") diff --git a/services/mana-tts/app/external_auth.py b/services/mana-tts/app/external_auth.py index 18a1cfe68..6f64bd315 100644 --- a/services/mana-tts/app/external_auth.py +++ b/services/mana-tts/app/external_auth.py @@ -1,22 +1,145 @@ """ -External API Key Validation — delegates to shared mana_auth package. +External API Key Validation via mana-core-auth + +When EXTERNAL_AUTH_ENABLED=true, API keys are validated against the +central mana-core-auth service. This allows users to create and manage +API keys from the mana.how web interface. + +Results are cached for 5 minutes to reduce load on the auth service. """ -import sys import os +import time +import logging +import httpx +from typing import Optional +from dataclasses import dataclass -sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "..", "..", "packages", "shared-python")) +logger = logging.getLogger(__name__) -from mana_auth.external_auth import ( - ExternalValidationResult, - is_external_auth_enabled, - validate_api_key_external, - clear_cache, -) +# Configuration +EXTERNAL_AUTH_ENABLED = os.getenv("EXTERNAL_AUTH_ENABLED", "false").lower() == "true" +MANA_CORE_AUTH_URL = os.getenv("MANA_CORE_AUTH_URL", "http://localhost:3001") +API_KEY_CACHE_TTL = int(os.getenv("API_KEY_CACHE_TTL", "300")) # 5 minutes +EXTERNAL_AUTH_TIMEOUT = float(os.getenv("EXTERNAL_AUTH_TIMEOUT", "5.0")) # seconds -__all__ = [ - "ExternalValidationResult", - "is_external_auth_enabled", - "validate_api_key_external", - "clear_cache", -] + +@dataclass +class ExternalValidationResult: + """Result from external API key validation.""" + valid: bool + user_id: Optional[str] = None + scopes: Optional[list] = None + rate_limit_requests: int = 60 + rate_limit_window: int = 60 + error: Optional[str] = None + cached_at: float = 0.0 + + +# In-memory cache for validation results +# Key: API key, Value: ExternalValidationResult +_validation_cache: dict[str, ExternalValidationResult] = {} + + +def is_external_auth_enabled() -> bool: + """Check if external authentication is enabled.""" + return EXTERNAL_AUTH_ENABLED + + +def _get_cached_result(api_key: str) -> Optional[ExternalValidationResult]: + """Get cached validation result if still valid.""" + result = _validation_cache.get(api_key) + if result and (time.time() - result.cached_at) < API_KEY_CACHE_TTL: + return result + return None + + +def _cache_result(api_key: str, result: ExternalValidationResult): + """Cache a validation result.""" + result.cached_at = time.time() + _validation_cache[api_key] = result + + # Clean up old entries periodically (keep cache size manageable) + if len(_validation_cache) > 1000: + now = time.time() + expired_keys = [ + k for k, v in _validation_cache.items() + if (now - v.cached_at) >= API_KEY_CACHE_TTL + ] + for k in expired_keys: + del _validation_cache[k] + + +async def validate_api_key_external(api_key: str, scope: str) -> Optional[ExternalValidationResult]: + """ + Validate an API key against mana-core-auth service. + + Args: + api_key: The API key to validate (e.g., "sk_live_...") + scope: The required scope (e.g., "stt" or "tts") + + Returns: + ExternalValidationResult if external auth is enabled and the key was validated. + None if external auth is disabled or the service is unavailable (fallback to local). + """ + if not EXTERNAL_AUTH_ENABLED: + return None + + # Check cache first + cached = _get_cached_result(api_key) + if cached: + logger.debug(f"Using cached validation result for key prefix: {api_key[:12]}...") + # Check scope against cached result + if cached.valid and cached.scopes and scope not in cached.scopes: + return ExternalValidationResult( + valid=False, + error=f"API key does not have scope: {scope}", + ) + return cached + + # Call mana-core-auth validation endpoint + try: + async with httpx.AsyncClient(timeout=EXTERNAL_AUTH_TIMEOUT) as client: + response = await client.post( + f"{MANA_CORE_AUTH_URL}/api/v1/api-keys/validate", + json={"apiKey": api_key, "scope": scope}, + ) + + if response.status_code == 200: + data = response.json() + result = ExternalValidationResult( + valid=data.get("valid", False), + user_id=data.get("userId"), + scopes=data.get("scopes", []), + rate_limit_requests=data.get("rateLimit", {}).get("requests", 60), + rate_limit_window=data.get("rateLimit", {}).get("window", 60), + error=data.get("error"), + ) + _cache_result(api_key, result) + return result + else: + logger.warning( + f"External auth returned status {response.status_code}: {response.text}" + ) + # Don't cache errors - allow retry + return ExternalValidationResult( + valid=False, + error=f"Auth service returned {response.status_code}", + ) + + except httpx.TimeoutException: + logger.warning("External auth service timeout - falling back to local auth") + return None + except httpx.ConnectError: + logger.warning("Cannot connect to external auth service - falling back to local auth") + return None + except Exception as e: + logger.error(f"External auth error: {e}") + return None + + +def clear_cache(): + """Clear the validation cache (for testing or runtime updates).""" + global _validation_cache + _validation_cache.clear() + logger.info("External auth cache cleared") diff --git a/services/mana-tts/app/f5_service.py b/services/mana-tts/app/f5_service.py index d6494c212..79a247c5c 100644 --- a/services/mana-tts/app/f5_service.py +++ b/services/mana-tts/app/f5_service.py @@ -1,6 +1,6 @@ """ F5-TTS Service for voice cloning synthesis. -Uses f5-tts-mlx optimized for Apple Silicon. +CUDA version using f5-tts PyTorch package. """ import logging @@ -15,14 +15,12 @@ import numpy as np logger = logging.getLogger(__name__) # Global singleton for lazy initialization -_f5_model = None -_f5_model_name = None +_f5_api = None # Default model -DEFAULT_F5_MODEL = os.getenv("F5_MODEL", "lucasnewman/f5-tts-mlx") +DEFAULT_F5_MODEL = os.getenv("F5_MODEL", "F5-TTS") # Default generation parameters -DEFAULT_DURATION = 10.0 # seconds DEFAULT_STEPS = 32 DEFAULT_CFG_STRENGTH = 2.0 DEFAULT_SWAY_COEF = -1.0 @@ -40,35 +38,25 @@ class F5Result: def get_f5_model(model_name: str = DEFAULT_F5_MODEL): - """ - Get or create F5-TTS model instance (singleton pattern). + """Get or create F5-TTS API instance (singleton pattern).""" + global _f5_api - Args: - model_name: HuggingFace model identifier - - Returns: - F5TTS model instance - """ - global _f5_model, _f5_model_name - - # Return existing model if same model name - if _f5_model is not None and _f5_model_name == model_name: - return _f5_model + if _f5_api is not None: + return _f5_api logger.info(f"Loading F5-TTS model: {model_name}") try: - from f5_tts_mlx import F5TTS + from f5_tts.api import F5TTS - _f5_model = F5TTS(model_name=model_name) - _f5_model_name = model_name - logger.info("F5-TTS model loaded successfully") - return _f5_model + _f5_api = F5TTS(model_type="F5-TTS") + logger.info("F5-TTS model loaded successfully (CUDA)") + return _f5_api except ImportError as e: - logger.error(f"Failed to import f5_tts_mlx: {e}") + logger.error(f"Failed to import f5_tts: {e}") raise RuntimeError( - "f5-tts-mlx not installed. Run: pip install f5-tts-mlx" + "f5-tts not installed. Run: pip install f5-tts" ) except Exception as e: logger.error(f"Failed to load F5-TTS model: {e}") @@ -77,7 +65,7 @@ def get_f5_model(model_name: str = DEFAULT_F5_MODEL): def is_f5_loaded() -> bool: """Check if F5-TTS model is currently loaded.""" - return _f5_model is not None + return _f5_api is not None async def synthesize_f5( @@ -103,13 +91,14 @@ async def synthesize_f5( cfg_strength: Classifier-free guidance strength sway_coef: Sway sampling coefficient speed: Speech speed multiplier - model_name: HuggingFace model identifier + model_name: Model identifier Returns: F5Result with audio data """ - # Get model - model = get_f5_model(model_name) + import asyncio + + api = get_f5_model(model_name) logger.info( f"Synthesizing with F5-TTS: text_length={len(text)}, " @@ -117,17 +106,26 @@ async def synthesize_f5( ) try: - # Generate audio - audio, sample_rate = model.generate( - text=text, - ref_audio_path=reference_audio_path, - ref_audio_text=reference_text, - duration=duration, - steps=steps, - cfg_strength=cfg_strength, - sway_coef=sway_coef, - speed=speed, - ) + # F5-TTS API infer method (runs synchronously, offload to thread) + loop = asyncio.get_event_loop() + + def _generate(): + wav, sr, _ = api.infer( + ref_file=reference_audio_path, + ref_text=reference_text, + gen_text=text, + nfe_step=steps, + cfg_strength=cfg_strength, + sway_sampling_coeff=sway_coef, + speed=speed, + ) + return wav, sr + + audio, sample_rate = await loop.run_in_executor(None, _generate) + + # Convert to numpy if needed + if not isinstance(audio, np.ndarray): + audio = np.array(audio, dtype=np.float32) # Calculate duration audio_duration = len(audio) / sample_rate @@ -152,24 +150,8 @@ async def synthesize_f5_from_bytes( audio_extension: str = ".wav", **kwargs, ) -> F5Result: - """ - Synthesize speech using F5-TTS with reference audio as bytes. - - Args: - text: Text to synthesize - reference_audio_bytes: Reference audio as bytes - reference_text: Transcript of the reference audio - audio_extension: File extension for temp file - **kwargs: Additional arguments passed to synthesize_f5 - - Returns: - F5Result with audio data - """ - # Save reference audio to temp file - with tempfile.NamedTemporaryFile( - suffix=audio_extension, - delete=False, - ) as tmp: + """Synthesize speech using F5-TTS with reference audio as bytes.""" + with tempfile.NamedTemporaryFile(suffix=audio_extension, delete=False) as tmp: tmp.write(reference_audio_bytes) tmp_path = tmp.name @@ -182,7 +164,6 @@ async def synthesize_f5_from_bytes( ) return result finally: - # Clean up temp file try: Path(tmp_path).unlink() except Exception: @@ -190,18 +171,7 @@ async def synthesize_f5_from_bytes( def estimate_duration(text: str, speed: float = 1.0) -> float: - """ - Estimate audio duration from text. - - Args: - text: Text to synthesize - speed: Speech speed multiplier - - Returns: - Estimated duration in seconds - """ - # Rough estimate: ~150 words per minute at normal speed - # Average word length: ~5 characters + """Estimate audio duration from text.""" words = len(text) / 5 minutes = words / 150 seconds = minutes * 60 diff --git a/services/mana-tts/app/kokoro_service.py b/services/mana-tts/app/kokoro_service.py index 2ce42d2ac..dbf7cb504 100644 --- a/services/mana-tts/app/kokoro_service.py +++ b/services/mana-tts/app/kokoro_service.py @@ -1,6 +1,6 @@ """ Kokoro TTS Service for fast preset voice synthesis. -Uses mlx-audio's Kokoro implementation optimized for Apple Silicon. +CUDA version using kokoro PyTorch package. """ import logging @@ -12,11 +12,10 @@ import numpy as np logger = logging.getLogger(__name__) # Global singleton for lazy initialization -_kokoro_model = None -_kokoro_model_name = None +_kokoro_pipeline = None # Default model -DEFAULT_KOKORO_MODEL = "mlx-community/Kokoro-82M-bf16" +DEFAULT_KOKORO_MODEL = "hexgrad/Kokoro-82M" # Available Kokoro voices (American Female/Male, British Female/Male) KOKORO_VOICES = { @@ -67,35 +66,25 @@ class KokoroResult: def get_kokoro_model(model_name: str = DEFAULT_KOKORO_MODEL): - """ - Get or create Kokoro model instance (singleton pattern). + """Get or create Kokoro pipeline instance (singleton pattern).""" + global _kokoro_pipeline - Args: - model_name: HuggingFace model identifier - - Returns: - Kokoro model instance - """ - global _kokoro_model, _kokoro_model_name - - # Return existing model if same model name - if _kokoro_model is not None and _kokoro_model_name == model_name: - return _kokoro_model + if _kokoro_pipeline is not None: + return _kokoro_pipeline logger.info(f"Loading Kokoro model: {model_name}") try: - from mlx_audio.tts import load + from kokoro import KPipeline - _kokoro_model = load(model_name) - _kokoro_model_name = model_name - logger.info("Kokoro model loaded successfully") - return _kokoro_model + _kokoro_pipeline = KPipeline(lang_code="a") # 'a' for American English + logger.info("Kokoro pipeline loaded successfully") + return _kokoro_pipeline except ImportError as e: - logger.error(f"Failed to import mlx_audio: {e}") + logger.error(f"Failed to import kokoro: {e}") raise RuntimeError( - "mlx-audio not installed. Run: pip install mlx-audio" + "kokoro not installed. Run: pip install kokoro" ) except Exception as e: logger.error(f"Failed to load Kokoro model: {e}") @@ -104,7 +93,7 @@ def get_kokoro_model(model_name: str = DEFAULT_KOKORO_MODEL): def is_kokoro_loaded() -> bool: """Check if Kokoro model is currently loaded.""" - return _kokoro_model is not None + return _kokoro_pipeline is not None def get_available_voices() -> dict[str, str]: @@ -125,7 +114,7 @@ async def synthesize_kokoro( text: Text to synthesize voice: Voice ID from KOKORO_VOICES speed: Speech speed multiplier (0.5-2.0) - model_name: HuggingFace model identifier + model_name: Model identifier Returns: KokoroResult with audio data @@ -139,29 +128,18 @@ async def synthesize_kokoro( speed = max(0.5, min(2.0, speed)) # Get model - model = get_kokoro_model(model_name) + pipeline = get_kokoro_model(model_name) logger.info(f"Synthesizing with Kokoro: voice={voice}, speed={speed}, text_length={len(text)}") try: - # Generate audio using mlx-audio's generate method - # Returns a generator of GenerationResult objects - result_gen = model.generate( - text=text, - voice=voice, - speed=speed, - ) - - # Collect all audio chunks from the generator + # Generate audio using kokoro pipeline audio_chunks = [] - sample_rate = 24000 # Default, will be updated from result + sample_rate = 24000 # Kokoro default - for result in result_gen: - # Each result has audio, sample_rate, audio_duration (string) - sample_rate = result.sample_rate - - # Convert MLX array to numpy - audio_np = np.array(result.audio, dtype=np.float32) + for result in pipeline(text, voice=voice, speed=speed): + # result is a KPipelineResult with .audio (tensor) and .graphemes/.phonemes + audio_np = result.audio.numpy() audio_chunks.append(audio_np) # Concatenate all chunks diff --git a/services/mana-tts/app/vram_manager.py b/services/mana-tts/app/vram_manager.py new file mode 100644 index 000000000..89b5656ae --- /dev/null +++ b/services/mana-tts/app/vram_manager.py @@ -0,0 +1,114 @@ +""" +VRAM Manager — Automatic model unloading after idle timeout. + +Tracks last usage time per model and unloads after configurable timeout. +Designed for shared GPU environments (multiple services on one RTX 3090). + +Usage in a service: + from vram_manager import VramManager + + vram = VramManager(idle_timeout=300) # 5 min + + # Before using a model + vram.touch() + + # Call periodically (e.g., from health check or background task) + vram.check_idle(unload_fn=my_unload_function) +""" + +import os +import time +import logging +import threading +from typing import Optional, Callable + +logger = logging.getLogger(__name__) + +DEFAULT_IDLE_TIMEOUT = int(os.getenv("VRAM_IDLE_TIMEOUT", "300")) # 5 minutes + + +class VramManager: + def __init__(self, idle_timeout: int = DEFAULT_IDLE_TIMEOUT, service_name: str = "unknown"): + self.idle_timeout = idle_timeout + self.service_name = service_name + self.last_used: float = 0.0 + self.model_loaded: bool = False + self._lock = threading.Lock() + self._timer: Optional[threading.Timer] = None + + def touch(self): + """Mark the model as recently used. Call before/after each inference.""" + with self._lock: + self.last_used = time.time() + self.model_loaded = True + self._schedule_check() + + def mark_loaded(self): + """Mark that a model has been loaded into VRAM.""" + with self._lock: + self.model_loaded = True + self.last_used = time.time() + self._schedule_check() + logger.info(f"[{self.service_name}] Model loaded, idle timeout: {self.idle_timeout}s") + + def mark_unloaded(self): + """Mark that a model has been unloaded from VRAM.""" + with self._lock: + self.model_loaded = False + if self._timer: + self._timer.cancel() + self._timer = None + logger.info(f"[{self.service_name}] Model unloaded, VRAM freed") + + def is_idle(self) -> bool: + """Check if the model has been idle longer than the timeout.""" + if not self.model_loaded: + return False + return (time.time() - self.last_used) > self.idle_timeout + + def seconds_until_unload(self) -> Optional[float]: + """Seconds until the model will be unloaded, or None if not loaded.""" + if not self.model_loaded: + return None + remaining = self.idle_timeout - (time.time() - self.last_used) + return max(0, remaining) + + def check_and_unload(self, unload_fn: Callable[[], None]) -> bool: + """Check if idle and unload if so. Returns True if unloaded.""" + if self.is_idle(): + logger.info(f"[{self.service_name}] Idle for >{self.idle_timeout}s, unloading model...") + try: + unload_fn() + self.mark_unloaded() + return True + except Exception as e: + logger.error(f"[{self.service_name}] Failed to unload: {e}") + return False + + def _schedule_check(self): + """Schedule an idle check after the timeout period.""" + if self._timer: + self._timer.cancel() + + self._timer = threading.Timer( + self.idle_timeout + 5, # Small buffer + self._auto_check, + ) + self._timer.daemon = True + self._timer.start() + + def _auto_check(self): + """Auto-triggered idle check (called by timer).""" + # This is just a log — actual unloading needs the unload_fn + # which depends on the service. The service should call check_and_unload. + if self.is_idle(): + logger.info(f"[{self.service_name}] Model idle for >{self.idle_timeout}s — ready to unload") + + def status(self) -> dict: + """Get current VRAM manager status.""" + return { + "model_loaded": self.model_loaded, + "idle_seconds": round(time.time() - self.last_used, 1) if self.model_loaded else None, + "idle_timeout": self.idle_timeout, + "seconds_until_unload": round(self.seconds_until_unload(), 1) if self.model_loaded else None, + } diff --git a/services/mana-tts/service.pyw b/services/mana-tts/service.pyw new file mode 100644 index 000000000..64ad24621 --- /dev/null +++ b/services/mana-tts/service.pyw @@ -0,0 +1,17 @@ +"""mana-tts service runner.""" +import os +import sys +os.chdir(r"C:\mana\services\mana-tts") +sys.path.insert(0, r"C:\mana\services\mana-tts") + +# Load .env file +from dotenv import load_dotenv +load_dotenv(r"C:\mana\services\mana-tts\.env") + +# Redirect stdout/stderr to log file +log = open(r"C:\mana\services\mana-tts\service.log", "w", buffering=1) +sys.stdout = log +sys.stderr = log + +import uvicorn +uvicorn.run("app.main:app", host="0.0.0.0", port=3022, log_level="info") diff --git a/services/mana-video-gen/service.pyw b/services/mana-video-gen/service.pyw new file mode 100644 index 000000000..bf733feb0 --- /dev/null +++ b/services/mana-video-gen/service.pyw @@ -0,0 +1,37 @@ +"""mana-video-gen service runner.""" +import os +import sys + +os.chdir(r"C:\mana\services\mana-video-gen") +sys.path.insert(0, r"C:\mana\services\mana-video-gen") + +# Redirect stdout/stderr to log file FIRST +log = open(r"C:\mana\services\mana-video-gen\service.log", "w", buffering=1) +sys.stdout = log +sys.stderr = log + +# Load .env file +from dotenv import load_dotenv +load_dotenv(r"C:\mana\services\mana-video-gen\.env") + +# Ensure FFmpeg is in PATH (LTX needs it for video encoding) +ffmpeg_dir = r"C:\Users\tills\AppData\Local\Microsoft\WinGet\Links" +if ffmpeg_dir not in os.environ.get("PATH", ""): + os.environ["PATH"] = ffmpeg_dir + os.pathsep + os.environ.get("PATH", "") + +# HF token for model downloads +hf_token = os.environ.get("HF_TOKEN", "") +if hf_token: + os.environ["HUGGING_FACE_HUB_TOKEN"] = hf_token + +# Pre-initialize CUDA before importing diffusers +import torch +if torch.cuda.is_available(): + torch.cuda.init() + print(f"CUDA initialized: {torch.cuda.get_device_name(0)}", flush=True) + print(f"VRAM: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB", flush=True) +else: + print("WARNING: CUDA not available, LTX will be unusably slow on CPU", flush=True) + +import uvicorn +uvicorn.run("app.main:app", host="0.0.0.0", port=3026, log_level="info")