managarten/services/mana-stt/app/whisper_service.py
Till JS b8e18b7f82 chore(ai-services): adopt Windows GPU as source of truth for llm/stt/tts
The Windows GPU server has been the actual production home for these
services for some time, and the running code there has drifted ahead of
the repo. This sync pulls the live versions back into the repo so the
Windows box is no longer the only place those changes exist.

Pulled from C:\mana\services\* on mana-server-gpu (192.168.178.11):

mana-llm:
- src/main.py, src/config.py — small fixes (auth wiring, config tweaks)
- src/api_auth.py — NEW (cross-service GPU_API_KEY validator)
- service.pyw — Windows runner used by the ManaLLM scheduled task
  (sets up logging redirect, loads .env, calls uvicorn)

mana-stt:
- app/main.py — substantial cleanup (684→392 lines), drops the
  whisperx-as-separate-backend branching now that whisper_service.py
  rolls whisperx in directly
- app/whisper_service.py — full CUDA + whisperx rewrite (158→358 lines)
- app/auth.py + external_auth.py — significantly expanded auth
- app/vram_manager.py — NEW (shared VRAM accounting helper)
- service.pyw — Windows runner with CUDA pre-init, FFmpeg PATH
  injection, .env loading
- removed: app/whisper_service_cuda.py (folded into whisper_service.py)
- removed: app/whisperx_service.py (folded into whisper_service.py)

mana-tts:
- app/auth.py, external_auth.py — same auth expansion as stt
- app/f5_service.py, kokoro_service.py — Windows tweaks
- app/vram_manager.py — NEW (same shared helper as stt)
- service.pyw — Windows runner

mana-video-gen:
- service.pyw — Windows runner (no other changes; the .py code on the
  GPU box is byte-identical to what's already in the repo)

The service.pyw files contain absolute Windows paths
(C:\mana\services\<svc>) and a hardcoded FFmpeg PATH for the tills user
profile. Kept as-is intentionally — they exist to be deployed to that
one machine and any abstraction layer would just hide what's actually
happening. Anyone redeploying to a different layout will need to edit
the path strings, which is a known and obvious change.

Mac-Mini infrastructure for these services (launchd plists, install
scripts, scripts/mac-mini/setup-{stt,tts}.sh, the Mac-flux2c image-gen
implementation) is still on disk and will be removed in a follow-up
commit, along with replacing mana-image-gen with the Windows
diffusers+CUDA implementation. This commit is just the live-code sync.
2026-04-08 12:46:03 +02:00

358 lines
10 KiB
Python

"""
Whisper STT Service using WhisperX (CUDA)
Provides: transcription, word-level timestamps, speaker diarization.
WhisperX pipeline:
1. faster-whisper for transcription
2. wav2vec2 for forced alignment (precise word timestamps)
3. pyannote-audio for speaker diarization
"""
import os
import tempfile
import logging
from pathlib import Path
from typing import Optional
from dataclasses import dataclass, field
logger = logging.getLogger(__name__)
# Lazy load
_whisperx_model = None
_align_model = None
_align_metadata = None
_diarize_pipeline = None
# Config
HF_TOKEN = os.getenv("HF_TOKEN", "")
# VRAM management — unload after 10 min idle (STT uses ~3GB)
from app.vram_manager import VramManager
_vram = VramManager(
idle_timeout=int(os.getenv("VRAM_IDLE_TIMEOUT", "600")),
service_name="mana-stt",
)
def unload_models():
"""Unload all WhisperX models from GPU to free VRAM."""
global _whisperx_model, _align_model, _align_metadata, _diarize_pipeline
import torch
if _whisperx_model is not None:
del _whisperx_model
_whisperx_model = None
if _align_model is not None:
del _align_model
_align_model = None
_align_metadata = None
if _diarize_pipeline is not None:
del _diarize_pipeline
_diarize_pipeline = None
torch.cuda.empty_cache()
_vram.mark_unloaded()
logger.info("WhisperX models unloaded, VRAM freed")
@dataclass
class WordSegment:
word: str
start: float
end: float
score: Optional[float] = None
speaker: Optional[str] = None
@dataclass
class TranscriptionResult:
text: str
language: Optional[str] = None
duration: Optional[float] = None
segments: Optional[list] = None
words: Optional[list[WordSegment]] = field(default_factory=list)
speakers: Optional[list[str]] = field(default_factory=list)
def get_whisper_model(model_name: str = "large-v3", **kwargs):
"""Get or create WhisperX model instance (singleton)."""
global _whisperx_model
if _whisperx_model is not None:
return _whisperx_model
logger.info(f"Loading WhisperX model: {model_name}")
try:
import whisperx
device = os.getenv("WHISPER_DEVICE", "cuda")
compute_type = os.getenv("WHISPER_COMPUTE_TYPE", "float16")
default_language = os.getenv("WHISPER_DEFAULT_LANGUAGE", "de")
_whisperx_model = whisperx.load_model(
model_name,
device=device,
compute_type=compute_type,
language=default_language,
)
logger.info(f"WhisperX model loaded: {model_name} on {device} ({compute_type})")
_vram.mark_loaded()
except ImportError as e:
logger.error(f"Failed to import whisperx: {e}")
raise RuntimeError("whisperx not installed. Run: pip install whisperx")
except Exception as e:
logger.error(f"Failed to load WhisperX model: {e}")
raise
return _whisperx_model
def _get_align_model(language: str, device: str = "cuda"):
"""Get or create alignment model for a language."""
global _align_model, _align_metadata
import whisperx
# Reload if language changed (alignment models are language-specific)
if _align_model is None:
logger.info(f"Loading alignment model for language: {language}")
_align_model, _align_metadata = whisperx.load_align_model(
language_code=language,
device=device,
)
logger.info("Alignment model loaded")
return _align_model, _align_metadata
def _get_diarize_pipeline(device: str = "cuda"):
"""Get or create speaker diarization pipeline."""
global _diarize_pipeline
if _diarize_pipeline is not None:
return _diarize_pipeline
import torch
from pyannote.audio import Pipeline
token = HF_TOKEN or os.getenv("HUGGING_FACE_HUB_TOKEN", "")
if not token:
logger.warning("No HF_TOKEN set — speaker diarization may fail for gated models")
logger.info("Loading speaker diarization pipeline (pyannote)...")
_diarize_pipeline = Pipeline.from_pretrained(
"pyannote/speaker-diarization-3.1",
token=token,
)
_diarize_pipeline.to(torch.device(device))
logger.info("Diarization pipeline loaded")
return _diarize_pipeline
def transcribe_audio(
audio_path: str,
language: Optional[str] = None,
model_name: str = "large-v3",
align: bool = True,
diarize: bool = False,
min_speakers: Optional[int] = None,
max_speakers: Optional[int] = None,
) -> TranscriptionResult:
"""
Transcribe audio using WhisperX with optional alignment and diarization.
Args:
audio_path: Path to audio file
language: Language code (auto-detect if None)
model_name: Whisper model to use
align: Enable word-level timestamp alignment
diarize: Enable speaker diarization
min_speakers: Minimum expected speakers (helps diarization)
max_speakers: Maximum expected speakers
Returns:
TranscriptionResult with text, word timestamps, and speaker info
"""
import whisperx
device = os.getenv("WHISPER_DEVICE", "cuda")
model = get_whisper_model(model_name)
logger.info(f"Transcribing: {audio_path} (align={align}, diarize={diarize})")
# Check and unload if idle, then reload
_vram.check_and_unload(unload_models)
_vram.touch()
# Step 1: Load audio
audio = whisperx.load_audio(audio_path)
# Step 2: Transcribe with faster-whisper
transcribe_kwargs = {"batch_size": 16}
if language:
transcribe_kwargs["language"] = language
result = model.transcribe(audio, **transcribe_kwargs)
detected_language = result.get("language", language or "en")
# Step 3: Align (word-level timestamps)
if align and result["segments"]:
try:
align_model, metadata = _get_align_model(detected_language, device)
result = whisperx.align(
result["segments"],
align_model,
metadata,
audio,
device,
return_char_alignments=False,
)
logger.info("Word alignment complete")
except Exception as e:
logger.warning(f"Alignment failed (continuing without): {e}")
# Step 4: Diarize (speaker identification)
if diarize:
try:
import torch
import torchaudio
diarize_pipe = _get_diarize_pipeline(device)
# pyannote needs waveform as tensor, not the whisperx audio array
waveform = torch.from_numpy(audio).unsqueeze(0).float()
diarize_input = {"waveform": waveform, "sample_rate": 16000}
diarize_kwargs = {}
if min_speakers is not None:
diarize_kwargs["min_speakers"] = min_speakers
if max_speakers is not None:
diarize_kwargs["max_speakers"] = max_speakers
diarize_output = diarize_pipe(diarize_input, **diarize_kwargs)
# pyannote 4.x returns DiarizeOutput, extract the Annotation
if hasattr(diarize_output, "speaker_diarization"):
diarize_annotation = diarize_output.speaker_diarization
else:
diarize_annotation = diarize_output
# Convert pyannote output to DataFrame for whisperx
import pandas as pd
diarize_rows = []
for turn, _, speaker in diarize_annotation.itertracks(yield_label=True):
diarize_rows.append({
"start": turn.start,
"end": turn.end,
"speaker": speaker,
})
diarize_df = pd.DataFrame(diarize_rows)
result = whisperx.assign_word_speakers(diarize_df, result)
logger.info("Speaker diarization complete")
except Exception as e:
logger.warning(f"Diarization failed (continuing without): {e}")
import traceback
traceback.print_exc()
# Build response
segments = result.get("segments", [])
full_text_parts = []
all_words = []
speaker_set = set()
for seg in segments:
full_text_parts.append(seg.get("text", ""))
speaker = seg.get("speaker")
if speaker:
speaker_set.add(speaker)
for word_info in seg.get("words", []):
all_words.append(WordSegment(
word=word_info.get("word", ""),
start=word_info.get("start", 0.0),
end=word_info.get("end", 0.0),
score=word_info.get("score"),
speaker=word_info.get("speaker", speaker),
))
text = " ".join(full_text_parts)
_vram.touch()
logger.info(
f"Transcription complete: {len(text)} chars, "
f"{len(all_words)} words, {len(speaker_set)} speakers"
)
return TranscriptionResult(
text=text.strip(),
language=detected_language,
segments=[{
"start": s.get("start", 0),
"end": s.get("end", 0),
"text": s.get("text", ""),
"speaker": s.get("speaker"),
} for s in segments],
words=all_words,
speakers=sorted(speaker_set),
)
async def transcribe_audio_bytes(
audio_bytes: bytes,
filename: str,
language: Optional[str] = None,
model_name: str = "large-v3",
align: bool = True,
diarize: bool = False,
min_speakers: Optional[int] = None,
max_speakers: Optional[int] = None,
) -> TranscriptionResult:
"""Transcribe audio from bytes (for API uploads)."""
import asyncio
ext = Path(filename).suffix or ".wav"
with tempfile.NamedTemporaryFile(suffix=ext, delete=False) as tmp:
tmp.write(audio_bytes)
tmp_path = tmp.name
try:
# Run in thread pool to avoid blocking the event loop
loop = asyncio.get_event_loop()
result = await loop.run_in_executor(
None,
lambda: transcribe_audio(
audio_path=tmp_path,
language=language,
model_name=model_name,
align=align,
diarize=diarize,
min_speakers=min_speakers,
max_speakers=max_speakers,
),
)
return result
finally:
try:
os.unlink(tmp_path)
except Exception:
pass
# Available models
AVAILABLE_MODELS = [
"tiny",
"tiny.en",
"base",
"base.en",
"small",
"small.en",
"medium",
"medium.en",
"large-v1",
"large-v2",
"large-v3",
"large-v3-turbo",
"distil-large-v2",
"distil-large-v3",
]