File size: 5,055 Bytes
b9b5a7c
7bcf8d7
b9b5a7c
 
 
7bcf8d7
616e430
0a6cd62
d3bad75
7bcf8d7
b9b5a7c
 
 
7bcf8d7
 
 
 
 
 
 
50a74e8
7bcf8d7
616e430
7bcf8d7
 
616e430
7bcf8d7
b9b5a7c
 
d3bad75
b9b5a7c
 
d3bad75
b9b5a7c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5d13184
 
 
48dfeff
5d13184
 
 
 
 
 
 
 
 
 
 
48dfeff
ef7a855
d3bad75
ef7a855
46c2ede
 
d3faf34
7bcf8d7
 
 
 
 
 
 
 
 
b9b5a7c
7bcf8d7
 
 
 
 
 
 
b9b5a7c
 
7bcf8d7
67ce7a9
581e11a
 
 
67ce7a9
46c2ede
7bcf8d7
d3bad75
4129671
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
import gradio as gr
import librosa
import os
import logging
from pathlib import Path
import torch
from transformers import Wav2Vec2ForCTC, AutoProcessor
import numpy as np
import spaces

# Настройка логирования
logging.basicConfig(level=logging.DEBUG)
logger = logging.getLogger(__name__)

ASR_SAMPLING_RATE = 16_000

ASR_LANGUAGES = {}
with open(f"data/asr/all_langs.tsv") as f:
    for line in f:
        iso, name = line.split(" ", 1)
        ASR_LANGUAGES[iso.strip()] = name.strip()

MODEL_ID = "facebook/mms-1b-all"

processor = AutoProcessor.from_pretrained(MODEL_ID)
model = Wav2Vec2ForCTC.from_pretrained(MODEL_ID)

def safe_process_file(file_obj):
    try:
        logger.debug(f"Processing file: {file_obj}")
        
        # Используем Path для безопасной обработки путей
        file_path = Path(file_obj)
        
        logger.debug(f"Loading audio from file path: {file_path}")
        
        # Используем librosa для загрузки аудио
        audio_samples, sr = librosa.load(str(file_path), sr=ASR_SAMPLING_RATE, mono=True)
        
        safe_name = f"audio_{file_path.stem}.wav"
        logger.debug(f"File processed successfully: {safe_name}")
        return audio_samples, sr, safe_name
    except Exception as e:
        logger.error(f"Error processing file {getattr(file_obj, 'name', 'unknown')}: {str(e)}")
        raise

def transcribe_multiple_files(audio_files, lang, transcription):
    transcriptions = []
    # for audio_file in audio_files:
    try:
        audio_samples, sr, safe_name = safe_process_file(audio_files)
        logger.debug(f"Transcribing file {audio_files}: {safe_name}")
        logger.debug(f"Language selected: {lang}")
        logger.debug(f"User-provided transcription: {transcription}")
        
        result = transcribe_file(model, audio_samples, lang, transcription)
        logger.debug(f"Transcription result: {result}")

        
        transcriptions.append(f"File: {safe_name}\nTranscription: {result}\n")
    except Exception as e:
        logger.error(f"Error in transcription process: {str(e)}")
        transcriptions.append(f"Error processing file: {str(e)}\n")
    return "\n".join(transcriptions)

@spaces.GPU
def transcribe_file(model, audio_samples, lang, user_transcription):
    # if not audio_samples:
    #     return "<<ERROR: Empty Audio Input>>"
    
    lang_code = lang.split()[0]
    processor.tokenizer.set_target_lang(lang_code)
    model.load_adapter(lang_code)

    inputs = processor(
        audio_samples, sampling_rate=ASR_SAMPLING_RATE, return_tensors="pt"
    )

    # set device
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    model.to(device)
    inputs = inputs.to(device)

    with torch.no_grad():
        outputs = model(**inputs).logits

    ids = torch.argmax(outputs, dim=-1)[0]
    transcription = processor.decode(ids)

    # If user-provided transcription is available, use it to fine-tune the model
    #if user_transcription:
        #model = fine_tune_model(model, processor, user_transcription, audio_samples, lang_code)
        #logger.debug(f"Fine-tuning the model with user-provided transcription: {user_transcription}")

    return transcription

@spaces.GPU
def fine_tune_model(model, processor, user_transcription, audio_samples, lang_code):
    # Convert the user-provided transcription to a tensor
    transcription_tensor = processor.tokenize(user_transcription, return_tensors="pt")

    # Create a new dataset with the user-provided transcription and audio samples
    dataset = [(audio_samples, transcription_tensor)]

    # Create a data loader for the new dataset
    data_loader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False)

    # Set the model to training mode
    model.train()

    # Define the loss function and optimizer
    criterion = torch.nn.CTCLoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=0.001)

    # Fine-tune the model on the new dataset
    for epoch in range(5):  # fine-tune for 5 epochs
        for batch in data_loader:
            audio, transcription = batch
            audio = audio.to(device)
            transcription = transcription.to(device)

            # Forward pass
            inputs = processor(audio, sampling_rate=ASR_SAMPLING_RATE, return_tensors="pt")
            outputs = model(**inputs).logits

            loss = criterion(outputs, transcription["input_ids"])

            # Backward pass
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

    # Set the model to evaluation mode
    model.eval()

    return model

ASR_EXAMPLES = [
    ["upload/english.mp3", "eng (English)"],
    # ["upload/tamil.mp3", "tam (Tamil)"],
    # ["upload/burmese.mp3",  "mya (Burmese)"],
]

ASR_NOTE = """
The above demo doesn't use beam-search decoding using a language model. 
Checkout the instructions [here](https://huggingface.co/facebook/mms-1b-all) on how to run LM decoding for better accuracy.
"""