bomolopuu commited on
Commit
b9b5a7c
·
1 Parent(s): 6643110
Files changed (1) hide show
  1. asr.py +74 -116
asr.py CHANGED
@@ -1,12 +1,15 @@
 
1
  import librosa
2
- from transformers import Wav2Vec2ForCTC, AutoProcessor
 
 
3
  import torch
 
4
  import numpy as np
5
- from pathlib import Path
6
- import os
7
 
8
- from huggingface_hub import hf_hub_download
9
- from torchaudio.models.decoder import ctc_decoder
 
10
 
11
  ASR_SAMPLING_RATE = 16_000
12
 
@@ -21,25 +24,42 @@ MODEL_ID = "facebook/mms-1b-all"
21
  processor = AutoProcessor.from_pretrained(MODEL_ID)
22
  model = Wav2Vec2ForCTC.from_pretrained(MODEL_ID)
23
 
24
- def transcribe(model, audio_dir, lang="eng (English)", user_transcription=None):
25
- # Получить список файлов в папке
26
- files = os.listdir(audio_dir)
27
-
28
- # Обработать каждый файл в папке
29
- for file in files:
30
- # Проверить, является ли файл аудиофайлом
31
- if file.endswith(".mp3") or file.endswith(".wav"):
32
- # Загрузить аудиофайл
33
- audio_path = os.path.join(audio_dir, file)
34
- audio_samples = librosa.load(audio_path, sr=ASR_SAMPLING_RATE, mono=True)[0]
35
-
36
- # Обработать аудиофайл
37
- transcription = transcribe_file(model, audio_samples, lang, user_transcription)
38
-
39
- # Вывести результат
40
- print(f"Файл: {file}")
41
- print(f"Транскрипция: {transcription}")
42
- print()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
43
 
44
  def transcribe_file(model, audio_samples, lang, user_transcription):
45
  if not audio_samples:
@@ -54,16 +74,7 @@ def transcribe_file(model, audio_samples, lang, user_transcription):
54
  )
55
 
56
  # set device
57
- if torch.cuda.is_available():
58
- device = torch.device("cuda")
59
- elif (
60
- hasattr(torch.backends, "mps")
61
- and torch.backends.mps.is_available()
62
- and torch.backends.mps.is_built()
63
- ):
64
- device = torch.device("mps")
65
- else:
66
- device = torch.device("cpu")
67
 
68
  model.to(device)
69
  inputs = inputs.to(device)
@@ -71,95 +82,42 @@ def transcribe_file(model, audio_samples, lang, user_transcription):
71
  with torch.no_grad():
72
  outputs = model(**inputs).logits
73
 
74
- if lang_code != "eng" or True:
75
- ids = torch.argmax(outputs, dim=-1)[0]
76
- transcription = processor.decode(ids)
77
- else:
78
- assert False
79
- # beam_search_result = beam_search_decoder(outputs.to("cpu"))
80
- # transcription = " ".join(beam_search_result[0][0].words).strip()
81
 
82
  # If user-provided transcription is available, use it to fine-tune the model
83
  if user_transcription:
84
- # Update the model's weights using the user-provided transcription
85
  model = fine_tune_model(model, processor, user_transcription, audio_samples, lang_code)
86
- print(f"Fine-tuning the model with user-provided transcription: {user_transcription}")
87
 
88
  return transcription
89
 
90
  def fine_tune_model(model, processor, user_transcription, audio_samples, lang_code):
91
- # Define the device
92
- if torch.cuda.is_available():
93
- device = torch.device("cuda")
94
- elif (
95
- hasattr(torch.backends, "mps")
96
- and torch.backends.mps.is_available()
97
- and torch.backends.mps.is_built()
98
- ):
99
- device = torch.device("mps")
100
- else:
101
- device = torch.device("cpu")
102
-
103
- # Convert the user-provided transcription to a tensor
104
- transcription_tensor = processor.tokenizer(user_transcription, return_tensors="pt")
105
-
106
- # Create a new dataset with the user-provided transcription and audio samples
107
- dataset = [(audio_samples, transcription_tensor)]
108
-
109
- # Create a data loader for the new dataset
110
- data_loader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False)
111
-
112
- # Set the model to training mode
113
- model.train()
114
-
115
- # Define the loss function and optimizer
116
- criterion = torch.nn.CTCLoss()
117
- optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
118
-
119
- # Move the model to the device
120
- model.to(device)
121
-
122
- # Fine-tune the model on the new dataset
123
- for epoch in range(5): # fine-tune for 5 epochs
124
- for batch in data_loader:
125
- audio, transcription = batch
126
- audio = audio.to(device)
127
- transcription = transcription.to(device)
128
-
129
- # Forward pass
130
- inputs = processor(audio, sampling_rate=ASR_SAMPLING_RATE, return_tensors ="pt")
131
- inputs = inputs.to(device)
132
- outputs = model(**inputs).logits
133
-
134
- # Calculate the loss
135
- loss = criterion(outputs, transcription)
136
-
137
- # Backward pass
138
- optimizer.zero_grad()
139
- loss.backward()
140
- optimizer.step()
141
-
142
- # Set the model to evaluation mode
143
- model.eval()
144
-
145
- return model
146
-
147
- def beam_search_decoder(logits):
148
- # Define the beam search parameters
149
- beam_width = 10
150
- alpha = 0.7
151
-
152
- # Initialize the beam search decoder
153
- decoder = ctc_decoder.CTCTokenizer(
154
- logits, beam_width=beam_width, alpha=alpha, blank_index=processor.tokenizer.pad_token_id
155
- )
156
-
157
- # Decode the logits
158
- decoded = decoder.decode()
159
-
160
- return decoded
161
 
162
  if __name__ == "__main__":
163
- audio_dir = "/path/to/audio/files"
164
- model = Wav2Vec2ForCTC.from_pretrained(MODEL_ID)
165
- transcribe(model, audio_dir)
 
1
+ import gradio as gr
2
  import librosa
3
+ import os
4
+ import logging
5
+ from pathlib import Path
6
  import torch
7
+ from transformers import Wav2Vec2ForCTC, AutoProcessor
8
  import numpy as np
 
 
9
 
10
+ # Настройка логирования
11
+ logging.basicConfig(level=logging.DEBUG)
12
+ logger = logging.getLogger(__name__)
13
 
14
  ASR_SAMPLING_RATE = 16_000
15
 
 
24
  processor = AutoProcessor.from_pretrained(MODEL_ID)
25
  model = Wav2Vec2ForCTC.from_pretrained(MODEL_ID)
26
 
27
+ def safe_process_file(file_obj):
28
+ try:
29
+ logger.debug(f"Processing file: {file_obj.name}")
30
+
31
+ # Используем Path для безопасной обработки путей
32
+ file_path = Path(file_obj.name)
33
+
34
+ logger.debug(f"Loading audio from file path: {file_path}")
35
+
36
+ # Используем librosa для загрузки аудио
37
+ audio_samples, sr = librosa.load(str(file_path), sr=ASR_SAMPLING_RATE, mono=True)
38
+
39
+ safe_name = f"audio_{file_path.stem}.wav"
40
+ logger.debug(f"File processed successfully: {safe_name}")
41
+ return audio_samples, sr, safe_name
42
+ except Exception as e:
43
+ logger.error(f"Error processing file {getattr(file_obj, 'name', 'unknown')}: {str(e)}")
44
+ raise
45
+
46
+ def transcribe_multiple_files(audio_files, lang, transcription):
47
+ transcriptions = []
48
+ for audio_file in audio_files:
49
+ try:
50
+ audio_samples, sr, safe_name = safe_process_file(audio_file)
51
+ logger.debug(f"Transcribing file: {safe_name}")
52
+ logger.debug(f"Language selected: {lang}")
53
+ logger.debug(f"User-provided transcription: {transcription}")
54
+
55
+ result = transcribe_file(model, audio_samples, lang, transcription)
56
+ logger.debug(f"Transcription result: {result}")
57
+
58
+ transcriptions.append(f"File: {safe_name}\nTranscription: {result}\n")
59
+ except Exception as e:
60
+ logger.error(f"Error in transcription process: {str(e)}")
61
+ transcriptions.append(f"Error processing file: {str(e)}\n")
62
+ return "\n".join(transcriptions)
63
 
64
  def transcribe_file(model, audio_samples, lang, user_transcription):
65
  if not audio_samples:
 
74
  )
75
 
76
  # set device
77
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
 
 
 
 
 
 
 
 
 
78
 
79
  model.to(device)
80
  inputs = inputs.to(device)
 
82
  with torch.no_grad():
83
  outputs = model(**inputs).logits
84
 
85
+ ids = torch.argmax(outputs, dim=-1)[0]
86
+ transcription = processor.decode(ids)
 
 
 
 
 
87
 
88
  # If user-provided transcription is available, use it to fine-tune the model
89
  if user_transcription:
 
90
  model = fine_tune_model(model, processor, user_transcription, audio_samples, lang_code)
91
+ logger.debug(f"Fine-tuning the model with user-provided transcription: {user_transcription}")
92
 
93
  return transcription
94
 
95
  def fine_tune_model(model, processor, user_transcription, audio_samples, lang_code):
96
+ # Implementation of fine_tune_model remains the same
97
+ # ...
98
+
99
+ # Подготовка опций языка для Dropdown
100
+ language_options = [f"{k} ({v})" for k, v in ASR_LANGUAGES.items()]
101
+
102
+ mms_transcribe = gr.Interface(
103
+ fn=transcribe_multiple_files,
104
+ inputs=[
105
+ gr.File(label="Audio Files", file_count="multiple"),
106
+ gr.Dropdown(
107
+ choices=language_options,
108
+ label="Language",
109
+ value=language_options[0] if language_options else None,
110
+ ),
111
+ gr.Textbox(label="Optional: Provide your own transcription"),
112
+ ],
113
+ outputs=gr.Textbox(label="Transcriptions", lines=10),
114
+ title="Speech-to-text",
115
+ description="Transcribe multiple audio files in your desired language.",
116
+ allow_flagging="never",
117
+ )
118
+
119
+ # Остальной код интерфейса остается без изменений
120
+ # ...
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
121
 
122
  if __name__ == "__main__":
123
+ demo.launch()