fixing 4
Browse files
asr.py
CHANGED
@@ -1,12 +1,15 @@
|
|
|
|
1 |
import librosa
|
2 |
-
|
|
|
|
|
3 |
import torch
|
|
|
4 |
import numpy as np
|
5 |
-
from pathlib import Path
|
6 |
-
import os
|
7 |
|
8 |
-
|
9 |
-
|
|
|
10 |
|
11 |
ASR_SAMPLING_RATE = 16_000
|
12 |
|
@@ -21,25 +24,42 @@ MODEL_ID = "facebook/mms-1b-all"
|
|
21 |
processor = AutoProcessor.from_pretrained(MODEL_ID)
|
22 |
model = Wav2Vec2ForCTC.from_pretrained(MODEL_ID)
|
23 |
|
24 |
-
def
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
43 |
|
44 |
def transcribe_file(model, audio_samples, lang, user_transcription):
|
45 |
if not audio_samples:
|
@@ -54,16 +74,7 @@ def transcribe_file(model, audio_samples, lang, user_transcription):
|
|
54 |
)
|
55 |
|
56 |
# set device
|
57 |
-
if torch.cuda.is_available()
|
58 |
-
device = torch.device("cuda")
|
59 |
-
elif (
|
60 |
-
hasattr(torch.backends, "mps")
|
61 |
-
and torch.backends.mps.is_available()
|
62 |
-
and torch.backends.mps.is_built()
|
63 |
-
):
|
64 |
-
device = torch.device("mps")
|
65 |
-
else:
|
66 |
-
device = torch.device("cpu")
|
67 |
|
68 |
model.to(device)
|
69 |
inputs = inputs.to(device)
|
@@ -71,95 +82,42 @@ def transcribe_file(model, audio_samples, lang, user_transcription):
|
|
71 |
with torch.no_grad():
|
72 |
outputs = model(**inputs).logits
|
73 |
|
74 |
-
|
75 |
-
|
76 |
-
transcription = processor.decode(ids)
|
77 |
-
else:
|
78 |
-
assert False
|
79 |
-
# beam_search_result = beam_search_decoder(outputs.to("cpu"))
|
80 |
-
# transcription = " ".join(beam_search_result[0][0].words).strip()
|
81 |
|
82 |
# If user-provided transcription is available, use it to fine-tune the model
|
83 |
if user_transcription:
|
84 |
-
# Update the model's weights using the user-provided transcription
|
85 |
model = fine_tune_model(model, processor, user_transcription, audio_samples, lang_code)
|
86 |
-
|
87 |
|
88 |
return transcription
|
89 |
|
90 |
def fine_tune_model(model, processor, user_transcription, audio_samples, lang_code):
|
91 |
-
#
|
92 |
-
|
93 |
-
|
94 |
-
|
95 |
-
|
96 |
-
|
97 |
-
|
98 |
-
|
99 |
-
|
100 |
-
|
101 |
-
|
102 |
-
|
103 |
-
|
104 |
-
|
105 |
-
|
106 |
-
|
107 |
-
|
108 |
-
|
109 |
-
|
110 |
-
|
111 |
-
|
112 |
-
|
113 |
-
|
114 |
-
|
115 |
-
|
116 |
-
criterion = torch.nn.CTCLoss()
|
117 |
-
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
|
118 |
-
|
119 |
-
# Move the model to the device
|
120 |
-
model.to(device)
|
121 |
-
|
122 |
-
# Fine-tune the model on the new dataset
|
123 |
-
for epoch in range(5): # fine-tune for 5 epochs
|
124 |
-
for batch in data_loader:
|
125 |
-
audio, transcription = batch
|
126 |
-
audio = audio.to(device)
|
127 |
-
transcription = transcription.to(device)
|
128 |
-
|
129 |
-
# Forward pass
|
130 |
-
inputs = processor(audio, sampling_rate=ASR_SAMPLING_RATE, return_tensors ="pt")
|
131 |
-
inputs = inputs.to(device)
|
132 |
-
outputs = model(**inputs).logits
|
133 |
-
|
134 |
-
# Calculate the loss
|
135 |
-
loss = criterion(outputs, transcription)
|
136 |
-
|
137 |
-
# Backward pass
|
138 |
-
optimizer.zero_grad()
|
139 |
-
loss.backward()
|
140 |
-
optimizer.step()
|
141 |
-
|
142 |
-
# Set the model to evaluation mode
|
143 |
-
model.eval()
|
144 |
-
|
145 |
-
return model
|
146 |
-
|
147 |
-
def beam_search_decoder(logits):
|
148 |
-
# Define the beam search parameters
|
149 |
-
beam_width = 10
|
150 |
-
alpha = 0.7
|
151 |
-
|
152 |
-
# Initialize the beam search decoder
|
153 |
-
decoder = ctc_decoder.CTCTokenizer(
|
154 |
-
logits, beam_width=beam_width, alpha=alpha, blank_index=processor.tokenizer.pad_token_id
|
155 |
-
)
|
156 |
-
|
157 |
-
# Decode the logits
|
158 |
-
decoded = decoder.decode()
|
159 |
-
|
160 |
-
return decoded
|
161 |
|
162 |
if __name__ == "__main__":
|
163 |
-
|
164 |
-
model = Wav2Vec2ForCTC.from_pretrained(MODEL_ID)
|
165 |
-
transcribe(model, audio_dir)
|
|
|
1 |
+
import gradio as gr
|
2 |
import librosa
|
3 |
+
import os
|
4 |
+
import logging
|
5 |
+
from pathlib import Path
|
6 |
import torch
|
7 |
+
from transformers import Wav2Vec2ForCTC, AutoProcessor
|
8 |
import numpy as np
|
|
|
|
|
9 |
|
10 |
+
# Настройка логирования
|
11 |
+
logging.basicConfig(level=logging.DEBUG)
|
12 |
+
logger = logging.getLogger(__name__)
|
13 |
|
14 |
ASR_SAMPLING_RATE = 16_000
|
15 |
|
|
|
24 |
processor = AutoProcessor.from_pretrained(MODEL_ID)
|
25 |
model = Wav2Vec2ForCTC.from_pretrained(MODEL_ID)
|
26 |
|
27 |
+
def safe_process_file(file_obj):
|
28 |
+
try:
|
29 |
+
logger.debug(f"Processing file: {file_obj.name}")
|
30 |
+
|
31 |
+
# Используем Path для безопасной обработки путей
|
32 |
+
file_path = Path(file_obj.name)
|
33 |
+
|
34 |
+
logger.debug(f"Loading audio from file path: {file_path}")
|
35 |
+
|
36 |
+
# Используем librosa для загрузки аудио
|
37 |
+
audio_samples, sr = librosa.load(str(file_path), sr=ASR_SAMPLING_RATE, mono=True)
|
38 |
+
|
39 |
+
safe_name = f"audio_{file_path.stem}.wav"
|
40 |
+
logger.debug(f"File processed successfully: {safe_name}")
|
41 |
+
return audio_samples, sr, safe_name
|
42 |
+
except Exception as e:
|
43 |
+
logger.error(f"Error processing file {getattr(file_obj, 'name', 'unknown')}: {str(e)}")
|
44 |
+
raise
|
45 |
+
|
46 |
+
def transcribe_multiple_files(audio_files, lang, transcription):
|
47 |
+
transcriptions = []
|
48 |
+
for audio_file in audio_files:
|
49 |
+
try:
|
50 |
+
audio_samples, sr, safe_name = safe_process_file(audio_file)
|
51 |
+
logger.debug(f"Transcribing file: {safe_name}")
|
52 |
+
logger.debug(f"Language selected: {lang}")
|
53 |
+
logger.debug(f"User-provided transcription: {transcription}")
|
54 |
+
|
55 |
+
result = transcribe_file(model, audio_samples, lang, transcription)
|
56 |
+
logger.debug(f"Transcription result: {result}")
|
57 |
+
|
58 |
+
transcriptions.append(f"File: {safe_name}\nTranscription: {result}\n")
|
59 |
+
except Exception as e:
|
60 |
+
logger.error(f"Error in transcription process: {str(e)}")
|
61 |
+
transcriptions.append(f"Error processing file: {str(e)}\n")
|
62 |
+
return "\n".join(transcriptions)
|
63 |
|
64 |
def transcribe_file(model, audio_samples, lang, user_transcription):
|
65 |
if not audio_samples:
|
|
|
74 |
)
|
75 |
|
76 |
# set device
|
77 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
78 |
|
79 |
model.to(device)
|
80 |
inputs = inputs.to(device)
|
|
|
82 |
with torch.no_grad():
|
83 |
outputs = model(**inputs).logits
|
84 |
|
85 |
+
ids = torch.argmax(outputs, dim=-1)[0]
|
86 |
+
transcription = processor.decode(ids)
|
|
|
|
|
|
|
|
|
|
|
87 |
|
88 |
# If user-provided transcription is available, use it to fine-tune the model
|
89 |
if user_transcription:
|
|
|
90 |
model = fine_tune_model(model, processor, user_transcription, audio_samples, lang_code)
|
91 |
+
logger.debug(f"Fine-tuning the model with user-provided transcription: {user_transcription}")
|
92 |
|
93 |
return transcription
|
94 |
|
95 |
def fine_tune_model(model, processor, user_transcription, audio_samples, lang_code):
|
96 |
+
# Implementation of fine_tune_model remains the same
|
97 |
+
# ...
|
98 |
+
|
99 |
+
# Подготовка опций языка для Dropdown
|
100 |
+
language_options = [f"{k} ({v})" for k, v in ASR_LANGUAGES.items()]
|
101 |
+
|
102 |
+
mms_transcribe = gr.Interface(
|
103 |
+
fn=transcribe_multiple_files,
|
104 |
+
inputs=[
|
105 |
+
gr.File(label="Audio Files", file_count="multiple"),
|
106 |
+
gr.Dropdown(
|
107 |
+
choices=language_options,
|
108 |
+
label="Language",
|
109 |
+
value=language_options[0] if language_options else None,
|
110 |
+
),
|
111 |
+
gr.Textbox(label="Optional: Provide your own transcription"),
|
112 |
+
],
|
113 |
+
outputs=gr.Textbox(label="Transcriptions", lines=10),
|
114 |
+
title="Speech-to-text",
|
115 |
+
description="Transcribe multiple audio files in your desired language.",
|
116 |
+
allow_flagging="never",
|
117 |
+
)
|
118 |
+
|
119 |
+
# Остальной код интерфейса остается без изменений
|
120 |
+
# ...
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
121 |
|
122 |
if __name__ == "__main__":
|
123 |
+
demo.launch()
|
|
|
|