mirror of
https://github.com/Memo-2023/mana-monorepo.git
synced 2026-05-14 20:01:09 +02:00
fix(stt): properly encode audio as base64 for Voxtral
This commit is contained in:
parent
e7e3561463
commit
a2233dc366
1 changed files with 23 additions and 26 deletions
|
|
@ -106,29 +106,28 @@ def transcribe_audio(
|
|||
VoxtralTranscriptionResult with transcribed text
|
||||
"""
|
||||
import torch
|
||||
import soundfile as sf
|
||||
import numpy as np
|
||||
|
||||
model, processor = get_voxtral_model(model_name)
|
||||
|
||||
logger.info(f"Transcribing with Voxtral: {audio_path}")
|
||||
|
||||
try:
|
||||
# Load audio
|
||||
audio_array, sample_rate = sf.read(audio_path)
|
||||
# Load audio file as bytes and encode to base64
|
||||
with open(audio_path, "rb") as f:
|
||||
audio_bytes = f.read()
|
||||
audio_base64 = base64.b64encode(audio_bytes).decode("utf-8")
|
||||
|
||||
# Convert stereo to mono if needed
|
||||
if len(audio_array.shape) > 1:
|
||||
audio_array = np.mean(audio_array, axis=1)
|
||||
|
||||
# Resample to 24kHz (Voxtral's expected sample rate)
|
||||
target_sr = 24000
|
||||
if sample_rate != target_sr:
|
||||
from scipy import signal
|
||||
|
||||
num_samples = int(len(audio_array) * target_sr / sample_rate)
|
||||
audio_array = signal.resample(audio_array, num_samples)
|
||||
sample_rate = target_sr
|
||||
# Determine audio format from extension
|
||||
ext = Path(audio_path).suffix.lower()
|
||||
mime_types = {
|
||||
".wav": "audio/wav",
|
||||
".mp3": "audio/mpeg",
|
||||
".m4a": "audio/m4a",
|
||||
".flac": "audio/flac",
|
||||
".ogg": "audio/ogg",
|
||||
".webm": "audio/webm",
|
||||
}
|
||||
mime_type = mime_types.get(ext, "audio/wav")
|
||||
|
||||
# Language mapping for prompts
|
||||
lang_names = {
|
||||
|
|
@ -143,26 +142,24 @@ def transcribe_audio(
|
|||
}
|
||||
lang_name = lang_names.get(language, "German")
|
||||
|
||||
# Create transcription prompt
|
||||
# Create transcription prompt with base64 audio
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "audio_url", "audio_url": {"url": f"data:audio/wav;base64,PLACEHOLDER"}},
|
||||
{"type": "audio_url", "audio_url": {"url": f"data:{mime_type};base64,{audio_base64}"}},
|
||||
{"type": "text", "text": f"Transcribe this audio in {lang_name}. Only output the transcription, nothing else."},
|
||||
],
|
||||
}
|
||||
]
|
||||
|
||||
# Apply chat template and process inputs
|
||||
prompt = processor.apply_chat_template(messages, tokenize=False)
|
||||
|
||||
# Process audio with the processor
|
||||
inputs = processor(
|
||||
text=prompt,
|
||||
audios=[audio_array],
|
||||
sampling_rate=sample_rate,
|
||||
inputs = processor.apply_chat_template(
|
||||
messages,
|
||||
tokenize=True,
|
||||
add_generation_prompt=True,
|
||||
return_tensors="pt",
|
||||
return_dict=True,
|
||||
)
|
||||
|
||||
# Move to same device as model
|
||||
|
|
@ -178,7 +175,7 @@ def transcribe_audio(
|
|||
)
|
||||
|
||||
# Decode only the generated tokens (exclude input)
|
||||
input_len = inputs.get("input_ids", inputs.get("input_features")).shape[-1]
|
||||
input_len = inputs["input_ids"].shape[-1]
|
||||
text = processor.batch_decode(
|
||||
generated_ids[:, input_len:],
|
||||
skip_special_tokens=True,
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue