LAP-DEV commited on
Commit
fbcf93f
·
verified ·
1 Parent(s): b958773

Upload whisper_base.py

Browse files
Files changed (1) hide show
  1. modules/whisper/whisper_base.py +758 -0
modules/whisper/whisper_base.py ADDED
@@ -0,0 +1,758 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import whisper
4
+ import gradio as gr
5
+ import torchaudio
6
+ from abc import ABC, abstractmethod
7
+ from typing import BinaryIO, Union, Tuple, List
8
+ import numpy as np
9
+ from datetime import datetime
10
+ from faster_whisper.vad import VadOptions
11
+ from dataclasses import astuple
12
+ import gc
13
+ from copy import deepcopy
14
+ from modules.vad.silero_vad import merge_chunks, Segment
15
+ from modules.uvr.music_separator import MusicSeparator
16
+ from modules.utils.paths import (WHISPER_MODELS_DIR, DIARIZATION_MODELS_DIR, OUTPUT_DIR, DEFAULT_PARAMETERS_CONFIG_PATH,
17
+ UVR_MODELS_DIR)
18
+ from modules.utils.subtitle_manager import get_srt, get_vtt, get_txt, get_plaintext, get_csv, write_file, safe_filename
19
+ from modules.utils.youtube_manager import get_ytdata, get_ytaudio
20
+ from modules.utils.files_manager import get_media_files, format_gradio_files, load_yaml, save_yaml
21
+ from modules.whisper.whisper_parameter import *
22
+ from modules.diarize.diarizer import Diarizer
23
+ from modules.vad.silero_vad import SileroVAD
24
+ from modules.translation.nllb_inference import NLLBInference
25
+ from modules.translation.nllb_inference import NLLB_AVAILABLE_LANGS
26
+ import faster_whisper
27
+
28
+ class WhisperBase(ABC):
29
+ def __init__(self,
30
+ model_dir: str = WHISPER_MODELS_DIR,
31
+ diarization_model_dir: str = DIARIZATION_MODELS_DIR,
32
+ uvr_model_dir: str = UVR_MODELS_DIR,
33
+ output_dir: str = OUTPUT_DIR,
34
+ ):
35
+ self.model_dir = model_dir
36
+ self.output_dir = output_dir
37
+ os.makedirs(self.output_dir, exist_ok=True)
38
+ os.makedirs(self.model_dir, exist_ok=True)
39
+ self.diarizer = Diarizer(
40
+ model_dir=diarization_model_dir
41
+ )
42
+ self.vad = SileroVAD()
43
+ self.music_separator = MusicSeparator(
44
+ model_dir=uvr_model_dir,
45
+ output_dir=os.path.join(output_dir, "UVR")
46
+ )
47
+
48
+ self.model = None
49
+ self.current_model_size = None
50
+ self.available_models = whisper.available_models()
51
+ self.available_langs = sorted(list(whisper.tokenizer.LANGUAGES.values()))
52
+ #self.translatable_models = ["large", "large-v1", "large-v2", "large-v3"]
53
+ self.translatable_models = whisper.available_models()
54
+ self.device = self.get_device()
55
+ self.available_compute_types = ["float16", "float32"]
56
+ self.current_compute_type = "float16" if self.device == "cuda" else "float32"
57
+
58
+ @abstractmethod
59
+ def transcribe(self,
60
+ audio: Union[str, BinaryIO, np.ndarray],
61
+ progress: gr.Progress = gr.Progress(),
62
+ *whisper_params,
63
+ ):
64
+ """Inference whisper model to transcribe"""
65
+ pass
66
+
67
+ @abstractmethod
68
+ def update_model(self,
69
+ model_size: str,
70
+ compute_type: str,
71
+ progress: gr.Progress = gr.Progress()
72
+ ):
73
+ """Initialize whisper model"""
74
+ pass
75
+
76
+ def run(self,
77
+ audio: Union[str, BinaryIO, np.ndarray],
78
+ progress: gr.Progress = gr.Progress(),
79
+ add_timestamp: bool = True,
80
+ *whisper_params,
81
+ ) -> Tuple[List[dict], float]:
82
+ """
83
+ Run transcription with conditional pre-processing and post-processing.
84
+ The VAD will be performed to remove noise from the audio input in pre-processing, if enabled.
85
+ The diarization will be performed in post-processing, if enabled.
86
+
87
+ Parameters
88
+ ----------
89
+ audio: Union[str, BinaryIO, np.ndarray]
90
+ Audio input. This can be file path or binary type.
91
+ progress: gr.Progress
92
+ Indicator to show progress directly in gradio.
93
+ add_timestamp: bool
94
+ Whether to add a timestamp at the end of the filename.
95
+ *whisper_params: tuple
96
+ Parameters related with whisper. This will be dealt with "WhisperParameters" data class
97
+
98
+ Returns
99
+ ----------
100
+ segments_result: List[dict]
101
+ list of dicts that includes start, end timestamps and transcribed text
102
+ elapsed_time: float
103
+ elapsed time for running
104
+ """
105
+
106
+ start_time = datetime.now()
107
+ params = WhisperParameters.as_value(*whisper_params)
108
+
109
+ # Get the offload params
110
+ default_params = load_yaml(DEFAULT_PARAMETERS_CONFIG_PATH)
111
+ whisper_params = default_params["whisper"]
112
+ diarization_params = default_params["diarization"]
113
+ bool_whisper_enable_offload = whisper_params["enable_offload"]
114
+ bool_diarization_enable_offload = diarization_params["enable_offload"]
115
+
116
+ if params.lang is None:
117
+ pass
118
+ elif params.lang == "Automatic Detection":
119
+ params.lang = None
120
+ else:
121
+ language_code_dict = {value: key for key, value in whisper.tokenizer.LANGUAGES.items()}
122
+ params.lang = language_code_dict[params.lang]
123
+
124
+ if params.is_bgm_separate:
125
+ music, audio, _ = self.music_separator.separate(
126
+ audio=audio,
127
+ model_name=params.uvr_model_size,
128
+ device=params.uvr_device,
129
+ segment_size=params.uvr_segment_size,
130
+ save_file=params.uvr_save_file,
131
+ progress=progress
132
+ )
133
+
134
+ if audio.ndim >= 2:
135
+ audio = audio.mean(axis=1)
136
+ if self.music_separator.audio_info is None:
137
+ origin_sample_rate = 16000
138
+ else:
139
+ origin_sample_rate = self.music_separator.audio_info.sample_rate
140
+ audio = self.resample_audio(audio=audio, original_sample_rate=origin_sample_rate)
141
+
142
+ if params.uvr_enable_offload:
143
+ self.music_separator.offload()
144
+ elapsed_time_bgm_sep = datetime.now() - start_time
145
+
146
+ origin_audio = deepcopy(audio)
147
+
148
+ if params.vad_filter:
149
+ # Explicit value set for float('inf') from gr.Number()
150
+ if params.max_speech_duration_s is None or params.max_speech_duration_s >= 9999:
151
+ params.max_speech_duration_s = float('inf')
152
+
153
+ progress(0, desc="Filtering silent parts from audio...")
154
+ vad_options = VadOptions(
155
+ threshold=params.threshold,
156
+ min_speech_duration_ms=params.min_speech_duration_ms,
157
+ max_speech_duration_s=params.max_speech_duration_s,
158
+ min_silence_duration_ms=params.min_silence_duration_ms,
159
+ speech_pad_ms=params.speech_pad_ms
160
+ )
161
+
162
+ vad_processed, speech_chunks = self.vad.run(
163
+ audio=audio,
164
+ vad_parameters=vad_options,
165
+ progress=progress
166
+ )
167
+
168
+ try:
169
+ if vad_processed.size > 0 and speech_chunks:
170
+ if not isinstance(audio, np.ndarray):
171
+ loaded_audio = faster_whisper.decode_audio(audio, sampling_rate=self.vad.sampling_rate)
172
+ else:
173
+ loaded_audio = audio
174
+ # Convert speech_chunks to Segment objects and convert samples to seconds
175
+ segments = [Segment(start=chunk['start']/self.vad.sampling_rate, end=chunk['end']/self.vad.sampling_rate) for chunk in speech_chunks]
176
+ # merged_chunks only works on segments expressed in seconds!!
177
+ merged_chunks = merge_chunks(segments, chunk_size=300, onset=0.0, offset=None)
178
+ all_segments = []
179
+ total_elapsed_time = 0.0
180
+ for merged in merged_chunks:
181
+ chunk_start = merged['start']
182
+ chunk_end = merged['end']
183
+
184
+ # To slice audio, convert chunk_start and chunk_end from seconds to samples by mulitplying by sampling rate.
185
+ start_sample = int(chunk_start*self.vad.sampling_rate)
186
+ end_sample = int(chunk_end*self.vad.sampling_rate)
187
+
188
+ chunk_audio = loaded_audio[start_sample:end_sample]
189
+
190
+ chunk_result, chunk_time = self.transcribe(
191
+ chunk_audio,
192
+ progress,
193
+ *astuple(params)
194
+ )
195
+ # Offset timestamps
196
+ for seg in chunk_result:
197
+ seg['start'] += chunk_start
198
+ seg['end'] += chunk_start
199
+ all_segments.extend(chunk_result)
200
+ total_elapsed_time += chunk_time
201
+ result = all_segments
202
+ elapsed_time = total_elapsed_time
203
+ else:
204
+ params.vad_filter = False
205
+ except Exception as e:
206
+ print(f"Error transcribing file: {e}")
207
+
208
+ if not params.vad_filter:
209
+ result, elapsed_time = self.transcribe(
210
+ audio,
211
+ progress,
212
+ *astuple(params)
213
+ )
214
+ if bool_whisper_enable_offload:
215
+ self.offload()
216
+
217
+ if params.is_diarize:
218
+ progress(0.99, desc="Diarizing speakers...")
219
+ result, elapsed_time_diarization = self.diarizer.run(
220
+ audio=origin_audio,
221
+ use_auth_token=params.hf_token,
222
+ transcribed_result=result,
223
+ device=params.diarization_device
224
+ )
225
+ if bool_diarization_enable_offload:
226
+ self.diarizer.offload()
227
+
228
+ if not result:
229
+ print(f"Whisper did not detected any speech segments in the audio.")
230
+ result = list()
231
+
232
+ progress(1.0, desc="Processing done!")
233
+ total_elapsed_time = datetime.now() - start_time
234
+ return result, elapsed_time
235
+
236
+ def transcribe_file(self,
237
+ files_audio: Optional[List] = None,
238
+ files_video: Optional[List] = None,
239
+ files_multi: Optional[List] = None,
240
+ input_multi: str = "Audio",
241
+ input_folder_path: Optional[str] = None,
242
+ file_format: list = ["CSV"],
243
+ add_timestamp: bool = True,
244
+ translate_output: bool = False,
245
+ translate_model: str = "",
246
+ target_lang: str = "",
247
+ add_timestamp_preview: bool = False,
248
+ progress=gr.Progress(),
249
+ *whisper_params,
250
+ ) -> list:
251
+ """
252
+ Write subtitle file from Files
253
+
254
+ Parameters
255
+ ----------
256
+ files_audio: list
257
+ List of files to transcribe from gr.Audio()
258
+ files_video: list
259
+ List of files to transcribe from gr.Video()
260
+ files_multi: list
261
+ List of files to transcribe from gr.Files_multi()
262
+ input_multi: bool
263
+ Process single or multiple files
264
+ input_folder_path: str
265
+ Input folder path to transcribe from gr.Textbox(). If this is provided, `files` will be ignored and
266
+ this will be used instead.
267
+ file_format: str
268
+ Subtitle File format to write from gr.Dropdown(). Supported format: [CSV, SRT, TXT]
269
+ add_timestamp: bool
270
+ Boolean value from gr.Checkbox() that determines whether to add a timestamp at the end of the subtitle filename.
271
+ translate_output: bool
272
+ Translate output
273
+ translate_model: str
274
+ Translation model to use
275
+ target_lang: str
276
+ Target language to use
277
+ add_timestamp_preview: bool
278
+ Boolean value from gr.Checkbox() that determines whether to add a timestamp to output preview
279
+ progress: gr.Progress
280
+ Indicator to show progress directly in gradio.
281
+ *whisper_params: tuple
282
+ Parameters related with whisper. This will be dealt with "WhisperParameters" data class
283
+
284
+ Returns
285
+ ----------
286
+ result_str:
287
+ Result of transcription to return to gr.Textbox()
288
+ result_file_path:
289
+ Output file path to return to gr.Files()
290
+ """
291
+
292
+ try:
293
+ file_count_total = 0
294
+ files = ""
295
+
296
+ if input_multi == "Audio":
297
+ files = files_audio
298
+ elif input_multi == "Video":
299
+ files = files_video
300
+ else:
301
+ files = files_multi
302
+ file_count_total = len(files)
303
+
304
+ if input_folder_path:
305
+ files = get_media_files(input_folder_path)
306
+ if isinstance(files, str):
307
+ files = [files]
308
+ if files and isinstance(files[0], gr.utils.NamedString):
309
+ files = [file.name for file in files]
310
+
311
+ ## Initialization variables & start time
312
+ files_info = {}
313
+ files_to_download = {}
314
+ time_start = datetime.now()
315
+
316
+ ## Load parameters related with whisper
317
+ params = WhisperParameters.as_value(*whisper_params)
318
+
319
+ ## Load model to detect language
320
+ model = whisper.load_model("base")
321
+
322
+ for file in files:
323
+ print(file)
324
+ ## Detect language
325
+ mel = whisper.log_mel_spectrogram(whisper.pad_or_trim(whisper.load_audio(file))).to(model.device)
326
+ _, probs = model.detect_language(mel)
327
+ file_language = ""
328
+ file_lang_probs = ""
329
+ for key,value in whisper.tokenizer.LANGUAGES.items():
330
+ if key == str(max(probs, key=probs.get)):
331
+ file_language = value.capitalize()
332
+ for key_prob,value_prob in probs.items():
333
+ if key == key_prob:
334
+ file_lang_probs = str((round(value_prob*100)))
335
+ break
336
+ break
337
+ transcribed_segments, time_for_task = self.run(
338
+ file,
339
+ progress,
340
+ add_timestamp,
341
+ *whisper_params,
342
+ )
343
+ # Define source language
344
+ #source_lang = file_language
345
+ if params.lang == "Automatic Detection" or (params.lang).strip() == "":
346
+ source_lang = file_language
347
+ else:
348
+ source_lang = ((params.lang).strip()).capitalize()
349
+
350
+ # Translate to English using Whisper built-in functionality
351
+ transcription_note = ""
352
+ if params.is_translate:
353
+ if source_lang != "English":
354
+ transcription_note = "To English"
355
+ source_lang = "English"
356
+ else:
357
+ transcription_note = "Already in English"
358
+
359
+ # Translate the transcribed segments
360
+ translation_note = ""
361
+ if translate_output:
362
+ if source_lang != target_lang:
363
+ self.nllb_inf = NLLBInference()
364
+ if source_lang in NLLB_AVAILABLE_LANGS.keys():
365
+ transcribed_segments = self.nllb_inf.translate_text(
366
+ input_list_dict=transcribed_segments,
367
+ model_size=translate_model,
368
+ src_lang=source_lang,
369
+ tgt_lang=target_lang,
370
+ speaker_diarization=params.is_diarize
371
+ )
372
+ translation_note = "To " + target_lang
373
+ else:
374
+ translation_note = source_lang + " not supported"
375
+ else:
376
+ translation_note = "Already in " + target_lang
377
+
378
+ ## Get input filename & extension
379
+ file_name, file_ext = os.path.splitext(os.path.basename(file))
380
+
381
+ ## Get output as preview with or without timestamps
382
+ if add_timestamp_preview:
383
+ subtitle = get_txt(transcribed_segments)
384
+ else:
385
+ subtitle = get_plaintext(transcribed_segments)
386
+ files_info[file_name] = {"subtitle": subtitle, "time_for_task": time_for_task, "lang": file_language, "lang_prob": file_lang_probs, "input_source_file": (file_name+file_ext), "translation": translation_note, "transcription": transcription_note}
387
+
388
+ ## Add output file as txt, srt and/or csv
389
+ for output_format in file_format:
390
+ subtitle, file_path = self.generate_and_write_file(
391
+ file_name=file_name,
392
+ transcribed_segments=transcribed_segments,
393
+ add_timestamp=add_timestamp,
394
+ file_format=output_format.lower(),
395
+ output_dir=self.output_dir
396
+ )
397
+ files_to_download[file_name+"_"+output_format.lower()] = {"path": file_path}
398
+
399
+ total_result = ""
400
+ total_info = ""
401
+ total_time = 0
402
+ file_count = 0
403
+ for file_name, info in files_info.items():
404
+
405
+ file_count += 1
406
+
407
+ if file_count > 1:
408
+ total_info += f'\n'
409
+
410
+ if file_count_total > 1:
411
+ if file_count > 1:
412
+ total_result += f'\n'
413
+ total_result += f'« Transcription of media file \'{info["input_source_file"]}\': »\n\n'
414
+
415
+ total_time += info["time_for_task"]
416
+ total_result += f'{info["subtitle"]}'
417
+ total_info += f'Media file:\t{info["input_source_file"]}\nLanguage:\t{info["lang"]} (probability {info["lang_prob"]}%)\n'
418
+
419
+ if params.is_translate:
420
+ total_info += f'Translation:\t{info["transcription"]}\n\t⤷ Handled by OpenAI Whisper\n'
421
+
422
+ if translate_output:
423
+ total_info += f'Translation:\t{info["translation"]}\n\t⤷ Handled by Facebook NLLB\n'
424
+
425
+ time_end = datetime.now()
426
+ #total_info += f"\nTotal processing time:\t{self.format_time((time_end-time_start).total_seconds())}"
427
+
428
+ temp_file_count_text = "file"
429
+ if file_count!=1:
430
+ temp_file_count_text += "s"
431
+ total_info += f"\nProcessed {file_count} {temp_file_count_text} in {self.format_time((time_end-time_start).total_seconds())}"
432
+
433
+ result_str = total_result.rstrip("\n")
434
+ result_file_path = [info['path'] for info in files_to_download.values()]
435
+
436
+ return [result_str,result_file_path,total_info]
437
+
438
+ except Exception as e:
439
+ print(f"Error transcribing file: {e}")
440
+ finally:
441
+ self.release_cuda_memory()
442
+
443
+ def transcribe_mic(self,
444
+ mic_audio: str,
445
+ file_format: str = "SRT",
446
+ add_timestamp: bool = True,
447
+ progress=gr.Progress(),
448
+ *whisper_params,
449
+ ) -> list:
450
+ """
451
+ Write subtitle file from microphone
452
+
453
+ Parameters
454
+ ----------
455
+ mic_audio: str
456
+ Audio file path from gr.Microphone()
457
+ file_format: str
458
+ Subtitle File format to write from gr.Dropdown(). Supported format: [SRT, WebVTT, txt]
459
+ add_timestamp: bool
460
+ Boolean value from gr.Checkbox() that determines whether to add a timestamp at the end of the filename.
461
+ progress: gr.Progress
462
+ Indicator to show progress directly in gradio.
463
+ *whisper_params: tuple
464
+ Parameters related with whisper. This will be dealt with "WhisperParameters" data class
465
+
466
+ Returns
467
+ ----------
468
+ result_str:
469
+ Result of transcription to return to gr.Textbox()
470
+ result_file_path:
471
+ Output file path to return to gr.Files()
472
+ """
473
+ try:
474
+ progress(0, desc="Loading Audio...")
475
+ transcribed_segments, time_for_task = self.run(
476
+ mic_audio,
477
+ progress,
478
+ add_timestamp,
479
+ *whisper_params,
480
+ )
481
+ progress(1, desc="Completed!")
482
+
483
+ subtitle, result_file_path = self.generate_and_write_file(
484
+ file_name="Mic",
485
+ transcribed_segments=transcribed_segments,
486
+ add_timestamp=add_timestamp,
487
+ file_format=file_format,
488
+ output_dir=self.output_dir
489
+ )
490
+
491
+ result_str = f"Done in {self.format_time(time_for_task)}! Subtitle file is in the outputs folder.\n\n{subtitle}"
492
+ return [result_str, result_file_path]
493
+ except Exception as e:
494
+ print(f"Error transcribing file: {e}")
495
+ finally:
496
+ self.release_cuda_memory()
497
+
498
+ def transcribe_youtube(self,
499
+ youtube_link: str,
500
+ file_format: str = "SRT",
501
+ add_timestamp: bool = True,
502
+ progress=gr.Progress(),
503
+ *whisper_params,
504
+ ) -> list:
505
+ """
506
+ Write subtitle file from Youtube
507
+
508
+ Parameters
509
+ ----------
510
+ youtube_link: str
511
+ URL of the Youtube video to transcribe from gr.Textbox()
512
+ file_format: str
513
+ Subtitle File format to write from gr.Dropdown(). Supported format: [SRT, WebVTT, txt]
514
+ add_timestamp: bool
515
+ Boolean value from gr.Checkbox() that determines whether to add a timestamp at the end of the filename.
516
+ progress: gr.Progress
517
+ Indicator to show progress directly in gradio.
518
+ *whisper_params: tuple
519
+ Parameters related with whisper. This will be dealt with "WhisperParameters" data class
520
+
521
+ Returns
522
+ ----------
523
+ result_str:
524
+ Result of transcription to return to gr.Textbox()
525
+ result_file_path:
526
+ Output file path to return to gr.Files()
527
+ """
528
+ try:
529
+ progress(0, desc="Loading Audio from Youtube...")
530
+ yt = get_ytdata(youtube_link)
531
+ audio = get_ytaudio(yt)
532
+
533
+ transcribed_segments, time_for_task = self.run(
534
+ audio,
535
+ progress,
536
+ add_timestamp,
537
+ *whisper_params,
538
+ )
539
+
540
+ progress(1, desc="Completed!")
541
+
542
+ file_name = safe_filename(yt.title)
543
+ subtitle, result_file_path = self.generate_and_write_file(
544
+ file_name=file_name,
545
+ transcribed_segments=transcribed_segments,
546
+ add_timestamp=add_timestamp,
547
+ file_format=file_format,
548
+ output_dir=self.output_dir
549
+ )
550
+ result_str = f"Done in {self.format_time(time_for_task)}! Subtitle file is in the outputs folder.\n\n{subtitle}"
551
+
552
+ if os.path.exists(audio):
553
+ os.remove(audio)
554
+
555
+ return [result_str, result_file_path]
556
+
557
+ except Exception as e:
558
+ print(f"Error transcribing file: {e}")
559
+ finally:
560
+ self.release_cuda_memory()
561
+
562
+ @staticmethod
563
+ def generate_and_write_file(file_name: str,
564
+ transcribed_segments: list,
565
+ add_timestamp: bool,
566
+ file_format: str,
567
+ output_dir: str
568
+ ) -> str:
569
+ """
570
+ Writes subtitle file
571
+
572
+ Parameters
573
+ ----------
574
+ file_name: str
575
+ Output file name
576
+ transcribed_segments: list
577
+ Text segments transcribed from audio
578
+ add_timestamp: bool
579
+ Determines whether to add a timestamp to the end of the filename.
580
+ file_format: str
581
+ File format to write. Supported formats: [SRT, WebVTT, txt, csv]
582
+ output_dir: str
583
+ Directory path of the output
584
+
585
+ Returns
586
+ ----------
587
+ content: str
588
+ Result of the transcription
589
+ output_path: str
590
+ output file path
591
+ """
592
+ if add_timestamp:
593
+ #timestamp = datetime.now().strftime("%m%d%H%M%S")
594
+ timestamp = datetime.now().strftime("%Y%m%d %H%M%S")
595
+ output_path = os.path.join(output_dir, f"{file_name} - {timestamp}")
596
+ else:
597
+ output_path = os.path.join(output_dir, f"{file_name}")
598
+
599
+ file_format = file_format.strip().lower()
600
+ if file_format == "srt":
601
+ content = get_srt(transcribed_segments)
602
+ output_path += '.srt'
603
+
604
+ elif file_format == "webvtt":
605
+ content = get_vtt(transcribed_segments)
606
+ output_path += '.vtt'
607
+
608
+ elif file_format == "txt":
609
+ content = get_txt(transcribed_segments)
610
+ output_path += '.txt'
611
+
612
+ elif file_format == "csv":
613
+ content = get_csv(transcribed_segments)
614
+ output_path += '.csv'
615
+
616
+ write_file(content, output_path)
617
+ return content, output_path
618
+
619
+ def offload(self):
620
+ """Offload the model and free up the memory"""
621
+ if self.model is not None:
622
+ del self.model
623
+ self.model = None
624
+ if self.device == "cuda":
625
+ self.release_cuda_memory()
626
+ gc.collect()
627
+
628
+ @staticmethod
629
+ def format_time(elapsed_time: float) -> str:
630
+ """
631
+ Get {hours} {minutes} {seconds} time format string
632
+
633
+ Parameters
634
+ ----------
635
+ elapsed_time: str
636
+ Elapsed time for transcription
637
+
638
+ Returns
639
+ ----------
640
+ Time format string
641
+ """
642
+ hours, rem = divmod(elapsed_time, 3600)
643
+ minutes, seconds = divmod(rem, 60)
644
+
645
+ time_str = ""
646
+
647
+ hours = round(hours)
648
+ if hours:
649
+ if hours == 1:
650
+ time_str += f"{hours} hour "
651
+ else:
652
+ time_str += f"{hours} hours "
653
+
654
+ minutes = round(minutes)
655
+ if minutes:
656
+ if minutes == 1:
657
+ time_str += f"{minutes} minute "
658
+ else:
659
+ time_str += f"{minutes} minutes "
660
+
661
+ seconds = round(seconds)
662
+ if seconds == 1:
663
+ time_str += f"{seconds} second"
664
+ else:
665
+ time_str += f"{seconds} seconds"
666
+
667
+ return time_str.strip()
668
+
669
+ @staticmethod
670
+ def get_device():
671
+ if torch.cuda.is_available():
672
+ return "cuda"
673
+ elif torch.backends.mps.is_available():
674
+ if not WhisperBase.is_sparse_api_supported():
675
+ # Device `SparseMPS` is not supported for now. See : https://github.com/pytorch/pytorch/issues/87886
676
+ return "cpu"
677
+ return "mps"
678
+ else:
679
+ return "cpu"
680
+
681
+ @staticmethod
682
+ def is_sparse_api_supported():
683
+ if not torch.backends.mps.is_available():
684
+ return False
685
+
686
+ try:
687
+ device = torch.device("mps")
688
+ sparse_tensor = torch.sparse_coo_tensor(
689
+ indices=torch.tensor([[0, 1], [2, 3]]),
690
+ values=torch.tensor([1, 2]),
691
+ size=(4, 4),
692
+ device=device
693
+ )
694
+ return True
695
+ except RuntimeError:
696
+ return False
697
+
698
+ @staticmethod
699
+ def release_cuda_memory():
700
+ """Release memory"""
701
+ if torch.cuda.is_available():
702
+ torch.cuda.empty_cache()
703
+ torch.cuda.reset_max_memory_allocated()
704
+
705
+ @staticmethod
706
+ def remove_input_files(file_paths: List[str]):
707
+ """Remove gradio cached files"""
708
+ if not file_paths:
709
+ return
710
+
711
+ for file_path in file_paths:
712
+ if file_path and os.path.exists(file_path):
713
+ os.remove(file_path)
714
+
715
+ @staticmethod
716
+ def cache_parameters(
717
+ params: WhisperValues,
718
+ file_format: str = "SRT",
719
+ add_timestamp: bool = True
720
+ ):
721
+ """Cache parameters to the yaml file"""
722
+ cached_params = load_yaml(DEFAULT_PARAMETERS_CONFIG_PATH)
723
+ param_to_cache = params.to_dict()
724
+
725
+ cached_yaml = {**cached_params, **param_to_cache}
726
+ cached_yaml["whisper"]["add_timestamp"] = add_timestamp
727
+ cached_yaml["whisper"]["file_format"] = file_format
728
+
729
+ suppress_token = cached_yaml["whisper"].get("suppress_tokens", None)
730
+ if suppress_token and isinstance(suppress_token, list):
731
+ cached_yaml["whisper"]["suppress_tokens"] = str(suppress_token)
732
+
733
+ if cached_yaml["whisper"].get("lang", None) is None:
734
+ cached_yaml["whisper"]["lang"] = AUTOMATIC_DETECTION.unwrap()
735
+ else:
736
+ language_dict = whisper.tokenizer.LANGUAGES
737
+ cached_yaml["whisper"]["lang"] = language_dict[cached_yaml["whisper"]["lang"]]
738
+
739
+ if cached_yaml["vad"].get("max_speech_duration_s", float('inf')) == float('inf'):
740
+ cached_yaml["vad"]["max_speech_duration_s"] = GRADIO_NONE_NUMBER_MAX
741
+
742
+ if cached_yaml is not None and cached_yaml:
743
+ save_yaml(cached_yaml, DEFAULT_PARAMETERS_CONFIG_PATH)
744
+
745
+ @staticmethod
746
+ def resample_audio(audio: Union[str, np.ndarray],
747
+ new_sample_rate: int = 16000,
748
+ original_sample_rate: Optional[int] = None,) -> np.ndarray:
749
+ """Resamples audio to 16k sample rate, standard on Whisper model"""
750
+ if isinstance(audio, str):
751
+ audio, original_sample_rate = torchaudio.load(audio)
752
+ else:
753
+ if original_sample_rate is None:
754
+ raise ValueError("original_sample_rate must be provided when audio is numpy array.")
755
+ audio = torch.from_numpy(audio)
756
+ resampler = torchaudio.transforms.Resample(orig_freq=original_sample_rate, new_freq=new_sample_rate)
757
+ resampled_audio = resampler(audio).numpy()
758
+ return resampled_audio