bomolopuu commited on
Commit
67ce7a9
·
1 Parent(s): f3731ec

add transcription and fine-tuning

Browse files
Files changed (2) hide show
  1. app.py +1 -0
  2. asr.py +46 -3
app.py CHANGED
@@ -15,6 +15,7 @@ mms_transcribe = gr.Interface(
15
  label="Language",
16
  value="eng English",
17
  ),
 
18
  # gr.Checkbox(label="Use Language Model (if available)", default=True),
19
  ],
20
  outputs="text",
 
15
  label="Language",
16
  value="eng English",
17
  ),
18
+ gr.Textbox(label="Optional: Provide your own transcription"),
19
  # gr.Checkbox(label="Use Language Model (if available)", default=True),
20
  ],
21
  outputs="text",
asr.py CHANGED
@@ -67,8 +67,7 @@ model = Wav2Vec2ForCTC.from_pretrained(MODEL_ID)
67
  # )
68
 
69
 
70
- def transcribe(audio_data=None, lang="eng (English)"):
71
-
72
  if not audio_data:
73
  return "<<ERROR: Empty Audio Input>>"
74
 
@@ -82,7 +81,6 @@ def transcribe(audio_data=None, lang="eng (English)"):
82
  )
83
  else:
84
  # file upload
85
-
86
  if not isinstance(audio_data, str):
87
  return "<<ERROR: Invalid Audio Input Instance: {}>>".format(type(audio_data))
88
  audio_samples = librosa.load(audio_data, sr=ASR_SAMPLING_RATE, mono=True)[0]
@@ -121,8 +119,53 @@ def transcribe(audio_data=None, lang="eng (English)"):
121
  # beam_search_result = beam_search_decoder(outputs.to("cpu"))
122
  # transcription = " ".join(beam_search_result[0][0].words).strip()
123
 
 
 
 
 
 
 
 
 
124
  return transcription
125
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
126
 
127
  ASR_EXAMPLES = [
128
  ["upload/english.mp3", "eng (English)"],
 
67
  # )
68
 
69
 
70
+ def transcribe(audio_data=None, lang="eng (English)", user_transcription=None):
 
71
  if not audio_data:
72
  return "<<ERROR: Empty Audio Input>>"
73
 
 
81
  )
82
  else:
83
  # file upload
 
84
  if not isinstance(audio_data, str):
85
  return "<<ERROR: Invalid Audio Input Instance: {}>>".format(type(audio_data))
86
  audio_samples = librosa.load(audio_data, sr=ASR_SAMPLING_RATE, mono=True)[0]
 
119
  # beam_search_result = beam_search_decoder(outputs.to("cpu"))
120
  # transcription = " ".join(beam_search_result[0][0].words).strip()
121
 
122
+ # If user-provided transcription is available, use it to fine-tune the model
123
+ if user_transcription:
124
+ # Update the model's weights using the user-provided transcription
125
+ model = fine_tune_model(model, processor, user_transcription, audio_samples, lang_code)
126
+ # This is a placeholder, you'll need to implement the actual fine-tuning logic
127
+ print(f"Fine-tuning the model with user-provided transcription: {user_transcription}")
128
+ # ...
129
+
130
  return transcription
131
 
132
+ def fine_tune_model(model, processor, user_transcription, audio_samples, lang_code):
133
+ # Convert the user-provided transcription to a tensor
134
+ transcription_tensor = processor.text_to_tensor(user_transcription)
135
+
136
+ # Create a new dataset with the user-provided transcription and audio samples
137
+ dataset = [(audio_samples, transcription_tensor)]
138
+
139
+ # Create a data loader for the new dataset
140
+ data_loader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False)
141
+
142
+ # Set the model to training mode
143
+ model.train()
144
+
145
+ # Define the loss function and optimizer
146
+ criterion = torch.nn.CTCLoss()
147
+ optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
148
+
149
+ # Fine-tune the model on the new dataset
150
+ for epoch in range(5): # fine-tune for 5 epochs
151
+ for batch in data_loader:
152
+ audio, transcription = batch
153
+ audio = audio.to(device)
154
+ transcription = transcription.to(device)
155
+
156
+ # Forward pass
157
+ outputs = model(audio)
158
+ loss = criterion(outputs, transcription)
159
+
160
+ # Backward pass
161
+ optimizer.zero_grad()
162
+ loss.backward()
163
+ optimizer.step()
164
+
165
+ # Set the model to evaluation mode
166
+ model.eval()
167
+
168
+ return model
169
 
170
  ASR_EXAMPLES = [
171
  ["upload/english.mp3", "eng (English)"],