mirror of
https://github.com/Memo-2023/mana-monorepo.git
synced 2026-05-14 18:01:09 +02:00
chore(ai-services): adopt Windows GPU as source of truth for llm/stt/tts
The Windows GPU server has been the actual production home for these
services for some time, and the running code there has drifted ahead of
the repo. This sync pulls the live versions back into the repo so the
Windows box is no longer the only place those changes exist.
Pulled from C:\mana\services\* on mana-server-gpu (192.168.178.11):
mana-llm:
- src/main.py, src/config.py — small fixes (auth wiring, config tweaks)
- src/api_auth.py — NEW (cross-service GPU_API_KEY validator)
- service.pyw — Windows runner used by the ManaLLM scheduled task
(sets up logging redirect, loads .env, calls uvicorn)
mana-stt:
- app/main.py — substantial cleanup (684→392 lines), drops the
whisperx-as-separate-backend branching now that whisper_service.py
rolls whisperx in directly
- app/whisper_service.py — full CUDA + whisperx rewrite (158→358 lines)
- app/auth.py + external_auth.py — significantly expanded auth
- app/vram_manager.py — NEW (shared VRAM accounting helper)
- service.pyw — Windows runner with CUDA pre-init, FFmpeg PATH
injection, .env loading
- removed: app/whisper_service_cuda.py (folded into whisper_service.py)
- removed: app/whisperx_service.py (folded into whisper_service.py)
mana-tts:
- app/auth.py, external_auth.py — same auth expansion as stt
- app/f5_service.py, kokoro_service.py — Windows tweaks
- app/vram_manager.py — NEW (same shared helper as stt)
- service.pyw — Windows runner
mana-video-gen:
- service.pyw — Windows runner (no other changes; the .py code on the
GPU box is byte-identical to what's already in the repo)
The service.pyw files contain absolute Windows paths
(C:\mana\services\<svc>) and a hardcoded FFmpeg PATH for the tills user
profile. Kept as-is intentionally — they exist to be deployed to that
one machine and any abstraction layer would just hide what's actually
happening. Anyone redeploying to a different layout will need to edit
the path strings, which is a known and obvious change.
Mac-Mini infrastructure for these services (launchd plists, install
scripts, scripts/mac-mini/setup-{stt,tts}.sh, the Mac-flux2c image-gen
implementation) is still on disk and will be removed in a follow-up
commit, along with replacing mana-image-gen with the Windows
diffusers+CUDA implementation. This commit is just the live-code sync.
This commit is contained in:
parent
abe0a21966
commit
b8e18b7f82
20 changed files with 1623 additions and 1261 deletions
|
|
@ -23,3 +23,6 @@ CACHE_TTL=3600
|
|||
|
||||
# CORS
|
||||
CORS_ORIGINS=http://localhost:5173,https://mana.how
|
||||
|
||||
# API key for cross-service auth (validated in src/api_auth.py)
|
||||
GPU_API_KEY=
|
||||
|
|
|
|||
17
services/mana-llm/service.pyw
Normal file
17
services/mana-llm/service.pyw
Normal file
|
|
@ -0,0 +1,17 @@
|
|||
"""mana-llm service runner (run with pythonw.exe to run headless)."""
|
||||
import os
|
||||
import sys
|
||||
os.chdir(r"C:\mana\services\mana-llm")
|
||||
sys.path.insert(0, r"C:\mana\services\mana-llm")
|
||||
|
||||
# Load .env file
|
||||
from dotenv import load_dotenv
|
||||
load_dotenv(r"C:\mana\services\mana-llm\.env")
|
||||
|
||||
# Redirect stdout/stderr to log file
|
||||
log = open(r"C:\mana\services\mana-llm\service.log", "w", buffering=1)
|
||||
sys.stdout = log
|
||||
sys.stderr = log
|
||||
|
||||
import uvicorn
|
||||
uvicorn.run("src.main:app", host="0.0.0.0", port=3025, log_level="info")
|
||||
53
services/mana-llm/src/api_auth.py
Normal file
53
services/mana-llm/src/api_auth.py
Normal file
|
|
@ -0,0 +1,53 @@
|
|||
"""
|
||||
Simple API Key Authentication Middleware for GPU Services.
|
||||
|
||||
Checks X-API-Key header or ?api_key query parameter.
|
||||
Skips auth for /health, /docs, /openapi.json, /redoc endpoints.
|
||||
|
||||
Environment variables:
|
||||
GPU_API_KEY: Required API key (if empty, auth is disabled)
|
||||
GPU_REQUIRE_AUTH: Enable/disable auth (default: true if GPU_API_KEY is set)
|
||||
"""
|
||||
|
||||
import os
|
||||
import logging
|
||||
from fastapi import Request
|
||||
from fastapi.responses import JSONResponse
|
||||
from starlette.middleware.base import BaseHTTPMiddleware
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
GPU_API_KEY = os.getenv("GPU_API_KEY", "")
|
||||
GPU_REQUIRE_AUTH = os.getenv("GPU_REQUIRE_AUTH", "true" if GPU_API_KEY else "false").lower() == "true"
|
||||
|
||||
# Endpoints that don't require auth
|
||||
PUBLIC_PATHS = {"/health", "/docs", "/openapi.json", "/redoc", "/metrics"}
|
||||
|
||||
|
||||
class ApiKeyMiddleware(BaseHTTPMiddleware):
|
||||
async def dispatch(self, request: Request, call_next):
|
||||
# Skip auth if disabled
|
||||
if not GPU_REQUIRE_AUTH or not GPU_API_KEY:
|
||||
return await call_next(request)
|
||||
|
||||
# Skip auth for public endpoints
|
||||
if request.url.path in PUBLIC_PATHS:
|
||||
return await call_next(request)
|
||||
|
||||
# Check API key from header or query param
|
||||
api_key = request.headers.get("X-API-Key") or request.query_params.get("api_key")
|
||||
|
||||
if not api_key:
|
||||
return JSONResponse(
|
||||
status_code=401,
|
||||
content={"detail": "Missing API key. Provide X-API-Key header."},
|
||||
)
|
||||
|
||||
if api_key != GPU_API_KEY:
|
||||
logger.warning(f"Invalid API key attempt from {request.client.host if request.client else 'unknown'}")
|
||||
return JSONResponse(
|
||||
status_code=401,
|
||||
content={"detail": "Invalid API key."},
|
||||
)
|
||||
|
||||
return await call_next(request)
|
||||
|
|
@ -51,6 +51,7 @@ class Settings(BaseSettings):
|
|||
class Config:
|
||||
env_file = ".env"
|
||||
env_file_encoding = "utf-8"
|
||||
extra = "ignore"
|
||||
|
||||
|
||||
settings = Settings()
|
||||
|
|
|
|||
|
|
@ -10,6 +10,7 @@ from fastapi.middleware.cors import CORSMiddleware
|
|||
from fastapi.responses import Response
|
||||
from sse_starlette.sse import EventSourceResponse
|
||||
|
||||
from src.api_auth import ApiKeyMiddleware
|
||||
from src.config import settings
|
||||
from src.models import (
|
||||
ChatCompletionRequest,
|
||||
|
|
@ -70,6 +71,7 @@ app.add_middleware(
|
|||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
app.add_middleware(ApiKeyMiddleware)
|
||||
|
||||
|
||||
# Health endpoint
|
||||
|
|
|
|||
|
|
@ -1,41 +1,271 @@
|
|||
"""
|
||||
API Key Authentication for Mana STT Service.
|
||||
Delegates to shared mana_auth package.
|
||||
API Key Authentication for ManaCore STT Service
|
||||
|
||||
Supports two authentication modes:
|
||||
1. Local API keys: Configured via environment variables
|
||||
2. External API keys: Validated via mana-core-auth service (when EXTERNAL_AUTH_ENABLED=true)
|
||||
|
||||
Usage:
|
||||
# Local keys
|
||||
API_KEYS=sk-key1:name1,sk-key2:name2
|
||||
INTERNAL_API_KEY=sk-internal-xxx
|
||||
|
||||
# External auth (for user-created keys via mana.how)
|
||||
EXTERNAL_AUTH_ENABLED=true
|
||||
MANA_CORE_AUTH_URL=http://localhost:3001
|
||||
"""
|
||||
|
||||
# Re-export everything from shared auth for backward compatibility
|
||||
import sys
|
||||
import os
|
||||
import time
|
||||
import logging
|
||||
from typing import Optional
|
||||
from collections import defaultdict
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
# Add shared-python to path
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "..", "..", "packages", "shared-python"))
|
||||
from fastapi import HTTPException, Security, Request
|
||||
from fastapi.security import APIKeyHeader
|
||||
|
||||
from mana_auth import (
|
||||
APIKey,
|
||||
AuthResult,
|
||||
RateLimitInfo,
|
||||
verify_api_key as _verify_api_key,
|
||||
get_api_key_stats,
|
||||
reload_api_keys,
|
||||
api_key_header,
|
||||
create_auth_dependency,
|
||||
)
|
||||
from mana_auth.external_auth import (
|
||||
ExternalValidationResult,
|
||||
from .external_auth import (
|
||||
is_external_auth_enabled,
|
||||
validate_api_key_external,
|
||||
ExternalValidationResult,
|
||||
)
|
||||
|
||||
from typing import Optional
|
||||
from fastapi import Security, Request
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# STT-specific auth dependency
|
||||
verify_stt_key = create_auth_dependency("stt")
|
||||
# Configuration
|
||||
API_KEYS_ENV = os.getenv("API_KEYS", "") # Format: "sk-key1:name1,sk-key2:name2"
|
||||
INTERNAL_API_KEY = os.getenv("INTERNAL_API_KEY", "") # Unlimited internal key
|
||||
REQUIRE_AUTH = os.getenv("REQUIRE_AUTH", "true").lower() == "true"
|
||||
RATE_LIMIT_REQUESTS = int(os.getenv("RATE_LIMIT_REQUESTS", "60")) # Per minute
|
||||
RATE_LIMIT_WINDOW = int(os.getenv("RATE_LIMIT_WINDOW", "60")) # Seconds
|
||||
|
||||
|
||||
@dataclass
|
||||
class APIKey:
|
||||
"""API Key with metadata."""
|
||||
key: str
|
||||
name: str
|
||||
is_internal: bool = False
|
||||
rate_limit: int = RATE_LIMIT_REQUESTS # Requests per window
|
||||
|
||||
|
||||
@dataclass
|
||||
class RateLimitInfo:
|
||||
"""Rate limit tracking per key."""
|
||||
requests: list = field(default_factory=list)
|
||||
|
||||
def is_allowed(self, limit: int, window: int) -> bool:
|
||||
"""Check if request is allowed within rate limit."""
|
||||
now = time.time()
|
||||
# Remove old requests outside window
|
||||
self.requests = [t for t in self.requests if now - t < window]
|
||||
|
||||
if len(self.requests) >= limit:
|
||||
return False
|
||||
|
||||
self.requests.append(now)
|
||||
return True
|
||||
|
||||
def remaining(self, limit: int, window: int) -> int:
|
||||
"""Get remaining requests in current window."""
|
||||
now = time.time()
|
||||
self.requests = [t for t in self.requests if now - t < window]
|
||||
return max(0, limit - len(self.requests))
|
||||
|
||||
|
||||
# Parse API keys from environment
|
||||
def _parse_api_keys() -> dict[str, APIKey]:
|
||||
"""Parse API keys from environment variables."""
|
||||
keys = {}
|
||||
|
||||
# Parse comma-separated keys
|
||||
if API_KEYS_ENV:
|
||||
for entry in API_KEYS_ENV.split(","):
|
||||
entry = entry.strip()
|
||||
if ":" in entry:
|
||||
key, name = entry.split(":", 1)
|
||||
else:
|
||||
key, name = entry, "default"
|
||||
keys[key.strip()] = APIKey(key=key.strip(), name=name.strip())
|
||||
|
||||
# Add internal key with no rate limit
|
||||
if INTERNAL_API_KEY:
|
||||
keys[INTERNAL_API_KEY] = APIKey(
|
||||
key=INTERNAL_API_KEY,
|
||||
name="internal",
|
||||
is_internal=True,
|
||||
rate_limit=999999, # Effectively unlimited
|
||||
)
|
||||
|
||||
return keys
|
||||
|
||||
|
||||
# Global state
|
||||
_api_keys = _parse_api_keys()
|
||||
_rate_limits: dict[str, RateLimitInfo] = defaultdict(RateLimitInfo)
|
||||
|
||||
# Security scheme
|
||||
api_key_header = APIKeyHeader(name="X-API-Key", auto_error=False)
|
||||
|
||||
|
||||
@dataclass
|
||||
class AuthResult:
|
||||
"""Result of authentication check."""
|
||||
authenticated: bool
|
||||
key_name: Optional[str] = None
|
||||
is_internal: bool = False
|
||||
rate_limit_remaining: Optional[int] = None
|
||||
user_id: Optional[str] = None # Set when using external auth
|
||||
|
||||
|
||||
async def verify_api_key(
|
||||
request: Request,
|
||||
api_key: Optional[str] = Security(api_key_header),
|
||||
) -> AuthResult:
|
||||
"""Verify API key with STT scope."""
|
||||
return await _verify_api_key(request, scope="stt", api_key=api_key)
|
||||
"""
|
||||
Verify API key and check rate limits.
|
||||
|
||||
Supports two authentication modes:
|
||||
1. External auth via mana-core-auth (for sk_live_ keys)
|
||||
2. Local auth via environment variables
|
||||
|
||||
Returns AuthResult with authentication status.
|
||||
Raises HTTPException if auth fails or rate limited.
|
||||
"""
|
||||
# Skip auth for health and docs endpoints
|
||||
path = request.url.path
|
||||
if path in ["/health", "/docs", "/openapi.json", "/redoc"]:
|
||||
return AuthResult(authenticated=True, key_name="public")
|
||||
|
||||
# If auth not required, allow all
|
||||
if not REQUIRE_AUTH:
|
||||
return AuthResult(authenticated=True, key_name="anonymous")
|
||||
|
||||
# Check for API key
|
||||
if not api_key:
|
||||
logger.warning(f"Missing API key for {path} from {request.client.host if request.client else 'unknown'}")
|
||||
raise HTTPException(
|
||||
status_code=401,
|
||||
detail="Missing API key. Provide X-API-Key header.",
|
||||
headers={"WWW-Authenticate": "ApiKey"},
|
||||
)
|
||||
|
||||
# Try external auth first for sk_live_ keys (user-created keys via mana.how)
|
||||
if api_key.startswith("sk_live_") and is_external_auth_enabled():
|
||||
external_result = await validate_api_key_external(api_key, "stt")
|
||||
|
||||
if external_result is not None:
|
||||
if external_result.valid:
|
||||
# Use rate limits from external auth
|
||||
rate_info = _rate_limits[api_key]
|
||||
limit = external_result.rate_limit_requests
|
||||
window = external_result.rate_limit_window
|
||||
|
||||
if not rate_info.is_allowed(limit, window):
|
||||
remaining = rate_info.remaining(limit, window)
|
||||
logger.warning(f"Rate limit exceeded for external key")
|
||||
raise HTTPException(
|
||||
status_code=429,
|
||||
detail=f"Rate limit exceeded. Try again in {window} seconds.",
|
||||
headers={
|
||||
"X-RateLimit-Limit": str(limit),
|
||||
"X-RateLimit-Remaining": str(remaining),
|
||||
"X-RateLimit-Reset": str(int(time.time()) + window),
|
||||
"Retry-After": str(window),
|
||||
},
|
||||
)
|
||||
|
||||
remaining = rate_info.remaining(limit, window)
|
||||
logger.debug(f"Authenticated external request from user {external_result.user_id} to {path}")
|
||||
|
||||
return AuthResult(
|
||||
authenticated=True,
|
||||
key_name="external",
|
||||
is_internal=False,
|
||||
rate_limit_remaining=remaining,
|
||||
user_id=external_result.user_id,
|
||||
)
|
||||
else:
|
||||
# External auth returned invalid
|
||||
logger.warning(f"External auth failed: {external_result.error}")
|
||||
raise HTTPException(
|
||||
status_code=401,
|
||||
detail=external_result.error or "Invalid API key.",
|
||||
headers={"WWW-Authenticate": "ApiKey"},
|
||||
)
|
||||
# If external_result is None, fall through to local auth
|
||||
|
||||
# Local auth: Validate key against environment variables
|
||||
if api_key not in _api_keys:
|
||||
logger.warning(f"Invalid API key attempt for {path}")
|
||||
raise HTTPException(
|
||||
status_code=401,
|
||||
detail="Invalid API key.",
|
||||
headers={"WWW-Authenticate": "ApiKey"},
|
||||
)
|
||||
|
||||
key_info = _api_keys[api_key]
|
||||
|
||||
# Check rate limit (skip for internal keys)
|
||||
if not key_info.is_internal:
|
||||
rate_info = _rate_limits[api_key]
|
||||
if not rate_info.is_allowed(key_info.rate_limit, RATE_LIMIT_WINDOW):
|
||||
remaining = rate_info.remaining(key_info.rate_limit, RATE_LIMIT_WINDOW)
|
||||
logger.warning(f"Rate limit exceeded for key '{key_info.name}'")
|
||||
raise HTTPException(
|
||||
status_code=429,
|
||||
detail=f"Rate limit exceeded. Try again in {RATE_LIMIT_WINDOW} seconds.",
|
||||
headers={
|
||||
"X-RateLimit-Limit": str(key_info.rate_limit),
|
||||
"X-RateLimit-Remaining": str(remaining),
|
||||
"X-RateLimit-Reset": str(int(time.time()) + RATE_LIMIT_WINDOW),
|
||||
"Retry-After": str(RATE_LIMIT_WINDOW),
|
||||
},
|
||||
)
|
||||
remaining = rate_info.remaining(key_info.rate_limit, RATE_LIMIT_WINDOW)
|
||||
else:
|
||||
remaining = None
|
||||
|
||||
logger.debug(f"Authenticated request from '{key_info.name}' to {path}")
|
||||
|
||||
return AuthResult(
|
||||
authenticated=True,
|
||||
key_name=key_info.name,
|
||||
is_internal=key_info.is_internal,
|
||||
rate_limit_remaining=remaining,
|
||||
)
|
||||
|
||||
|
||||
def get_api_key_stats() -> dict:
|
||||
"""Get statistics about API keys (for admin endpoint)."""
|
||||
stats = {
|
||||
"total_keys": len(_api_keys),
|
||||
"auth_required": REQUIRE_AUTH,
|
||||
"rate_limit": {
|
||||
"requests_per_window": RATE_LIMIT_REQUESTS,
|
||||
"window_seconds": RATE_LIMIT_WINDOW,
|
||||
},
|
||||
"keys": [],
|
||||
}
|
||||
|
||||
for key, info in _api_keys.items():
|
||||
# Don't expose actual keys, just metadata
|
||||
masked_key = key[:8] + "..." if len(key) > 8 else "***"
|
||||
rate_info = _rate_limits.get(key, RateLimitInfo())
|
||||
stats["keys"].append({
|
||||
"name": info.name,
|
||||
"key_prefix": masked_key,
|
||||
"is_internal": info.is_internal,
|
||||
"requests_in_window": len(rate_info.requests),
|
||||
"remaining": rate_info.remaining(info.rate_limit, RATE_LIMIT_WINDOW),
|
||||
})
|
||||
|
||||
return stats
|
||||
|
||||
|
||||
def reload_api_keys():
|
||||
"""Reload API keys from environment (for runtime updates)."""
|
||||
global _api_keys
|
||||
_api_keys = _parse_api_keys()
|
||||
logger.info(f"Reloaded {len(_api_keys)} API keys")
|
||||
|
|
|
|||
|
|
@ -1,22 +1,145 @@
|
|||
"""
|
||||
External API Key Validation — delegates to shared mana_auth package.
|
||||
External API Key Validation via mana-core-auth
|
||||
|
||||
When EXTERNAL_AUTH_ENABLED=true, API keys are validated against the
|
||||
central mana-core-auth service. This allows users to create and manage
|
||||
API keys from the mana.how web interface.
|
||||
|
||||
Results are cached for 5 minutes to reduce load on the auth service.
|
||||
"""
|
||||
|
||||
import sys
|
||||
import os
|
||||
import time
|
||||
import logging
|
||||
import httpx
|
||||
from typing import Optional
|
||||
from dataclasses import dataclass
|
||||
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "..", "..", "packages", "shared-python"))
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
from mana_auth.external_auth import (
|
||||
ExternalValidationResult,
|
||||
is_external_auth_enabled,
|
||||
validate_api_key_external,
|
||||
clear_cache,
|
||||
)
|
||||
# Configuration
|
||||
EXTERNAL_AUTH_ENABLED = os.getenv("EXTERNAL_AUTH_ENABLED", "false").lower() == "true"
|
||||
MANA_CORE_AUTH_URL = os.getenv("MANA_CORE_AUTH_URL", "http://localhost:3001")
|
||||
API_KEY_CACHE_TTL = int(os.getenv("API_KEY_CACHE_TTL", "300")) # 5 minutes
|
||||
EXTERNAL_AUTH_TIMEOUT = float(os.getenv("EXTERNAL_AUTH_TIMEOUT", "5.0")) # seconds
|
||||
|
||||
__all__ = [
|
||||
"ExternalValidationResult",
|
||||
"is_external_auth_enabled",
|
||||
"validate_api_key_external",
|
||||
"clear_cache",
|
||||
]
|
||||
|
||||
@dataclass
|
||||
class ExternalValidationResult:
|
||||
"""Result from external API key validation."""
|
||||
valid: bool
|
||||
user_id: Optional[str] = None
|
||||
scopes: Optional[list] = None
|
||||
rate_limit_requests: int = 60
|
||||
rate_limit_window: int = 60
|
||||
error: Optional[str] = None
|
||||
cached_at: float = 0.0
|
||||
|
||||
|
||||
# In-memory cache for validation results
|
||||
# Key: API key, Value: ExternalValidationResult
|
||||
_validation_cache: dict[str, ExternalValidationResult] = {}
|
||||
|
||||
|
||||
def is_external_auth_enabled() -> bool:
|
||||
"""Check if external authentication is enabled."""
|
||||
return EXTERNAL_AUTH_ENABLED
|
||||
|
||||
|
||||
def _get_cached_result(api_key: str) -> Optional[ExternalValidationResult]:
|
||||
"""Get cached validation result if still valid."""
|
||||
result = _validation_cache.get(api_key)
|
||||
if result and (time.time() - result.cached_at) < API_KEY_CACHE_TTL:
|
||||
return result
|
||||
return None
|
||||
|
||||
|
||||
def _cache_result(api_key: str, result: ExternalValidationResult):
|
||||
"""Cache a validation result."""
|
||||
result.cached_at = time.time()
|
||||
_validation_cache[api_key] = result
|
||||
|
||||
# Clean up old entries periodically (keep cache size manageable)
|
||||
if len(_validation_cache) > 1000:
|
||||
now = time.time()
|
||||
expired_keys = [
|
||||
k for k, v in _validation_cache.items()
|
||||
if (now - v.cached_at) >= API_KEY_CACHE_TTL
|
||||
]
|
||||
for k in expired_keys:
|
||||
del _validation_cache[k]
|
||||
|
||||
|
||||
async def validate_api_key_external(api_key: str, scope: str) -> Optional[ExternalValidationResult]:
|
||||
"""
|
||||
Validate an API key against mana-core-auth service.
|
||||
|
||||
Args:
|
||||
api_key: The API key to validate (e.g., "sk_live_...")
|
||||
scope: The required scope (e.g., "stt" or "tts")
|
||||
|
||||
Returns:
|
||||
ExternalValidationResult if external auth is enabled and the key was validated.
|
||||
None if external auth is disabled or the service is unavailable (fallback to local).
|
||||
"""
|
||||
if not EXTERNAL_AUTH_ENABLED:
|
||||
return None
|
||||
|
||||
# Check cache first
|
||||
cached = _get_cached_result(api_key)
|
||||
if cached:
|
||||
logger.debug(f"Using cached validation result for key prefix: {api_key[:12]}...")
|
||||
# Check scope against cached result
|
||||
if cached.valid and cached.scopes and scope not in cached.scopes:
|
||||
return ExternalValidationResult(
|
||||
valid=False,
|
||||
error=f"API key does not have scope: {scope}",
|
||||
)
|
||||
return cached
|
||||
|
||||
# Call mana-core-auth validation endpoint
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=EXTERNAL_AUTH_TIMEOUT) as client:
|
||||
response = await client.post(
|
||||
f"{MANA_CORE_AUTH_URL}/api/v1/api-keys/validate",
|
||||
json={"apiKey": api_key, "scope": scope},
|
||||
)
|
||||
|
||||
if response.status_code == 200:
|
||||
data = response.json()
|
||||
result = ExternalValidationResult(
|
||||
valid=data.get("valid", False),
|
||||
user_id=data.get("userId"),
|
||||
scopes=data.get("scopes", []),
|
||||
rate_limit_requests=data.get("rateLimit", {}).get("requests", 60),
|
||||
rate_limit_window=data.get("rateLimit", {}).get("window", 60),
|
||||
error=data.get("error"),
|
||||
)
|
||||
_cache_result(api_key, result)
|
||||
return result
|
||||
else:
|
||||
logger.warning(
|
||||
f"External auth returned status {response.status_code}: {response.text}"
|
||||
)
|
||||
# Don't cache errors - allow retry
|
||||
return ExternalValidationResult(
|
||||
valid=False,
|
||||
error=f"Auth service returned {response.status_code}",
|
||||
)
|
||||
|
||||
except httpx.TimeoutException:
|
||||
logger.warning("External auth service timeout - falling back to local auth")
|
||||
return None
|
||||
except httpx.ConnectError:
|
||||
logger.warning("Cannot connect to external auth service - falling back to local auth")
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"External auth error: {e}")
|
||||
return None
|
||||
|
||||
|
||||
def clear_cache():
|
||||
"""Clear the validation cache (for testing or runtime updates)."""
|
||||
global _validation_cache
|
||||
_validation_cache.clear()
|
||||
logger.info("External auth cache cleared")
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
"""
|
||||
Mana STT API Service
|
||||
Speech-to-Text with Whisper (MLX), WhisperX (CUDA), Voxtral (vLLM), and Mistral API (fallback)
|
||||
ManaCore STT API Service (WhisperX Edition)
|
||||
Speech-to-Text with WhisperX: transcription, word timestamps, speaker diarization.
|
||||
|
||||
Run with: uvicorn app.main:app --host 0.0.0.0 --port 3020
|
||||
"""
|
||||
|
|
@ -34,66 +34,42 @@ CORS_ORIGINS = os.getenv(
|
|||
"https://mana.how,https://chat.mana.how,http://localhost:5173"
|
||||
).split(",")
|
||||
|
||||
# vLLM configuration (disabled by default - has issues on macOS CPU)
|
||||
# vLLM configuration
|
||||
VLLM_URL = os.getenv("VLLM_URL", "http://localhost:8100")
|
||||
USE_VLLM = os.getenv("USE_VLLM", "false").lower() == "true"
|
||||
|
||||
# WhisperX configuration (CUDA GPU server)
|
||||
USE_WHISPERX = os.getenv("USE_WHISPERX", "false").lower() == "true"
|
||||
|
||||
|
||||
# Response models
|
||||
class WordInfo(BaseModel):
|
||||
word: str
|
||||
start: float
|
||||
end: float
|
||||
score: Optional[float] = None
|
||||
speaker: Optional[str] = None
|
||||
|
||||
|
||||
class SegmentInfo(BaseModel):
|
||||
start: float
|
||||
end: float
|
||||
text: str
|
||||
speaker: Optional[str] = None
|
||||
|
||||
|
||||
class TranscriptionResponse(BaseModel):
|
||||
text: str
|
||||
language: Optional[str] = None
|
||||
model: str
|
||||
latency_ms: Optional[float] = None
|
||||
duration_seconds: Optional[float] = None
|
||||
|
||||
|
||||
class WordTimestampResponse(BaseModel):
|
||||
word: str
|
||||
start: float
|
||||
end: float
|
||||
score: float = 0.0
|
||||
speaker: Optional[str] = None
|
||||
|
||||
|
||||
class SegmentResponse(BaseModel):
|
||||
start: float
|
||||
end: float
|
||||
text: str
|
||||
speaker: Optional[str] = None
|
||||
words: list[WordTimestampResponse] = []
|
||||
|
||||
|
||||
class UtteranceResponse(BaseModel):
|
||||
speaker: int
|
||||
text: str
|
||||
offset: int # milliseconds
|
||||
duration: int # milliseconds
|
||||
|
||||
|
||||
class RichTranscriptionResponse(BaseModel):
|
||||
"""Extended response with segments, utterances, and speaker diarization."""
|
||||
text: str
|
||||
language: Optional[str] = None
|
||||
model: str
|
||||
latency_ms: Optional[float] = None
|
||||
duration_seconds: Optional[float] = None
|
||||
segments: list[SegmentResponse] = []
|
||||
utterances: list[UtteranceResponse] = []
|
||||
speakers: dict[str, str] = {}
|
||||
speaker_map: dict[str, int] = {}
|
||||
languages: list[str] = []
|
||||
primary_language: Optional[str] = None
|
||||
words: list[WordTimestampResponse] = []
|
||||
words: Optional[list[WordInfo]] = None
|
||||
segments: Optional[list[SegmentInfo]] = None
|
||||
speakers: Optional[list[str]] = None
|
||||
|
||||
|
||||
class HealthResponse(BaseModel):
|
||||
status: str
|
||||
whisper_loaded: bool
|
||||
whisperx_available: bool
|
||||
whisperx: bool
|
||||
vllm_available: bool
|
||||
vllm_url: Optional[str] = None
|
||||
mistral_api_available: bool
|
||||
|
|
@ -103,7 +79,6 @@ class HealthResponse(BaseModel):
|
|||
|
||||
class ModelsResponse(BaseModel):
|
||||
whisper: list
|
||||
whisperx: list
|
||||
voxtral_vllm: list
|
||||
default_whisper: str
|
||||
|
||||
|
|
@ -111,7 +86,6 @@ class ModelsResponse(BaseModel):
|
|||
# Track loaded models
|
||||
models_status = {
|
||||
"whisper_loaded": False,
|
||||
"whisperx_available": False,
|
||||
"vllm_available": False,
|
||||
}
|
||||
|
||||
|
|
@ -119,45 +93,28 @@ models_status = {
|
|||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
"""Startup and shutdown events."""
|
||||
logger.info("Starting Mana STT Service...")
|
||||
logger.info("Starting ManaCore STT Service (WhisperX Edition)...")
|
||||
|
||||
# Check vLLM availability
|
||||
if USE_VLLM:
|
||||
from app.vllm_service import check_health
|
||||
health = await check_health()
|
||||
models_status["vllm_available"] = health.get("status") == "healthy"
|
||||
if models_status["vllm_available"]:
|
||||
logger.info(f"vLLM server available at {VLLM_URL}")
|
||||
else:
|
||||
logger.warning(f"vLLM server not available: {health}")
|
||||
|
||||
# Check WhisperX availability
|
||||
if USE_WHISPERX:
|
||||
try:
|
||||
from app.whisperx_service import is_available as whisperx_available
|
||||
models_status["whisperx_available"] = whisperx_available()
|
||||
if models_status["whisperx_available"]:
|
||||
logger.info("WhisperX (CUDA) available")
|
||||
else:
|
||||
logger.warning("WhisperX not available (whisperx package not installed)")
|
||||
except Exception as e:
|
||||
logger.warning(f"WhisperX check failed: {e}")
|
||||
|
||||
# Check Mistral API
|
||||
from app.voxtral_api_service import is_available as api_available
|
||||
if api_available():
|
||||
logger.info("Mistral API fallback configured")
|
||||
|
||||
# Optionally preload Whisper
|
||||
if PRELOAD_MODELS:
|
||||
logger.info("Preloading Whisper model...")
|
||||
try:
|
||||
from app.whisper_service import get_whisper_model
|
||||
get_whisper_model(DEFAULT_WHISPER_MODEL)
|
||||
models_status["whisper_loaded"] = True
|
||||
logger.info("Whisper model preloaded")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to preload Whisper: {e}")
|
||||
# Always preload WhisperX model at startup (avoids timeout on first request)
|
||||
logger.info("Preloading WhisperX model...")
|
||||
try:
|
||||
from app.whisper_service import get_whisper_model
|
||||
get_whisper_model(DEFAULT_WHISPER_MODEL)
|
||||
models_status["whisper_loaded"] = True
|
||||
logger.info("WhisperX model preloaded successfully")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to preload WhisperX: {e}")
|
||||
|
||||
logger.info(f"STT Service ready on port {PORT}")
|
||||
yield
|
||||
|
|
@ -166,9 +123,9 @@ async def lifespan(app: FastAPI):
|
|||
|
||||
# Create FastAPI app
|
||||
app = FastAPI(
|
||||
title="Mana STT Service",
|
||||
description="Speech-to-Text API with Whisper (MLX), Voxtral (vLLM), and Mistral API",
|
||||
version="2.0.0",
|
||||
title="ManaCore STT Service",
|
||||
description="Speech-to-Text API with WhisperX (word timestamps + speaker diarization)",
|
||||
version="3.0.0",
|
||||
lifespan=lifespan,
|
||||
)
|
||||
|
||||
|
|
@ -193,13 +150,15 @@ async def health_check():
|
|||
return HealthResponse(
|
||||
status="healthy",
|
||||
whisper_loaded=models_status["whisper_loaded"],
|
||||
whisperx_available=models_status["whisperx_available"],
|
||||
whisperx=True,
|
||||
vllm_available=vllm_health.get("status") == "healthy",
|
||||
vllm_url=VLLM_URL if USE_VLLM else None,
|
||||
mistral_api_available=api_available(),
|
||||
auth_required=REQUIRE_AUTH,
|
||||
models={
|
||||
"default_whisper": DEFAULT_WHISPER_MODEL,
|
||||
"engine": "whisperx",
|
||||
"features": ["transcription", "word_timestamps", "speaker_diarization"],
|
||||
},
|
||||
)
|
||||
|
||||
|
|
@ -212,17 +171,8 @@ async def list_models(auth: AuthResult = Depends(verify_api_key)):
|
|||
|
||||
vllm_models = await get_models()
|
||||
|
||||
whisperx_models = []
|
||||
if USE_WHISPERX:
|
||||
try:
|
||||
from app.whisperx_service import AVAILABLE_MODELS as wx_models
|
||||
whisperx_models = wx_models
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
return ModelsResponse(
|
||||
whisper=whisper_models,
|
||||
whisperx=whisperx_models,
|
||||
voxtral_vllm=vllm_models,
|
||||
default_whisper=DEFAULT_WHISPER_MODEL,
|
||||
)
|
||||
|
|
@ -234,16 +184,22 @@ async def transcribe_whisper(
|
|||
file: UploadFile = File(..., description="Audio file to transcribe"),
|
||||
language: Optional[str] = Form(None, description="Language code (auto-detect if not provided)"),
|
||||
model: Optional[str] = Form(None, description="Whisper model to use"),
|
||||
align: bool = Form(True, description="Enable word-level timestamp alignment"),
|
||||
diarize: bool = Form(False, description="Enable speaker diarization"),
|
||||
min_speakers: Optional[int] = Form(None, description="Min expected speakers (helps diarization)"),
|
||||
max_speakers: Optional[int] = Form(None, description="Max expected speakers"),
|
||||
auth: AuthResult = Depends(verify_api_key),
|
||||
):
|
||||
"""
|
||||
Transcribe audio using Whisper (Lightning MLX).
|
||||
Transcribe audio using WhisperX.
|
||||
|
||||
Best for: General transcription, many languages
|
||||
Supported formats: mp3, wav, m4a, flac, ogg, webm
|
||||
Features:
|
||||
- Word-level timestamps (align=true, default)
|
||||
- Speaker diarization (diarize=true, opt-in)
|
||||
|
||||
Supported formats: mp3, wav, m4a, flac, ogg, webm, mp4
|
||||
Max file size: 100MB
|
||||
"""
|
||||
# Add rate limit headers
|
||||
if auth.rate_limit_remaining is not None:
|
||||
response.headers["X-RateLimit-Remaining"] = str(auth.rate_limit_remaining)
|
||||
|
||||
|
|
@ -274,20 +230,51 @@ async def transcribe_whisper(
|
|||
filename=file.filename,
|
||||
language=language,
|
||||
model_name=model_name,
|
||||
align=align,
|
||||
diarize=diarize,
|
||||
min_speakers=min_speakers,
|
||||
max_speakers=max_speakers,
|
||||
)
|
||||
|
||||
models_status["whisper_loaded"] = True
|
||||
latency_ms = (time.time() - start_time) * 1000
|
||||
|
||||
return TranscriptionResponse(
|
||||
# Build response
|
||||
resp = TranscriptionResponse(
|
||||
text=result.text,
|
||||
language=result.language,
|
||||
model=f"whisper-{model_name}",
|
||||
model=f"whisperx-{model_name}",
|
||||
latency_ms=latency_ms,
|
||||
duration_seconds=result.duration,
|
||||
)
|
||||
|
||||
# Add word timestamps if available
|
||||
if result.words:
|
||||
resp.words = [
|
||||
WordInfo(
|
||||
word=w.word,
|
||||
start=w.start,
|
||||
end=w.end,
|
||||
score=w.score,
|
||||
speaker=w.speaker,
|
||||
)
|
||||
for w in result.words
|
||||
]
|
||||
|
||||
# Add segments
|
||||
if result.segments:
|
||||
resp.segments = [
|
||||
SegmentInfo(**s) for s in result.segments
|
||||
]
|
||||
|
||||
# Add speakers
|
||||
if result.speakers:
|
||||
resp.speakers = result.speakers
|
||||
|
||||
return resp
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Whisper transcription error: {e}")
|
||||
logger.error(f"WhisperX transcription error: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
|
|
@ -296,22 +283,10 @@ async def transcribe_voxtral(
|
|||
response: Response,
|
||||
file: UploadFile = File(..., description="Audio file to transcribe"),
|
||||
language: str = Form("de", description="Language code"),
|
||||
use_realtime: bool = Form(False, description="Use Realtime 4B model for lower latency"),
|
||||
use_realtime: bool = Form(False, description="Use Realtime 4B model"),
|
||||
auth: AuthResult = Depends(verify_api_key),
|
||||
):
|
||||
"""
|
||||
Transcribe audio using Voxtral via vLLM server.
|
||||
|
||||
Models:
|
||||
- Voxtral Mini 3B (default): Best quality
|
||||
- Voxtral Realtime 4B: Lower latency (<500ms)
|
||||
|
||||
Falls back to Mistral API if vLLM is unavailable.
|
||||
|
||||
Supported formats: mp3, wav, m4a, flac, ogg, webm
|
||||
Max file size: 100MB
|
||||
"""
|
||||
# Add rate limit headers
|
||||
"""Transcribe audio using Voxtral via vLLM or Mistral API."""
|
||||
if auth.rate_limit_remaining is not None:
|
||||
response.headers["X-RateLimit-Remaining"] = str(auth.rate_limit_remaining)
|
||||
|
||||
|
|
@ -345,49 +320,30 @@ async def transcribe_voxtral(
|
|||
if USE_VLLM:
|
||||
health = await check_health()
|
||||
if health.get("status") == "healthy":
|
||||
logger.info("Using vLLM for Voxtral transcription")
|
||||
if use_realtime:
|
||||
result = await transcribe_with_realtime(
|
||||
audio_bytes=audio_bytes,
|
||||
filename=file.filename,
|
||||
language=language,
|
||||
audio_bytes=audio_bytes, filename=file.filename, language=language,
|
||||
)
|
||||
else:
|
||||
result = await vllm_transcribe(
|
||||
audio_bytes=audio_bytes,
|
||||
filename=file.filename,
|
||||
language=language,
|
||||
audio_bytes=audio_bytes, filename=file.filename, language=language,
|
||||
)
|
||||
|
||||
return TranscriptionResponse(
|
||||
text=result.text,
|
||||
language=result.language,
|
||||
model=result.model,
|
||||
latency_ms=result.latency_ms,
|
||||
duration_seconds=result.duration_seconds,
|
||||
text=result.text, language=result.language, model=result.model,
|
||||
latency_ms=result.latency_ms, duration_seconds=result.duration_seconds,
|
||||
)
|
||||
|
||||
# Fallback to Mistral API
|
||||
if api_available():
|
||||
logger.info("Falling back to Mistral API")
|
||||
result = await api_transcribe(
|
||||
audio_bytes=audio_bytes,
|
||||
filename=file.filename,
|
||||
language=language,
|
||||
audio_bytes=audio_bytes, filename=file.filename, language=language,
|
||||
)
|
||||
|
||||
return TranscriptionResponse(
|
||||
text=result.text,
|
||||
language=result.language,
|
||||
model=result.model,
|
||||
latency_ms=None,
|
||||
text=result.text, language=result.language, model=result.model,
|
||||
duration_seconds=result.duration_seconds,
|
||||
)
|
||||
|
||||
raise HTTPException(
|
||||
status_code=503,
|
||||
detail="Voxtral not available. Start vLLM server or configure MISTRAL_API_KEY."
|
||||
)
|
||||
raise HTTPException(status_code=503, detail="Voxtral not available.")
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
|
|
@ -396,273 +352,30 @@ async def transcribe_voxtral(
|
|||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@app.post("/transcribe/voxtral/api", response_model=TranscriptionResponse)
|
||||
async def transcribe_voxtral_api(
|
||||
response: Response,
|
||||
file: UploadFile = File(..., description="Audio file to transcribe"),
|
||||
language: Optional[str] = Form(None, description="Language code (auto-detect if not provided)"),
|
||||
diarization: bool = Form(False, description="Enable speaker diarization"),
|
||||
auth: AuthResult = Depends(verify_api_key),
|
||||
):
|
||||
"""
|
||||
Transcribe audio using Mistral's Voxtral API directly.
|
||||
|
||||
Features:
|
||||
- Speaker diarization
|
||||
- Auto language detection
|
||||
- High quality (~4% WER)
|
||||
|
||||
Requires MISTRAL_API_KEY environment variable.
|
||||
"""
|
||||
# Add rate limit headers
|
||||
if auth.rate_limit_remaining is not None:
|
||||
response.headers["X-RateLimit-Remaining"] = str(auth.rate_limit_remaining)
|
||||
|
||||
from app.voxtral_api_service import is_available, transcribe_audio_bytes
|
||||
|
||||
if not is_available():
|
||||
raise HTTPException(
|
||||
status_code=503,
|
||||
detail="Mistral API not configured. Set MISTRAL_API_KEY environment variable."
|
||||
)
|
||||
|
||||
if not file.filename:
|
||||
raise HTTPException(status_code=400, detail="No file provided")
|
||||
|
||||
try:
|
||||
audio_bytes = await file.read()
|
||||
if len(audio_bytes) > 100 * 1024 * 1024:
|
||||
raise HTTPException(status_code=400, detail="File too large (max 100MB)")
|
||||
|
||||
result = await transcribe_audio_bytes(
|
||||
audio_bytes=audio_bytes,
|
||||
filename=file.filename,
|
||||
language=language,
|
||||
diarization=diarization,
|
||||
)
|
||||
|
||||
return TranscriptionResponse(
|
||||
text=result.text,
|
||||
language=result.language,
|
||||
model=result.model,
|
||||
duration_seconds=result.duration_seconds,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Mistral API error: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@app.post("/transcribe/whisperx", response_model=RichTranscriptionResponse)
|
||||
async def transcribe_whisperx(
|
||||
response: Response,
|
||||
file: UploadFile = File(..., description="Audio file to transcribe"),
|
||||
language: Optional[str] = Form(None, description="Language code (auto-detect if not provided)"),
|
||||
model: Optional[str] = Form(None, description="Whisper model to use"),
|
||||
diarization: bool = Form(True, description="Enable speaker diarization"),
|
||||
alignment: bool = Form(True, description="Enable word-level alignment"),
|
||||
min_speakers: Optional[int] = Form(None, description="Minimum expected speakers"),
|
||||
max_speakers: Optional[int] = Form(None, description="Maximum expected speakers"),
|
||||
auth: AuthResult = Depends(verify_api_key),
|
||||
):
|
||||
"""
|
||||
Transcribe audio using WhisperX (CUDA GPU).
|
||||
|
||||
Returns rich transcription with:
|
||||
- Word-level timestamps (via forced alignment)
|
||||
- Speaker diarization (via pyannote.audio)
|
||||
- Memoro-compatible utterances with speaker IDs
|
||||
|
||||
Requires NVIDIA GPU with CUDA and USE_WHISPERX=true.
|
||||
Diarization requires HF_TOKEN with pyannote model access.
|
||||
|
||||
Supported formats: mp3, wav, m4a, flac, ogg, webm, mp4
|
||||
Max file size: 100MB
|
||||
"""
|
||||
if auth.rate_limit_remaining is not None:
|
||||
response.headers["X-RateLimit-Remaining"] = str(auth.rate_limit_remaining)
|
||||
|
||||
if not USE_WHISPERX:
|
||||
raise HTTPException(
|
||||
status_code=503,
|
||||
detail="WhisperX not enabled. Set USE_WHISPERX=true on a CUDA-capable server."
|
||||
)
|
||||
|
||||
if not file.filename:
|
||||
raise HTTPException(status_code=400, detail="No file provided")
|
||||
|
||||
allowed_extensions = {".mp3", ".wav", ".m4a", ".flac", ".ogg", ".webm", ".mp4"}
|
||||
ext = os.path.splitext(file.filename)[1].lower()
|
||||
if ext not in allowed_extensions:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Unsupported file type: {ext}. Allowed: {allowed_extensions}"
|
||||
)
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
try:
|
||||
from app.whisperx_service import transcribe_audio_bytes
|
||||
|
||||
audio_bytes = await file.read()
|
||||
if len(audio_bytes) > 100 * 1024 * 1024:
|
||||
raise HTTPException(status_code=400, detail="File too large (max 100MB)")
|
||||
|
||||
model_name = model or DEFAULT_WHISPER_MODEL
|
||||
|
||||
result = await transcribe_audio_bytes(
|
||||
audio_bytes=audio_bytes,
|
||||
filename=file.filename,
|
||||
language=language,
|
||||
model_name=model_name,
|
||||
enable_diarization=diarization,
|
||||
enable_alignment=alignment,
|
||||
min_speakers=min_speakers,
|
||||
max_speakers=max_speakers,
|
||||
)
|
||||
|
||||
latency_ms = (time.time() - start_time) * 1000
|
||||
|
||||
return RichTranscriptionResponse(
|
||||
text=result.text,
|
||||
language=result.language,
|
||||
model=f"whisperx-{model_name}",
|
||||
latency_ms=latency_ms,
|
||||
duration_seconds=result.duration_seconds,
|
||||
segments=[
|
||||
SegmentResponse(
|
||||
start=s.start,
|
||||
end=s.end,
|
||||
text=s.text,
|
||||
speaker=s.speaker,
|
||||
words=[
|
||||
WordTimestampResponse(
|
||||
word=w.word,
|
||||
start=w.start,
|
||||
end=w.end,
|
||||
score=w.score,
|
||||
speaker=w.speaker,
|
||||
)
|
||||
for w in s.words
|
||||
],
|
||||
)
|
||||
for s in result.segments
|
||||
],
|
||||
utterances=[
|
||||
UtteranceResponse(
|
||||
speaker=u.speaker,
|
||||
text=u.text,
|
||||
offset=u.offset,
|
||||
duration=u.duration,
|
||||
)
|
||||
for u in result.utterances
|
||||
],
|
||||
speakers=result.speakers,
|
||||
speaker_map={k: v for k, v in result.speaker_map.items()},
|
||||
languages=result.languages,
|
||||
primary_language=result.primary_language,
|
||||
words=[
|
||||
WordTimestampResponse(
|
||||
word=w.word,
|
||||
start=w.start,
|
||||
end=w.end,
|
||||
score=w.score,
|
||||
speaker=w.speaker,
|
||||
)
|
||||
for w in result.words
|
||||
],
|
||||
)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except ImportError:
|
||||
raise HTTPException(
|
||||
status_code=503,
|
||||
detail="WhisperX not installed. Install with: pip install -r requirements-cuda.txt"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"WhisperX transcription error: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@app.post("/transcribe/auto", response_model=TranscriptionResponse)
|
||||
async def transcribe_auto(
|
||||
response: Response,
|
||||
file: UploadFile = File(..., description="Audio file to transcribe"),
|
||||
language: Optional[str] = Form(None, description="Language hint"),
|
||||
prefer: str = Form("whisper", description="Preferred: 'whisper', 'whisperx', or 'voxtral'"),
|
||||
prefer: str = Form("whisper", description="Preferred: 'whisper' or 'voxtral'"),
|
||||
auth: AuthResult = Depends(verify_api_key),
|
||||
):
|
||||
"""
|
||||
Transcribe with automatic model selection and fallback.
|
||||
|
||||
Fallback chain:
|
||||
- whisper: Whisper → WhisperX → Voxtral → Mistral API
|
||||
- whisperx: WhisperX → Whisper → Voxtral → Mistral API
|
||||
- voxtral: Voxtral → WhisperX → Whisper → Mistral API
|
||||
"""
|
||||
# Add rate limit headers
|
||||
"""Auto-select best model with fallback chain."""
|
||||
if auth.rate_limit_remaining is not None:
|
||||
response.headers["X-RateLimit-Remaining"] = str(auth.rate_limit_remaining)
|
||||
|
||||
async def try_whisperx_simple():
|
||||
"""Try WhisperX and return as simple TranscriptionResponse."""
|
||||
if not USE_WHISPERX:
|
||||
raise RuntimeError("WhisperX not enabled")
|
||||
from app.whisperx_service import transcribe_audio_bytes as wx_transcribe
|
||||
audio_bytes = await file.read()
|
||||
result = await wx_transcribe(
|
||||
audio_bytes=audio_bytes,
|
||||
filename=file.filename or "audio.wav",
|
||||
language=language,
|
||||
enable_diarization=False,
|
||||
enable_alignment=False,
|
||||
)
|
||||
return TranscriptionResponse(
|
||||
text=result.text,
|
||||
language=result.language,
|
||||
model=f"whisperx-{DEFAULT_WHISPER_MODEL}",
|
||||
latency_ms=None,
|
||||
duration_seconds=result.duration_seconds,
|
||||
)
|
||||
|
||||
# Build fallback chain based on preference
|
||||
if prefer == "whisperx":
|
||||
chain = [
|
||||
("WhisperX", try_whisperx_simple),
|
||||
("Whisper", lambda: transcribe_whisper(response, file, language, None, auth)),
|
||||
("Voxtral", lambda: transcribe_voxtral(response, file, language or "de", False, auth)),
|
||||
("Mistral API", lambda: transcribe_voxtral_api(response, file, language, False, auth)),
|
||||
]
|
||||
elif prefer == "voxtral":
|
||||
chain = [
|
||||
("Voxtral", lambda: transcribe_voxtral(response, file, language or "de", False, auth)),
|
||||
("WhisperX", try_whisperx_simple),
|
||||
("Whisper", lambda: transcribe_whisper(response, file, language, None, auth)),
|
||||
("Mistral API", lambda: transcribe_voxtral_api(response, file, language, False, auth)),
|
||||
]
|
||||
else:
|
||||
chain = [
|
||||
("Whisper", lambda: transcribe_whisper(response, file, language, None, auth)),
|
||||
("WhisperX", try_whisperx_simple),
|
||||
("Voxtral", lambda: transcribe_voxtral(response, file, language or "de", False, auth)),
|
||||
("Mistral API", lambda: transcribe_voxtral_api(response, file, language, False, auth)),
|
||||
]
|
||||
|
||||
last_error = None
|
||||
for name, fn in chain:
|
||||
if prefer == "voxtral":
|
||||
try:
|
||||
result = await fn()
|
||||
return result
|
||||
except Exception as e:
|
||||
last_error = e
|
||||
logger.warning(f"{name} failed: {e}")
|
||||
return await transcribe_voxtral(response, file, language or "de", False, auth)
|
||||
except Exception:
|
||||
await file.seek(0)
|
||||
|
||||
raise HTTPException(
|
||||
status_code=503,
|
||||
detail=f"All transcription backends failed. Last error: {last_error}"
|
||||
)
|
||||
return await transcribe_whisper(response, file, language, None, True, False, None, None, auth)
|
||||
else:
|
||||
try:
|
||||
return await transcribe_whisper(response, file, language, None, True, False, None, None, auth)
|
||||
except Exception:
|
||||
await file.seek(0)
|
||||
return await transcribe_voxtral(response, file, language or "de", False, auth)
|
||||
|
||||
|
||||
@app.exception_handler(Exception)
|
||||
|
|
@ -676,9 +389,4 @@ async def global_exception_handler(request, exc):
|
|||
|
||||
if __name__ == "__main__":
|
||||
import uvicorn
|
||||
uvicorn.run(
|
||||
"app.main:app",
|
||||
host="0.0.0.0",
|
||||
port=PORT,
|
||||
reload=False,
|
||||
)
|
||||
uvicorn.run("app.main:app", host="0.0.0.0", port=PORT, reload=False)
|
||||
|
|
|
|||
114
services/mana-stt/app/vram_manager.py
Normal file
114
services/mana-stt/app/vram_manager.py
Normal file
|
|
@ -0,0 +1,114 @@
|
|||
"""
|
||||
VRAM Manager — Automatic model unloading after idle timeout.
|
||||
|
||||
Tracks last usage time per model and unloads after configurable timeout.
|
||||
Designed for shared GPU environments (multiple services on one RTX 3090).
|
||||
|
||||
Usage in a service:
|
||||
from vram_manager import VramManager
|
||||
|
||||
vram = VramManager(idle_timeout=300) # 5 min
|
||||
|
||||
# Before using a model
|
||||
vram.touch()
|
||||
|
||||
# Call periodically (e.g., from health check or background task)
|
||||
vram.check_idle(unload_fn=my_unload_function)
|
||||
"""
|
||||
|
||||
import os
|
||||
import time
|
||||
import logging
|
||||
import threading
|
||||
from typing import Optional, Callable
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
DEFAULT_IDLE_TIMEOUT = int(os.getenv("VRAM_IDLE_TIMEOUT", "300")) # 5 minutes
|
||||
|
||||
|
||||
class VramManager:
|
||||
def __init__(self, idle_timeout: int = DEFAULT_IDLE_TIMEOUT, service_name: str = "unknown"):
|
||||
self.idle_timeout = idle_timeout
|
||||
self.service_name = service_name
|
||||
self.last_used: float = 0.0
|
||||
self.model_loaded: bool = False
|
||||
self._lock = threading.Lock()
|
||||
self._timer: Optional[threading.Timer] = None
|
||||
|
||||
def touch(self):
|
||||
"""Mark the model as recently used. Call before/after each inference."""
|
||||
with self._lock:
|
||||
self.last_used = time.time()
|
||||
self.model_loaded = True
|
||||
self._schedule_check()
|
||||
|
||||
def mark_loaded(self):
|
||||
"""Mark that a model has been loaded into VRAM."""
|
||||
with self._lock:
|
||||
self.model_loaded = True
|
||||
self.last_used = time.time()
|
||||
self._schedule_check()
|
||||
logger.info(f"[{self.service_name}] Model loaded, idle timeout: {self.idle_timeout}s")
|
||||
|
||||
def mark_unloaded(self):
|
||||
"""Mark that a model has been unloaded from VRAM."""
|
||||
with self._lock:
|
||||
self.model_loaded = False
|
||||
if self._timer:
|
||||
self._timer.cancel()
|
||||
self._timer = None
|
||||
logger.info(f"[{self.service_name}] Model unloaded, VRAM freed")
|
||||
|
||||
def is_idle(self) -> bool:
|
||||
"""Check if the model has been idle longer than the timeout."""
|
||||
if not self.model_loaded:
|
||||
return False
|
||||
return (time.time() - self.last_used) > self.idle_timeout
|
||||
|
||||
def seconds_until_unload(self) -> Optional[float]:
|
||||
"""Seconds until the model will be unloaded, or None if not loaded."""
|
||||
if not self.model_loaded:
|
||||
return None
|
||||
remaining = self.idle_timeout - (time.time() - self.last_used)
|
||||
return max(0, remaining)
|
||||
|
||||
def check_and_unload(self, unload_fn: Callable[[], None]) -> bool:
|
||||
"""Check if idle and unload if so. Returns True if unloaded."""
|
||||
if self.is_idle():
|
||||
logger.info(f"[{self.service_name}] Idle for >{self.idle_timeout}s, unloading model...")
|
||||
try:
|
||||
unload_fn()
|
||||
self.mark_unloaded()
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"[{self.service_name}] Failed to unload: {e}")
|
||||
return False
|
||||
|
||||
def _schedule_check(self):
|
||||
"""Schedule an idle check after the timeout period."""
|
||||
if self._timer:
|
||||
self._timer.cancel()
|
||||
|
||||
self._timer = threading.Timer(
|
||||
self.idle_timeout + 5, # Small buffer
|
||||
self._auto_check,
|
||||
)
|
||||
self._timer.daemon = True
|
||||
self._timer.start()
|
||||
|
||||
def _auto_check(self):
|
||||
"""Auto-triggered idle check (called by timer)."""
|
||||
# This is just a log — actual unloading needs the unload_fn
|
||||
# which depends on the service. The service should call check_and_unload.
|
||||
if self.is_idle():
|
||||
logger.info(f"[{self.service_name}] Model idle for >{self.idle_timeout}s — ready to unload")
|
||||
|
||||
def status(self) -> dict:
|
||||
"""Get current VRAM manager status."""
|
||||
return {
|
||||
"model_loaded": self.model_loaded,
|
||||
"idle_seconds": round(time.time() - self.last_used, 1) if self.model_loaded else None,
|
||||
"idle_timeout": self.idle_timeout,
|
||||
"seconds_until_unload": round(self.seconds_until_unload(), 1) if self.model_loaded else None,
|
||||
}
|
||||
|
|
@ -1,6 +1,11 @@
|
|||
"""
|
||||
Whisper STT Service using Lightning Whisper MLX
|
||||
Optimized for Apple Silicon (M1/M2/M3/M4)
|
||||
Whisper STT Service using WhisperX (CUDA)
|
||||
Provides: transcription, word-level timestamps, speaker diarization.
|
||||
|
||||
WhisperX pipeline:
|
||||
1. faster-whisper for transcription
|
||||
2. wav2vec2 for forced alignment (precise word timestamps)
|
||||
3. pyannote-audio for speaker diarization
|
||||
"""
|
||||
|
||||
import os
|
||||
|
|
@ -8,12 +13,55 @@ import tempfile
|
|||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
from dataclasses import dataclass
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Lazy load to avoid import errors if not installed
|
||||
_whisper_model = None
|
||||
# Lazy load
|
||||
_whisperx_model = None
|
||||
_align_model = None
|
||||
_align_metadata = None
|
||||
_diarize_pipeline = None
|
||||
|
||||
# Config
|
||||
HF_TOKEN = os.getenv("HF_TOKEN", "")
|
||||
|
||||
# VRAM management — unload after 10 min idle (STT uses ~3GB)
|
||||
from app.vram_manager import VramManager
|
||||
_vram = VramManager(
|
||||
idle_timeout=int(os.getenv("VRAM_IDLE_TIMEOUT", "600")),
|
||||
service_name="mana-stt",
|
||||
)
|
||||
|
||||
|
||||
def unload_models():
|
||||
"""Unload all WhisperX models from GPU to free VRAM."""
|
||||
global _whisperx_model, _align_model, _align_metadata, _diarize_pipeline
|
||||
import torch
|
||||
|
||||
if _whisperx_model is not None:
|
||||
del _whisperx_model
|
||||
_whisperx_model = None
|
||||
if _align_model is not None:
|
||||
del _align_model
|
||||
_align_model = None
|
||||
_align_metadata = None
|
||||
if _diarize_pipeline is not None:
|
||||
del _diarize_pipeline
|
||||
_diarize_pipeline = None
|
||||
|
||||
torch.cuda.empty_cache()
|
||||
_vram.mark_unloaded()
|
||||
logger.info("WhisperX models unloaded, VRAM freed")
|
||||
|
||||
|
||||
@dataclass
|
||||
class WordSegment:
|
||||
word: str
|
||||
start: float
|
||||
end: float
|
||||
score: Optional[float] = None
|
||||
speaker: Optional[str] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
|
|
@ -22,83 +70,231 @@ class TranscriptionResult:
|
|||
language: Optional[str] = None
|
||||
duration: Optional[float] = None
|
||||
segments: Optional[list] = None
|
||||
words: Optional[list[WordSegment]] = field(default_factory=list)
|
||||
speakers: Optional[list[str]] = field(default_factory=list)
|
||||
|
||||
|
||||
def get_whisper_model(model_name: str = "large-v3", batch_size: int = 12):
|
||||
"""Get or create Whisper model instance (singleton pattern)."""
|
||||
global _whisper_model
|
||||
def get_whisper_model(model_name: str = "large-v3", **kwargs):
|
||||
"""Get or create WhisperX model instance (singleton)."""
|
||||
global _whisperx_model
|
||||
|
||||
if _whisper_model is None:
|
||||
logger.info(f"Loading Whisper model: {model_name}")
|
||||
try:
|
||||
from lightning_whisper_mlx import LightningWhisperMLX
|
||||
_whisper_model = LightningWhisperMLX(
|
||||
model=model_name,
|
||||
batch_size=batch_size,
|
||||
quant=None # Use full precision for best quality
|
||||
)
|
||||
logger.info(f"Whisper model loaded successfully: {model_name}")
|
||||
except ImportError as e:
|
||||
logger.error(f"Failed to import lightning_whisper_mlx: {e}")
|
||||
raise RuntimeError(
|
||||
"lightning-whisper-mlx not installed. "
|
||||
"Run: pip install lightning-whisper-mlx"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to load Whisper model: {e}")
|
||||
raise
|
||||
if _whisperx_model is not None:
|
||||
return _whisperx_model
|
||||
|
||||
return _whisper_model
|
||||
logger.info(f"Loading WhisperX model: {model_name}")
|
||||
try:
|
||||
import whisperx
|
||||
|
||||
device = os.getenv("WHISPER_DEVICE", "cuda")
|
||||
compute_type = os.getenv("WHISPER_COMPUTE_TYPE", "float16")
|
||||
|
||||
default_language = os.getenv("WHISPER_DEFAULT_LANGUAGE", "de")
|
||||
_whisperx_model = whisperx.load_model(
|
||||
model_name,
|
||||
device=device,
|
||||
compute_type=compute_type,
|
||||
language=default_language,
|
||||
)
|
||||
logger.info(f"WhisperX model loaded: {model_name} on {device} ({compute_type})")
|
||||
_vram.mark_loaded()
|
||||
except ImportError as e:
|
||||
logger.error(f"Failed to import whisperx: {e}")
|
||||
raise RuntimeError("whisperx not installed. Run: pip install whisperx")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to load WhisperX model: {e}")
|
||||
raise
|
||||
|
||||
return _whisperx_model
|
||||
|
||||
|
||||
def _get_align_model(language: str, device: str = "cuda"):
|
||||
"""Get or create alignment model for a language."""
|
||||
global _align_model, _align_metadata
|
||||
|
||||
import whisperx
|
||||
|
||||
# Reload if language changed (alignment models are language-specific)
|
||||
if _align_model is None:
|
||||
logger.info(f"Loading alignment model for language: {language}")
|
||||
_align_model, _align_metadata = whisperx.load_align_model(
|
||||
language_code=language,
|
||||
device=device,
|
||||
)
|
||||
logger.info("Alignment model loaded")
|
||||
|
||||
return _align_model, _align_metadata
|
||||
|
||||
|
||||
def _get_diarize_pipeline(device: str = "cuda"):
|
||||
"""Get or create speaker diarization pipeline."""
|
||||
global _diarize_pipeline
|
||||
|
||||
if _diarize_pipeline is not None:
|
||||
return _diarize_pipeline
|
||||
|
||||
import torch
|
||||
from pyannote.audio import Pipeline
|
||||
|
||||
token = HF_TOKEN or os.getenv("HUGGING_FACE_HUB_TOKEN", "")
|
||||
if not token:
|
||||
logger.warning("No HF_TOKEN set — speaker diarization may fail for gated models")
|
||||
|
||||
logger.info("Loading speaker diarization pipeline (pyannote)...")
|
||||
_diarize_pipeline = Pipeline.from_pretrained(
|
||||
"pyannote/speaker-diarization-3.1",
|
||||
token=token,
|
||||
)
|
||||
_diarize_pipeline.to(torch.device(device))
|
||||
logger.info("Diarization pipeline loaded")
|
||||
return _diarize_pipeline
|
||||
|
||||
|
||||
def transcribe_audio(
|
||||
audio_path: str,
|
||||
language: Optional[str] = None,
|
||||
model_name: str = "large-v3",
|
||||
align: bool = True,
|
||||
diarize: bool = False,
|
||||
min_speakers: Optional[int] = None,
|
||||
max_speakers: Optional[int] = None,
|
||||
) -> TranscriptionResult:
|
||||
"""
|
||||
Transcribe audio file using Lightning Whisper MLX.
|
||||
Transcribe audio using WhisperX with optional alignment and diarization.
|
||||
|
||||
Args:
|
||||
audio_path: Path to audio file (mp3, wav, m4a, etc.)
|
||||
language: Optional language code (e.g., 'de', 'en'). Auto-detect if None.
|
||||
audio_path: Path to audio file
|
||||
language: Language code (auto-detect if None)
|
||||
model_name: Whisper model to use
|
||||
align: Enable word-level timestamp alignment
|
||||
diarize: Enable speaker diarization
|
||||
min_speakers: Minimum expected speakers (helps diarization)
|
||||
max_speakers: Maximum expected speakers
|
||||
|
||||
Returns:
|
||||
TranscriptionResult with text and metadata
|
||||
TranscriptionResult with text, word timestamps, and speaker info
|
||||
"""
|
||||
import whisperx
|
||||
|
||||
device = os.getenv("WHISPER_DEVICE", "cuda")
|
||||
model = get_whisper_model(model_name)
|
||||
|
||||
logger.info(f"Transcribing: {audio_path}")
|
||||
logger.info(f"Transcribing: {audio_path} (align={align}, diarize={diarize})")
|
||||
|
||||
try:
|
||||
# Lightning Whisper MLX returns dict with 'text' key
|
||||
result = model.transcribe(
|
||||
audio_path=audio_path,
|
||||
language=language,
|
||||
)
|
||||
# Check and unload if idle, then reload
|
||||
_vram.check_and_unload(unload_models)
|
||||
_vram.touch()
|
||||
|
||||
# Handle different return formats
|
||||
if isinstance(result, dict):
|
||||
text = result.get("text", "")
|
||||
segments = result.get("segments", [])
|
||||
detected_language = result.get("language", language)
|
||||
else:
|
||||
text = str(result)
|
||||
segments = []
|
||||
detected_language = language
|
||||
# Step 1: Load audio
|
||||
audio = whisperx.load_audio(audio_path)
|
||||
|
||||
logger.info(f"Transcription complete: {len(text)} characters")
|
||||
# Step 2: Transcribe with faster-whisper
|
||||
transcribe_kwargs = {"batch_size": 16}
|
||||
if language:
|
||||
transcribe_kwargs["language"] = language
|
||||
result = model.transcribe(audio, **transcribe_kwargs)
|
||||
detected_language = result.get("language", language or "en")
|
||||
|
||||
return TranscriptionResult(
|
||||
text=text.strip(),
|
||||
language=detected_language,
|
||||
segments=segments,
|
||||
)
|
||||
# Step 3: Align (word-level timestamps)
|
||||
if align and result["segments"]:
|
||||
try:
|
||||
align_model, metadata = _get_align_model(detected_language, device)
|
||||
result = whisperx.align(
|
||||
result["segments"],
|
||||
align_model,
|
||||
metadata,
|
||||
audio,
|
||||
device,
|
||||
return_char_alignments=False,
|
||||
)
|
||||
logger.info("Word alignment complete")
|
||||
except Exception as e:
|
||||
logger.warning(f"Alignment failed (continuing without): {e}")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Transcription failed: {e}")
|
||||
raise
|
||||
# Step 4: Diarize (speaker identification)
|
||||
if diarize:
|
||||
try:
|
||||
import torch
|
||||
import torchaudio
|
||||
|
||||
diarize_pipe = _get_diarize_pipeline(device)
|
||||
|
||||
# pyannote needs waveform as tensor, not the whisperx audio array
|
||||
waveform = torch.from_numpy(audio).unsqueeze(0).float()
|
||||
diarize_input = {"waveform": waveform, "sample_rate": 16000}
|
||||
|
||||
diarize_kwargs = {}
|
||||
if min_speakers is not None:
|
||||
diarize_kwargs["min_speakers"] = min_speakers
|
||||
if max_speakers is not None:
|
||||
diarize_kwargs["max_speakers"] = max_speakers
|
||||
|
||||
diarize_output = diarize_pipe(diarize_input, **diarize_kwargs)
|
||||
|
||||
# pyannote 4.x returns DiarizeOutput, extract the Annotation
|
||||
if hasattr(diarize_output, "speaker_diarization"):
|
||||
diarize_annotation = diarize_output.speaker_diarization
|
||||
else:
|
||||
diarize_annotation = diarize_output
|
||||
|
||||
# Convert pyannote output to DataFrame for whisperx
|
||||
import pandas as pd
|
||||
diarize_rows = []
|
||||
for turn, _, speaker in diarize_annotation.itertracks(yield_label=True):
|
||||
diarize_rows.append({
|
||||
"start": turn.start,
|
||||
"end": turn.end,
|
||||
"speaker": speaker,
|
||||
})
|
||||
|
||||
diarize_df = pd.DataFrame(diarize_rows)
|
||||
result = whisperx.assign_word_speakers(diarize_df, result)
|
||||
logger.info("Speaker diarization complete")
|
||||
except Exception as e:
|
||||
logger.warning(f"Diarization failed (continuing without): {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
|
||||
# Build response
|
||||
segments = result.get("segments", [])
|
||||
full_text_parts = []
|
||||
all_words = []
|
||||
speaker_set = set()
|
||||
|
||||
for seg in segments:
|
||||
full_text_parts.append(seg.get("text", ""))
|
||||
speaker = seg.get("speaker")
|
||||
if speaker:
|
||||
speaker_set.add(speaker)
|
||||
|
||||
for word_info in seg.get("words", []):
|
||||
all_words.append(WordSegment(
|
||||
word=word_info.get("word", ""),
|
||||
start=word_info.get("start", 0.0),
|
||||
end=word_info.get("end", 0.0),
|
||||
score=word_info.get("score"),
|
||||
speaker=word_info.get("speaker", speaker),
|
||||
))
|
||||
|
||||
text = " ".join(full_text_parts)
|
||||
|
||||
_vram.touch()
|
||||
logger.info(
|
||||
f"Transcription complete: {len(text)} chars, "
|
||||
f"{len(all_words)} words, {len(speaker_set)} speakers"
|
||||
)
|
||||
|
||||
return TranscriptionResult(
|
||||
text=text.strip(),
|
||||
language=detected_language,
|
||||
segments=[{
|
||||
"start": s.get("start", 0),
|
||||
"end": s.get("end", 0),
|
||||
"text": s.get("text", ""),
|
||||
"speaker": s.get("speaker"),
|
||||
} for s in segments],
|
||||
words=all_words,
|
||||
speakers=sorted(speaker_set),
|
||||
)
|
||||
|
||||
|
||||
async def transcribe_audio_bytes(
|
||||
|
|
@ -106,53 +302,57 @@ async def transcribe_audio_bytes(
|
|||
filename: str,
|
||||
language: Optional[str] = None,
|
||||
model_name: str = "large-v3",
|
||||
align: bool = True,
|
||||
diarize: bool = False,
|
||||
min_speakers: Optional[int] = None,
|
||||
max_speakers: Optional[int] = None,
|
||||
) -> TranscriptionResult:
|
||||
"""
|
||||
Transcribe audio from bytes (for API uploads).
|
||||
"""Transcribe audio from bytes (for API uploads)."""
|
||||
import asyncio
|
||||
|
||||
Args:
|
||||
audio_bytes: Raw audio file bytes
|
||||
filename: Original filename (for extension detection)
|
||||
language: Optional language code
|
||||
model_name: Whisper model to use
|
||||
|
||||
Returns:
|
||||
TranscriptionResult
|
||||
"""
|
||||
# Get file extension
|
||||
ext = Path(filename).suffix or ".wav"
|
||||
|
||||
# Write to temp file
|
||||
with tempfile.NamedTemporaryFile(suffix=ext, delete=False) as tmp:
|
||||
tmp.write(audio_bytes)
|
||||
tmp_path = tmp.name
|
||||
|
||||
try:
|
||||
result = transcribe_audio(
|
||||
audio_path=tmp_path,
|
||||
language=language,
|
||||
model_name=model_name,
|
||||
# Run in thread pool to avoid blocking the event loop
|
||||
loop = asyncio.get_event_loop()
|
||||
result = await loop.run_in_executor(
|
||||
None,
|
||||
lambda: transcribe_audio(
|
||||
audio_path=tmp_path,
|
||||
language=language,
|
||||
model_name=model_name,
|
||||
align=align,
|
||||
diarize=diarize,
|
||||
min_speakers=min_speakers,
|
||||
max_speakers=max_speakers,
|
||||
),
|
||||
)
|
||||
return result
|
||||
finally:
|
||||
# Clean up temp file
|
||||
try:
|
||||
os.unlink(tmp_path)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
# Available models for Lightning Whisper MLX
|
||||
# Available models
|
||||
AVAILABLE_MODELS = [
|
||||
"tiny",
|
||||
"small",
|
||||
"tiny.en",
|
||||
"base",
|
||||
"base.en",
|
||||
"small",
|
||||
"small.en",
|
||||
"medium",
|
||||
"large",
|
||||
"medium.en",
|
||||
"large-v1",
|
||||
"large-v2",
|
||||
"large-v3", # Recommended for Mac Mini
|
||||
"distil-small.en",
|
||||
"distil-medium.en",
|
||||
"large-v3",
|
||||
"large-v3-turbo",
|
||||
"distil-large-v2",
|
||||
"distil-large-v3",
|
||||
]
|
||||
|
|
|
|||
|
|
@ -1,175 +0,0 @@
|
|||
"""
|
||||
Whisper STT Service using faster-whisper (CUDA)
|
||||
Optimized for NVIDIA GPUs (RTX 3090 etc.)
|
||||
|
||||
Drop-in replacement for whisper_service.py (MLX version).
|
||||
Uses faster-whisper with CTranslate2 for GPU-accelerated inference.
|
||||
"""
|
||||
|
||||
import os
|
||||
import tempfile
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
from dataclasses import dataclass
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Lazy load to avoid import errors if not installed
|
||||
_whisper_model = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class TranscriptionResult:
|
||||
text: str
|
||||
language: Optional[str] = None
|
||||
duration: Optional[float] = None
|
||||
segments: Optional[list] = None
|
||||
|
||||
|
||||
def get_whisper_model(model_name: str = "large-v3", **kwargs):
|
||||
"""Get or create Whisper model instance (singleton pattern)."""
|
||||
global _whisper_model
|
||||
|
||||
if _whisper_model is None:
|
||||
logger.info(f"Loading Whisper model: {model_name}")
|
||||
try:
|
||||
from faster_whisper import WhisperModel
|
||||
|
||||
# Use CUDA with float16 for RTX 3090
|
||||
compute_type = os.getenv("WHISPER_COMPUTE_TYPE", "float16")
|
||||
device = os.getenv("WHISPER_DEVICE", "cuda")
|
||||
|
||||
_whisper_model = WhisperModel(
|
||||
model_name,
|
||||
device=device,
|
||||
compute_type=compute_type,
|
||||
)
|
||||
logger.info(f"Whisper model loaded: {model_name} on {device} ({compute_type})")
|
||||
except ImportError as e:
|
||||
logger.error(f"Failed to import faster_whisper: {e}")
|
||||
raise RuntimeError(
|
||||
"faster-whisper not installed. "
|
||||
"Run: pip install faster-whisper"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to load Whisper model: {e}")
|
||||
raise
|
||||
|
||||
return _whisper_model
|
||||
|
||||
|
||||
def transcribe_audio(
|
||||
audio_path: str,
|
||||
language: Optional[str] = None,
|
||||
model_name: str = "large-v3",
|
||||
) -> TranscriptionResult:
|
||||
"""
|
||||
Transcribe audio file using faster-whisper (CUDA).
|
||||
|
||||
Args:
|
||||
audio_path: Path to audio file (mp3, wav, m4a, etc.)
|
||||
language: Optional language code (e.g., 'de', 'en'). Auto-detect if None.
|
||||
model_name: Whisper model to use
|
||||
|
||||
Returns:
|
||||
TranscriptionResult with text and metadata
|
||||
"""
|
||||
model = get_whisper_model(model_name)
|
||||
|
||||
logger.info(f"Transcribing: {audio_path}")
|
||||
|
||||
try:
|
||||
segments, info = model.transcribe(
|
||||
audio_path,
|
||||
language=language,
|
||||
beam_size=5,
|
||||
vad_filter=True, # Filter out silence
|
||||
)
|
||||
|
||||
# Collect all segments
|
||||
all_segments = []
|
||||
full_text_parts = []
|
||||
for segment in segments:
|
||||
full_text_parts.append(segment.text)
|
||||
all_segments.append({
|
||||
"start": segment.start,
|
||||
"end": segment.end,
|
||||
"text": segment.text,
|
||||
})
|
||||
|
||||
text = " ".join(full_text_parts)
|
||||
detected_language = info.language if info else language
|
||||
|
||||
logger.info(f"Transcription complete: {len(text)} characters, language={detected_language}")
|
||||
|
||||
return TranscriptionResult(
|
||||
text=text.strip(),
|
||||
language=detected_language,
|
||||
duration=info.duration if info else None,
|
||||
segments=all_segments,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Transcription failed: {e}")
|
||||
raise
|
||||
|
||||
|
||||
async def transcribe_audio_bytes(
|
||||
audio_bytes: bytes,
|
||||
filename: str,
|
||||
language: Optional[str] = None,
|
||||
model_name: str = "large-v3",
|
||||
) -> TranscriptionResult:
|
||||
"""
|
||||
Transcribe audio from bytes (for API uploads).
|
||||
|
||||
Args:
|
||||
audio_bytes: Raw audio file bytes
|
||||
filename: Original filename (for extension detection)
|
||||
language: Optional language code
|
||||
model_name: Whisper model to use
|
||||
|
||||
Returns:
|
||||
TranscriptionResult
|
||||
"""
|
||||
# Get file extension
|
||||
ext = Path(filename).suffix or ".wav"
|
||||
|
||||
# Write to temp file
|
||||
with tempfile.NamedTemporaryFile(suffix=ext, delete=False) as tmp:
|
||||
tmp.write(audio_bytes)
|
||||
tmp_path = tmp.name
|
||||
|
||||
try:
|
||||
result = transcribe_audio(
|
||||
audio_path=tmp_path,
|
||||
language=language,
|
||||
model_name=model_name,
|
||||
)
|
||||
return result
|
||||
finally:
|
||||
# Clean up temp file
|
||||
try:
|
||||
os.unlink(tmp_path)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
# Available models for faster-whisper
|
||||
AVAILABLE_MODELS = [
|
||||
"tiny",
|
||||
"tiny.en",
|
||||
"base",
|
||||
"base.en",
|
||||
"small",
|
||||
"small.en",
|
||||
"medium",
|
||||
"medium.en",
|
||||
"large-v1",
|
||||
"large-v2",
|
||||
"large-v3",
|
||||
"large-v3-turbo",
|
||||
"distil-large-v2",
|
||||
"distil-large-v3",
|
||||
]
|
||||
|
|
@ -1,419 +0,0 @@
|
|||
"""
|
||||
WhisperX STT Service using faster-whisper + pyannote (CUDA)
|
||||
Optimized for NVIDIA GPUs (RTX 3090 etc.)
|
||||
|
||||
Features:
|
||||
- Word-level timestamps via forced alignment
|
||||
- Speaker diarization via pyannote.audio
|
||||
- Segment-level timestamps with speaker labels
|
||||
- VAD filtering for silence removal
|
||||
|
||||
Requires HuggingFace token for pyannote models:
|
||||
export HF_TOKEN=hf_xxx
|
||||
# Accept terms at: https://huggingface.co/pyannote/speaker-diarization-3.1
|
||||
"""
|
||||
|
||||
import os
|
||||
import tempfile
|
||||
import logging
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Lazy-loaded singletons
|
||||
_whisper_model = None
|
||||
_align_model = None
|
||||
_align_metadata = None
|
||||
_diarize_pipeline = None
|
||||
|
||||
HF_TOKEN = os.getenv("HF_TOKEN")
|
||||
DEFAULT_MODEL = os.getenv("WHISPER_MODEL", "large-v3")
|
||||
DEVICE = os.getenv("WHISPER_DEVICE", "cuda")
|
||||
COMPUTE_TYPE = os.getenv("WHISPER_COMPUTE_TYPE", "float16")
|
||||
BATCH_SIZE = int(os.getenv("WHISPERX_BATCH_SIZE", "16"))
|
||||
|
||||
|
||||
@dataclass
|
||||
class WordInfo:
|
||||
"""Word with timestamp."""
|
||||
word: str
|
||||
start: float
|
||||
end: float
|
||||
score: float = 0.0
|
||||
speaker: Optional[str] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class SegmentInfo:
|
||||
"""Segment with speaker and word-level detail."""
|
||||
start: float
|
||||
end: float
|
||||
text: str
|
||||
speaker: Optional[str] = None
|
||||
words: list[WordInfo] = field(default_factory=list)
|
||||
|
||||
|
||||
@dataclass
|
||||
class Utterance:
|
||||
"""Speaker utterance in Memoro-compatible format."""
|
||||
speaker: int
|
||||
text: str
|
||||
offset: int # milliseconds
|
||||
duration: int # milliseconds
|
||||
|
||||
|
||||
@dataclass
|
||||
class WhisperXResult:
|
||||
"""Rich transcription result with alignment and diarization."""
|
||||
text: str
|
||||
language: Optional[str] = None
|
||||
duration_seconds: Optional[float] = None
|
||||
segments: list[SegmentInfo] = field(default_factory=list)
|
||||
utterances: list[Utterance] = field(default_factory=list)
|
||||
speakers: dict[str, str] = field(default_factory=dict)
|
||||
speaker_map: dict[str, int] = field(default_factory=dict)
|
||||
languages: list[str] = field(default_factory=list)
|
||||
primary_language: Optional[str] = None
|
||||
words: list[WordInfo] = field(default_factory=list)
|
||||
|
||||
|
||||
def get_whisper_model(model_name: str = None):
|
||||
"""Load faster-whisper model (singleton)."""
|
||||
global _whisper_model
|
||||
model_name = model_name or DEFAULT_MODEL
|
||||
|
||||
if _whisper_model is None:
|
||||
logger.info(f"Loading WhisperX model: {model_name} on {DEVICE} ({COMPUTE_TYPE})")
|
||||
try:
|
||||
import whisperx
|
||||
|
||||
_whisper_model = whisperx.load_model(
|
||||
model_name,
|
||||
device=DEVICE,
|
||||
compute_type=COMPUTE_TYPE,
|
||||
)
|
||||
logger.info(f"WhisperX model loaded: {model_name}")
|
||||
except ImportError:
|
||||
raise RuntimeError(
|
||||
"whisperx not installed. "
|
||||
"Run: pip install whisperx"
|
||||
)
|
||||
|
||||
return _whisper_model
|
||||
|
||||
|
||||
def get_align_model(language_code: str):
|
||||
"""Load alignment model for a specific language (cached per language)."""
|
||||
global _align_model, _align_metadata
|
||||
|
||||
try:
|
||||
import whisperx
|
||||
|
||||
_align_model, _align_metadata = whisperx.load_align_model(
|
||||
language_code=language_code,
|
||||
device=DEVICE,
|
||||
)
|
||||
logger.info(f"Alignment model loaded for language: {language_code}")
|
||||
return _align_model, _align_metadata
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to load alignment model for {language_code}: {e}")
|
||||
return None, None
|
||||
|
||||
|
||||
def get_diarize_pipeline():
|
||||
"""Load pyannote speaker diarization pipeline (singleton)."""
|
||||
global _diarize_pipeline
|
||||
|
||||
if _diarize_pipeline is None:
|
||||
if not HF_TOKEN:
|
||||
logger.warning("HF_TOKEN not set — diarization disabled")
|
||||
return None
|
||||
|
||||
try:
|
||||
import whisperx
|
||||
|
||||
_diarize_pipeline = whisperx.DiarizationPipeline(
|
||||
use_auth_token=HF_TOKEN,
|
||||
device=DEVICE,
|
||||
)
|
||||
logger.info("Diarization pipeline loaded")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to load diarization pipeline: {e}")
|
||||
return None
|
||||
|
||||
return _diarize_pipeline
|
||||
|
||||
|
||||
def _build_utterances(segments: list[SegmentInfo]) -> tuple[list[Utterance], dict[str, str], dict[str, int]]:
|
||||
"""
|
||||
Build Memoro-compatible utterances from diarized segments.
|
||||
Groups consecutive segments by the same speaker.
|
||||
"""
|
||||
if not segments:
|
||||
return [], {}, {}
|
||||
|
||||
# Collect unique speakers
|
||||
speaker_labels = sorted(set(
|
||||
s.speaker for s in segments if s.speaker is not None
|
||||
))
|
||||
speaker_map: dict[str, int] = {}
|
||||
speakers: dict[str, str] = {}
|
||||
for idx, label in enumerate(speaker_labels):
|
||||
speaker_map[label] = idx
|
||||
speakers[str(idx)] = label
|
||||
|
||||
# Merge consecutive segments with the same speaker
|
||||
utterances: list[Utterance] = []
|
||||
current_speaker = None
|
||||
current_text_parts: list[str] = []
|
||||
current_start = 0.0
|
||||
current_end = 0.0
|
||||
|
||||
for seg in segments:
|
||||
sp = seg.speaker or "SPEAKER_00"
|
||||
if sp != current_speaker:
|
||||
# Flush previous
|
||||
if current_speaker is not None and current_text_parts:
|
||||
utterances.append(Utterance(
|
||||
speaker=speaker_map.get(current_speaker, 0),
|
||||
text=" ".join(current_text_parts).strip(),
|
||||
offset=int(current_start * 1000),
|
||||
duration=int((current_end - current_start) * 1000),
|
||||
))
|
||||
current_speaker = sp
|
||||
current_text_parts = [seg.text]
|
||||
current_start = seg.start
|
||||
current_end = seg.end
|
||||
else:
|
||||
current_text_parts.append(seg.text)
|
||||
current_end = seg.end
|
||||
|
||||
# Flush last
|
||||
if current_speaker is not None and current_text_parts:
|
||||
utterances.append(Utterance(
|
||||
speaker=speaker_map.get(current_speaker, 0),
|
||||
text=" ".join(current_text_parts).strip(),
|
||||
offset=int(current_start * 1000),
|
||||
duration=int((current_end - current_start) * 1000),
|
||||
))
|
||||
|
||||
return utterances, speakers, speaker_map
|
||||
|
||||
|
||||
def transcribe_audio(
|
||||
audio_path: str,
|
||||
language: Optional[str] = None,
|
||||
model_name: Optional[str] = None,
|
||||
enable_diarization: bool = True,
|
||||
enable_alignment: bool = True,
|
||||
min_speakers: Optional[int] = None,
|
||||
max_speakers: Optional[int] = None,
|
||||
) -> WhisperXResult:
|
||||
"""
|
||||
Transcribe audio with WhisperX: alignment + diarization.
|
||||
|
||||
Args:
|
||||
audio_path: Path to audio file
|
||||
language: Language code (e.g., 'de', 'en'). Auto-detect if None.
|
||||
model_name: Whisper model to use
|
||||
enable_diarization: Run speaker diarization
|
||||
enable_alignment: Run forced word alignment
|
||||
min_speakers: Minimum expected speakers (hint for diarization)
|
||||
max_speakers: Maximum expected speakers (hint for diarization)
|
||||
|
||||
Returns:
|
||||
WhisperXResult with full transcription, segments, utterances, speakers
|
||||
"""
|
||||
import whisperx
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
# 1. Load audio
|
||||
audio = whisperx.load_audio(audio_path)
|
||||
audio_duration = len(audio) / 16000 # whisperx resamples to 16kHz
|
||||
|
||||
# 2. Transcribe with faster-whisper
|
||||
model = get_whisper_model(model_name)
|
||||
transcribe_result = model.transcribe(
|
||||
audio,
|
||||
batch_size=BATCH_SIZE,
|
||||
language=language,
|
||||
)
|
||||
|
||||
detected_language = transcribe_result.get("language", language or "en")
|
||||
raw_segments = transcribe_result.get("segments", [])
|
||||
|
||||
logger.info(
|
||||
f"Transcription: {len(raw_segments)} segments, "
|
||||
f"language={detected_language}, "
|
||||
f"duration={audio_duration:.1f}s"
|
||||
)
|
||||
|
||||
# 3. Forced alignment (word-level timestamps)
|
||||
if enable_alignment and raw_segments:
|
||||
align_model, align_metadata = get_align_model(detected_language)
|
||||
if align_model is not None:
|
||||
try:
|
||||
transcribe_result = whisperx.align(
|
||||
raw_segments,
|
||||
align_model,
|
||||
align_metadata,
|
||||
audio,
|
||||
DEVICE,
|
||||
return_char_alignments=False,
|
||||
)
|
||||
raw_segments = transcribe_result.get("segments", raw_segments)
|
||||
logger.info("Word alignment complete")
|
||||
except Exception as e:
|
||||
logger.warning(f"Alignment failed, using segment-level timestamps: {e}")
|
||||
|
||||
# 4. Speaker diarization
|
||||
if enable_diarization:
|
||||
diarize_pipeline = get_diarize_pipeline()
|
||||
if diarize_pipeline is not None:
|
||||
try:
|
||||
diarize_kwargs = {}
|
||||
if min_speakers is not None:
|
||||
diarize_kwargs["min_speakers"] = min_speakers
|
||||
if max_speakers is not None:
|
||||
diarize_kwargs["max_speakers"] = max_speakers
|
||||
|
||||
diarize_segments = diarize_pipeline(
|
||||
audio_path,
|
||||
**diarize_kwargs,
|
||||
)
|
||||
transcribe_result = whisperx.assign_word_speakers(
|
||||
diarize_segments, transcribe_result
|
||||
)
|
||||
raw_segments = transcribe_result.get("segments", raw_segments)
|
||||
logger.info("Diarization complete")
|
||||
except Exception as e:
|
||||
logger.warning(f"Diarization failed: {e}")
|
||||
|
||||
# 5. Build structured result
|
||||
segments: list[SegmentInfo] = []
|
||||
all_words: list[WordInfo] = []
|
||||
full_text_parts: list[str] = []
|
||||
|
||||
for seg in raw_segments:
|
||||
seg_words: list[WordInfo] = []
|
||||
for w in seg.get("words", []):
|
||||
wi = WordInfo(
|
||||
word=w.get("word", ""),
|
||||
start=w.get("start", 0.0),
|
||||
end=w.get("end", 0.0),
|
||||
score=w.get("score", 0.0),
|
||||
speaker=w.get("speaker"),
|
||||
)
|
||||
seg_words.append(wi)
|
||||
all_words.append(wi)
|
||||
|
||||
segment = SegmentInfo(
|
||||
start=seg.get("start", 0.0),
|
||||
end=seg.get("end", 0.0),
|
||||
text=seg.get("text", "").strip(),
|
||||
speaker=seg.get("speaker"),
|
||||
words=seg_words,
|
||||
)
|
||||
segments.append(segment)
|
||||
full_text_parts.append(segment.text)
|
||||
|
||||
full_text = " ".join(full_text_parts)
|
||||
|
||||
# 6. Build utterances (Memoro-compatible)
|
||||
utterances, speakers, speaker_map = _build_utterances(segments)
|
||||
|
||||
latency = time.time() - start_time
|
||||
logger.info(f"WhisperX complete in {latency:.1f}s: {len(full_text)} chars, {len(speakers)} speakers")
|
||||
|
||||
return WhisperXResult(
|
||||
text=full_text,
|
||||
language=detected_language,
|
||||
duration_seconds=audio_duration,
|
||||
segments=segments,
|
||||
utterances=utterances,
|
||||
speakers=speakers,
|
||||
speaker_map=speaker_map,
|
||||
languages=[detected_language] if detected_language else [],
|
||||
primary_language=detected_language,
|
||||
words=all_words,
|
||||
)
|
||||
|
||||
|
||||
async def transcribe_audio_bytes(
|
||||
audio_bytes: bytes,
|
||||
filename: str,
|
||||
language: Optional[str] = None,
|
||||
model_name: Optional[str] = None,
|
||||
enable_diarization: bool = True,
|
||||
enable_alignment: bool = True,
|
||||
min_speakers: Optional[int] = None,
|
||||
max_speakers: Optional[int] = None,
|
||||
) -> WhisperXResult:
|
||||
"""
|
||||
Transcribe audio from bytes (for API uploads).
|
||||
|
||||
Args:
|
||||
audio_bytes: Raw audio file bytes
|
||||
filename: Original filename (for extension detection)
|
||||
language: Optional language code
|
||||
model_name: Whisper model to use
|
||||
enable_diarization: Run speaker diarization
|
||||
enable_alignment: Run forced word alignment
|
||||
min_speakers: Min expected speakers
|
||||
max_speakers: Max expected speakers
|
||||
|
||||
Returns:
|
||||
WhisperXResult with full transcription data
|
||||
"""
|
||||
ext = Path(filename).suffix or ".wav"
|
||||
|
||||
with tempfile.NamedTemporaryFile(suffix=ext, delete=False) as tmp:
|
||||
tmp.write(audio_bytes)
|
||||
tmp_path = tmp.name
|
||||
|
||||
try:
|
||||
return transcribe_audio(
|
||||
audio_path=tmp_path,
|
||||
language=language,
|
||||
model_name=model_name,
|
||||
enable_diarization=enable_diarization,
|
||||
enable_alignment=enable_alignment,
|
||||
min_speakers=min_speakers,
|
||||
max_speakers=max_speakers,
|
||||
)
|
||||
finally:
|
||||
try:
|
||||
os.unlink(tmp_path)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
def is_available() -> bool:
|
||||
"""Check if WhisperX dependencies are installed."""
|
||||
try:
|
||||
import whisperx
|
||||
return True
|
||||
except ImportError:
|
||||
return False
|
||||
|
||||
|
||||
AVAILABLE_MODELS = [
|
||||
"tiny",
|
||||
"tiny.en",
|
||||
"base",
|
||||
"base.en",
|
||||
"small",
|
||||
"small.en",
|
||||
"medium",
|
||||
"medium.en",
|
||||
"large-v1",
|
||||
"large-v2",
|
||||
"large-v3",
|
||||
"large-v3-turbo",
|
||||
"distil-large-v2",
|
||||
"distil-large-v3",
|
||||
]
|
||||
34
services/mana-stt/service.pyw
Normal file
34
services/mana-stt/service.pyw
Normal file
|
|
@ -0,0 +1,34 @@
|
|||
"""mana-stt service runner."""
|
||||
import os
|
||||
import sys
|
||||
|
||||
os.chdir(r"C:\mana\services\mana-stt")
|
||||
sys.path.insert(0, r"C:\mana\services\mana-stt")
|
||||
|
||||
# Redirect stdout/stderr to log file FIRST (before any imports that warn)
|
||||
log = open(r"C:\mana\services\mana-stt\service.log", "w", buffering=1)
|
||||
sys.stdout = log
|
||||
sys.stderr = log
|
||||
|
||||
# Load .env file
|
||||
from dotenv import load_dotenv
|
||||
load_dotenv(r"C:\mana\services\mana-stt\.env")
|
||||
|
||||
# Ensure FFmpeg is in PATH
|
||||
ffmpeg_dir = r"C:\Users\tills\AppData\Local\Microsoft\WinGet\Links"
|
||||
if ffmpeg_dir not in os.environ.get("PATH", ""):
|
||||
os.environ["PATH"] = ffmpeg_dir + os.pathsep + os.environ.get("PATH", "")
|
||||
|
||||
# Set HF token
|
||||
hf_token = os.environ.get("HF_TOKEN", "")
|
||||
if hf_token:
|
||||
os.environ["HUGGING_FACE_HUB_TOKEN"] = hf_token
|
||||
|
||||
# Pre-initialize CUDA before importing whisperx (avoids hangs)
|
||||
import torch
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.init()
|
||||
print(f"CUDA initialized: {torch.cuda.get_device_name(0)}", flush=True)
|
||||
|
||||
import uvicorn
|
||||
uvicorn.run("app.main:app", host="0.0.0.0", port=3020, log_level="info")
|
||||
|
|
@ -1,39 +1,271 @@
|
|||
"""
|
||||
API Key Authentication for Mana TTS Service.
|
||||
Delegates to shared mana_auth package.
|
||||
API Key Authentication for ManaCore STT Service
|
||||
|
||||
Supports two authentication modes:
|
||||
1. Local API keys: Configured via environment variables
|
||||
2. External API keys: Validated via mana-core-auth service (when EXTERNAL_AUTH_ENABLED=true)
|
||||
|
||||
Usage:
|
||||
# Local keys
|
||||
API_KEYS=sk-key1:name1,sk-key2:name2
|
||||
INTERNAL_API_KEY=sk-internal-xxx
|
||||
|
||||
# External auth (for user-created keys via mana.how)
|
||||
EXTERNAL_AUTH_ENABLED=true
|
||||
MANA_CORE_AUTH_URL=http://localhost:3001
|
||||
"""
|
||||
|
||||
import sys
|
||||
import os
|
||||
import time
|
||||
import logging
|
||||
from typing import Optional
|
||||
from collections import defaultdict
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "..", "..", "packages", "shared-python"))
|
||||
from fastapi import HTTPException, Security, Request
|
||||
from fastapi.security import APIKeyHeader
|
||||
|
||||
from mana_auth import (
|
||||
APIKey,
|
||||
AuthResult,
|
||||
RateLimitInfo,
|
||||
verify_api_key as _verify_api_key,
|
||||
get_api_key_stats,
|
||||
reload_api_keys,
|
||||
api_key_header,
|
||||
create_auth_dependency,
|
||||
)
|
||||
from mana_auth.external_auth import (
|
||||
ExternalValidationResult,
|
||||
from .external_auth import (
|
||||
is_external_auth_enabled,
|
||||
validate_api_key_external,
|
||||
ExternalValidationResult,
|
||||
)
|
||||
|
||||
from typing import Optional
|
||||
from fastapi import Security, Request
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# TTS-specific auth dependency
|
||||
verify_tts_key = create_auth_dependency("tts")
|
||||
# Configuration
|
||||
API_KEYS_ENV = os.getenv("API_KEYS", "") # Format: "sk-key1:name1,sk-key2:name2"
|
||||
INTERNAL_API_KEY = os.getenv("INTERNAL_API_KEY", "") # Unlimited internal key
|
||||
REQUIRE_AUTH = os.getenv("REQUIRE_AUTH", "true").lower() == "true"
|
||||
RATE_LIMIT_REQUESTS = int(os.getenv("RATE_LIMIT_REQUESTS", "60")) # Per minute
|
||||
RATE_LIMIT_WINDOW = int(os.getenv("RATE_LIMIT_WINDOW", "60")) # Seconds
|
||||
|
||||
|
||||
@dataclass
|
||||
class APIKey:
|
||||
"""API Key with metadata."""
|
||||
key: str
|
||||
name: str
|
||||
is_internal: bool = False
|
||||
rate_limit: int = RATE_LIMIT_REQUESTS # Requests per window
|
||||
|
||||
|
||||
@dataclass
|
||||
class RateLimitInfo:
|
||||
"""Rate limit tracking per key."""
|
||||
requests: list = field(default_factory=list)
|
||||
|
||||
def is_allowed(self, limit: int, window: int) -> bool:
|
||||
"""Check if request is allowed within rate limit."""
|
||||
now = time.time()
|
||||
# Remove old requests outside window
|
||||
self.requests = [t for t in self.requests if now - t < window]
|
||||
|
||||
if len(self.requests) >= limit:
|
||||
return False
|
||||
|
||||
self.requests.append(now)
|
||||
return True
|
||||
|
||||
def remaining(self, limit: int, window: int) -> int:
|
||||
"""Get remaining requests in current window."""
|
||||
now = time.time()
|
||||
self.requests = [t for t in self.requests if now - t < window]
|
||||
return max(0, limit - len(self.requests))
|
||||
|
||||
|
||||
# Parse API keys from environment
|
||||
def _parse_api_keys() -> dict[str, APIKey]:
|
||||
"""Parse API keys from environment variables."""
|
||||
keys = {}
|
||||
|
||||
# Parse comma-separated keys
|
||||
if API_KEYS_ENV:
|
||||
for entry in API_KEYS_ENV.split(","):
|
||||
entry = entry.strip()
|
||||
if ":" in entry:
|
||||
key, name = entry.split(":", 1)
|
||||
else:
|
||||
key, name = entry, "default"
|
||||
keys[key.strip()] = APIKey(key=key.strip(), name=name.strip())
|
||||
|
||||
# Add internal key with no rate limit
|
||||
if INTERNAL_API_KEY:
|
||||
keys[INTERNAL_API_KEY] = APIKey(
|
||||
key=INTERNAL_API_KEY,
|
||||
name="internal",
|
||||
is_internal=True,
|
||||
rate_limit=999999, # Effectively unlimited
|
||||
)
|
||||
|
||||
return keys
|
||||
|
||||
|
||||
# Global state
|
||||
_api_keys = _parse_api_keys()
|
||||
_rate_limits: dict[str, RateLimitInfo] = defaultdict(RateLimitInfo)
|
||||
|
||||
# Security scheme
|
||||
api_key_header = APIKeyHeader(name="X-API-Key", auto_error=False)
|
||||
|
||||
|
||||
@dataclass
|
||||
class AuthResult:
|
||||
"""Result of authentication check."""
|
||||
authenticated: bool
|
||||
key_name: Optional[str] = None
|
||||
is_internal: bool = False
|
||||
rate_limit_remaining: Optional[int] = None
|
||||
user_id: Optional[str] = None # Set when using external auth
|
||||
|
||||
|
||||
async def verify_api_key(
|
||||
request: Request,
|
||||
api_key: Optional[str] = Security(api_key_header),
|
||||
) -> AuthResult:
|
||||
"""Verify API key with TTS scope."""
|
||||
return await _verify_api_key(request, scope="tts", api_key=api_key)
|
||||
"""
|
||||
Verify API key and check rate limits.
|
||||
|
||||
Supports two authentication modes:
|
||||
1. External auth via mana-core-auth (for sk_live_ keys)
|
||||
2. Local auth via environment variables
|
||||
|
||||
Returns AuthResult with authentication status.
|
||||
Raises HTTPException if auth fails or rate limited.
|
||||
"""
|
||||
# Skip auth for health and docs endpoints
|
||||
path = request.url.path
|
||||
if path in ["/health", "/docs", "/openapi.json", "/redoc"]:
|
||||
return AuthResult(authenticated=True, key_name="public")
|
||||
|
||||
# If auth not required, allow all
|
||||
if not REQUIRE_AUTH:
|
||||
return AuthResult(authenticated=True, key_name="anonymous")
|
||||
|
||||
# Check for API key
|
||||
if not api_key:
|
||||
logger.warning(f"Missing API key for {path} from {request.client.host if request.client else 'unknown'}")
|
||||
raise HTTPException(
|
||||
status_code=401,
|
||||
detail="Missing API key. Provide X-API-Key header.",
|
||||
headers={"WWW-Authenticate": "ApiKey"},
|
||||
)
|
||||
|
||||
# Try external auth first for sk_live_ keys (user-created keys via mana.how)
|
||||
if api_key.startswith("sk_live_") and is_external_auth_enabled():
|
||||
external_result = await validate_api_key_external(api_key, "stt")
|
||||
|
||||
if external_result is not None:
|
||||
if external_result.valid:
|
||||
# Use rate limits from external auth
|
||||
rate_info = _rate_limits[api_key]
|
||||
limit = external_result.rate_limit_requests
|
||||
window = external_result.rate_limit_window
|
||||
|
||||
if not rate_info.is_allowed(limit, window):
|
||||
remaining = rate_info.remaining(limit, window)
|
||||
logger.warning(f"Rate limit exceeded for external key")
|
||||
raise HTTPException(
|
||||
status_code=429,
|
||||
detail=f"Rate limit exceeded. Try again in {window} seconds.",
|
||||
headers={
|
||||
"X-RateLimit-Limit": str(limit),
|
||||
"X-RateLimit-Remaining": str(remaining),
|
||||
"X-RateLimit-Reset": str(int(time.time()) + window),
|
||||
"Retry-After": str(window),
|
||||
},
|
||||
)
|
||||
|
||||
remaining = rate_info.remaining(limit, window)
|
||||
logger.debug(f"Authenticated external request from user {external_result.user_id} to {path}")
|
||||
|
||||
return AuthResult(
|
||||
authenticated=True,
|
||||
key_name="external",
|
||||
is_internal=False,
|
||||
rate_limit_remaining=remaining,
|
||||
user_id=external_result.user_id,
|
||||
)
|
||||
else:
|
||||
# External auth returned invalid
|
||||
logger.warning(f"External auth failed: {external_result.error}")
|
||||
raise HTTPException(
|
||||
status_code=401,
|
||||
detail=external_result.error or "Invalid API key.",
|
||||
headers={"WWW-Authenticate": "ApiKey"},
|
||||
)
|
||||
# If external_result is None, fall through to local auth
|
||||
|
||||
# Local auth: Validate key against environment variables
|
||||
if api_key not in _api_keys:
|
||||
logger.warning(f"Invalid API key attempt for {path}")
|
||||
raise HTTPException(
|
||||
status_code=401,
|
||||
detail="Invalid API key.",
|
||||
headers={"WWW-Authenticate": "ApiKey"},
|
||||
)
|
||||
|
||||
key_info = _api_keys[api_key]
|
||||
|
||||
# Check rate limit (skip for internal keys)
|
||||
if not key_info.is_internal:
|
||||
rate_info = _rate_limits[api_key]
|
||||
if not rate_info.is_allowed(key_info.rate_limit, RATE_LIMIT_WINDOW):
|
||||
remaining = rate_info.remaining(key_info.rate_limit, RATE_LIMIT_WINDOW)
|
||||
logger.warning(f"Rate limit exceeded for key '{key_info.name}'")
|
||||
raise HTTPException(
|
||||
status_code=429,
|
||||
detail=f"Rate limit exceeded. Try again in {RATE_LIMIT_WINDOW} seconds.",
|
||||
headers={
|
||||
"X-RateLimit-Limit": str(key_info.rate_limit),
|
||||
"X-RateLimit-Remaining": str(remaining),
|
||||
"X-RateLimit-Reset": str(int(time.time()) + RATE_LIMIT_WINDOW),
|
||||
"Retry-After": str(RATE_LIMIT_WINDOW),
|
||||
},
|
||||
)
|
||||
remaining = rate_info.remaining(key_info.rate_limit, RATE_LIMIT_WINDOW)
|
||||
else:
|
||||
remaining = None
|
||||
|
||||
logger.debug(f"Authenticated request from '{key_info.name}' to {path}")
|
||||
|
||||
return AuthResult(
|
||||
authenticated=True,
|
||||
key_name=key_info.name,
|
||||
is_internal=key_info.is_internal,
|
||||
rate_limit_remaining=remaining,
|
||||
)
|
||||
|
||||
|
||||
def get_api_key_stats() -> dict:
|
||||
"""Get statistics about API keys (for admin endpoint)."""
|
||||
stats = {
|
||||
"total_keys": len(_api_keys),
|
||||
"auth_required": REQUIRE_AUTH,
|
||||
"rate_limit": {
|
||||
"requests_per_window": RATE_LIMIT_REQUESTS,
|
||||
"window_seconds": RATE_LIMIT_WINDOW,
|
||||
},
|
||||
"keys": [],
|
||||
}
|
||||
|
||||
for key, info in _api_keys.items():
|
||||
# Don't expose actual keys, just metadata
|
||||
masked_key = key[:8] + "..." if len(key) > 8 else "***"
|
||||
rate_info = _rate_limits.get(key, RateLimitInfo())
|
||||
stats["keys"].append({
|
||||
"name": info.name,
|
||||
"key_prefix": masked_key,
|
||||
"is_internal": info.is_internal,
|
||||
"requests_in_window": len(rate_info.requests),
|
||||
"remaining": rate_info.remaining(info.rate_limit, RATE_LIMIT_WINDOW),
|
||||
})
|
||||
|
||||
return stats
|
||||
|
||||
|
||||
def reload_api_keys():
|
||||
"""Reload API keys from environment (for runtime updates)."""
|
||||
global _api_keys
|
||||
_api_keys = _parse_api_keys()
|
||||
logger.info(f"Reloaded {len(_api_keys)} API keys")
|
||||
|
|
|
|||
|
|
@ -1,22 +1,145 @@
|
|||
"""
|
||||
External API Key Validation — delegates to shared mana_auth package.
|
||||
External API Key Validation via mana-core-auth
|
||||
|
||||
When EXTERNAL_AUTH_ENABLED=true, API keys are validated against the
|
||||
central mana-core-auth service. This allows users to create and manage
|
||||
API keys from the mana.how web interface.
|
||||
|
||||
Results are cached for 5 minutes to reduce load on the auth service.
|
||||
"""
|
||||
|
||||
import sys
|
||||
import os
|
||||
import time
|
||||
import logging
|
||||
import httpx
|
||||
from typing import Optional
|
||||
from dataclasses import dataclass
|
||||
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "..", "..", "packages", "shared-python"))
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
from mana_auth.external_auth import (
|
||||
ExternalValidationResult,
|
||||
is_external_auth_enabled,
|
||||
validate_api_key_external,
|
||||
clear_cache,
|
||||
)
|
||||
# Configuration
|
||||
EXTERNAL_AUTH_ENABLED = os.getenv("EXTERNAL_AUTH_ENABLED", "false").lower() == "true"
|
||||
MANA_CORE_AUTH_URL = os.getenv("MANA_CORE_AUTH_URL", "http://localhost:3001")
|
||||
API_KEY_CACHE_TTL = int(os.getenv("API_KEY_CACHE_TTL", "300")) # 5 minutes
|
||||
EXTERNAL_AUTH_TIMEOUT = float(os.getenv("EXTERNAL_AUTH_TIMEOUT", "5.0")) # seconds
|
||||
|
||||
__all__ = [
|
||||
"ExternalValidationResult",
|
||||
"is_external_auth_enabled",
|
||||
"validate_api_key_external",
|
||||
"clear_cache",
|
||||
]
|
||||
|
||||
@dataclass
|
||||
class ExternalValidationResult:
|
||||
"""Result from external API key validation."""
|
||||
valid: bool
|
||||
user_id: Optional[str] = None
|
||||
scopes: Optional[list] = None
|
||||
rate_limit_requests: int = 60
|
||||
rate_limit_window: int = 60
|
||||
error: Optional[str] = None
|
||||
cached_at: float = 0.0
|
||||
|
||||
|
||||
# In-memory cache for validation results
|
||||
# Key: API key, Value: ExternalValidationResult
|
||||
_validation_cache: dict[str, ExternalValidationResult] = {}
|
||||
|
||||
|
||||
def is_external_auth_enabled() -> bool:
|
||||
"""Check if external authentication is enabled."""
|
||||
return EXTERNAL_AUTH_ENABLED
|
||||
|
||||
|
||||
def _get_cached_result(api_key: str) -> Optional[ExternalValidationResult]:
|
||||
"""Get cached validation result if still valid."""
|
||||
result = _validation_cache.get(api_key)
|
||||
if result and (time.time() - result.cached_at) < API_KEY_CACHE_TTL:
|
||||
return result
|
||||
return None
|
||||
|
||||
|
||||
def _cache_result(api_key: str, result: ExternalValidationResult):
|
||||
"""Cache a validation result."""
|
||||
result.cached_at = time.time()
|
||||
_validation_cache[api_key] = result
|
||||
|
||||
# Clean up old entries periodically (keep cache size manageable)
|
||||
if len(_validation_cache) > 1000:
|
||||
now = time.time()
|
||||
expired_keys = [
|
||||
k for k, v in _validation_cache.items()
|
||||
if (now - v.cached_at) >= API_KEY_CACHE_TTL
|
||||
]
|
||||
for k in expired_keys:
|
||||
del _validation_cache[k]
|
||||
|
||||
|
||||
async def validate_api_key_external(api_key: str, scope: str) -> Optional[ExternalValidationResult]:
|
||||
"""
|
||||
Validate an API key against mana-core-auth service.
|
||||
|
||||
Args:
|
||||
api_key: The API key to validate (e.g., "sk_live_...")
|
||||
scope: The required scope (e.g., "stt" or "tts")
|
||||
|
||||
Returns:
|
||||
ExternalValidationResult if external auth is enabled and the key was validated.
|
||||
None if external auth is disabled or the service is unavailable (fallback to local).
|
||||
"""
|
||||
if not EXTERNAL_AUTH_ENABLED:
|
||||
return None
|
||||
|
||||
# Check cache first
|
||||
cached = _get_cached_result(api_key)
|
||||
if cached:
|
||||
logger.debug(f"Using cached validation result for key prefix: {api_key[:12]}...")
|
||||
# Check scope against cached result
|
||||
if cached.valid and cached.scopes and scope not in cached.scopes:
|
||||
return ExternalValidationResult(
|
||||
valid=False,
|
||||
error=f"API key does not have scope: {scope}",
|
||||
)
|
||||
return cached
|
||||
|
||||
# Call mana-core-auth validation endpoint
|
||||
try:
|
||||
async with httpx.AsyncClient(timeout=EXTERNAL_AUTH_TIMEOUT) as client:
|
||||
response = await client.post(
|
||||
f"{MANA_CORE_AUTH_URL}/api/v1/api-keys/validate",
|
||||
json={"apiKey": api_key, "scope": scope},
|
||||
)
|
||||
|
||||
if response.status_code == 200:
|
||||
data = response.json()
|
||||
result = ExternalValidationResult(
|
||||
valid=data.get("valid", False),
|
||||
user_id=data.get("userId"),
|
||||
scopes=data.get("scopes", []),
|
||||
rate_limit_requests=data.get("rateLimit", {}).get("requests", 60),
|
||||
rate_limit_window=data.get("rateLimit", {}).get("window", 60),
|
||||
error=data.get("error"),
|
||||
)
|
||||
_cache_result(api_key, result)
|
||||
return result
|
||||
else:
|
||||
logger.warning(
|
||||
f"External auth returned status {response.status_code}: {response.text}"
|
||||
)
|
||||
# Don't cache errors - allow retry
|
||||
return ExternalValidationResult(
|
||||
valid=False,
|
||||
error=f"Auth service returned {response.status_code}",
|
||||
)
|
||||
|
||||
except httpx.TimeoutException:
|
||||
logger.warning("External auth service timeout - falling back to local auth")
|
||||
return None
|
||||
except httpx.ConnectError:
|
||||
logger.warning("Cannot connect to external auth service - falling back to local auth")
|
||||
return None
|
||||
except Exception as e:
|
||||
logger.error(f"External auth error: {e}")
|
||||
return None
|
||||
|
||||
|
||||
def clear_cache():
|
||||
"""Clear the validation cache (for testing or runtime updates)."""
|
||||
global _validation_cache
|
||||
_validation_cache.clear()
|
||||
logger.info("External auth cache cleared")
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
"""
|
||||
F5-TTS Service for voice cloning synthesis.
|
||||
Uses f5-tts-mlx optimized for Apple Silicon.
|
||||
CUDA version using f5-tts PyTorch package.
|
||||
"""
|
||||
|
||||
import logging
|
||||
|
|
@ -15,14 +15,12 @@ import numpy as np
|
|||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Global singleton for lazy initialization
|
||||
_f5_model = None
|
||||
_f5_model_name = None
|
||||
_f5_api = None
|
||||
|
||||
# Default model
|
||||
DEFAULT_F5_MODEL = os.getenv("F5_MODEL", "lucasnewman/f5-tts-mlx")
|
||||
DEFAULT_F5_MODEL = os.getenv("F5_MODEL", "F5-TTS")
|
||||
|
||||
# Default generation parameters
|
||||
DEFAULT_DURATION = 10.0 # seconds
|
||||
DEFAULT_STEPS = 32
|
||||
DEFAULT_CFG_STRENGTH = 2.0
|
||||
DEFAULT_SWAY_COEF = -1.0
|
||||
|
|
@ -40,35 +38,25 @@ class F5Result:
|
|||
|
||||
|
||||
def get_f5_model(model_name: str = DEFAULT_F5_MODEL):
|
||||
"""
|
||||
Get or create F5-TTS model instance (singleton pattern).
|
||||
"""Get or create F5-TTS API instance (singleton pattern)."""
|
||||
global _f5_api
|
||||
|
||||
Args:
|
||||
model_name: HuggingFace model identifier
|
||||
|
||||
Returns:
|
||||
F5TTS model instance
|
||||
"""
|
||||
global _f5_model, _f5_model_name
|
||||
|
||||
# Return existing model if same model name
|
||||
if _f5_model is not None and _f5_model_name == model_name:
|
||||
return _f5_model
|
||||
if _f5_api is not None:
|
||||
return _f5_api
|
||||
|
||||
logger.info(f"Loading F5-TTS model: {model_name}")
|
||||
|
||||
try:
|
||||
from f5_tts_mlx import F5TTS
|
||||
from f5_tts.api import F5TTS
|
||||
|
||||
_f5_model = F5TTS(model_name=model_name)
|
||||
_f5_model_name = model_name
|
||||
logger.info("F5-TTS model loaded successfully")
|
||||
return _f5_model
|
||||
_f5_api = F5TTS(model_type="F5-TTS")
|
||||
logger.info("F5-TTS model loaded successfully (CUDA)")
|
||||
return _f5_api
|
||||
|
||||
except ImportError as e:
|
||||
logger.error(f"Failed to import f5_tts_mlx: {e}")
|
||||
logger.error(f"Failed to import f5_tts: {e}")
|
||||
raise RuntimeError(
|
||||
"f5-tts-mlx not installed. Run: pip install f5-tts-mlx"
|
||||
"f5-tts not installed. Run: pip install f5-tts"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to load F5-TTS model: {e}")
|
||||
|
|
@ -77,7 +65,7 @@ def get_f5_model(model_name: str = DEFAULT_F5_MODEL):
|
|||
|
||||
def is_f5_loaded() -> bool:
|
||||
"""Check if F5-TTS model is currently loaded."""
|
||||
return _f5_model is not None
|
||||
return _f5_api is not None
|
||||
|
||||
|
||||
async def synthesize_f5(
|
||||
|
|
@ -103,13 +91,14 @@ async def synthesize_f5(
|
|||
cfg_strength: Classifier-free guidance strength
|
||||
sway_coef: Sway sampling coefficient
|
||||
speed: Speech speed multiplier
|
||||
model_name: HuggingFace model identifier
|
||||
model_name: Model identifier
|
||||
|
||||
Returns:
|
||||
F5Result with audio data
|
||||
"""
|
||||
# Get model
|
||||
model = get_f5_model(model_name)
|
||||
import asyncio
|
||||
|
||||
api = get_f5_model(model_name)
|
||||
|
||||
logger.info(
|
||||
f"Synthesizing with F5-TTS: text_length={len(text)}, "
|
||||
|
|
@ -117,17 +106,26 @@ async def synthesize_f5(
|
|||
)
|
||||
|
||||
try:
|
||||
# Generate audio
|
||||
audio, sample_rate = model.generate(
|
||||
text=text,
|
||||
ref_audio_path=reference_audio_path,
|
||||
ref_audio_text=reference_text,
|
||||
duration=duration,
|
||||
steps=steps,
|
||||
cfg_strength=cfg_strength,
|
||||
sway_coef=sway_coef,
|
||||
speed=speed,
|
||||
)
|
||||
# F5-TTS API infer method (runs synchronously, offload to thread)
|
||||
loop = asyncio.get_event_loop()
|
||||
|
||||
def _generate():
|
||||
wav, sr, _ = api.infer(
|
||||
ref_file=reference_audio_path,
|
||||
ref_text=reference_text,
|
||||
gen_text=text,
|
||||
nfe_step=steps,
|
||||
cfg_strength=cfg_strength,
|
||||
sway_sampling_coeff=sway_coef,
|
||||
speed=speed,
|
||||
)
|
||||
return wav, sr
|
||||
|
||||
audio, sample_rate = await loop.run_in_executor(None, _generate)
|
||||
|
||||
# Convert to numpy if needed
|
||||
if not isinstance(audio, np.ndarray):
|
||||
audio = np.array(audio, dtype=np.float32)
|
||||
|
||||
# Calculate duration
|
||||
audio_duration = len(audio) / sample_rate
|
||||
|
|
@ -152,24 +150,8 @@ async def synthesize_f5_from_bytes(
|
|||
audio_extension: str = ".wav",
|
||||
**kwargs,
|
||||
) -> F5Result:
|
||||
"""
|
||||
Synthesize speech using F5-TTS with reference audio as bytes.
|
||||
|
||||
Args:
|
||||
text: Text to synthesize
|
||||
reference_audio_bytes: Reference audio as bytes
|
||||
reference_text: Transcript of the reference audio
|
||||
audio_extension: File extension for temp file
|
||||
**kwargs: Additional arguments passed to synthesize_f5
|
||||
|
||||
Returns:
|
||||
F5Result with audio data
|
||||
"""
|
||||
# Save reference audio to temp file
|
||||
with tempfile.NamedTemporaryFile(
|
||||
suffix=audio_extension,
|
||||
delete=False,
|
||||
) as tmp:
|
||||
"""Synthesize speech using F5-TTS with reference audio as bytes."""
|
||||
with tempfile.NamedTemporaryFile(suffix=audio_extension, delete=False) as tmp:
|
||||
tmp.write(reference_audio_bytes)
|
||||
tmp_path = tmp.name
|
||||
|
||||
|
|
@ -182,7 +164,6 @@ async def synthesize_f5_from_bytes(
|
|||
)
|
||||
return result
|
||||
finally:
|
||||
# Clean up temp file
|
||||
try:
|
||||
Path(tmp_path).unlink()
|
||||
except Exception:
|
||||
|
|
@ -190,18 +171,7 @@ async def synthesize_f5_from_bytes(
|
|||
|
||||
|
||||
def estimate_duration(text: str, speed: float = 1.0) -> float:
|
||||
"""
|
||||
Estimate audio duration from text.
|
||||
|
||||
Args:
|
||||
text: Text to synthesize
|
||||
speed: Speech speed multiplier
|
||||
|
||||
Returns:
|
||||
Estimated duration in seconds
|
||||
"""
|
||||
# Rough estimate: ~150 words per minute at normal speed
|
||||
# Average word length: ~5 characters
|
||||
"""Estimate audio duration from text."""
|
||||
words = len(text) / 5
|
||||
minutes = words / 150
|
||||
seconds = minutes * 60
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
"""
|
||||
Kokoro TTS Service for fast preset voice synthesis.
|
||||
Uses mlx-audio's Kokoro implementation optimized for Apple Silicon.
|
||||
CUDA version using kokoro PyTorch package.
|
||||
"""
|
||||
|
||||
import logging
|
||||
|
|
@ -12,11 +12,10 @@ import numpy as np
|
|||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Global singleton for lazy initialization
|
||||
_kokoro_model = None
|
||||
_kokoro_model_name = None
|
||||
_kokoro_pipeline = None
|
||||
|
||||
# Default model
|
||||
DEFAULT_KOKORO_MODEL = "mlx-community/Kokoro-82M-bf16"
|
||||
DEFAULT_KOKORO_MODEL = "hexgrad/Kokoro-82M"
|
||||
|
||||
# Available Kokoro voices (American Female/Male, British Female/Male)
|
||||
KOKORO_VOICES = {
|
||||
|
|
@ -67,35 +66,25 @@ class KokoroResult:
|
|||
|
||||
|
||||
def get_kokoro_model(model_name: str = DEFAULT_KOKORO_MODEL):
|
||||
"""
|
||||
Get or create Kokoro model instance (singleton pattern).
|
||||
"""Get or create Kokoro pipeline instance (singleton pattern)."""
|
||||
global _kokoro_pipeline
|
||||
|
||||
Args:
|
||||
model_name: HuggingFace model identifier
|
||||
|
||||
Returns:
|
||||
Kokoro model instance
|
||||
"""
|
||||
global _kokoro_model, _kokoro_model_name
|
||||
|
||||
# Return existing model if same model name
|
||||
if _kokoro_model is not None and _kokoro_model_name == model_name:
|
||||
return _kokoro_model
|
||||
if _kokoro_pipeline is not None:
|
||||
return _kokoro_pipeline
|
||||
|
||||
logger.info(f"Loading Kokoro model: {model_name}")
|
||||
|
||||
try:
|
||||
from mlx_audio.tts import load
|
||||
from kokoro import KPipeline
|
||||
|
||||
_kokoro_model = load(model_name)
|
||||
_kokoro_model_name = model_name
|
||||
logger.info("Kokoro model loaded successfully")
|
||||
return _kokoro_model
|
||||
_kokoro_pipeline = KPipeline(lang_code="a") # 'a' for American English
|
||||
logger.info("Kokoro pipeline loaded successfully")
|
||||
return _kokoro_pipeline
|
||||
|
||||
except ImportError as e:
|
||||
logger.error(f"Failed to import mlx_audio: {e}")
|
||||
logger.error(f"Failed to import kokoro: {e}")
|
||||
raise RuntimeError(
|
||||
"mlx-audio not installed. Run: pip install mlx-audio"
|
||||
"kokoro not installed. Run: pip install kokoro"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to load Kokoro model: {e}")
|
||||
|
|
@ -104,7 +93,7 @@ def get_kokoro_model(model_name: str = DEFAULT_KOKORO_MODEL):
|
|||
|
||||
def is_kokoro_loaded() -> bool:
|
||||
"""Check if Kokoro model is currently loaded."""
|
||||
return _kokoro_model is not None
|
||||
return _kokoro_pipeline is not None
|
||||
|
||||
|
||||
def get_available_voices() -> dict[str, str]:
|
||||
|
|
@ -125,7 +114,7 @@ async def synthesize_kokoro(
|
|||
text: Text to synthesize
|
||||
voice: Voice ID from KOKORO_VOICES
|
||||
speed: Speech speed multiplier (0.5-2.0)
|
||||
model_name: HuggingFace model identifier
|
||||
model_name: Model identifier
|
||||
|
||||
Returns:
|
||||
KokoroResult with audio data
|
||||
|
|
@ -139,29 +128,18 @@ async def synthesize_kokoro(
|
|||
speed = max(0.5, min(2.0, speed))
|
||||
|
||||
# Get model
|
||||
model = get_kokoro_model(model_name)
|
||||
pipeline = get_kokoro_model(model_name)
|
||||
|
||||
logger.info(f"Synthesizing with Kokoro: voice={voice}, speed={speed}, text_length={len(text)}")
|
||||
|
||||
try:
|
||||
# Generate audio using mlx-audio's generate method
|
||||
# Returns a generator of GenerationResult objects
|
||||
result_gen = model.generate(
|
||||
text=text,
|
||||
voice=voice,
|
||||
speed=speed,
|
||||
)
|
||||
|
||||
# Collect all audio chunks from the generator
|
||||
# Generate audio using kokoro pipeline
|
||||
audio_chunks = []
|
||||
sample_rate = 24000 # Default, will be updated from result
|
||||
sample_rate = 24000 # Kokoro default
|
||||
|
||||
for result in result_gen:
|
||||
# Each result has audio, sample_rate, audio_duration (string)
|
||||
sample_rate = result.sample_rate
|
||||
|
||||
# Convert MLX array to numpy
|
||||
audio_np = np.array(result.audio, dtype=np.float32)
|
||||
for result in pipeline(text, voice=voice, speed=speed):
|
||||
# result is a KPipelineResult with .audio (tensor) and .graphemes/.phonemes
|
||||
audio_np = result.audio.numpy()
|
||||
audio_chunks.append(audio_np)
|
||||
|
||||
# Concatenate all chunks
|
||||
|
|
|
|||
114
services/mana-tts/app/vram_manager.py
Normal file
114
services/mana-tts/app/vram_manager.py
Normal file
|
|
@ -0,0 +1,114 @@
|
|||
"""
|
||||
VRAM Manager — Automatic model unloading after idle timeout.
|
||||
|
||||
Tracks last usage time per model and unloads after configurable timeout.
|
||||
Designed for shared GPU environments (multiple services on one RTX 3090).
|
||||
|
||||
Usage in a service:
|
||||
from vram_manager import VramManager
|
||||
|
||||
vram = VramManager(idle_timeout=300) # 5 min
|
||||
|
||||
# Before using a model
|
||||
vram.touch()
|
||||
|
||||
# Call periodically (e.g., from health check or background task)
|
||||
vram.check_idle(unload_fn=my_unload_function)
|
||||
"""
|
||||
|
||||
import os
|
||||
import time
|
||||
import logging
|
||||
import threading
|
||||
from typing import Optional, Callable
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
DEFAULT_IDLE_TIMEOUT = int(os.getenv("VRAM_IDLE_TIMEOUT", "300")) # 5 minutes
|
||||
|
||||
|
||||
class VramManager:
|
||||
def __init__(self, idle_timeout: int = DEFAULT_IDLE_TIMEOUT, service_name: str = "unknown"):
|
||||
self.idle_timeout = idle_timeout
|
||||
self.service_name = service_name
|
||||
self.last_used: float = 0.0
|
||||
self.model_loaded: bool = False
|
||||
self._lock = threading.Lock()
|
||||
self._timer: Optional[threading.Timer] = None
|
||||
|
||||
def touch(self):
|
||||
"""Mark the model as recently used. Call before/after each inference."""
|
||||
with self._lock:
|
||||
self.last_used = time.time()
|
||||
self.model_loaded = True
|
||||
self._schedule_check()
|
||||
|
||||
def mark_loaded(self):
|
||||
"""Mark that a model has been loaded into VRAM."""
|
||||
with self._lock:
|
||||
self.model_loaded = True
|
||||
self.last_used = time.time()
|
||||
self._schedule_check()
|
||||
logger.info(f"[{self.service_name}] Model loaded, idle timeout: {self.idle_timeout}s")
|
||||
|
||||
def mark_unloaded(self):
|
||||
"""Mark that a model has been unloaded from VRAM."""
|
||||
with self._lock:
|
||||
self.model_loaded = False
|
||||
if self._timer:
|
||||
self._timer.cancel()
|
||||
self._timer = None
|
||||
logger.info(f"[{self.service_name}] Model unloaded, VRAM freed")
|
||||
|
||||
def is_idle(self) -> bool:
|
||||
"""Check if the model has been idle longer than the timeout."""
|
||||
if not self.model_loaded:
|
||||
return False
|
||||
return (time.time() - self.last_used) > self.idle_timeout
|
||||
|
||||
def seconds_until_unload(self) -> Optional[float]:
|
||||
"""Seconds until the model will be unloaded, or None if not loaded."""
|
||||
if not self.model_loaded:
|
||||
return None
|
||||
remaining = self.idle_timeout - (time.time() - self.last_used)
|
||||
return max(0, remaining)
|
||||
|
||||
def check_and_unload(self, unload_fn: Callable[[], None]) -> bool:
|
||||
"""Check if idle and unload if so. Returns True if unloaded."""
|
||||
if self.is_idle():
|
||||
logger.info(f"[{self.service_name}] Idle for >{self.idle_timeout}s, unloading model...")
|
||||
try:
|
||||
unload_fn()
|
||||
self.mark_unloaded()
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"[{self.service_name}] Failed to unload: {e}")
|
||||
return False
|
||||
|
||||
def _schedule_check(self):
|
||||
"""Schedule an idle check after the timeout period."""
|
||||
if self._timer:
|
||||
self._timer.cancel()
|
||||
|
||||
self._timer = threading.Timer(
|
||||
self.idle_timeout + 5, # Small buffer
|
||||
self._auto_check,
|
||||
)
|
||||
self._timer.daemon = True
|
||||
self._timer.start()
|
||||
|
||||
def _auto_check(self):
|
||||
"""Auto-triggered idle check (called by timer)."""
|
||||
# This is just a log — actual unloading needs the unload_fn
|
||||
# which depends on the service. The service should call check_and_unload.
|
||||
if self.is_idle():
|
||||
logger.info(f"[{self.service_name}] Model idle for >{self.idle_timeout}s — ready to unload")
|
||||
|
||||
def status(self) -> dict:
|
||||
"""Get current VRAM manager status."""
|
||||
return {
|
||||
"model_loaded": self.model_loaded,
|
||||
"idle_seconds": round(time.time() - self.last_used, 1) if self.model_loaded else None,
|
||||
"idle_timeout": self.idle_timeout,
|
||||
"seconds_until_unload": round(self.seconds_until_unload(), 1) if self.model_loaded else None,
|
||||
}
|
||||
17
services/mana-tts/service.pyw
Normal file
17
services/mana-tts/service.pyw
Normal file
|
|
@ -0,0 +1,17 @@
|
|||
"""mana-tts service runner."""
|
||||
import os
|
||||
import sys
|
||||
os.chdir(r"C:\mana\services\mana-tts")
|
||||
sys.path.insert(0, r"C:\mana\services\mana-tts")
|
||||
|
||||
# Load .env file
|
||||
from dotenv import load_dotenv
|
||||
load_dotenv(r"C:\mana\services\mana-tts\.env")
|
||||
|
||||
# Redirect stdout/stderr to log file
|
||||
log = open(r"C:\mana\services\mana-tts\service.log", "w", buffering=1)
|
||||
sys.stdout = log
|
||||
sys.stderr = log
|
||||
|
||||
import uvicorn
|
||||
uvicorn.run("app.main:app", host="0.0.0.0", port=3022, log_level="info")
|
||||
37
services/mana-video-gen/service.pyw
Normal file
37
services/mana-video-gen/service.pyw
Normal file
|
|
@ -0,0 +1,37 @@
|
|||
"""mana-video-gen service runner."""
|
||||
import os
|
||||
import sys
|
||||
|
||||
os.chdir(r"C:\mana\services\mana-video-gen")
|
||||
sys.path.insert(0, r"C:\mana\services\mana-video-gen")
|
||||
|
||||
# Redirect stdout/stderr to log file FIRST
|
||||
log = open(r"C:\mana\services\mana-video-gen\service.log", "w", buffering=1)
|
||||
sys.stdout = log
|
||||
sys.stderr = log
|
||||
|
||||
# Load .env file
|
||||
from dotenv import load_dotenv
|
||||
load_dotenv(r"C:\mana\services\mana-video-gen\.env")
|
||||
|
||||
# Ensure FFmpeg is in PATH (LTX needs it for video encoding)
|
||||
ffmpeg_dir = r"C:\Users\tills\AppData\Local\Microsoft\WinGet\Links"
|
||||
if ffmpeg_dir not in os.environ.get("PATH", ""):
|
||||
os.environ["PATH"] = ffmpeg_dir + os.pathsep + os.environ.get("PATH", "")
|
||||
|
||||
# HF token for model downloads
|
||||
hf_token = os.environ.get("HF_TOKEN", "")
|
||||
if hf_token:
|
||||
os.environ["HUGGING_FACE_HUB_TOKEN"] = hf_token
|
||||
|
||||
# Pre-initialize CUDA before importing diffusers
|
||||
import torch
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.init()
|
||||
print(f"CUDA initialized: {torch.cuda.get_device_name(0)}", flush=True)
|
||||
print(f"VRAM: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB", flush=True)
|
||||
else:
|
||||
print("WARNING: CUDA not available, LTX will be unusably slow on CPU", flush=True)
|
||||
|
||||
import uvicorn
|
||||
uvicorn.run("app.main:app", host="0.0.0.0", port=3026, log_level="info")
|
||||
Loading…
Add table
Add a link
Reference in a new issue