Spaces:
Running
Running
Merge master
Browse files- .gitignore +4 -0
- app.py +220 -276
- docker-compose.yaml +31 -0
- modules/diarize/audio_loader.py +26 -8
- modules/diarize/diarize_pipeline.py +2 -0
- modules/diarize/diarizer.py +3 -7
- modules/translation/deepl_api.py +15 -14
- modules/translation/nllb_inference.py +25 -9
- modules/translation/translation_base.py +21 -27
- modules/utils/files_manager.py +39 -0
- modules/utils/subtitle_manager.py +1 -3
- modules/utils/youtube_manager.py +1 -1
- modules/vad/silero_vad.py +41 -20
- modules/whisper/faster_whisper_inference.py +55 -34
- modules/whisper/insanely_fast_whisper_inference.py +7 -4
- modules/whisper/whisper_Inference.py +5 -4
- modules/whisper/whisper_base.py +66 -57
- modules/whisper/whisper_factory.py +81 -0
- modules/whisper/whisper_parameter.py +92 -35
- notebook/whisper-webui.ipynb +6 -5
- requirements.txt +9 -4
- start-webui.bat +1 -1
.gitignore
CHANGED
@@ -1,3 +1,7 @@
|
|
|
|
|
|
|
|
|
|
1 |
venv/
|
2 |
ui/__pycache__/
|
3 |
outputs/
|
|
|
1 |
+
*.wav
|
2 |
+
*.png
|
3 |
+
*.mp4
|
4 |
+
*.mp3
|
5 |
venv/
|
6 |
ui/__pycache__/
|
7 |
outputs/
|
app.py
CHANGED
@@ -1,7 +1,8 @@
|
|
1 |
import os
|
2 |
import argparse
|
|
|
3 |
|
4 |
-
from modules.whisper.
|
5 |
from modules.whisper.faster_whisper_inference import FasterWhisperInference
|
6 |
from modules.whisper.insanely_fast_whisper_inference import InsanelyFastWhisperInference
|
7 |
from modules.translation.nllb_inference import NLLBInference
|
@@ -15,68 +16,150 @@ class App:
|
|
15 |
def __init__(self, args):
|
16 |
self.args = args
|
17 |
self.app = gr.Blocks(css=CSS, theme=self.args.theme)
|
18 |
-
self.whisper_inf =
|
19 |
-
|
|
|
|
|
|
|
20 |
output_dir=self.args.output_dir,
|
21 |
-
args=self.args
|
22 |
)
|
23 |
print(f"Use \"{self.args.whisper_type}\" implementation")
|
24 |
print(f"Device \"{self.whisper_inf.device}\" is detected")
|
25 |
self.nllb_inf = NLLBInference(
|
26 |
model_dir=self.args.nllb_model_dir,
|
27 |
-
output_dir=self.args.output_dir
|
28 |
)
|
29 |
self.deepl_api = DeepLAPI(
|
30 |
-
output_dir=self.args.output_dir
|
31 |
)
|
32 |
|
33 |
-
def
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
41 |
-
|
42 |
-
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
65 |
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
-
|
70 |
-
|
71 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
72 |
|
73 |
-
|
74 |
-
|
75 |
-
|
76 |
-
|
77 |
-
|
78 |
-
|
79 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
80 |
|
81 |
def launch(self):
|
82 |
with self.app:
|
@@ -85,84 +168,28 @@ class App:
|
|
85 |
gr.Markdown(MARKDOWN, elem_id="md_project")
|
86 |
with gr.Tabs():
|
87 |
with gr.TabItem("File"): # tab1
|
88 |
-
with gr.
|
89 |
input_file = gr.Files(type="filepath", label="Upload File here")
|
90 |
-
|
91 |
-
|
92 |
-
|
93 |
-
|
94 |
-
|
95 |
-
|
96 |
-
|
97 |
-
|
98 |
-
with gr.Row():
|
99 |
-
cb_timestamp = gr.Checkbox(value=True, label="Add a timestamp to the end of the filename", interactive=True)
|
100 |
-
with gr.Accordion("Advanced Parameters", open=False):
|
101 |
-
nb_beam_size = gr.Number(label="Beam Size", value=1, precision=0, interactive=True)
|
102 |
-
nb_log_prob_threshold = gr.Number(label="Log Probability Threshold", value=-1.0, interactive=True)
|
103 |
-
nb_no_speech_threshold = gr.Number(label="No Speech Threshold", value=0.6, interactive=True)
|
104 |
-
dd_compute_type = gr.Dropdown(label="Compute Type", choices=self.whisper_inf.available_compute_types, value=self.whisper_inf.current_compute_type, interactive=True)
|
105 |
-
nb_best_of = gr.Number(label="Best Of", value=5, interactive=True)
|
106 |
-
nb_patience = gr.Number(label="Patience", value=1, interactive=True)
|
107 |
-
cb_condition_on_previous_text = gr.Checkbox(label="Condition On Previous Text", value=True, interactive=True)
|
108 |
-
tb_initial_prompt = gr.Textbox(label="Initial Prompt", value=None, interactive=True)
|
109 |
-
sd_temperature = gr.Slider(label="Temperature", value=0, step=0.01, maximum=1.0, interactive=True)
|
110 |
-
nb_compression_ratio_threshold = gr.Number(label="Compression Ratio Threshold", value=2.4, interactive=True)
|
111 |
-
with gr.Accordion("VAD", open=False):
|
112 |
-
cb_vad_filter = gr.Checkbox(label="Enable Silero VAD Filter", value=False, interactive=True)
|
113 |
-
sd_threshold = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label="Speech Threshold", value=0.5, info="Lower it to be more sensitive to small sounds.")
|
114 |
-
nb_min_speech_duration_ms = gr.Number(label="Minimum Speech Duration (ms)", precision=0, value=250)
|
115 |
-
nb_max_speech_duration_s = gr.Number(label="Maximum Speech Duration (s)", value=9999)
|
116 |
-
nb_min_silence_duration_ms = gr.Number(label="Minimum Silence Duration (ms)", precision=0, value=2000)
|
117 |
-
nb_window_size_sample = gr.Number(label="Window Size (samples)", precision=0, value=1024)
|
118 |
-
nb_speech_pad_ms = gr.Number(label="Speech Padding (ms)", precision=0, value=400)
|
119 |
-
with gr.Accordion("Diarization", open=False):
|
120 |
-
cb_diarize = gr.Checkbox(label="Enable Diarization")
|
121 |
-
tb_hf_token = gr.Text(label="HuggingFace Token", value="",
|
122 |
-
info="This is only needed the first time you download the model. If you already have models, you don't need to enter. "
|
123 |
-
"To download the model, you must manually go to \"https://huggingface.co/pyannote/speaker-diarization-3.1\" and agree to their requirement.")
|
124 |
-
dd_diarization_device = gr.Dropdown(label="Device", choices=self.whisper_inf.diarizer.get_available_device(), value=self.whisper_inf.diarizer.get_device())
|
125 |
-
with gr.Accordion("Insanely Fast Whisper Parameters", open=False, visible=isinstance(self.whisper_inf, InsanelyFastWhisperInference)):
|
126 |
-
nb_chunk_length_s = gr.Number(label="Chunk Lengths (sec)", value=30, precision=0)
|
127 |
-
nb_batch_size = gr.Number(label="Batch Size", value=24, precision=0)
|
128 |
with gr.Row():
|
129 |
btn_run = gr.Button("GENERATE SUBTITLE FILE", variant="primary")
|
130 |
with gr.Row():
|
131 |
-
tb_indicator = gr.Textbox(label="Output", scale=
|
132 |
files_subtitles = gr.Files(label="Downloadable output file", scale=3, interactive=False)
|
|
|
133 |
|
134 |
-
params = [input_file, dd_file_format, cb_timestamp]
|
135 |
-
|
136 |
-
whisper_params = WhisperParameters(model_size=dd_model,
|
137 |
-
lang=dd_lang,
|
138 |
-
is_translate=cb_translate,
|
139 |
-
beam_size=nb_beam_size,
|
140 |
-
log_prob_threshold=nb_log_prob_threshold,
|
141 |
-
no_speech_threshold=nb_no_speech_threshold,
|
142 |
-
compute_type=dd_compute_type,
|
143 |
-
best_of=nb_best_of,
|
144 |
-
patience=nb_patience,
|
145 |
-
condition_on_previous_text=cb_condition_on_previous_text,
|
146 |
-
initial_prompt=tb_initial_prompt,
|
147 |
-
temperature=sd_temperature,
|
148 |
-
compression_ratio_threshold=nb_compression_ratio_threshold,
|
149 |
-
vad_filter=cb_vad_filter,
|
150 |
-
threshold=sd_threshold,
|
151 |
-
min_speech_duration_ms=nb_min_speech_duration_ms,
|
152 |
-
max_speech_duration_s=nb_max_speech_duration_s,
|
153 |
-
min_silence_duration_ms=nb_min_silence_duration_ms,
|
154 |
-
window_size_sample=nb_window_size_sample,
|
155 |
-
speech_pad_ms=nb_speech_pad_ms,
|
156 |
-
chunk_length_s=nb_chunk_length_s,
|
157 |
-
batch_size=nb_batch_size,
|
158 |
-
is_diarize=cb_diarize,
|
159 |
-
hf_token=tb_hf_token,
|
160 |
-
diarization_device=dd_diarization_device)
|
161 |
-
|
162 |
btn_run.click(fn=self.whisper_inf.transcribe_file,
|
163 |
-
inputs=params+whisper_params.as_list(),
|
164 |
outputs=[tb_indicator, files_subtitles])
|
165 |
-
|
166 |
|
167 |
with gr.TabItem("Youtube"): # tab2
|
168 |
with gr.Row():
|
@@ -173,164 +200,44 @@ class App:
|
|
173 |
with gr.Column():
|
174 |
tb_title = gr.Label(label="Youtube Title")
|
175 |
tb_description = gr.Textbox(label="Youtube Description", max_lines=15)
|
176 |
-
|
177 |
-
|
178 |
-
|
179 |
-
dd_lang = gr.Dropdown(choices=["Automatic Detection"] + self.whisper_inf.available_langs,
|
180 |
-
value="Automatic Detection", label="Language")
|
181 |
-
dd_file_format = gr.Dropdown(choices=["SRT", "WebVTT", "txt"], value="SRT", label="File Format")
|
182 |
-
with gr.Row():
|
183 |
-
cb_translate = gr.Checkbox(value=False, label="Translate to English?", interactive=True)
|
184 |
-
with gr.Row():
|
185 |
-
cb_timestamp = gr.Checkbox(value=True, label="Add a timestamp to the end of the filename",
|
186 |
-
interactive=True)
|
187 |
-
with gr.Accordion("Advanced Parameters", open=False):
|
188 |
-
nb_beam_size = gr.Number(label="Beam Size", value=1, precision=0, interactive=True)
|
189 |
-
nb_log_prob_threshold = gr.Number(label="Log Probability Threshold", value=-1.0, interactive=True)
|
190 |
-
nb_no_speech_threshold = gr.Number(label="No Speech Threshold", value=0.6, interactive=True)
|
191 |
-
dd_compute_type = gr.Dropdown(label="Compute Type", choices=self.whisper_inf.available_compute_types, value=self.whisper_inf.current_compute_type, interactive=True)
|
192 |
-
nb_best_of = gr.Number(label="Best Of", value=5, interactive=True)
|
193 |
-
nb_patience = gr.Number(label="Patience", value=1, interactive=True)
|
194 |
-
cb_condition_on_previous_text = gr.Checkbox(label="Condition On Previous Text", value=True, interactive=True)
|
195 |
-
tb_initial_prompt = gr.Textbox(label="Initial Prompt", value=None, interactive=True)
|
196 |
-
sd_temperature = gr.Slider(label="Temperature", value=0, step=0.01, maximum=1.0, interactive=True)
|
197 |
-
nb_compression_ratio_threshold = gr.Number(label="Compression Ratio Threshold", value=2.4, interactive=True)
|
198 |
-
with gr.Accordion("VAD", open=False):
|
199 |
-
cb_vad_filter = gr.Checkbox(label="Enable Silero VAD Filter", value=False, interactive=True)
|
200 |
-
sd_threshold = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label="Speech Threshold", value=0.5, info="Lower it to be more sensitive to small sounds.")
|
201 |
-
nb_min_speech_duration_ms = gr.Number(label="Minimum Speech Duration (ms)", precision=0, value=250)
|
202 |
-
nb_max_speech_duration_s = gr.Number(label="Maximum Speech Duration (s)", value=9999)
|
203 |
-
nb_min_silence_duration_ms = gr.Number(label="Minimum Silence Duration (ms)", precision=0, value=2000)
|
204 |
-
nb_window_size_sample = gr.Number(label="Window Size (samples)", precision=0, value=1024)
|
205 |
-
nb_speech_pad_ms = gr.Number(label="Speech Padding (ms)", precision=0, value=400)
|
206 |
-
with gr.Accordion("Diarization", open=False):
|
207 |
-
cb_diarize = gr.Checkbox(label="Enable Diarization")
|
208 |
-
tb_hf_token = gr.Text(label="HuggingFace Token", value="",
|
209 |
-
info="This is only needed the first time you download the model. If you already have models, you don't need to enter. "
|
210 |
-
"To download the model, you must manually go to \"https://huggingface.co/pyannote/speaker-diarization-3.1\" and agree to their requirement.")
|
211 |
-
dd_diarization_device = gr.Dropdown(label="Device", choices=self.whisper_inf.diarizer.get_available_device(), value=self.whisper_inf.diarizer.get_device())
|
212 |
-
with gr.Accordion("Insanely Fast Whisper Parameters", open=False,
|
213 |
-
visible=isinstance(self.whisper_inf, InsanelyFastWhisperInference)):
|
214 |
-
nb_chunk_length_s = gr.Number(label="Chunk Lengths (sec)", value=30, precision=0)
|
215 |
-
nb_batch_size = gr.Number(label="Batch Size", value=24, precision=0)
|
216 |
with gr.Row():
|
217 |
btn_run = gr.Button("GENERATE SUBTITLE FILE", variant="primary")
|
218 |
with gr.Row():
|
219 |
-
tb_indicator = gr.Textbox(label="Output", scale=
|
220 |
files_subtitles = gr.Files(label="Downloadable output file", scale=3)
|
|
|
221 |
|
222 |
params = [tb_youtubelink, dd_file_format, cb_timestamp]
|
223 |
-
whisper_params = WhisperParameters(model_size=dd_model,
|
224 |
-
lang=dd_lang,
|
225 |
-
is_translate=cb_translate,
|
226 |
-
beam_size=nb_beam_size,
|
227 |
-
log_prob_threshold=nb_log_prob_threshold,
|
228 |
-
no_speech_threshold=nb_no_speech_threshold,
|
229 |
-
compute_type=dd_compute_type,
|
230 |
-
best_of=nb_best_of,
|
231 |
-
patience=nb_patience,
|
232 |
-
condition_on_previous_text=cb_condition_on_previous_text,
|
233 |
-
initial_prompt=tb_initial_prompt,
|
234 |
-
temperature=sd_temperature,
|
235 |
-
compression_ratio_threshold=nb_compression_ratio_threshold,
|
236 |
-
vad_filter=cb_vad_filter,
|
237 |
-
threshold=sd_threshold,
|
238 |
-
min_speech_duration_ms=nb_min_speech_duration_ms,
|
239 |
-
max_speech_duration_s=nb_max_speech_duration_s,
|
240 |
-
min_silence_duration_ms=nb_min_silence_duration_ms,
|
241 |
-
window_size_sample=nb_window_size_sample,
|
242 |
-
speech_pad_ms=nb_speech_pad_ms,
|
243 |
-
chunk_length_s=nb_chunk_length_s,
|
244 |
-
batch_size=nb_batch_size,
|
245 |
-
is_diarize=cb_diarize,
|
246 |
-
hf_token=tb_hf_token,
|
247 |
-
diarization_device=dd_diarization_device)
|
248 |
|
249 |
btn_run.click(fn=self.whisper_inf.transcribe_youtube,
|
250 |
inputs=params + whisper_params.as_list(),
|
251 |
outputs=[tb_indicator, files_subtitles])
|
252 |
tb_youtubelink.change(get_ytmetas, inputs=[tb_youtubelink],
|
253 |
outputs=[img_thumbnail, tb_title, tb_description])
|
254 |
-
|
255 |
|
256 |
with gr.TabItem("Mic"): # tab3
|
257 |
with gr.Row():
|
258 |
mic_input = gr.Microphone(label="Record with Mic", type="filepath", interactive=True)
|
259 |
-
|
260 |
-
|
261 |
-
|
262 |
-
dd_lang = gr.Dropdown(choices=["Automatic Detection"] + self.whisper_inf.available_langs,
|
263 |
-
value="Automatic Detection", label="Language")
|
264 |
-
dd_file_format = gr.Dropdown(["SRT", "WebVTT", "txt"], value="SRT", label="File Format")
|
265 |
-
with gr.Row():
|
266 |
-
cb_translate = gr.Checkbox(value=False, label="Translate to English?", interactive=True)
|
267 |
-
with gr.Accordion("Advanced Parameters", open=False):
|
268 |
-
nb_beam_size = gr.Number(label="Beam Size", value=1, precision=0, interactive=True)
|
269 |
-
nb_log_prob_threshold = gr.Number(label="Log Probability Threshold", value=-1.0, interactive=True)
|
270 |
-
nb_no_speech_threshold = gr.Number(label="No Speech Threshold", value=0.6, interactive=True)
|
271 |
-
dd_compute_type = gr.Dropdown(label="Compute Type", choices=self.whisper_inf.available_compute_types, value=self.whisper_inf.current_compute_type, interactive=True)
|
272 |
-
nb_best_of = gr.Number(label="Best Of", value=5, interactive=True)
|
273 |
-
nb_patience = gr.Number(label="Patience", value=1, interactive=True)
|
274 |
-
cb_condition_on_previous_text = gr.Checkbox(label="Condition On Previous Text", value=True, interactive=True)
|
275 |
-
tb_initial_prompt = gr.Textbox(label="Initial Prompt", value=None, interactive=True)
|
276 |
-
sd_temperature = gr.Slider(label="Temperature", value=0, step=0.01, maximum=1.0, interactive=True)
|
277 |
-
with gr.Accordion("VAD", open=False):
|
278 |
-
cb_vad_filter = gr.Checkbox(label="Enable Silero VAD Filter", value=False, interactive=True)
|
279 |
-
sd_threshold = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label="Speech Threshold", value=0.5, info="Lower it to be more sensitive to small sounds.")
|
280 |
-
nb_min_speech_duration_ms = gr.Number(label="Minimum Speech Duration (ms)", precision=0, value=250)
|
281 |
-
nb_max_speech_duration_s = gr.Number(label="Maximum Speech Duration (s)", value=9999)
|
282 |
-
nb_min_silence_duration_ms = gr.Number(label="Minimum Silence Duration (ms)", precision=0, value=2000)
|
283 |
-
nb_window_size_sample = gr.Number(label="Window Size (samples)", precision=0, value=1024)
|
284 |
-
nb_speech_pad_ms = gr.Number(label="Speech Padding (ms)", precision=0, value=400)
|
285 |
-
with gr.Accordion("Diarization", open=False):
|
286 |
-
cb_diarize = gr.Checkbox(label="Enable Diarization")
|
287 |
-
tb_hf_token = gr.Text(label="HuggingFace Token", value="",
|
288 |
-
info="This is only needed the first time you download the model. If you already have models, you don't need to enter. "
|
289 |
-
"To download the model, you must manually go to \"https://huggingface.co/pyannote/speaker-diarization-3.1\" and agree to their requirement.")
|
290 |
-
dd_diarization_device = gr.Dropdown(label="Device",
|
291 |
-
choices=self.whisper_inf.diarizer.get_available_device(),
|
292 |
-
value=self.whisper_inf.diarizer.get_device())
|
293 |
-
with gr.Accordion("Insanely Fast Whisper Parameters", open=False,
|
294 |
-
visible=isinstance(self.whisper_inf, InsanelyFastWhisperInference)):
|
295 |
-
nb_chunk_length_s = gr.Number(label="Chunk Lengths (sec)", value=30, precision=0)
|
296 |
-
nb_batch_size = gr.Number(label="Batch Size", value=24, precision=0)
|
297 |
with gr.Row():
|
298 |
btn_run = gr.Button("GENERATE SUBTITLE FILE", variant="primary")
|
299 |
with gr.Row():
|
300 |
-
tb_indicator = gr.Textbox(label="Output", scale=
|
301 |
files_subtitles = gr.Files(label="Downloadable output file", scale=3)
|
|
|
302 |
|
303 |
params = [mic_input, dd_file_format]
|
304 |
-
whisper_params = WhisperParameters(model_size=dd_model,
|
305 |
-
lang=dd_lang,
|
306 |
-
is_translate=cb_translate,
|
307 |
-
beam_size=nb_beam_size,
|
308 |
-
log_prob_threshold=nb_log_prob_threshold,
|
309 |
-
no_speech_threshold=nb_no_speech_threshold,
|
310 |
-
compute_type=dd_compute_type,
|
311 |
-
best_of=nb_best_of,
|
312 |
-
patience=nb_patience,
|
313 |
-
condition_on_previous_text=cb_condition_on_previous_text,
|
314 |
-
initial_prompt=tb_initial_prompt,
|
315 |
-
temperature=sd_temperature,
|
316 |
-
compression_ratio_threshold=nb_compression_ratio_threshold,
|
317 |
-
vad_filter=cb_vad_filter,
|
318 |
-
threshold=sd_threshold,
|
319 |
-
min_speech_duration_ms=nb_min_speech_duration_ms,
|
320 |
-
max_speech_duration_s=nb_max_speech_duration_s,
|
321 |
-
min_silence_duration_ms=nb_min_silence_duration_ms,
|
322 |
-
window_size_sample=nb_window_size_sample,
|
323 |
-
speech_pad_ms=nb_speech_pad_ms,
|
324 |
-
chunk_length_s=nb_chunk_length_s,
|
325 |
-
batch_size=nb_batch_size,
|
326 |
-
is_diarize=cb_diarize,
|
327 |
-
hf_token=tb_hf_token,
|
328 |
-
diarization_device=dd_diarization_device)
|
329 |
|
330 |
btn_run.click(fn=self.whisper_inf.transcribe_mic,
|
331 |
inputs=params + whisper_params.as_list(),
|
332 |
outputs=[tb_indicator, files_subtitles])
|
333 |
-
|
334 |
|
335 |
with gr.TabItem("T2T Translation"): # tab 4
|
336 |
with gr.Row():
|
@@ -350,17 +257,25 @@ class App:
|
|
350 |
self.deepl_api.available_target_langs.keys()))
|
351 |
with gr.Row():
|
352 |
cb_deepl_ispro = gr.Checkbox(label="Pro User?", value=False)
|
|
|
|
|
|
|
353 |
with gr.Row():
|
354 |
btn_run = gr.Button("TRANSLATE SUBTITLE FILE", variant="primary")
|
355 |
with gr.Row():
|
356 |
tb_indicator = gr.Textbox(label="Output", scale=5)
|
357 |
files_subtitles = gr.Files(label="Downloadable output file", scale=3)
|
|
|
358 |
|
359 |
btn_run.click(fn=self.deepl_api.translate_deepl,
|
360 |
inputs=[tb_authkey, file_subs, dd_deepl_sourcelang, dd_deepl_targetlang,
|
361 |
-
cb_deepl_ispro],
|
362 |
outputs=[tb_indicator, files_subtitles])
|
363 |
|
|
|
|
|
|
|
|
|
364 |
with gr.TabItem("NLLB"): # sub tab2
|
365 |
with gr.Row():
|
366 |
dd_nllb_model = gr.Dropdown(label="Model", value="facebook/nllb-200-1.3B",
|
@@ -369,6 +284,8 @@ class App:
|
|
369 |
choices=self.nllb_inf.available_source_langs)
|
370 |
dd_nllb_targetlang = gr.Dropdown(label="Target Language",
|
371 |
choices=self.nllb_inf.available_target_langs)
|
|
|
|
|
372 |
with gr.Row():
|
373 |
cb_timestamp = gr.Checkbox(value=True, label="Add a timestamp to the end of the filename",
|
374 |
interactive=True)
|
@@ -377,33 +294,53 @@ class App:
|
|
377 |
with gr.Row():
|
378 |
tb_indicator = gr.Textbox(label="Output", scale=5)
|
379 |
files_subtitles = gr.Files(label="Downloadable output file", scale=3)
|
|
|
380 |
with gr.Column():
|
381 |
md_vram_table = gr.HTML(NLLB_VRAM_TABLE, elem_id="md_nllb_vram_table")
|
382 |
|
383 |
btn_run.click(fn=self.nllb_inf.translate_file,
|
384 |
-
inputs=[file_subs, dd_nllb_model, dd_nllb_sourcelang, dd_nllb_targetlang,
|
|
|
385 |
outputs=[tb_indicator, files_subtitles])
|
386 |
|
|
|
|
|
|
|
|
|
387 |
# Launch the app with optional gradio settings
|
388 |
-
|
389 |
-
|
390 |
-
|
391 |
-
|
392 |
-
|
393 |
-
|
394 |
-
|
395 |
-
|
396 |
-
|
397 |
-
|
398 |
-
|
399 |
-
|
400 |
-
|
401 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
402 |
|
403 |
|
404 |
# Create the parser for command-line arguments
|
405 |
parser = argparse.ArgumentParser()
|
406 |
-
parser.add_argument('--whisper_type', type=str, default="faster-whisper",
|
|
|
407 |
parser.add_argument('--share', type=bool, default=False, nargs='?', const=True, help='Gradio share value')
|
408 |
parser.add_argument('--server_name', type=str, default=None, help='Gradio server host')
|
409 |
parser.add_argument('--server_port', type=int, default=None, help='Gradio server port')
|
@@ -412,12 +349,19 @@ parser.add_argument('--username', type=str, default=None, help='Gradio authentic
|
|
412 |
parser.add_argument('--password', type=str, default=None, help='Gradio authentication password')
|
413 |
parser.add_argument('--theme', type=str, default=None, help='Gradio Blocks theme')
|
414 |
parser.add_argument('--colab', type=bool, default=False, nargs='?', const=True, help='Is colab user or not')
|
415 |
-
parser.add_argument('--api_open', type=bool, default=False, nargs='?', const=True, help='
|
416 |
-
parser.add_argument('--
|
417 |
-
parser.add_argument('--
|
418 |
-
|
419 |
-
parser.add_argument('--
|
420 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
421 |
parser.add_argument('--output_dir', type=str, default=os.path.join("outputs"), help='Directory path of the outputs')
|
422 |
_args = parser.parse_args()
|
423 |
|
|
|
1 |
import os
|
2 |
import argparse
|
3 |
+
import gradio as gr
|
4 |
|
5 |
+
from modules.whisper.whisper_factory import WhisperFactory
|
6 |
from modules.whisper.faster_whisper_inference import FasterWhisperInference
|
7 |
from modules.whisper.insanely_fast_whisper_inference import InsanelyFastWhisperInference
|
8 |
from modules.translation.nllb_inference import NLLBInference
|
|
|
16 |
def __init__(self, args):
|
17 |
self.args = args
|
18 |
self.app = gr.Blocks(css=CSS, theme=self.args.theme)
|
19 |
+
self.whisper_inf = WhisperFactory.create_whisper_inference(
|
20 |
+
whisper_type=self.args.whisper_type,
|
21 |
+
whisper_model_dir=self.args.whisper_model_dir,
|
22 |
+
faster_whisper_model_dir=self.args.faster_whisper_model_dir,
|
23 |
+
insanely_fast_whisper_model_dir=self.args.insanely_fast_whisper_model_dir,
|
24 |
output_dir=self.args.output_dir,
|
|
|
25 |
)
|
26 |
print(f"Use \"{self.args.whisper_type}\" implementation")
|
27 |
print(f"Device \"{self.whisper_inf.device}\" is detected")
|
28 |
self.nllb_inf = NLLBInference(
|
29 |
model_dir=self.args.nllb_model_dir,
|
30 |
+
output_dir=os.path.join(self.args.output_dir, "translations")
|
31 |
)
|
32 |
self.deepl_api = DeepLAPI(
|
33 |
+
output_dir=os.path.join(self.args.output_dir, "translations")
|
34 |
)
|
35 |
|
36 |
+
def create_whisper_parameters(self):
|
37 |
+
with gr.Row():
|
38 |
+
dd_model = gr.Dropdown(choices=self.whisper_inf.available_models, value="large-v2",
|
39 |
+
label="Model")
|
40 |
+
dd_lang = gr.Dropdown(choices=["Automatic Detection"] + self.whisper_inf.available_langs,
|
41 |
+
value="Automatic Detection", label="Language")
|
42 |
+
dd_file_format = gr.Dropdown(choices=["SRT", "WebVTT", "txt"], value="SRT", label="File Format")
|
43 |
+
with gr.Row():
|
44 |
+
cb_translate = gr.Checkbox(value=False, label="Translate to English?", interactive=True)
|
45 |
+
with gr.Row():
|
46 |
+
cb_timestamp = gr.Checkbox(value=True, label="Add a timestamp to the end of the filename",
|
47 |
+
interactive=True)
|
48 |
+
with gr.Accordion("Advanced Parameters", open=False):
|
49 |
+
nb_beam_size = gr.Number(label="Beam Size", value=5, precision=0, interactive=True,
|
50 |
+
info="Beam size to use for decoding.")
|
51 |
+
nb_log_prob_threshold = gr.Number(label="Log Probability Threshold", value=-1.0, interactive=True,
|
52 |
+
info="If the average log probability over sampled tokens is below this value, treat as failed.")
|
53 |
+
nb_no_speech_threshold = gr.Number(label="No Speech Threshold", value=0.6, interactive=True,
|
54 |
+
info="If the no speech probability is higher than this value AND the average log probability over sampled tokens is below 'Log Prob Threshold', consider the segment as silent.")
|
55 |
+
dd_compute_type = gr.Dropdown(label="Compute Type", choices=self.whisper_inf.available_compute_types,
|
56 |
+
value=self.whisper_inf.current_compute_type, interactive=True,
|
57 |
+
info="Select the type of computation to perform.")
|
58 |
+
nb_best_of = gr.Number(label="Best Of", value=5, interactive=True,
|
59 |
+
info="Number of candidates when sampling with non-zero temperature.")
|
60 |
+
nb_patience = gr.Number(label="Patience", value=1, interactive=True,
|
61 |
+
info="Beam search patience factor.")
|
62 |
+
cb_condition_on_previous_text = gr.Checkbox(label="Condition On Previous Text", value=True,
|
63 |
+
interactive=True,
|
64 |
+
info="Condition on previous text during decoding.")
|
65 |
+
sld_prompt_reset_on_temperature = gr.Slider(label="Prompt Reset On Temperature", value=0.5,
|
66 |
+
minimum=0, maximum=1, step=0.01, interactive=True,
|
67 |
+
info="Resets prompt if temperature is above this value."
|
68 |
+
" Arg has effect only if 'Condition On Previous Text' is True.")
|
69 |
+
tb_initial_prompt = gr.Textbox(label="Initial Prompt", value=None, interactive=True,
|
70 |
+
info="Initial prompt to use for decoding.")
|
71 |
+
sd_temperature = gr.Slider(label="Temperature", value=0, step=0.01, maximum=1.0, interactive=True,
|
72 |
+
info="Temperature for sampling. It can be a tuple of temperatures, which will be successively used upon failures according to either `Compression Ratio Threshold` or `Log Prob Threshold`.")
|
73 |
+
nb_compression_ratio_threshold = gr.Number(label="Compression Ratio Threshold", value=2.4, interactive=True,
|
74 |
+
info="If the gzip compression ratio is above this value, treat as failed.")
|
75 |
+
with gr.Group(visible=isinstance(self.whisper_inf, FasterWhisperInference)):
|
76 |
+
nb_length_penalty = gr.Number(label="Length Penalty", value=1,
|
77 |
+
info="Exponential length penalty constant.")
|
78 |
+
nb_repetition_penalty = gr.Number(label="Repetition Penalty", value=1,
|
79 |
+
info="Penalty applied to the score of previously generated tokens (set > 1 to penalize).")
|
80 |
+
nb_no_repeat_ngram_size = gr.Number(label="No Repeat N-gram Size", value=0, precision=0,
|
81 |
+
info="Prevent repetitions of n-grams with this size (set 0 to disable).")
|
82 |
+
tb_prefix = gr.Textbox(label="Prefix", value=lambda: None,
|
83 |
+
info="Optional text to provide as a prefix for the first window.")
|
84 |
+
cb_suppress_blank = gr.Checkbox(label="Suppress Blank", value=True,
|
85 |
+
info="Suppress blank outputs at the beginning of the sampling.")
|
86 |
+
tb_suppress_tokens = gr.Textbox(label="Suppress Tokens", value="[-1]",
|
87 |
+
info="List of token IDs to suppress. -1 will suppress a default set of symbols as defined in the model config.json file.")
|
88 |
+
nb_max_initial_timestamp = gr.Number(label="Max Initial Timestamp", value=1.0,
|
89 |
+
info="The initial timestamp cannot be later than this.")
|
90 |
+
cb_word_timestamps = gr.Checkbox(label="Word Timestamps", value=False,
|
91 |
+
info="Extract word-level timestamps using the cross-attention pattern and dynamic time warping, and include the timestamps for each word in each segment.")
|
92 |
+
tb_prepend_punctuations = gr.Textbox(label="Prepend Punctuations", value="\"'“¿([{-",
|
93 |
+
info="If 'Word Timestamps' is True, merge these punctuation symbols with the next word.")
|
94 |
+
tb_append_punctuations = gr.Textbox(label="Append Punctuations", value="\"'.。,,!!??::”)]}、",
|
95 |
+
info="If 'Word Timestamps' is True, merge these punctuation symbols with the previous word.")
|
96 |
+
nb_max_new_tokens = gr.Number(label="Max New Tokens", value=lambda: None, precision=0,
|
97 |
+
info="Maximum number of new tokens to generate per-chunk. If not set, the maximum will be set by the default max_length.")
|
98 |
+
nb_chunk_length = gr.Number(label="Chunk Length", value=lambda: None, precision=0,
|
99 |
+
info="The length of audio segments. If it is not None, it will overwrite the default chunk_length of the FeatureExtractor.")
|
100 |
+
nb_hallucination_silence_threshold = gr.Number(label="Hallucination Silence Threshold (sec)",
|
101 |
+
value=lambda: None,
|
102 |
+
info="When 'Word Timestamps' is True, skip silent periods longer than this threshold (in seconds) when a possible hallucination is detected.")
|
103 |
+
tb_hotwords = gr.Textbox(label="Hotwords", value=None,
|
104 |
+
info="Hotwords/hint phrases to provide the model with. Has no effect if prefix is not None.")
|
105 |
+
nb_language_detection_threshold = gr.Number(label="Language Detection Threshold", value=None,
|
106 |
+
info="If the maximum probability of the language tokens is higher than this value, the language is detected.")
|
107 |
+
nb_language_detection_segments = gr.Number(label="Language Detection Segments", value=1, precision=0,
|
108 |
+
info="Number of segments to consider for the language detection.")
|
109 |
+
with gr.Group(visible=isinstance(self.whisper_inf, InsanelyFastWhisperInference)):
|
110 |
+
nb_chunk_length_s = gr.Number(label="Chunk Lengths (sec)", value=30, precision=0)
|
111 |
+
nb_batch_size = gr.Number(label="Batch Size", value=24, precision=0)
|
112 |
|
113 |
+
with gr.Accordion("VAD", open=False):
|
114 |
+
cb_vad_filter = gr.Checkbox(label="Enable Silero VAD Filter", value=False, interactive=True)
|
115 |
+
sd_threshold = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label="Speech Threshold", value=0.5,
|
116 |
+
info="Lower it to be more sensitive to small sounds.")
|
117 |
+
nb_min_speech_duration_ms = gr.Number(label="Minimum Speech Duration (ms)", precision=0, value=250,
|
118 |
+
info="Final speech chunks shorter than this time are thrown out")
|
119 |
+
nb_max_speech_duration_s = gr.Number(label="Maximum Speech Duration (s)", value=9999,
|
120 |
+
info="Maximum duration of speech chunks in \"seconds\". Chunks longer"
|
121 |
+
" than this time will be split at the timestamp of the last silence that"
|
122 |
+
" lasts more than 100ms (if any), to prevent aggressive cutting.")
|
123 |
+
nb_min_silence_duration_ms = gr.Number(label="Minimum Silence Duration (ms)", precision=0, value=2000,
|
124 |
+
info="In the end of each speech chunk wait for this time"
|
125 |
+
" before separating it")
|
126 |
+
nb_speech_pad_ms = gr.Number(label="Speech Padding (ms)", precision=0, value=400,
|
127 |
+
info="Final speech chunks are padded by this time each side")
|
128 |
|
129 |
+
with gr.Accordion("Diarization", open=False):
|
130 |
+
cb_diarize = gr.Checkbox(label="Enable Diarization")
|
131 |
+
tb_hf_token = gr.Text(label="HuggingFace Token", value="",
|
132 |
+
info="This is only needed the first time you download the model. If you already have models, you don't need to enter. To download the model, you must manually go to \"https://huggingface.co/pyannote/speaker-diarization-3.1\" and agree to their requirement.")
|
133 |
+
dd_diarization_device = gr.Dropdown(label="Device",
|
134 |
+
choices=self.whisper_inf.diarizer.get_available_device(),
|
135 |
+
value=self.whisper_inf.diarizer.get_device())
|
136 |
+
|
137 |
+
dd_model.change(fn=self.on_change_models, inputs=[dd_model], outputs=[cb_translate])
|
138 |
+
|
139 |
+
return (
|
140 |
+
WhisperParameters(
|
141 |
+
model_size=dd_model, lang=dd_lang, is_translate=cb_translate, beam_size=nb_beam_size,
|
142 |
+
log_prob_threshold=nb_log_prob_threshold, no_speech_threshold=nb_no_speech_threshold,
|
143 |
+
compute_type=dd_compute_type, best_of=nb_best_of, patience=nb_patience,
|
144 |
+
condition_on_previous_text=cb_condition_on_previous_text, initial_prompt=tb_initial_prompt,
|
145 |
+
temperature=sd_temperature, compression_ratio_threshold=nb_compression_ratio_threshold,
|
146 |
+
vad_filter=cb_vad_filter, threshold=sd_threshold, min_speech_duration_ms=nb_min_speech_duration_ms,
|
147 |
+
max_speech_duration_s=nb_max_speech_duration_s, min_silence_duration_ms=nb_min_silence_duration_ms,
|
148 |
+
speech_pad_ms=nb_speech_pad_ms, chunk_length_s=nb_chunk_length_s, batch_size=nb_batch_size,
|
149 |
+
is_diarize=cb_diarize, hf_token=tb_hf_token, diarization_device=dd_diarization_device,
|
150 |
+
length_penalty=nb_length_penalty, repetition_penalty=nb_repetition_penalty,
|
151 |
+
no_repeat_ngram_size=nb_no_repeat_ngram_size, prefix=tb_prefix, suppress_blank=cb_suppress_blank,
|
152 |
+
suppress_tokens=tb_suppress_tokens, max_initial_timestamp=nb_max_initial_timestamp,
|
153 |
+
word_timestamps=cb_word_timestamps, prepend_punctuations=tb_prepend_punctuations,
|
154 |
+
append_punctuations=tb_append_punctuations, max_new_tokens=nb_max_new_tokens, chunk_length=nb_chunk_length,
|
155 |
+
hallucination_silence_threshold=nb_hallucination_silence_threshold, hotwords=tb_hotwords,
|
156 |
+
language_detection_threshold=nb_language_detection_threshold,
|
157 |
+
language_detection_segments=nb_language_detection_segments,
|
158 |
+
prompt_reset_on_temperature=sld_prompt_reset_on_temperature
|
159 |
+
),
|
160 |
+
dd_file_format,
|
161 |
+
cb_timestamp
|
162 |
+
)
|
163 |
|
164 |
def launch(self):
|
165 |
with self.app:
|
|
|
168 |
gr.Markdown(MARKDOWN, elem_id="md_project")
|
169 |
with gr.Tabs():
|
170 |
with gr.TabItem("File"): # tab1
|
171 |
+
with gr.Column():
|
172 |
input_file = gr.Files(type="filepath", label="Upload File here")
|
173 |
+
tb_input_folder = gr.Textbox(label="Input Folder Path (Optional)",
|
174 |
+
info="Optional: Specify the folder path where the input files are located, if you prefer to use local files instead of uploading them."
|
175 |
+
" Leave this field empty if you do not wish to use a local path.",
|
176 |
+
visible=self.args.colab,
|
177 |
+
value="")
|
178 |
+
|
179 |
+
whisper_params, dd_file_format, cb_timestamp = self.create_whisper_parameters()
|
180 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
181 |
with gr.Row():
|
182 |
btn_run = gr.Button("GENERATE SUBTITLE FILE", variant="primary")
|
183 |
with gr.Row():
|
184 |
+
tb_indicator = gr.Textbox(label="Output", scale=5)
|
185 |
files_subtitles = gr.Files(label="Downloadable output file", scale=3, interactive=False)
|
186 |
+
btn_openfolder = gr.Button('📂', scale=1)
|
187 |
|
188 |
+
params = [input_file, tb_input_folder, dd_file_format, cb_timestamp]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
189 |
btn_run.click(fn=self.whisper_inf.transcribe_file,
|
190 |
+
inputs=params + whisper_params.as_list(),
|
191 |
outputs=[tb_indicator, files_subtitles])
|
192 |
+
btn_openfolder.click(fn=lambda: self.open_folder("outputs"), inputs=None, outputs=None)
|
193 |
|
194 |
with gr.TabItem("Youtube"): # tab2
|
195 |
with gr.Row():
|
|
|
200 |
with gr.Column():
|
201 |
tb_title = gr.Label(label="Youtube Title")
|
202 |
tb_description = gr.Textbox(label="Youtube Description", max_lines=15)
|
203 |
+
|
204 |
+
whisper_params, dd_file_format, cb_timestamp = self.create_whisper_parameters()
|
205 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
206 |
with gr.Row():
|
207 |
btn_run = gr.Button("GENERATE SUBTITLE FILE", variant="primary")
|
208 |
with gr.Row():
|
209 |
+
tb_indicator = gr.Textbox(label="Output", scale=5)
|
210 |
files_subtitles = gr.Files(label="Downloadable output file", scale=3)
|
211 |
+
btn_openfolder = gr.Button('📂', scale=1)
|
212 |
|
213 |
params = [tb_youtubelink, dd_file_format, cb_timestamp]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
214 |
|
215 |
btn_run.click(fn=self.whisper_inf.transcribe_youtube,
|
216 |
inputs=params + whisper_params.as_list(),
|
217 |
outputs=[tb_indicator, files_subtitles])
|
218 |
tb_youtubelink.change(get_ytmetas, inputs=[tb_youtubelink],
|
219 |
outputs=[img_thumbnail, tb_title, tb_description])
|
220 |
+
btn_openfolder.click(fn=lambda: self.open_folder("outputs"), inputs=None, outputs=None)
|
221 |
|
222 |
with gr.TabItem("Mic"): # tab3
|
223 |
with gr.Row():
|
224 |
mic_input = gr.Microphone(label="Record with Mic", type="filepath", interactive=True)
|
225 |
+
|
226 |
+
whisper_params, dd_file_format, cb_timestamp = self.create_whisper_parameters()
|
227 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
228 |
with gr.Row():
|
229 |
btn_run = gr.Button("GENERATE SUBTITLE FILE", variant="primary")
|
230 |
with gr.Row():
|
231 |
+
tb_indicator = gr.Textbox(label="Output", scale=5)
|
232 |
files_subtitles = gr.Files(label="Downloadable output file", scale=3)
|
233 |
+
btn_openfolder = gr.Button('📂', scale=1)
|
234 |
|
235 |
params = [mic_input, dd_file_format]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
236 |
|
237 |
btn_run.click(fn=self.whisper_inf.transcribe_mic,
|
238 |
inputs=params + whisper_params.as_list(),
|
239 |
outputs=[tb_indicator, files_subtitles])
|
240 |
+
btn_openfolder.click(fn=lambda: self.open_folder("outputs"), inputs=None, outputs=None)
|
241 |
|
242 |
with gr.TabItem("T2T Translation"): # tab 4
|
243 |
with gr.Row():
|
|
|
257 |
self.deepl_api.available_target_langs.keys()))
|
258 |
with gr.Row():
|
259 |
cb_deepl_ispro = gr.Checkbox(label="Pro User?", value=False)
|
260 |
+
with gr.Row():
|
261 |
+
cb_timestamp = gr.Checkbox(value=True, label="Add a timestamp to the end of the filename",
|
262 |
+
interactive=True)
|
263 |
with gr.Row():
|
264 |
btn_run = gr.Button("TRANSLATE SUBTITLE FILE", variant="primary")
|
265 |
with gr.Row():
|
266 |
tb_indicator = gr.Textbox(label="Output", scale=5)
|
267 |
files_subtitles = gr.Files(label="Downloadable output file", scale=3)
|
268 |
+
btn_openfolder = gr.Button('📂', scale=1)
|
269 |
|
270 |
btn_run.click(fn=self.deepl_api.translate_deepl,
|
271 |
inputs=[tb_authkey, file_subs, dd_deepl_sourcelang, dd_deepl_targetlang,
|
272 |
+
cb_deepl_ispro, cb_timestamp],
|
273 |
outputs=[tb_indicator, files_subtitles])
|
274 |
|
275 |
+
btn_openfolder.click(fn=lambda: self.open_folder(os.path.join("outputs", "translations")),
|
276 |
+
inputs=None,
|
277 |
+
outputs=None)
|
278 |
+
|
279 |
with gr.TabItem("NLLB"): # sub tab2
|
280 |
with gr.Row():
|
281 |
dd_nllb_model = gr.Dropdown(label="Model", value="facebook/nllb-200-1.3B",
|
|
|
284 |
choices=self.nllb_inf.available_source_langs)
|
285 |
dd_nllb_targetlang = gr.Dropdown(label="Target Language",
|
286 |
choices=self.nllb_inf.available_target_langs)
|
287 |
+
with gr.Row():
|
288 |
+
nb_max_length = gr.Number(label="Max Length Per Line", value=200, precision=0)
|
289 |
with gr.Row():
|
290 |
cb_timestamp = gr.Checkbox(value=True, label="Add a timestamp to the end of the filename",
|
291 |
interactive=True)
|
|
|
294 |
with gr.Row():
|
295 |
tb_indicator = gr.Textbox(label="Output", scale=5)
|
296 |
files_subtitles = gr.Files(label="Downloadable output file", scale=3)
|
297 |
+
btn_openfolder = gr.Button('📂', scale=1)
|
298 |
with gr.Column():
|
299 |
md_vram_table = gr.HTML(NLLB_VRAM_TABLE, elem_id="md_nllb_vram_table")
|
300 |
|
301 |
btn_run.click(fn=self.nllb_inf.translate_file,
|
302 |
+
inputs=[file_subs, dd_nllb_model, dd_nllb_sourcelang, dd_nllb_targetlang,
|
303 |
+
nb_max_length, cb_timestamp],
|
304 |
outputs=[tb_indicator, files_subtitles])
|
305 |
|
306 |
+
btn_openfolder.click(fn=lambda: self.open_folder(os.path.join("outputs", "translations")),
|
307 |
+
inputs=None,
|
308 |
+
outputs=None)
|
309 |
+
|
310 |
# Launch the app with optional gradio settings
|
311 |
+
args = self.args
|
312 |
+
|
313 |
+
self.app.queue(
|
314 |
+
api_open=args.api_open
|
315 |
+
).launch(
|
316 |
+
share=args.share,
|
317 |
+
server_name=args.server_name,
|
318 |
+
server_port=args.server_port,
|
319 |
+
auth=(args.username, args.password) if args.username and args.password else None,
|
320 |
+
root_path=args.root_path,
|
321 |
+
inbrowser=args.inbrowser
|
322 |
+
)
|
323 |
+
|
324 |
+
@staticmethod
|
325 |
+
def open_folder(folder_path: str):
|
326 |
+
if os.path.exists(folder_path):
|
327 |
+
os.system(f"start {folder_path}")
|
328 |
+
else:
|
329 |
+
print(f"The folder {folder_path} does not exist.")
|
330 |
+
|
331 |
+
@staticmethod
|
332 |
+
def on_change_models(model_size: str):
|
333 |
+
translatable_model = ["large", "large-v1", "large-v2", "large-v3"]
|
334 |
+
if model_size not in translatable_model:
|
335 |
+
return gr.Checkbox(visible=False, value=False, interactive=False)
|
336 |
+
else:
|
337 |
+
return gr.Checkbox(visible=True, value=False, label="Translate to English?", interactive=True)
|
338 |
|
339 |
|
340 |
# Create the parser for command-line arguments
|
341 |
parser = argparse.ArgumentParser()
|
342 |
+
parser.add_argument('--whisper_type', type=str, default="faster-whisper",
|
343 |
+
help='A type of the whisper implementation between: ["whisper", "faster-whisper", "insanely-fast-whisper"]')
|
344 |
parser.add_argument('--share', type=bool, default=False, nargs='?', const=True, help='Gradio share value')
|
345 |
parser.add_argument('--server_name', type=str, default=None, help='Gradio server host')
|
346 |
parser.add_argument('--server_port', type=int, default=None, help='Gradio server port')
|
|
|
349 |
parser.add_argument('--password', type=str, default=None, help='Gradio authentication password')
|
350 |
parser.add_argument('--theme', type=str, default=None, help='Gradio Blocks theme')
|
351 |
parser.add_argument('--colab', type=bool, default=False, nargs='?', const=True, help='Is colab user or not')
|
352 |
+
parser.add_argument('--api_open', type=bool, default=False, nargs='?', const=True, help='Enable api or not in Gradio')
|
353 |
+
parser.add_argument('--inbrowser', type=bool, default=True, nargs='?', const=True, help='Whether to automatically start Gradio app or not')
|
354 |
+
parser.add_argument('--whisper_model_dir', type=str, default=os.path.join("models", "Whisper"),
|
355 |
+
help='Directory path of the whisper model')
|
356 |
+
parser.add_argument('--faster_whisper_model_dir', type=str, default=os.path.join("models", "Whisper", "faster-whisper"),
|
357 |
+
help='Directory path of the faster-whisper model')
|
358 |
+
parser.add_argument('--insanely_fast_whisper_model_dir', type=str,
|
359 |
+
default=os.path.join("models", "Whisper", "insanely-fast-whisper"),
|
360 |
+
help='Directory path of the insanely-fast-whisper model')
|
361 |
+
parser.add_argument('--diarization_model_dir', type=str, default=os.path.join("models", "Diarization"),
|
362 |
+
help='Directory path of the diarization model')
|
363 |
+
parser.add_argument('--nllb_model_dir', type=str, default=os.path.join("models", "NLLB"),
|
364 |
+
help='Directory path of the Facebook NLLB model')
|
365 |
parser.add_argument('--output_dir', type=str, default=os.path.join("outputs"), help='Directory path of the outputs')
|
366 |
_args = parser.parse_args()
|
367 |
|
docker-compose.yaml
ADDED
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
version: '3.8'
|
2 |
+
|
3 |
+
services:
|
4 |
+
app:
|
5 |
+
build: .
|
6 |
+
image: jhj0517/whisper-webui:latest
|
7 |
+
|
8 |
+
volumes:
|
9 |
+
# Update paths to mount models and output paths to your custom paths like this, e.g:
|
10 |
+
# - C:/whisper-models/custom-path:/Whisper-WebUI/models
|
11 |
+
# - C:/whisper-webui-outputs/custom-path:/Whisper-WebUI/outputs
|
12 |
+
- /Whisper-WebUI/models
|
13 |
+
- /Whisper-WebUI/outputs
|
14 |
+
|
15 |
+
ports:
|
16 |
+
- "7860:7860"
|
17 |
+
|
18 |
+
stdin_open: true
|
19 |
+
tty: true
|
20 |
+
|
21 |
+
entrypoint: ["python", "app.py", "--server_port", "7860", "--server_name", "0.0.0.0",]
|
22 |
+
|
23 |
+
# If you're not using nvidia GPU, Update device to match yours.
|
24 |
+
# See more info at : https://docs.docker.com/compose/compose-file/deploy/#driver
|
25 |
+
deploy:
|
26 |
+
resources:
|
27 |
+
reservations:
|
28 |
+
devices:
|
29 |
+
- driver: nvidia
|
30 |
+
count: all
|
31 |
+
capabilities: [ gpu ]
|
modules/diarize/audio_loader.py
CHANGED
@@ -1,7 +1,11 @@
|
|
|
|
|
|
1 |
import os
|
2 |
import subprocess
|
3 |
from functools import lru_cache
|
4 |
from typing import Optional, Union
|
|
|
|
|
5 |
|
6 |
import numpy as np
|
7 |
import torch
|
@@ -24,32 +28,43 @@ FRAMES_PER_SECOND = exact_div(SAMPLE_RATE, HOP_LENGTH) # 10ms per audio frame
|
|
24 |
TOKENS_PER_SECOND = exact_div(SAMPLE_RATE, N_SAMPLES_PER_TOKEN) # 20ms per audio token
|
25 |
|
26 |
|
27 |
-
def load_audio(file: str, sr: int = SAMPLE_RATE):
|
28 |
"""
|
29 |
-
Open an audio file
|
30 |
|
31 |
Parameters
|
32 |
----------
|
33 |
-
file: str
|
34 |
-
The audio file to open
|
35 |
|
36 |
sr: int
|
37 |
-
The sample rate to resample the audio if necessary
|
38 |
|
39 |
Returns
|
40 |
-------
|
41 |
A NumPy array containing the audio waveform, in float32 dtype.
|
42 |
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
43 |
try:
|
44 |
-
# Launches a subprocess to decode audio while down-mixing and resampling as necessary.
|
45 |
-
# Requires the ffmpeg CLI to be installed.
|
46 |
cmd = [
|
47 |
"ffmpeg",
|
48 |
"-nostdin",
|
49 |
"-threads",
|
50 |
"0",
|
51 |
"-i",
|
52 |
-
|
53 |
"-f",
|
54 |
"s16le",
|
55 |
"-ac",
|
@@ -63,6 +78,9 @@ def load_audio(file: str, sr: int = SAMPLE_RATE):
|
|
63 |
out = subprocess.run(cmd, capture_output=True, check=True).stdout
|
64 |
except subprocess.CalledProcessError as e:
|
65 |
raise RuntimeError(f"Failed to load audio: {e.stderr.decode()}") from e
|
|
|
|
|
|
|
66 |
|
67 |
return np.frombuffer(out, np.int16).flatten().astype(np.float32) / 32768.0
|
68 |
|
|
|
1 |
+
# Adapted from https://github.com/m-bain/whisperX/blob/main/whisperx/audio.py
|
2 |
+
|
3 |
import os
|
4 |
import subprocess
|
5 |
from functools import lru_cache
|
6 |
from typing import Optional, Union
|
7 |
+
from scipy.io.wavfile import write
|
8 |
+
import tempfile
|
9 |
|
10 |
import numpy as np
|
11 |
import torch
|
|
|
28 |
TOKENS_PER_SECOND = exact_div(SAMPLE_RATE, N_SAMPLES_PER_TOKEN) # 20ms per audio token
|
29 |
|
30 |
|
31 |
+
def load_audio(file: Union[str, np.ndarray], sr: int = SAMPLE_RATE) -> np.ndarray:
|
32 |
"""
|
33 |
+
Open an audio file or process a numpy array containing audio data as mono waveform, resampling as necessary.
|
34 |
|
35 |
Parameters
|
36 |
----------
|
37 |
+
file: Union[str, np.ndarray]
|
38 |
+
The audio file to open or a numpy array containing the audio data.
|
39 |
|
40 |
sr: int
|
41 |
+
The sample rate to resample the audio if necessary.
|
42 |
|
43 |
Returns
|
44 |
-------
|
45 |
A NumPy array containing the audio waveform, in float32 dtype.
|
46 |
"""
|
47 |
+
if isinstance(file, np.ndarray):
|
48 |
+
if file.dtype != np.float32:
|
49 |
+
file = file.astype(np.float32)
|
50 |
+
if file.ndim > 1:
|
51 |
+
file = np.mean(file, axis=1)
|
52 |
+
|
53 |
+
temp_file = tempfile.NamedTemporaryFile(delete=False, suffix=".wav")
|
54 |
+
write(temp_file.name, SAMPLE_RATE, (file * 32768).astype(np.int16))
|
55 |
+
temp_file_path = temp_file.name
|
56 |
+
temp_file.close()
|
57 |
+
else:
|
58 |
+
temp_file_path = file
|
59 |
+
|
60 |
try:
|
|
|
|
|
61 |
cmd = [
|
62 |
"ffmpeg",
|
63 |
"-nostdin",
|
64 |
"-threads",
|
65 |
"0",
|
66 |
"-i",
|
67 |
+
temp_file_path,
|
68 |
"-f",
|
69 |
"s16le",
|
70 |
"-ac",
|
|
|
78 |
out = subprocess.run(cmd, capture_output=True, check=True).stdout
|
79 |
except subprocess.CalledProcessError as e:
|
80 |
raise RuntimeError(f"Failed to load audio: {e.stderr.decode()}") from e
|
81 |
+
finally:
|
82 |
+
if isinstance(file, np.ndarray):
|
83 |
+
os.remove(temp_file_path)
|
84 |
|
85 |
return np.frombuffer(out, np.int16).flatten().astype(np.float32) / 32768.0
|
86 |
|
modules/diarize/diarize_pipeline.py
CHANGED
@@ -1,3 +1,5 @@
|
|
|
|
|
|
1 |
import numpy as np
|
2 |
import pandas as pd
|
3 |
import os
|
|
|
1 |
+
# Adapted from https://github.com/m-bain/whisperX/blob/main/whisperx/diarize.py
|
2 |
+
|
3 |
import numpy as np
|
4 |
import pandas as pd
|
5 |
import os
|
modules/diarize/diarizer.py
CHANGED
@@ -1,9 +1,9 @@
|
|
1 |
import os
|
2 |
import torch
|
3 |
-
from typing import List
|
|
|
4 |
import time
|
5 |
import logging
|
6 |
-
import spaces
|
7 |
|
8 |
from modules.diarize.diarize_pipeline import DiarizationPipeline, assign_word_speakers
|
9 |
from modules.diarize.audio_loader import load_audio
|
@@ -20,9 +20,8 @@ class Diarizer:
|
|
20 |
os.makedirs(self.model_dir, exist_ok=True)
|
21 |
self.pipe = None
|
22 |
|
23 |
-
@spaces.GPU
|
24 |
def run(self,
|
25 |
-
audio: str,
|
26 |
transcribed_result: List[dict],
|
27 |
use_auth_token: str,
|
28 |
device: str
|
@@ -75,7 +74,6 @@ class Diarizer:
|
|
75 |
elapsed_time = time.time() - start_time
|
76 |
return diarized_result["segments"], elapsed_time
|
77 |
|
78 |
-
@spaces.GPU
|
79 |
def update_pipe(self,
|
80 |
use_auth_token: str,
|
81 |
device: str
|
@@ -113,7 +111,6 @@ class Diarizer:
|
|
113 |
logger.disabled = False
|
114 |
|
115 |
@staticmethod
|
116 |
-
@spaces.GPU
|
117 |
def get_device():
|
118 |
if torch.cuda.is_available():
|
119 |
return "cuda"
|
@@ -123,7 +120,6 @@ class Diarizer:
|
|
123 |
return "cpu"
|
124 |
|
125 |
@staticmethod
|
126 |
-
@spaces.GPU
|
127 |
def get_available_device():
|
128 |
devices = ["cpu"]
|
129 |
if torch.cuda.is_available():
|
|
|
1 |
import os
|
2 |
import torch
|
3 |
+
from typing import List, Union, BinaryIO
|
4 |
+
import numpy as np
|
5 |
import time
|
6 |
import logging
|
|
|
7 |
|
8 |
from modules.diarize.diarize_pipeline import DiarizationPipeline, assign_word_speakers
|
9 |
from modules.diarize.audio_loader import load_audio
|
|
|
20 |
os.makedirs(self.model_dir, exist_ok=True)
|
21 |
self.pipe = None
|
22 |
|
|
|
23 |
def run(self,
|
24 |
+
audio: Union[str, BinaryIO, np.ndarray],
|
25 |
transcribed_result: List[dict],
|
26 |
use_auth_token: str,
|
27 |
device: str
|
|
|
74 |
elapsed_time = time.time() - start_time
|
75 |
return diarized_result["segments"], elapsed_time
|
76 |
|
|
|
77 |
def update_pipe(self,
|
78 |
use_auth_token: str,
|
79 |
device: str
|
|
|
111 |
logger.disabled = False
|
112 |
|
113 |
@staticmethod
|
|
|
114 |
def get_device():
|
115 |
if torch.cuda.is_available():
|
116 |
return "cuda"
|
|
|
120 |
return "cpu"
|
121 |
|
122 |
@staticmethod
|
|
|
123 |
def get_available_device():
|
124 |
devices = ["cpu"]
|
125 |
if torch.cuda.is_available():
|
modules/translation/deepl_api.py
CHANGED
@@ -83,7 +83,7 @@ DEEPL_AVAILABLE_SOURCE_LANGS = {
|
|
83 |
|
84 |
class DeepLAPI:
|
85 |
def __init__(self,
|
86 |
-
output_dir: str
|
87 |
):
|
88 |
self.api_interval = 1
|
89 |
self.max_text_batch_size = 50
|
@@ -97,6 +97,7 @@ class DeepLAPI:
|
|
97 |
source_lang: str,
|
98 |
target_lang: str,
|
99 |
is_pro: bool,
|
|
|
100 |
progress=gr.Progress()) -> list:
|
101 |
"""
|
102 |
Translate subtitle files using DeepL API
|
@@ -112,6 +113,8 @@ class DeepLAPI:
|
|
112 |
Target language of the file to transcribe from gr.Dropdown()
|
113 |
is_pro: str
|
114 |
Boolean value that is about pro user or not from gr.Checkbox().
|
|
|
|
|
115 |
progress: gr.Progress
|
116 |
Indicator to show progress directly in gradio.
|
117 |
|
@@ -141,11 +144,6 @@ class DeepLAPI:
|
|
141 |
progress(batch_end / len(parsed_dicts), desc="Translating..")
|
142 |
|
143 |
subtitle = get_serialized_srt(parsed_dicts)
|
144 |
-
timestamp = datetime.now().strftime("%m%d%H%M%S")
|
145 |
-
|
146 |
-
file_name = file_name[:-9]
|
147 |
-
output_path = os.path.join(self.output_dir, "", f"{file_name}-{timestamp}.srt")
|
148 |
-
write_file(subtitle, output_path)
|
149 |
|
150 |
elif file_ext == ".vtt":
|
151 |
parsed_dicts = parse_vtt(file_path=file_path)
|
@@ -161,22 +159,25 @@ class DeepLAPI:
|
|
161 |
progress(batch_end / len(parsed_dicts), desc="Translating..")
|
162 |
|
163 |
subtitle = get_serialized_vtt(parsed_dicts)
|
|
|
|
|
164 |
timestamp = datetime.now().strftime("%m%d%H%M%S")
|
|
|
165 |
|
166 |
-
|
167 |
-
|
168 |
|
169 |
-
|
170 |
|
171 |
-
files_info[file_name] = subtitle
|
172 |
total_result = ''
|
173 |
-
for file_name,
|
174 |
total_result += '------------------------------------\n'
|
175 |
total_result += f'{file_name}\n\n'
|
176 |
-
total_result += f'{subtitle}'
|
177 |
-
|
178 |
gr_str = f"Done! Subtitle is in the outputs/translation folder.\n\n{total_result}"
|
179 |
-
|
|
|
|
|
180 |
|
181 |
def request_deepl_translate(self,
|
182 |
auth_key: str,
|
|
|
83 |
|
84 |
class DeepLAPI:
|
85 |
def __init__(self,
|
86 |
+
output_dir: str = os.path.join("outputs", "translations")
|
87 |
):
|
88 |
self.api_interval = 1
|
89 |
self.max_text_batch_size = 50
|
|
|
97 |
source_lang: str,
|
98 |
target_lang: str,
|
99 |
is_pro: bool,
|
100 |
+
add_timestamp: bool,
|
101 |
progress=gr.Progress()) -> list:
|
102 |
"""
|
103 |
Translate subtitle files using DeepL API
|
|
|
113 |
Target language of the file to transcribe from gr.Dropdown()
|
114 |
is_pro: str
|
115 |
Boolean value that is about pro user or not from gr.Checkbox().
|
116 |
+
add_timestamp: bool
|
117 |
+
Boolean value from gr.Checkbox() that determines whether to add a timestamp at the end of the filename.
|
118 |
progress: gr.Progress
|
119 |
Indicator to show progress directly in gradio.
|
120 |
|
|
|
144 |
progress(batch_end / len(parsed_dicts), desc="Translating..")
|
145 |
|
146 |
subtitle = get_serialized_srt(parsed_dicts)
|
|
|
|
|
|
|
|
|
|
|
147 |
|
148 |
elif file_ext == ".vtt":
|
149 |
parsed_dicts = parse_vtt(file_path=file_path)
|
|
|
159 |
progress(batch_end / len(parsed_dicts), desc="Translating..")
|
160 |
|
161 |
subtitle = get_serialized_vtt(parsed_dicts)
|
162 |
+
|
163 |
+
if add_timestamp:
|
164 |
timestamp = datetime.now().strftime("%m%d%H%M%S")
|
165 |
+
file_name += f"-{timestamp}"
|
166 |
|
167 |
+
output_path = os.path.join(self.output_dir, f"{file_name}{file_ext}")
|
168 |
+
write_file(subtitle, output_path)
|
169 |
|
170 |
+
files_info[file_name] = {"subtitle": subtitle, "path": output_path}
|
171 |
|
|
|
172 |
total_result = ''
|
173 |
+
for file_name, info in files_info.items():
|
174 |
total_result += '------------------------------------\n'
|
175 |
total_result += f'{file_name}\n\n'
|
176 |
+
total_result += f'{info["subtitle"]}'
|
|
|
177 |
gr_str = f"Done! Subtitle is in the outputs/translation folder.\n\n{total_result}"
|
178 |
+
|
179 |
+
output_file_paths = [item["path"] for key, item in files_info.items()]
|
180 |
+
return [gr_str, output_file_paths]
|
181 |
|
182 |
def request_deepl_translate(self,
|
183 |
auth_key: str,
|
modules/translation/nllb_inference.py
CHANGED
@@ -1,15 +1,14 @@
|
|
1 |
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
|
2 |
import gradio as gr
|
3 |
import os
|
4 |
-
import spaces
|
5 |
|
6 |
from modules.translation.translation_base import TranslationBase
|
7 |
|
8 |
|
9 |
class NLLBInference(TranslationBase):
|
10 |
def __init__(self,
|
11 |
-
model_dir: str,
|
12 |
-
output_dir: str
|
13 |
):
|
14 |
super().__init__(
|
15 |
model_dir=model_dir,
|
@@ -21,14 +20,16 @@ class NLLBInference(TranslationBase):
|
|
21 |
self.available_target_langs = list(NLLB_AVAILABLE_LANGS.keys())
|
22 |
self.pipeline = None
|
23 |
|
24 |
-
@spaces.GPU(duration=120)
|
25 |
def translate(self,
|
26 |
-
text: str
|
|
|
27 |
):
|
28 |
-
result = self.pipeline(
|
|
|
|
|
|
|
29 |
return result[0]['translation_text']
|
30 |
|
31 |
-
@spaces.GPU(duration=120)
|
32 |
def update_model(self,
|
33 |
model_size: str,
|
34 |
src_lang: str,
|
@@ -39,10 +40,13 @@ class NLLBInference(TranslationBase):
|
|
39 |
print("\nInitializing NLLB Model..\n")
|
40 |
progress(0, desc="Initializing NLLB Model..")
|
41 |
self.current_model_size = model_size
|
|
|
42 |
self.model = AutoModelForSeq2SeqLM.from_pretrained(pretrained_model_name_or_path=model_size,
|
43 |
-
cache_dir=self.model_dir
|
|
|
44 |
self.tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path=model_size,
|
45 |
-
cache_dir=os.path.join(self.model_dir, "tokenizers")
|
|
|
46 |
src_lang = NLLB_AVAILABLE_LANGS[src_lang]
|
47 |
tgt_lang = NLLB_AVAILABLE_LANGS[tgt_lang]
|
48 |
self.pipeline = pipeline("translation",
|
@@ -52,6 +56,18 @@ class NLLBInference(TranslationBase):
|
|
52 |
tgt_lang=tgt_lang,
|
53 |
device=self.device)
|
54 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
55 |
NLLB_AVAILABLE_LANGS = {
|
56 |
"Acehnese (Arabic script)": "ace_Arab",
|
57 |
"Acehnese (Latin script)": "ace_Latn",
|
|
|
1 |
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
|
2 |
import gradio as gr
|
3 |
import os
|
|
|
4 |
|
5 |
from modules.translation.translation_base import TranslationBase
|
6 |
|
7 |
|
8 |
class NLLBInference(TranslationBase):
|
9 |
def __init__(self,
|
10 |
+
model_dir: str = os.path.join("models", "NLLB"),
|
11 |
+
output_dir: str = os.path.join("outputs", "translations")
|
12 |
):
|
13 |
super().__init__(
|
14 |
model_dir=model_dir,
|
|
|
20 |
self.available_target_langs = list(NLLB_AVAILABLE_LANGS.keys())
|
21 |
self.pipeline = None
|
22 |
|
|
|
23 |
def translate(self,
|
24 |
+
text: str,
|
25 |
+
max_length: int
|
26 |
):
|
27 |
+
result = self.pipeline(
|
28 |
+
text,
|
29 |
+
max_length=max_length
|
30 |
+
)
|
31 |
return result[0]['translation_text']
|
32 |
|
|
|
33 |
def update_model(self,
|
34 |
model_size: str,
|
35 |
src_lang: str,
|
|
|
40 |
print("\nInitializing NLLB Model..\n")
|
41 |
progress(0, desc="Initializing NLLB Model..")
|
42 |
self.current_model_size = model_size
|
43 |
+
local_files_only = self.is_model_exists(self.current_model_size)
|
44 |
self.model = AutoModelForSeq2SeqLM.from_pretrained(pretrained_model_name_or_path=model_size,
|
45 |
+
cache_dir=self.model_dir,
|
46 |
+
local_files_only=local_files_only)
|
47 |
self.tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path=model_size,
|
48 |
+
cache_dir=os.path.join(self.model_dir, "tokenizers"),
|
49 |
+
local_files_only=local_files_only)
|
50 |
src_lang = NLLB_AVAILABLE_LANGS[src_lang]
|
51 |
tgt_lang = NLLB_AVAILABLE_LANGS[tgt_lang]
|
52 |
self.pipeline = pipeline("translation",
|
|
|
56 |
tgt_lang=tgt_lang,
|
57 |
device=self.device)
|
58 |
|
59 |
+
def is_model_exists(self,
|
60 |
+
model_size: str):
|
61 |
+
"""Check if model exists or not (Only facebook model)"""
|
62 |
+
prefix = "models--facebook--"
|
63 |
+
_id, model_size_name = model_size.split("/")
|
64 |
+
model_dir_name = prefix + model_size_name
|
65 |
+
model_dir_path = os.path.join(self.model_dir, model_dir_name)
|
66 |
+
if os.path.exists(model_dir_path) and os.listdir(model_dir_path):
|
67 |
+
return True
|
68 |
+
return False
|
69 |
+
|
70 |
+
|
71 |
NLLB_AVAILABLE_LANGS = {
|
72 |
"Acehnese (Arabic script)": "ace_Arab",
|
73 |
"Acehnese (Latin script)": "ace_Latn",
|
modules/translation/translation_base.py
CHANGED
@@ -4,7 +4,6 @@ import gradio as gr
|
|
4 |
from abc import ABC, abstractmethod
|
5 |
from typing import List
|
6 |
from datetime import datetime
|
7 |
-
import spaces
|
8 |
|
9 |
from modules.whisper.whisper_parameter import *
|
10 |
from modules.utils.subtitle_manager import *
|
@@ -12,8 +11,9 @@ from modules.utils.subtitle_manager import *
|
|
12 |
|
13 |
class TranslationBase(ABC):
|
14 |
def __init__(self,
|
15 |
-
model_dir: str,
|
16 |
-
output_dir: str)
|
|
|
17 |
super().__init__()
|
18 |
self.model = None
|
19 |
self.model_dir = model_dir
|
@@ -24,14 +24,13 @@ class TranslationBase(ABC):
|
|
24 |
self.device = self.get_device()
|
25 |
|
26 |
@abstractmethod
|
27 |
-
@spaces.GPU(duration=120)
|
28 |
def translate(self,
|
29 |
-
text: str
|
|
|
30 |
):
|
31 |
pass
|
32 |
|
33 |
@abstractmethod
|
34 |
-
@spaces.GPU(duration=120)
|
35 |
def update_model(self,
|
36 |
model_size: str,
|
37 |
src_lang: str,
|
@@ -40,12 +39,12 @@ class TranslationBase(ABC):
|
|
40 |
):
|
41 |
pass
|
42 |
|
43 |
-
@spaces.GPU(duration=120)
|
44 |
def translate_file(self,
|
45 |
fileobjs: list,
|
46 |
model_size: str,
|
47 |
src_lang: str,
|
48 |
tgt_lang: str,
|
|
|
49 |
add_timestamp: bool,
|
50 |
progress=gr.Progress()) -> list:
|
51 |
"""
|
@@ -61,6 +60,8 @@ class TranslationBase(ABC):
|
|
61 |
Source language of the file to translate from gr.Dropdown()
|
62 |
tgt_lang: str
|
63 |
Target language of the file to translate from gr.Dropdown()
|
|
|
|
|
64 |
add_timestamp: bool
|
65 |
Boolean value from gr.Checkbox() that determines whether to add a timestamp at the end of the filename.
|
66 |
progress: gr.Progress
|
@@ -88,50 +89,44 @@ class TranslationBase(ABC):
|
|
88 |
total_progress = len(parsed_dicts)
|
89 |
for index, dic in enumerate(parsed_dicts):
|
90 |
progress(index / total_progress, desc="Translating..")
|
91 |
-
translated_text = self.translate(dic["sentence"])
|
92 |
dic["sentence"] = translated_text
|
93 |
subtitle = get_serialized_srt(parsed_dicts)
|
94 |
|
95 |
-
timestamp = datetime.now().strftime("%m%d%H%M%S")
|
96 |
-
if add_timestamp:
|
97 |
-
output_path = os.path.join("outputs", "", f"{file_name}-{timestamp}.srt")
|
98 |
-
else:
|
99 |
-
output_path = os.path.join("outputs", "", f"{file_name}.srt")
|
100 |
-
|
101 |
elif file_ext == ".vtt":
|
102 |
parsed_dicts = parse_vtt(file_path=file_path)
|
103 |
total_progress = len(parsed_dicts)
|
104 |
for index, dic in enumerate(parsed_dicts):
|
105 |
progress(index / total_progress, desc="Translating..")
|
106 |
-
translated_text = self.translate(dic["sentence"])
|
107 |
dic["sentence"] = translated_text
|
108 |
subtitle = get_serialized_vtt(parsed_dicts)
|
109 |
|
|
|
110 |
timestamp = datetime.now().strftime("%m%d%H%M%S")
|
111 |
-
|
112 |
-
output_path = os.path.join(self.output_dir, "", f"{file_name}-{timestamp}.vtt")
|
113 |
-
else:
|
114 |
-
output_path = os.path.join(self.output_dir, "", f"{file_name}.vtt")
|
115 |
|
|
|
116 |
write_file(subtitle, output_path)
|
117 |
-
|
|
|
118 |
|
119 |
total_result = ''
|
120 |
-
for file_name,
|
121 |
total_result += '------------------------------------\n'
|
122 |
total_result += f'{file_name}\n\n'
|
123 |
-
total_result += f'{subtitle}'
|
124 |
-
|
125 |
gr_str = f"Done! Subtitle is in the outputs/translation folder.\n\n{total_result}"
|
126 |
-
|
|
|
|
|
|
|
127 |
except Exception as e:
|
128 |
print(f"Error: {str(e)}")
|
129 |
finally:
|
130 |
self.release_cuda_memory()
|
131 |
-
self.remove_input_files([fileobj.name for fileobj in fileobjs])
|
132 |
|
133 |
@staticmethod
|
134 |
-
@spaces.GPU(duration=120)
|
135 |
def get_device():
|
136 |
if torch.cuda.is_available():
|
137 |
return "cuda"
|
@@ -141,7 +136,6 @@ class TranslationBase(ABC):
|
|
141 |
return "cpu"
|
142 |
|
143 |
@staticmethod
|
144 |
-
@spaces.GPU(duration=120)
|
145 |
def release_cuda_memory():
|
146 |
if torch.cuda.is_available():
|
147 |
torch.cuda.empty_cache()
|
|
|
4 |
from abc import ABC, abstractmethod
|
5 |
from typing import List
|
6 |
from datetime import datetime
|
|
|
7 |
|
8 |
from modules.whisper.whisper_parameter import *
|
9 |
from modules.utils.subtitle_manager import *
|
|
|
11 |
|
12 |
class TranslationBase(ABC):
|
13 |
def __init__(self,
|
14 |
+
model_dir: str = os.path.join("models", "NLLB"),
|
15 |
+
output_dir: str = os.path.join("outputs", "translations")
|
16 |
+
):
|
17 |
super().__init__()
|
18 |
self.model = None
|
19 |
self.model_dir = model_dir
|
|
|
24 |
self.device = self.get_device()
|
25 |
|
26 |
@abstractmethod
|
|
|
27 |
def translate(self,
|
28 |
+
text: str,
|
29 |
+
max_length: int
|
30 |
):
|
31 |
pass
|
32 |
|
33 |
@abstractmethod
|
|
|
34 |
def update_model(self,
|
35 |
model_size: str,
|
36 |
src_lang: str,
|
|
|
39 |
):
|
40 |
pass
|
41 |
|
|
|
42 |
def translate_file(self,
|
43 |
fileobjs: list,
|
44 |
model_size: str,
|
45 |
src_lang: str,
|
46 |
tgt_lang: str,
|
47 |
+
max_length: int,
|
48 |
add_timestamp: bool,
|
49 |
progress=gr.Progress()) -> list:
|
50 |
"""
|
|
|
60 |
Source language of the file to translate from gr.Dropdown()
|
61 |
tgt_lang: str
|
62 |
Target language of the file to translate from gr.Dropdown()
|
63 |
+
max_length: int
|
64 |
+
Max length per line to translate
|
65 |
add_timestamp: bool
|
66 |
Boolean value from gr.Checkbox() that determines whether to add a timestamp at the end of the filename.
|
67 |
progress: gr.Progress
|
|
|
89 |
total_progress = len(parsed_dicts)
|
90 |
for index, dic in enumerate(parsed_dicts):
|
91 |
progress(index / total_progress, desc="Translating..")
|
92 |
+
translated_text = self.translate(dic["sentence"], max_length=max_length)
|
93 |
dic["sentence"] = translated_text
|
94 |
subtitle = get_serialized_srt(parsed_dicts)
|
95 |
|
|
|
|
|
|
|
|
|
|
|
|
|
96 |
elif file_ext == ".vtt":
|
97 |
parsed_dicts = parse_vtt(file_path=file_path)
|
98 |
total_progress = len(parsed_dicts)
|
99 |
for index, dic in enumerate(parsed_dicts):
|
100 |
progress(index / total_progress, desc="Translating..")
|
101 |
+
translated_text = self.translate(dic["sentence"], max_length=max_length)
|
102 |
dic["sentence"] = translated_text
|
103 |
subtitle = get_serialized_vtt(parsed_dicts)
|
104 |
|
105 |
+
if add_timestamp:
|
106 |
timestamp = datetime.now().strftime("%m%d%H%M%S")
|
107 |
+
file_name += f"-{timestamp}"
|
|
|
|
|
|
|
108 |
|
109 |
+
output_path = os.path.join(self.output_dir, f"{file_name}{file_ext}")
|
110 |
write_file(subtitle, output_path)
|
111 |
+
|
112 |
+
files_info[file_name] = {"subtitle": subtitle, "path": output_path}
|
113 |
|
114 |
total_result = ''
|
115 |
+
for file_name, info in files_info.items():
|
116 |
total_result += '------------------------------------\n'
|
117 |
total_result += f'{file_name}\n\n'
|
118 |
+
total_result += f'{info["subtitle"]}'
|
|
|
119 |
gr_str = f"Done! Subtitle is in the outputs/translation folder.\n\n{total_result}"
|
120 |
+
|
121 |
+
output_file_paths = [item["path"] for key, item in files_info.items()]
|
122 |
+
return [gr_str, output_file_paths]
|
123 |
+
|
124 |
except Exception as e:
|
125 |
print(f"Error: {str(e)}")
|
126 |
finally:
|
127 |
self.release_cuda_memory()
|
|
|
128 |
|
129 |
@staticmethod
|
|
|
130 |
def get_device():
|
131 |
if torch.cuda.is_available():
|
132 |
return "cuda"
|
|
|
136 |
return "cpu"
|
137 |
|
138 |
@staticmethod
|
|
|
139 |
def release_cuda_memory():
|
140 |
if torch.cuda.is_available():
|
141 |
torch.cuda.empty_cache()
|
modules/utils/files_manager.py
ADDED
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import fnmatch
|
3 |
+
|
4 |
+
from gradio.utils import NamedString
|
5 |
+
|
6 |
+
|
7 |
+
def get_media_files(folder_path, include_sub_directory=False):
|
8 |
+
video_extensions = ['*.mp4', '*.mkv', '*.flv', '*.avi', '*.mov', '*.wmv']
|
9 |
+
audio_extensions = ['*.mp3', '*.wav', '*.aac', '*.flac', '*.ogg', '*.m4a']
|
10 |
+
media_extensions = video_extensions + audio_extensions
|
11 |
+
|
12 |
+
media_files = []
|
13 |
+
|
14 |
+
if include_sub_directory:
|
15 |
+
for root, _, files in os.walk(folder_path):
|
16 |
+
for extension in media_extensions:
|
17 |
+
media_files.extend(
|
18 |
+
os.path.join(root, file) for file in fnmatch.filter(files, extension)
|
19 |
+
if os.path.exists(os.path.join(root, file))
|
20 |
+
)
|
21 |
+
else:
|
22 |
+
for extension in media_extensions:
|
23 |
+
media_files.extend(
|
24 |
+
os.path.join(folder_path, file) for file in fnmatch.filter(os.listdir(folder_path), extension)
|
25 |
+
if os.path.isfile(os.path.join(folder_path, file)) and os.path.exists(os.path.join(folder_path, file))
|
26 |
+
)
|
27 |
+
|
28 |
+
return media_files
|
29 |
+
|
30 |
+
|
31 |
+
def format_gradio_files(files: list):
|
32 |
+
if not files:
|
33 |
+
return files
|
34 |
+
|
35 |
+
gradio_files = []
|
36 |
+
for file in files:
|
37 |
+
gradio_files.append(NamedString(file))
|
38 |
+
return gradio_files
|
39 |
+
|
modules/utils/subtitle_manager.py
CHANGED
@@ -1,7 +1,5 @@
|
|
1 |
import re
|
2 |
|
3 |
-
# Zero GPU
|
4 |
-
import spaces
|
5 |
|
6 |
def timeformat_srt(time):
|
7 |
hours = time // 3600
|
@@ -119,7 +117,7 @@ def get_serialized_vtt(dicts):
|
|
119 |
output += f'{dic["sentence"]}\n\n'
|
120 |
return output
|
121 |
|
122 |
-
|
123 |
def safe_filename(name):
|
124 |
from app import _args
|
125 |
INVALID_FILENAME_CHARS = r'[<>:"/\\|?*\x00-\x1f]'
|
|
|
1 |
import re
|
2 |
|
|
|
|
|
3 |
|
4 |
def timeformat_srt(time):
|
5 |
hours = time // 3600
|
|
|
117 |
output += f'{dic["sentence"]}\n\n'
|
118 |
return output
|
119 |
|
120 |
+
|
121 |
def safe_filename(name):
|
122 |
from app import _args
|
123 |
INVALID_FILENAME_CHARS = r'[<>:"/\\|?*\x00-\x1f]'
|
modules/utils/youtube_manager.py
CHANGED
@@ -1,4 +1,4 @@
|
|
1 |
-
from
|
2 |
import os
|
3 |
|
4 |
|
|
|
1 |
+
from pytubefix import YouTube
|
2 |
import os
|
3 |
|
4 |
|
modules/vad/silero_vad.py
CHANGED
@@ -1,21 +1,25 @@
|
|
1 |
-
from faster_whisper.
|
|
|
|
|
2 |
import numpy as np
|
3 |
-
from typing import BinaryIO, Union, List, Optional
|
4 |
import warnings
|
5 |
import faster_whisper
|
|
|
6 |
import gradio as gr
|
7 |
-
import spaces
|
8 |
|
9 |
|
10 |
class SileroVAD:
|
11 |
def __init__(self):
|
12 |
self.sampling_rate = 16000
|
|
|
|
|
13 |
|
14 |
-
@spaces.GPU
|
15 |
def run(self,
|
16 |
audio: Union[str, BinaryIO, np.ndarray],
|
17 |
vad_parameters: VadOptions,
|
18 |
-
progress: gr.Progress = gr.Progress()
|
|
|
19 |
"""
|
20 |
Run VAD
|
21 |
|
@@ -30,8 +34,10 @@ class SileroVAD:
|
|
30 |
|
31 |
Returns
|
32 |
----------
|
33 |
-
|
34 |
Pre-processed audio with VAD
|
|
|
|
|
35 |
"""
|
36 |
|
37 |
sampling_rate = self.sampling_rate
|
@@ -54,11 +60,10 @@ class SileroVAD:
|
|
54 |
audio = self.collect_chunks(audio, speech_chunks)
|
55 |
duration_after_vad = audio.shape[0] / sampling_rate
|
56 |
|
57 |
-
return audio
|
58 |
|
59 |
-
@staticmethod
|
60 |
-
@spaces.GPU
|
61 |
def get_speech_timestamps(
|
|
|
62 |
audio: np.ndarray,
|
63 |
vad_options: Optional[VadOptions] = None,
|
64 |
progress: gr.Progress = gr.Progress(),
|
@@ -75,6 +80,10 @@ class SileroVAD:
|
|
75 |
Returns:
|
76 |
List of dicts containing begin and end samples of each speech chunk.
|
77 |
"""
|
|
|
|
|
|
|
|
|
78 |
if vad_options is None:
|
79 |
vad_options = VadOptions(**kwargs)
|
80 |
|
@@ -82,15 +91,8 @@ class SileroVAD:
|
|
82 |
min_speech_duration_ms = vad_options.min_speech_duration_ms
|
83 |
max_speech_duration_s = vad_options.max_speech_duration_s
|
84 |
min_silence_duration_ms = vad_options.min_silence_duration_ms
|
85 |
-
window_size_samples =
|
86 |
speech_pad_ms = vad_options.speech_pad_ms
|
87 |
-
|
88 |
-
if window_size_samples not in [512, 1024, 1536]:
|
89 |
-
warnings.warn(
|
90 |
-
"Unusual window_size_samples! Supported window_size_samples:\n"
|
91 |
-
" - [512, 1024, 1536] for 16000 sampling_rate"
|
92 |
-
)
|
93 |
-
|
94 |
sampling_rate = 16000
|
95 |
min_speech_samples = sampling_rate * min_speech_duration_ms / 1000
|
96 |
speech_pad_samples = sampling_rate * speech_pad_ms / 1000
|
@@ -104,8 +106,7 @@ class SileroVAD:
|
|
104 |
|
105 |
audio_length_samples = len(audio)
|
106 |
|
107 |
-
|
108 |
-
state = model.get_initial_state(batch_size=1)
|
109 |
|
110 |
speech_probs = []
|
111 |
for current_start_sample in range(0, audio_length_samples, window_size_samples):
|
@@ -114,7 +115,7 @@ class SileroVAD:
|
|
114 |
chunk = audio[current_start_sample: current_start_sample + window_size_samples]
|
115 |
if len(chunk) < window_size_samples:
|
116 |
chunk = np.pad(chunk, (0, int(window_size_samples - len(chunk))))
|
117 |
-
speech_prob, state = model(chunk, state, sampling_rate)
|
118 |
speech_probs.append(speech_prob)
|
119 |
|
120 |
triggered = False
|
@@ -210,6 +211,9 @@ class SileroVAD:
|
|
210 |
|
211 |
return speeches
|
212 |
|
|
|
|
|
|
|
213 |
@staticmethod
|
214 |
def collect_chunks(audio: np.ndarray, chunks: List[dict]) -> np.ndarray:
|
215 |
"""Collects and concatenates audio chunks."""
|
@@ -241,3 +245,20 @@ class SileroVAD:
|
|
241 |
f"{hours_marker}{minutes:02d}:{seconds:02d}{decimal_marker}{milliseconds:03d}"
|
242 |
)
|
243 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Adapted from https://github.com/SYSTRAN/faster-whisper/blob/master/faster_whisper/vad.py
|
2 |
+
|
3 |
+
from faster_whisper.vad import VadOptions, get_vad_model
|
4 |
import numpy as np
|
5 |
+
from typing import BinaryIO, Union, List, Optional, Tuple
|
6 |
import warnings
|
7 |
import faster_whisper
|
8 |
+
from faster_whisper.transcribe import SpeechTimestampsMap, Segment
|
9 |
import gradio as gr
|
|
|
10 |
|
11 |
|
12 |
class SileroVAD:
|
13 |
def __init__(self):
|
14 |
self.sampling_rate = 16000
|
15 |
+
self.window_size_samples = 512
|
16 |
+
self.model = None
|
17 |
|
|
|
18 |
def run(self,
|
19 |
audio: Union[str, BinaryIO, np.ndarray],
|
20 |
vad_parameters: VadOptions,
|
21 |
+
progress: gr.Progress = gr.Progress()
|
22 |
+
) -> Tuple[np.ndarray, List[dict]]:
|
23 |
"""
|
24 |
Run VAD
|
25 |
|
|
|
34 |
|
35 |
Returns
|
36 |
----------
|
37 |
+
np.ndarray
|
38 |
Pre-processed audio with VAD
|
39 |
+
List[dict]
|
40 |
+
Chunks of speeches to be used to restore the timestamps later
|
41 |
"""
|
42 |
|
43 |
sampling_rate = self.sampling_rate
|
|
|
60 |
audio = self.collect_chunks(audio, speech_chunks)
|
61 |
duration_after_vad = audio.shape[0] / sampling_rate
|
62 |
|
63 |
+
return audio, speech_chunks
|
64 |
|
|
|
|
|
65 |
def get_speech_timestamps(
|
66 |
+
self,
|
67 |
audio: np.ndarray,
|
68 |
vad_options: Optional[VadOptions] = None,
|
69 |
progress: gr.Progress = gr.Progress(),
|
|
|
80 |
Returns:
|
81 |
List of dicts containing begin and end samples of each speech chunk.
|
82 |
"""
|
83 |
+
|
84 |
+
if self.model is None:
|
85 |
+
self.update_model()
|
86 |
+
|
87 |
if vad_options is None:
|
88 |
vad_options = VadOptions(**kwargs)
|
89 |
|
|
|
91 |
min_speech_duration_ms = vad_options.min_speech_duration_ms
|
92 |
max_speech_duration_s = vad_options.max_speech_duration_s
|
93 |
min_silence_duration_ms = vad_options.min_silence_duration_ms
|
94 |
+
window_size_samples = self.window_size_samples
|
95 |
speech_pad_ms = vad_options.speech_pad_ms
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
96 |
sampling_rate = 16000
|
97 |
min_speech_samples = sampling_rate * min_speech_duration_ms / 1000
|
98 |
speech_pad_samples = sampling_rate * speech_pad_ms / 1000
|
|
|
106 |
|
107 |
audio_length_samples = len(audio)
|
108 |
|
109 |
+
state, context = self.model.get_initial_states(batch_size=1)
|
|
|
110 |
|
111 |
speech_probs = []
|
112 |
for current_start_sample in range(0, audio_length_samples, window_size_samples):
|
|
|
115 |
chunk = audio[current_start_sample: current_start_sample + window_size_samples]
|
116 |
if len(chunk) < window_size_samples:
|
117 |
chunk = np.pad(chunk, (0, int(window_size_samples - len(chunk))))
|
118 |
+
speech_prob, state, context = self.model(chunk, state, context, sampling_rate)
|
119 |
speech_probs.append(speech_prob)
|
120 |
|
121 |
triggered = False
|
|
|
211 |
|
212 |
return speeches
|
213 |
|
214 |
+
def update_model(self):
|
215 |
+
self.model = get_vad_model()
|
216 |
+
|
217 |
@staticmethod
|
218 |
def collect_chunks(audio: np.ndarray, chunks: List[dict]) -> np.ndarray:
|
219 |
"""Collects and concatenates audio chunks."""
|
|
|
245 |
f"{hours_marker}{minutes:02d}:{seconds:02d}{decimal_marker}{milliseconds:03d}"
|
246 |
)
|
247 |
|
248 |
+
def restore_speech_timestamps(
|
249 |
+
self,
|
250 |
+
segments: List[dict],
|
251 |
+
speech_chunks: List[dict],
|
252 |
+
sampling_rate: Optional[int] = None,
|
253 |
+
) -> List[dict]:
|
254 |
+
if sampling_rate is None:
|
255 |
+
sampling_rate = self.sampling_rate
|
256 |
+
|
257 |
+
ts_map = SpeechTimestampsMap(speech_chunks, sampling_rate)
|
258 |
+
|
259 |
+
for segment in segments:
|
260 |
+
segment["start"] = ts_map.get_original_time(segment["start"])
|
261 |
+
segment["end"] = ts_map.get_original_time(segment["end"])
|
262 |
+
|
263 |
+
return segments
|
264 |
+
|
modules/whisper/faster_whisper_inference.py
CHANGED
@@ -5,6 +5,7 @@ import torch
|
|
5 |
from typing import BinaryIO, Union, Tuple, List
|
6 |
import faster_whisper
|
7 |
from faster_whisper.vad import VadOptions
|
|
|
8 |
import ctranslate2
|
9 |
import whisper
|
10 |
import gradio as gr
|
@@ -13,31 +14,31 @@ from argparse import Namespace
|
|
13 |
from modules.whisper.whisper_parameter import *
|
14 |
from modules.whisper.whisper_base import WhisperBase
|
15 |
|
16 |
-
import spaces
|
17 |
|
18 |
class FasterWhisperInference(WhisperBase):
|
19 |
def __init__(self,
|
20 |
-
model_dir: str,
|
21 |
-
|
22 |
-
|
23 |
):
|
24 |
super().__init__(
|
25 |
model_dir=model_dir,
|
26 |
-
|
27 |
-
|
28 |
)
|
|
|
|
|
|
|
29 |
self.model_paths = self.get_model_paths()
|
30 |
self.device = self.get_device()
|
31 |
self.available_models = self.model_paths.keys()
|
32 |
-
self.available_compute_types =
|
33 |
-
|
34 |
-
self.download_model(model_size="large-v2", model_dir=self.model_dir)
|
35 |
|
36 |
-
#@spaces.GPU(duration=120)
|
37 |
def transcribe(self,
|
38 |
audio: Union[str, BinaryIO, np.ndarray],
|
|
|
39 |
*whisper_params,
|
40 |
-
progress: gr.Progress = gr.Progress(),
|
41 |
) -> Tuple[List[dict], float]:
|
42 |
"""
|
43 |
transcribe method for faster-whisper.
|
@@ -65,7 +66,16 @@ class FasterWhisperInference(WhisperBase):
|
|
65 |
if params.model_size != self.current_model_size or self.model is None or self.current_compute_type != params.compute_type:
|
66 |
self.update_model(params.model_size, params.compute_type, progress)
|
67 |
|
68 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
69 |
segments, info = self.model.transcribe(
|
70 |
audio=audio,
|
71 |
language=params.lang,
|
@@ -76,7 +86,25 @@ class FasterWhisperInference(WhisperBase):
|
|
76 |
best_of=params.best_of,
|
77 |
patience=params.patience,
|
78 |
temperature=params.temperature,
|
|
|
79 |
compression_ratio_threshold=params.compression_ratio_threshold,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
80 |
)
|
81 |
progress(0, desc="Loading audio..")
|
82 |
|
@@ -90,14 +118,12 @@ class FasterWhisperInference(WhisperBase):
|
|
90 |
})
|
91 |
|
92 |
elapsed_time = time.time() - start_time
|
93 |
-
print("transcribe: finished")
|
94 |
return segments_result, elapsed_time
|
95 |
|
96 |
-
#@spaces.GPU(duration=120)
|
97 |
def update_model(self,
|
98 |
model_size: str,
|
99 |
compute_type: str,
|
100 |
-
progress: gr.Progress
|
101 |
):
|
102 |
"""
|
103 |
Update current model setting
|
@@ -113,7 +139,6 @@ class FasterWhisperInference(WhisperBase):
|
|
113 |
Indicator to show progress directly in gradio.
|
114 |
"""
|
115 |
progress(0, desc="Initializing Model..")
|
116 |
-
print("update_model:")
|
117 |
self.current_model_size = self.model_paths[model_size]
|
118 |
self.current_compute_type = compute_type
|
119 |
self.model = faster_whisper.WhisperModel(
|
@@ -122,7 +147,6 @@ class FasterWhisperInference(WhisperBase):
|
|
122 |
download_root=self.model_dir,
|
123 |
compute_type=self.current_compute_type
|
124 |
)
|
125 |
-
print("update_model: finished")
|
126 |
|
127 |
def get_model_paths(self):
|
128 |
"""
|
@@ -149,22 +173,19 @@ class FasterWhisperInference(WhisperBase):
|
|
149 |
model_paths[model_name] = os.path.join(webui_dir, self.model_dir, model_name)
|
150 |
return model_paths
|
151 |
|
152 |
-
def get_available_compute_type(self):
|
153 |
-
if self.device == "cuda":
|
154 |
-
return ['float32', 'int8_float16', 'float16', 'int8', 'int8_float32']
|
155 |
-
return ['int16', 'float32', 'int8', 'int8_float32']
|
156 |
-
|
157 |
-
def get_device(self):
|
158 |
-
# Because of huggingface spaces bug, just return cpu
|
159 |
-
return "cpu"
|
160 |
-
|
161 |
@staticmethod
|
162 |
-
def
|
163 |
-
|
164 |
-
|
165 |
-
|
166 |
-
|
167 |
-
cache_dir=model_dir
|
168 |
-
)
|
169 |
-
|
170 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
5 |
from typing import BinaryIO, Union, Tuple, List
|
6 |
import faster_whisper
|
7 |
from faster_whisper.vad import VadOptions
|
8 |
+
import ast
|
9 |
import ctranslate2
|
10 |
import whisper
|
11 |
import gradio as gr
|
|
|
14 |
from modules.whisper.whisper_parameter import *
|
15 |
from modules.whisper.whisper_base import WhisperBase
|
16 |
|
|
|
17 |
|
18 |
class FasterWhisperInference(WhisperBase):
|
19 |
def __init__(self,
|
20 |
+
model_dir: str = os.path.join("models", "Whisper", "faster-whisper"),
|
21 |
+
diarization_model_dir: str = os.path.join("models", "Diarization"),
|
22 |
+
output_dir: str = os.path.join("outputs"),
|
23 |
):
|
24 |
super().__init__(
|
25 |
model_dir=model_dir,
|
26 |
+
diarization_model_dir=diarization_model_dir,
|
27 |
+
output_dir=output_dir
|
28 |
)
|
29 |
+
self.model_dir = model_dir
|
30 |
+
os.makedirs(self.model_dir, exist_ok=True)
|
31 |
+
|
32 |
self.model_paths = self.get_model_paths()
|
33 |
self.device = self.get_device()
|
34 |
self.available_models = self.model_paths.keys()
|
35 |
+
self.available_compute_types = ctranslate2.get_supported_compute_types(
|
36 |
+
"cuda") if self.device == "cuda" else ctranslate2.get_supported_compute_types("cpu")
|
|
|
37 |
|
|
|
38 |
def transcribe(self,
|
39 |
audio: Union[str, BinaryIO, np.ndarray],
|
40 |
+
progress: gr.Progress,
|
41 |
*whisper_params,
|
|
|
42 |
) -> Tuple[List[dict], float]:
|
43 |
"""
|
44 |
transcribe method for faster-whisper.
|
|
|
66 |
if params.model_size != self.current_model_size or self.model is None or self.current_compute_type != params.compute_type:
|
67 |
self.update_model(params.model_size, params.compute_type, progress)
|
68 |
|
69 |
+
# None parameters with Textboxes: https://github.com/gradio-app/gradio/issues/8723
|
70 |
+
if not params.initial_prompt:
|
71 |
+
params.initial_prompt = None
|
72 |
+
if not params.prefix:
|
73 |
+
params.prefix = None
|
74 |
+
if not params.hotwords:
|
75 |
+
params.hotwords = None
|
76 |
+
|
77 |
+
params.suppress_tokens = self.format_suppress_tokens_str(params.suppress_tokens)
|
78 |
+
|
79 |
segments, info = self.model.transcribe(
|
80 |
audio=audio,
|
81 |
language=params.lang,
|
|
|
86 |
best_of=params.best_of,
|
87 |
patience=params.patience,
|
88 |
temperature=params.temperature,
|
89 |
+
initial_prompt=params.initial_prompt,
|
90 |
compression_ratio_threshold=params.compression_ratio_threshold,
|
91 |
+
length_penalty=params.length_penalty,
|
92 |
+
repetition_penalty=params.repetition_penalty,
|
93 |
+
no_repeat_ngram_size=params.no_repeat_ngram_size,
|
94 |
+
prefix=params.prefix,
|
95 |
+
suppress_blank=params.suppress_blank,
|
96 |
+
suppress_tokens=params.suppress_tokens,
|
97 |
+
max_initial_timestamp=params.max_initial_timestamp,
|
98 |
+
word_timestamps=params.word_timestamps,
|
99 |
+
prepend_punctuations=params.prepend_punctuations,
|
100 |
+
append_punctuations=params.append_punctuations,
|
101 |
+
max_new_tokens=params.max_new_tokens,
|
102 |
+
chunk_length=params.chunk_length,
|
103 |
+
hallucination_silence_threshold=params.hallucination_silence_threshold,
|
104 |
+
hotwords=params.hotwords,
|
105 |
+
language_detection_threshold=params.language_detection_threshold,
|
106 |
+
language_detection_segments=params.language_detection_segments,
|
107 |
+
prompt_reset_on_temperature=params.prompt_reset_on_temperature,
|
108 |
)
|
109 |
progress(0, desc="Loading audio..")
|
110 |
|
|
|
118 |
})
|
119 |
|
120 |
elapsed_time = time.time() - start_time
|
|
|
121 |
return segments_result, elapsed_time
|
122 |
|
|
|
123 |
def update_model(self,
|
124 |
model_size: str,
|
125 |
compute_type: str,
|
126 |
+
progress: gr.Progress
|
127 |
):
|
128 |
"""
|
129 |
Update current model setting
|
|
|
139 |
Indicator to show progress directly in gradio.
|
140 |
"""
|
141 |
progress(0, desc="Initializing Model..")
|
|
|
142 |
self.current_model_size = self.model_paths[model_size]
|
143 |
self.current_compute_type = compute_type
|
144 |
self.model = faster_whisper.WhisperModel(
|
|
|
147 |
download_root=self.model_dir,
|
148 |
compute_type=self.current_compute_type
|
149 |
)
|
|
|
150 |
|
151 |
def get_model_paths(self):
|
152 |
"""
|
|
|
173 |
model_paths[model_name] = os.path.join(webui_dir, self.model_dir, model_name)
|
174 |
return model_paths
|
175 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
176 |
@staticmethod
|
177 |
+
def get_device():
|
178 |
+
if torch.cuda.is_available():
|
179 |
+
return "cuda"
|
180 |
+
else:
|
181 |
+
return "auto"
|
|
|
|
|
|
|
182 |
|
183 |
+
@staticmethod
|
184 |
+
def format_suppress_tokens_str(suppress_tokens_str: str) -> List[int]:
|
185 |
+
try:
|
186 |
+
suppress_tokens = ast.literal_eval(suppress_tokens_str)
|
187 |
+
if not isinstance(suppress_tokens, list) or not all(isinstance(item, int) for item in suppress_tokens):
|
188 |
+
raise ValueError("Invalid Suppress Tokens. The value must be type of List[int]")
|
189 |
+
return suppress_tokens
|
190 |
+
except Exception as e:
|
191 |
+
raise ValueError("Invalid Suppress Tokens. The value must be type of List[int]")
|
modules/whisper/insanely_fast_whisper_inference.py
CHANGED
@@ -17,15 +17,18 @@ from modules.whisper.whisper_base import WhisperBase
|
|
17 |
|
18 |
class InsanelyFastWhisperInference(WhisperBase):
|
19 |
def __init__(self,
|
20 |
-
model_dir: str,
|
21 |
-
|
22 |
-
|
23 |
):
|
24 |
super().__init__(
|
25 |
model_dir=model_dir,
|
26 |
output_dir=output_dir,
|
27 |
-
|
28 |
)
|
|
|
|
|
|
|
29 |
openai_models = whisper.available_models()
|
30 |
distil_models = ["distil-large-v2", "distil-large-v3", "distil-medium.en", "distil-small.en"]
|
31 |
self.available_models = openai_models + distil_models
|
|
|
17 |
|
18 |
class InsanelyFastWhisperInference(WhisperBase):
|
19 |
def __init__(self,
|
20 |
+
model_dir: str = os.path.join("models", "Whisper", "insanely-fast-whisper"),
|
21 |
+
diarization_model_dir: str = os.path.join("models", "Diarization"),
|
22 |
+
output_dir: str = os.path.join("outputs"),
|
23 |
):
|
24 |
super().__init__(
|
25 |
model_dir=model_dir,
|
26 |
output_dir=output_dir,
|
27 |
+
diarization_model_dir=diarization_model_dir
|
28 |
)
|
29 |
+
self.model_dir = model_dir
|
30 |
+
os.makedirs(self.model_dir, exist_ok=True)
|
31 |
+
|
32 |
openai_models = whisper.available_models()
|
33 |
distil_models = ["distil-large-v2", "distil-large-v3", "distil-medium.en", "distil-small.en"]
|
34 |
self.available_models = openai_models + distil_models
|
modules/whisper/whisper_Inference.py
CHANGED
@@ -4,6 +4,7 @@ import time
|
|
4 |
from typing import BinaryIO, Union, Tuple, List
|
5 |
import numpy as np
|
6 |
import torch
|
|
|
7 |
from argparse import Namespace
|
8 |
|
9 |
from modules.whisper.whisper_base import WhisperBase
|
@@ -12,14 +13,14 @@ from modules.whisper.whisper_parameter import *
|
|
12 |
|
13 |
class WhisperInference(WhisperBase):
|
14 |
def __init__(self,
|
15 |
-
model_dir: str,
|
16 |
-
|
17 |
-
|
18 |
):
|
19 |
super().__init__(
|
20 |
model_dir=model_dir,
|
21 |
output_dir=output_dir,
|
22 |
-
|
23 |
)
|
24 |
|
25 |
def transcribe(self,
|
|
|
4 |
from typing import BinaryIO, Union, Tuple, List
|
5 |
import numpy as np
|
6 |
import torch
|
7 |
+
import os
|
8 |
from argparse import Namespace
|
9 |
|
10 |
from modules.whisper.whisper_base import WhisperBase
|
|
|
13 |
|
14 |
class WhisperInference(WhisperBase):
|
15 |
def __init__(self,
|
16 |
+
model_dir: str = os.path.join("models", "Whisper"),
|
17 |
+
diarization_model_dir: str = os.path.join("models", "Diarization"),
|
18 |
+
output_dir: str = os.path.join("outputs"),
|
19 |
):
|
20 |
super().__init__(
|
21 |
model_dir=model_dir,
|
22 |
output_dir=output_dir,
|
23 |
+
diarization_model_dir=diarization_model_dir
|
24 |
)
|
25 |
|
26 |
def transcribe(self,
|
modules/whisper/whisper_base.py
CHANGED
@@ -6,13 +6,12 @@ from abc import ABC, abstractmethod
|
|
6 |
from typing import BinaryIO, Union, Tuple, List
|
7 |
import numpy as np
|
8 |
from datetime import datetime
|
9 |
-
from argparse import Namespace
|
10 |
from faster_whisper.vad import VadOptions
|
11 |
from dataclasses import astuple
|
12 |
-
import spaces
|
13 |
|
14 |
from modules.utils.subtitle_manager import get_srt, get_vtt, get_txt, write_file, safe_filename
|
15 |
from modules.utils.youtube_manager import get_ytdata, get_ytaudio
|
|
|
16 |
from modules.whisper.whisper_parameter import *
|
17 |
from modules.diarize.diarizer import Diarizer
|
18 |
from modules.vad.silero_vad import SileroVAD
|
@@ -20,51 +19,50 @@ from modules.vad.silero_vad import SileroVAD
|
|
20 |
|
21 |
class WhisperBase(ABC):
|
22 |
def __init__(self,
|
23 |
-
model_dir: str,
|
24 |
-
|
25 |
-
|
26 |
):
|
27 |
-
self.model = None
|
28 |
-
self.current_model_size = None
|
29 |
self.model_dir = model_dir
|
30 |
self.output_dir = output_dir
|
31 |
os.makedirs(self.output_dir, exist_ok=True)
|
32 |
os.makedirs(self.model_dir, exist_ok=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
33 |
self.available_models = whisper.available_models()
|
34 |
self.available_langs = sorted(list(whisper.tokenizer.LANGUAGES.values()))
|
35 |
self.translatable_models = ["large", "large-v1", "large-v2", "large-v3"]
|
36 |
self.device = self.get_device()
|
37 |
self.available_compute_types = ["float16", "float32"]
|
38 |
self.current_compute_type = "float16" if self.device == "cuda" else "float32"
|
39 |
-
self.diarizer = Diarizer(
|
40 |
-
model_dir=args.diarization_model_dir
|
41 |
-
)
|
42 |
-
self.vad = SileroVAD()
|
43 |
|
44 |
@abstractmethod
|
45 |
-
#@spaces.GPU(duration=120)
|
46 |
def transcribe(self,
|
47 |
audio: Union[str, BinaryIO, np.ndarray],
|
|
|
48 |
*whisper_params,
|
49 |
-
progress: gr.Progress = gr.Progress(),
|
50 |
):
|
|
|
51 |
pass
|
52 |
|
53 |
@abstractmethod
|
54 |
-
@spaces.GPU(duration=120)
|
55 |
def update_model(self,
|
56 |
model_size: str,
|
57 |
compute_type: str,
|
58 |
-
progress: gr.Progress
|
59 |
):
|
|
|
60 |
pass
|
61 |
|
62 |
-
# spaces is problematic
|
63 |
-
#@spaces.GPU(duration=120)
|
64 |
def run(self,
|
65 |
audio: Union[str, BinaryIO, np.ndarray],
|
|
|
66 |
*whisper_params,
|
67 |
-
progress: gr.Progress = gr.Progress(),
|
68 |
) -> Tuple[List[dict], float]:
|
69 |
"""
|
70 |
Run transcription with conditional pre-processing and post-processing.
|
@@ -89,33 +87,44 @@ class WhisperBase(ABC):
|
|
89 |
"""
|
90 |
params = WhisperParameters.as_value(*whisper_params)
|
91 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
92 |
if params.vad_filter:
|
|
|
|
|
|
|
|
|
93 |
vad_options = VadOptions(
|
94 |
threshold=params.threshold,
|
95 |
min_speech_duration_ms=params.min_speech_duration_ms,
|
96 |
max_speech_duration_s=params.max_speech_duration_s,
|
97 |
min_silence_duration_ms=params.min_silence_duration_ms,
|
98 |
-
window_size_samples=params.window_size_samples,
|
99 |
speech_pad_ms=params.speech_pad_ms
|
100 |
)
|
101 |
-
|
|
|
102 |
audio=audio,
|
103 |
vad_parameters=vad_options,
|
104 |
progress=progress
|
105 |
)
|
106 |
|
107 |
-
if params.lang == "Automatic Detection":
|
108 |
-
params.lang = None
|
109 |
-
else:
|
110 |
-
language_code_dict = {value: key for key, value in whisper.tokenizer.LANGUAGES.items()}
|
111 |
-
params.lang = language_code_dict[params.lang]
|
112 |
-
|
113 |
result, elapsed_time = self.transcribe(
|
114 |
audio,
|
115 |
-
|
116 |
-
|
117 |
)
|
118 |
|
|
|
|
|
|
|
|
|
|
|
|
|
119 |
if params.is_diarize:
|
120 |
result, elapsed_time_diarization = self.diarizer.run(
|
121 |
audio=audio,
|
@@ -126,15 +135,14 @@ class WhisperBase(ABC):
|
|
126 |
elapsed_time += elapsed_time_diarization
|
127 |
return result, elapsed_time
|
128 |
|
129 |
-
# spaces is problematic
|
130 |
-
#@spaces.GPU(duration=120)
|
131 |
def transcribe_file(self,
|
132 |
-
files,
|
133 |
-
|
134 |
-
|
135 |
-
|
136 |
progress=gr.Progress(),
|
137 |
-
|
|
|
138 |
"""
|
139 |
Write subtitle file from Files
|
140 |
|
@@ -142,6 +150,9 @@ class WhisperBase(ABC):
|
|
142 |
----------
|
143 |
files: list
|
144 |
List of files to transcribe from gr.Files()
|
|
|
|
|
|
|
145 |
file_format: str
|
146 |
Subtitle File format to write from gr.Dropdown(). Supported format: [SRT, WebVTT, txt]
|
147 |
add_timestamp: bool
|
@@ -159,16 +170,19 @@ class WhisperBase(ABC):
|
|
159 |
Output file path to return to gr.Files()
|
160 |
"""
|
161 |
try:
|
|
|
|
|
|
|
|
|
162 |
files_info = {}
|
163 |
for file in files:
|
164 |
transcribed_segments, time_for_task = self.run(
|
165 |
file.name,
|
|
|
166 |
*whisper_params,
|
167 |
-
progress=progress
|
168 |
)
|
169 |
|
170 |
file_name, file_ext = os.path.splitext(os.path.basename(file.name))
|
171 |
-
file_name = safe_filename(file_name)
|
172 |
subtitle, file_path = self.generate_and_write_file(
|
173 |
file_name=file_name,
|
174 |
transcribed_segments=transcribed_segments,
|
@@ -176,7 +190,6 @@ class WhisperBase(ABC):
|
|
176 |
file_format=file_format,
|
177 |
output_dir=self.output_dir
|
178 |
)
|
179 |
-
print("generated sub finished: ")
|
180 |
files_info[file_name] = {"subtitle": subtitle, "time_for_task": time_for_task, "path": file_path}
|
181 |
|
182 |
total_result = ''
|
@@ -195,16 +208,15 @@ class WhisperBase(ABC):
|
|
195 |
except Exception as e:
|
196 |
print(f"Error transcribing file: {e}")
|
197 |
finally:
|
198 |
-
|
199 |
if not files:
|
200 |
self.remove_input_files([file.name for file in files])
|
201 |
|
202 |
-
#@spaces.GPU(duration=120)
|
203 |
def transcribe_mic(self,
|
204 |
mic_audio: str,
|
205 |
file_format: str,
|
|
|
206 |
*whisper_params,
|
207 |
-
progress: gr.Progress = gr.Progress(),
|
208 |
) -> list:
|
209 |
"""
|
210 |
Write subtitle file from microphone
|
@@ -231,8 +243,8 @@ class WhisperBase(ABC):
|
|
231 |
progress(0, desc="Loading Audio..")
|
232 |
transcribed_segments, time_for_task = self.run(
|
233 |
mic_audio,
|
|
|
234 |
*whisper_params,
|
235 |
-
progress=progress
|
236 |
)
|
237 |
progress(1, desc="Completed!")
|
238 |
|
@@ -252,13 +264,12 @@ class WhisperBase(ABC):
|
|
252 |
self.release_cuda_memory()
|
253 |
self.remove_input_files([mic_audio])
|
254 |
|
255 |
-
#@spaces.GPU(duration=120)
|
256 |
def transcribe_youtube(self,
|
257 |
youtube_link: str,
|
258 |
file_format: str,
|
259 |
add_timestamp: bool,
|
|
|
260 |
*whisper_params,
|
261 |
-
progress: gr.Progress = gr.Progress(),
|
262 |
) -> list:
|
263 |
"""
|
264 |
Write subtitle file from Youtube
|
@@ -290,8 +301,8 @@ class WhisperBase(ABC):
|
|
290 |
|
291 |
transcribed_segments, time_for_task = self.run(
|
292 |
audio,
|
|
|
293 |
*whisper_params,
|
294 |
-
progress=progress
|
295 |
)
|
296 |
|
297 |
progress(1, desc="Completed!")
|
@@ -318,13 +329,12 @@ class WhisperBase(ABC):
|
|
318 |
else:
|
319 |
file_path = get_ytaudio(yt)
|
320 |
|
321 |
-
|
322 |
self.remove_input_files([file_path])
|
323 |
except Exception as cleanup_error:
|
324 |
pass
|
325 |
|
326 |
@staticmethod
|
327 |
-
@spaces.GPU(duration=120)
|
328 |
def generate_and_write_file(file_name: str,
|
329 |
transcribed_segments: list,
|
330 |
add_timestamp: bool,
|
@@ -354,8 +364,8 @@ class WhisperBase(ABC):
|
|
354 |
output_path: str
|
355 |
output file path
|
356 |
"""
|
357 |
-
timestamp = datetime.now().strftime("%m%d%H%M%S")
|
358 |
if add_timestamp:
|
|
|
359 |
output_path = os.path.join(output_dir, f"{file_name}-{timestamp}")
|
360 |
else:
|
361 |
output_path = os.path.join(output_dir, f"{file_name}")
|
@@ -363,17 +373,16 @@ class WhisperBase(ABC):
|
|
363 |
if file_format == "SRT":
|
364 |
content = get_srt(transcribed_segments)
|
365 |
output_path += '.srt'
|
366 |
-
write_file(content, output_path)
|
367 |
|
368 |
elif file_format == "WebVTT":
|
369 |
content = get_vtt(transcribed_segments)
|
370 |
output_path += '.vtt'
|
371 |
-
write_file(content, output_path)
|
372 |
|
373 |
elif file_format == "txt":
|
374 |
content = get_txt(transcribed_segments)
|
375 |
output_path += '.txt'
|
376 |
-
|
|
|
377 |
return content, output_path
|
378 |
|
379 |
@staticmethod
|
@@ -403,12 +412,6 @@ class WhisperBase(ABC):
|
|
403 |
|
404 |
return time_str.strip()
|
405 |
|
406 |
-
@staticmethod
|
407 |
-
def release_cuda_memory():
|
408 |
-
if torch.cuda.is_available():
|
409 |
-
torch.cuda.empty_cache()
|
410 |
-
torch.cuda.reset_max_memory_allocated()
|
411 |
-
|
412 |
@staticmethod
|
413 |
def get_device():
|
414 |
if torch.cuda.is_available():
|
@@ -418,6 +421,12 @@ class WhisperBase(ABC):
|
|
418 |
else:
|
419 |
return "cpu"
|
420 |
|
|
|
|
|
|
|
|
|
|
|
|
|
421 |
@staticmethod
|
422 |
def remove_input_files(file_paths: List[str]):
|
423 |
if not file_paths:
|
|
|
6 |
from typing import BinaryIO, Union, Tuple, List
|
7 |
import numpy as np
|
8 |
from datetime import datetime
|
|
|
9 |
from faster_whisper.vad import VadOptions
|
10 |
from dataclasses import astuple
|
|
|
11 |
|
12 |
from modules.utils.subtitle_manager import get_srt, get_vtt, get_txt, write_file, safe_filename
|
13 |
from modules.utils.youtube_manager import get_ytdata, get_ytaudio
|
14 |
+
from modules.utils.files_manager import get_media_files, format_gradio_files
|
15 |
from modules.whisper.whisper_parameter import *
|
16 |
from modules.diarize.diarizer import Diarizer
|
17 |
from modules.vad.silero_vad import SileroVAD
|
|
|
19 |
|
20 |
class WhisperBase(ABC):
|
21 |
def __init__(self,
|
22 |
+
model_dir: str = os.path.join("models", "Whisper"),
|
23 |
+
diarization_model_dir: str = os.path.join("models", "Diarization"),
|
24 |
+
output_dir: str = os.path.join("outputs"),
|
25 |
):
|
|
|
|
|
26 |
self.model_dir = model_dir
|
27 |
self.output_dir = output_dir
|
28 |
os.makedirs(self.output_dir, exist_ok=True)
|
29 |
os.makedirs(self.model_dir, exist_ok=True)
|
30 |
+
self.diarizer = Diarizer(
|
31 |
+
model_dir=diarization_model_dir
|
32 |
+
)
|
33 |
+
self.vad = SileroVAD()
|
34 |
+
|
35 |
+
self.model = None
|
36 |
+
self.current_model_size = None
|
37 |
self.available_models = whisper.available_models()
|
38 |
self.available_langs = sorted(list(whisper.tokenizer.LANGUAGES.values()))
|
39 |
self.translatable_models = ["large", "large-v1", "large-v2", "large-v3"]
|
40 |
self.device = self.get_device()
|
41 |
self.available_compute_types = ["float16", "float32"]
|
42 |
self.current_compute_type = "float16" if self.device == "cuda" else "float32"
|
|
|
|
|
|
|
|
|
43 |
|
44 |
@abstractmethod
|
|
|
45 |
def transcribe(self,
|
46 |
audio: Union[str, BinaryIO, np.ndarray],
|
47 |
+
progress: gr.Progress,
|
48 |
*whisper_params,
|
|
|
49 |
):
|
50 |
+
"""Inference whisper model to transcribe"""
|
51 |
pass
|
52 |
|
53 |
@abstractmethod
|
|
|
54 |
def update_model(self,
|
55 |
model_size: str,
|
56 |
compute_type: str,
|
57 |
+
progress: gr.Progress
|
58 |
):
|
59 |
+
"""Initialize whisper model"""
|
60 |
pass
|
61 |
|
|
|
|
|
62 |
def run(self,
|
63 |
audio: Union[str, BinaryIO, np.ndarray],
|
64 |
+
progress: gr.Progress,
|
65 |
*whisper_params,
|
|
|
66 |
) -> Tuple[List[dict], float]:
|
67 |
"""
|
68 |
Run transcription with conditional pre-processing and post-processing.
|
|
|
87 |
"""
|
88 |
params = WhisperParameters.as_value(*whisper_params)
|
89 |
|
90 |
+
if params.lang == "Automatic Detection":
|
91 |
+
params.lang = None
|
92 |
+
else:
|
93 |
+
language_code_dict = {value: key for key, value in whisper.tokenizer.LANGUAGES.items()}
|
94 |
+
params.lang = language_code_dict[params.lang]
|
95 |
+
|
96 |
+
speech_chunks = None
|
97 |
if params.vad_filter:
|
98 |
+
# Explicit value set for float('inf') from gr.Number()
|
99 |
+
if params.max_speech_duration_s >= 9999:
|
100 |
+
params.max_speech_duration_s = float('inf')
|
101 |
+
|
102 |
vad_options = VadOptions(
|
103 |
threshold=params.threshold,
|
104 |
min_speech_duration_ms=params.min_speech_duration_ms,
|
105 |
max_speech_duration_s=params.max_speech_duration_s,
|
106 |
min_silence_duration_ms=params.min_silence_duration_ms,
|
|
|
107 |
speech_pad_ms=params.speech_pad_ms
|
108 |
)
|
109 |
+
|
110 |
+
audio, speech_chunks = self.vad.run(
|
111 |
audio=audio,
|
112 |
vad_parameters=vad_options,
|
113 |
progress=progress
|
114 |
)
|
115 |
|
|
|
|
|
|
|
|
|
|
|
|
|
116 |
result, elapsed_time = self.transcribe(
|
117 |
audio,
|
118 |
+
progress,
|
119 |
+
*astuple(params)
|
120 |
)
|
121 |
|
122 |
+
if params.vad_filter:
|
123 |
+
result = self.vad.restore_speech_timestamps(
|
124 |
+
segments=result,
|
125 |
+
speech_chunks=speech_chunks,
|
126 |
+
)
|
127 |
+
|
128 |
if params.is_diarize:
|
129 |
result, elapsed_time_diarization = self.diarizer.run(
|
130 |
audio=audio,
|
|
|
135 |
elapsed_time += elapsed_time_diarization
|
136 |
return result, elapsed_time
|
137 |
|
|
|
|
|
138 |
def transcribe_file(self,
|
139 |
+
files: list,
|
140 |
+
input_folder_path: str,
|
141 |
+
file_format: str,
|
142 |
+
add_timestamp: bool,
|
143 |
progress=gr.Progress(),
|
144 |
+
*whisper_params,
|
145 |
+
) -> list:
|
146 |
"""
|
147 |
Write subtitle file from Files
|
148 |
|
|
|
150 |
----------
|
151 |
files: list
|
152 |
List of files to transcribe from gr.Files()
|
153 |
+
input_folder_path: str
|
154 |
+
Input folder path to transcribe from gr.Textbox(). If this is provided, `files` will be ignored and
|
155 |
+
this will be used instead.
|
156 |
file_format: str
|
157 |
Subtitle File format to write from gr.Dropdown(). Supported format: [SRT, WebVTT, txt]
|
158 |
add_timestamp: bool
|
|
|
170 |
Output file path to return to gr.Files()
|
171 |
"""
|
172 |
try:
|
173 |
+
if input_folder_path:
|
174 |
+
files = get_media_files(input_folder_path)
|
175 |
+
files = format_gradio_files(files)
|
176 |
+
|
177 |
files_info = {}
|
178 |
for file in files:
|
179 |
transcribed_segments, time_for_task = self.run(
|
180 |
file.name,
|
181 |
+
progress,
|
182 |
*whisper_params,
|
|
|
183 |
)
|
184 |
|
185 |
file_name, file_ext = os.path.splitext(os.path.basename(file.name))
|
|
|
186 |
subtitle, file_path = self.generate_and_write_file(
|
187 |
file_name=file_name,
|
188 |
transcribed_segments=transcribed_segments,
|
|
|
190 |
file_format=file_format,
|
191 |
output_dir=self.output_dir
|
192 |
)
|
|
|
193 |
files_info[file_name] = {"subtitle": subtitle, "time_for_task": time_for_task, "path": file_path}
|
194 |
|
195 |
total_result = ''
|
|
|
208 |
except Exception as e:
|
209 |
print(f"Error transcribing file: {e}")
|
210 |
finally:
|
211 |
+
self.release_cuda_memory()
|
212 |
if not files:
|
213 |
self.remove_input_files([file.name for file in files])
|
214 |
|
|
|
215 |
def transcribe_mic(self,
|
216 |
mic_audio: str,
|
217 |
file_format: str,
|
218 |
+
progress=gr.Progress(),
|
219 |
*whisper_params,
|
|
|
220 |
) -> list:
|
221 |
"""
|
222 |
Write subtitle file from microphone
|
|
|
243 |
progress(0, desc="Loading Audio..")
|
244 |
transcribed_segments, time_for_task = self.run(
|
245 |
mic_audio,
|
246 |
+
progress,
|
247 |
*whisper_params,
|
|
|
248 |
)
|
249 |
progress(1, desc="Completed!")
|
250 |
|
|
|
264 |
self.release_cuda_memory()
|
265 |
self.remove_input_files([mic_audio])
|
266 |
|
|
|
267 |
def transcribe_youtube(self,
|
268 |
youtube_link: str,
|
269 |
file_format: str,
|
270 |
add_timestamp: bool,
|
271 |
+
progress=gr.Progress(),
|
272 |
*whisper_params,
|
|
|
273 |
) -> list:
|
274 |
"""
|
275 |
Write subtitle file from Youtube
|
|
|
301 |
|
302 |
transcribed_segments, time_for_task = self.run(
|
303 |
audio,
|
304 |
+
progress,
|
305 |
*whisper_params,
|
|
|
306 |
)
|
307 |
|
308 |
progress(1, desc="Completed!")
|
|
|
329 |
else:
|
330 |
file_path = get_ytaudio(yt)
|
331 |
|
332 |
+
self.release_cuda_memory()
|
333 |
self.remove_input_files([file_path])
|
334 |
except Exception as cleanup_error:
|
335 |
pass
|
336 |
|
337 |
@staticmethod
|
|
|
338 |
def generate_and_write_file(file_name: str,
|
339 |
transcribed_segments: list,
|
340 |
add_timestamp: bool,
|
|
|
364 |
output_path: str
|
365 |
output file path
|
366 |
"""
|
|
|
367 |
if add_timestamp:
|
368 |
+
timestamp = datetime.now().strftime("%m%d%H%M%S")
|
369 |
output_path = os.path.join(output_dir, f"{file_name}-{timestamp}")
|
370 |
else:
|
371 |
output_path = os.path.join(output_dir, f"{file_name}")
|
|
|
373 |
if file_format == "SRT":
|
374 |
content = get_srt(transcribed_segments)
|
375 |
output_path += '.srt'
|
|
|
376 |
|
377 |
elif file_format == "WebVTT":
|
378 |
content = get_vtt(transcribed_segments)
|
379 |
output_path += '.vtt'
|
|
|
380 |
|
381 |
elif file_format == "txt":
|
382 |
content = get_txt(transcribed_segments)
|
383 |
output_path += '.txt'
|
384 |
+
|
385 |
+
write_file(content, output_path)
|
386 |
return content, output_path
|
387 |
|
388 |
@staticmethod
|
|
|
412 |
|
413 |
return time_str.strip()
|
414 |
|
|
|
|
|
|
|
|
|
|
|
|
|
415 |
@staticmethod
|
416 |
def get_device():
|
417 |
if torch.cuda.is_available():
|
|
|
421 |
else:
|
422 |
return "cpu"
|
423 |
|
424 |
+
@staticmethod
|
425 |
+
def release_cuda_memory():
|
426 |
+
if torch.cuda.is_available():
|
427 |
+
torch.cuda.empty_cache()
|
428 |
+
torch.cuda.reset_max_memory_allocated()
|
429 |
+
|
430 |
@staticmethod
|
431 |
def remove_input_files(file_paths: List[str]):
|
432 |
if not file_paths:
|
modules/whisper/whisper_factory.py
ADDED
@@ -0,0 +1,81 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Optional
|
2 |
+
import os
|
3 |
+
|
4 |
+
from modules.whisper.faster_whisper_inference import FasterWhisperInference
|
5 |
+
from modules.whisper.whisper_Inference import WhisperInference
|
6 |
+
from modules.whisper.insanely_fast_whisper_inference import InsanelyFastWhisperInference
|
7 |
+
from modules.whisper.whisper_base import WhisperBase
|
8 |
+
|
9 |
+
|
10 |
+
class WhisperFactory:
|
11 |
+
@staticmethod
|
12 |
+
def create_whisper_inference(
|
13 |
+
whisper_type: str,
|
14 |
+
whisper_model_dir: str = os.path.join("models", "Whisper"),
|
15 |
+
faster_whisper_model_dir: str = os.path.join("models", "Whisper", "faster-whisper"),
|
16 |
+
insanely_fast_whisper_model_dir: str = os.path.join("models", "Whisper", "insanely-fast-whisper"),
|
17 |
+
diarization_model_dir: str = os.path.join("models", "Diarization"),
|
18 |
+
output_dir: str = os.path.join("outputs"),
|
19 |
+
) -> "WhisperBase":
|
20 |
+
"""
|
21 |
+
Create a whisper inference class based on the provided whisper_type.
|
22 |
+
|
23 |
+
Parameters
|
24 |
+
----------
|
25 |
+
whisper_type : str
|
26 |
+
The type of Whisper implementation to use. Supported values (case-insensitive):
|
27 |
+
- "faster-whisper": https://github.com/openai/whisper
|
28 |
+
- "whisper": https://github.com/openai/whisper
|
29 |
+
- "insanely-fast-whisper": https://github.com/Vaibhavs10/insanely-fast-whisper
|
30 |
+
whisper_model_dir : str
|
31 |
+
Directory path for the Whisper model.
|
32 |
+
faster_whisper_model_dir : str
|
33 |
+
Directory path for the Faster Whisper model.
|
34 |
+
insanely_fast_whisper_model_dir : str
|
35 |
+
Directory path for the Insanely Fast Whisper model.
|
36 |
+
diarization_model_dir : str
|
37 |
+
Directory path for the diarization model.
|
38 |
+
output_dir : str
|
39 |
+
Directory path where output files will be saved.
|
40 |
+
|
41 |
+
Returns
|
42 |
+
-------
|
43 |
+
WhisperBase
|
44 |
+
An instance of the appropriate whisper inference class based on the whisper_type.
|
45 |
+
"""
|
46 |
+
# Temporal fix of the bug : https://github.com/jhj0517/Whisper-WebUI/issues/144
|
47 |
+
os.environ['KMP_DUPLICATE_LIB_OK'] = 'True'
|
48 |
+
|
49 |
+
whisper_type = whisper_type.lower().strip()
|
50 |
+
|
51 |
+
faster_whisper_typos = ["faster_whisper", "faster-whisper", "fasterwhisper"]
|
52 |
+
whisper_typos = ["whisper"]
|
53 |
+
insanely_fast_whisper_typos = [
|
54 |
+
"insanely_fast_whisper", "insanely-fast-whisper", "insanelyfastwhisper",
|
55 |
+
"insanely_faster_whisper", "insanely-faster-whisper", "insanelyfasterwhisper"
|
56 |
+
]
|
57 |
+
|
58 |
+
if whisper_type in faster_whisper_typos:
|
59 |
+
return FasterWhisperInference(
|
60 |
+
model_dir=faster_whisper_model_dir,
|
61 |
+
output_dir=output_dir,
|
62 |
+
diarization_model_dir=diarization_model_dir
|
63 |
+
)
|
64 |
+
elif whisper_type in whisper_typos:
|
65 |
+
return WhisperInference(
|
66 |
+
model_dir=whisper_model_dir,
|
67 |
+
output_dir=output_dir,
|
68 |
+
diarization_model_dir=diarization_model_dir
|
69 |
+
)
|
70 |
+
elif whisper_type in insanely_fast_whisper_typos:
|
71 |
+
return InsanelyFastWhisperInference(
|
72 |
+
model_dir=insanely_fast_whisper_model_dir,
|
73 |
+
output_dir=output_dir,
|
74 |
+
diarization_model_dir=diarization_model_dir
|
75 |
+
)
|
76 |
+
else:
|
77 |
+
return FasterWhisperInference(
|
78 |
+
model_dir=faster_whisper_model_dir,
|
79 |
+
output_dir=output_dir,
|
80 |
+
diarization_model_dir=diarization_model_dir
|
81 |
+
)
|
modules/whisper/whisper_parameter.py
CHANGED
@@ -15,6 +15,7 @@ class WhisperParameters:
|
|
15 |
best_of: gr.Number
|
16 |
patience: gr.Number
|
17 |
condition_on_previous_text: gr.Checkbox
|
|
|
18 |
initial_prompt: gr.Textbox
|
19 |
temperature: gr.Slider
|
20 |
compression_ratio_threshold: gr.Number
|
@@ -23,13 +24,28 @@ class WhisperParameters:
|
|
23 |
min_speech_duration_ms: gr.Number
|
24 |
max_speech_duration_s: gr.Number
|
25 |
min_silence_duration_ms: gr.Number
|
26 |
-
window_size_sample: gr.Number
|
27 |
speech_pad_ms: gr.Number
|
28 |
chunk_length_s: gr.Number
|
29 |
batch_size: gr.Number
|
30 |
is_diarize: gr.Checkbox
|
31 |
hf_token: gr.Textbox
|
32 |
diarization_device: gr.Dropdown
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
33 |
"""
|
34 |
A data class for Gradio components of the Whisper Parameters. Use "before" Gradio pre-processing.
|
35 |
This data class is used to mitigate the key-value problem between Gradio components and function parameters.
|
@@ -111,11 +127,6 @@ class WhisperParameters:
|
|
111 |
This parameter is related with Silero VAD. In the end of each speech chunk wait for min_silence_duration_ms
|
112 |
before separating it
|
113 |
|
114 |
-
window_size_samples: gr.Number
|
115 |
-
This parameter is related with Silero VAD. Audio chunks of window_size_samples size are fed to the silero VAD model.
|
116 |
-
WARNING! Silero VAD models were trained using 512, 1024, 1536 samples for 16000 sample rate.
|
117 |
-
Values other than these may affect model performance!!
|
118 |
-
|
119 |
speech_pad_ms: gr.Number
|
120 |
This parameter is related with Silero VAD. Final speech chunks are padded by speech_pad_ms each side
|
121 |
|
@@ -135,6 +146,62 @@ class WhisperParameters:
|
|
135 |
|
136 |
diarization_device: gr.Dropdown
|
137 |
This parameter is related with whisperx. Device to run diarization model
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
138 |
"""
|
139 |
|
140 |
def as_list(self) -> list:
|
@@ -159,33 +226,7 @@ class WhisperParameters:
|
|
159 |
WhisperValues
|
160 |
Data class that has values of parameters
|
161 |
"""
|
162 |
-
return WhisperValues(
|
163 |
-
model_size=args[0],
|
164 |
-
lang=args[1],
|
165 |
-
is_translate=args[2],
|
166 |
-
beam_size=args[3],
|
167 |
-
log_prob_threshold=args[4],
|
168 |
-
no_speech_threshold=args[5],
|
169 |
-
compute_type=args[6],
|
170 |
-
best_of=args[7],
|
171 |
-
patience=args[8],
|
172 |
-
condition_on_previous_text=args[9],
|
173 |
-
initial_prompt=args[10],
|
174 |
-
temperature=args[11],
|
175 |
-
compression_ratio_threshold=args[12],
|
176 |
-
vad_filter=args[13],
|
177 |
-
threshold=args[14],
|
178 |
-
min_speech_duration_ms=args[15],
|
179 |
-
max_speech_duration_s=args[16],
|
180 |
-
min_silence_duration_ms=args[17],
|
181 |
-
window_size_samples=args[18],
|
182 |
-
speech_pad_ms=args[19],
|
183 |
-
chunk_length_s=args[20],
|
184 |
-
batch_size=args[21],
|
185 |
-
is_diarize=args[22],
|
186 |
-
hf_token=args[23],
|
187 |
-
diarization_device=args[24]
|
188 |
-
)
|
189 |
|
190 |
|
191 |
@dataclass
|
@@ -200,6 +241,7 @@ class WhisperValues:
|
|
200 |
best_of: int
|
201 |
patience: float
|
202 |
condition_on_previous_text: bool
|
|
|
203 |
initial_prompt: Optional[str]
|
204 |
temperature: float
|
205 |
compression_ratio_threshold: float
|
@@ -208,13 +250,28 @@ class WhisperValues:
|
|
208 |
min_speech_duration_ms: int
|
209 |
max_speech_duration_s: float
|
210 |
min_silence_duration_ms: int
|
211 |
-
window_size_samples: int
|
212 |
speech_pad_ms: int
|
213 |
chunk_length_s: int
|
214 |
batch_size: int
|
215 |
is_diarize: bool
|
216 |
hf_token: str
|
217 |
diarization_device: str
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
218 |
"""
|
219 |
A data class to use Whisper parameters.
|
220 |
-
"""
|
|
|
15 |
best_of: gr.Number
|
16 |
patience: gr.Number
|
17 |
condition_on_previous_text: gr.Checkbox
|
18 |
+
prompt_reset_on_temperature: gr.Slider
|
19 |
initial_prompt: gr.Textbox
|
20 |
temperature: gr.Slider
|
21 |
compression_ratio_threshold: gr.Number
|
|
|
24 |
min_speech_duration_ms: gr.Number
|
25 |
max_speech_duration_s: gr.Number
|
26 |
min_silence_duration_ms: gr.Number
|
|
|
27 |
speech_pad_ms: gr.Number
|
28 |
chunk_length_s: gr.Number
|
29 |
batch_size: gr.Number
|
30 |
is_diarize: gr.Checkbox
|
31 |
hf_token: gr.Textbox
|
32 |
diarization_device: gr.Dropdown
|
33 |
+
length_penalty: gr.Number
|
34 |
+
repetition_penalty: gr.Number
|
35 |
+
no_repeat_ngram_size: gr.Number
|
36 |
+
prefix: gr.Textbox
|
37 |
+
suppress_blank: gr.Checkbox
|
38 |
+
suppress_tokens: gr.Textbox
|
39 |
+
max_initial_timestamp: gr.Number
|
40 |
+
word_timestamps: gr.Checkbox
|
41 |
+
prepend_punctuations: gr.Textbox
|
42 |
+
append_punctuations: gr.Textbox
|
43 |
+
max_new_tokens: gr.Number
|
44 |
+
chunk_length: gr.Number
|
45 |
+
hallucination_silence_threshold: gr.Number
|
46 |
+
hotwords: gr.Textbox
|
47 |
+
language_detection_threshold: gr.Number
|
48 |
+
language_detection_segments: gr.Number
|
49 |
"""
|
50 |
A data class for Gradio components of the Whisper Parameters. Use "before" Gradio pre-processing.
|
51 |
This data class is used to mitigate the key-value problem between Gradio components and function parameters.
|
|
|
127 |
This parameter is related with Silero VAD. In the end of each speech chunk wait for min_silence_duration_ms
|
128 |
before separating it
|
129 |
|
|
|
|
|
|
|
|
|
|
|
130 |
speech_pad_ms: gr.Number
|
131 |
This parameter is related with Silero VAD. Final speech chunks are padded by speech_pad_ms each side
|
132 |
|
|
|
146 |
|
147 |
diarization_device: gr.Dropdown
|
148 |
This parameter is related with whisperx. Device to run diarization model
|
149 |
+
|
150 |
+
length_penalty:
|
151 |
+
This parameter is related to faster-whisper. Exponential length penalty constant.
|
152 |
+
|
153 |
+
repetition_penalty:
|
154 |
+
This parameter is related to faster-whisper. Penalty applied to the score of previously generated tokens
|
155 |
+
(set > 1 to penalize).
|
156 |
+
|
157 |
+
no_repeat_ngram_size:
|
158 |
+
This parameter is related to faster-whisper. Prevent repetitions of n-grams with this size (set 0 to disable).
|
159 |
+
|
160 |
+
prefix:
|
161 |
+
This parameter is related to faster-whisper. Optional text to provide as a prefix for the first window.
|
162 |
+
|
163 |
+
suppress_blank:
|
164 |
+
This parameter is related to faster-whisper. Suppress blank outputs at the beginning of the sampling.
|
165 |
+
|
166 |
+
suppress_tokens:
|
167 |
+
This parameter is related to faster-whisper. List of token IDs to suppress. -1 will suppress a default set
|
168 |
+
of symbols as defined in the model config.json file.
|
169 |
+
|
170 |
+
max_initial_timestamp:
|
171 |
+
This parameter is related to faster-whisper. The initial timestamp cannot be later than this.
|
172 |
+
|
173 |
+
word_timestamps:
|
174 |
+
This parameter is related to faster-whisper. Extract word-level timestamps using the cross-attention pattern
|
175 |
+
and dynamic time warping, and include the timestamps for each word in each segment.
|
176 |
+
|
177 |
+
prepend_punctuations:
|
178 |
+
This parameter is related to faster-whisper. If word_timestamps is True, merge these punctuation symbols
|
179 |
+
with the next word.
|
180 |
+
|
181 |
+
append_punctuations:
|
182 |
+
This parameter is related to faster-whisper. If word_timestamps is True, merge these punctuation symbols
|
183 |
+
with the previous word.
|
184 |
+
|
185 |
+
max_new_tokens:
|
186 |
+
This parameter is related to faster-whisper. Maximum number of new tokens to generate per-chunk. If not set,
|
187 |
+
the maximum will be set by the default max_length.
|
188 |
+
|
189 |
+
chunk_length:
|
190 |
+
This parameter is related to faster-whisper. The length of audio segments. If it is not None, it will overwrite the
|
191 |
+
default chunk_length of the FeatureExtractor.
|
192 |
+
|
193 |
+
hallucination_silence_threshold:
|
194 |
+
This parameter is related to faster-whisper. When word_timestamps is True, skip silent periods longer than this threshold
|
195 |
+
(in seconds) when a possible hallucination is detected.
|
196 |
+
|
197 |
+
hotwords:
|
198 |
+
This parameter is related to faster-whisper. Hotwords/hint phrases to provide the model with. Has no effect if prefix is not None.
|
199 |
+
|
200 |
+
language_detection_threshold:
|
201 |
+
This parameter is related to faster-whisper. If the maximum probability of the language tokens is higher than this value, the language is detected.
|
202 |
+
|
203 |
+
language_detection_segments:
|
204 |
+
This parameter is related to faster-whisper. Number of segments to consider for the language detection.
|
205 |
"""
|
206 |
|
207 |
def as_list(self) -> list:
|
|
|
226 |
WhisperValues
|
227 |
Data class that has values of parameters
|
228 |
"""
|
229 |
+
return WhisperValues(*args)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
230 |
|
231 |
|
232 |
@dataclass
|
|
|
241 |
best_of: int
|
242 |
patience: float
|
243 |
condition_on_previous_text: bool
|
244 |
+
prompt_reset_on_temperature: float
|
245 |
initial_prompt: Optional[str]
|
246 |
temperature: float
|
247 |
compression_ratio_threshold: float
|
|
|
250 |
min_speech_duration_ms: int
|
251 |
max_speech_duration_s: float
|
252 |
min_silence_duration_ms: int
|
|
|
253 |
speech_pad_ms: int
|
254 |
chunk_length_s: int
|
255 |
batch_size: int
|
256 |
is_diarize: bool
|
257 |
hf_token: str
|
258 |
diarization_device: str
|
259 |
+
length_penalty: float
|
260 |
+
repetition_penalty: float
|
261 |
+
no_repeat_ngram_size: int
|
262 |
+
prefix: Optional[str]
|
263 |
+
suppress_blank: bool
|
264 |
+
suppress_tokens: Optional[str]
|
265 |
+
max_initial_timestamp: float
|
266 |
+
word_timestamps: bool
|
267 |
+
prepend_punctuations: Optional[str]
|
268 |
+
append_punctuations: Optional[str]
|
269 |
+
max_new_tokens: Optional[int]
|
270 |
+
chunk_length: Optional[int]
|
271 |
+
hallucination_silence_threshold: Optional[float]
|
272 |
+
hotwords: Optional[str]
|
273 |
+
language_detection_threshold: Optional[float]
|
274 |
+
language_detection_segments: int
|
275 |
"""
|
276 |
A data class to use Whisper parameters.
|
277 |
+
"""
|
notebook/whisper-webui.ipynb
CHANGED
@@ -13,7 +13,7 @@
|
|
13 |
"\n",
|
14 |
"If you find this project useful, please consider supporting it:\n",
|
15 |
"\n",
|
16 |
-
"<a href=\"https://ko-fi.com/
|
17 |
" <img src=\"https://storage.ko-fi.com/cdn/kofi2.png?v=3\" alt=\"Buy Me a Coffee at ko-fi.com\" height=\"36\">\n",
|
18 |
"</a>\n",
|
19 |
"\n",
|
@@ -53,9 +53,10 @@
|
|
53 |
"!git clone https://github.com/jhj0517/Whisper-WebUI.git\n",
|
54 |
"%cd Whisper-WebUI\n",
|
55 |
"!pip install git+https://github.com/jhj0517/jhj0517-whisper.git\n",
|
56 |
-
"!pip install faster-whisper==1.0.
|
57 |
-
"!pip install gradio==4.
|
58 |
-
"
|
|
|
59 |
"!pip install tokenizers==0.19.1\n",
|
60 |
"!pip install pyannote.audio==3.3.1"
|
61 |
]
|
@@ -70,7 +71,7 @@
|
|
70 |
"\n",
|
71 |
"USERNAME = '' #@param {type: \"string\"}\n",
|
72 |
"PASSWORD = '' #@param {type: \"string\"}\n",
|
73 |
-
"WHISPER_TYPE = 'faster-whisper'
|
74 |
"THEME = '' #@param {type: \"string\"}\n",
|
75 |
"\n",
|
76 |
"arguments = \"\"\n",
|
|
|
13 |
"\n",
|
14 |
"If you find this project useful, please consider supporting it:\n",
|
15 |
"\n",
|
16 |
+
"<a href=\"https://ko-fi.com/jhj0517\" target=\"_blank\">\n",
|
17 |
" <img src=\"https://storage.ko-fi.com/cdn/kofi2.png?v=3\" alt=\"Buy Me a Coffee at ko-fi.com\" height=\"36\">\n",
|
18 |
"</a>\n",
|
19 |
"\n",
|
|
|
53 |
"!git clone https://github.com/jhj0517/Whisper-WebUI.git\n",
|
54 |
"%cd Whisper-WebUI\n",
|
55 |
"!pip install git+https://github.com/jhj0517/jhj0517-whisper.git\n",
|
56 |
+
"!pip install faster-whisper==1.0.3\n",
|
57 |
+
"!pip install gradio==4.29.0\n",
|
58 |
+
"# Temporal bug fix from https://github.com/jhj0517/Whisper-WebUI/issues/220\n",
|
59 |
+
"!pip install pytubefix\n",
|
60 |
"!pip install tokenizers==0.19.1\n",
|
61 |
"!pip install pyannote.audio==3.3.1"
|
62 |
]
|
|
|
71 |
"\n",
|
72 |
"USERNAME = '' #@param {type: \"string\"}\n",
|
73 |
"PASSWORD = '' #@param {type: \"string\"}\n",
|
74 |
+
"WHISPER_TYPE = 'faster-whisper' # @param [\"whisper\", \"faster-whisper\", \"insanely-fast-whisper\"]\n",
|
75 |
"THEME = '' #@param {type: \"string\"}\n",
|
76 |
"\n",
|
77 |
"arguments = \"\"\n",
|
requirements.txt
CHANGED
@@ -1,8 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
1 |
--extra-index-url https://download.pytorch.org/whl/cu121
|
2 |
torch
|
3 |
git+https://github.com/jhj0517/jhj0517-whisper.git
|
4 |
-
faster-whisper==1.0.
|
5 |
-
transformers
|
6 |
-
|
7 |
-
|
8 |
pyannote.audio==3.3.1
|
|
|
1 |
+
# Remove the --extra-index-url line below if you're not using Nvidia GPU.
|
2 |
+
# If you're using it, update url to your CUDA version (CUDA 12.1 is minimum requirement):
|
3 |
+
# For CUDA 12.1, use : https://download.pytorch.org/whl/cu121
|
4 |
+
# For CUDA 12.4, use : https://download.pytorch.org/whl/cu124
|
5 |
+
|
6 |
--extra-index-url https://download.pytorch.org/whl/cu121
|
7 |
torch
|
8 |
git+https://github.com/jhj0517/jhj0517-whisper.git
|
9 |
+
faster-whisper==1.0.3
|
10 |
+
transformers==4.42.3
|
11 |
+
gradio==4.29.0
|
12 |
+
pytubefix
|
13 |
pyannote.audio==3.3.1
|
start-webui.bat
CHANGED
@@ -1,7 +1,7 @@
|
|
1 |
@echo off
|
2 |
|
3 |
call venv\scripts\activate
|
4 |
-
python app.py
|
5 |
|
6 |
echo "launching the app"
|
7 |
pause
|
|
|
1 |
@echo off
|
2 |
|
3 |
call venv\scripts\activate
|
4 |
+
python app.py %*
|
5 |
|
6 |
echo "launching the app"
|
7 |
pause
|