not-lain commited on
Commit
b8354e9
·
1 Parent(s): ad44b76
Files changed (1) hide show
  1. app.py +55 -41
app.py CHANGED
@@ -179,44 +179,56 @@ def transcribe(audio, task="transcribe"):
179
  raise gr.Error("No audio file submitted!")
180
 
181
  device = "cuda" if torch.cuda.is_available() else "cpu"
182
- compute_type = "float16" # can be changed to "int8" if low on GPU memory
183
  batch_size = 8 # reduced batch size to be conservative with memory
184
 
185
- # 1. Load model and transcribe
186
- model = whisperx.load_model("large-v2", device, compute_type=compute_type)
187
- audio_input = whisperx.load_audio(audio)
188
- result = model.transcribe(audio_input, batch_size=batch_size)
189
-
190
- # Clear GPU memory
191
- del model
192
- gc.collect()
193
- torch.cuda.empty_cache()
194
-
195
- # 2. Align whisper output
196
- model_a, metadata = whisperx.load_align_model(language_code=result["language"], device=device)
197
- result = whisperx.align(result["segments"], model_a, metadata, audio_input, device, return_char_alignments=False)
198
-
199
- # Clear GPU memory
200
- del model_a
201
- gc.collect()
202
- torch.cuda.empty_cache()
203
-
204
- # 3. Assign speaker labels
205
- diarize_model = whisperx.DiarizationPipeline(device=device)
206
- diarize_segments = diarize_model(audio_input)
207
-
208
- # Combine transcription with speaker diarization
209
- result = whisperx.assign_word_speakers(diarize_segments, result)
210
-
211
- # Format output with speaker labels and timestamps
212
- formatted_text = ""
213
- for segment in result["segments"]:
214
- speaker = f"[Speaker {segment['speaker']}]" if "speaker" in segment else ""
215
- start_time = f"{segment.get('start', 0):.2f}"
216
- end_time = f"{segment.get('end', 0):.2f}"
217
- formatted_text += f"[{start_time}s - {end_time}s] {speaker}: {segment['text']}\n"
218
-
219
- return formatted_text
 
 
 
 
 
 
 
 
 
 
 
 
220
 
221
 
222
  @spaces.GPU(duration=120)
@@ -330,13 +342,15 @@ erase_tab = gr.Interface(
330
  transcribe_tab = gr.Interface(
331
  fn=main,
332
  inputs=[
333
- gr.Number(6, interactive=False),
334
- gr.Audio(type="filepath"),
335
- gr.Radio(["transcribe", "translate"], label="Task", value="transcribe"),
336
  ],
337
- outputs="text",
 
 
338
  api_name="transcribe",
339
- description="Upload an audio file to extract text using Whisper Large V3",
340
  )
341
 
342
  demo = gr.TabbedInterface(
 
179
  raise gr.Error("No audio file submitted!")
180
 
181
  device = "cuda" if torch.cuda.is_available() else "cpu"
182
+ compute_type = "float16"
183
  batch_size = 8 # reduced batch size to be conservative with memory
184
 
185
+ try:
186
+ # 1. Load model and transcribe
187
+ model = whisperx.load_model("large-v2", device, compute_type=compute_type)
188
+ audio_input = whisperx.load_audio(audio)
189
+ result = model.transcribe(audio_input, batch_size=batch_size)
190
+
191
+ # Clear GPU memory
192
+ del model
193
+ gc.collect()
194
+ torch.cuda.empty_cache()
195
+
196
+ # 2. Align whisper output
197
+ model_a, metadata = whisperx.load_align_model(language_code=result["language"], device=device)
198
+ result = whisperx.align(result["segments"], model_a, metadata, audio_input, device, return_char_alignments=False)
199
+
200
+ # Clear GPU memory
201
+ del model_a
202
+ gc.collect()
203
+ torch.cuda.empty_cache()
204
+
205
+ # 3. Assign speaker labels
206
+ diarize_model = whisperx.DiarizationPipeline(device=device)
207
+ diarize_segments = diarize_model(audio_input)
208
+
209
+ # Combine transcription with speaker diarization
210
+ result = whisperx.assign_word_speakers(diarize_segments, result)
211
+
212
+ # Format output with speaker labels and timestamps
213
+ formatted_text = []
214
+ for segment in result["segments"]:
215
+ if not isinstance(segment, dict):
216
+ continue
217
+
218
+ speaker = f"[Speaker {segment.get('speaker', 'Unknown')}]"
219
+ start_time = f"{float(segment.get('start', 0)):.2f}"
220
+ end_time = f"{float(segment.get('end', 0)):.2f}"
221
+ text = segment.get('text', '').strip()
222
+ formatted_text.append(f"[{start_time}s - {end_time}s] {speaker}: {text}")
223
+
224
+ return "\n".join(formatted_text)
225
+
226
+ except Exception as e:
227
+ raise gr.Error(f"Transcription failed: {str(e)}")
228
+ finally:
229
+ # Ensure GPU memory is cleared even if an error occurs
230
+ gc.collect()
231
+ torch.cuda.empty_cache()
232
 
233
 
234
  @spaces.GPU(duration=120)
 
342
  transcribe_tab = gr.Interface(
343
  fn=main,
344
  inputs=[
345
+ gr.Number(value=6, visible=False, precision=0), # API number
346
+ gr.Audio(type="filepath", label="Audio File"),
347
+ gr.Radio(choices=["transcribe", "translate"], value="transcribe", label="Task", visible=True),
348
  ],
349
+ outputs=gr.Textbox(label="Transcription"),
350
+ title="Audio Transcription",
351
+ description="Upload an audio file to extract text using WhisperX with speaker diarization",
352
  api_name="transcribe",
353
+ examples=[]
354
  )
355
 
356
  demo = gr.TabbedInterface(