mirror of
https://github.com/Memo-2023/mana-monorepo.git
synced 2026-05-14 20:01:09 +02:00
fix(stt): use correct AutoModel for Voxtral multimodal architecture
This commit is contained in:
parent
9bd699aec7
commit
49255ac794
1 changed files with 57 additions and 16 deletions
|
|
@ -38,7 +38,7 @@ def get_voxtral_model(model_name: str = "mistralai/Voxtral-Mini-3B-2507"):
|
|||
logger.info(f"Loading Voxtral model: {model_name}")
|
||||
try:
|
||||
import torch
|
||||
from transformers import AutoModelForSpeechSeq2Seq, AutoProcessor
|
||||
from transformers import AutoModel, AutoProcessor
|
||||
|
||||
# Determine device
|
||||
if torch.backends.mps.is_available():
|
||||
|
|
@ -59,14 +59,18 @@ def get_voxtral_model(model_name: str = "mistralai/Voxtral-Mini-3B-2507"):
|
|||
trust_remote_code=True,
|
||||
)
|
||||
|
||||
# Load model
|
||||
_voxtral_model = AutoModelForSpeechSeq2Seq.from_pretrained(
|
||||
# Load model - Voxtral uses AutoModel, not AutoModelForSpeechSeq2Seq
|
||||
_voxtral_model = AutoModel.from_pretrained(
|
||||
model_name,
|
||||
torch_dtype=torch_dtype,
|
||||
device_map="auto",
|
||||
device_map="auto" if device != "mps" else None,
|
||||
trust_remote_code=True,
|
||||
)
|
||||
|
||||
# Move to MPS if available (device_map doesn't support MPS)
|
||||
if device == "mps":
|
||||
_voxtral_model = _voxtral_model.to(device)
|
||||
|
||||
logger.info(f"Voxtral model loaded successfully on {device}")
|
||||
|
||||
except ImportError as e:
|
||||
|
|
@ -90,6 +94,9 @@ def transcribe_audio(
|
|||
"""
|
||||
Transcribe audio file using Voxtral.
|
||||
|
||||
Voxtral is a multimodal audio understanding model that can be prompted
|
||||
for transcription tasks.
|
||||
|
||||
Args:
|
||||
audio_path: Path to audio file
|
||||
language: Target language for transcription
|
||||
|
|
@ -100,6 +107,7 @@ def transcribe_audio(
|
|||
"""
|
||||
import torch
|
||||
import soundfile as sf
|
||||
import numpy as np
|
||||
|
||||
model, processor = get_voxtral_model(model_name)
|
||||
|
||||
|
|
@ -109,37 +117,70 @@ def transcribe_audio(
|
|||
# Load audio
|
||||
audio_array, sample_rate = sf.read(audio_path)
|
||||
|
||||
# Resample to 16kHz if needed
|
||||
if sample_rate != 16000:
|
||||
import numpy as np
|
||||
# Convert stereo to mono if needed
|
||||
if len(audio_array.shape) > 1:
|
||||
audio_array = np.mean(audio_array, axis=1)
|
||||
|
||||
# Resample to 24kHz (Voxtral's expected sample rate)
|
||||
target_sr = 24000
|
||||
if sample_rate != target_sr:
|
||||
from scipy import signal
|
||||
|
||||
num_samples = int(len(audio_array) * 16000 / sample_rate)
|
||||
num_samples = int(len(audio_array) * target_sr / sample_rate)
|
||||
audio_array = signal.resample(audio_array, num_samples)
|
||||
sample_rate = 16000
|
||||
sample_rate = target_sr
|
||||
|
||||
# Process audio
|
||||
# Language mapping for prompts
|
||||
lang_names = {
|
||||
"de": "German",
|
||||
"en": "English",
|
||||
"fr": "French",
|
||||
"es": "Spanish",
|
||||
"pt": "Portuguese",
|
||||
"it": "Italian",
|
||||
"nl": "Dutch",
|
||||
"hi": "Hindi",
|
||||
}
|
||||
lang_name = lang_names.get(language, "German")
|
||||
|
||||
# Create transcription prompt
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "audio_url", "audio_url": {"url": f"data:audio/wav;base64,PLACEHOLDER"}},
|
||||
{"type": "text", "text": f"Transcribe this audio in {lang_name}. Only output the transcription, nothing else."},
|
||||
],
|
||||
}
|
||||
]
|
||||
|
||||
# Apply chat template and process inputs
|
||||
prompt = processor.apply_chat_template(messages, tokenize=False)
|
||||
|
||||
# Process audio with the processor
|
||||
inputs = processor(
|
||||
audio_array,
|
||||
text=prompt,
|
||||
audios=[audio_array],
|
||||
sampling_rate=sample_rate,
|
||||
return_tensors="pt",
|
||||
)
|
||||
|
||||
# Move to same device as model
|
||||
device = next(model.parameters()).device
|
||||
inputs = {k: v.to(device) for k, v in inputs.items()}
|
||||
inputs = {k: v.to(device) if hasattr(v, 'to') else v for k, v in inputs.items()}
|
||||
|
||||
# Generate transcription
|
||||
with torch.no_grad():
|
||||
generated_ids = model.generate(
|
||||
**inputs,
|
||||
max_new_tokens=448,
|
||||
language=language,
|
||||
max_new_tokens=512,
|
||||
do_sample=False,
|
||||
)
|
||||
|
||||
# Decode
|
||||
# Decode only the generated tokens (exclude input)
|
||||
input_len = inputs.get("input_ids", inputs.get("input_features")).shape[-1]
|
||||
text = processor.batch_decode(
|
||||
generated_ids,
|
||||
generated_ids[:, input_len:],
|
||||
skip_special_tokens=True,
|
||||
)[0]
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue