Revert "tokenizeR"
Browse filesThis reverts commit 78b8142b0963bfa1c8eb08e63a1dbd9de9962a12.
asr.py
CHANGED
@@ -92,32 +92,53 @@ def transcribe_file(model, audio_samples, lang, user_transcription):
|
|
92 |
|
93 |
#return transcription
|
94 |
|
95 |
-
|
96 |
-
#
|
97 |
-
|
98 |
-
|
99 |
-
#
|
100 |
-
|
101 |
-
|
102 |
-
|
103 |
-
|
104 |
-
|
105 |
-
|
106 |
-
|
107 |
-
|
108 |
-
|
109 |
-
|
110 |
-
|
111 |
-
|
112 |
-
|
113 |
-
|
114 |
-
|
115 |
-
|
116 |
-
|
117 |
-
)
|
118 |
-
|
119 |
-
#
|
120 |
-
|
121 |
-
|
122 |
-
|
123 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
92 |
|
93 |
#return transcription
|
94 |
|
95 |
+
def fine_tune_model(model, processor, user_transcription, audio_samples, lang_code):
|
96 |
+
# Convert the user-provided transcription to a tensor
|
97 |
+
transcription_tensor = processor.tokenize(user_transcription, return_tensors="pt")
|
98 |
+
|
99 |
+
# Create a new dataset with the user-provided transcription and audio samples
|
100 |
+
dataset = [(audio_samples, transcription_tensor)]
|
101 |
+
|
102 |
+
# Create a data loader for the new dataset
|
103 |
+
data_loader = torch.utils.data.DataLoader(dataset, batch_size=1, shuffle=False)
|
104 |
+
|
105 |
+
# Set the model to training mode
|
106 |
+
model.train()
|
107 |
+
|
108 |
+
# Define the loss function and optimizer
|
109 |
+
criterion = torch.nn.CTCLoss()
|
110 |
+
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
|
111 |
+
|
112 |
+
# Fine-tune the model on the new dataset
|
113 |
+
for epoch in range(5): # fine-tune for 5 epochs
|
114 |
+
for batch in data_loader:
|
115 |
+
audio, transcription = batch
|
116 |
+
audio = audio.to(device)
|
117 |
+
transcription = transcription.to(device)
|
118 |
+
|
119 |
+
# Forward pass
|
120 |
+
inputs = processor(audio, sampling_rate=ASR_SAMPLING_RATE, return_tensors="pt")
|
121 |
+
outputs = model(**inputs).logits
|
122 |
+
|
123 |
+
loss = criterion(outputs, transcription["input_ids"])
|
124 |
+
|
125 |
+
# Backward pass
|
126 |
+
optimizer.zero_grad()
|
127 |
+
loss.backward()
|
128 |
+
optimizer.step()
|
129 |
+
|
130 |
+
# Set the model to evaluation mode
|
131 |
+
model.eval()
|
132 |
+
|
133 |
+
return model
|
134 |
+
|
135 |
+
ASR_EXAMPLES = [
|
136 |
+
["upload/english.mp3", "eng (English)"],
|
137 |
+
# ["upload/tamil.mp3", "tam (Tamil)"],
|
138 |
+
# ["upload/burmese.mp3", "mya (Burmese)"],
|
139 |
+
]
|
140 |
+
|
141 |
+
ASR_NOTE = """
|
142 |
+
The above demo doesn't use beam-search decoding using a language model.
|
143 |
+
Checkout the instructions [here](https://huggingface.co/facebook/mms-1b-all) on how to run LM decoding for better accuracy.
|
144 |
+
"""
|