Nithin Rao Koluguri commited on
Commit
7e3fe0c
·
1 Parent(s): 025dfc0

Add support for longer audio inference

Browse files

Signed-off-by: Nithin Rao Koluguri <nithinraok>

Files changed (1) hide show
  1. app.py +34 -1
app.py CHANGED
@@ -90,6 +90,7 @@ def get_transcripts_and_raw_times(audio_path, session_dir):
90
  try:
91
  gr.Info(f"Loading audio: {original_path_name}", duration=2)
92
  audio = AudioSegment.from_file(audio_path)
 
93
  except Exception as load_e:
94
  gr.Error(f"Failed to load audio file {original_path_name}: {load_e}", duration=None)
95
  # Return an update to hide the button
@@ -137,9 +138,27 @@ def get_transcripts_and_raw_times(audio_path, session_dir):
137
  transcribe_path = audio_path
138
  info_path_name = original_path_name
139
 
 
 
140
  try:
141
  model.to(device)
 
142
  gr.Info(f"Transcribing {info_path_name} on {device}...", duration=2)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
143
  output = model.transcribe([transcribe_path], timestamps=True)
144
 
145
  if not output or not isinstance(output, list) or not output[0] or not hasattr(output[0], 'timestamp') or not output[0].timestamp or 'segment' not in output[0].timestamp:
@@ -194,7 +213,20 @@ def get_transcripts_and_raw_times(audio_path, session_dir):
194
  # Return an update to hide the button
195
  return vis_data, raw_times_data, audio_path, gr.DownloadButton(visible=False)
196
  finally:
 
197
  try:
 
 
 
 
 
 
 
 
 
 
 
 
198
  if 'model' in locals() and hasattr(model, 'cpu'):
199
  if device == 'cuda':
200
  model.cpu()
@@ -204,6 +236,7 @@ def get_transcripts_and_raw_times(audio_path, session_dir):
204
  except Exception as cleanup_e:
205
  print(f"Error during model cleanup: {cleanup_e}")
206
  gr.Warning(f"Issue during model cleanup: {cleanup_e}", duration=5)
 
207
 
208
  finally:
209
  if processed_audio_path and os.path.exists(processed_audio_path):
@@ -253,7 +286,7 @@ article = (
253
  "<ul style='font-size: 1.1em;'>"
254
  " <li>Automatic punctuation and capitalization</li>"
255
  " <li>Accurate word-level timestamps (click on a segment in the table below to play it!)</li>"
256
- " <li>Efficiently transcribes long audio segments (up to 20 minutes) <small>(For even longer audios, see <a href='https://github.com/NVIDIA/NeMo/blob/main/examples/asr/asr_chunked_inference/rnnt/speech_to_text_buffered_infer_rnnt.py' target='_blank'>this script</a>)</small></li>"
257
  " <li>Robust performance on spoken numbers, and song lyrics transcription </li>"
258
  "</ul>"
259
  "<p style='font-size: 1.1em;'>"
 
90
  try:
91
  gr.Info(f"Loading audio: {original_path_name}", duration=2)
92
  audio = AudioSegment.from_file(audio_path)
93
+ duration_sec = audio.duration_seconds
94
  except Exception as load_e:
95
  gr.Error(f"Failed to load audio file {original_path_name}: {load_e}", duration=None)
96
  # Return an update to hide the button
 
138
  transcribe_path = audio_path
139
  info_path_name = original_path_name
140
 
141
+ # Flag to track if long audio settings were applied
142
+ long_audio_settings_applied = False
143
  try:
144
  model.to(device)
145
+ model.to(torch.float32)
146
  gr.Info(f"Transcribing {info_path_name} on {device}...", duration=2)
147
+
148
+ # Check duration and apply specific settings for long audio
149
+ if duration_sec > 900: # 15 minutes
150
+ try:
151
+ gr.Info("Audio longer than 15 minutes. Applying optimized settings for long transcription.", duration=3)
152
+ print("Applying long audio settings: Local Attention and Chunking.")
153
+ model.change_attention_model("rel_pos_local_attn", [256,256])
154
+ model.change_subsampling_conv_chunking_factor(1) # 1 = auto select
155
+ long_audio_settings_applied = True
156
+ except Exception as setting_e:
157
+ gr.Warning(f"Could not apply long audio settings: {setting_e}", duration=5)
158
+ print(f"Warning: Failed to apply long audio settings: {setting_e}")
159
+ # Proceed without long audio settings if applying them failed
160
+
161
+ model.to(torch.bfloat16)
162
  output = model.transcribe([transcribe_path], timestamps=True)
163
 
164
  if not output or not isinstance(output, list) or not output[0] or not hasattr(output[0], 'timestamp') or not output[0].timestamp or 'segment' not in output[0].timestamp:
 
213
  # Return an update to hide the button
214
  return vis_data, raw_times_data, audio_path, gr.DownloadButton(visible=False)
215
  finally:
216
+ # --- Model Cleanup ---
217
  try:
218
+ # Revert settings if they were applied for long audio
219
+ if long_audio_settings_applied:
220
+ try:
221
+ print("Reverting long audio settings.")
222
+ model.change_attention_model("rel_pos", [-1,-1])
223
+ model.change_subsampling_conv_chunking_factor(-1)
224
+ long_audio_settings_applied = False # Reset flag
225
+ except Exception as revert_e:
226
+ print(f"Warning: Failed to revert long audio settings: {revert_e}")
227
+ gr.Warning(f"Issue reverting model settings after long transcription: {revert_e}", duration=5)
228
+
229
+ # Original cleanup
230
  if 'model' in locals() and hasattr(model, 'cpu'):
231
  if device == 'cuda':
232
  model.cpu()
 
236
  except Exception as cleanup_e:
237
  print(f"Error during model cleanup: {cleanup_e}")
238
  gr.Warning(f"Issue during model cleanup: {cleanup_e}", duration=5)
239
+ # --- End Model Cleanup ---
240
 
241
  finally:
242
  if processed_audio_path and os.path.exists(processed_audio_path):
 
286
  "<ul style='font-size: 1.1em;'>"
287
  " <li>Automatic punctuation and capitalization</li>"
288
  " <li>Accurate word-level timestamps (click on a segment in the table below to play it!)</li>"
289
+ " <li>Efficiently transcribes long audio segments (<strong>updated to support upto 3 hours</strong>) <small>(For even longer audios, see <a href='https://github.com/NVIDIA/NeMo/blob/main/examples/asr/asr_chunked_inference/rnnt/speech_to_text_buffered_infer_rnnt.py' target='_blank'>this script</a>)</small></li>"
290
  " <li>Robust performance on spoken numbers, and song lyrics transcription </li>"
291
  "</ul>"
292
  "<p style='font-size: 1.1em;'>"