LAP-DEV commited on
Commit
b958773
·
verified ·
1 Parent(s): 18a379b

Delete modules/whisper/whisper_base.py

Browse files
Files changed (1) hide show
  1. modules/whisper/whisper_base.py +0 -775
modules/whisper/whisper_base.py DELETED
@@ -1,775 +0,0 @@
1
- import os
2
- import torch
3
- import whisper
4
- import gradio as gr
5
- import torchaudio
6
- from abc import ABC, abstractmethod
7
- from typing import BinaryIO, Union, Tuple, List
8
- import numpy as np
9
- from datetime import datetime
10
- from faster_whisper.vad import VadOptions
11
- from dataclasses import astuple
12
- import gc
13
- from copy import deepcopy
14
- from modules.vad.silero_vad import merge_chunks, Segment
15
- from modules.uvr.music_separator import MusicSeparator
16
- from modules.utils.paths import (WHISPER_MODELS_DIR, DIARIZATION_MODELS_DIR, OUTPUT_DIR, DEFAULT_PARAMETERS_CONFIG_PATH,
17
- UVR_MODELS_DIR)
18
- from modules.utils.subtitle_manager import get_srt, get_vtt, get_txt, get_plaintext, get_csv, write_file, safe_filename
19
- from modules.utils.youtube_manager import get_ytdata, get_ytaudio
20
- from modules.utils.files_manager import get_media_files, format_gradio_files, load_yaml, save_yaml
21
- from modules.whisper.whisper_parameter import *
22
- from modules.diarize.diarizer import Diarizer
23
- from modules.vad.silero_vad import SileroVAD
24
- from modules.translation.nllb_inference import NLLBInference
25
- from modules.translation.nllb_inference import NLLB_AVAILABLE_LANGS
26
- import faster_whisper
27
-
28
- class WhisperBase(ABC):
29
- def __init__(self,
30
- model_dir: str = WHISPER_MODELS_DIR,
31
- diarization_model_dir: str = DIARIZATION_MODELS_DIR,
32
- uvr_model_dir: str = UVR_MODELS_DIR,
33
- output_dir: str = OUTPUT_DIR,
34
- ):
35
- self.model_dir = model_dir
36
- self.output_dir = output_dir
37
- os.makedirs(self.output_dir, exist_ok=True)
38
- os.makedirs(self.model_dir, exist_ok=True)
39
- self.diarizer = Diarizer(
40
- model_dir=diarization_model_dir
41
- )
42
- self.vad = SileroVAD()
43
- self.music_separator = MusicSeparator(
44
- model_dir=uvr_model_dir,
45
- output_dir=os.path.join(output_dir, "UVR")
46
- )
47
-
48
- self.model = None
49
- self.current_model_size = None
50
- self.available_models = whisper.available_models()
51
- self.available_langs = sorted(list(whisper.tokenizer.LANGUAGES.values()))
52
- #self.translatable_models = ["large", "large-v1", "large-v2", "large-v3"]
53
- self.translatable_models = whisper.available_models()
54
- self.device = self.get_device()
55
- self.available_compute_types = ["float16", "float32"]
56
- self.current_compute_type = "float16" if self.device == "cuda" else "float32"
57
-
58
- @abstractmethod
59
- def transcribe(self,
60
- audio: Union[str, BinaryIO, np.ndarray],
61
- progress: gr.Progress = gr.Progress(),
62
- *whisper_params,
63
- ):
64
- """Inference whisper model to transcribe"""
65
- pass
66
-
67
- @abstractmethod
68
- def update_model(self,
69
- model_size: str,
70
- compute_type: str,
71
- progress: gr.Progress = gr.Progress()
72
- ):
73
- """Initialize whisper model"""
74
- pass
75
-
76
- def run(self,
77
- audio: Union[str, BinaryIO, np.ndarray],
78
- progress: gr.Progress = gr.Progress(),
79
- add_timestamp: bool = True,
80
- *whisper_params,
81
- ) -> Tuple[List[dict], float]:
82
- """
83
- Run transcription with conditional pre-processing and post-processing.
84
- The VAD will be performed to remove noise from the audio input in pre-processing, if enabled.
85
- The diarization will be performed in post-processing, if enabled.
86
-
87
- Parameters
88
- ----------
89
- audio: Union[str, BinaryIO, np.ndarray]
90
- Audio input. This can be file path or binary type.
91
- progress: gr.Progress
92
- Indicator to show progress directly in gradio.
93
- add_timestamp: bool
94
- Whether to add a timestamp at the end of the filename.
95
- *whisper_params: tuple
96
- Parameters related with whisper. This will be dealt with "WhisperParameters" data class
97
-
98
- Returns
99
- ----------
100
- segments_result: List[dict]
101
- list of dicts that includes start, end timestamps and transcribed text
102
- elapsed_time: float
103
- elapsed time for running
104
- """
105
-
106
- start_time = datetime.now()
107
- params = WhisperParameters.as_value(*whisper_params)
108
-
109
- # Get the offload params
110
- default_params = load_yaml(DEFAULT_PARAMETERS_CONFIG_PATH)
111
- whisper_params = default_params["whisper"]
112
- diarization_params = default_params["diarization"]
113
- bool_whisper_enable_offload = whisper_params["enable_offload"]
114
- bool_diarization_enable_offload = diarization_params["enable_offload"]
115
-
116
- if params.lang is None:
117
- pass
118
- elif params.lang == "Automatic Detection":
119
- params.lang = None
120
- else:
121
- language_code_dict = {value: key for key, value in whisper.tokenizer.LANGUAGES.items()}
122
- params.lang = language_code_dict[params.lang]
123
-
124
- if params.is_bgm_separate:
125
- music, audio, _ = self.music_separator.separate(
126
- audio=audio,
127
- model_name=params.uvr_model_size,
128
- device=params.uvr_device,
129
- segment_size=params.uvr_segment_size,
130
- save_file=params.uvr_save_file,
131
- progress=progress
132
- )
133
-
134
- if audio.ndim >= 2:
135
- audio = audio.mean(axis=1)
136
- if self.music_separator.audio_info is None:
137
- origin_sample_rate = 16000
138
- else:
139
- origin_sample_rate = self.music_separator.audio_info.sample_rate
140
- audio = self.resample_audio(audio=audio, original_sample_rate=origin_sample_rate)
141
-
142
- if params.uvr_enable_offload:
143
- self.music_separator.offload()
144
- elapsed_time_bgm_sep = datetime.now() - start_time
145
-
146
- origin_audio = deepcopy(audio)
147
-
148
- if params.vad_filter:
149
- # Explicit value set for float('inf') from gr.Number()
150
- if params.max_speech_duration_s is None or params.max_speech_duration_s >= 9999:
151
- params.max_speech_duration_s = float('inf')
152
-
153
- progress(0, desc="Filtering silent parts from audio...")
154
- vad_options = VadOptions(
155
- threshold=params.threshold,
156
- min_speech_duration_ms=params.min_speech_duration_ms,
157
- max_speech_duration_s=params.max_speech_duration_s,
158
- min_silence_duration_ms=params.min_silence_duration_ms,
159
- speech_pad_ms=params.speech_pad_ms
160
- )
161
-
162
- vad_processed, speech_chunks = self.vad.run(
163
- audio=audio,
164
- vad_parameters=vad_options,
165
- progress=progress
166
- )
167
-
168
- try:
169
- if vad_processed.size > 0 and speech_chunks:
170
- if not isinstance(audio, np.ndarray):
171
- loaded_audio = faster_whisper.decode_audio(audio, sampling_rate=self.vad.sampling_rate)
172
- else:
173
- loaded_audio = audio
174
- # Convert speech_chunks to Segment objects and convert samples to seconds
175
- segments = [Segment(start=chunk['start']/self.vad.sampling_rate, end=chunk['end']/self.vad.sampling_rate) for chunk in speech_chunks]
176
- # merged_chunks only works on segments expressed in seconds!!
177
- merged_chunks = merge_chunks(segments, chunk_size=300, onset=0.0, offset=None)
178
- all_segments = []
179
- total_elapsed_time = 0.0
180
- for merged in merged_chunks:
181
- chunk_start = merged['start']
182
- chunk_end = merged['end']
183
-
184
- # To slice audio, convert chunk_start and chunk_end from seconds to samples by mulitplying by sampling rate.
185
- start_sample = int(chunk_start*self.vad.sampling_rate)
186
- end_sample = int(chunk_end*self.vad.sampling_rate)
187
-
188
- chunk_audio = loaded_audio[start_sample:end_sample]
189
-
190
- chunk_result, chunk_time = self.transcribe(
191
- chunk_audio,
192
- progress,
193
- *astuple(params)
194
- )
195
- # Offset timestamps
196
- for seg in chunk_result:
197
- seg['start'] += chunk_start
198
- seg['end'] += chunk_start
199
- all_segments.extend(chunk_result)
200
- total_elapsed_time += chunk_time
201
- result = all_segments
202
- elapsed_time = total_elapsed_time
203
- else:
204
- params.vad_filter = False
205
- except Exception as e:
206
- print(f"Error transcribing file: {e}")
207
-
208
- if not params.vad_filter:
209
- result, elapsed_time = self.transcribe(
210
- audio,
211
- progress,
212
- *astuple(params)
213
- )
214
- if bool_whisper_enable_offload:
215
- self.offload()
216
-
217
- if params.is_diarize:
218
- progress(0.99, desc="Diarizing speakers...")
219
- result, elapsed_time_diarization = self.diarizer.run(
220
- audio=origin_audio,
221
- use_auth_token=params.hf_token,
222
- transcribed_result=result,
223
- device=params.diarization_device
224
- )
225
- if bool_diarization_enable_offload:
226
- self.diarizer.offload()
227
-
228
- if not result:
229
- print(f"Whisper did not detected any speech segments in the audio.")
230
- result = list()
231
-
232
- progress(1.0, desc="Processing done!")
233
- total_elapsed_time = datetime.now() - start_time
234
- return result, elapsed_time
235
-
236
- def transcribe_file(self,
237
- files_audio: Optional[List] = None,
238
- files_video: Optional[List] = None,
239
- files_multi: Optional[List] = None,
240
- input_multi: str = "Audio",
241
- input_folder_path: Optional[str] = None,
242
- file_format: list = ["CSV"],
243
- add_timestamp: bool = True,
244
- translate_output: bool = False,
245
- translate_model: str = "",
246
- target_lang: str = "",
247
- add_timestamp_preview: bool = False,
248
- diarize_speakers: bool = False,
249
- progress=gr.Progress(),
250
- *whisper_params,
251
- ) -> list:
252
- """
253
- Write subtitle file from Files
254
-
255
- Parameters
256
- ----------
257
- files_audio: list
258
- List of files to transcribe from gr.Audio()
259
- files_video: list
260
- List of files to transcribe from gr.Video()
261
- files_multi: list
262
- List of files to transcribe from gr.Files_multi()
263
- input_multi: bool
264
- Process single or multiple files
265
- input_folder_path: str
266
- Input folder path to transcribe from gr.Textbox(). If this is provided, `files` will be ignored and
267
- this will be used instead.
268
- file_format: str
269
- Subtitle File format to write from gr.Dropdown(). Supported format: [CSV, SRT, TXT]
270
- add_timestamp: bool
271
- Boolean value from gr.Checkbox() that determines whether to add a timestamp at the end of the subtitle filename.
272
- translate_output: bool
273
- Translate output
274
- translate_model: str
275
- Translation model to use
276
- target_lang: str
277
- Target language to use
278
- add_timestamp_preview: bool
279
- Boolean value from gr.Checkbox() that determines whether to add a timestamp to output preview
280
- diarize_speakers: bool
281
- Boolean value from gr.Checkbox() that determines whether to diarize speakers
282
- progress: gr.Progress
283
- Indicator to show progress directly in gradio.
284
- *whisper_params: tuple
285
- Parameters related with whisper. This will be dealt with "WhisperParameters" data class
286
-
287
- Returns
288
- ----------
289
- result_str:
290
- Result of transcription to return to gr.Textbox()
291
- result_file_path:
292
- Output file path to return to gr.Files()
293
- """
294
-
295
- try:
296
- file_count_total = 0
297
- files = ""
298
-
299
- if input_multi == "Audio":
300
- files = files_audio
301
- elif input_multi == "Video":
302
- files = files_video
303
- else:
304
- files = files_multi
305
- file_count_total = len(files)
306
-
307
- if input_folder_path:
308
- files = get_media_files(input_folder_path)
309
- if isinstance(files, str):
310
- files = [files]
311
- if files and isinstance(files[0], gr.utils.NamedString):
312
- files = [file.name for file in files]
313
-
314
- ## Initialization variables & start time
315
- files_info = {}
316
- files_to_download = {}
317
- time_start = datetime.now()
318
-
319
- ## Load parameters related with whisper
320
- params = WhisperParameters.as_value(*whisper_params)
321
-
322
- ## Load model to detect language
323
- model = whisper.load_model("base")
324
-
325
- for file in files:
326
- print(file)
327
- ## Detect language
328
- mel = whisper.log_mel_spectrogram(whisper.pad_or_trim(whisper.load_audio(file))).to(model.device)
329
- _, probs = model.detect_language(mel)
330
- file_language = ""
331
- file_lang_probs = ""
332
- for key,value in whisper.tokenizer.LANGUAGES.items():
333
- if key == str(max(probs, key=probs.get)):
334
- file_language = value.capitalize()
335
- for key_prob,value_prob in probs.items():
336
- if key == key_prob:
337
- file_lang_probs = str((round(value_prob*100)))
338
- break
339
- break
340
- transcribed_segments, time_for_task = self.run(
341
- file,
342
- progress,
343
- add_timestamp,
344
- *whisper_params,
345
- )
346
- # Define source language
347
- #source_lang = file_language
348
- if params.lang == "Automatic Detection" or (params.lang).strip() == "":
349
- source_lang = file_language
350
- else:
351
- source_lang = ((params.lang).strip()).capitalize()
352
-
353
- # Translate to English using Whisper built-in functionality
354
- transcription_note = ""
355
- if params.is_translate:
356
- if source_lang != "English":
357
- transcription_note = "To English"
358
- source_lang = "English"
359
- else:
360
- transcription_note = "Already in English"
361
-
362
- # Translate the transcribed segments
363
- translation_note = ""
364
- if translate_output:
365
- if source_lang != target_lang:
366
- self.nllb_inf = NLLBInference()
367
- if source_lang in NLLB_AVAILABLE_LANGS.keys():
368
- transcribed_segments = self.nllb_inf.translate_text(
369
- input_list_dict=transcribed_segments,
370
- model_size=translate_model,
371
- src_lang=source_lang,
372
- tgt_lang=target_lang,
373
- speaker_diarization=params.is_diarize
374
- )
375
- translation_note = "To " + target_lang
376
- else:
377
- translation_note = source_lang + " not supported"
378
- else:
379
- translation_note = "Already in " + target_lang
380
-
381
- ## Get input filename & extension
382
- file_name, file_ext = os.path.splitext(os.path.basename(file))
383
-
384
- ## Get output as preview with or without timestamps
385
- if add_timestamp_preview:
386
- subtitle = get_txt(transcribed_segments)
387
- else:
388
- subtitle = get_plaintext(transcribed_segments)
389
- files_info[file_name] = {"subtitle": subtitle, "time_for_task": time_for_task, "lang": file_language, "lang_prob": file_lang_probs, "input_source_file": (file_name+file_ext), "translation": translation_note, "transcription": transcription_note}
390
-
391
- ## Add output file as txt, srt and/or csv
392
- for output_format in file_format:
393
- subtitle, file_path = self.generate_and_write_file(
394
- file_name=file_name,
395
- transcribed_segments=transcribed_segments,
396
- add_timestamp=add_timestamp,
397
- file_format=output_format.lower(),
398
- output_dir=self.output_dir
399
- )
400
- files_to_download[file_name+"_"+output_format.lower()] = {"path": file_path}
401
-
402
- total_result = ""
403
- total_info = ""
404
- total_time = 0
405
- file_count = 0
406
- for file_name, info in files_info.items():
407
-
408
- file_count += 1
409
-
410
- if file_count > 1:
411
- total_info += f'\n'
412
-
413
- if file_count_total > 1:
414
- if file_count > 1:
415
- total_result += f'\n'
416
- total_result += f'« Transcription of media file \'{info["input_source_file"]}\': »\n\n'
417
-
418
- total_time += info["time_for_task"]
419
- total_result += f'{info["subtitle"]}'
420
- total_info += f'Media file:\t{info["input_source_file"]}\nLanguage:\t{info["lang"]} (probability {info["lang_prob"]}%)\n'
421
-
422
- if params.is_translate:
423
- total_info += f'Translation:\t{info["transcription"]}\n\t⤷ Handled by OpenAI Whisper\n'
424
-
425
- if translate_output:
426
- total_info += f'Translation:\t{info["translation"]}\n\t⤷ Handled by Facebook NLLB\n'
427
-
428
- time_end = datetime.now()
429
- #total_info += f"\nTotal processing time:\t{self.format_time((time_end-time_start).total_seconds())}"
430
-
431
- temp_file_count_text = "file"
432
- if file_count!=1:
433
- temp_file_count_text += "s"
434
- total_info += f"\nProcessed {file_count} {temp_file_count_text} in {self.format_time((time_end-time_start).total_seconds())}"
435
-
436
- result_str = total_result.rstrip("\n")
437
- result_str = self.transform_text_to_list(result_str)
438
- result_file_path = [info['path'] for info in files_to_download.values()]
439
-
440
- return [result_str,result_file_path,total_info]
441
-
442
- except Exception as e:
443
- print(f"Error transcribing file: {e}")
444
- finally:
445
- self.release_cuda_memory()
446
-
447
- def transcribe_mic(self,
448
- mic_audio: str,
449
- file_format: str = "SRT",
450
- add_timestamp: bool = True,
451
- progress=gr.Progress(),
452
- *whisper_params,
453
- ) -> list:
454
- """
455
- Write subtitle file from microphone
456
-
457
- Parameters
458
- ----------
459
- mic_audio: str
460
- Audio file path from gr.Microphone()
461
- file_format: str
462
- Subtitle File format to write from gr.Dropdown(). Supported format: [SRT, WebVTT, txt]
463
- add_timestamp: bool
464
- Boolean value from gr.Checkbox() that determines whether to add a timestamp at the end of the filename.
465
- progress: gr.Progress
466
- Indicator to show progress directly in gradio.
467
- *whisper_params: tuple
468
- Parameters related with whisper. This will be dealt with "WhisperParameters" data class
469
-
470
- Returns
471
- ----------
472
- result_str:
473
- Result of transcription to return to gr.Textbox()
474
- result_file_path:
475
- Output file path to return to gr.Files()
476
- """
477
- try:
478
- progress(0, desc="Loading Audio...")
479
- transcribed_segments, time_for_task = self.run(
480
- mic_audio,
481
- progress,
482
- add_timestamp,
483
- *whisper_params,
484
- )
485
- progress(1, desc="Completed!")
486
-
487
- subtitle, result_file_path = self.generate_and_write_file(
488
- file_name="Mic",
489
- transcribed_segments=transcribed_segments,
490
- add_timestamp=add_timestamp,
491
- file_format=file_format,
492
- output_dir=self.output_dir
493
- )
494
-
495
- result_str = f"Done in {self.format_time(time_for_task)}! Subtitle file is in the outputs folder.\n\n{subtitle}"
496
- return [result_str, result_file_path]
497
- except Exception as e:
498
- print(f"Error transcribing file: {e}")
499
- finally:
500
- self.release_cuda_memory()
501
-
502
- def transcribe_youtube(self,
503
- youtube_link: str,
504
- file_format: str = "SRT",
505
- add_timestamp: bool = True,
506
- progress=gr.Progress(),
507
- *whisper_params,
508
- ) -> list:
509
- """
510
- Write subtitle file from Youtube
511
-
512
- Parameters
513
- ----------
514
- youtube_link: str
515
- URL of the Youtube video to transcribe from gr.Textbox()
516
- file_format: str
517
- Subtitle File format to write from gr.Dropdown(). Supported format: [SRT, WebVTT, txt]
518
- add_timestamp: bool
519
- Boolean value from gr.Checkbox() that determines whether to add a timestamp at the end of the filename.
520
- progress: gr.Progress
521
- Indicator to show progress directly in gradio.
522
- *whisper_params: tuple
523
- Parameters related with whisper. This will be dealt with "WhisperParameters" data class
524
-
525
- Returns
526
- ----------
527
- result_str:
528
- Result of transcription to return to gr.Textbox()
529
- result_file_path:
530
- Output file path to return to gr.Files()
531
- """
532
- try:
533
- progress(0, desc="Loading Audio from Youtube...")
534
- yt = get_ytdata(youtube_link)
535
- audio = get_ytaudio(yt)
536
-
537
- transcribed_segments, time_for_task = self.run(
538
- audio,
539
- progress,
540
- add_timestamp,
541
- *whisper_params,
542
- )
543
-
544
- progress(1, desc="Completed!")
545
-
546
- file_name = safe_filename(yt.title)
547
- subtitle, result_file_path = self.generate_and_write_file(
548
- file_name=file_name,
549
- transcribed_segments=transcribed_segments,
550
- add_timestamp=add_timestamp,
551
- file_format=file_format,
552
- output_dir=self.output_dir
553
- )
554
- result_str = f"Done in {self.format_time(time_for_task)}! Subtitle file is in the outputs folder.\n\n{subtitle}"
555
-
556
- if os.path.exists(audio):
557
- os.remove(audio)
558
-
559
- return [result_str, result_file_path]
560
-
561
- except Exception as e:
562
- print(f"Error transcribing file: {e}")
563
- finally:
564
- self.release_cuda_memory()
565
-
566
- @staticmethod
567
- def generate_and_write_file(file_name: str,
568
- transcribed_segments: list,
569
- add_timestamp: bool,
570
- file_format: str,
571
- output_dir: str
572
- ) -> str:
573
- """
574
- Writes subtitle file
575
-
576
- Parameters
577
- ----------
578
- file_name: str
579
- Output file name
580
- transcribed_segments: list
581
- Text segments transcribed from audio
582
- add_timestamp: bool
583
- Determines whether to add a timestamp to the end of the filename.
584
- file_format: str
585
- File format to write. Supported formats: [SRT, WebVTT, txt, csv]
586
- output_dir: str
587
- Directory path of the output
588
-
589
- Returns
590
- ----------
591
- content: str
592
- Result of the transcription
593
- output_path: str
594
- output file path
595
- """
596
- if add_timestamp:
597
- #timestamp = datetime.now().strftime("%m%d%H%M%S")
598
- timestamp = datetime.now().strftime("%Y%m%d %H%M%S")
599
- output_path = os.path.join(output_dir, f"{file_name} - {timestamp}")
600
- else:
601
- output_path = os.path.join(output_dir, f"{file_name}")
602
-
603
- file_format = file_format.strip().lower()
604
- if file_format == "srt":
605
- content = get_srt(transcribed_segments)
606
- output_path += '.srt'
607
-
608
- elif file_format == "webvtt":
609
- content = get_vtt(transcribed_segments)
610
- output_path += '.vtt'
611
-
612
- elif file_format == "txt":
613
- content = get_txt(transcribed_segments)
614
- output_path += '.txt'
615
-
616
- elif file_format == "csv":
617
- content = get_csv(transcribed_segments)
618
- output_path += '.csv'
619
-
620
- write_file(content, output_path)
621
- return content, output_path
622
-
623
- def offload(self):
624
- """Offload the model and free up the memory"""
625
- if self.model is not None:
626
- del self.model
627
- self.model = None
628
- if self.device == "cuda":
629
- self.release_cuda_memory()
630
- gc.collect()
631
-
632
- @staticmethod
633
- def transform_text_to_list(inputdata: str) -> list:
634
- outputdata = []
635
- temp_inputdata = (inputdata.strip("\n")).splitlines()
636
- for temp_line in temp_inputdata:
637
- temp_line_list = []
638
- temp_line_items = temp_line.split("\t")
639
- for temp_line_item in temp_line_items:
640
- temp_line_list.append(temp_line_item)
641
- outputdata.append(temp_line_list)
642
-
643
- return outputdata
644
-
645
- @staticmethod
646
- def format_time(elapsed_time: float) -> str:
647
- """
648
- Get {hours} {minutes} {seconds} time format string
649
-
650
- Parameters
651
- ----------
652
- elapsed_time: str
653
- Elapsed time for transcription
654
-
655
- Returns
656
- ----------
657
- Time format string
658
- """
659
- hours, rem = divmod(elapsed_time, 3600)
660
- minutes, seconds = divmod(rem, 60)
661
-
662
- time_str = ""
663
-
664
- hours = round(hours)
665
- if hours:
666
- if hours == 1:
667
- time_str += f"{hours} hour "
668
- else:
669
- time_str += f"{hours} hours "
670
-
671
- minutes = round(minutes)
672
- if minutes:
673
- if minutes == 1:
674
- time_str += f"{minutes} minute "
675
- else:
676
- time_str += f"{minutes} minutes "
677
-
678
- seconds = round(seconds)
679
- if seconds == 1:
680
- time_str += f"{seconds} second"
681
- else:
682
- time_str += f"{seconds} seconds"
683
-
684
- return time_str.strip()
685
-
686
- @staticmethod
687
- def get_device():
688
- if torch.cuda.is_available():
689
- return "cuda"
690
- elif torch.backends.mps.is_available():
691
- if not WhisperBase.is_sparse_api_supported():
692
- # Device `SparseMPS` is not supported for now. See : https://github.com/pytorch/pytorch/issues/87886
693
- return "cpu"
694
- return "mps"
695
- else:
696
- return "cpu"
697
-
698
- @staticmethod
699
- def is_sparse_api_supported():
700
- if not torch.backends.mps.is_available():
701
- return False
702
-
703
- try:
704
- device = torch.device("mps")
705
- sparse_tensor = torch.sparse_coo_tensor(
706
- indices=torch.tensor([[0, 1], [2, 3]]),
707
- values=torch.tensor([1, 2]),
708
- size=(4, 4),
709
- device=device
710
- )
711
- return True
712
- except RuntimeError:
713
- return False
714
-
715
- @staticmethod
716
- def release_cuda_memory():
717
- """Release memory"""
718
- if torch.cuda.is_available():
719
- torch.cuda.empty_cache()
720
- torch.cuda.reset_max_memory_allocated()
721
-
722
- @staticmethod
723
- def remove_input_files(file_paths: List[str]):
724
- """Remove gradio cached files"""
725
- if not file_paths:
726
- return
727
-
728
- for file_path in file_paths:
729
- if file_path and os.path.exists(file_path):
730
- os.remove(file_path)
731
-
732
- @staticmethod
733
- def cache_parameters(
734
- params: WhisperValues,
735
- file_format: str = "SRT",
736
- add_timestamp: bool = True
737
- ):
738
- """Cache parameters to the yaml file"""
739
- cached_params = load_yaml(DEFAULT_PARAMETERS_CONFIG_PATH)
740
- param_to_cache = params.to_dict()
741
-
742
- cached_yaml = {**cached_params, **param_to_cache}
743
- cached_yaml["whisper"]["add_timestamp"] = add_timestamp
744
- cached_yaml["whisper"]["file_format"] = file_format
745
-
746
- suppress_token = cached_yaml["whisper"].get("suppress_tokens", None)
747
- if suppress_token and isinstance(suppress_token, list):
748
- cached_yaml["whisper"]["suppress_tokens"] = str(suppress_token)
749
-
750
- if cached_yaml["whisper"].get("lang", None) is None:
751
- cached_yaml["whisper"]["lang"] = AUTOMATIC_DETECTION.unwrap()
752
- else:
753
- language_dict = whisper.tokenizer.LANGUAGES
754
- cached_yaml["whisper"]["lang"] = language_dict[cached_yaml["whisper"]["lang"]]
755
-
756
- if cached_yaml["vad"].get("max_speech_duration_s", float('inf')) == float('inf'):
757
- cached_yaml["vad"]["max_speech_duration_s"] = GRADIO_NONE_NUMBER_MAX
758
-
759
- if cached_yaml is not None and cached_yaml:
760
- save_yaml(cached_yaml, DEFAULT_PARAMETERS_CONFIG_PATH)
761
-
762
- @staticmethod
763
- def resample_audio(audio: Union[str, np.ndarray],
764
- new_sample_rate: int = 16000,
765
- original_sample_rate: Optional[int] = None,) -> np.ndarray:
766
- """Resamples audio to 16k sample rate, standard on Whisper model"""
767
- if isinstance(audio, str):
768
- audio, original_sample_rate = torchaudio.load(audio)
769
- else:
770
- if original_sample_rate is None:
771
- raise ValueError("original_sample_rate must be provided when audio is numpy array.")
772
- audio = torch.from_numpy(audio)
773
- resampler = torchaudio.transforms.Resample(orig_freq=original_sample_rate, new_freq=new_sample_rate)
774
- resampled_audio = resampler(audio).numpy()
775
- return resampled_audio