add transcription and fine-tuning
Browse files
app.py
CHANGED
@@ -15,6 +15,7 @@ mms_transcribe = gr.Interface(
|
|
15 |
label="Language",
|
16 |
value="eng English",
|
17 |
),
|
|
|
18 |
# gr.Checkbox(label="Use Language Model (if available)", default=True),
|
19 |
],
|
20 |
outputs="text",
|
|
|
15 |
label="Language",
|
16 |
value="eng English",
|
17 |
),
|
18 |
+
gr.Textbox(label="Optional: Provide your own transcription"),
|
19 |
# gr.Checkbox(label="Use Language Model (if available)", default=True),
|
20 |
],
|
21 |
outputs="text",
|
asr.py
CHANGED
@@ -67,8 +67,7 @@ model = Wav2Vec2ForCTC.from_pretrained(MODEL_ID)
|
|
67 |
# )
|
68 |
|
69 |
|
70 |
-
def transcribe(audio_data=None, lang="eng (English)"):
|
71 |
-
|
72 |
if not audio_data:
|
73 |
return "<<ERROR: Empty Audio Input>>"
|
74 |
|
@@ -82,7 +81,6 @@ def transcribe(audio_data=None, lang="eng (English)"):
|
|
82 |
)
|
83 |
else:
|
84 |
# file upload
|
85 |
-
|
86 |
if not isinstance(audio_data, str):
|
87 |
return "<<ERROR: Invalid Audio Input Instance: {}>>".format(type(audio_data))
|
88 |
audio_samples = librosa.load(audio_data, sr=ASR_SAMPLING_RATE, mono=True)[0]
|
@@ -121,8 +119,53 @@ def transcribe(audio_data=None, lang="eng (English)"):
|
|
121 |
# beam_search_result = beam_search_decoder(outputs.to("cpu"))
|
122 |
# transcription = " ".join(beam_search_result[0][0].words).strip()
|
123 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
124 |
return transcription
|
125 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
126 |
|
127 |
ASR_EXAMPLES = [
|
128 |
["upload/english.mp3", "eng (English)"],
|
|
|
67 |
# )
|
68 |
|
69 |
|
70 |
+
def transcribe(audio_data=None, lang="eng (English)", user_transcription=None):
|
|
|
71 |
if not audio_data:
|
72 |
return "<<ERROR: Empty Audio Input>>"
|
73 |
|
|
|
81 |
)
|
82 |
else:
|
83 |
# file upload
|
|
|
84 |
if not isinstance(audio_data, str):
|
85 |
return "<<ERROR: Invalid Audio Input Instance: {}>>".format(type(audio_data))
|
86 |
audio_samples = librosa.load(audio_data, sr=ASR_SAMPLING_RATE, mono=True)[0]
|
|
|
119 |
# beam_search_result = beam_search_decoder(outputs.to("cpu"))
|
120 |
# transcription = " ".join(beam_search_result[0][0].words).strip()
|
121 |
|
122 |
+
# If user-provided transcription is available, use it to fine-tune the model
|
123 |
+
if user_transcription:
|
124 |
+
# Update the model's weights using the user-provided transcription
|
125 |
+
model = fine_tune_model(model, processor, user_transcription, audio_samples, lang_code)
|
126 |
+
# This is a placeholder, you'll need to implement the actual fine-tuning logic
|
127 |
+
print(f"Fine-tuning the model with user-provided transcription: {user_transcription}")
|
128 |
+
# ...
|
129 |
+
|
130 |
return transcription
|
131 |
|
132 |
+
def fine_tune_model(model, processor, user_transcription, audio_samples, lang_code):
|
133 |
+
# Convert the user-provided transcription to a tensor
|
134 |
+
transcription_tensor = processor.text_to_tensor(user_transcription)
|
135 |
+
|
136 |
+
# Create a new dataset with the user-provided transcription and audio samples
|
137 |
+
dataset = [(audio_samples, transcription_tensor)]
|
138 |
+
|
139 |
+
# Create a data loader for the new dataset
|
140 |
+
data_loader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False)
|
141 |
+
|
142 |
+
# Set the model to training mode
|
143 |
+
model.train()
|
144 |
+
|
145 |
+
# Define the loss function and optimizer
|
146 |
+
criterion = torch.nn.CTCLoss()
|
147 |
+
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
|
148 |
+
|
149 |
+
# Fine-tune the model on the new dataset
|
150 |
+
for epoch in range(5): # fine-tune for 5 epochs
|
151 |
+
for batch in data_loader:
|
152 |
+
audio, transcription = batch
|
153 |
+
audio = audio.to(device)
|
154 |
+
transcription = transcription.to(device)
|
155 |
+
|
156 |
+
# Forward pass
|
157 |
+
outputs = model(audio)
|
158 |
+
loss = criterion(outputs, transcription)
|
159 |
+
|
160 |
+
# Backward pass
|
161 |
+
optimizer.zero_grad()
|
162 |
+
loss.backward()
|
163 |
+
optimizer.step()
|
164 |
+
|
165 |
+
# Set the model to evaluation mode
|
166 |
+
model.eval()
|
167 |
+
|
168 |
+
return model
|
169 |
|
170 |
ASR_EXAMPLES = [
|
171 |
["upload/english.mp3", "eng (English)"],
|