import gradio as gr import librosa import os import logging from pathlib import Path import torch from transformers import Wav2Vec2ForCTC, AutoProcessor import numpy as np import spaces # Настройка логирования logging.basicConfig(level=logging.DEBUG) logger = logging.getLogger(__name__) ASR_SAMPLING_RATE = 16_000 ASR_LANGUAGES = {} with open(f"data/asr/all_langs.tsv") as f: for line in f: iso, name = line.split(" ", 1) ASR_LANGUAGES[iso.strip()] = name.strip() MODEL_ID = "facebook/mms-1b-all" processor = AutoProcessor.from_pretrained(MODEL_ID) model = Wav2Vec2ForCTC.from_pretrained(MODEL_ID) def safe_process_file(file_obj): try: logger.debug(f"Processing file: {file_obj}") # Используем Path для безопасной обработки путей file_path = Path(file_obj) logger.debug(f"Loading audio from file path: {file_path}") # Используем librosa для загрузки аудио audio_samples, sr = librosa.load(str(file_path), sr=ASR_SAMPLING_RATE, mono=True) safe_name = f"audio_{file_path.stem}.wav" logger.debug(f"File processed successfully: {safe_name}") return audio_samples, sr, safe_name except Exception as e: logger.error(f"Error processing file {getattr(file_obj, 'name', 'unknown')}: {str(e)}") raise def transcribe_multiple_files(audio_files, lang, transcription): transcriptions = [] # for audio_file in audio_files: try: audio_samples, sr, safe_name = safe_process_file(audio_files) logger.debug(f"Transcribing file {audio_files}: {safe_name}") logger.debug(f"Language selected: {lang}") logger.debug(f"User-provided transcription: {transcription}") result = transcribe_file(model, audio_samples, lang, transcription) logger.debug(f"Transcription result: {result}") transcriptions.append(f"File: {safe_name}\nTranscription: {result}\n") except Exception as e: logger.error(f"Error in transcription process: {str(e)}") transcriptions.append(f"Error processing file: {str(e)}\n") return "\n".join(transcriptions) @spaces.GPU def transcribe_file(model, audio_samples, lang, user_transcription): # if not audio_samples: # return "<>" lang_code = lang.split()[0] processor.tokenizer.set_target_lang(lang_code) model.load_adapter(lang_code) inputs = processor( audio_samples, sampling_rate=ASR_SAMPLING_RATE, return_tensors="pt" ) # set device device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model.to(device) inputs = inputs.to(device) with torch.no_grad(): outputs = model(**inputs).logits ids = torch.argmax(outputs, dim=-1)[0] transcription = processor.decode(ids) # If user-provided transcription is available, use it to fine-tune the model #if user_transcription: #model = fine_tune_model(model, processor, user_transcription, audio_samples, lang_code) #logger.debug(f"Fine-tuning the model with user-provided transcription: {user_transcription}") return transcription @spaces.GPU def fine_tune_model(model, processor, user_transcription, audio_samples, lang_code): # Convert the user-provided transcription to a tensor transcription_tensor = processor.tokenize(user_transcription, return_tensors="pt") # Create a new dataset with the user-provided transcription and audio samples dataset = [(audio_samples, transcription_tensor)] # Create a data loader for the new dataset data_loader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False) # Set the model to training mode model.train() # Define the loss function and optimizer criterion = torch.nn.CTCLoss() optimizer = torch.optim.Adam(model.parameters(), lr=0.001) # Fine-tune the model on the new dataset for epoch in range(5): # fine-tune for 5 epochs for batch in data_loader: audio, transcription = batch audio = audio.to(device) transcription = transcription.to(device) # Forward pass inputs = processor(audio, sampling_rate=ASR_SAMPLING_RATE, return_tensors="pt") outputs = model(**inputs).logits loss = criterion(outputs, transcription["input_ids"]) # Backward pass optimizer.zero_grad() loss.backward() optimizer.step() # Set the model to evaluation mode model.eval() return model ASR_EXAMPLES = [ ["upload/english.mp3", "eng (English)"], # ["upload/tamil.mp3", "tam (Tamil)"], # ["upload/burmese.mp3", "mya (Burmese)"], ] ASR_NOTE = """ The above demo doesn't use beam-search decoding using a language model. Checkout the instructions [here](https://huggingface.co/facebook/mms-1b-all) on how to run LM decoding for better accuracy. """