Spaces:
Running
Running
Merge branch 'main' of https://huggingface.co/spaces/aadnk/whisper-webui into main
Browse files- app.py +72 -31
- cli.py +17 -2
- config.json5 +10 -1
- src/config.py +11 -1
- src/utils.py +118 -8
- src/vad.py +8 -0
- src/whisper/whisperContainer.py +3 -2
app.py
CHANGED
@@ -1,4 +1,5 @@
|
|
1 |
from datetime import datetime
|
|
|
2 |
import math
|
3 |
from typing import Iterator, Union
|
4 |
import argparse
|
@@ -28,7 +29,7 @@ import ffmpeg
|
|
28 |
import gradio as gr
|
29 |
|
30 |
from src.download import ExceededMaximumDuration, download_url
|
31 |
-
from src.utils import slugify, write_srt, write_vtt
|
32 |
from src.vad import AbstractTranscription, NonSpeechStrategy, PeriodicTranscriptionConfig, TranscriptionConfig, VadPeriodicTranscription, VadSileroTranscription
|
33 |
from src.whisper.abstractWhisperContainer import AbstractWhisperContainer
|
34 |
from src.whisper.whisperFactory import create_whisper_container
|
@@ -84,37 +85,49 @@ class WhisperTranscriber:
|
|
84 |
print("[Auto parallel] Using GPU devices " + str(self.parallel_device_list) + " and " + str(self.vad_cpu_cores) + " CPU cores for VAD/transcription.")
|
85 |
|
86 |
# Entry function for the simple tab
|
87 |
-
def transcribe_webui_simple(self, modelName, languageName, urlData, multipleFiles, microphoneData, task,
|
88 |
-
|
|
|
|
|
|
|
|
|
89 |
|
90 |
# Entry function for the simple tab progress
|
91 |
-
def transcribe_webui_simple_progress(self, modelName, languageName, urlData, multipleFiles, microphoneData, task,
|
92 |
-
|
|
|
|
|
93 |
|
94 |
-
vadOptions = VadOptions(vad, vadMergeWindow, vadMaxMergeSize,
|
95 |
|
96 |
-
return self.transcribe_webui(modelName, languageName, urlData, multipleFiles, microphoneData, task, vadOptions,
|
|
|
97 |
|
98 |
# Entry function for the full tab
|
99 |
def transcribe_webui_full(self, modelName, languageName, urlData, multipleFiles, microphoneData, task,
|
100 |
-
|
101 |
-
|
102 |
-
|
103 |
-
|
|
|
|
|
104 |
|
105 |
return self.transcribe_webui_full_progress(modelName, languageName, urlData, multipleFiles, microphoneData, task,
|
106 |
vad, vadMergeWindow, vadMaxMergeSize, vadPadding, vadPromptWindow, vadInitialPromptMode,
|
|
|
107 |
initial_prompt, temperature, best_of, beam_size, patience, length_penalty, suppress_tokens,
|
108 |
condition_on_previous_text, fp16, temperature_increment_on_fallback,
|
109 |
compression_ratio_threshold, logprob_threshold, no_speech_threshold)
|
110 |
|
111 |
# Entry function for the full tab with progress
|
112 |
def transcribe_webui_full_progress(self, modelName, languageName, urlData, multipleFiles, microphoneData, task,
|
113 |
-
|
114 |
-
|
115 |
-
|
116 |
-
|
117 |
-
|
|
|
|
|
118 |
|
119 |
# Handle temperature_increment_on_fallback
|
120 |
if temperature_increment_on_fallback is not None:
|
@@ -128,13 +141,15 @@ class WhisperTranscriber:
|
|
128 |
initial_prompt=initial_prompt, temperature=temperature, best_of=best_of, beam_size=beam_size, patience=patience, length_penalty=length_penalty, suppress_tokens=suppress_tokens,
|
129 |
condition_on_previous_text=condition_on_previous_text, fp16=fp16,
|
130 |
compression_ratio_threshold=compression_ratio_threshold, logprob_threshold=logprob_threshold, no_speech_threshold=no_speech_threshold,
|
|
|
131 |
progress=progress)
|
132 |
|
133 |
def transcribe_webui(self, modelName, languageName, urlData, multipleFiles, microphoneData, task,
|
134 |
-
vadOptions: VadOptions, progress: gr.Progress = None,
|
|
|
135 |
try:
|
136 |
sources = self.__get_source(urlData, multipleFiles, microphoneData)
|
137 |
-
|
138 |
try:
|
139 |
selectedLanguage = languageName.lower() if len(languageName) > 0 else None
|
140 |
selectedModel = modelName if modelName is not None else "base"
|
@@ -185,7 +200,7 @@ class WhisperTranscriber:
|
|
185 |
# Update progress
|
186 |
current_progress += source_audio_duration
|
187 |
|
188 |
-
source_download, source_text, source_vtt = self.write_result(result, filePrefix, outputDirectory)
|
189 |
|
190 |
if len(sources) > 1:
|
191 |
# Add new line separators
|
@@ -359,7 +374,7 @@ class WhisperTranscriber:
|
|
359 |
|
360 |
return config
|
361 |
|
362 |
-
def write_result(self, result: dict, source_name: str, output_dir: str):
|
363 |
if not os.path.exists(output_dir):
|
364 |
os.makedirs(output_dir)
|
365 |
|
@@ -368,13 +383,15 @@ class WhisperTranscriber:
|
|
368 |
languageMaxLineWidth = self.__get_max_line_width(language)
|
369 |
|
370 |
print("Max line width " + str(languageMaxLineWidth))
|
371 |
-
vtt = self.__get_subs(result["segments"], "vtt", languageMaxLineWidth)
|
372 |
-
srt = self.__get_subs(result["segments"], "srt", languageMaxLineWidth)
|
|
|
373 |
|
374 |
output_files = []
|
375 |
output_files.append(self.__create_file(srt, output_dir, source_name + "-subs.srt"));
|
376 |
output_files.append(self.__create_file(vtt, output_dir, source_name + "-subs.vtt"));
|
377 |
output_files.append(self.__create_file(text, output_dir, source_name + "-transcript.txt"));
|
|
|
378 |
|
379 |
return output_files, text, vtt
|
380 |
|
@@ -394,13 +411,13 @@ class WhisperTranscriber:
|
|
394 |
# 80 latin characters should fit on a 1080p/720p screen
|
395 |
return 80
|
396 |
|
397 |
-
def __get_subs(self, segments: Iterator[dict], format: str, maxLineWidth: int) -> str:
|
398 |
segmentStream = StringIO()
|
399 |
|
400 |
if format == 'vtt':
|
401 |
-
write_vtt(segments, file=segmentStream, maxLineWidth=maxLineWidth)
|
402 |
elif format == 'srt':
|
403 |
-
write_srt(segments, file=segmentStream, maxLineWidth=maxLineWidth)
|
404 |
else:
|
405 |
raise Exception("Unknown format " + format)
|
406 |
|
@@ -460,24 +477,34 @@ def create_ui(app_config: ApplicationConfig):
|
|
460 |
|
461 |
whisper_models = app_config.get_model_names()
|
462 |
|
463 |
-
|
464 |
gr.Dropdown(choices=whisper_models, value=app_config.default_model_name, label="Model"),
|
465 |
gr.Dropdown(choices=sorted(get_language_names()), label="Language", value=app_config.language),
|
466 |
gr.Text(label="URL (YouTube, etc.)"),
|
467 |
gr.File(label="Upload Files", file_count="multiple"),
|
468 |
gr.Audio(source="microphone", type="filepath", label="Microphone Input"),
|
469 |
gr.Dropdown(choices=["transcribe", "translate"], label="Task", value=app_config.task),
|
|
|
|
|
|
|
470 |
gr.Dropdown(choices=["none", "silero-vad", "silero-vad-skip-gaps", "silero-vad-expand-into-gaps", "periodic-vad"], value=app_config.default_vad, label="VAD"),
|
471 |
gr.Number(label="VAD - Merge Window (s)", precision=0, value=app_config.vad_merge_window),
|
472 |
gr.Number(label="VAD - Max Merge Size (s)", precision=0, value=app_config.vad_max_merge_size),
|
473 |
-
|
474 |
-
|
|
|
|
|
|
|
475 |
]
|
476 |
|
477 |
is_queue_mode = app_config.queue_concurrency_count is not None and app_config.queue_concurrency_count > 0
|
478 |
|
479 |
simple_transcribe = gr.Interface(fn=ui.transcribe_webui_simple_progress if is_queue_mode else ui.transcribe_webui_simple,
|
480 |
-
description=ui_description, article=ui_article, inputs=
|
|
|
|
|
|
|
|
|
481 |
gr.File(label="Download"),
|
482 |
gr.Text(label="Transcription"),
|
483 |
gr.Text(label="Segments")
|
@@ -487,8 +514,17 @@ def create_ui(app_config: ApplicationConfig):
|
|
487 |
|
488 |
full_transcribe = gr.Interface(fn=ui.transcribe_webui_full_progress if is_queue_mode else ui.transcribe_webui_full,
|
489 |
description=full_description, article=ui_article, inputs=[
|
490 |
-
*
|
|
|
|
|
|
|
|
|
491 |
gr.Dropdown(choices=["prepend_first_segment", "prepend_all_segments"], value=app_config.vad_initial_prompt_mode, label="VAD - Initial Prompt Mode"),
|
|
|
|
|
|
|
|
|
|
|
492 |
gr.TextArea(label="Initial Prompt"),
|
493 |
gr.Number(label="Temperature", value=app_config.temperature),
|
494 |
gr.Number(label="Best Of - Non-zero temperature", value=app_config.best_of, precision=0),
|
@@ -501,7 +537,7 @@ def create_ui(app_config: ApplicationConfig):
|
|
501 |
gr.Number(label="Temperature increment on fallback", value=app_config.temperature_increment_on_fallback),
|
502 |
gr.Number(label="Compression ratio threshold", value=app_config.compression_ratio_threshold),
|
503 |
gr.Number(label="Logprob threshold", value=app_config.logprob_threshold),
|
504 |
-
gr.Number(label="No speech threshold", value=app_config.no_speech_threshold)
|
505 |
], outputs=[
|
506 |
gr.File(label="Download"),
|
507 |
gr.Text(label="Transcription"),
|
@@ -560,9 +596,14 @@ if __name__ == '__main__':
|
|
560 |
help="the Whisper implementation to use")
|
561 |
parser.add_argument("--compute_type", type=str, default=default_app_config.compute_type, choices=["default", "auto", "int8", "int8_float16", "int16", "float16", "float32"], \
|
562 |
help="the compute type to use for inference")
|
|
|
|
|
563 |
|
564 |
args = parser.parse_args().__dict__
|
565 |
|
566 |
updated_config = default_app_config.update(**args)
|
567 |
|
|
|
|
|
|
|
568 |
create_ui(app_config=updated_config)
|
|
|
1 |
from datetime import datetime
|
2 |
+
import json
|
3 |
import math
|
4 |
from typing import Iterator, Union
|
5 |
import argparse
|
|
|
29 |
import gradio as gr
|
30 |
|
31 |
from src.download import ExceededMaximumDuration, download_url
|
32 |
+
from src.utils import optional_int, slugify, write_srt, write_vtt
|
33 |
from src.vad import AbstractTranscription, NonSpeechStrategy, PeriodicTranscriptionConfig, TranscriptionConfig, VadPeriodicTranscription, VadSileroTranscription
|
34 |
from src.whisper.abstractWhisperContainer import AbstractWhisperContainer
|
35 |
from src.whisper.whisperFactory import create_whisper_container
|
|
|
85 |
print("[Auto parallel] Using GPU devices " + str(self.parallel_device_list) + " and " + str(self.vad_cpu_cores) + " CPU cores for VAD/transcription.")
|
86 |
|
87 |
# Entry function for the simple tab
|
88 |
+
def transcribe_webui_simple(self, modelName, languageName, urlData, multipleFiles, microphoneData, task,
|
89 |
+
vad, vadMergeWindow, vadMaxMergeSize,
|
90 |
+
word_timestamps: bool = False, highlight_words: bool = False):
|
91 |
+
return self.transcribe_webui_simple_progress(modelName, languageName, urlData, multipleFiles, microphoneData, task,
|
92 |
+
vad, vadMergeWindow, vadMaxMergeSize,
|
93 |
+
word_timestamps, highlight_words)
|
94 |
|
95 |
# Entry function for the simple tab progress
|
96 |
+
def transcribe_webui_simple_progress(self, modelName, languageName, urlData, multipleFiles, microphoneData, task,
|
97 |
+
vad, vadMergeWindow, vadMaxMergeSize,
|
98 |
+
word_timestamps: bool = False, highlight_words: bool = False,
|
99 |
+
progress=gr.Progress()):
|
100 |
|
101 |
+
vadOptions = VadOptions(vad, vadMergeWindow, vadMaxMergeSize, self.app_config.vad_padding, self.app_config.vad_prompt_window, self.app_config.vad_initial_prompt_mode)
|
102 |
|
103 |
+
return self.transcribe_webui(modelName, languageName, urlData, multipleFiles, microphoneData, task, vadOptions,
|
104 |
+
word_timestamps=word_timestamps, highlight_words=highlight_words, progress=progress)
|
105 |
|
106 |
# Entry function for the full tab
|
107 |
def transcribe_webui_full(self, modelName, languageName, urlData, multipleFiles, microphoneData, task,
|
108 |
+
vad, vadMergeWindow, vadMaxMergeSize, vadPadding, vadPromptWindow, vadInitialPromptMode,
|
109 |
+
# Word timestamps
|
110 |
+
word_timestamps: bool, highlight_words: bool, prepend_punctuations: str, append_punctuations: str,
|
111 |
+
initial_prompt: str, temperature: float, best_of: int, beam_size: int, patience: float, length_penalty: float, suppress_tokens: str,
|
112 |
+
condition_on_previous_text: bool, fp16: bool, temperature_increment_on_fallback: float,
|
113 |
+
compression_ratio_threshold: float, logprob_threshold: float, no_speech_threshold: float):
|
114 |
|
115 |
return self.transcribe_webui_full_progress(modelName, languageName, urlData, multipleFiles, microphoneData, task,
|
116 |
vad, vadMergeWindow, vadMaxMergeSize, vadPadding, vadPromptWindow, vadInitialPromptMode,
|
117 |
+
word_timestamps, highlight_words, prepend_punctuations, append_punctuations,
|
118 |
initial_prompt, temperature, best_of, beam_size, patience, length_penalty, suppress_tokens,
|
119 |
condition_on_previous_text, fp16, temperature_increment_on_fallback,
|
120 |
compression_ratio_threshold, logprob_threshold, no_speech_threshold)
|
121 |
|
122 |
# Entry function for the full tab with progress
|
123 |
def transcribe_webui_full_progress(self, modelName, languageName, urlData, multipleFiles, microphoneData, task,
|
124 |
+
vad, vadMergeWindow, vadMaxMergeSize, vadPadding, vadPromptWindow, vadInitialPromptMode,
|
125 |
+
# Word timestamps
|
126 |
+
word_timestamps: bool, highlight_words: bool, prepend_punctuations: str, append_punctuations: str,
|
127 |
+
initial_prompt: str, temperature: float, best_of: int, beam_size: int, patience: float, length_penalty: float, suppress_tokens: str,
|
128 |
+
condition_on_previous_text: bool, fp16: bool, temperature_increment_on_fallback: float,
|
129 |
+
compression_ratio_threshold: float, logprob_threshold: float, no_speech_threshold: float,
|
130 |
+
progress=gr.Progress()):
|
131 |
|
132 |
# Handle temperature_increment_on_fallback
|
133 |
if temperature_increment_on_fallback is not None:
|
|
|
141 |
initial_prompt=initial_prompt, temperature=temperature, best_of=best_of, beam_size=beam_size, patience=patience, length_penalty=length_penalty, suppress_tokens=suppress_tokens,
|
142 |
condition_on_previous_text=condition_on_previous_text, fp16=fp16,
|
143 |
compression_ratio_threshold=compression_ratio_threshold, logprob_threshold=logprob_threshold, no_speech_threshold=no_speech_threshold,
|
144 |
+
word_timestamps=word_timestamps, prepend_punctuations=prepend_punctuations, append_punctuations=append_punctuations, highlight_words=highlight_words,
|
145 |
progress=progress)
|
146 |
|
147 |
def transcribe_webui(self, modelName, languageName, urlData, multipleFiles, microphoneData, task,
|
148 |
+
vadOptions: VadOptions, progress: gr.Progress = None, highlight_words: bool = False,
|
149 |
+
**decodeOptions: dict):
|
150 |
try:
|
151 |
sources = self.__get_source(urlData, multipleFiles, microphoneData)
|
152 |
+
|
153 |
try:
|
154 |
selectedLanguage = languageName.lower() if len(languageName) > 0 else None
|
155 |
selectedModel = modelName if modelName is not None else "base"
|
|
|
200 |
# Update progress
|
201 |
current_progress += source_audio_duration
|
202 |
|
203 |
+
source_download, source_text, source_vtt = self.write_result(result, filePrefix, outputDirectory, highlight_words)
|
204 |
|
205 |
if len(sources) > 1:
|
206 |
# Add new line separators
|
|
|
374 |
|
375 |
return config
|
376 |
|
377 |
+
def write_result(self, result: dict, source_name: str, output_dir: str, highlight_words: bool = False):
|
378 |
if not os.path.exists(output_dir):
|
379 |
os.makedirs(output_dir)
|
380 |
|
|
|
383 |
languageMaxLineWidth = self.__get_max_line_width(language)
|
384 |
|
385 |
print("Max line width " + str(languageMaxLineWidth))
|
386 |
+
vtt = self.__get_subs(result["segments"], "vtt", languageMaxLineWidth, highlight_words=highlight_words)
|
387 |
+
srt = self.__get_subs(result["segments"], "srt", languageMaxLineWidth, highlight_words=highlight_words)
|
388 |
+
json_result = json.dumps(result, indent=4, ensure_ascii=False)
|
389 |
|
390 |
output_files = []
|
391 |
output_files.append(self.__create_file(srt, output_dir, source_name + "-subs.srt"));
|
392 |
output_files.append(self.__create_file(vtt, output_dir, source_name + "-subs.vtt"));
|
393 |
output_files.append(self.__create_file(text, output_dir, source_name + "-transcript.txt"));
|
394 |
+
output_files.append(self.__create_file(json_result, output_dir, source_name + "-result.json"));
|
395 |
|
396 |
return output_files, text, vtt
|
397 |
|
|
|
411 |
# 80 latin characters should fit on a 1080p/720p screen
|
412 |
return 80
|
413 |
|
414 |
+
def __get_subs(self, segments: Iterator[dict], format: str, maxLineWidth: int, highlight_words: bool = False) -> str:
|
415 |
segmentStream = StringIO()
|
416 |
|
417 |
if format == 'vtt':
|
418 |
+
write_vtt(segments, file=segmentStream, maxLineWidth=maxLineWidth, highlight_words=highlight_words)
|
419 |
elif format == 'srt':
|
420 |
+
write_srt(segments, file=segmentStream, maxLineWidth=maxLineWidth, highlight_words=highlight_words)
|
421 |
else:
|
422 |
raise Exception("Unknown format " + format)
|
423 |
|
|
|
477 |
|
478 |
whisper_models = app_config.get_model_names()
|
479 |
|
480 |
+
common_inputs = lambda : [
|
481 |
gr.Dropdown(choices=whisper_models, value=app_config.default_model_name, label="Model"),
|
482 |
gr.Dropdown(choices=sorted(get_language_names()), label="Language", value=app_config.language),
|
483 |
gr.Text(label="URL (YouTube, etc.)"),
|
484 |
gr.File(label="Upload Files", file_count="multiple"),
|
485 |
gr.Audio(source="microphone", type="filepath", label="Microphone Input"),
|
486 |
gr.Dropdown(choices=["transcribe", "translate"], label="Task", value=app_config.task),
|
487 |
+
]
|
488 |
+
|
489 |
+
common_vad_inputs = lambda : [
|
490 |
gr.Dropdown(choices=["none", "silero-vad", "silero-vad-skip-gaps", "silero-vad-expand-into-gaps", "periodic-vad"], value=app_config.default_vad, label="VAD"),
|
491 |
gr.Number(label="VAD - Merge Window (s)", precision=0, value=app_config.vad_merge_window),
|
492 |
gr.Number(label="VAD - Max Merge Size (s)", precision=0, value=app_config.vad_max_merge_size),
|
493 |
+
]
|
494 |
+
|
495 |
+
common_word_timestamps_inputs = lambda : [
|
496 |
+
gr.Checkbox(label="Word Timestamps", value=app_config.word_timestamps),
|
497 |
+
gr.Checkbox(label="Word Timestamps - Highlight Words", value=app_config.highlight_words),
|
498 |
]
|
499 |
|
500 |
is_queue_mode = app_config.queue_concurrency_count is not None and app_config.queue_concurrency_count > 0
|
501 |
|
502 |
simple_transcribe = gr.Interface(fn=ui.transcribe_webui_simple_progress if is_queue_mode else ui.transcribe_webui_simple,
|
503 |
+
description=ui_description, article=ui_article, inputs=[
|
504 |
+
*common_inputs(),
|
505 |
+
*common_vad_inputs(),
|
506 |
+
*common_word_timestamps_inputs(),
|
507 |
+
], outputs=[
|
508 |
gr.File(label="Download"),
|
509 |
gr.Text(label="Transcription"),
|
510 |
gr.Text(label="Segments")
|
|
|
514 |
|
515 |
full_transcribe = gr.Interface(fn=ui.transcribe_webui_full_progress if is_queue_mode else ui.transcribe_webui_full,
|
516 |
description=full_description, article=ui_article, inputs=[
|
517 |
+
*common_inputs(),
|
518 |
+
|
519 |
+
*common_vad_inputs(),
|
520 |
+
gr.Number(label="VAD - Padding (s)", precision=None, value=app_config.vad_padding),
|
521 |
+
gr.Number(label="VAD - Prompt Window (s)", precision=None, value=app_config.vad_prompt_window),
|
522 |
gr.Dropdown(choices=["prepend_first_segment", "prepend_all_segments"], value=app_config.vad_initial_prompt_mode, label="VAD - Initial Prompt Mode"),
|
523 |
+
|
524 |
+
*common_word_timestamps_inputs(),
|
525 |
+
gr.Text(label="Word Timestamps - Prepend Punctuations", value=app_config.prepend_punctuations),
|
526 |
+
gr.Text(label="Word Timestamps - Append Punctuations", value=app_config.append_punctuations),
|
527 |
+
|
528 |
gr.TextArea(label="Initial Prompt"),
|
529 |
gr.Number(label="Temperature", value=app_config.temperature),
|
530 |
gr.Number(label="Best Of - Non-zero temperature", value=app_config.best_of, precision=0),
|
|
|
537 |
gr.Number(label="Temperature increment on fallback", value=app_config.temperature_increment_on_fallback),
|
538 |
gr.Number(label="Compression ratio threshold", value=app_config.compression_ratio_threshold),
|
539 |
gr.Number(label="Logprob threshold", value=app_config.logprob_threshold),
|
540 |
+
gr.Number(label="No speech threshold", value=app_config.no_speech_threshold),
|
541 |
], outputs=[
|
542 |
gr.File(label="Download"),
|
543 |
gr.Text(label="Transcription"),
|
|
|
596 |
help="the Whisper implementation to use")
|
597 |
parser.add_argument("--compute_type", type=str, default=default_app_config.compute_type, choices=["default", "auto", "int8", "int8_float16", "int16", "float16", "float32"], \
|
598 |
help="the compute type to use for inference")
|
599 |
+
parser.add_argument("--threads", type=optional_int, default=0,
|
600 |
+
help="number of threads used by torch for CPU inference; supercedes MKL_NUM_THREADS/OMP_NUM_THREADS")
|
601 |
|
602 |
args = parser.parse_args().__dict__
|
603 |
|
604 |
updated_config = default_app_config.update(**args)
|
605 |
|
606 |
+
if (threads := args.pop("threads")) > 0:
|
607 |
+
torch.set_num_threads(threads)
|
608 |
+
|
609 |
create_ui(app_config=updated_config)
|
cli.py
CHANGED
@@ -95,6 +95,17 @@ def cli():
|
|
95 |
parser.add_argument("--no_speech_threshold", type=optional_float, default=app_config.no_speech_threshold, \
|
96 |
help="if the probability of the <|nospeech|> token is higher than this value AND the decoding has failed due to `logprob_threshold`, consider the segment as silence")
|
97 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
98 |
args = parser.parse_args().__dict__
|
99 |
model_name: str = args.pop("model")
|
100 |
model_dir: str = args.pop("model_dir")
|
@@ -102,6 +113,9 @@ def cli():
|
|
102 |
device: str = args.pop("device")
|
103 |
os.makedirs(output_dir, exist_ok=True)
|
104 |
|
|
|
|
|
|
|
105 |
whisper_implementation = args.pop("whisper_implementation")
|
106 |
print(f"Using {whisper_implementation} for Whisper")
|
107 |
|
@@ -126,6 +140,7 @@ def cli():
|
|
126 |
auto_parallel = args.pop("auto_parallel")
|
127 |
|
128 |
compute_type = args.pop("compute_type")
|
|
|
129 |
|
130 |
transcriber = WhisperTranscriber(delete_uploaded_files=False, vad_cpu_cores=vad_cpu_cores, app_config=app_config)
|
131 |
transcriber.set_parallel_devices(args.pop("vad_parallel_devices"))
|
@@ -133,7 +148,7 @@ def cli():
|
|
133 |
|
134 |
model = create_whisper_container(whisper_implementation=whisper_implementation, model_name=model_name,
|
135 |
device=device, compute_type=compute_type, download_root=model_dir, models=app_config.models)
|
136 |
-
|
137 |
if (transcriber._has_parallel_devices()):
|
138 |
print("Using parallel devices:", transcriber.parallel_device_list)
|
139 |
|
@@ -158,7 +173,7 @@ def cli():
|
|
158 |
|
159 |
result = transcriber.transcribe_file(model, source_path, temperature=temperature, vadOptions=vadOptions, **args)
|
160 |
|
161 |
-
transcriber.write_result(result, source_name, output_dir)
|
162 |
|
163 |
transcriber.close()
|
164 |
|
|
|
95 |
parser.add_argument("--no_speech_threshold", type=optional_float, default=app_config.no_speech_threshold, \
|
96 |
help="if the probability of the <|nospeech|> token is higher than this value AND the decoding has failed due to `logprob_threshold`, consider the segment as silence")
|
97 |
|
98 |
+
parser.add_argument("--word_timestamps", type=str2bool, default=app_config.word_timestamps,
|
99 |
+
help="(experimental) extract word-level timestamps and refine the results based on them")
|
100 |
+
parser.add_argument("--prepend_punctuations", type=str, default=app_config.prepend_punctuations,
|
101 |
+
help="if word_timestamps is True, merge these punctuation symbols with the next word")
|
102 |
+
parser.add_argument("--append_punctuations", type=str, default=app_config.append_punctuations,
|
103 |
+
help="if word_timestamps is True, merge these punctuation symbols with the previous word")
|
104 |
+
parser.add_argument("--highlight_words", type=str2bool, default=app_config.highlight_words,
|
105 |
+
help="(requires --word_timestamps True) underline each word as it is spoken in srt and vtt")
|
106 |
+
parser.add_argument("--threads", type=optional_int, default=0,
|
107 |
+
help="number of threads used by torch for CPU inference; supercedes MKL_NUM_THREADS/OMP_NUM_THREADS")
|
108 |
+
|
109 |
args = parser.parse_args().__dict__
|
110 |
model_name: str = args.pop("model")
|
111 |
model_dir: str = args.pop("model_dir")
|
|
|
113 |
device: str = args.pop("device")
|
114 |
os.makedirs(output_dir, exist_ok=True)
|
115 |
|
116 |
+
if (threads := args.pop("threads")) > 0:
|
117 |
+
torch.set_num_threads(threads)
|
118 |
+
|
119 |
whisper_implementation = args.pop("whisper_implementation")
|
120 |
print(f"Using {whisper_implementation} for Whisper")
|
121 |
|
|
|
140 |
auto_parallel = args.pop("auto_parallel")
|
141 |
|
142 |
compute_type = args.pop("compute_type")
|
143 |
+
highlight_words = args.pop("highlight_words")
|
144 |
|
145 |
transcriber = WhisperTranscriber(delete_uploaded_files=False, vad_cpu_cores=vad_cpu_cores, app_config=app_config)
|
146 |
transcriber.set_parallel_devices(args.pop("vad_parallel_devices"))
|
|
|
148 |
|
149 |
model = create_whisper_container(whisper_implementation=whisper_implementation, model_name=model_name,
|
150 |
device=device, compute_type=compute_type, download_root=model_dir, models=app_config.models)
|
151 |
+
|
152 |
if (transcriber._has_parallel_devices()):
|
153 |
print("Using parallel devices:", transcriber.parallel_device_list)
|
154 |
|
|
|
173 |
|
174 |
result = transcriber.transcribe_file(model, source_path, temperature=temperature, vadOptions=vadOptions, **args)
|
175 |
|
176 |
+
transcriber.write_result(result, source_name, output_dir, highlight_words)
|
177 |
|
178 |
transcriber.close()
|
179 |
|
config.json5
CHANGED
@@ -128,5 +128,14 @@
|
|
128 |
// If the average log probability is lower than this value, treat the decoding as failed
|
129 |
"logprob_threshold": -1.0,
|
130 |
// If the probability of the <no-speech> token is higher than this value AND the decoding has failed due to `logprob_threshold`, consider the segment as silence
|
131 |
-
"no_speech_threshold": 0.6
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
132 |
}
|
|
|
128 |
// If the average log probability is lower than this value, treat the decoding as failed
|
129 |
"logprob_threshold": -1.0,
|
130 |
// If the probability of the <no-speech> token is higher than this value AND the decoding has failed due to `logprob_threshold`, consider the segment as silence
|
131 |
+
"no_speech_threshold": 0.6,
|
132 |
+
|
133 |
+
// (experimental) extract word-level timestamps and refine the results based on them
|
134 |
+
"word_timestamps": false,
|
135 |
+
// if word_timestamps is True, merge these punctuation symbols with the next word
|
136 |
+
"prepend_punctuations": "\"\'“¿([{-",
|
137 |
+
// if word_timestamps is True, merge these punctuation symbols with the previous word
|
138 |
+
"append_punctuations": "\"\'.。,,!!??::”)]}、",
|
139 |
+
// (requires --word_timestamps True) underline each word as it is spoken in srt and vtt
|
140 |
+
"highlight_words": false,
|
141 |
}
|
src/config.py
CHANGED
@@ -58,7 +58,11 @@ class ApplicationConfig:
|
|
58 |
condition_on_previous_text: bool = True, fp16: bool = True,
|
59 |
compute_type: str = "float16",
|
60 |
temperature_increment_on_fallback: float = 0.2, compression_ratio_threshold: float = 2.4,
|
61 |
-
logprob_threshold: float = -1.0, no_speech_threshold: float = 0.6
|
|
|
|
|
|
|
|
|
62 |
|
63 |
self.models = models
|
64 |
|
@@ -104,6 +108,12 @@ class ApplicationConfig:
|
|
104 |
self.logprob_threshold = logprob_threshold
|
105 |
self.no_speech_threshold = no_speech_threshold
|
106 |
|
|
|
|
|
|
|
|
|
|
|
|
|
107 |
def get_model_names(self):
|
108 |
return [ x.name for x in self.models ]
|
109 |
|
|
|
58 |
condition_on_previous_text: bool = True, fp16: bool = True,
|
59 |
compute_type: str = "float16",
|
60 |
temperature_increment_on_fallback: float = 0.2, compression_ratio_threshold: float = 2.4,
|
61 |
+
logprob_threshold: float = -1.0, no_speech_threshold: float = 0.6,
|
62 |
+
# Word timestamp settings
|
63 |
+
word_timestamps: bool = False, prepend_punctuations: str = "\"\'“¿([{-",
|
64 |
+
append_punctuations: str = "\"\'.。,,!!??::”)]}、",
|
65 |
+
highlight_words: bool = False):
|
66 |
|
67 |
self.models = models
|
68 |
|
|
|
108 |
self.logprob_threshold = logprob_threshold
|
109 |
self.no_speech_threshold = no_speech_threshold
|
110 |
|
111 |
+
# Word timestamp settings
|
112 |
+
self.word_timestamps = word_timestamps
|
113 |
+
self.prepend_punctuations = prepend_punctuations
|
114 |
+
self.append_punctuations = append_punctuations
|
115 |
+
self.highlight_words = highlight_words
|
116 |
+
|
117 |
def get_model_names(self):
|
118 |
return [ x.name for x in self.models ]
|
119 |
|
src/utils.py
CHANGED
@@ -3,7 +3,7 @@ import unicodedata
|
|
3 |
import re
|
4 |
|
5 |
import zlib
|
6 |
-
from typing import Iterator, TextIO
|
7 |
import tqdm
|
8 |
|
9 |
import urllib3
|
@@ -56,10 +56,14 @@ def write_txt(transcript: Iterator[dict], file: TextIO):
|
|
56 |
print(segment['text'].strip(), file=file, flush=True)
|
57 |
|
58 |
|
59 |
-
def write_vtt(transcript: Iterator[dict], file: TextIO,
|
|
|
|
|
|
|
60 |
print("WEBVTT\n", file=file)
|
61 |
-
|
62 |
-
|
|
|
63 |
|
64 |
print(
|
65 |
f"{format_timestamp(segment['start'])} --> {format_timestamp(segment['end'])}\n"
|
@@ -68,8 +72,8 @@ def write_vtt(transcript: Iterator[dict], file: TextIO, maxLineWidth=None):
|
|
68 |
flush=True,
|
69 |
)
|
70 |
|
71 |
-
|
72 |
-
|
73 |
"""
|
74 |
Write a transcript to a file in SRT format.
|
75 |
Example usage:
|
@@ -81,8 +85,10 @@ def write_srt(transcript: Iterator[dict], file: TextIO, maxLineWidth=None):
|
|
81 |
with open(Path(output_dir) / (audio_basename + ".srt"), "w", encoding="utf-8") as srt:
|
82 |
write_srt(result["segments"], file=srt)
|
83 |
"""
|
84 |
-
|
85 |
-
|
|
|
|
|
86 |
|
87 |
# write srt lines
|
88 |
print(
|
@@ -94,6 +100,110 @@ def write_srt(transcript: Iterator[dict], file: TextIO, maxLineWidth=None):
|
|
94 |
flush=True,
|
95 |
)
|
96 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
97 |
def process_text(text: str, maxLineWidth=None):
|
98 |
if (maxLineWidth is None or maxLineWidth < 0):
|
99 |
return text
|
|
|
3 |
import re
|
4 |
|
5 |
import zlib
|
6 |
+
from typing import Iterator, TextIO, Union
|
7 |
import tqdm
|
8 |
|
9 |
import urllib3
|
|
|
56 |
print(segment['text'].strip(), file=file, flush=True)
|
57 |
|
58 |
|
59 |
+
def write_vtt(transcript: Iterator[dict], file: TextIO,
|
60 |
+
maxLineWidth=None, highlight_words: bool = False):
|
61 |
+
iterator = __subtitle_preprocessor_iterator(transcript, maxLineWidth, highlight_words)
|
62 |
+
|
63 |
print("WEBVTT\n", file=file)
|
64 |
+
|
65 |
+
for segment in iterator:
|
66 |
+
text = segment['text'].replace('-->', '->')
|
67 |
|
68 |
print(
|
69 |
f"{format_timestamp(segment['start'])} --> {format_timestamp(segment['end'])}\n"
|
|
|
72 |
flush=True,
|
73 |
)
|
74 |
|
75 |
+
def write_srt(transcript: Iterator[dict], file: TextIO,
|
76 |
+
maxLineWidth=None, highlight_words: bool = False):
|
77 |
"""
|
78 |
Write a transcript to a file in SRT format.
|
79 |
Example usage:
|
|
|
85 |
with open(Path(output_dir) / (audio_basename + ".srt"), "w", encoding="utf-8") as srt:
|
86 |
write_srt(result["segments"], file=srt)
|
87 |
"""
|
88 |
+
iterator = __subtitle_preprocessor_iterator(transcript, maxLineWidth, highlight_words)
|
89 |
+
|
90 |
+
for i, segment in enumerate(iterator, start=1):
|
91 |
+
text = segment['text'].replace('-->', '->')
|
92 |
|
93 |
# write srt lines
|
94 |
print(
|
|
|
100 |
flush=True,
|
101 |
)
|
102 |
|
103 |
+
def __subtitle_preprocessor_iterator(transcript: Iterator[dict], maxLineWidth: int = None, highlight_words: bool = False):
|
104 |
+
for segment in transcript:
|
105 |
+
words = segment.get('words', [])
|
106 |
+
|
107 |
+
if len(words) == 0:
|
108 |
+
# Yield the segment as-is or processed
|
109 |
+
if maxLineWidth is None or maxLineWidth < 0:
|
110 |
+
yield segment
|
111 |
+
else:
|
112 |
+
yield {
|
113 |
+
'start': segment['start'],
|
114 |
+
'end': segment['end'],
|
115 |
+
'text': process_text(segment['text'].strip(), maxLineWidth)
|
116 |
+
}
|
117 |
+
# We are done
|
118 |
+
continue
|
119 |
+
|
120 |
+
subtitle_start = segment['start']
|
121 |
+
subtitle_end = segment['end']
|
122 |
+
|
123 |
+
text_words = [ this_word["word"] for this_word in words ]
|
124 |
+
subtitle_text = __join_words(text_words, maxLineWidth)
|
125 |
+
|
126 |
+
# Iterate over the words in the segment
|
127 |
+
if highlight_words:
|
128 |
+
last = subtitle_start
|
129 |
+
|
130 |
+
for i, this_word in enumerate(words):
|
131 |
+
start = this_word['start']
|
132 |
+
end = this_word['end']
|
133 |
+
|
134 |
+
if last != start:
|
135 |
+
# Display the text up to this point
|
136 |
+
yield {
|
137 |
+
'start': last,
|
138 |
+
'end': start,
|
139 |
+
'text': subtitle_text
|
140 |
+
}
|
141 |
+
|
142 |
+
# Display the text with the current word highlighted
|
143 |
+
yield {
|
144 |
+
'start': start,
|
145 |
+
'end': end,
|
146 |
+
'text': __join_words(
|
147 |
+
[
|
148 |
+
{
|
149 |
+
"word": re.sub(r"^(\s*)(.*)$", r"\1<u>\2</u>", word)
|
150 |
+
if j == i
|
151 |
+
else word,
|
152 |
+
# The HTML tags <u> and </u> are not displayed,
|
153 |
+
# # so they should not be counted in the word length
|
154 |
+
"length": len(word)
|
155 |
+
} for j, word in enumerate(text_words)
|
156 |
+
], maxLineWidth)
|
157 |
+
}
|
158 |
+
last = end
|
159 |
+
|
160 |
+
if last != subtitle_end:
|
161 |
+
# Display the last part of the text
|
162 |
+
yield {
|
163 |
+
'start': last,
|
164 |
+
'end': subtitle_end,
|
165 |
+
'text': subtitle_text
|
166 |
+
}
|
167 |
+
|
168 |
+
# Just return the subtitle text
|
169 |
+
else:
|
170 |
+
yield {
|
171 |
+
'start': subtitle_start,
|
172 |
+
'end': subtitle_end,
|
173 |
+
'text': subtitle_text
|
174 |
+
}
|
175 |
+
|
176 |
+
def __join_words(words: Iterator[Union[str, dict]], maxLineWidth: int = None):
|
177 |
+
if maxLineWidth is None or maxLineWidth < 0:
|
178 |
+
return " ".join(words)
|
179 |
+
|
180 |
+
lines = []
|
181 |
+
current_line = ""
|
182 |
+
current_length = 0
|
183 |
+
|
184 |
+
for entry in words:
|
185 |
+
# Either accept a string or a dict with a 'word' and 'length' field
|
186 |
+
if isinstance(entry, dict):
|
187 |
+
word = entry['word']
|
188 |
+
word_length = entry['length']
|
189 |
+
else:
|
190 |
+
word = entry
|
191 |
+
word_length = len(word)
|
192 |
+
|
193 |
+
if current_length > 0 and current_length + word_length > maxLineWidth:
|
194 |
+
lines.append(current_line)
|
195 |
+
current_line = ""
|
196 |
+
current_length = 0
|
197 |
+
|
198 |
+
current_length += word_length
|
199 |
+
# The word will be prefixed with a space by Whisper, so we don't need to add one here
|
200 |
+
current_line += word
|
201 |
+
|
202 |
+
if len(current_line) > 0:
|
203 |
+
lines.append(current_line)
|
204 |
+
|
205 |
+
return "\n".join(lines)
|
206 |
+
|
207 |
def process_text(text: str, maxLineWidth=None):
|
208 |
if (maxLineWidth is None or maxLineWidth < 0):
|
209 |
return text
|
src/vad.py
CHANGED
@@ -404,6 +404,14 @@ class AbstractTranscription(ABC):
|
|
404 |
# Add to start and end
|
405 |
new_segment['start'] = segment_start + adjust_seconds
|
406 |
new_segment['end'] = segment_end + adjust_seconds
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
407 |
result.append(new_segment)
|
408 |
return result
|
409 |
|
|
|
404 |
# Add to start and end
|
405 |
new_segment['start'] = segment_start + adjust_seconds
|
406 |
new_segment['end'] = segment_end + adjust_seconds
|
407 |
+
|
408 |
+
# Handle words
|
409 |
+
if ('words' in new_segment):
|
410 |
+
for word in new_segment['words']:
|
411 |
+
# Adjust start and end
|
412 |
+
word['start'] = word['start'] + adjust_seconds
|
413 |
+
word['end'] = word['end'] + adjust_seconds
|
414 |
+
|
415 |
result.append(new_segment)
|
416 |
return result
|
417 |
|
src/whisper/whisperContainer.py
CHANGED
@@ -203,8 +203,9 @@ class WhisperCallback(AbstractWhisperCallback):
|
|
203 |
|
204 |
initial_prompt = self._get_initial_prompt(self.initial_prompt, self.initial_prompt_mode, prompt, segment_index)
|
205 |
|
206 |
-
|
207 |
language=self.language if self.language else detected_language, task=self.task, \
|
208 |
initial_prompt=initial_prompt, \
|
209 |
**decodeOptions
|
210 |
-
)
|
|
|
|
203 |
|
204 |
initial_prompt = self._get_initial_prompt(self.initial_prompt, self.initial_prompt_mode, prompt, segment_index)
|
205 |
|
206 |
+
result = model.transcribe(audio, \
|
207 |
language=self.language if self.language else detected_language, task=self.task, \
|
208 |
initial_prompt=initial_prompt, \
|
209 |
**decodeOptions
|
210 |
+
)
|
211 |
+
return result
|