diff --git a/services/mana-stt/app/voxtral_service.py b/services/mana-stt/app/voxtral_service.py index 88cf8fd5d..00ab0d737 100644 --- a/services/mana-stt/app/voxtral_service.py +++ b/services/mana-stt/app/voxtral_service.py @@ -38,7 +38,7 @@ def get_voxtral_model(model_name: str = "mistralai/Voxtral-Mini-3B-2507"): logger.info(f"Loading Voxtral model: {model_name}") try: import torch - from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor + from transformers import AutoModel, AutoProcessor # Determine device if torch.backends.mps.is_available(): @@ -59,14 +59,18 @@ def get_voxtral_model(model_name: str = "mistralai/Voxtral-Mini-3B-2507"): trust_remote_code=True, ) - # Load model - _voxtral_model = AutoModelForSpeechSeq2Seq.from_pretrained( + # Load model - Voxtral uses AutoModel, not AutoModelForSpeechSeq2Seq + _voxtral_model = AutoModel.from_pretrained( model_name, torch_dtype=torch_dtype, - device_map="auto", + device_map="auto" if device != "mps" else None, trust_remote_code=True, ) + # Move to MPS if available (device_map doesn't support MPS) + if device == "mps": + _voxtral_model = _voxtral_model.to(device) + logger.info(f"Voxtral model loaded successfully on {device}") except ImportError as e: @@ -90,6 +94,9 @@ def transcribe_audio( """ Transcribe audio file using Voxtral. + Voxtral is a multimodal audio understanding model that can be prompted + for transcription tasks. + Args: audio_path: Path to audio file language: Target language for transcription @@ -100,6 +107,7 @@ def transcribe_audio( """ import torch import soundfile as sf + import numpy as np model, processor = get_voxtral_model(model_name) @@ -109,37 +117,70 @@ def transcribe_audio( # Load audio audio_array, sample_rate = sf.read(audio_path) - # Resample to 16kHz if needed - if sample_rate != 16000: - import numpy as np + # Convert stereo to mono if needed + if len(audio_array.shape) > 1: + audio_array = np.mean(audio_array, axis=1) + + # Resample to 24kHz (Voxtral's expected sample rate) + target_sr = 24000 + if sample_rate != target_sr: from scipy import signal - num_samples = int(len(audio_array) * 16000 / sample_rate) + num_samples = int(len(audio_array) * target_sr / sample_rate) audio_array = signal.resample(audio_array, num_samples) - sample_rate = 16000 + sample_rate = target_sr - # Process audio + # Language mapping for prompts + lang_names = { + "de": "German", + "en": "English", + "fr": "French", + "es": "Spanish", + "pt": "Portuguese", + "it": "Italian", + "nl": "Dutch", + "hi": "Hindi", + } + lang_name = lang_names.get(language, "German") + + # Create transcription prompt + messages = [ + { + "role": "user", + "content": [ + {"type": "audio_url", "audio_url": {"url": f"data:audio/wav;base64,PLACEHOLDER"}}, + {"type": "text", "text": f"Transcribe this audio in {lang_name}. Only output the transcription, nothing else."}, + ], + } + ] + + # Apply chat template and process inputs + prompt = processor.apply_chat_template(messages, tokenize=False) + + # Process audio with the processor inputs = processor( - audio_array, + text=prompt, + audios=[audio_array], sampling_rate=sample_rate, return_tensors="pt", ) # Move to same device as model device = next(model.parameters()).device - inputs = {k: v.to(device) for k, v in inputs.items()} + inputs = {k: v.to(device) if hasattr(v, 'to') else v for k, v in inputs.items()} # Generate transcription with torch.no_grad(): generated_ids = model.generate( **inputs, - max_new_tokens=448, - language=language, + max_new_tokens=512, + do_sample=False, ) - # Decode + # Decode only the generated tokens (exclude input) + input_len = inputs.get("input_ids", inputs.get("input_features")).shape[-1] text = processor.batch_decode( - generated_ids, + generated_ids[:, input_len:], skip_special_tokens=True, )[0]