MMS_1_10 / asr.py
bomolopuu's picture
added downloader and bam as default
48dfeff
import gradio as gr
import librosa
import os
import logging
from pathlib import Path
import torch
from transformers import Wav2Vec2ForCTC, AutoProcessor
import numpy as np
import spaces
# Настройка логирования
logging.basicConfig(level=logging.DEBUG)
logger = logging.getLogger(__name__)
ASR_SAMPLING_RATE = 16_000
ASR_LANGUAGES = {}
with open(f"data/asr/all_langs.tsv") as f:
for line in f:
iso, name = line.split(" ", 1)
ASR_LANGUAGES[iso.strip()] = name.strip()
MODEL_ID = "facebook/mms-1b-all"
processor = AutoProcessor.from_pretrained(MODEL_ID)
model = Wav2Vec2ForCTC.from_pretrained(MODEL_ID)
def safe_process_file(file_obj):
try:
logger.debug(f"Processing file: {file_obj}")
# Используем Path для безопасной обработки путей
file_path = Path(file_obj)
logger.debug(f"Loading audio from file path: {file_path}")
# Используем librosa для загрузки аудио
audio_samples, sr = librosa.load(str(file_path), sr=ASR_SAMPLING_RATE, mono=True)
safe_name = f"audio_{file_path.stem}.wav"
logger.debug(f"File processed successfully: {safe_name}")
return audio_samples, sr, safe_name
except Exception as e:
logger.error(f"Error processing file {getattr(file_obj, 'name', 'unknown')}: {str(e)}")
raise
def transcribe_multiple_files(audio_files, lang, transcription):
transcriptions = []
# for audio_file in audio_files:
try:
audio_samples, sr, safe_name = safe_process_file(audio_files)
logger.debug(f"Transcribing file {audio_files}: {safe_name}")
logger.debug(f"Language selected: {lang}")
logger.debug(f"User-provided transcription: {transcription}")
result = transcribe_file(model, audio_samples, lang, transcription)
logger.debug(f"Transcription result: {result}")
transcriptions.append(f"File: {safe_name}\nTranscription: {result}\n")
except Exception as e:
logger.error(f"Error in transcription process: {str(e)}")
transcriptions.append(f"Error processing file: {str(e)}\n")
return "\n".join(transcriptions)
@spaces.GPU
def transcribe_file(model, audio_samples, lang, user_transcription):
# if not audio_samples:
# return "<<ERROR: Empty Audio Input>>"
lang_code = lang.split()[0]
processor.tokenizer.set_target_lang(lang_code)
model.load_adapter(lang_code)
inputs = processor(
audio_samples, sampling_rate=ASR_SAMPLING_RATE, return_tensors="pt"
)
# set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)
inputs = inputs.to(device)
with torch.no_grad():
outputs = model(**inputs).logits
ids = torch.argmax(outputs, dim=-1)[0]
transcription = processor.decode(ids)
# If user-provided transcription is available, use it to fine-tune the model
#if user_transcription:
#model = fine_tune_model(model, processor, user_transcription, audio_samples, lang_code)
#logger.debug(f"Fine-tuning the model with user-provided transcription: {user_transcription}")
return transcription
@spaces.GPU
def fine_tune_model(model, processor, user_transcription, audio_samples, lang_code):
# Convert the user-provided transcription to a tensor
transcription_tensor = processor.tokenize(user_transcription, return_tensors="pt")
# Create a new dataset with the user-provided transcription and audio samples
dataset = [(audio_samples, transcription_tensor)]
# Create a data loader for the new dataset
data_loader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False)
# Set the model to training mode
model.train()
# Define the loss function and optimizer
criterion = torch.nn.CTCLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
# Fine-tune the model on the new dataset
for epoch in range(5): # fine-tune for 5 epochs
for batch in data_loader:
audio, transcription = batch
audio = audio.to(device)
transcription = transcription.to(device)
# Forward pass
inputs = processor(audio, sampling_rate=ASR_SAMPLING_RATE, return_tensors="pt")
outputs = model(**inputs).logits
loss = criterion(outputs, transcription["input_ids"])
# Backward pass
optimizer.zero_grad()
loss.backward()
optimizer.step()
# Set the model to evaluation mode
model.eval()
return model
ASR_EXAMPLES = [
["upload/english.mp3", "eng (English)"],
# ["upload/tamil.mp3", "tam (Tamil)"],
# ["upload/burmese.mp3", "mya (Burmese)"],
]
ASR_NOTE = """
The above demo doesn't use beam-search decoding using a language model.
Checkout the instructions [here](https://huggingface.co/facebook/mms-1b-all) on how to run LM decoding for better accuracy.
"""