Spaces:
Runtime error
Runtime error
File size: 10,306 Bytes
8d120bf 05a2178 3fadc6e 8d120bf 3fadc6e 8d120bf 05a2178 883c794 c52f09b 7ce6041 4514e2e 7ce6041 533d92e 6a308c6 93c4867 05a2178 883c794 05a2178 f288ceb 71950a8 883c794 05a2178 883c794 533d92e 883c794 fdd892b 883c794 533d92e fdd892b 883c794 533d92e bbbf06e fdd892b f288ceb bc0cb58 f288ceb 084aa80 7f502b4 084aa80 bc0cb58 084aa80 bc0cb58 084aa80 7f502b4 bc0cb58 bbbf06e 084aa80 bbbf06e f288ceb bbbf06e fdd892b 71950a8 fdd892b 883c794 533d92e fdd892b 883c794 6a308c6 fdd892b 6a308c6 fdd892b 883c794 71950a8 fdd892b 71950a8 fdd892b 883c794 fdd892b 3fadc6e 883c794 fdd892b 883c794 fdd892b 3fadc6e 8f5637c fdd892b 3fadc6e fdd892b 8d120bf 883c794 6a308c6 883c794 3fadc6e 883c794 3fadc6e 883c794 7ce6041 883c794 05a2178 883c794 05a2178 71950a8 05a2178 084aa80 38cc8a7 71950a8 93c4867 883c794 084aa80 883c794 71950a8 8d120bf 71950a8 bc0cb58 b1d4eff 7f502b4 3fadc6e 8d120bf 3fadc6e 8d120bf 3fadc6e 7ce6041 d5154e9 05a2178 71950a8 883c794 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 |
from typing import Iterator
from io import StringIO
import os
import pathlib
import tempfile
# External programs
import whisper
import ffmpeg
# UI
import gradio as gr
from src.download import ExceededMaximumDuration, download_url
from src.utils import slugify, write_srt, write_vtt
from src.vad import VadPeriodicTranscription, VadSileroTranscription
# Limitations (set to -1 to disable)
DEFAULT_INPUT_AUDIO_MAX_DURATION = 600 # seconds
# Whether or not to automatically delete all uploaded files, to save disk space
DELETE_UPLOADED_FILES = True
# Gradio seems to truncate files without keeping the extension, so we need to truncate the file prefix ourself
MAX_FILE_PREFIX_LENGTH = 17
LANGUAGES = [
"English", "Chinese", "German", "Spanish", "Russian", "Korean",
"French", "Japanese", "Portuguese", "Turkish", "Polish", "Catalan",
"Dutch", "Arabic", "Swedish", "Italian", "Indonesian", "Hindi",
"Finnish", "Vietnamese", "Hebrew", "Ukrainian", "Greek", "Malay",
"Czech", "Romanian", "Danish", "Hungarian", "Tamil", "Norwegian",
"Thai", "Urdu", "Croatian", "Bulgarian", "Lithuanian", "Latin",
"Maori", "Malayalam", "Welsh", "Slovak", "Telugu", "Persian",
"Latvian", "Bengali", "Serbian", "Azerbaijani", "Slovenian",
"Kannada", "Estonian", "Macedonian", "Breton", "Basque", "Icelandic",
"Armenian", "Nepali", "Mongolian", "Bosnian", "Kazakh", "Albanian",
"Swahili", "Galician", "Marathi", "Punjabi", "Sinhala", "Khmer",
"Shona", "Yoruba", "Somali", "Afrikaans", "Occitan", "Georgian",
"Belarusian", "Tajik", "Sindhi", "Gujarati", "Amharic", "Yiddish",
"Lao", "Uzbek", "Faroese", "Haitian Creole", "Pashto", "Turkmen",
"Nynorsk", "Maltese", "Sanskrit", "Luxembourgish", "Myanmar", "Tibetan",
"Tagalog", "Malagasy", "Assamese", "Tatar", "Hawaiian", "Lingala",
"Hausa", "Bashkir", "Javanese", "Sundanese"
]
class WhisperTranscriber:
def __init__(self, inputAudioMaxDuration: float = DEFAULT_INPUT_AUDIO_MAX_DURATION, deleteUploadedFiles: bool = DELETE_UPLOADED_FILES):
self.model_cache = dict()
self.vad_model = None
self.inputAudioMaxDuration = inputAudioMaxDuration
self.deleteUploadedFiles = deleteUploadedFiles
def transcribe_file(self, modelName, languageName, urlData, uploadFile, microphoneData, task, vad, vadMergeWindow, vadMaxMergeSize, vadPadding):
try:
source, sourceName = self.__get_source(urlData, uploadFile, microphoneData)
try:
selectedLanguage = languageName.lower() if len(languageName) > 0 else None
selectedModel = modelName if modelName is not None else "base"
model = self.model_cache.get(selectedModel, None)
if not model:
model = whisper.load_model(selectedModel)
self.model_cache[selectedModel] = model
# Callable for processing an audio file
whisperCallable = lambda audio : model.transcribe(audio, language=selectedLanguage, task=task)
# The results
if (vad == 'silero-vad'):
# Use Silero VAD and include gaps
if (self.vad_model is None):
self.vad_model = VadSileroTranscription()
process_gaps = VadSileroTranscription(transcribe_non_speech = True,
max_silent_period=vadMergeWindow, max_merge_size=vadMaxMergeSize,
segment_padding_left=vadPadding, segment_padding_right=vadPadding, copy=self.vad_model)
result = process_gaps.transcribe(source, whisperCallable)
elif (vad == 'silero-vad-skip-gaps'):
# Use Silero VAD
if (self.vad_model is None):
self.vad_model = VadSileroTranscription()
skip_gaps = VadSileroTranscription(transcribe_non_speech = False,
max_silent_period=vadMergeWindow, max_merge_size=vadMaxMergeSize,
segment_padding_left=vadPadding, segment_padding_right=vadPadding, copy=self.vad_model)
result = skip_gaps.transcribe(source, whisperCallable)
elif (vad == 'periodic-vad'):
# Very simple VAD - mark every 5 minutes as speech. This makes it less likely that Whisper enters an infinite loop, but
# it may create a break in the middle of a sentence, causing some artifacts.
periodic_vad = VadPeriodicTranscription(periodic_duration=vadMaxMergeSize)
result = periodic_vad.transcribe(source, whisperCallable)
else:
# Default VAD
result = whisperCallable(source)
text = result["text"]
language = result["language"]
languageMaxLineWidth = self.__get_max_line_width(language)
print("Max line width " + str(languageMaxLineWidth))
vtt = self.__get_subs(result["segments"], "vtt", languageMaxLineWidth)
srt = self.__get_subs(result["segments"], "srt", languageMaxLineWidth)
# Files that can be downloaded
downloadDirectory = tempfile.mkdtemp()
filePrefix = slugify(sourceName, allow_unicode=True)
download = []
download.append(self.__create_file(srt, downloadDirectory, filePrefix + "-subs.srt"));
download.append(self.__create_file(vtt, downloadDirectory, filePrefix + "-subs.vtt"));
download.append(self.__create_file(text, downloadDirectory, filePrefix + "-transcript.txt"));
return download, text, vtt
finally:
# Cleanup source
if self.deleteUploadedFiles:
print("Deleting source file " + source)
os.remove(source)
except ExceededMaximumDuration as e:
return [], ("[ERROR]: Maximum remote video length is " + str(e.maxDuration) + "s, file was " + str(e.videoDuration) + "s"), "[ERROR]"
def clear_cache(self):
self.model_cache = dict()
def __get_source(self, urlData, uploadFile, microphoneData):
if urlData:
# Download from YouTube
source = download_url(urlData, self.inputAudioMaxDuration)
else:
# File input
source = uploadFile if uploadFile is not None else microphoneData
if self.inputAudioMaxDuration > 0:
# Calculate audio length
audioDuration = ffmpeg.probe(source)["format"]["duration"]
if float(audioDuration) > self.inputAudioMaxDuration:
raise ExceededMaximumDuration(videoDuration=audioDuration, maxDuration=self.inputAudioMaxDuration, message="Video is too long")
file_path = pathlib.Path(source)
sourceName = file_path.stem[:MAX_FILE_PREFIX_LENGTH] + file_path.suffix
return source, sourceName
def __get_max_line_width(self, language: str) -> int:
if (language and language.lower() in ["japanese", "ja", "chinese", "zh"]):
# Chinese characters and kana are wider, so limit line length to 40 characters
return 40
else:
# TODO: Add more languages
# 80 latin characters should fit on a 1080p/720p screen
return 80
def __get_subs(self, segments: Iterator[dict], format: str, maxLineWidth: int) -> str:
segmentStream = StringIO()
if format == 'vtt':
write_vtt(segments, file=segmentStream, maxLineWidth=maxLineWidth)
elif format == 'srt':
write_srt(segments, file=segmentStream, maxLineWidth=maxLineWidth)
else:
raise Exception("Unknown format " + format)
segmentStream.seek(0)
return segmentStream.read()
def __create_file(self, text: str, directory: str, fileName: str) -> str:
# Write the text to a file
with open(os.path.join(directory, fileName), 'w+', encoding="utf-8") as file:
file.write(text)
return file.name
def create_ui(inputAudioMaxDuration, share=False, server_name: str = None):
ui = WhisperTranscriber(inputAudioMaxDuration)
ui_description = "Whisper is a general-purpose speech recognition model. It is trained on a large dataset of diverse "
ui_description += " audio and is also a multi-task model that can perform multilingual speech recognition "
ui_description += " as well as speech translation and language identification. "
ui_description += "\n\n\n\nFor longer audio files (>10 minutes), it is recommended that you select Silero VAD (Voice Activity Detector) in the VAD option."
if inputAudioMaxDuration > 0:
ui_description += "\n\n" + "Max audio file length: " + str(inputAudioMaxDuration) + " s"
ui_article = "Read the [documentation here](https://huggingface.co/spaces/aadnk/whisper-webui/blob/main/docs/options.md)"
demo = gr.Interface(fn=ui.transcribe_file, description=ui_description, article=ui_article, inputs=[
gr.Dropdown(choices=["tiny", "base", "small", "medium", "large"], value="medium", label="Model"),
gr.Dropdown(choices=sorted(LANGUAGES), label="Language"),
gr.Text(label="URL (YouTube, etc.)"),
gr.Audio(source="upload", type="filepath", label="Upload Audio"),
gr.Audio(source="microphone", type="filepath", label="Microphone Input"),
gr.Dropdown(choices=["transcribe", "translate"], label="Task"),
gr.Dropdown(choices=["none", "silero-vad", "silero-vad-skip-gaps", "periodic-vad"], label="VAD"),
gr.Number(label="VAD - Merge Window (s)", precision=0, value=5),
gr.Number(label="VAD - Max Merge Size (s)", precision=0, value=150),
gr.Number(label="VAD - Padding (s)", precision=None, value=1)
], outputs=[
gr.File(label="Download"),
gr.Text(label="Transcription"),
gr.Text(label="Segments")
])
demo.launch(share=share, server_name=server_name)
if __name__ == '__main__':
create_ui(DEFAULT_INPUT_AUDIO_MAX_DURATION) |