jhj0517 commited on
Commit
31adf69
·
2 Parent(s): 2353351 0b0f426

Merge master

Browse files
.gitignore CHANGED
@@ -1,3 +1,7 @@
 
 
 
 
1
  venv/
2
  ui/__pycache__/
3
  outputs/
 
1
+ *.wav
2
+ *.png
3
+ *.mp4
4
+ *.mp3
5
  venv/
6
  ui/__pycache__/
7
  outputs/
app.py CHANGED
@@ -1,7 +1,8 @@
1
  import os
2
  import argparse
 
3
 
4
- from modules.whisper.whisper_Inference import WhisperInference
5
  from modules.whisper.faster_whisper_inference import FasterWhisperInference
6
  from modules.whisper.insanely_fast_whisper_inference import InsanelyFastWhisperInference
7
  from modules.translation.nllb_inference import NLLBInference
@@ -15,68 +16,150 @@ class App:
15
  def __init__(self, args):
16
  self.args = args
17
  self.app = gr.Blocks(css=CSS, theme=self.args.theme)
18
- self.whisper_inf = FasterWhisperInference(
19
- model_dir=self.args.faster_whisper_model_dir,
 
 
 
20
  output_dir=self.args.output_dir,
21
- args=self.args
22
  )
23
  print(f"Use \"{self.args.whisper_type}\" implementation")
24
  print(f"Device \"{self.whisper_inf.device}\" is detected")
25
  self.nllb_inf = NLLBInference(
26
  model_dir=self.args.nllb_model_dir,
27
- output_dir=self.args.output_dir
28
  )
29
  self.deepl_api = DeepLAPI(
30
- output_dir=self.args.output_dir
31
  )
32
 
33
- def init_whisper(self):
34
- # Temporal fix of the issue : https://github.com/jhj0517/Whisper-WebUI/issues/144
35
- os.environ['KMP_DUPLICATE_LIB_OK'] = 'True'
36
-
37
- whisper_type = self.args.whisper_type.lower().strip()
38
-
39
- if whisper_type in ["faster_whisper", "faster-whisper", "fasterwhisper"]:
40
- whisper_inf = FasterWhisperInference(
41
- model_dir=self.args.faster_whisper_model_dir,
42
- output_dir=self.args.output_dir,
43
- args=self.args
44
- )
45
- elif whisper_type in ["whisper"]:
46
- whisper_inf = WhisperInference(
47
- model_dir=self.args.whisper_model_dir,
48
- output_dir=self.args.output_dir,
49
- args=self.args
50
- )
51
- elif whisper_type in ["insanely_fast_whisper", "insanely-fast-whisper", "insanelyfastwhisper",
52
- "insanely_faster_whisper", "insanely-faster-whisper", "insanelyfasterwhisper"]:
53
- whisper_inf = InsanelyFastWhisperInference(
54
- model_dir=self.args.insanely_fast_whisper_model_dir,
55
- output_dir=self.args.output_dir,
56
- args=self.args
57
- )
58
- else:
59
- whisper_inf = FasterWhisperInference(
60
- model_dir=self.args.faster_whisper_model_dir,
61
- output_dir=self.args.output_dir,
62
- args=self.args
63
- )
64
- return whisper_inf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
65
 
66
- @staticmethod
67
- def open_folder(folder_path: str):
68
- if os.path.exists(folder_path):
69
- os.system(f"start {folder_path}")
70
- else:
71
- print(f"The folder {folder_path} does not exist.")
 
 
 
 
 
 
 
 
 
72
 
73
- @staticmethod
74
- def on_change_models(model_size: str):
75
- translatable_model = ["large", "large-v1", "large-v2", "large-v3"]
76
- if model_size not in translatable_model:
77
- return gr.Checkbox(visible=False, value=False, interactive=False)
78
- else:
79
- return gr.Checkbox(visible=True, value=False, label="Translate to English?", interactive=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
80
 
81
  def launch(self):
82
  with self.app:
@@ -85,84 +168,28 @@ class App:
85
  gr.Markdown(MARKDOWN, elem_id="md_project")
86
  with gr.Tabs():
87
  with gr.TabItem("File"): # tab1
88
- with gr.Row():
89
  input_file = gr.Files(type="filepath", label="Upload File here")
90
- with gr.Row():
91
- dd_model = gr.Dropdown(choices=self.whisper_inf.available_models, value="large-v2",
92
- label="Model")
93
- dd_lang = gr.Dropdown(choices=["Automatic Detection"] + self.whisper_inf.available_langs,
94
- value="Automatic Detection", label="Language")
95
- dd_file_format = gr.Dropdown(["SRT", "WebVTT", "txt"], value="SRT", label="File Format")
96
- with gr.Row():
97
- cb_translate = gr.Checkbox(value=False, label="Translate to English?", interactive=True)
98
- with gr.Row():
99
- cb_timestamp = gr.Checkbox(value=True, label="Add a timestamp to the end of the filename", interactive=True)
100
- with gr.Accordion("Advanced Parameters", open=False):
101
- nb_beam_size = gr.Number(label="Beam Size", value=1, precision=0, interactive=True)
102
- nb_log_prob_threshold = gr.Number(label="Log Probability Threshold", value=-1.0, interactive=True)
103
- nb_no_speech_threshold = gr.Number(label="No Speech Threshold", value=0.6, interactive=True)
104
- dd_compute_type = gr.Dropdown(label="Compute Type", choices=self.whisper_inf.available_compute_types, value=self.whisper_inf.current_compute_type, interactive=True)
105
- nb_best_of = gr.Number(label="Best Of", value=5, interactive=True)
106
- nb_patience = gr.Number(label="Patience", value=1, interactive=True)
107
- cb_condition_on_previous_text = gr.Checkbox(label="Condition On Previous Text", value=True, interactive=True)
108
- tb_initial_prompt = gr.Textbox(label="Initial Prompt", value=None, interactive=True)
109
- sd_temperature = gr.Slider(label="Temperature", value=0, step=0.01, maximum=1.0, interactive=True)
110
- nb_compression_ratio_threshold = gr.Number(label="Compression Ratio Threshold", value=2.4, interactive=True)
111
- with gr.Accordion("VAD", open=False):
112
- cb_vad_filter = gr.Checkbox(label="Enable Silero VAD Filter", value=False, interactive=True)
113
- sd_threshold = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label="Speech Threshold", value=0.5, info="Lower it to be more sensitive to small sounds.")
114
- nb_min_speech_duration_ms = gr.Number(label="Minimum Speech Duration (ms)", precision=0, value=250)
115
- nb_max_speech_duration_s = gr.Number(label="Maximum Speech Duration (s)", value=9999)
116
- nb_min_silence_duration_ms = gr.Number(label="Minimum Silence Duration (ms)", precision=0, value=2000)
117
- nb_window_size_sample = gr.Number(label="Window Size (samples)", precision=0, value=1024)
118
- nb_speech_pad_ms = gr.Number(label="Speech Padding (ms)", precision=0, value=400)
119
- with gr.Accordion("Diarization", open=False):
120
- cb_diarize = gr.Checkbox(label="Enable Diarization")
121
- tb_hf_token = gr.Text(label="HuggingFace Token", value="",
122
- info="This is only needed the first time you download the model. If you already have models, you don't need to enter. "
123
- "To download the model, you must manually go to \"https://huggingface.co/pyannote/speaker-diarization-3.1\" and agree to their requirement.")
124
- dd_diarization_device = gr.Dropdown(label="Device", choices=self.whisper_inf.diarizer.get_available_device(), value=self.whisper_inf.diarizer.get_device())
125
- with gr.Accordion("Insanely Fast Whisper Parameters", open=False, visible=isinstance(self.whisper_inf, InsanelyFastWhisperInference)):
126
- nb_chunk_length_s = gr.Number(label="Chunk Lengths (sec)", value=30, precision=0)
127
- nb_batch_size = gr.Number(label="Batch Size", value=24, precision=0)
128
  with gr.Row():
129
  btn_run = gr.Button("GENERATE SUBTITLE FILE", variant="primary")
130
  with gr.Row():
131
- tb_indicator = gr.Textbox(label="Output", scale=6)
132
  files_subtitles = gr.Files(label="Downloadable output file", scale=3, interactive=False)
 
133
 
134
- params = [input_file, dd_file_format, cb_timestamp]
135
-
136
- whisper_params = WhisperParameters(model_size=dd_model,
137
- lang=dd_lang,
138
- is_translate=cb_translate,
139
- beam_size=nb_beam_size,
140
- log_prob_threshold=nb_log_prob_threshold,
141
- no_speech_threshold=nb_no_speech_threshold,
142
- compute_type=dd_compute_type,
143
- best_of=nb_best_of,
144
- patience=nb_patience,
145
- condition_on_previous_text=cb_condition_on_previous_text,
146
- initial_prompt=tb_initial_prompt,
147
- temperature=sd_temperature,
148
- compression_ratio_threshold=nb_compression_ratio_threshold,
149
- vad_filter=cb_vad_filter,
150
- threshold=sd_threshold,
151
- min_speech_duration_ms=nb_min_speech_duration_ms,
152
- max_speech_duration_s=nb_max_speech_duration_s,
153
- min_silence_duration_ms=nb_min_silence_duration_ms,
154
- window_size_sample=nb_window_size_sample,
155
- speech_pad_ms=nb_speech_pad_ms,
156
- chunk_length_s=nb_chunk_length_s,
157
- batch_size=nb_batch_size,
158
- is_diarize=cb_diarize,
159
- hf_token=tb_hf_token,
160
- diarization_device=dd_diarization_device)
161
-
162
  btn_run.click(fn=self.whisper_inf.transcribe_file,
163
- inputs=params+whisper_params.as_list(),
164
  outputs=[tb_indicator, files_subtitles])
165
- dd_model.change(fn=self.on_change_models, inputs=[dd_model], outputs=[cb_translate])
166
 
167
  with gr.TabItem("Youtube"): # tab2
168
  with gr.Row():
@@ -173,164 +200,44 @@ class App:
173
  with gr.Column():
174
  tb_title = gr.Label(label="Youtube Title")
175
  tb_description = gr.Textbox(label="Youtube Description", max_lines=15)
176
- with gr.Row():
177
- dd_model = gr.Dropdown(choices=self.whisper_inf.available_models, value="large-v2",
178
- label="Model")
179
- dd_lang = gr.Dropdown(choices=["Automatic Detection"] + self.whisper_inf.available_langs,
180
- value="Automatic Detection", label="Language")
181
- dd_file_format = gr.Dropdown(choices=["SRT", "WebVTT", "txt"], value="SRT", label="File Format")
182
- with gr.Row():
183
- cb_translate = gr.Checkbox(value=False, label="Translate to English?", interactive=True)
184
- with gr.Row():
185
- cb_timestamp = gr.Checkbox(value=True, label="Add a timestamp to the end of the filename",
186
- interactive=True)
187
- with gr.Accordion("Advanced Parameters", open=False):
188
- nb_beam_size = gr.Number(label="Beam Size", value=1, precision=0, interactive=True)
189
- nb_log_prob_threshold = gr.Number(label="Log Probability Threshold", value=-1.0, interactive=True)
190
- nb_no_speech_threshold = gr.Number(label="No Speech Threshold", value=0.6, interactive=True)
191
- dd_compute_type = gr.Dropdown(label="Compute Type", choices=self.whisper_inf.available_compute_types, value=self.whisper_inf.current_compute_type, interactive=True)
192
- nb_best_of = gr.Number(label="Best Of", value=5, interactive=True)
193
- nb_patience = gr.Number(label="Patience", value=1, interactive=True)
194
- cb_condition_on_previous_text = gr.Checkbox(label="Condition On Previous Text", value=True, interactive=True)
195
- tb_initial_prompt = gr.Textbox(label="Initial Prompt", value=None, interactive=True)
196
- sd_temperature = gr.Slider(label="Temperature", value=0, step=0.01, maximum=1.0, interactive=True)
197
- nb_compression_ratio_threshold = gr.Number(label="Compression Ratio Threshold", value=2.4, interactive=True)
198
- with gr.Accordion("VAD", open=False):
199
- cb_vad_filter = gr.Checkbox(label="Enable Silero VAD Filter", value=False, interactive=True)
200
- sd_threshold = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label="Speech Threshold", value=0.5, info="Lower it to be more sensitive to small sounds.")
201
- nb_min_speech_duration_ms = gr.Number(label="Minimum Speech Duration (ms)", precision=0, value=250)
202
- nb_max_speech_duration_s = gr.Number(label="Maximum Speech Duration (s)", value=9999)
203
- nb_min_silence_duration_ms = gr.Number(label="Minimum Silence Duration (ms)", precision=0, value=2000)
204
- nb_window_size_sample = gr.Number(label="Window Size (samples)", precision=0, value=1024)
205
- nb_speech_pad_ms = gr.Number(label="Speech Padding (ms)", precision=0, value=400)
206
- with gr.Accordion("Diarization", open=False):
207
- cb_diarize = gr.Checkbox(label="Enable Diarization")
208
- tb_hf_token = gr.Text(label="HuggingFace Token", value="",
209
- info="This is only needed the first time you download the model. If you already have models, you don't need to enter. "
210
- "To download the model, you must manually go to \"https://huggingface.co/pyannote/speaker-diarization-3.1\" and agree to their requirement.")
211
- dd_diarization_device = gr.Dropdown(label="Device", choices=self.whisper_inf.diarizer.get_available_device(), value=self.whisper_inf.diarizer.get_device())
212
- with gr.Accordion("Insanely Fast Whisper Parameters", open=False,
213
- visible=isinstance(self.whisper_inf, InsanelyFastWhisperInference)):
214
- nb_chunk_length_s = gr.Number(label="Chunk Lengths (sec)", value=30, precision=0)
215
- nb_batch_size = gr.Number(label="Batch Size", value=24, precision=0)
216
  with gr.Row():
217
  btn_run = gr.Button("GENERATE SUBTITLE FILE", variant="primary")
218
  with gr.Row():
219
- tb_indicator = gr.Textbox(label="Output", scale=6)
220
  files_subtitles = gr.Files(label="Downloadable output file", scale=3)
 
221
 
222
  params = [tb_youtubelink, dd_file_format, cb_timestamp]
223
- whisper_params = WhisperParameters(model_size=dd_model,
224
- lang=dd_lang,
225
- is_translate=cb_translate,
226
- beam_size=nb_beam_size,
227
- log_prob_threshold=nb_log_prob_threshold,
228
- no_speech_threshold=nb_no_speech_threshold,
229
- compute_type=dd_compute_type,
230
- best_of=nb_best_of,
231
- patience=nb_patience,
232
- condition_on_previous_text=cb_condition_on_previous_text,
233
- initial_prompt=tb_initial_prompt,
234
- temperature=sd_temperature,
235
- compression_ratio_threshold=nb_compression_ratio_threshold,
236
- vad_filter=cb_vad_filter,
237
- threshold=sd_threshold,
238
- min_speech_duration_ms=nb_min_speech_duration_ms,
239
- max_speech_duration_s=nb_max_speech_duration_s,
240
- min_silence_duration_ms=nb_min_silence_duration_ms,
241
- window_size_sample=nb_window_size_sample,
242
- speech_pad_ms=nb_speech_pad_ms,
243
- chunk_length_s=nb_chunk_length_s,
244
- batch_size=nb_batch_size,
245
- is_diarize=cb_diarize,
246
- hf_token=tb_hf_token,
247
- diarization_device=dd_diarization_device)
248
 
249
  btn_run.click(fn=self.whisper_inf.transcribe_youtube,
250
  inputs=params + whisper_params.as_list(),
251
  outputs=[tb_indicator, files_subtitles])
252
  tb_youtubelink.change(get_ytmetas, inputs=[tb_youtubelink],
253
  outputs=[img_thumbnail, tb_title, tb_description])
254
- dd_model.change(fn=self.on_change_models, inputs=[dd_model], outputs=[cb_translate])
255
 
256
  with gr.TabItem("Mic"): # tab3
257
  with gr.Row():
258
  mic_input = gr.Microphone(label="Record with Mic", type="filepath", interactive=True)
259
- with gr.Row():
260
- dd_model = gr.Dropdown(choices=self.whisper_inf.available_models, value="large-v2",
261
- label="Model")
262
- dd_lang = gr.Dropdown(choices=["Automatic Detection"] + self.whisper_inf.available_langs,
263
- value="Automatic Detection", label="Language")
264
- dd_file_format = gr.Dropdown(["SRT", "WebVTT", "txt"], value="SRT", label="File Format")
265
- with gr.Row():
266
- cb_translate = gr.Checkbox(value=False, label="Translate to English?", interactive=True)
267
- with gr.Accordion("Advanced Parameters", open=False):
268
- nb_beam_size = gr.Number(label="Beam Size", value=1, precision=0, interactive=True)
269
- nb_log_prob_threshold = gr.Number(label="Log Probability Threshold", value=-1.0, interactive=True)
270
- nb_no_speech_threshold = gr.Number(label="No Speech Threshold", value=0.6, interactive=True)
271
- dd_compute_type = gr.Dropdown(label="Compute Type", choices=self.whisper_inf.available_compute_types, value=self.whisper_inf.current_compute_type, interactive=True)
272
- nb_best_of = gr.Number(label="Best Of", value=5, interactive=True)
273
- nb_patience = gr.Number(label="Patience", value=1, interactive=True)
274
- cb_condition_on_previous_text = gr.Checkbox(label="Condition On Previous Text", value=True, interactive=True)
275
- tb_initial_prompt = gr.Textbox(label="Initial Prompt", value=None, interactive=True)
276
- sd_temperature = gr.Slider(label="Temperature", value=0, step=0.01, maximum=1.0, interactive=True)
277
- with gr.Accordion("VAD", open=False):
278
- cb_vad_filter = gr.Checkbox(label="Enable Silero VAD Filter", value=False, interactive=True)
279
- sd_threshold = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label="Speech Threshold", value=0.5, info="Lower it to be more sensitive to small sounds.")
280
- nb_min_speech_duration_ms = gr.Number(label="Minimum Speech Duration (ms)", precision=0, value=250)
281
- nb_max_speech_duration_s = gr.Number(label="Maximum Speech Duration (s)", value=9999)
282
- nb_min_silence_duration_ms = gr.Number(label="Minimum Silence Duration (ms)", precision=0, value=2000)
283
- nb_window_size_sample = gr.Number(label="Window Size (samples)", precision=0, value=1024)
284
- nb_speech_pad_ms = gr.Number(label="Speech Padding (ms)", precision=0, value=400)
285
- with gr.Accordion("Diarization", open=False):
286
- cb_diarize = gr.Checkbox(label="Enable Diarization")
287
- tb_hf_token = gr.Text(label="HuggingFace Token", value="",
288
- info="This is only needed the first time you download the model. If you already have models, you don't need to enter. "
289
- "To download the model, you must manually go to \"https://huggingface.co/pyannote/speaker-diarization-3.1\" and agree to their requirement.")
290
- dd_diarization_device = gr.Dropdown(label="Device",
291
- choices=self.whisper_inf.diarizer.get_available_device(),
292
- value=self.whisper_inf.diarizer.get_device())
293
- with gr.Accordion("Insanely Fast Whisper Parameters", open=False,
294
- visible=isinstance(self.whisper_inf, InsanelyFastWhisperInference)):
295
- nb_chunk_length_s = gr.Number(label="Chunk Lengths (sec)", value=30, precision=0)
296
- nb_batch_size = gr.Number(label="Batch Size", value=24, precision=0)
297
  with gr.Row():
298
  btn_run = gr.Button("GENERATE SUBTITLE FILE", variant="primary")
299
  with gr.Row():
300
- tb_indicator = gr.Textbox(label="Output", scale=6)
301
  files_subtitles = gr.Files(label="Downloadable output file", scale=3)
 
302
 
303
  params = [mic_input, dd_file_format]
304
- whisper_params = WhisperParameters(model_size=dd_model,
305
- lang=dd_lang,
306
- is_translate=cb_translate,
307
- beam_size=nb_beam_size,
308
- log_prob_threshold=nb_log_prob_threshold,
309
- no_speech_threshold=nb_no_speech_threshold,
310
- compute_type=dd_compute_type,
311
- best_of=nb_best_of,
312
- patience=nb_patience,
313
- condition_on_previous_text=cb_condition_on_previous_text,
314
- initial_prompt=tb_initial_prompt,
315
- temperature=sd_temperature,
316
- compression_ratio_threshold=nb_compression_ratio_threshold,
317
- vad_filter=cb_vad_filter,
318
- threshold=sd_threshold,
319
- min_speech_duration_ms=nb_min_speech_duration_ms,
320
- max_speech_duration_s=nb_max_speech_duration_s,
321
- min_silence_duration_ms=nb_min_silence_duration_ms,
322
- window_size_sample=nb_window_size_sample,
323
- speech_pad_ms=nb_speech_pad_ms,
324
- chunk_length_s=nb_chunk_length_s,
325
- batch_size=nb_batch_size,
326
- is_diarize=cb_diarize,
327
- hf_token=tb_hf_token,
328
- diarization_device=dd_diarization_device)
329
 
330
  btn_run.click(fn=self.whisper_inf.transcribe_mic,
331
  inputs=params + whisper_params.as_list(),
332
  outputs=[tb_indicator, files_subtitles])
333
- dd_model.change(fn=self.on_change_models, inputs=[dd_model], outputs=[cb_translate])
334
 
335
  with gr.TabItem("T2T Translation"): # tab 4
336
  with gr.Row():
@@ -350,17 +257,25 @@ class App:
350
  self.deepl_api.available_target_langs.keys()))
351
  with gr.Row():
352
  cb_deepl_ispro = gr.Checkbox(label="Pro User?", value=False)
 
 
 
353
  with gr.Row():
354
  btn_run = gr.Button("TRANSLATE SUBTITLE FILE", variant="primary")
355
  with gr.Row():
356
  tb_indicator = gr.Textbox(label="Output", scale=5)
357
  files_subtitles = gr.Files(label="Downloadable output file", scale=3)
 
358
 
359
  btn_run.click(fn=self.deepl_api.translate_deepl,
360
  inputs=[tb_authkey, file_subs, dd_deepl_sourcelang, dd_deepl_targetlang,
361
- cb_deepl_ispro],
362
  outputs=[tb_indicator, files_subtitles])
363
 
 
 
 
 
364
  with gr.TabItem("NLLB"): # sub tab2
365
  with gr.Row():
366
  dd_nllb_model = gr.Dropdown(label="Model", value="facebook/nllb-200-1.3B",
@@ -369,6 +284,8 @@ class App:
369
  choices=self.nllb_inf.available_source_langs)
370
  dd_nllb_targetlang = gr.Dropdown(label="Target Language",
371
  choices=self.nllb_inf.available_target_langs)
 
 
372
  with gr.Row():
373
  cb_timestamp = gr.Checkbox(value=True, label="Add a timestamp to the end of the filename",
374
  interactive=True)
@@ -377,33 +294,53 @@ class App:
377
  with gr.Row():
378
  tb_indicator = gr.Textbox(label="Output", scale=5)
379
  files_subtitles = gr.Files(label="Downloadable output file", scale=3)
 
380
  with gr.Column():
381
  md_vram_table = gr.HTML(NLLB_VRAM_TABLE, elem_id="md_nllb_vram_table")
382
 
383
  btn_run.click(fn=self.nllb_inf.translate_file,
384
- inputs=[file_subs, dd_nllb_model, dd_nllb_sourcelang, dd_nllb_targetlang, cb_timestamp],
 
385
  outputs=[tb_indicator, files_subtitles])
386
 
 
 
 
 
387
  # Launch the app with optional gradio settings
388
- launch_args = {}
389
- if self.args.share:
390
- launch_args['share'] = self.args.share
391
- if self.args.server_name:
392
- launch_args['server_name'] = self.args.server_name
393
- if self.args.server_port:
394
- launch_args['server_port'] = self.args.server_port
395
- if self.args.username and self.args.password:
396
- launch_args['auth'] = (self.args.username, self.args.password)
397
- if self.args.root_path:
398
- launch_args['root_path'] = self.args.root_path
399
- launch_args['inbrowser'] = True
400
-
401
- self.app.queue(api_open=False).launch(**launch_args)
 
 
 
 
 
 
 
 
 
 
 
 
 
402
 
403
 
404
  # Create the parser for command-line arguments
405
  parser = argparse.ArgumentParser()
406
- parser.add_argument('--whisper_type', type=str, default="faster-whisper", help='A type of the whisper implementation between: ["whisper", "faster-whisper", "insanely-fast-whisper"]')
 
407
  parser.add_argument('--share', type=bool, default=False, nargs='?', const=True, help='Gradio share value')
408
  parser.add_argument('--server_name', type=str, default=None, help='Gradio server host')
409
  parser.add_argument('--server_port', type=int, default=None, help='Gradio server port')
@@ -412,12 +349,19 @@ parser.add_argument('--username', type=str, default=None, help='Gradio authentic
412
  parser.add_argument('--password', type=str, default=None, help='Gradio authentication password')
413
  parser.add_argument('--theme', type=str, default=None, help='Gradio Blocks theme')
414
  parser.add_argument('--colab', type=bool, default=False, nargs='?', const=True, help='Is colab user or not')
415
- parser.add_argument('--api_open', type=bool, default=False, nargs='?', const=True, help='enable api or not')
416
- parser.add_argument('--whisper_model_dir', type=str, default=os.path.join("models", "Whisper"), help='Directory path of the whisper model')
417
- parser.add_argument('--faster_whisper_model_dir', type=str, default=os.path.join("models", "Whisper", "faster-whisper"), help='Directory path of the faster-whisper model')
418
- parser.add_argument('--insanely_fast_whisper_model_dir', type=str, default=os.path.join("models", "Whisper", "insanely-fast-whisper"), help='Directory path of the insanely-fast-whisper model')
419
- parser.add_argument('--diarization_model_dir', type=str, default=os.path.join("models", "Diarization"), help='Directory path of the diarization model')
420
- parser.add_argument('--nllb_model_dir', type=str, default=os.path.join("models", "NLLB"), help='Directory path of the Facebook NLLB model')
 
 
 
 
 
 
 
421
  parser.add_argument('--output_dir', type=str, default=os.path.join("outputs"), help='Directory path of the outputs')
422
  _args = parser.parse_args()
423
 
 
1
  import os
2
  import argparse
3
+ import gradio as gr
4
 
5
+ from modules.whisper.whisper_factory import WhisperFactory
6
  from modules.whisper.faster_whisper_inference import FasterWhisperInference
7
  from modules.whisper.insanely_fast_whisper_inference import InsanelyFastWhisperInference
8
  from modules.translation.nllb_inference import NLLBInference
 
16
  def __init__(self, args):
17
  self.args = args
18
  self.app = gr.Blocks(css=CSS, theme=self.args.theme)
19
+ self.whisper_inf = WhisperFactory.create_whisper_inference(
20
+ whisper_type=self.args.whisper_type,
21
+ whisper_model_dir=self.args.whisper_model_dir,
22
+ faster_whisper_model_dir=self.args.faster_whisper_model_dir,
23
+ insanely_fast_whisper_model_dir=self.args.insanely_fast_whisper_model_dir,
24
  output_dir=self.args.output_dir,
 
25
  )
26
  print(f"Use \"{self.args.whisper_type}\" implementation")
27
  print(f"Device \"{self.whisper_inf.device}\" is detected")
28
  self.nllb_inf = NLLBInference(
29
  model_dir=self.args.nllb_model_dir,
30
+ output_dir=os.path.join(self.args.output_dir, "translations")
31
  )
32
  self.deepl_api = DeepLAPI(
33
+ output_dir=os.path.join(self.args.output_dir, "translations")
34
  )
35
 
36
+ def create_whisper_parameters(self):
37
+ with gr.Row():
38
+ dd_model = gr.Dropdown(choices=self.whisper_inf.available_models, value="large-v2",
39
+ label="Model")
40
+ dd_lang = gr.Dropdown(choices=["Automatic Detection"] + self.whisper_inf.available_langs,
41
+ value="Automatic Detection", label="Language")
42
+ dd_file_format = gr.Dropdown(choices=["SRT", "WebVTT", "txt"], value="SRT", label="File Format")
43
+ with gr.Row():
44
+ cb_translate = gr.Checkbox(value=False, label="Translate to English?", interactive=True)
45
+ with gr.Row():
46
+ cb_timestamp = gr.Checkbox(value=True, label="Add a timestamp to the end of the filename",
47
+ interactive=True)
48
+ with gr.Accordion("Advanced Parameters", open=False):
49
+ nb_beam_size = gr.Number(label="Beam Size", value=5, precision=0, interactive=True,
50
+ info="Beam size to use for decoding.")
51
+ nb_log_prob_threshold = gr.Number(label="Log Probability Threshold", value=-1.0, interactive=True,
52
+ info="If the average log probability over sampled tokens is below this value, treat as failed.")
53
+ nb_no_speech_threshold = gr.Number(label="No Speech Threshold", value=0.6, interactive=True,
54
+ info="If the no speech probability is higher than this value AND the average log probability over sampled tokens is below 'Log Prob Threshold', consider the segment as silent.")
55
+ dd_compute_type = gr.Dropdown(label="Compute Type", choices=self.whisper_inf.available_compute_types,
56
+ value=self.whisper_inf.current_compute_type, interactive=True,
57
+ info="Select the type of computation to perform.")
58
+ nb_best_of = gr.Number(label="Best Of", value=5, interactive=True,
59
+ info="Number of candidates when sampling with non-zero temperature.")
60
+ nb_patience = gr.Number(label="Patience", value=1, interactive=True,
61
+ info="Beam search patience factor.")
62
+ cb_condition_on_previous_text = gr.Checkbox(label="Condition On Previous Text", value=True,
63
+ interactive=True,
64
+ info="Condition on previous text during decoding.")
65
+ sld_prompt_reset_on_temperature = gr.Slider(label="Prompt Reset On Temperature", value=0.5,
66
+ minimum=0, maximum=1, step=0.01, interactive=True,
67
+ info="Resets prompt if temperature is above this value."
68
+ " Arg has effect only if 'Condition On Previous Text' is True.")
69
+ tb_initial_prompt = gr.Textbox(label="Initial Prompt", value=None, interactive=True,
70
+ info="Initial prompt to use for decoding.")
71
+ sd_temperature = gr.Slider(label="Temperature", value=0, step=0.01, maximum=1.0, interactive=True,
72
+ info="Temperature for sampling. It can be a tuple of temperatures, which will be successively used upon failures according to either `Compression Ratio Threshold` or `Log Prob Threshold`.")
73
+ nb_compression_ratio_threshold = gr.Number(label="Compression Ratio Threshold", value=2.4, interactive=True,
74
+ info="If the gzip compression ratio is above this value, treat as failed.")
75
+ with gr.Group(visible=isinstance(self.whisper_inf, FasterWhisperInference)):
76
+ nb_length_penalty = gr.Number(label="Length Penalty", value=1,
77
+ info="Exponential length penalty constant.")
78
+ nb_repetition_penalty = gr.Number(label="Repetition Penalty", value=1,
79
+ info="Penalty applied to the score of previously generated tokens (set > 1 to penalize).")
80
+ nb_no_repeat_ngram_size = gr.Number(label="No Repeat N-gram Size", value=0, precision=0,
81
+ info="Prevent repetitions of n-grams with this size (set 0 to disable).")
82
+ tb_prefix = gr.Textbox(label="Prefix", value=lambda: None,
83
+ info="Optional text to provide as a prefix for the first window.")
84
+ cb_suppress_blank = gr.Checkbox(label="Suppress Blank", value=True,
85
+ info="Suppress blank outputs at the beginning of the sampling.")
86
+ tb_suppress_tokens = gr.Textbox(label="Suppress Tokens", value="[-1]",
87
+ info="List of token IDs to suppress. -1 will suppress a default set of symbols as defined in the model config.json file.")
88
+ nb_max_initial_timestamp = gr.Number(label="Max Initial Timestamp", value=1.0,
89
+ info="The initial timestamp cannot be later than this.")
90
+ cb_word_timestamps = gr.Checkbox(label="Word Timestamps", value=False,
91
+ info="Extract word-level timestamps using the cross-attention pattern and dynamic time warping, and include the timestamps for each word in each segment.")
92
+ tb_prepend_punctuations = gr.Textbox(label="Prepend Punctuations", value="\"'“¿([{-",
93
+ info="If 'Word Timestamps' is True, merge these punctuation symbols with the next word.")
94
+ tb_append_punctuations = gr.Textbox(label="Append Punctuations", value="\"'.。,,!!??::”)]}、",
95
+ info="If 'Word Timestamps' is True, merge these punctuation symbols with the previous word.")
96
+ nb_max_new_tokens = gr.Number(label="Max New Tokens", value=lambda: None, precision=0,
97
+ info="Maximum number of new tokens to generate per-chunk. If not set, the maximum will be set by the default max_length.")
98
+ nb_chunk_length = gr.Number(label="Chunk Length", value=lambda: None, precision=0,
99
+ info="The length of audio segments. If it is not None, it will overwrite the default chunk_length of the FeatureExtractor.")
100
+ nb_hallucination_silence_threshold = gr.Number(label="Hallucination Silence Threshold (sec)",
101
+ value=lambda: None,
102
+ info="When 'Word Timestamps' is True, skip silent periods longer than this threshold (in seconds) when a possible hallucination is detected.")
103
+ tb_hotwords = gr.Textbox(label="Hotwords", value=None,
104
+ info="Hotwords/hint phrases to provide the model with. Has no effect if prefix is not None.")
105
+ nb_language_detection_threshold = gr.Number(label="Language Detection Threshold", value=None,
106
+ info="If the maximum probability of the language tokens is higher than this value, the language is detected.")
107
+ nb_language_detection_segments = gr.Number(label="Language Detection Segments", value=1, precision=0,
108
+ info="Number of segments to consider for the language detection.")
109
+ with gr.Group(visible=isinstance(self.whisper_inf, InsanelyFastWhisperInference)):
110
+ nb_chunk_length_s = gr.Number(label="Chunk Lengths (sec)", value=30, precision=0)
111
+ nb_batch_size = gr.Number(label="Batch Size", value=24, precision=0)
112
 
113
+ with gr.Accordion("VAD", open=False):
114
+ cb_vad_filter = gr.Checkbox(label="Enable Silero VAD Filter", value=False, interactive=True)
115
+ sd_threshold = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label="Speech Threshold", value=0.5,
116
+ info="Lower it to be more sensitive to small sounds.")
117
+ nb_min_speech_duration_ms = gr.Number(label="Minimum Speech Duration (ms)", precision=0, value=250,
118
+ info="Final speech chunks shorter than this time are thrown out")
119
+ nb_max_speech_duration_s = gr.Number(label="Maximum Speech Duration (s)", value=9999,
120
+ info="Maximum duration of speech chunks in \"seconds\". Chunks longer"
121
+ " than this time will be split at the timestamp of the last silence that"
122
+ " lasts more than 100ms (if any), to prevent aggressive cutting.")
123
+ nb_min_silence_duration_ms = gr.Number(label="Minimum Silence Duration (ms)", precision=0, value=2000,
124
+ info="In the end of each speech chunk wait for this time"
125
+ " before separating it")
126
+ nb_speech_pad_ms = gr.Number(label="Speech Padding (ms)", precision=0, value=400,
127
+ info="Final speech chunks are padded by this time each side")
128
 
129
+ with gr.Accordion("Diarization", open=False):
130
+ cb_diarize = gr.Checkbox(label="Enable Diarization")
131
+ tb_hf_token = gr.Text(label="HuggingFace Token", value="",
132
+ info="This is only needed the first time you download the model. If you already have models, you don't need to enter. To download the model, you must manually go to \"https://huggingface.co/pyannote/speaker-diarization-3.1\" and agree to their requirement.")
133
+ dd_diarization_device = gr.Dropdown(label="Device",
134
+ choices=self.whisper_inf.diarizer.get_available_device(),
135
+ value=self.whisper_inf.diarizer.get_device())
136
+
137
+ dd_model.change(fn=self.on_change_models, inputs=[dd_model], outputs=[cb_translate])
138
+
139
+ return (
140
+ WhisperParameters(
141
+ model_size=dd_model, lang=dd_lang, is_translate=cb_translate, beam_size=nb_beam_size,
142
+ log_prob_threshold=nb_log_prob_threshold, no_speech_threshold=nb_no_speech_threshold,
143
+ compute_type=dd_compute_type, best_of=nb_best_of, patience=nb_patience,
144
+ condition_on_previous_text=cb_condition_on_previous_text, initial_prompt=tb_initial_prompt,
145
+ temperature=sd_temperature, compression_ratio_threshold=nb_compression_ratio_threshold,
146
+ vad_filter=cb_vad_filter, threshold=sd_threshold, min_speech_duration_ms=nb_min_speech_duration_ms,
147
+ max_speech_duration_s=nb_max_speech_duration_s, min_silence_duration_ms=nb_min_silence_duration_ms,
148
+ speech_pad_ms=nb_speech_pad_ms, chunk_length_s=nb_chunk_length_s, batch_size=nb_batch_size,
149
+ is_diarize=cb_diarize, hf_token=tb_hf_token, diarization_device=dd_diarization_device,
150
+ length_penalty=nb_length_penalty, repetition_penalty=nb_repetition_penalty,
151
+ no_repeat_ngram_size=nb_no_repeat_ngram_size, prefix=tb_prefix, suppress_blank=cb_suppress_blank,
152
+ suppress_tokens=tb_suppress_tokens, max_initial_timestamp=nb_max_initial_timestamp,
153
+ word_timestamps=cb_word_timestamps, prepend_punctuations=tb_prepend_punctuations,
154
+ append_punctuations=tb_append_punctuations, max_new_tokens=nb_max_new_tokens, chunk_length=nb_chunk_length,
155
+ hallucination_silence_threshold=nb_hallucination_silence_threshold, hotwords=tb_hotwords,
156
+ language_detection_threshold=nb_language_detection_threshold,
157
+ language_detection_segments=nb_language_detection_segments,
158
+ prompt_reset_on_temperature=sld_prompt_reset_on_temperature
159
+ ),
160
+ dd_file_format,
161
+ cb_timestamp
162
+ )
163
 
164
  def launch(self):
165
  with self.app:
 
168
  gr.Markdown(MARKDOWN, elem_id="md_project")
169
  with gr.Tabs():
170
  with gr.TabItem("File"): # tab1
171
+ with gr.Column():
172
  input_file = gr.Files(type="filepath", label="Upload File here")
173
+ tb_input_folder = gr.Textbox(label="Input Folder Path (Optional)",
174
+ info="Optional: Specify the folder path where the input files are located, if you prefer to use local files instead of uploading them."
175
+ " Leave this field empty if you do not wish to use a local path.",
176
+ visible=self.args.colab,
177
+ value="")
178
+
179
+ whisper_params, dd_file_format, cb_timestamp = self.create_whisper_parameters()
180
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
181
  with gr.Row():
182
  btn_run = gr.Button("GENERATE SUBTITLE FILE", variant="primary")
183
  with gr.Row():
184
+ tb_indicator = gr.Textbox(label="Output", scale=5)
185
  files_subtitles = gr.Files(label="Downloadable output file", scale=3, interactive=False)
186
+ btn_openfolder = gr.Button('📂', scale=1)
187
 
188
+ params = [input_file, tb_input_folder, dd_file_format, cb_timestamp]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
189
  btn_run.click(fn=self.whisper_inf.transcribe_file,
190
+ inputs=params + whisper_params.as_list(),
191
  outputs=[tb_indicator, files_subtitles])
192
+ btn_openfolder.click(fn=lambda: self.open_folder("outputs"), inputs=None, outputs=None)
193
 
194
  with gr.TabItem("Youtube"): # tab2
195
  with gr.Row():
 
200
  with gr.Column():
201
  tb_title = gr.Label(label="Youtube Title")
202
  tb_description = gr.Textbox(label="Youtube Description", max_lines=15)
203
+
204
+ whisper_params, dd_file_format, cb_timestamp = self.create_whisper_parameters()
205
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
206
  with gr.Row():
207
  btn_run = gr.Button("GENERATE SUBTITLE FILE", variant="primary")
208
  with gr.Row():
209
+ tb_indicator = gr.Textbox(label="Output", scale=5)
210
  files_subtitles = gr.Files(label="Downloadable output file", scale=3)
211
+ btn_openfolder = gr.Button('📂', scale=1)
212
 
213
  params = [tb_youtubelink, dd_file_format, cb_timestamp]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
214
 
215
  btn_run.click(fn=self.whisper_inf.transcribe_youtube,
216
  inputs=params + whisper_params.as_list(),
217
  outputs=[tb_indicator, files_subtitles])
218
  tb_youtubelink.change(get_ytmetas, inputs=[tb_youtubelink],
219
  outputs=[img_thumbnail, tb_title, tb_description])
220
+ btn_openfolder.click(fn=lambda: self.open_folder("outputs"), inputs=None, outputs=None)
221
 
222
  with gr.TabItem("Mic"): # tab3
223
  with gr.Row():
224
  mic_input = gr.Microphone(label="Record with Mic", type="filepath", interactive=True)
225
+
226
+ whisper_params, dd_file_format, cb_timestamp = self.create_whisper_parameters()
227
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
228
  with gr.Row():
229
  btn_run = gr.Button("GENERATE SUBTITLE FILE", variant="primary")
230
  with gr.Row():
231
+ tb_indicator = gr.Textbox(label="Output", scale=5)
232
  files_subtitles = gr.Files(label="Downloadable output file", scale=3)
233
+ btn_openfolder = gr.Button('📂', scale=1)
234
 
235
  params = [mic_input, dd_file_format]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
236
 
237
  btn_run.click(fn=self.whisper_inf.transcribe_mic,
238
  inputs=params + whisper_params.as_list(),
239
  outputs=[tb_indicator, files_subtitles])
240
+ btn_openfolder.click(fn=lambda: self.open_folder("outputs"), inputs=None, outputs=None)
241
 
242
  with gr.TabItem("T2T Translation"): # tab 4
243
  with gr.Row():
 
257
  self.deepl_api.available_target_langs.keys()))
258
  with gr.Row():
259
  cb_deepl_ispro = gr.Checkbox(label="Pro User?", value=False)
260
+ with gr.Row():
261
+ cb_timestamp = gr.Checkbox(value=True, label="Add a timestamp to the end of the filename",
262
+ interactive=True)
263
  with gr.Row():
264
  btn_run = gr.Button("TRANSLATE SUBTITLE FILE", variant="primary")
265
  with gr.Row():
266
  tb_indicator = gr.Textbox(label="Output", scale=5)
267
  files_subtitles = gr.Files(label="Downloadable output file", scale=3)
268
+ btn_openfolder = gr.Button('📂', scale=1)
269
 
270
  btn_run.click(fn=self.deepl_api.translate_deepl,
271
  inputs=[tb_authkey, file_subs, dd_deepl_sourcelang, dd_deepl_targetlang,
272
+ cb_deepl_ispro, cb_timestamp],
273
  outputs=[tb_indicator, files_subtitles])
274
 
275
+ btn_openfolder.click(fn=lambda: self.open_folder(os.path.join("outputs", "translations")),
276
+ inputs=None,
277
+ outputs=None)
278
+
279
  with gr.TabItem("NLLB"): # sub tab2
280
  with gr.Row():
281
  dd_nllb_model = gr.Dropdown(label="Model", value="facebook/nllb-200-1.3B",
 
284
  choices=self.nllb_inf.available_source_langs)
285
  dd_nllb_targetlang = gr.Dropdown(label="Target Language",
286
  choices=self.nllb_inf.available_target_langs)
287
+ with gr.Row():
288
+ nb_max_length = gr.Number(label="Max Length Per Line", value=200, precision=0)
289
  with gr.Row():
290
  cb_timestamp = gr.Checkbox(value=True, label="Add a timestamp to the end of the filename",
291
  interactive=True)
 
294
  with gr.Row():
295
  tb_indicator = gr.Textbox(label="Output", scale=5)
296
  files_subtitles = gr.Files(label="Downloadable output file", scale=3)
297
+ btn_openfolder = gr.Button('📂', scale=1)
298
  with gr.Column():
299
  md_vram_table = gr.HTML(NLLB_VRAM_TABLE, elem_id="md_nllb_vram_table")
300
 
301
  btn_run.click(fn=self.nllb_inf.translate_file,
302
+ inputs=[file_subs, dd_nllb_model, dd_nllb_sourcelang, dd_nllb_targetlang,
303
+ nb_max_length, cb_timestamp],
304
  outputs=[tb_indicator, files_subtitles])
305
 
306
+ btn_openfolder.click(fn=lambda: self.open_folder(os.path.join("outputs", "translations")),
307
+ inputs=None,
308
+ outputs=None)
309
+
310
  # Launch the app with optional gradio settings
311
+ args = self.args
312
+
313
+ self.app.queue(
314
+ api_open=args.api_open
315
+ ).launch(
316
+ share=args.share,
317
+ server_name=args.server_name,
318
+ server_port=args.server_port,
319
+ auth=(args.username, args.password) if args.username and args.password else None,
320
+ root_path=args.root_path,
321
+ inbrowser=args.inbrowser
322
+ )
323
+
324
+ @staticmethod
325
+ def open_folder(folder_path: str):
326
+ if os.path.exists(folder_path):
327
+ os.system(f"start {folder_path}")
328
+ else:
329
+ print(f"The folder {folder_path} does not exist.")
330
+
331
+ @staticmethod
332
+ def on_change_models(model_size: str):
333
+ translatable_model = ["large", "large-v1", "large-v2", "large-v3"]
334
+ if model_size not in translatable_model:
335
+ return gr.Checkbox(visible=False, value=False, interactive=False)
336
+ else:
337
+ return gr.Checkbox(visible=True, value=False, label="Translate to English?", interactive=True)
338
 
339
 
340
  # Create the parser for command-line arguments
341
  parser = argparse.ArgumentParser()
342
+ parser.add_argument('--whisper_type', type=str, default="faster-whisper",
343
+ help='A type of the whisper implementation between: ["whisper", "faster-whisper", "insanely-fast-whisper"]')
344
  parser.add_argument('--share', type=bool, default=False, nargs='?', const=True, help='Gradio share value')
345
  parser.add_argument('--server_name', type=str, default=None, help='Gradio server host')
346
  parser.add_argument('--server_port', type=int, default=None, help='Gradio server port')
 
349
  parser.add_argument('--password', type=str, default=None, help='Gradio authentication password')
350
  parser.add_argument('--theme', type=str, default=None, help='Gradio Blocks theme')
351
  parser.add_argument('--colab', type=bool, default=False, nargs='?', const=True, help='Is colab user or not')
352
+ parser.add_argument('--api_open', type=bool, default=False, nargs='?', const=True, help='Enable api or not in Gradio')
353
+ parser.add_argument('--inbrowser', type=bool, default=True, nargs='?', const=True, help='Whether to automatically start Gradio app or not')
354
+ parser.add_argument('--whisper_model_dir', type=str, default=os.path.join("models", "Whisper"),
355
+ help='Directory path of the whisper model')
356
+ parser.add_argument('--faster_whisper_model_dir', type=str, default=os.path.join("models", "Whisper", "faster-whisper"),
357
+ help='Directory path of the faster-whisper model')
358
+ parser.add_argument('--insanely_fast_whisper_model_dir', type=str,
359
+ default=os.path.join("models", "Whisper", "insanely-fast-whisper"),
360
+ help='Directory path of the insanely-fast-whisper model')
361
+ parser.add_argument('--diarization_model_dir', type=str, default=os.path.join("models", "Diarization"),
362
+ help='Directory path of the diarization model')
363
+ parser.add_argument('--nllb_model_dir', type=str, default=os.path.join("models", "NLLB"),
364
+ help='Directory path of the Facebook NLLB model')
365
  parser.add_argument('--output_dir', type=str, default=os.path.join("outputs"), help='Directory path of the outputs')
366
  _args = parser.parse_args()
367
 
docker-compose.yaml ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ version: '3.8'
2
+
3
+ services:
4
+ app:
5
+ build: .
6
+ image: jhj0517/whisper-webui:latest
7
+
8
+ volumes:
9
+ # Update paths to mount models and output paths to your custom paths like this, e.g:
10
+ # - C:/whisper-models/custom-path:/Whisper-WebUI/models
11
+ # - C:/whisper-webui-outputs/custom-path:/Whisper-WebUI/outputs
12
+ - /Whisper-WebUI/models
13
+ - /Whisper-WebUI/outputs
14
+
15
+ ports:
16
+ - "7860:7860"
17
+
18
+ stdin_open: true
19
+ tty: true
20
+
21
+ entrypoint: ["python", "app.py", "--server_port", "7860", "--server_name", "0.0.0.0",]
22
+
23
+ # If you're not using nvidia GPU, Update device to match yours.
24
+ # See more info at : https://docs.docker.com/compose/compose-file/deploy/#driver
25
+ deploy:
26
+ resources:
27
+ reservations:
28
+ devices:
29
+ - driver: nvidia
30
+ count: all
31
+ capabilities: [ gpu ]
modules/diarize/audio_loader.py CHANGED
@@ -1,7 +1,11 @@
 
 
1
  import os
2
  import subprocess
3
  from functools import lru_cache
4
  from typing import Optional, Union
 
 
5
 
6
  import numpy as np
7
  import torch
@@ -24,32 +28,43 @@ FRAMES_PER_SECOND = exact_div(SAMPLE_RATE, HOP_LENGTH) # 10ms per audio frame
24
  TOKENS_PER_SECOND = exact_div(SAMPLE_RATE, N_SAMPLES_PER_TOKEN) # 20ms per audio token
25
 
26
 
27
- def load_audio(file: str, sr: int = SAMPLE_RATE):
28
  """
29
- Open an audio file and read as mono waveform, resampling as necessary
30
 
31
  Parameters
32
  ----------
33
- file: str
34
- The audio file to open
35
 
36
  sr: int
37
- The sample rate to resample the audio if necessary
38
 
39
  Returns
40
  -------
41
  A NumPy array containing the audio waveform, in float32 dtype.
42
  """
 
 
 
 
 
 
 
 
 
 
 
 
 
43
  try:
44
- # Launches a subprocess to decode audio while down-mixing and resampling as necessary.
45
- # Requires the ffmpeg CLI to be installed.
46
  cmd = [
47
  "ffmpeg",
48
  "-nostdin",
49
  "-threads",
50
  "0",
51
  "-i",
52
- file,
53
  "-f",
54
  "s16le",
55
  "-ac",
@@ -63,6 +78,9 @@ def load_audio(file: str, sr: int = SAMPLE_RATE):
63
  out = subprocess.run(cmd, capture_output=True, check=True).stdout
64
  except subprocess.CalledProcessError as e:
65
  raise RuntimeError(f"Failed to load audio: {e.stderr.decode()}") from e
 
 
 
66
 
67
  return np.frombuffer(out, np.int16).flatten().astype(np.float32) / 32768.0
68
 
 
1
+ # Adapted from https://github.com/m-bain/whisperX/blob/main/whisperx/audio.py
2
+
3
  import os
4
  import subprocess
5
  from functools import lru_cache
6
  from typing import Optional, Union
7
+ from scipy.io.wavfile import write
8
+ import tempfile
9
 
10
  import numpy as np
11
  import torch
 
28
  TOKENS_PER_SECOND = exact_div(SAMPLE_RATE, N_SAMPLES_PER_TOKEN) # 20ms per audio token
29
 
30
 
31
+ def load_audio(file: Union[str, np.ndarray], sr: int = SAMPLE_RATE) -> np.ndarray:
32
  """
33
+ Open an audio file or process a numpy array containing audio data as mono waveform, resampling as necessary.
34
 
35
  Parameters
36
  ----------
37
+ file: Union[str, np.ndarray]
38
+ The audio file to open or a numpy array containing the audio data.
39
 
40
  sr: int
41
+ The sample rate to resample the audio if necessary.
42
 
43
  Returns
44
  -------
45
  A NumPy array containing the audio waveform, in float32 dtype.
46
  """
47
+ if isinstance(file, np.ndarray):
48
+ if file.dtype != np.float32:
49
+ file = file.astype(np.float32)
50
+ if file.ndim > 1:
51
+ file = np.mean(file, axis=1)
52
+
53
+ temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".wav")
54
+ write(temp_file.name, SAMPLE_RATE, (file * 32768).astype(np.int16))
55
+ temp_file_path = temp_file.name
56
+ temp_file.close()
57
+ else:
58
+ temp_file_path = file
59
+
60
  try:
 
 
61
  cmd = [
62
  "ffmpeg",
63
  "-nostdin",
64
  "-threads",
65
  "0",
66
  "-i",
67
+ temp_file_path,
68
  "-f",
69
  "s16le",
70
  "-ac",
 
78
  out = subprocess.run(cmd, capture_output=True, check=True).stdout
79
  except subprocess.CalledProcessError as e:
80
  raise RuntimeError(f"Failed to load audio: {e.stderr.decode()}") from e
81
+ finally:
82
+ if isinstance(file, np.ndarray):
83
+ os.remove(temp_file_path)
84
 
85
  return np.frombuffer(out, np.int16).flatten().astype(np.float32) / 32768.0
86
 
modules/diarize/diarize_pipeline.py CHANGED
@@ -1,3 +1,5 @@
 
 
1
  import numpy as np
2
  import pandas as pd
3
  import os
 
1
+ # Adapted from https://github.com/m-bain/whisperX/blob/main/whisperx/diarize.py
2
+
3
  import numpy as np
4
  import pandas as pd
5
  import os
modules/diarize/diarizer.py CHANGED
@@ -1,9 +1,9 @@
1
  import os
2
  import torch
3
- from typing import List
 
4
  import time
5
  import logging
6
- import spaces
7
 
8
  from modules.diarize.diarize_pipeline import DiarizationPipeline, assign_word_speakers
9
  from modules.diarize.audio_loader import load_audio
@@ -20,9 +20,8 @@ class Diarizer:
20
  os.makedirs(self.model_dir, exist_ok=True)
21
  self.pipe = None
22
 
23
- @spaces.GPU
24
  def run(self,
25
- audio: str,
26
  transcribed_result: List[dict],
27
  use_auth_token: str,
28
  device: str
@@ -75,7 +74,6 @@ class Diarizer:
75
  elapsed_time = time.time() - start_time
76
  return diarized_result["segments"], elapsed_time
77
 
78
- @spaces.GPU
79
  def update_pipe(self,
80
  use_auth_token: str,
81
  device: str
@@ -113,7 +111,6 @@ class Diarizer:
113
  logger.disabled = False
114
 
115
  @staticmethod
116
- @spaces.GPU
117
  def get_device():
118
  if torch.cuda.is_available():
119
  return "cuda"
@@ -123,7 +120,6 @@ class Diarizer:
123
  return "cpu"
124
 
125
  @staticmethod
126
- @spaces.GPU
127
  def get_available_device():
128
  devices = ["cpu"]
129
  if torch.cuda.is_available():
 
1
  import os
2
  import torch
3
+ from typing import List, Union, BinaryIO
4
+ import numpy as np
5
  import time
6
  import logging
 
7
 
8
  from modules.diarize.diarize_pipeline import DiarizationPipeline, assign_word_speakers
9
  from modules.diarize.audio_loader import load_audio
 
20
  os.makedirs(self.model_dir, exist_ok=True)
21
  self.pipe = None
22
 
 
23
  def run(self,
24
+ audio: Union[str, BinaryIO, np.ndarray],
25
  transcribed_result: List[dict],
26
  use_auth_token: str,
27
  device: str
 
74
  elapsed_time = time.time() - start_time
75
  return diarized_result["segments"], elapsed_time
76
 
 
77
  def update_pipe(self,
78
  use_auth_token: str,
79
  device: str
 
111
  logger.disabled = False
112
 
113
  @staticmethod
 
114
  def get_device():
115
  if torch.cuda.is_available():
116
  return "cuda"
 
120
  return "cpu"
121
 
122
  @staticmethod
 
123
  def get_available_device():
124
  devices = ["cpu"]
125
  if torch.cuda.is_available():
modules/translation/deepl_api.py CHANGED
@@ -83,7 +83,7 @@ DEEPL_AVAILABLE_SOURCE_LANGS = {
83
 
84
  class DeepLAPI:
85
  def __init__(self,
86
- output_dir: str
87
  ):
88
  self.api_interval = 1
89
  self.max_text_batch_size = 50
@@ -97,6 +97,7 @@ class DeepLAPI:
97
  source_lang: str,
98
  target_lang: str,
99
  is_pro: bool,
 
100
  progress=gr.Progress()) -> list:
101
  """
102
  Translate subtitle files using DeepL API
@@ -112,6 +113,8 @@ class DeepLAPI:
112
  Target language of the file to transcribe from gr.Dropdown()
113
  is_pro: str
114
  Boolean value that is about pro user or not from gr.Checkbox().
 
 
115
  progress: gr.Progress
116
  Indicator to show progress directly in gradio.
117
 
@@ -141,11 +144,6 @@ class DeepLAPI:
141
  progress(batch_end / len(parsed_dicts), desc="Translating..")
142
 
143
  subtitle = get_serialized_srt(parsed_dicts)
144
- timestamp = datetime.now().strftime("%m%d%H%M%S")
145
-
146
- file_name = file_name[:-9]
147
- output_path = os.path.join(self.output_dir, "", f"{file_name}-{timestamp}.srt")
148
- write_file(subtitle, output_path)
149
 
150
  elif file_ext == ".vtt":
151
  parsed_dicts = parse_vtt(file_path=file_path)
@@ -161,22 +159,25 @@ class DeepLAPI:
161
  progress(batch_end / len(parsed_dicts), desc="Translating..")
162
 
163
  subtitle = get_serialized_vtt(parsed_dicts)
 
 
164
  timestamp = datetime.now().strftime("%m%d%H%M%S")
 
165
 
166
- file_name = file_name[:-9]
167
- output_path = os.path.join(self.output_dir, "", f"{file_name}-{timestamp}.vtt")
168
 
169
- write_file(subtitle, output_path)
170
 
171
- files_info[file_name] = subtitle
172
  total_result = ''
173
- for file_name, subtitle in files_info.items():
174
  total_result += '------------------------------------\n'
175
  total_result += f'{file_name}\n\n'
176
- total_result += f'{subtitle}'
177
-
178
  gr_str = f"Done! Subtitle is in the outputs/translation folder.\n\n{total_result}"
179
- return [gr_str, output_path]
 
 
180
 
181
  def request_deepl_translate(self,
182
  auth_key: str,
 
83
 
84
  class DeepLAPI:
85
  def __init__(self,
86
+ output_dir: str = os.path.join("outputs", "translations")
87
  ):
88
  self.api_interval = 1
89
  self.max_text_batch_size = 50
 
97
  source_lang: str,
98
  target_lang: str,
99
  is_pro: bool,
100
+ add_timestamp: bool,
101
  progress=gr.Progress()) -> list:
102
  """
103
  Translate subtitle files using DeepL API
 
113
  Target language of the file to transcribe from gr.Dropdown()
114
  is_pro: str
115
  Boolean value that is about pro user or not from gr.Checkbox().
116
+ add_timestamp: bool
117
+ Boolean value from gr.Checkbox() that determines whether to add a timestamp at the end of the filename.
118
  progress: gr.Progress
119
  Indicator to show progress directly in gradio.
120
 
 
144
  progress(batch_end / len(parsed_dicts), desc="Translating..")
145
 
146
  subtitle = get_serialized_srt(parsed_dicts)
 
 
 
 
 
147
 
148
  elif file_ext == ".vtt":
149
  parsed_dicts = parse_vtt(file_path=file_path)
 
159
  progress(batch_end / len(parsed_dicts), desc="Translating..")
160
 
161
  subtitle = get_serialized_vtt(parsed_dicts)
162
+
163
+ if add_timestamp:
164
  timestamp = datetime.now().strftime("%m%d%H%M%S")
165
+ file_name += f"-{timestamp}"
166
 
167
+ output_path = os.path.join(self.output_dir, f"{file_name}{file_ext}")
168
+ write_file(subtitle, output_path)
169
 
170
+ files_info[file_name] = {"subtitle": subtitle, "path": output_path}
171
 
 
172
  total_result = ''
173
+ for file_name, info in files_info.items():
174
  total_result += '------------------------------------\n'
175
  total_result += f'{file_name}\n\n'
176
+ total_result += f'{info["subtitle"]}'
 
177
  gr_str = f"Done! Subtitle is in the outputs/translation folder.\n\n{total_result}"
178
+
179
+ output_file_paths = [item["path"] for key, item in files_info.items()]
180
+ return [gr_str, output_file_paths]
181
 
182
  def request_deepl_translate(self,
183
  auth_key: str,
modules/translation/nllb_inference.py CHANGED
@@ -1,15 +1,14 @@
1
  from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
2
  import gradio as gr
3
  import os
4
- import spaces
5
 
6
  from modules.translation.translation_base import TranslationBase
7
 
8
 
9
  class NLLBInference(TranslationBase):
10
  def __init__(self,
11
- model_dir: str,
12
- output_dir: str
13
  ):
14
  super().__init__(
15
  model_dir=model_dir,
@@ -21,14 +20,16 @@ class NLLBInference(TranslationBase):
21
  self.available_target_langs = list(NLLB_AVAILABLE_LANGS.keys())
22
  self.pipeline = None
23
 
24
- @spaces.GPU(duration=120)
25
  def translate(self,
26
- text: str
 
27
  ):
28
- result = self.pipeline(text)
 
 
 
29
  return result[0]['translation_text']
30
 
31
- @spaces.GPU(duration=120)
32
  def update_model(self,
33
  model_size: str,
34
  src_lang: str,
@@ -39,10 +40,13 @@ class NLLBInference(TranslationBase):
39
  print("\nInitializing NLLB Model..\n")
40
  progress(0, desc="Initializing NLLB Model..")
41
  self.current_model_size = model_size
 
42
  self.model = AutoModelForSeq2SeqLM.from_pretrained(pretrained_model_name_or_path=model_size,
43
- cache_dir=self.model_dir)
 
44
  self.tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path=model_size,
45
- cache_dir=os.path.join(self.model_dir, "tokenizers"))
 
46
  src_lang = NLLB_AVAILABLE_LANGS[src_lang]
47
  tgt_lang = NLLB_AVAILABLE_LANGS[tgt_lang]
48
  self.pipeline = pipeline("translation",
@@ -52,6 +56,18 @@ class NLLBInference(TranslationBase):
52
  tgt_lang=tgt_lang,
53
  device=self.device)
54
 
 
 
 
 
 
 
 
 
 
 
 
 
55
  NLLB_AVAILABLE_LANGS = {
56
  "Acehnese (Arabic script)": "ace_Arab",
57
  "Acehnese (Latin script)": "ace_Latn",
 
1
  from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
2
  import gradio as gr
3
  import os
 
4
 
5
  from modules.translation.translation_base import TranslationBase
6
 
7
 
8
  class NLLBInference(TranslationBase):
9
  def __init__(self,
10
+ model_dir: str = os.path.join("models", "NLLB"),
11
+ output_dir: str = os.path.join("outputs", "translations")
12
  ):
13
  super().__init__(
14
  model_dir=model_dir,
 
20
  self.available_target_langs = list(NLLB_AVAILABLE_LANGS.keys())
21
  self.pipeline = None
22
 
 
23
  def translate(self,
24
+ text: str,
25
+ max_length: int
26
  ):
27
+ result = self.pipeline(
28
+ text,
29
+ max_length=max_length
30
+ )
31
  return result[0]['translation_text']
32
 
 
33
  def update_model(self,
34
  model_size: str,
35
  src_lang: str,
 
40
  print("\nInitializing NLLB Model..\n")
41
  progress(0, desc="Initializing NLLB Model..")
42
  self.current_model_size = model_size
43
+ local_files_only = self.is_model_exists(self.current_model_size)
44
  self.model = AutoModelForSeq2SeqLM.from_pretrained(pretrained_model_name_or_path=model_size,
45
+ cache_dir=self.model_dir,
46
+ local_files_only=local_files_only)
47
  self.tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path=model_size,
48
+ cache_dir=os.path.join(self.model_dir, "tokenizers"),
49
+ local_files_only=local_files_only)
50
  src_lang = NLLB_AVAILABLE_LANGS[src_lang]
51
  tgt_lang = NLLB_AVAILABLE_LANGS[tgt_lang]
52
  self.pipeline = pipeline("translation",
 
56
  tgt_lang=tgt_lang,
57
  device=self.device)
58
 
59
+ def is_model_exists(self,
60
+ model_size: str):
61
+ """Check if model exists or not (Only facebook model)"""
62
+ prefix = "models--facebook--"
63
+ _id, model_size_name = model_size.split("/")
64
+ model_dir_name = prefix + model_size_name
65
+ model_dir_path = os.path.join(self.model_dir, model_dir_name)
66
+ if os.path.exists(model_dir_path) and os.listdir(model_dir_path):
67
+ return True
68
+ return False
69
+
70
+
71
  NLLB_AVAILABLE_LANGS = {
72
  "Acehnese (Arabic script)": "ace_Arab",
73
  "Acehnese (Latin script)": "ace_Latn",
modules/translation/translation_base.py CHANGED
@@ -4,7 +4,6 @@ import gradio as gr
4
  from abc import ABC, abstractmethod
5
  from typing import List
6
  from datetime import datetime
7
- import spaces
8
 
9
  from modules.whisper.whisper_parameter import *
10
  from modules.utils.subtitle_manager import *
@@ -12,8 +11,9 @@ from modules.utils.subtitle_manager import *
12
 
13
  class TranslationBase(ABC):
14
  def __init__(self,
15
- model_dir: str,
16
- output_dir: str):
 
17
  super().__init__()
18
  self.model = None
19
  self.model_dir = model_dir
@@ -24,14 +24,13 @@ class TranslationBase(ABC):
24
  self.device = self.get_device()
25
 
26
  @abstractmethod
27
- @spaces.GPU(duration=120)
28
  def translate(self,
29
- text: str
 
30
  ):
31
  pass
32
 
33
  @abstractmethod
34
- @spaces.GPU(duration=120)
35
  def update_model(self,
36
  model_size: str,
37
  src_lang: str,
@@ -40,12 +39,12 @@ class TranslationBase(ABC):
40
  ):
41
  pass
42
 
43
- @spaces.GPU(duration=120)
44
  def translate_file(self,
45
  fileobjs: list,
46
  model_size: str,
47
  src_lang: str,
48
  tgt_lang: str,
 
49
  add_timestamp: bool,
50
  progress=gr.Progress()) -> list:
51
  """
@@ -61,6 +60,8 @@ class TranslationBase(ABC):
61
  Source language of the file to translate from gr.Dropdown()
62
  tgt_lang: str
63
  Target language of the file to translate from gr.Dropdown()
 
 
64
  add_timestamp: bool
65
  Boolean value from gr.Checkbox() that determines whether to add a timestamp at the end of the filename.
66
  progress: gr.Progress
@@ -88,50 +89,44 @@ class TranslationBase(ABC):
88
  total_progress = len(parsed_dicts)
89
  for index, dic in enumerate(parsed_dicts):
90
  progress(index / total_progress, desc="Translating..")
91
- translated_text = self.translate(dic["sentence"])
92
  dic["sentence"] = translated_text
93
  subtitle = get_serialized_srt(parsed_dicts)
94
 
95
- timestamp = datetime.now().strftime("%m%d%H%M%S")
96
- if add_timestamp:
97
- output_path = os.path.join("outputs", "", f"{file_name}-{timestamp}.srt")
98
- else:
99
- output_path = os.path.join("outputs", "", f"{file_name}.srt")
100
-
101
  elif file_ext == ".vtt":
102
  parsed_dicts = parse_vtt(file_path=file_path)
103
  total_progress = len(parsed_dicts)
104
  for index, dic in enumerate(parsed_dicts):
105
  progress(index / total_progress, desc="Translating..")
106
- translated_text = self.translate(dic["sentence"])
107
  dic["sentence"] = translated_text
108
  subtitle = get_serialized_vtt(parsed_dicts)
109
 
 
110
  timestamp = datetime.now().strftime("%m%d%H%M%S")
111
- if add_timestamp:
112
- output_path = os.path.join(self.output_dir, "", f"{file_name}-{timestamp}.vtt")
113
- else:
114
- output_path = os.path.join(self.output_dir, "", f"{file_name}.vtt")
115
 
 
116
  write_file(subtitle, output_path)
117
- files_info[file_name] = subtitle
 
118
 
119
  total_result = ''
120
- for file_name, subtitle in files_info.items():
121
  total_result += '------------------------------------\n'
122
  total_result += f'{file_name}\n\n'
123
- total_result += f'{subtitle}'
124
-
125
  gr_str = f"Done! Subtitle is in the outputs/translation folder.\n\n{total_result}"
126
- return [gr_str, output_path]
 
 
 
127
  except Exception as e:
128
  print(f"Error: {str(e)}")
129
  finally:
130
  self.release_cuda_memory()
131
- self.remove_input_files([fileobj.name for fileobj in fileobjs])
132
 
133
  @staticmethod
134
- @spaces.GPU(duration=120)
135
  def get_device():
136
  if torch.cuda.is_available():
137
  return "cuda"
@@ -141,7 +136,6 @@ class TranslationBase(ABC):
141
  return "cpu"
142
 
143
  @staticmethod
144
- @spaces.GPU(duration=120)
145
  def release_cuda_memory():
146
  if torch.cuda.is_available():
147
  torch.cuda.empty_cache()
 
4
  from abc import ABC, abstractmethod
5
  from typing import List
6
  from datetime import datetime
 
7
 
8
  from modules.whisper.whisper_parameter import *
9
  from modules.utils.subtitle_manager import *
 
11
 
12
  class TranslationBase(ABC):
13
  def __init__(self,
14
+ model_dir: str = os.path.join("models", "NLLB"),
15
+ output_dir: str = os.path.join("outputs", "translations")
16
+ ):
17
  super().__init__()
18
  self.model = None
19
  self.model_dir = model_dir
 
24
  self.device = self.get_device()
25
 
26
  @abstractmethod
 
27
  def translate(self,
28
+ text: str,
29
+ max_length: int
30
  ):
31
  pass
32
 
33
  @abstractmethod
 
34
  def update_model(self,
35
  model_size: str,
36
  src_lang: str,
 
39
  ):
40
  pass
41
 
 
42
  def translate_file(self,
43
  fileobjs: list,
44
  model_size: str,
45
  src_lang: str,
46
  tgt_lang: str,
47
+ max_length: int,
48
  add_timestamp: bool,
49
  progress=gr.Progress()) -> list:
50
  """
 
60
  Source language of the file to translate from gr.Dropdown()
61
  tgt_lang: str
62
  Target language of the file to translate from gr.Dropdown()
63
+ max_length: int
64
+ Max length per line to translate
65
  add_timestamp: bool
66
  Boolean value from gr.Checkbox() that determines whether to add a timestamp at the end of the filename.
67
  progress: gr.Progress
 
89
  total_progress = len(parsed_dicts)
90
  for index, dic in enumerate(parsed_dicts):
91
  progress(index / total_progress, desc="Translating..")
92
+ translated_text = self.translate(dic["sentence"], max_length=max_length)
93
  dic["sentence"] = translated_text
94
  subtitle = get_serialized_srt(parsed_dicts)
95
 
 
 
 
 
 
 
96
  elif file_ext == ".vtt":
97
  parsed_dicts = parse_vtt(file_path=file_path)
98
  total_progress = len(parsed_dicts)
99
  for index, dic in enumerate(parsed_dicts):
100
  progress(index / total_progress, desc="Translating..")
101
+ translated_text = self.translate(dic["sentence"], max_length=max_length)
102
  dic["sentence"] = translated_text
103
  subtitle = get_serialized_vtt(parsed_dicts)
104
 
105
+ if add_timestamp:
106
  timestamp = datetime.now().strftime("%m%d%H%M%S")
107
+ file_name += f"-{timestamp}"
 
 
 
108
 
109
+ output_path = os.path.join(self.output_dir, f"{file_name}{file_ext}")
110
  write_file(subtitle, output_path)
111
+
112
+ files_info[file_name] = {"subtitle": subtitle, "path": output_path}
113
 
114
  total_result = ''
115
+ for file_name, info in files_info.items():
116
  total_result += '------------------------------------\n'
117
  total_result += f'{file_name}\n\n'
118
+ total_result += f'{info["subtitle"]}'
 
119
  gr_str = f"Done! Subtitle is in the outputs/translation folder.\n\n{total_result}"
120
+
121
+ output_file_paths = [item["path"] for key, item in files_info.items()]
122
+ return [gr_str, output_file_paths]
123
+
124
  except Exception as e:
125
  print(f"Error: {str(e)}")
126
  finally:
127
  self.release_cuda_memory()
 
128
 
129
  @staticmethod
 
130
  def get_device():
131
  if torch.cuda.is_available():
132
  return "cuda"
 
136
  return "cpu"
137
 
138
  @staticmethod
 
139
  def release_cuda_memory():
140
  if torch.cuda.is_available():
141
  torch.cuda.empty_cache()
modules/utils/files_manager.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import fnmatch
3
+
4
+ from gradio.utils import NamedString
5
+
6
+
7
+ def get_media_files(folder_path, include_sub_directory=False):
8
+ video_extensions = ['*.mp4', '*.mkv', '*.flv', '*.avi', '*.mov', '*.wmv']
9
+ audio_extensions = ['*.mp3', '*.wav', '*.aac', '*.flac', '*.ogg', '*.m4a']
10
+ media_extensions = video_extensions + audio_extensions
11
+
12
+ media_files = []
13
+
14
+ if include_sub_directory:
15
+ for root, _, files in os.walk(folder_path):
16
+ for extension in media_extensions:
17
+ media_files.extend(
18
+ os.path.join(root, file) for file in fnmatch.filter(files, extension)
19
+ if os.path.exists(os.path.join(root, file))
20
+ )
21
+ else:
22
+ for extension in media_extensions:
23
+ media_files.extend(
24
+ os.path.join(folder_path, file) for file in fnmatch.filter(os.listdir(folder_path), extension)
25
+ if os.path.isfile(os.path.join(folder_path, file)) and os.path.exists(os.path.join(folder_path, file))
26
+ )
27
+
28
+ return media_files
29
+
30
+
31
+ def format_gradio_files(files: list):
32
+ if not files:
33
+ return files
34
+
35
+ gradio_files = []
36
+ for file in files:
37
+ gradio_files.append(NamedString(file))
38
+ return gradio_files
39
+
modules/utils/subtitle_manager.py CHANGED
@@ -1,7 +1,5 @@
1
  import re
2
 
3
- # Zero GPU
4
- import spaces
5
 
6
  def timeformat_srt(time):
7
  hours = time // 3600
@@ -119,7 +117,7 @@ def get_serialized_vtt(dicts):
119
  output += f'{dic["sentence"]}\n\n'
120
  return output
121
 
122
- @spaces.GPU(duration=120)
123
  def safe_filename(name):
124
  from app import _args
125
  INVALID_FILENAME_CHARS = r'[<>:"/\\|?*\x00-\x1f]'
 
1
  import re
2
 
 
 
3
 
4
  def timeformat_srt(time):
5
  hours = time // 3600
 
117
  output += f'{dic["sentence"]}\n\n'
118
  return output
119
 
120
+
121
  def safe_filename(name):
122
  from app import _args
123
  INVALID_FILENAME_CHARS = r'[<>:"/\\|?*\x00-\x1f]'
modules/utils/youtube_manager.py CHANGED
@@ -1,4 +1,4 @@
1
- from pytube import YouTube
2
  import os
3
 
4
 
 
1
+ from pytubefix import YouTube
2
  import os
3
 
4
 
modules/vad/silero_vad.py CHANGED
@@ -1,21 +1,25 @@
1
- from faster_whisper.vad import VadOptions
 
 
2
  import numpy as np
3
- from typing import BinaryIO, Union, List, Optional
4
  import warnings
5
  import faster_whisper
 
6
  import gradio as gr
7
- import spaces
8
 
9
 
10
  class SileroVAD:
11
  def __init__(self):
12
  self.sampling_rate = 16000
 
 
13
 
14
- @spaces.GPU
15
  def run(self,
16
  audio: Union[str, BinaryIO, np.ndarray],
17
  vad_parameters: VadOptions,
18
- progress: gr.Progress = gr.Progress()):
 
19
  """
20
  Run VAD
21
 
@@ -30,8 +34,10 @@ class SileroVAD:
30
 
31
  Returns
32
  ----------
33
- audio: np.ndarray
34
  Pre-processed audio with VAD
 
 
35
  """
36
 
37
  sampling_rate = self.sampling_rate
@@ -54,11 +60,10 @@ class SileroVAD:
54
  audio = self.collect_chunks(audio, speech_chunks)
55
  duration_after_vad = audio.shape[0] / sampling_rate
56
 
57
- return audio
58
 
59
- @staticmethod
60
- @spaces.GPU
61
  def get_speech_timestamps(
 
62
  audio: np.ndarray,
63
  vad_options: Optional[VadOptions] = None,
64
  progress: gr.Progress = gr.Progress(),
@@ -75,6 +80,10 @@ class SileroVAD:
75
  Returns:
76
  List of dicts containing begin and end samples of each speech chunk.
77
  """
 
 
 
 
78
  if vad_options is None:
79
  vad_options = VadOptions(**kwargs)
80
 
@@ -82,15 +91,8 @@ class SileroVAD:
82
  min_speech_duration_ms = vad_options.min_speech_duration_ms
83
  max_speech_duration_s = vad_options.max_speech_duration_s
84
  min_silence_duration_ms = vad_options.min_silence_duration_ms
85
- window_size_samples = vad_options.window_size_samples
86
  speech_pad_ms = vad_options.speech_pad_ms
87
-
88
- if window_size_samples not in [512, 1024, 1536]:
89
- warnings.warn(
90
- "Unusual window_size_samples! Supported window_size_samples:\n"
91
- " - [512, 1024, 1536] for 16000 sampling_rate"
92
- )
93
-
94
  sampling_rate = 16000
95
  min_speech_samples = sampling_rate * min_speech_duration_ms / 1000
96
  speech_pad_samples = sampling_rate * speech_pad_ms / 1000
@@ -104,8 +106,7 @@ class SileroVAD:
104
 
105
  audio_length_samples = len(audio)
106
 
107
- model = faster_whisper.vad.get_vad_model()
108
- state = model.get_initial_state(batch_size=1)
109
 
110
  speech_probs = []
111
  for current_start_sample in range(0, audio_length_samples, window_size_samples):
@@ -114,7 +115,7 @@ class SileroVAD:
114
  chunk = audio[current_start_sample: current_start_sample + window_size_samples]
115
  if len(chunk) < window_size_samples:
116
  chunk = np.pad(chunk, (0, int(window_size_samples - len(chunk))))
117
- speech_prob, state = model(chunk, state, sampling_rate)
118
  speech_probs.append(speech_prob)
119
 
120
  triggered = False
@@ -210,6 +211,9 @@ class SileroVAD:
210
 
211
  return speeches
212
 
 
 
 
213
  @staticmethod
214
  def collect_chunks(audio: np.ndarray, chunks: List[dict]) -> np.ndarray:
215
  """Collects and concatenates audio chunks."""
@@ -241,3 +245,20 @@ class SileroVAD:
241
  f"{hours_marker}{minutes:02d}:{seconds:02d}{decimal_marker}{milliseconds:03d}"
242
  )
243
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adapted from https://github.com/SYSTRAN/faster-whisper/blob/master/faster_whisper/vad.py
2
+
3
+ from faster_whisper.vad import VadOptions, get_vad_model
4
  import numpy as np
5
+ from typing import BinaryIO, Union, List, Optional, Tuple
6
  import warnings
7
  import faster_whisper
8
+ from faster_whisper.transcribe import SpeechTimestampsMap, Segment
9
  import gradio as gr
 
10
 
11
 
12
  class SileroVAD:
13
  def __init__(self):
14
  self.sampling_rate = 16000
15
+ self.window_size_samples = 512
16
+ self.model = None
17
 
 
18
  def run(self,
19
  audio: Union[str, BinaryIO, np.ndarray],
20
  vad_parameters: VadOptions,
21
+ progress: gr.Progress = gr.Progress()
22
+ ) -> Tuple[np.ndarray, List[dict]]:
23
  """
24
  Run VAD
25
 
 
34
 
35
  Returns
36
  ----------
37
+ np.ndarray
38
  Pre-processed audio with VAD
39
+ List[dict]
40
+ Chunks of speeches to be used to restore the timestamps later
41
  """
42
 
43
  sampling_rate = self.sampling_rate
 
60
  audio = self.collect_chunks(audio, speech_chunks)
61
  duration_after_vad = audio.shape[0] / sampling_rate
62
 
63
+ return audio, speech_chunks
64
 
 
 
65
  def get_speech_timestamps(
66
+ self,
67
  audio: np.ndarray,
68
  vad_options: Optional[VadOptions] = None,
69
  progress: gr.Progress = gr.Progress(),
 
80
  Returns:
81
  List of dicts containing begin and end samples of each speech chunk.
82
  """
83
+
84
+ if self.model is None:
85
+ self.update_model()
86
+
87
  if vad_options is None:
88
  vad_options = VadOptions(**kwargs)
89
 
 
91
  min_speech_duration_ms = vad_options.min_speech_duration_ms
92
  max_speech_duration_s = vad_options.max_speech_duration_s
93
  min_silence_duration_ms = vad_options.min_silence_duration_ms
94
+ window_size_samples = self.window_size_samples
95
  speech_pad_ms = vad_options.speech_pad_ms
 
 
 
 
 
 
 
96
  sampling_rate = 16000
97
  min_speech_samples = sampling_rate * min_speech_duration_ms / 1000
98
  speech_pad_samples = sampling_rate * speech_pad_ms / 1000
 
106
 
107
  audio_length_samples = len(audio)
108
 
109
+ state, context = self.model.get_initial_states(batch_size=1)
 
110
 
111
  speech_probs = []
112
  for current_start_sample in range(0, audio_length_samples, window_size_samples):
 
115
  chunk = audio[current_start_sample: current_start_sample + window_size_samples]
116
  if len(chunk) < window_size_samples:
117
  chunk = np.pad(chunk, (0, int(window_size_samples - len(chunk))))
118
+ speech_prob, state, context = self.model(chunk, state, context, sampling_rate)
119
  speech_probs.append(speech_prob)
120
 
121
  triggered = False
 
211
 
212
  return speeches
213
 
214
+ def update_model(self):
215
+ self.model = get_vad_model()
216
+
217
  @staticmethod
218
  def collect_chunks(audio: np.ndarray, chunks: List[dict]) -> np.ndarray:
219
  """Collects and concatenates audio chunks."""
 
245
  f"{hours_marker}{minutes:02d}:{seconds:02d}{decimal_marker}{milliseconds:03d}"
246
  )
247
 
248
+ def restore_speech_timestamps(
249
+ self,
250
+ segments: List[dict],
251
+ speech_chunks: List[dict],
252
+ sampling_rate: Optional[int] = None,
253
+ ) -> List[dict]:
254
+ if sampling_rate is None:
255
+ sampling_rate = self.sampling_rate
256
+
257
+ ts_map = SpeechTimestampsMap(speech_chunks, sampling_rate)
258
+
259
+ for segment in segments:
260
+ segment["start"] = ts_map.get_original_time(segment["start"])
261
+ segment["end"] = ts_map.get_original_time(segment["end"])
262
+
263
+ return segments
264
+
modules/whisper/faster_whisper_inference.py CHANGED
@@ -5,6 +5,7 @@ import torch
5
  from typing import BinaryIO, Union, Tuple, List
6
  import faster_whisper
7
  from faster_whisper.vad import VadOptions
 
8
  import ctranslate2
9
  import whisper
10
  import gradio as gr
@@ -13,31 +14,31 @@ from argparse import Namespace
13
  from modules.whisper.whisper_parameter import *
14
  from modules.whisper.whisper_base import WhisperBase
15
 
16
- import spaces
17
 
18
  class FasterWhisperInference(WhisperBase):
19
  def __init__(self,
20
- model_dir: str,
21
- output_dir: str,
22
- args: Namespace
23
  ):
24
  super().__init__(
25
  model_dir=model_dir,
26
- output_dir=output_dir,
27
- args=args
28
  )
 
 
 
29
  self.model_paths = self.get_model_paths()
30
  self.device = self.get_device()
31
  self.available_models = self.model_paths.keys()
32
- self.available_compute_types = ["float32"] # spaces bug
33
- self.current_compute_type = "float32" # spaces bug
34
- self.download_model(model_size="large-v2", model_dir=self.model_dir)
35
 
36
- #@spaces.GPU(duration=120)
37
  def transcribe(self,
38
  audio: Union[str, BinaryIO, np.ndarray],
 
39
  *whisper_params,
40
- progress: gr.Progress = gr.Progress(),
41
  ) -> Tuple[List[dict], float]:
42
  """
43
  transcribe method for faster-whisper.
@@ -65,7 +66,16 @@ class FasterWhisperInference(WhisperBase):
65
  if params.model_size != self.current_model_size or self.model is None or self.current_compute_type != params.compute_type:
66
  self.update_model(params.model_size, params.compute_type, progress)
67
 
68
- print("transcribe:")
 
 
 
 
 
 
 
 
 
69
  segments, info = self.model.transcribe(
70
  audio=audio,
71
  language=params.lang,
@@ -76,7 +86,25 @@ class FasterWhisperInference(WhisperBase):
76
  best_of=params.best_of,
77
  patience=params.patience,
78
  temperature=params.temperature,
 
79
  compression_ratio_threshold=params.compression_ratio_threshold,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
80
  )
81
  progress(0, desc="Loading audio..")
82
 
@@ -90,14 +118,12 @@ class FasterWhisperInference(WhisperBase):
90
  })
91
 
92
  elapsed_time = time.time() - start_time
93
- print("transcribe: finished")
94
  return segments_result, elapsed_time
95
 
96
- #@spaces.GPU(duration=120)
97
  def update_model(self,
98
  model_size: str,
99
  compute_type: str,
100
- progress: gr.Progress = gr.Progress(),
101
  ):
102
  """
103
  Update current model setting
@@ -113,7 +139,6 @@ class FasterWhisperInference(WhisperBase):
113
  Indicator to show progress directly in gradio.
114
  """
115
  progress(0, desc="Initializing Model..")
116
- print("update_model:")
117
  self.current_model_size = self.model_paths[model_size]
118
  self.current_compute_type = compute_type
119
  self.model = faster_whisper.WhisperModel(
@@ -122,7 +147,6 @@ class FasterWhisperInference(WhisperBase):
122
  download_root=self.model_dir,
123
  compute_type=self.current_compute_type
124
  )
125
- print("update_model: finished")
126
 
127
  def get_model_paths(self):
128
  """
@@ -149,22 +173,19 @@ class FasterWhisperInference(WhisperBase):
149
  model_paths[model_name] = os.path.join(webui_dir, self.model_dir, model_name)
150
  return model_paths
151
 
152
- def get_available_compute_type(self):
153
- if self.device == "cuda":
154
- return ['float32', 'int8_float16', 'float16', 'int8', 'int8_float32']
155
- return ['int16', 'float32', 'int8', 'int8_float32']
156
-
157
- def get_device(self):
158
- # Because of huggingface spaces bug, just return cpu
159
- return "cpu"
160
-
161
  @staticmethod
162
- def download_model(model_size: str, model_dir: str):
163
- print(f"\nDownloading \"{model_size}\" to \"{model_dir}\"..\n")
164
- os.makedirs(model_dir, exist_ok=True)
165
- faster_whisper.download_model(
166
- size_or_id=model_size,
167
- cache_dir=model_dir
168
- )
169
-
170
 
 
 
 
 
 
 
 
 
 
 
5
  from typing import BinaryIO, Union, Tuple, List
6
  import faster_whisper
7
  from faster_whisper.vad import VadOptions
8
+ import ast
9
  import ctranslate2
10
  import whisper
11
  import gradio as gr
 
14
  from modules.whisper.whisper_parameter import *
15
  from modules.whisper.whisper_base import WhisperBase
16
 
 
17
 
18
  class FasterWhisperInference(WhisperBase):
19
  def __init__(self,
20
+ model_dir: str = os.path.join("models", "Whisper", "faster-whisper"),
21
+ diarization_model_dir: str = os.path.join("models", "Diarization"),
22
+ output_dir: str = os.path.join("outputs"),
23
  ):
24
  super().__init__(
25
  model_dir=model_dir,
26
+ diarization_model_dir=diarization_model_dir,
27
+ output_dir=output_dir
28
  )
29
+ self.model_dir = model_dir
30
+ os.makedirs(self.model_dir, exist_ok=True)
31
+
32
  self.model_paths = self.get_model_paths()
33
  self.device = self.get_device()
34
  self.available_models = self.model_paths.keys()
35
+ self.available_compute_types = ctranslate2.get_supported_compute_types(
36
+ "cuda") if self.device == "cuda" else ctranslate2.get_supported_compute_types("cpu")
 
37
 
 
38
  def transcribe(self,
39
  audio: Union[str, BinaryIO, np.ndarray],
40
+ progress: gr.Progress,
41
  *whisper_params,
 
42
  ) -> Tuple[List[dict], float]:
43
  """
44
  transcribe method for faster-whisper.
 
66
  if params.model_size != self.current_model_size or self.model is None or self.current_compute_type != params.compute_type:
67
  self.update_model(params.model_size, params.compute_type, progress)
68
 
69
+ # None parameters with Textboxes: https://github.com/gradio-app/gradio/issues/8723
70
+ if not params.initial_prompt:
71
+ params.initial_prompt = None
72
+ if not params.prefix:
73
+ params.prefix = None
74
+ if not params.hotwords:
75
+ params.hotwords = None
76
+
77
+ params.suppress_tokens = self.format_suppress_tokens_str(params.suppress_tokens)
78
+
79
  segments, info = self.model.transcribe(
80
  audio=audio,
81
  language=params.lang,
 
86
  best_of=params.best_of,
87
  patience=params.patience,
88
  temperature=params.temperature,
89
+ initial_prompt=params.initial_prompt,
90
  compression_ratio_threshold=params.compression_ratio_threshold,
91
+ length_penalty=params.length_penalty,
92
+ repetition_penalty=params.repetition_penalty,
93
+ no_repeat_ngram_size=params.no_repeat_ngram_size,
94
+ prefix=params.prefix,
95
+ suppress_blank=params.suppress_blank,
96
+ suppress_tokens=params.suppress_tokens,
97
+ max_initial_timestamp=params.max_initial_timestamp,
98
+ word_timestamps=params.word_timestamps,
99
+ prepend_punctuations=params.prepend_punctuations,
100
+ append_punctuations=params.append_punctuations,
101
+ max_new_tokens=params.max_new_tokens,
102
+ chunk_length=params.chunk_length,
103
+ hallucination_silence_threshold=params.hallucination_silence_threshold,
104
+ hotwords=params.hotwords,
105
+ language_detection_threshold=params.language_detection_threshold,
106
+ language_detection_segments=params.language_detection_segments,
107
+ prompt_reset_on_temperature=params.prompt_reset_on_temperature,
108
  )
109
  progress(0, desc="Loading audio..")
110
 
 
118
  })
119
 
120
  elapsed_time = time.time() - start_time
 
121
  return segments_result, elapsed_time
122
 
 
123
  def update_model(self,
124
  model_size: str,
125
  compute_type: str,
126
+ progress: gr.Progress
127
  ):
128
  """
129
  Update current model setting
 
139
  Indicator to show progress directly in gradio.
140
  """
141
  progress(0, desc="Initializing Model..")
 
142
  self.current_model_size = self.model_paths[model_size]
143
  self.current_compute_type = compute_type
144
  self.model = faster_whisper.WhisperModel(
 
147
  download_root=self.model_dir,
148
  compute_type=self.current_compute_type
149
  )
 
150
 
151
  def get_model_paths(self):
152
  """
 
173
  model_paths[model_name] = os.path.join(webui_dir, self.model_dir, model_name)
174
  return model_paths
175
 
 
 
 
 
 
 
 
 
 
176
  @staticmethod
177
+ def get_device():
178
+ if torch.cuda.is_available():
179
+ return "cuda"
180
+ else:
181
+ return "auto"
 
 
 
182
 
183
+ @staticmethod
184
+ def format_suppress_tokens_str(suppress_tokens_str: str) -> List[int]:
185
+ try:
186
+ suppress_tokens = ast.literal_eval(suppress_tokens_str)
187
+ if not isinstance(suppress_tokens, list) or not all(isinstance(item, int) for item in suppress_tokens):
188
+ raise ValueError("Invalid Suppress Tokens. The value must be type of List[int]")
189
+ return suppress_tokens
190
+ except Exception as e:
191
+ raise ValueError("Invalid Suppress Tokens. The value must be type of List[int]")
modules/whisper/insanely_fast_whisper_inference.py CHANGED
@@ -17,15 +17,18 @@ from modules.whisper.whisper_base import WhisperBase
17
 
18
  class InsanelyFastWhisperInference(WhisperBase):
19
  def __init__(self,
20
- model_dir: str,
21
- output_dir: str,
22
- args: Namespace
23
  ):
24
  super().__init__(
25
  model_dir=model_dir,
26
  output_dir=output_dir,
27
- args=args
28
  )
 
 
 
29
  openai_models = whisper.available_models()
30
  distil_models = ["distil-large-v2", "distil-large-v3", "distil-medium.en", "distil-small.en"]
31
  self.available_models = openai_models + distil_models
 
17
 
18
  class InsanelyFastWhisperInference(WhisperBase):
19
  def __init__(self,
20
+ model_dir: str = os.path.join("models", "Whisper", "insanely-fast-whisper"),
21
+ diarization_model_dir: str = os.path.join("models", "Diarization"),
22
+ output_dir: str = os.path.join("outputs"),
23
  ):
24
  super().__init__(
25
  model_dir=model_dir,
26
  output_dir=output_dir,
27
+ diarization_model_dir=diarization_model_dir
28
  )
29
+ self.model_dir = model_dir
30
+ os.makedirs(self.model_dir, exist_ok=True)
31
+
32
  openai_models = whisper.available_models()
33
  distil_models = ["distil-large-v2", "distil-large-v3", "distil-medium.en", "distil-small.en"]
34
  self.available_models = openai_models + distil_models
modules/whisper/whisper_Inference.py CHANGED
@@ -4,6 +4,7 @@ import time
4
  from typing import BinaryIO, Union, Tuple, List
5
  import numpy as np
6
  import torch
 
7
  from argparse import Namespace
8
 
9
  from modules.whisper.whisper_base import WhisperBase
@@ -12,14 +13,14 @@ from modules.whisper.whisper_parameter import *
12
 
13
  class WhisperInference(WhisperBase):
14
  def __init__(self,
15
- model_dir: str,
16
- output_dir: str,
17
- args: Namespace
18
  ):
19
  super().__init__(
20
  model_dir=model_dir,
21
  output_dir=output_dir,
22
- args=args
23
  )
24
 
25
  def transcribe(self,
 
4
  from typing import BinaryIO, Union, Tuple, List
5
  import numpy as np
6
  import torch
7
+ import os
8
  from argparse import Namespace
9
 
10
  from modules.whisper.whisper_base import WhisperBase
 
13
 
14
  class WhisperInference(WhisperBase):
15
  def __init__(self,
16
+ model_dir: str = os.path.join("models", "Whisper"),
17
+ diarization_model_dir: str = os.path.join("models", "Diarization"),
18
+ output_dir: str = os.path.join("outputs"),
19
  ):
20
  super().__init__(
21
  model_dir=model_dir,
22
  output_dir=output_dir,
23
+ diarization_model_dir=diarization_model_dir
24
  )
25
 
26
  def transcribe(self,
modules/whisper/whisper_base.py CHANGED
@@ -6,13 +6,12 @@ from abc import ABC, abstractmethod
6
  from typing import BinaryIO, Union, Tuple, List
7
  import numpy as np
8
  from datetime import datetime
9
- from argparse import Namespace
10
  from faster_whisper.vad import VadOptions
11
  from dataclasses import astuple
12
- import spaces
13
 
14
  from modules.utils.subtitle_manager import get_srt, get_vtt, get_txt, write_file, safe_filename
15
  from modules.utils.youtube_manager import get_ytdata, get_ytaudio
 
16
  from modules.whisper.whisper_parameter import *
17
  from modules.diarize.diarizer import Diarizer
18
  from modules.vad.silero_vad import SileroVAD
@@ -20,51 +19,50 @@ from modules.vad.silero_vad import SileroVAD
20
 
21
  class WhisperBase(ABC):
22
  def __init__(self,
23
- model_dir: str,
24
- output_dir: str,
25
- args: Namespace
26
  ):
27
- self.model = None
28
- self.current_model_size = None
29
  self.model_dir = model_dir
30
  self.output_dir = output_dir
31
  os.makedirs(self.output_dir, exist_ok=True)
32
  os.makedirs(self.model_dir, exist_ok=True)
 
 
 
 
 
 
 
33
  self.available_models = whisper.available_models()
34
  self.available_langs = sorted(list(whisper.tokenizer.LANGUAGES.values()))
35
  self.translatable_models = ["large", "large-v1", "large-v2", "large-v3"]
36
  self.device = self.get_device()
37
  self.available_compute_types = ["float16", "float32"]
38
  self.current_compute_type = "float16" if self.device == "cuda" else "float32"
39
- self.diarizer = Diarizer(
40
- model_dir=args.diarization_model_dir
41
- )
42
- self.vad = SileroVAD()
43
 
44
  @abstractmethod
45
- #@spaces.GPU(duration=120)
46
  def transcribe(self,
47
  audio: Union[str, BinaryIO, np.ndarray],
 
48
  *whisper_params,
49
- progress: gr.Progress = gr.Progress(),
50
  ):
 
51
  pass
52
 
53
  @abstractmethod
54
- @spaces.GPU(duration=120)
55
  def update_model(self,
56
  model_size: str,
57
  compute_type: str,
58
- progress: gr.Progress = gr.Progress(),
59
  ):
 
60
  pass
61
 
62
- # spaces is problematic
63
- #@spaces.GPU(duration=120)
64
  def run(self,
65
  audio: Union[str, BinaryIO, np.ndarray],
 
66
  *whisper_params,
67
- progress: gr.Progress = gr.Progress(),
68
  ) -> Tuple[List[dict], float]:
69
  """
70
  Run transcription with conditional pre-processing and post-processing.
@@ -89,33 +87,44 @@ class WhisperBase(ABC):
89
  """
90
  params = WhisperParameters.as_value(*whisper_params)
91
 
 
 
 
 
 
 
 
92
  if params.vad_filter:
 
 
 
 
93
  vad_options = VadOptions(
94
  threshold=params.threshold,
95
  min_speech_duration_ms=params.min_speech_duration_ms,
96
  max_speech_duration_s=params.max_speech_duration_s,
97
  min_silence_duration_ms=params.min_silence_duration_ms,
98
- window_size_samples=params.window_size_samples,
99
  speech_pad_ms=params.speech_pad_ms
100
  )
101
- self.vad.run(
 
102
  audio=audio,
103
  vad_parameters=vad_options,
104
  progress=progress
105
  )
106
 
107
- if params.lang == "Automatic Detection":
108
- params.lang = None
109
- else:
110
- language_code_dict = {value: key for key, value in whisper.tokenizer.LANGUAGES.items()}
111
- params.lang = language_code_dict[params.lang]
112
-
113
  result, elapsed_time = self.transcribe(
114
  audio,
115
- *astuple(params),
116
- progress=progress
117
  )
118
 
 
 
 
 
 
 
119
  if params.is_diarize:
120
  result, elapsed_time_diarization = self.diarizer.run(
121
  audio=audio,
@@ -126,15 +135,14 @@ class WhisperBase(ABC):
126
  elapsed_time += elapsed_time_diarization
127
  return result, elapsed_time
128
 
129
- # spaces is problematic
130
- #@spaces.GPU(duration=120)
131
  def transcribe_file(self,
132
- files,
133
- file_format,
134
- add_timestamp,
135
- *whisper_params,
136
  progress=gr.Progress(),
137
- ):
 
138
  """
139
  Write subtitle file from Files
140
 
@@ -142,6 +150,9 @@ class WhisperBase(ABC):
142
  ----------
143
  files: list
144
  List of files to transcribe from gr.Files()
 
 
 
145
  file_format: str
146
  Subtitle File format to write from gr.Dropdown(). Supported format: [SRT, WebVTT, txt]
147
  add_timestamp: bool
@@ -159,16 +170,19 @@ class WhisperBase(ABC):
159
  Output file path to return to gr.Files()
160
  """
161
  try:
 
 
 
 
162
  files_info = {}
163
  for file in files:
164
  transcribed_segments, time_for_task = self.run(
165
  file.name,
 
166
  *whisper_params,
167
- progress=progress
168
  )
169
 
170
  file_name, file_ext = os.path.splitext(os.path.basename(file.name))
171
- file_name = safe_filename(file_name)
172
  subtitle, file_path = self.generate_and_write_file(
173
  file_name=file_name,
174
  transcribed_segments=transcribed_segments,
@@ -176,7 +190,6 @@ class WhisperBase(ABC):
176
  file_format=file_format,
177
  output_dir=self.output_dir
178
  )
179
- print("generated sub finished: ")
180
  files_info[file_name] = {"subtitle": subtitle, "time_for_task": time_for_task, "path": file_path}
181
 
182
  total_result = ''
@@ -195,16 +208,15 @@ class WhisperBase(ABC):
195
  except Exception as e:
196
  print(f"Error transcribing file: {e}")
197
  finally:
198
- # self.release_cuda_memory()
199
  if not files:
200
  self.remove_input_files([file.name for file in files])
201
 
202
- #@spaces.GPU(duration=120)
203
  def transcribe_mic(self,
204
  mic_audio: str,
205
  file_format: str,
 
206
  *whisper_params,
207
- progress: gr.Progress = gr.Progress(),
208
  ) -> list:
209
  """
210
  Write subtitle file from microphone
@@ -231,8 +243,8 @@ class WhisperBase(ABC):
231
  progress(0, desc="Loading Audio..")
232
  transcribed_segments, time_for_task = self.run(
233
  mic_audio,
 
234
  *whisper_params,
235
- progress=progress
236
  )
237
  progress(1, desc="Completed!")
238
 
@@ -252,13 +264,12 @@ class WhisperBase(ABC):
252
  self.release_cuda_memory()
253
  self.remove_input_files([mic_audio])
254
 
255
- #@spaces.GPU(duration=120)
256
  def transcribe_youtube(self,
257
  youtube_link: str,
258
  file_format: str,
259
  add_timestamp: bool,
 
260
  *whisper_params,
261
- progress: gr.Progress = gr.Progress(),
262
  ) -> list:
263
  """
264
  Write subtitle file from Youtube
@@ -290,8 +301,8 @@ class WhisperBase(ABC):
290
 
291
  transcribed_segments, time_for_task = self.run(
292
  audio,
 
293
  *whisper_params,
294
- progress=progress
295
  )
296
 
297
  progress(1, desc="Completed!")
@@ -318,13 +329,12 @@ class WhisperBase(ABC):
318
  else:
319
  file_path = get_ytaudio(yt)
320
 
321
- #self.release_cuda_memory()
322
  self.remove_input_files([file_path])
323
  except Exception as cleanup_error:
324
  pass
325
 
326
  @staticmethod
327
- @spaces.GPU(duration=120)
328
  def generate_and_write_file(file_name: str,
329
  transcribed_segments: list,
330
  add_timestamp: bool,
@@ -354,8 +364,8 @@ class WhisperBase(ABC):
354
  output_path: str
355
  output file path
356
  """
357
- timestamp = datetime.now().strftime("%m%d%H%M%S")
358
  if add_timestamp:
 
359
  output_path = os.path.join(output_dir, f"{file_name}-{timestamp}")
360
  else:
361
  output_path = os.path.join(output_dir, f"{file_name}")
@@ -363,17 +373,16 @@ class WhisperBase(ABC):
363
  if file_format == "SRT":
364
  content = get_srt(transcribed_segments)
365
  output_path += '.srt'
366
- write_file(content, output_path)
367
 
368
  elif file_format == "WebVTT":
369
  content = get_vtt(transcribed_segments)
370
  output_path += '.vtt'
371
- write_file(content, output_path)
372
 
373
  elif file_format == "txt":
374
  content = get_txt(transcribed_segments)
375
  output_path += '.txt'
376
- write_file(content, output_path)
 
377
  return content, output_path
378
 
379
  @staticmethod
@@ -403,12 +412,6 @@ class WhisperBase(ABC):
403
 
404
  return time_str.strip()
405
 
406
- @staticmethod
407
- def release_cuda_memory():
408
- if torch.cuda.is_available():
409
- torch.cuda.empty_cache()
410
- torch.cuda.reset_max_memory_allocated()
411
-
412
  @staticmethod
413
  def get_device():
414
  if torch.cuda.is_available():
@@ -418,6 +421,12 @@ class WhisperBase(ABC):
418
  else:
419
  return "cpu"
420
 
 
 
 
 
 
 
421
  @staticmethod
422
  def remove_input_files(file_paths: List[str]):
423
  if not file_paths:
 
6
  from typing import BinaryIO, Union, Tuple, List
7
  import numpy as np
8
  from datetime import datetime
 
9
  from faster_whisper.vad import VadOptions
10
  from dataclasses import astuple
 
11
 
12
  from modules.utils.subtitle_manager import get_srt, get_vtt, get_txt, write_file, safe_filename
13
  from modules.utils.youtube_manager import get_ytdata, get_ytaudio
14
+ from modules.utils.files_manager import get_media_files, format_gradio_files
15
  from modules.whisper.whisper_parameter import *
16
  from modules.diarize.diarizer import Diarizer
17
  from modules.vad.silero_vad import SileroVAD
 
19
 
20
  class WhisperBase(ABC):
21
  def __init__(self,
22
+ model_dir: str = os.path.join("models", "Whisper"),
23
+ diarization_model_dir: str = os.path.join("models", "Diarization"),
24
+ output_dir: str = os.path.join("outputs"),
25
  ):
 
 
26
  self.model_dir = model_dir
27
  self.output_dir = output_dir
28
  os.makedirs(self.output_dir, exist_ok=True)
29
  os.makedirs(self.model_dir, exist_ok=True)
30
+ self.diarizer = Diarizer(
31
+ model_dir=diarization_model_dir
32
+ )
33
+ self.vad = SileroVAD()
34
+
35
+ self.model = None
36
+ self.current_model_size = None
37
  self.available_models = whisper.available_models()
38
  self.available_langs = sorted(list(whisper.tokenizer.LANGUAGES.values()))
39
  self.translatable_models = ["large", "large-v1", "large-v2", "large-v3"]
40
  self.device = self.get_device()
41
  self.available_compute_types = ["float16", "float32"]
42
  self.current_compute_type = "float16" if self.device == "cuda" else "float32"
 
 
 
 
43
 
44
  @abstractmethod
 
45
  def transcribe(self,
46
  audio: Union[str, BinaryIO, np.ndarray],
47
+ progress: gr.Progress,
48
  *whisper_params,
 
49
  ):
50
+ """Inference whisper model to transcribe"""
51
  pass
52
 
53
  @abstractmethod
 
54
  def update_model(self,
55
  model_size: str,
56
  compute_type: str,
57
+ progress: gr.Progress
58
  ):
59
+ """Initialize whisper model"""
60
  pass
61
 
 
 
62
  def run(self,
63
  audio: Union[str, BinaryIO, np.ndarray],
64
+ progress: gr.Progress,
65
  *whisper_params,
 
66
  ) -> Tuple[List[dict], float]:
67
  """
68
  Run transcription with conditional pre-processing and post-processing.
 
87
  """
88
  params = WhisperParameters.as_value(*whisper_params)
89
 
90
+ if params.lang == "Automatic Detection":
91
+ params.lang = None
92
+ else:
93
+ language_code_dict = {value: key for key, value in whisper.tokenizer.LANGUAGES.items()}
94
+ params.lang = language_code_dict[params.lang]
95
+
96
+ speech_chunks = None
97
  if params.vad_filter:
98
+ # Explicit value set for float('inf') from gr.Number()
99
+ if params.max_speech_duration_s >= 9999:
100
+ params.max_speech_duration_s = float('inf')
101
+
102
  vad_options = VadOptions(
103
  threshold=params.threshold,
104
  min_speech_duration_ms=params.min_speech_duration_ms,
105
  max_speech_duration_s=params.max_speech_duration_s,
106
  min_silence_duration_ms=params.min_silence_duration_ms,
 
107
  speech_pad_ms=params.speech_pad_ms
108
  )
109
+
110
+ audio, speech_chunks = self.vad.run(
111
  audio=audio,
112
  vad_parameters=vad_options,
113
  progress=progress
114
  )
115
 
 
 
 
 
 
 
116
  result, elapsed_time = self.transcribe(
117
  audio,
118
+ progress,
119
+ *astuple(params)
120
  )
121
 
122
+ if params.vad_filter:
123
+ result = self.vad.restore_speech_timestamps(
124
+ segments=result,
125
+ speech_chunks=speech_chunks,
126
+ )
127
+
128
  if params.is_diarize:
129
  result, elapsed_time_diarization = self.diarizer.run(
130
  audio=audio,
 
135
  elapsed_time += elapsed_time_diarization
136
  return result, elapsed_time
137
 
 
 
138
  def transcribe_file(self,
139
+ files: list,
140
+ input_folder_path: str,
141
+ file_format: str,
142
+ add_timestamp: bool,
143
  progress=gr.Progress(),
144
+ *whisper_params,
145
+ ) -> list:
146
  """
147
  Write subtitle file from Files
148
 
 
150
  ----------
151
  files: list
152
  List of files to transcribe from gr.Files()
153
+ input_folder_path: str
154
+ Input folder path to transcribe from gr.Textbox(). If this is provided, `files` will be ignored and
155
+ this will be used instead.
156
  file_format: str
157
  Subtitle File format to write from gr.Dropdown(). Supported format: [SRT, WebVTT, txt]
158
  add_timestamp: bool
 
170
  Output file path to return to gr.Files()
171
  """
172
  try:
173
+ if input_folder_path:
174
+ files = get_media_files(input_folder_path)
175
+ files = format_gradio_files(files)
176
+
177
  files_info = {}
178
  for file in files:
179
  transcribed_segments, time_for_task = self.run(
180
  file.name,
181
+ progress,
182
  *whisper_params,
 
183
  )
184
 
185
  file_name, file_ext = os.path.splitext(os.path.basename(file.name))
 
186
  subtitle, file_path = self.generate_and_write_file(
187
  file_name=file_name,
188
  transcribed_segments=transcribed_segments,
 
190
  file_format=file_format,
191
  output_dir=self.output_dir
192
  )
 
193
  files_info[file_name] = {"subtitle": subtitle, "time_for_task": time_for_task, "path": file_path}
194
 
195
  total_result = ''
 
208
  except Exception as e:
209
  print(f"Error transcribing file: {e}")
210
  finally:
211
+ self.release_cuda_memory()
212
  if not files:
213
  self.remove_input_files([file.name for file in files])
214
 
 
215
  def transcribe_mic(self,
216
  mic_audio: str,
217
  file_format: str,
218
+ progress=gr.Progress(),
219
  *whisper_params,
 
220
  ) -> list:
221
  """
222
  Write subtitle file from microphone
 
243
  progress(0, desc="Loading Audio..")
244
  transcribed_segments, time_for_task = self.run(
245
  mic_audio,
246
+ progress,
247
  *whisper_params,
 
248
  )
249
  progress(1, desc="Completed!")
250
 
 
264
  self.release_cuda_memory()
265
  self.remove_input_files([mic_audio])
266
 
 
267
  def transcribe_youtube(self,
268
  youtube_link: str,
269
  file_format: str,
270
  add_timestamp: bool,
271
+ progress=gr.Progress(),
272
  *whisper_params,
 
273
  ) -> list:
274
  """
275
  Write subtitle file from Youtube
 
301
 
302
  transcribed_segments, time_for_task = self.run(
303
  audio,
304
+ progress,
305
  *whisper_params,
 
306
  )
307
 
308
  progress(1, desc="Completed!")
 
329
  else:
330
  file_path = get_ytaudio(yt)
331
 
332
+ self.release_cuda_memory()
333
  self.remove_input_files([file_path])
334
  except Exception as cleanup_error:
335
  pass
336
 
337
  @staticmethod
 
338
  def generate_and_write_file(file_name: str,
339
  transcribed_segments: list,
340
  add_timestamp: bool,
 
364
  output_path: str
365
  output file path
366
  """
 
367
  if add_timestamp:
368
+ timestamp = datetime.now().strftime("%m%d%H%M%S")
369
  output_path = os.path.join(output_dir, f"{file_name}-{timestamp}")
370
  else:
371
  output_path = os.path.join(output_dir, f"{file_name}")
 
373
  if file_format == "SRT":
374
  content = get_srt(transcribed_segments)
375
  output_path += '.srt'
 
376
 
377
  elif file_format == "WebVTT":
378
  content = get_vtt(transcribed_segments)
379
  output_path += '.vtt'
 
380
 
381
  elif file_format == "txt":
382
  content = get_txt(transcribed_segments)
383
  output_path += '.txt'
384
+
385
+ write_file(content, output_path)
386
  return content, output_path
387
 
388
  @staticmethod
 
412
 
413
  return time_str.strip()
414
 
 
 
 
 
 
 
415
  @staticmethod
416
  def get_device():
417
  if torch.cuda.is_available():
 
421
  else:
422
  return "cpu"
423
 
424
+ @staticmethod
425
+ def release_cuda_memory():
426
+ if torch.cuda.is_available():
427
+ torch.cuda.empty_cache()
428
+ torch.cuda.reset_max_memory_allocated()
429
+
430
  @staticmethod
431
  def remove_input_files(file_paths: List[str]):
432
  if not file_paths:
modules/whisper/whisper_factory.py ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional
2
+ import os
3
+
4
+ from modules.whisper.faster_whisper_inference import FasterWhisperInference
5
+ from modules.whisper.whisper_Inference import WhisperInference
6
+ from modules.whisper.insanely_fast_whisper_inference import InsanelyFastWhisperInference
7
+ from modules.whisper.whisper_base import WhisperBase
8
+
9
+
10
+ class WhisperFactory:
11
+ @staticmethod
12
+ def create_whisper_inference(
13
+ whisper_type: str,
14
+ whisper_model_dir: str = os.path.join("models", "Whisper"),
15
+ faster_whisper_model_dir: str = os.path.join("models", "Whisper", "faster-whisper"),
16
+ insanely_fast_whisper_model_dir: str = os.path.join("models", "Whisper", "insanely-fast-whisper"),
17
+ diarization_model_dir: str = os.path.join("models", "Diarization"),
18
+ output_dir: str = os.path.join("outputs"),
19
+ ) -> "WhisperBase":
20
+ """
21
+ Create a whisper inference class based on the provided whisper_type.
22
+
23
+ Parameters
24
+ ----------
25
+ whisper_type : str
26
+ The type of Whisper implementation to use. Supported values (case-insensitive):
27
+ - "faster-whisper": https://github.com/openai/whisper
28
+ - "whisper": https://github.com/openai/whisper
29
+ - "insanely-fast-whisper": https://github.com/Vaibhavs10/insanely-fast-whisper
30
+ whisper_model_dir : str
31
+ Directory path for the Whisper model.
32
+ faster_whisper_model_dir : str
33
+ Directory path for the Faster Whisper model.
34
+ insanely_fast_whisper_model_dir : str
35
+ Directory path for the Insanely Fast Whisper model.
36
+ diarization_model_dir : str
37
+ Directory path for the diarization model.
38
+ output_dir : str
39
+ Directory path where output files will be saved.
40
+
41
+ Returns
42
+ -------
43
+ WhisperBase
44
+ An instance of the appropriate whisper inference class based on the whisper_type.
45
+ """
46
+ # Temporal fix of the bug : https://github.com/jhj0517/Whisper-WebUI/issues/144
47
+ os.environ['KMP_DUPLICATE_LIB_OK'] = 'True'
48
+
49
+ whisper_type = whisper_type.lower().strip()
50
+
51
+ faster_whisper_typos = ["faster_whisper", "faster-whisper", "fasterwhisper"]
52
+ whisper_typos = ["whisper"]
53
+ insanely_fast_whisper_typos = [
54
+ "insanely_fast_whisper", "insanely-fast-whisper", "insanelyfastwhisper",
55
+ "insanely_faster_whisper", "insanely-faster-whisper", "insanelyfasterwhisper"
56
+ ]
57
+
58
+ if whisper_type in faster_whisper_typos:
59
+ return FasterWhisperInference(
60
+ model_dir=faster_whisper_model_dir,
61
+ output_dir=output_dir,
62
+ diarization_model_dir=diarization_model_dir
63
+ )
64
+ elif whisper_type in whisper_typos:
65
+ return WhisperInference(
66
+ model_dir=whisper_model_dir,
67
+ output_dir=output_dir,
68
+ diarization_model_dir=diarization_model_dir
69
+ )
70
+ elif whisper_type in insanely_fast_whisper_typos:
71
+ return InsanelyFastWhisperInference(
72
+ model_dir=insanely_fast_whisper_model_dir,
73
+ output_dir=output_dir,
74
+ diarization_model_dir=diarization_model_dir
75
+ )
76
+ else:
77
+ return FasterWhisperInference(
78
+ model_dir=faster_whisper_model_dir,
79
+ output_dir=output_dir,
80
+ diarization_model_dir=diarization_model_dir
81
+ )
modules/whisper/whisper_parameter.py CHANGED
@@ -15,6 +15,7 @@ class WhisperParameters:
15
  best_of: gr.Number
16
  patience: gr.Number
17
  condition_on_previous_text: gr.Checkbox
 
18
  initial_prompt: gr.Textbox
19
  temperature: gr.Slider
20
  compression_ratio_threshold: gr.Number
@@ -23,13 +24,28 @@ class WhisperParameters:
23
  min_speech_duration_ms: gr.Number
24
  max_speech_duration_s: gr.Number
25
  min_silence_duration_ms: gr.Number
26
- window_size_sample: gr.Number
27
  speech_pad_ms: gr.Number
28
  chunk_length_s: gr.Number
29
  batch_size: gr.Number
30
  is_diarize: gr.Checkbox
31
  hf_token: gr.Textbox
32
  diarization_device: gr.Dropdown
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
  """
34
  A data class for Gradio components of the Whisper Parameters. Use "before" Gradio pre-processing.
35
  This data class is used to mitigate the key-value problem between Gradio components and function parameters.
@@ -111,11 +127,6 @@ class WhisperParameters:
111
  This parameter is related with Silero VAD. In the end of each speech chunk wait for min_silence_duration_ms
112
  before separating it
113
 
114
- window_size_samples: gr.Number
115
- This parameter is related with Silero VAD. Audio chunks of window_size_samples size are fed to the silero VAD model.
116
- WARNING! Silero VAD models were trained using 512, 1024, 1536 samples for 16000 sample rate.
117
- Values other than these may affect model performance!!
118
-
119
  speech_pad_ms: gr.Number
120
  This parameter is related with Silero VAD. Final speech chunks are padded by speech_pad_ms each side
121
 
@@ -135,6 +146,62 @@ class WhisperParameters:
135
 
136
  diarization_device: gr.Dropdown
137
  This parameter is related with whisperx. Device to run diarization model
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
138
  """
139
 
140
  def as_list(self) -> list:
@@ -159,33 +226,7 @@ class WhisperParameters:
159
  WhisperValues
160
  Data class that has values of parameters
161
  """
162
- return WhisperValues(
163
- model_size=args[0],
164
- lang=args[1],
165
- is_translate=args[2],
166
- beam_size=args[3],
167
- log_prob_threshold=args[4],
168
- no_speech_threshold=args[5],
169
- compute_type=args[6],
170
- best_of=args[7],
171
- patience=args[8],
172
- condition_on_previous_text=args[9],
173
- initial_prompt=args[10],
174
- temperature=args[11],
175
- compression_ratio_threshold=args[12],
176
- vad_filter=args[13],
177
- threshold=args[14],
178
- min_speech_duration_ms=args[15],
179
- max_speech_duration_s=args[16],
180
- min_silence_duration_ms=args[17],
181
- window_size_samples=args[18],
182
- speech_pad_ms=args[19],
183
- chunk_length_s=args[20],
184
- batch_size=args[21],
185
- is_diarize=args[22],
186
- hf_token=args[23],
187
- diarization_device=args[24]
188
- )
189
 
190
 
191
  @dataclass
@@ -200,6 +241,7 @@ class WhisperValues:
200
  best_of: int
201
  patience: float
202
  condition_on_previous_text: bool
 
203
  initial_prompt: Optional[str]
204
  temperature: float
205
  compression_ratio_threshold: float
@@ -208,13 +250,28 @@ class WhisperValues:
208
  min_speech_duration_ms: int
209
  max_speech_duration_s: float
210
  min_silence_duration_ms: int
211
- window_size_samples: int
212
  speech_pad_ms: int
213
  chunk_length_s: int
214
  batch_size: int
215
  is_diarize: bool
216
  hf_token: str
217
  diarization_device: str
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
218
  """
219
  A data class to use Whisper parameters.
220
- """
 
15
  best_of: gr.Number
16
  patience: gr.Number
17
  condition_on_previous_text: gr.Checkbox
18
+ prompt_reset_on_temperature: gr.Slider
19
  initial_prompt: gr.Textbox
20
  temperature: gr.Slider
21
  compression_ratio_threshold: gr.Number
 
24
  min_speech_duration_ms: gr.Number
25
  max_speech_duration_s: gr.Number
26
  min_silence_duration_ms: gr.Number
 
27
  speech_pad_ms: gr.Number
28
  chunk_length_s: gr.Number
29
  batch_size: gr.Number
30
  is_diarize: gr.Checkbox
31
  hf_token: gr.Textbox
32
  diarization_device: gr.Dropdown
33
+ length_penalty: gr.Number
34
+ repetition_penalty: gr.Number
35
+ no_repeat_ngram_size: gr.Number
36
+ prefix: gr.Textbox
37
+ suppress_blank: gr.Checkbox
38
+ suppress_tokens: gr.Textbox
39
+ max_initial_timestamp: gr.Number
40
+ word_timestamps: gr.Checkbox
41
+ prepend_punctuations: gr.Textbox
42
+ append_punctuations: gr.Textbox
43
+ max_new_tokens: gr.Number
44
+ chunk_length: gr.Number
45
+ hallucination_silence_threshold: gr.Number
46
+ hotwords: gr.Textbox
47
+ language_detection_threshold: gr.Number
48
+ language_detection_segments: gr.Number
49
  """
50
  A data class for Gradio components of the Whisper Parameters. Use "before" Gradio pre-processing.
51
  This data class is used to mitigate the key-value problem between Gradio components and function parameters.
 
127
  This parameter is related with Silero VAD. In the end of each speech chunk wait for min_silence_duration_ms
128
  before separating it
129
 
 
 
 
 
 
130
  speech_pad_ms: gr.Number
131
  This parameter is related with Silero VAD. Final speech chunks are padded by speech_pad_ms each side
132
 
 
146
 
147
  diarization_device: gr.Dropdown
148
  This parameter is related with whisperx. Device to run diarization model
149
+
150
+ length_penalty:
151
+ This parameter is related to faster-whisper. Exponential length penalty constant.
152
+
153
+ repetition_penalty:
154
+ This parameter is related to faster-whisper. Penalty applied to the score of previously generated tokens
155
+ (set > 1 to penalize).
156
+
157
+ no_repeat_ngram_size:
158
+ This parameter is related to faster-whisper. Prevent repetitions of n-grams with this size (set 0 to disable).
159
+
160
+ prefix:
161
+ This parameter is related to faster-whisper. Optional text to provide as a prefix for the first window.
162
+
163
+ suppress_blank:
164
+ This parameter is related to faster-whisper. Suppress blank outputs at the beginning of the sampling.
165
+
166
+ suppress_tokens:
167
+ This parameter is related to faster-whisper. List of token IDs to suppress. -1 will suppress a default set
168
+ of symbols as defined in the model config.json file.
169
+
170
+ max_initial_timestamp:
171
+ This parameter is related to faster-whisper. The initial timestamp cannot be later than this.
172
+
173
+ word_timestamps:
174
+ This parameter is related to faster-whisper. Extract word-level timestamps using the cross-attention pattern
175
+ and dynamic time warping, and include the timestamps for each word in each segment.
176
+
177
+ prepend_punctuations:
178
+ This parameter is related to faster-whisper. If word_timestamps is True, merge these punctuation symbols
179
+ with the next word.
180
+
181
+ append_punctuations:
182
+ This parameter is related to faster-whisper. If word_timestamps is True, merge these punctuation symbols
183
+ with the previous word.
184
+
185
+ max_new_tokens:
186
+ This parameter is related to faster-whisper. Maximum number of new tokens to generate per-chunk. If not set,
187
+ the maximum will be set by the default max_length.
188
+
189
+ chunk_length:
190
+ This parameter is related to faster-whisper. The length of audio segments. If it is not None, it will overwrite the
191
+ default chunk_length of the FeatureExtractor.
192
+
193
+ hallucination_silence_threshold:
194
+ This parameter is related to faster-whisper. When word_timestamps is True, skip silent periods longer than this threshold
195
+ (in seconds) when a possible hallucination is detected.
196
+
197
+ hotwords:
198
+ This parameter is related to faster-whisper. Hotwords/hint phrases to provide the model with. Has no effect if prefix is not None.
199
+
200
+ language_detection_threshold:
201
+ This parameter is related to faster-whisper. If the maximum probability of the language tokens is higher than this value, the language is detected.
202
+
203
+ language_detection_segments:
204
+ This parameter is related to faster-whisper. Number of segments to consider for the language detection.
205
  """
206
 
207
  def as_list(self) -> list:
 
226
  WhisperValues
227
  Data class that has values of parameters
228
  """
229
+ return WhisperValues(*args)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
230
 
231
 
232
  @dataclass
 
241
  best_of: int
242
  patience: float
243
  condition_on_previous_text: bool
244
+ prompt_reset_on_temperature: float
245
  initial_prompt: Optional[str]
246
  temperature: float
247
  compression_ratio_threshold: float
 
250
  min_speech_duration_ms: int
251
  max_speech_duration_s: float
252
  min_silence_duration_ms: int
 
253
  speech_pad_ms: int
254
  chunk_length_s: int
255
  batch_size: int
256
  is_diarize: bool
257
  hf_token: str
258
  diarization_device: str
259
+ length_penalty: float
260
+ repetition_penalty: float
261
+ no_repeat_ngram_size: int
262
+ prefix: Optional[str]
263
+ suppress_blank: bool
264
+ suppress_tokens: Optional[str]
265
+ max_initial_timestamp: float
266
+ word_timestamps: bool
267
+ prepend_punctuations: Optional[str]
268
+ append_punctuations: Optional[str]
269
+ max_new_tokens: Optional[int]
270
+ chunk_length: Optional[int]
271
+ hallucination_silence_threshold: Optional[float]
272
+ hotwords: Optional[str]
273
+ language_detection_threshold: Optional[float]
274
+ language_detection_segments: int
275
  """
276
  A data class to use Whisper parameters.
277
+ """
notebook/whisper-webui.ipynb CHANGED
@@ -13,7 +13,7 @@
13
  "\n",
14
  "If you find this project useful, please consider supporting it:\n",
15
  "\n",
16
- "<a href=\"https://ko-fi.com/A0A7JSQRJ\" target=\"_blank\">\n",
17
  " <img src=\"https://storage.ko-fi.com/cdn/kofi2.png?v=3\" alt=\"Buy Me a Coffee at ko-fi.com\" height=\"36\">\n",
18
  "</a>\n",
19
  "\n",
@@ -53,9 +53,10 @@
53
  "!git clone https://github.com/jhj0517/Whisper-WebUI.git\n",
54
  "%cd Whisper-WebUI\n",
55
  "!pip install git+https://github.com/jhj0517/jhj0517-whisper.git\n",
56
- "!pip install faster-whisper==1.0.2\n",
57
- "!pip install gradio==4.14.0\n",
58
- "!pip install pytube\n",
 
59
  "!pip install tokenizers==0.19.1\n",
60
  "!pip install pyannote.audio==3.3.1"
61
  ]
@@ -70,7 +71,7 @@
70
  "\n",
71
  "USERNAME = '' #@param {type: \"string\"}\n",
72
  "PASSWORD = '' #@param {type: \"string\"}\n",
73
- "WHISPER_TYPE = 'faster-whisper' #@param {type: \"string\"}\n",
74
  "THEME = '' #@param {type: \"string\"}\n",
75
  "\n",
76
  "arguments = \"\"\n",
 
13
  "\n",
14
  "If you find this project useful, please consider supporting it:\n",
15
  "\n",
16
+ "<a href=\"https://ko-fi.com/jhj0517\" target=\"_blank\">\n",
17
  " <img src=\"https://storage.ko-fi.com/cdn/kofi2.png?v=3\" alt=\"Buy Me a Coffee at ko-fi.com\" height=\"36\">\n",
18
  "</a>\n",
19
  "\n",
 
53
  "!git clone https://github.com/jhj0517/Whisper-WebUI.git\n",
54
  "%cd Whisper-WebUI\n",
55
  "!pip install git+https://github.com/jhj0517/jhj0517-whisper.git\n",
56
+ "!pip install faster-whisper==1.0.3\n",
57
+ "!pip install gradio==4.29.0\n",
58
+ "# Temporal bug fix from https://github.com/jhj0517/Whisper-WebUI/issues/220\n",
59
+ "!pip install pytubefix\n",
60
  "!pip install tokenizers==0.19.1\n",
61
  "!pip install pyannote.audio==3.3.1"
62
  ]
 
71
  "\n",
72
  "USERNAME = '' #@param {type: \"string\"}\n",
73
  "PASSWORD = '' #@param {type: \"string\"}\n",
74
+ "WHISPER_TYPE = 'faster-whisper' # @param [\"whisper\", \"faster-whisper\", \"insanely-fast-whisper\"]\n",
75
  "THEME = '' #@param {type: \"string\"}\n",
76
  "\n",
77
  "arguments = \"\"\n",
requirements.txt CHANGED
@@ -1,8 +1,13 @@
 
 
 
 
 
1
  --extra-index-url https://download.pytorch.org/whl/cu121
2
  torch
3
  git+https://github.com/jhj0517/jhj0517-whisper.git
4
- faster-whisper==1.0.2
5
- transformers
6
- pytube
7
- gradio
8
  pyannote.audio==3.3.1
 
1
+ # Remove the --extra-index-url line below if you're not using Nvidia GPU.
2
+ # If you're using it, update url to your CUDA version (CUDA 12.1 is minimum requirement):
3
+ # For CUDA 12.1, use : https://download.pytorch.org/whl/cu121
4
+ # For CUDA 12.4, use : https://download.pytorch.org/whl/cu124
5
+
6
  --extra-index-url https://download.pytorch.org/whl/cu121
7
  torch
8
  git+https://github.com/jhj0517/jhj0517-whisper.git
9
+ faster-whisper==1.0.3
10
+ transformers==4.42.3
11
+ gradio==4.29.0
12
+ pytubefix
13
  pyannote.audio==3.3.1
start-webui.bat CHANGED
@@ -1,7 +1,7 @@
1
  @echo off
2
 
3
  call venv\scripts\activate
4
- python app.py
5
 
6
  echo "launching the app"
7
  pause
 
1
  @echo off
2
 
3
  call venv\scripts\activate
4
+ python app.py %*
5
 
6
  echo "launching the app"
7
  pause