fix(stt): use correct AutoModel for Voxtral multimodal architecture

This commit is contained in:
Till-JS 2026-01-27 01:58:32 +01:00
parent 9bd699aec7
commit 49255ac794

View file

@ -38,7 +38,7 @@ def get_voxtral_model(model_name: str = "mistralai/Voxtral-Mini-3B-2507"):
logger.info(f"Loading Voxtral model: {model_name}")
try:
import torch
from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor
from transformers import AutoModel, AutoProcessor
# Determine device
if torch.backends.mps.is_available():
@ -59,14 +59,18 @@ def get_voxtral_model(model_name: str = "mistralai/Voxtral-Mini-3B-2507"):
trust_remote_code=True,
)
# Load model
_voxtral_model = AutoModelForSpeechSeq2Seq.from_pretrained(
# Load model - Voxtral uses AutoModel, not AutoModelForSpeechSeq2Seq
_voxtral_model = AutoModel.from_pretrained(
model_name,
torch_dtype=torch_dtype,
device_map="auto",
device_map="auto" if device != "mps" else None,
trust_remote_code=True,
)
# Move to MPS if available (device_map doesn't support MPS)
if device == "mps":
_voxtral_model = _voxtral_model.to(device)
logger.info(f"Voxtral model loaded successfully on {device}")
except ImportError as e:
@ -90,6 +94,9 @@ def transcribe_audio(
"""
Transcribe audio file using Voxtral.
Voxtral is a multimodal audio understanding model that can be prompted
for transcription tasks.
Args:
audio_path: Path to audio file
language: Target language for transcription
@ -100,6 +107,7 @@ def transcribe_audio(
"""
import torch
import soundfile as sf
import numpy as np
model, processor = get_voxtral_model(model_name)
@ -109,37 +117,70 @@ def transcribe_audio(
# Load audio
audio_array, sample_rate = sf.read(audio_path)
# Resample to 16kHz if needed
if sample_rate != 16000:
import numpy as np
# Convert stereo to mono if needed
if len(audio_array.shape) > 1:
audio_array = np.mean(audio_array, axis=1)
# Resample to 24kHz (Voxtral's expected sample rate)
target_sr = 24000
if sample_rate != target_sr:
from scipy import signal
num_samples = int(len(audio_array) * 16000 / sample_rate)
num_samples = int(len(audio_array) * target_sr / sample_rate)
audio_array = signal.resample(audio_array, num_samples)
sample_rate = 16000
sample_rate = target_sr
# Process audio
# Language mapping for prompts
lang_names = {
"de": "German",
"en": "English",
"fr": "French",
"es": "Spanish",
"pt": "Portuguese",
"it": "Italian",
"nl": "Dutch",
"hi": "Hindi",
}
lang_name = lang_names.get(language, "German")
# Create transcription prompt
messages = [
{
"role": "user",
"content": [
{"type": "audio_url", "audio_url": {"url": f"data:audio/wav;base64,PLACEHOLDER"}},
{"type": "text", "text": f"Transcribe this audio in {lang_name}. Only output the transcription, nothing else."},
],
}
]
# Apply chat template and process inputs
prompt = processor.apply_chat_template(messages, tokenize=False)
# Process audio with the processor
inputs = processor(
audio_array,
text=prompt,
audios=[audio_array],
sampling_rate=sample_rate,
return_tensors="pt",
)
# Move to same device as model
device = next(model.parameters()).device
inputs = {k: v.to(device) for k, v in inputs.items()}
inputs = {k: v.to(device) if hasattr(v, 'to') else v for k, v in inputs.items()}
# Generate transcription
with torch.no_grad():
generated_ids = model.generate(
**inputs,
max_new_tokens=448,
language=language,
max_new_tokens=512,
do_sample=False,
)
# Decode
# Decode only the generated tokens (exclude input)
input_len = inputs.get("input_ids", inputs.get("input_features")).shape[-1]
text = processor.batch_decode(
generated_ids,
generated_ids[:, input_len:],
skip_special_tokens=True,
)[0]