|
import gradio as gr |
|
import librosa |
|
import os |
|
import logging |
|
from pathlib import Path |
|
import torch |
|
from transformers import Wav2Vec2ForCTC, AutoProcessor |
|
import numpy as np |
|
import spaces |
|
|
|
|
|
logging.basicConfig(level=logging.DEBUG) |
|
logger = logging.getLogger(__name__) |
|
|
|
ASR_SAMPLING_RATE = 16_000 |
|
|
|
ASR_LANGUAGES = {} |
|
with open(f"data/asr/all_langs.tsv") as f: |
|
for line in f: |
|
iso, name = line.split(" ", 1) |
|
ASR_LANGUAGES[iso.strip()] = name.strip() |
|
|
|
MODEL_ID = "facebook/mms-1b-all" |
|
|
|
processor = AutoProcessor.from_pretrained(MODEL_ID) |
|
model = Wav2Vec2ForCTC.from_pretrained(MODEL_ID) |
|
|
|
def safe_process_file(file_obj): |
|
try: |
|
logger.debug(f"Processing file: {file_obj}") |
|
|
|
|
|
file_path = Path(file_obj) |
|
|
|
logger.debug(f"Loading audio from file path: {file_path}") |
|
|
|
|
|
audio_samples, sr = librosa.load(str(file_path), sr=ASR_SAMPLING_RATE, mono=True) |
|
|
|
safe_name = f"audio_{file_path.stem}.wav" |
|
logger.debug(f"File processed successfully: {safe_name}") |
|
return audio_samples, sr, safe_name |
|
except Exception as e: |
|
logger.error(f"Error processing file {getattr(file_obj, 'name', 'unknown')}: {str(e)}") |
|
raise |
|
|
|
def transcribe_multiple_files(audio_files, lang, transcription): |
|
transcriptions = [] |
|
|
|
try: |
|
audio_samples, sr, safe_name = safe_process_file(audio_files) |
|
logger.debug(f"Transcribing file {audio_files}: {safe_name}") |
|
logger.debug(f"Language selected: {lang}") |
|
logger.debug(f"User-provided transcription: {transcription}") |
|
|
|
result = transcribe_file(model, audio_samples, lang, transcription) |
|
logger.debug(f"Transcription result: {result}") |
|
|
|
|
|
transcriptions.append(f"File: {safe_name}\nTranscription: {result}\n") |
|
except Exception as e: |
|
logger.error(f"Error in transcription process: {str(e)}") |
|
transcriptions.append(f"Error processing file: {str(e)}\n") |
|
return "\n".join(transcriptions) |
|
|
|
@spaces.GPU |
|
def transcribe_file(model, audio_samples, lang, user_transcription): |
|
|
|
|
|
|
|
lang_code = lang.split()[0] |
|
processor.tokenizer.set_target_lang(lang_code) |
|
model.load_adapter(lang_code) |
|
|
|
inputs = processor( |
|
audio_samples, sampling_rate=ASR_SAMPLING_RATE, return_tensors="pt" |
|
) |
|
|
|
|
|
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
|
model.to(device) |
|
inputs = inputs.to(device) |
|
|
|
with torch.no_grad(): |
|
outputs = model(**inputs).logits |
|
|
|
ids = torch.argmax(outputs, dim=-1)[0] |
|
transcription = processor.decode(ids) |
|
|
|
|
|
|
|
|
|
|
|
|
|
return transcription |
|
|
|
@spaces.GPU |
|
def fine_tune_model(model, processor, user_transcription, audio_samples, lang_code): |
|
|
|
transcription_tensor = processor.tokenize(user_transcription, return_tensors="pt") |
|
|
|
|
|
dataset = [(audio_samples, transcription_tensor)] |
|
|
|
|
|
data_loader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False) |
|
|
|
|
|
model.train() |
|
|
|
|
|
criterion = torch.nn.CTCLoss() |
|
optimizer = torch.optim.Adam(model.parameters(), lr=0.001) |
|
|
|
|
|
for epoch in range(5): |
|
for batch in data_loader: |
|
audio, transcription = batch |
|
audio = audio.to(device) |
|
transcription = transcription.to(device) |
|
|
|
|
|
inputs = processor(audio, sampling_rate=ASR_SAMPLING_RATE, return_tensors="pt") |
|
outputs = model(**inputs).logits |
|
|
|
loss = criterion(outputs, transcription["input_ids"]) |
|
|
|
|
|
optimizer.zero_grad() |
|
loss.backward() |
|
optimizer.step() |
|
|
|
|
|
model.eval() |
|
|
|
return model |
|
|
|
ASR_EXAMPLES = [ |
|
["upload/english.mp3", "eng (English)"], |
|
|
|
|
|
] |
|
|
|
ASR_NOTE = """ |
|
The above demo doesn't use beam-search decoding using a language model. |
|
Checkout the instructions [here](https://huggingface.co/facebook/mms-1b-all) on how to run LM decoding for better accuracy. |
|
""" |