mirror of
https://github.com/Memo-2023/mana-monorepo.git
synced 2026-05-14 22:01:09 +02:00
update(infra): mana-stt WhisperX + diarization, mana-notify templates, CD pipeline updates
mana-stt: add WhisperX service with CUDA GPU support, speaker diarization, and auto-fallback chain. mana-notify: add locale fallback and default templates for task reminders. CD: update deployment pipeline and docker-compose configuration. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
parent
a03de84e79
commit
da3a140f21
9 changed files with 1280 additions and 463 deletions
|
|
@ -120,6 +120,13 @@ func (e *Engine) SeedDefaults(ctx context.Context) {
|
|||
body: `<!DOCTYPE html><html><body><h1>{{.eventTitle}}</h1><p>Wann: {{.eventTime}}</p>{{if .eventLocation}}<p>Wo: {{.eventLocation}}</p>{{end}}<p><a href="{{.eventUrl}}">Termin anzeigen</a></p></body></html>`,
|
||||
vars: `{"eventTitle": "Titel", "eventTime": "Zeit", "eventLocation": "Ort (optional)", "eventUrl": "Link"}`,
|
||||
},
|
||||
{
|
||||
slug: "task-reminder",
|
||||
channel: "email",
|
||||
subject: "Erinnerung: {{.taskTitle}}",
|
||||
body: `<!DOCTYPE html><html><body><h1>{{.taskTitle}}</h1>{{if .dueDate}}<p>Fällig: {{.dueDate}}</p>{{end}}<p><a href="{{.taskUrl}}">Aufgabe anzeigen</a></p></body></html>`,
|
||||
vars: `{"taskTitle": "Aufgabentitel", "dueDate": "Fälligkeitsdatum (optional)", "taskUrl": "Link zur Aufgabe"}`,
|
||||
},
|
||||
}
|
||||
|
||||
for _, d := range defaults {
|
||||
|
|
|
|||
|
|
@ -11,6 +11,24 @@ WHISPER_MODEL=large-v3
|
|||
# Options: voxtral-mini-3b, voxtral-realtime-4b, voxtral-small-24b
|
||||
VOXTRAL_MODEL=voxtral-realtime-4b
|
||||
|
||||
# WhisperX (CUDA GPU Server)
|
||||
# Enable WhisperX for rich transcription (diarization, word alignment)
|
||||
# Requires NVIDIA GPU + requirements-cuda.txt
|
||||
USE_WHISPERX=false
|
||||
|
||||
# WhisperX batch size (higher = faster but more VRAM, 16 works well for RTX 3090)
|
||||
WHISPERX_BATCH_SIZE=16
|
||||
|
||||
# Device and compute type for CUDA
|
||||
# WHISPER_DEVICE=cuda
|
||||
# WHISPER_COMPUTE_TYPE=float16
|
||||
|
||||
# HuggingFace token for pyannote speaker diarization models
|
||||
# Required for diarization. Accept terms at:
|
||||
# https://huggingface.co/pyannote/speaker-diarization-3.1
|
||||
# https://huggingface.co/pyannote/segmentation-3.0
|
||||
HF_TOKEN=
|
||||
|
||||
# Model Loading
|
||||
# Set to true to preload models on startup (slower startup, faster first request)
|
||||
PRELOAD_MODELS=false
|
||||
|
|
|
|||
|
|
@ -1,6 +1,6 @@
|
|||
"""
|
||||
ManaCore STT API Service
|
||||
Speech-to-Text with Whisper (MLX), Voxtral (vLLM), and Mistral API (fallback)
|
||||
Speech-to-Text with Whisper (MLX), WhisperX (CUDA), Voxtral (vLLM), and Mistral API (fallback)
|
||||
|
||||
Run with: uvicorn app.main:app --host 0.0.0.0 --port 3020
|
||||
"""
|
||||
|
|
@ -38,6 +38,9 @@ CORS_ORIGINS = os.getenv(
|
|||
VLLM_URL = os.getenv("VLLM_URL", "http://localhost:8100")
|
||||
USE_VLLM = os.getenv("USE_VLLM", "false").lower() == "true"
|
||||
|
||||
# WhisperX configuration (CUDA GPU server)
|
||||
USE_WHISPERX = os.getenv("USE_WHISPERX", "false").lower() == "true"
|
||||
|
||||
|
||||
# Response models
|
||||
class TranscriptionResponse(BaseModel):
|
||||
|
|
@ -48,9 +51,49 @@ class TranscriptionResponse(BaseModel):
|
|||
duration_seconds: Optional[float] = None
|
||||
|
||||
|
||||
class WordTimestampResponse(BaseModel):
|
||||
word: str
|
||||
start: float
|
||||
end: float
|
||||
score: float = 0.0
|
||||
speaker: Optional[str] = None
|
||||
|
||||
|
||||
class SegmentResponse(BaseModel):
|
||||
start: float
|
||||
end: float
|
||||
text: str
|
||||
speaker: Optional[str] = None
|
||||
words: list[WordTimestampResponse] = []
|
||||
|
||||
|
||||
class UtteranceResponse(BaseModel):
|
||||
speaker: int
|
||||
text: str
|
||||
offset: int # milliseconds
|
||||
duration: int # milliseconds
|
||||
|
||||
|
||||
class RichTranscriptionResponse(BaseModel):
|
||||
"""Extended response with segments, utterances, and speaker diarization."""
|
||||
text: str
|
||||
language: Optional[str] = None
|
||||
model: str
|
||||
latency_ms: Optional[float] = None
|
||||
duration_seconds: Optional[float] = None
|
||||
segments: list[SegmentResponse] = []
|
||||
utterances: list[UtteranceResponse] = []
|
||||
speakers: dict[str, str] = {}
|
||||
speaker_map: dict[str, int] = {}
|
||||
languages: list[str] = []
|
||||
primary_language: Optional[str] = None
|
||||
words: list[WordTimestampResponse] = []
|
||||
|
||||
|
||||
class HealthResponse(BaseModel):
|
||||
status: str
|
||||
whisper_loaded: bool
|
||||
whisperx_available: bool
|
||||
vllm_available: bool
|
||||
vllm_url: Optional[str] = None
|
||||
mistral_api_available: bool
|
||||
|
|
@ -60,6 +103,7 @@ class HealthResponse(BaseModel):
|
|||
|
||||
class ModelsResponse(BaseModel):
|
||||
whisper: list
|
||||
whisperx: list
|
||||
voxtral_vllm: list
|
||||
default_whisper: str
|
||||
|
||||
|
|
@ -67,6 +111,7 @@ class ModelsResponse(BaseModel):
|
|||
# Track loaded models
|
||||
models_status = {
|
||||
"whisper_loaded": False,
|
||||
"whisperx_available": False,
|
||||
"vllm_available": False,
|
||||
}
|
||||
|
||||
|
|
@ -86,6 +131,18 @@ async def lifespan(app: FastAPI):
|
|||
else:
|
||||
logger.warning(f"vLLM server not available: {health}")
|
||||
|
||||
# Check WhisperX availability
|
||||
if USE_WHISPERX:
|
||||
try:
|
||||
from app.whisperx_service import is_available as whisperx_available
|
||||
models_status["whisperx_available"] = whisperx_available()
|
||||
if models_status["whisperx_available"]:
|
||||
logger.info("WhisperX (CUDA) available")
|
||||
else:
|
||||
logger.warning("WhisperX not available (whisperx package not installed)")
|
||||
except Exception as e:
|
||||
logger.warning(f"WhisperX check failed: {e}")
|
||||
|
||||
# Check Mistral API
|
||||
from app.voxtral_api_service import is_available as api_available
|
||||
if api_available():
|
||||
|
|
@ -136,6 +193,7 @@ async def health_check():
|
|||
return HealthResponse(
|
||||
status="healthy",
|
||||
whisper_loaded=models_status["whisper_loaded"],
|
||||
whisperx_available=models_status["whisperx_available"],
|
||||
vllm_available=vllm_health.get("status") == "healthy",
|
||||
vllm_url=VLLM_URL if USE_VLLM else None,
|
||||
mistral_api_available=api_available(),
|
||||
|
|
@ -154,8 +212,17 @@ async def list_models(auth: AuthResult = Depends(verify_api_key)):
|
|||
|
||||
vllm_models = await get_models()
|
||||
|
||||
whisperx_models = []
|
||||
if USE_WHISPERX:
|
||||
try:
|
||||
from app.whisperx_service import AVAILABLE_MODELS as wx_models
|
||||
whisperx_models = wx_models
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
return ModelsResponse(
|
||||
whisper=whisper_models,
|
||||
whisperx=whisperx_models,
|
||||
voxtral_vllm=vllm_models,
|
||||
default_whisper=DEFAULT_WHISPER_MODEL,
|
||||
)
|
||||
|
|
@ -386,50 +453,216 @@ async def transcribe_voxtral_api(
|
|||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@app.post("/transcribe/whisperx", response_model=RichTranscriptionResponse)
|
||||
async def transcribe_whisperx(
|
||||
response: Response,
|
||||
file: UploadFile = File(..., description="Audio file to transcribe"),
|
||||
language: Optional[str] = Form(None, description="Language code (auto-detect if not provided)"),
|
||||
model: Optional[str] = Form(None, description="Whisper model to use"),
|
||||
diarization: bool = Form(True, description="Enable speaker diarization"),
|
||||
alignment: bool = Form(True, description="Enable word-level alignment"),
|
||||
min_speakers: Optional[int] = Form(None, description="Minimum expected speakers"),
|
||||
max_speakers: Optional[int] = Form(None, description="Maximum expected speakers"),
|
||||
auth: AuthResult = Depends(verify_api_key),
|
||||
):
|
||||
"""
|
||||
Transcribe audio using WhisperX (CUDA GPU).
|
||||
|
||||
Returns rich transcription with:
|
||||
- Word-level timestamps (via forced alignment)
|
||||
- Speaker diarization (via pyannote.audio)
|
||||
- Memoro-compatible utterances with speaker IDs
|
||||
|
||||
Requires NVIDIA GPU with CUDA and USE_WHISPERX=true.
|
||||
Diarization requires HF_TOKEN with pyannote model access.
|
||||
|
||||
Supported formats: mp3, wav, m4a, flac, ogg, webm, mp4
|
||||
Max file size: 100MB
|
||||
"""
|
||||
if auth.rate_limit_remaining is not None:
|
||||
response.headers["X-RateLimit-Remaining"] = str(auth.rate_limit_remaining)
|
||||
|
||||
if not USE_WHISPERX:
|
||||
raise HTTPException(
|
||||
status_code=503,
|
||||
detail="WhisperX not enabled. Set USE_WHISPERX=true on a CUDA-capable server."
|
||||
)
|
||||
|
||||
if not file.filename:
|
||||
raise HTTPException(status_code=400, detail="No file provided")
|
||||
|
||||
allowed_extensions = {".mp3", ".wav", ".m4a", ".flac", ".ogg", ".webm", ".mp4"}
|
||||
ext = os.path.splitext(file.filename)[1].lower()
|
||||
if ext not in allowed_extensions:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Unsupported file type: {ext}. Allowed: {allowed_extensions}"
|
||||
)
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
try:
|
||||
from app.whisperx_service import transcribe_audio_bytes
|
||||
|
||||
audio_bytes = await file.read()
|
||||
if len(audio_bytes) > 100 * 1024 * 1024:
|
||||
raise HTTPException(status_code=400, detail="File too large (max 100MB)")
|
||||
|
||||
model_name = model or DEFAULT_WHISPER_MODEL
|
||||
|
||||
result = await transcribe_audio_bytes(
|
||||
audio_bytes=audio_bytes,
|
||||
filename=file.filename,
|
||||
language=language,
|
||||
model_name=model_name,
|
||||
enable_diarization=diarization,
|
||||
enable_alignment=alignment,
|
||||
min_speakers=min_speakers,
|
||||
max_speakers=max_speakers,
|
||||
)
|
||||
|
||||
latency_ms = (time.time() - start_time) * 1000
|
||||
|
||||
return RichTranscriptionResponse(
|
||||
text=result.text,
|
||||
language=result.language,
|
||||
model=f"whisperx-{model_name}",
|
||||
latency_ms=latency_ms,
|
||||
duration_seconds=result.duration_seconds,
|
||||
segments=[
|
||||
SegmentResponse(
|
||||
start=s.start,
|
||||
end=s.end,
|
||||
text=s.text,
|
||||
speaker=s.speaker,
|
||||
words=[
|
||||
WordTimestampResponse(
|
||||
word=w.word,
|
||||
start=w.start,
|
||||
end=w.end,
|
||||
score=w.score,
|
||||
speaker=w.speaker,
|
||||
)
|
||||
for w in s.words
|
||||
],
|
||||
)
|
||||
for s in result.segments
|
||||
],
|
||||
utterances=[
|
||||
UtteranceResponse(
|
||||
speaker=u.speaker,
|
||||
text=u.text,
|
||||
offset=u.offset,
|
||||
duration=u.duration,
|
||||
)
|
||||
for u in result.utterances
|
||||
],
|
||||
speakers=result.speakers,
|
||||
speaker_map={k: v for k, v in result.speaker_map.items()},
|
||||
languages=result.languages,
|
||||
primary_language=result.primary_language,
|
||||
words=[
|
||||
WordTimestampResponse(
|
||||
word=w.word,
|
||||
start=w.start,
|
||||
end=w.end,
|
||||
score=w.score,
|
||||
speaker=w.speaker,
|
||||
)
|
||||
for w in result.words
|
||||
],
|
||||
)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except ImportError:
|
||||
raise HTTPException(
|
||||
status_code=503,
|
||||
detail="WhisperX not installed. Install with: pip install -r requirements-cuda.txt"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"WhisperX transcription error: {e}")
|
||||
raise HTTPException(status_code=500, detail=str(e))
|
||||
|
||||
|
||||
@app.post("/transcribe/auto", response_model=TranscriptionResponse)
|
||||
async def transcribe_auto(
|
||||
response: Response,
|
||||
file: UploadFile = File(..., description="Audio file to transcribe"),
|
||||
language: Optional[str] = Form(None, description="Language hint"),
|
||||
prefer: str = Form("whisper", description="Preferred: 'whisper' or 'voxtral'"),
|
||||
prefer: str = Form("whisper", description="Preferred: 'whisper', 'whisperx', or 'voxtral'"),
|
||||
auth: AuthResult = Depends(verify_api_key),
|
||||
):
|
||||
"""
|
||||
Transcribe with automatic model selection and fallback.
|
||||
|
||||
Fallback chain:
|
||||
1. Preferred model (whisper or voxtral)
|
||||
2. Alternative model
|
||||
3. Mistral API
|
||||
- whisper: Whisper → WhisperX → Voxtral → Mistral API
|
||||
- whisperx: WhisperX → Whisper → Voxtral → Mistral API
|
||||
- voxtral: Voxtral → WhisperX → Whisper → Mistral API
|
||||
"""
|
||||
# Add rate limit headers
|
||||
if auth.rate_limit_remaining is not None:
|
||||
response.headers["X-RateLimit-Remaining"] = str(auth.rate_limit_remaining)
|
||||
|
||||
if prefer == "voxtral":
|
||||
try:
|
||||
return await transcribe_voxtral(file, language or "de", False)
|
||||
except Exception as e:
|
||||
logger.warning(f"Voxtral failed, trying Whisper: {e}")
|
||||
await file.seek(0)
|
||||
try:
|
||||
return await transcribe_whisper(file, language, None)
|
||||
except Exception as e2:
|
||||
logger.warning(f"Whisper failed, trying API: {e2}")
|
||||
await file.seek(0)
|
||||
return await transcribe_voxtral_api(file, language, False)
|
||||
async def try_whisperx_simple():
|
||||
"""Try WhisperX and return as simple TranscriptionResponse."""
|
||||
if not USE_WHISPERX:
|
||||
raise RuntimeError("WhisperX not enabled")
|
||||
from app.whisperx_service import transcribe_audio_bytes as wx_transcribe
|
||||
audio_bytes = await file.read()
|
||||
result = await wx_transcribe(
|
||||
audio_bytes=audio_bytes,
|
||||
filename=file.filename or "audio.wav",
|
||||
language=language,
|
||||
enable_diarization=False,
|
||||
enable_alignment=False,
|
||||
)
|
||||
return TranscriptionResponse(
|
||||
text=result.text,
|
||||
language=result.language,
|
||||
model=f"whisperx-{DEFAULT_WHISPER_MODEL}",
|
||||
latency_ms=None,
|
||||
duration_seconds=result.duration_seconds,
|
||||
)
|
||||
|
||||
# Build fallback chain based on preference
|
||||
if prefer == "whisperx":
|
||||
chain = [
|
||||
("WhisperX", try_whisperx_simple),
|
||||
("Whisper", lambda: transcribe_whisper(response, file, language, None, auth)),
|
||||
("Voxtral", lambda: transcribe_voxtral(response, file, language or "de", False, auth)),
|
||||
("Mistral API", lambda: transcribe_voxtral_api(response, file, language, False, auth)),
|
||||
]
|
||||
elif prefer == "voxtral":
|
||||
chain = [
|
||||
("Voxtral", lambda: transcribe_voxtral(response, file, language or "de", False, auth)),
|
||||
("WhisperX", try_whisperx_simple),
|
||||
("Whisper", lambda: transcribe_whisper(response, file, language, None, auth)),
|
||||
("Mistral API", lambda: transcribe_voxtral_api(response, file, language, False, auth)),
|
||||
]
|
||||
else:
|
||||
chain = [
|
||||
("Whisper", lambda: transcribe_whisper(response, file, language, None, auth)),
|
||||
("WhisperX", try_whisperx_simple),
|
||||
("Voxtral", lambda: transcribe_voxtral(response, file, language or "de", False, auth)),
|
||||
("Mistral API", lambda: transcribe_voxtral_api(response, file, language, False, auth)),
|
||||
]
|
||||
|
||||
last_error = None
|
||||
for name, fn in chain:
|
||||
try:
|
||||
return await transcribe_whisper(file, language, None)
|
||||
result = await fn()
|
||||
return result
|
||||
except Exception as e:
|
||||
logger.warning(f"Whisper failed, trying Voxtral: {e}")
|
||||
last_error = e
|
||||
logger.warning(f"{name} failed: {e}")
|
||||
await file.seek(0)
|
||||
try:
|
||||
return await transcribe_voxtral(file, language or "de", False)
|
||||
except Exception as e2:
|
||||
logger.warning(f"Voxtral failed, trying API: {e2}")
|
||||
await file.seek(0)
|
||||
return await transcribe_voxtral_api(file, language, False)
|
||||
|
||||
raise HTTPException(
|
||||
status_code=503,
|
||||
detail=f"All transcription backends failed. Last error: {last_error}"
|
||||
)
|
||||
|
||||
|
||||
@app.exception_handler(Exception)
|
||||
|
|
|
|||
419
services/mana-stt/app/whisperx_service.py
Normal file
419
services/mana-stt/app/whisperx_service.py
Normal file
|
|
@ -0,0 +1,419 @@
|
|||
"""
|
||||
WhisperX STT Service using faster-whisper + pyannote (CUDA)
|
||||
Optimized for NVIDIA GPUs (RTX 3090 etc.)
|
||||
|
||||
Features:
|
||||
- Word-level timestamps via forced alignment
|
||||
- Speaker diarization via pyannote.audio
|
||||
- Segment-level timestamps with speaker labels
|
||||
- VAD filtering for silence removal
|
||||
|
||||
Requires HuggingFace token for pyannote models:
|
||||
export HF_TOKEN=hf_xxx
|
||||
# Accept terms at: https://huggingface.co/pyannote/speaker-diarization-3.1
|
||||
"""
|
||||
|
||||
import os
|
||||
import tempfile
|
||||
import logging
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Lazy-loaded singletons
|
||||
_whisper_model = None
|
||||
_align_model = None
|
||||
_align_metadata = None
|
||||
_diarize_pipeline = None
|
||||
|
||||
HF_TOKEN = os.getenv("HF_TOKEN")
|
||||
DEFAULT_MODEL = os.getenv("WHISPER_MODEL", "large-v3")
|
||||
DEVICE = os.getenv("WHISPER_DEVICE", "cuda")
|
||||
COMPUTE_TYPE = os.getenv("WHISPER_COMPUTE_TYPE", "float16")
|
||||
BATCH_SIZE = int(os.getenv("WHISPERX_BATCH_SIZE", "16"))
|
||||
|
||||
|
||||
@dataclass
|
||||
class WordInfo:
|
||||
"""Word with timestamp."""
|
||||
word: str
|
||||
start: float
|
||||
end: float
|
||||
score: float = 0.0
|
||||
speaker: Optional[str] = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class SegmentInfo:
|
||||
"""Segment with speaker and word-level detail."""
|
||||
start: float
|
||||
end: float
|
||||
text: str
|
||||
speaker: Optional[str] = None
|
||||
words: list[WordInfo] = field(default_factory=list)
|
||||
|
||||
|
||||
@dataclass
|
||||
class Utterance:
|
||||
"""Speaker utterance in Memoro-compatible format."""
|
||||
speaker: int
|
||||
text: str
|
||||
offset: int # milliseconds
|
||||
duration: int # milliseconds
|
||||
|
||||
|
||||
@dataclass
|
||||
class WhisperXResult:
|
||||
"""Rich transcription result with alignment and diarization."""
|
||||
text: str
|
||||
language: Optional[str] = None
|
||||
duration_seconds: Optional[float] = None
|
||||
segments: list[SegmentInfo] = field(default_factory=list)
|
||||
utterances: list[Utterance] = field(default_factory=list)
|
||||
speakers: dict[str, str] = field(default_factory=dict)
|
||||
speaker_map: dict[str, int] = field(default_factory=dict)
|
||||
languages: list[str] = field(default_factory=list)
|
||||
primary_language: Optional[str] = None
|
||||
words: list[WordInfo] = field(default_factory=list)
|
||||
|
||||
|
||||
def get_whisper_model(model_name: str = None):
|
||||
"""Load faster-whisper model (singleton)."""
|
||||
global _whisper_model
|
||||
model_name = model_name or DEFAULT_MODEL
|
||||
|
||||
if _whisper_model is None:
|
||||
logger.info(f"Loading WhisperX model: {model_name} on {DEVICE} ({COMPUTE_TYPE})")
|
||||
try:
|
||||
import whisperx
|
||||
|
||||
_whisper_model = whisperx.load_model(
|
||||
model_name,
|
||||
device=DEVICE,
|
||||
compute_type=COMPUTE_TYPE,
|
||||
)
|
||||
logger.info(f"WhisperX model loaded: {model_name}")
|
||||
except ImportError:
|
||||
raise RuntimeError(
|
||||
"whisperx not installed. "
|
||||
"Run: pip install whisperx"
|
||||
)
|
||||
|
||||
return _whisper_model
|
||||
|
||||
|
||||
def get_align_model(language_code: str):
|
||||
"""Load alignment model for a specific language (cached per language)."""
|
||||
global _align_model, _align_metadata
|
||||
|
||||
try:
|
||||
import whisperx
|
||||
|
||||
_align_model, _align_metadata = whisperx.load_align_model(
|
||||
language_code=language_code,
|
||||
device=DEVICE,
|
||||
)
|
||||
logger.info(f"Alignment model loaded for language: {language_code}")
|
||||
return _align_model, _align_metadata
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to load alignment model for {language_code}: {e}")
|
||||
return None, None
|
||||
|
||||
|
||||
def get_diarize_pipeline():
|
||||
"""Load pyannote speaker diarization pipeline (singleton)."""
|
||||
global _diarize_pipeline
|
||||
|
||||
if _diarize_pipeline is None:
|
||||
if not HF_TOKEN:
|
||||
logger.warning("HF_TOKEN not set — diarization disabled")
|
||||
return None
|
||||
|
||||
try:
|
||||
import whisperx
|
||||
|
||||
_diarize_pipeline = whisperx.DiarizationPipeline(
|
||||
use_auth_token=HF_TOKEN,
|
||||
device=DEVICE,
|
||||
)
|
||||
logger.info("Diarization pipeline loaded")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to load diarization pipeline: {e}")
|
||||
return None
|
||||
|
||||
return _diarize_pipeline
|
||||
|
||||
|
||||
def _build_utterances(segments: list[SegmentInfo]) -> tuple[list[Utterance], dict[str, str], dict[str, int]]:
|
||||
"""
|
||||
Build Memoro-compatible utterances from diarized segments.
|
||||
Groups consecutive segments by the same speaker.
|
||||
"""
|
||||
if not segments:
|
||||
return [], {}, {}
|
||||
|
||||
# Collect unique speakers
|
||||
speaker_labels = sorted(set(
|
||||
s.speaker for s in segments if s.speaker is not None
|
||||
))
|
||||
speaker_map: dict[str, int] = {}
|
||||
speakers: dict[str, str] = {}
|
||||
for idx, label in enumerate(speaker_labels):
|
||||
speaker_map[label] = idx
|
||||
speakers[str(idx)] = label
|
||||
|
||||
# Merge consecutive segments with the same speaker
|
||||
utterances: list[Utterance] = []
|
||||
current_speaker = None
|
||||
current_text_parts: list[str] = []
|
||||
current_start = 0.0
|
||||
current_end = 0.0
|
||||
|
||||
for seg in segments:
|
||||
sp = seg.speaker or "SPEAKER_00"
|
||||
if sp != current_speaker:
|
||||
# Flush previous
|
||||
if current_speaker is not None and current_text_parts:
|
||||
utterances.append(Utterance(
|
||||
speaker=speaker_map.get(current_speaker, 0),
|
||||
text=" ".join(current_text_parts).strip(),
|
||||
offset=int(current_start * 1000),
|
||||
duration=int((current_end - current_start) * 1000),
|
||||
))
|
||||
current_speaker = sp
|
||||
current_text_parts = [seg.text]
|
||||
current_start = seg.start
|
||||
current_end = seg.end
|
||||
else:
|
||||
current_text_parts.append(seg.text)
|
||||
current_end = seg.end
|
||||
|
||||
# Flush last
|
||||
if current_speaker is not None and current_text_parts:
|
||||
utterances.append(Utterance(
|
||||
speaker=speaker_map.get(current_speaker, 0),
|
||||
text=" ".join(current_text_parts).strip(),
|
||||
offset=int(current_start * 1000),
|
||||
duration=int((current_end - current_start) * 1000),
|
||||
))
|
||||
|
||||
return utterances, speakers, speaker_map
|
||||
|
||||
|
||||
def transcribe_audio(
|
||||
audio_path: str,
|
||||
language: Optional[str] = None,
|
||||
model_name: Optional[str] = None,
|
||||
enable_diarization: bool = True,
|
||||
enable_alignment: bool = True,
|
||||
min_speakers: Optional[int] = None,
|
||||
max_speakers: Optional[int] = None,
|
||||
) -> WhisperXResult:
|
||||
"""
|
||||
Transcribe audio with WhisperX: alignment + diarization.
|
||||
|
||||
Args:
|
||||
audio_path: Path to audio file
|
||||
language: Language code (e.g., 'de', 'en'). Auto-detect if None.
|
||||
model_name: Whisper model to use
|
||||
enable_diarization: Run speaker diarization
|
||||
enable_alignment: Run forced word alignment
|
||||
min_speakers: Minimum expected speakers (hint for diarization)
|
||||
max_speakers: Maximum expected speakers (hint for diarization)
|
||||
|
||||
Returns:
|
||||
WhisperXResult with full transcription, segments, utterances, speakers
|
||||
"""
|
||||
import whisperx
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
# 1. Load audio
|
||||
audio = whisperx.load_audio(audio_path)
|
||||
audio_duration = len(audio) / 16000 # whisperx resamples to 16kHz
|
||||
|
||||
# 2. Transcribe with faster-whisper
|
||||
model = get_whisper_model(model_name)
|
||||
transcribe_result = model.transcribe(
|
||||
audio,
|
||||
batch_size=BATCH_SIZE,
|
||||
language=language,
|
||||
)
|
||||
|
||||
detected_language = transcribe_result.get("language", language or "en")
|
||||
raw_segments = transcribe_result.get("segments", [])
|
||||
|
||||
logger.info(
|
||||
f"Transcription: {len(raw_segments)} segments, "
|
||||
f"language={detected_language}, "
|
||||
f"duration={audio_duration:.1f}s"
|
||||
)
|
||||
|
||||
# 3. Forced alignment (word-level timestamps)
|
||||
if enable_alignment and raw_segments:
|
||||
align_model, align_metadata = get_align_model(detected_language)
|
||||
if align_model is not None:
|
||||
try:
|
||||
transcribe_result = whisperx.align(
|
||||
raw_segments,
|
||||
align_model,
|
||||
align_metadata,
|
||||
audio,
|
||||
DEVICE,
|
||||
return_char_alignments=False,
|
||||
)
|
||||
raw_segments = transcribe_result.get("segments", raw_segments)
|
||||
logger.info("Word alignment complete")
|
||||
except Exception as e:
|
||||
logger.warning(f"Alignment failed, using segment-level timestamps: {e}")
|
||||
|
||||
# 4. Speaker diarization
|
||||
if enable_diarization:
|
||||
diarize_pipeline = get_diarize_pipeline()
|
||||
if diarize_pipeline is not None:
|
||||
try:
|
||||
diarize_kwargs = {}
|
||||
if min_speakers is not None:
|
||||
diarize_kwargs["min_speakers"] = min_speakers
|
||||
if max_speakers is not None:
|
||||
diarize_kwargs["max_speakers"] = max_speakers
|
||||
|
||||
diarize_segments = diarize_pipeline(
|
||||
audio_path,
|
||||
**diarize_kwargs,
|
||||
)
|
||||
transcribe_result = whisperx.assign_word_speakers(
|
||||
diarize_segments, transcribe_result
|
||||
)
|
||||
raw_segments = transcribe_result.get("segments", raw_segments)
|
||||
logger.info("Diarization complete")
|
||||
except Exception as e:
|
||||
logger.warning(f"Diarization failed: {e}")
|
||||
|
||||
# 5. Build structured result
|
||||
segments: list[SegmentInfo] = []
|
||||
all_words: list[WordInfo] = []
|
||||
full_text_parts: list[str] = []
|
||||
|
||||
for seg in raw_segments:
|
||||
seg_words: list[WordInfo] = []
|
||||
for w in seg.get("words", []):
|
||||
wi = WordInfo(
|
||||
word=w.get("word", ""),
|
||||
start=w.get("start", 0.0),
|
||||
end=w.get("end", 0.0),
|
||||
score=w.get("score", 0.0),
|
||||
speaker=w.get("speaker"),
|
||||
)
|
||||
seg_words.append(wi)
|
||||
all_words.append(wi)
|
||||
|
||||
segment = SegmentInfo(
|
||||
start=seg.get("start", 0.0),
|
||||
end=seg.get("end", 0.0),
|
||||
text=seg.get("text", "").strip(),
|
||||
speaker=seg.get("speaker"),
|
||||
words=seg_words,
|
||||
)
|
||||
segments.append(segment)
|
||||
full_text_parts.append(segment.text)
|
||||
|
||||
full_text = " ".join(full_text_parts)
|
||||
|
||||
# 6. Build utterances (Memoro-compatible)
|
||||
utterances, speakers, speaker_map = _build_utterances(segments)
|
||||
|
||||
latency = time.time() - start_time
|
||||
logger.info(f"WhisperX complete in {latency:.1f}s: {len(full_text)} chars, {len(speakers)} speakers")
|
||||
|
||||
return WhisperXResult(
|
||||
text=full_text,
|
||||
language=detected_language,
|
||||
duration_seconds=audio_duration,
|
||||
segments=segments,
|
||||
utterances=utterances,
|
||||
speakers=speakers,
|
||||
speaker_map=speaker_map,
|
||||
languages=[detected_language] if detected_language else [],
|
||||
primary_language=detected_language,
|
||||
words=all_words,
|
||||
)
|
||||
|
||||
|
||||
async def transcribe_audio_bytes(
|
||||
audio_bytes: bytes,
|
||||
filename: str,
|
||||
language: Optional[str] = None,
|
||||
model_name: Optional[str] = None,
|
||||
enable_diarization: bool = True,
|
||||
enable_alignment: bool = True,
|
||||
min_speakers: Optional[int] = None,
|
||||
max_speakers: Optional[int] = None,
|
||||
) -> WhisperXResult:
|
||||
"""
|
||||
Transcribe audio from bytes (for API uploads).
|
||||
|
||||
Args:
|
||||
audio_bytes: Raw audio file bytes
|
||||
filename: Original filename (for extension detection)
|
||||
language: Optional language code
|
||||
model_name: Whisper model to use
|
||||
enable_diarization: Run speaker diarization
|
||||
enable_alignment: Run forced word alignment
|
||||
min_speakers: Min expected speakers
|
||||
max_speakers: Max expected speakers
|
||||
|
||||
Returns:
|
||||
WhisperXResult with full transcription data
|
||||
"""
|
||||
ext = Path(filename).suffix or ".wav"
|
||||
|
||||
with tempfile.NamedTemporaryFile(suffix=ext, delete=False) as tmp:
|
||||
tmp.write(audio_bytes)
|
||||
tmp_path = tmp.name
|
||||
|
||||
try:
|
||||
return transcribe_audio(
|
||||
audio_path=tmp_path,
|
||||
language=language,
|
||||
model_name=model_name,
|
||||
enable_diarization=enable_diarization,
|
||||
enable_alignment=enable_alignment,
|
||||
min_speakers=min_speakers,
|
||||
max_speakers=max_speakers,
|
||||
)
|
||||
finally:
|
||||
try:
|
||||
os.unlink(tmp_path)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
|
||||
def is_available() -> bool:
|
||||
"""Check if WhisperX dependencies are installed."""
|
||||
try:
|
||||
import whisperx
|
||||
return True
|
||||
except ImportError:
|
||||
return False
|
||||
|
||||
|
||||
AVAILABLE_MODELS = [
|
||||
"tiny",
|
||||
"tiny.en",
|
||||
"base",
|
||||
"base.en",
|
||||
"small",
|
||||
"small.en",
|
||||
"medium",
|
||||
"medium.en",
|
||||
"large-v1",
|
||||
"large-v2",
|
||||
"large-v3",
|
||||
"large-v3-turbo",
|
||||
"distil-large-v2",
|
||||
"distil-large-v3",
|
||||
]
|
||||
35
services/mana-stt/requirements-cuda.txt
Normal file
35
services/mana-stt/requirements-cuda.txt
Normal file
|
|
@ -0,0 +1,35 @@
|
|||
# ManaCore STT Service Dependencies
|
||||
# For GPU Server (NVIDIA RTX 3090 / CUDA)
|
||||
|
||||
# Web Framework
|
||||
fastapi==0.115.6
|
||||
uvicorn[standard]==0.34.0
|
||||
python-multipart==0.0.20
|
||||
|
||||
# Audio Processing
|
||||
pydub==0.25.1
|
||||
soundfile==0.13.1
|
||||
|
||||
# WhisperX (CUDA) — includes faster-whisper + alignment
|
||||
whisperx @ git+https://github.com/m-bain/whisperX.git
|
||||
|
||||
# faster-whisper with CTranslate2 (CUDA backend)
|
||||
faster-whisper>=1.1.0
|
||||
|
||||
# Speaker Diarization (pyannote.audio)
|
||||
# Requires HF_TOKEN with accepted terms:
|
||||
# https://huggingface.co/pyannote/speaker-diarization-3.1
|
||||
# https://huggingface.co/pyannote/segmentation-3.0
|
||||
pyannote.audio>=3.3.0
|
||||
|
||||
# PyTorch CUDA — install separately for your CUDA version:
|
||||
# pip install torch torchaudio --index-url https://download.pytorch.org/whl/cu121
|
||||
torch>=2.5.0
|
||||
torchaudio>=2.5.0
|
||||
|
||||
# Utilities
|
||||
numpy>=1.26.0
|
||||
tqdm>=4.67.0
|
||||
|
||||
# External Auth (mana-core-auth integration)
|
||||
httpx>=0.27.0
|
||||
Loading…
Add table
Add a link
Reference in a new issue