Spaces:
Runtime error
Runtime error
Revert "whisperX merged in"
Browse filesThis reverts commit 9c2d684d1a778e559b6b13bb90112712dbc20568.
- .gitignore +0 -2
- EXAMPLES.md +0 -37
- LICENSE +0 -27
- MANIFEST.in +0 -4
- figures/pipeline.png +0 -0
- requirements.txt +0 -10
- setup.py +0 -28
- whisperx/__init__.py +0 -3
- whisperx/__main__.py +0 -4
- whisperx/alignment.py +0 -548
- whisperx/asr.py +0 -429
- whisperx/diarize.py +0 -76
- whisperx/transcribe.py +0 -220
- whisperx/utils.py +0 -317
- whisperx/vad.py +0 -305
.gitignore
DELETED
@@ -1,2 +0,0 @@
|
|
1 |
-
whisperx.egg-info/
|
2 |
-
**/__pycache__/
|
|
|
|
|
|
EXAMPLES.md
DELETED
@@ -1,37 +0,0 @@
|
|
1 |
-
# More Examples
|
2 |
-
|
3 |
-
## Other Languages
|
4 |
-
|
5 |
-
For non-english ASR, it is best to use the `large` whisper model. Alignment models are automatically picked by the chosen language from the default [lists](https://github.com/m-bain/whisperX/blob/main/whisperx/alignment.py#L18).
|
6 |
-
|
7 |
-
Currently support default models tested for {en, fr, de, es, it, ja, zh, nl}
|
8 |
-
|
9 |
-
|
10 |
-
If the detected language is not in this list, you need to find a phoneme-based ASR model from [huggingface model hub](https://huggingface.co/models) and test it on your data.
|
11 |
-
|
12 |
-
### French
|
13 |
-
whisperx --model large --language fr examples/sample_fr_01.wav
|
14 |
-
|
15 |
-
|
16 |
-
https://user-images.githubusercontent.com/36994049/208298804-31c49d6f-6787-444e-a53f-e93c52706752.mov
|
17 |
-
|
18 |
-
|
19 |
-
### German
|
20 |
-
whisperx --model large --language de examples/sample_de_01.wav
|
21 |
-
|
22 |
-
|
23 |
-
https://user-images.githubusercontent.com/36994049/208298811-e36002ba-3698-4731-97d4-0aebd07e0eb3.mov
|
24 |
-
|
25 |
-
|
26 |
-
### Italian
|
27 |
-
whisperx --model large --language de examples/sample_it_01.wav
|
28 |
-
|
29 |
-
|
30 |
-
https://user-images.githubusercontent.com/36994049/208298819-6f462b2c-8cae-4c54-b8e1-90855794efc7.mov
|
31 |
-
|
32 |
-
|
33 |
-
### Japanese
|
34 |
-
whisperx --model large --language ja examples/sample_ja_01.wav
|
35 |
-
|
36 |
-
|
37 |
-
https://user-images.githubusercontent.com/19920981/208731743-311f2360-b73b-4c60-809d-aaf3cd7e06f4.mov
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
LICENSE
DELETED
@@ -1,27 +0,0 @@
|
|
1 |
-
Copyright (c) 2022, Max Bain
|
2 |
-
All rights reserved.
|
3 |
-
|
4 |
-
Redistribution and use in source and binary forms, with or without
|
5 |
-
modification, are permitted provided that the following conditions are met:
|
6 |
-
1. Redistributions of source code must retain the above copyright
|
7 |
-
notice, this list of conditions and the following disclaimer.
|
8 |
-
2. Redistributions in binary form must reproduce the above copyright
|
9 |
-
notice, this list of conditions and the following disclaimer in the
|
10 |
-
documentation and/or other materials provided with the distribution.
|
11 |
-
3. All advertising materials mentioning features or use of this software
|
12 |
-
must display the following acknowledgement:
|
13 |
-
This product includes software developed by Max Bain.
|
14 |
-
4. Neither the name of Max Bain nor the
|
15 |
-
names of its contributors may be used to endorse or promote products
|
16 |
-
derived from this software without specific prior written permission.
|
17 |
-
|
18 |
-
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDER ''AS IS'' AND ANY
|
19 |
-
EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
|
20 |
-
WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
21 |
-
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
22 |
-
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
23 |
-
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
24 |
-
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
25 |
-
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
26 |
-
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE
|
27 |
-
USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
MANIFEST.in
DELETED
@@ -1,4 +0,0 @@
|
|
1 |
-
include whisperx/assets/*
|
2 |
-
include whisperx/assets/gpt2/*
|
3 |
-
include whisperx/assets/multilingual/*
|
4 |
-
include whisperx/normalizers/english.json
|
|
|
|
|
|
|
|
|
|
figures/pipeline.png
DELETED
Binary file (123 kB)
|
|
requirements.txt
DELETED
@@ -1,10 +0,0 @@
|
|
1 |
-
numpy
|
2 |
-
pandas
|
3 |
-
torch >=1.9
|
4 |
-
torchaudio >=0.10,<1.0
|
5 |
-
tqdm
|
6 |
-
more-itertools
|
7 |
-
transformers>=4.19.0
|
8 |
-
ffmpeg-python==0.2.0
|
9 |
-
pyannote.audio
|
10 |
-
openai-whisper==20230314
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
setup.py
DELETED
@@ -1,28 +0,0 @@
|
|
1 |
-
import os
|
2 |
-
|
3 |
-
import pkg_resources
|
4 |
-
from setuptools import setup, find_packages
|
5 |
-
|
6 |
-
setup(
|
7 |
-
name="whisperx",
|
8 |
-
py_modules=["whisperx"],
|
9 |
-
version="2.0",
|
10 |
-
description="Time-Accurate Automatic Speech Recognition using Whisper.",
|
11 |
-
readme="README.md",
|
12 |
-
python_requires=">=3.8",
|
13 |
-
author="Max Bain",
|
14 |
-
url="https://github.com/m-bain/whisperx",
|
15 |
-
license="MIT",
|
16 |
-
packages=find_packages(exclude=["tests*"]),
|
17 |
-
install_requires=[
|
18 |
-
str(r)
|
19 |
-
for r in pkg_resources.parse_requirements(
|
20 |
-
open(os.path.join(os.path.dirname(__file__), "requirements.txt"))
|
21 |
-
)
|
22 |
-
],
|
23 |
-
entry_points = {
|
24 |
-
'console_scripts': ['whisperx=whisperx.transcribe:cli'],
|
25 |
-
},
|
26 |
-
include_package_data=True,
|
27 |
-
extras_require={'dev': ['pytest']},
|
28 |
-
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
whisperx/__init__.py
DELETED
@@ -1,3 +0,0 @@
|
|
1 |
-
from .transcribe import transcribe, transcribe_with_vad
|
2 |
-
from .alignment import load_align_model, align
|
3 |
-
from .vad import load_vad_model
|
|
|
|
|
|
|
|
whisperx/__main__.py
DELETED
@@ -1,4 +0,0 @@
|
|
1 |
-
from .transcribe import cli
|
2 |
-
|
3 |
-
|
4 |
-
cli()
|
|
|
|
|
|
|
|
|
|
whisperx/alignment.py
DELETED
@@ -1,548 +0,0 @@
|
|
1 |
-
""""
|
2 |
-
Forced Alignment with Whisper
|
3 |
-
C. Max Bain
|
4 |
-
"""
|
5 |
-
import numpy as np
|
6 |
-
import pandas as pd
|
7 |
-
from typing import List, Union, Iterator, TYPE_CHECKING
|
8 |
-
from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor
|
9 |
-
import torchaudio
|
10 |
-
import torch
|
11 |
-
from dataclasses import dataclass
|
12 |
-
from whisper.audio import SAMPLE_RATE, load_audio
|
13 |
-
from .utils import interpolate_nans
|
14 |
-
|
15 |
-
|
16 |
-
LANGUAGES_WITHOUT_SPACES = ["ja", "zh"]
|
17 |
-
|
18 |
-
DEFAULT_ALIGN_MODELS_TORCH = {
|
19 |
-
"en": "WAV2VEC2_ASR_BASE_960H",
|
20 |
-
"fr": "VOXPOPULI_ASR_BASE_10K_FR",
|
21 |
-
"de": "VOXPOPULI_ASR_BASE_10K_DE",
|
22 |
-
"es": "VOXPOPULI_ASR_BASE_10K_ES",
|
23 |
-
"it": "VOXPOPULI_ASR_BASE_10K_IT",
|
24 |
-
}
|
25 |
-
|
26 |
-
DEFAULT_ALIGN_MODELS_HF = {
|
27 |
-
"ja": "jonatasgrosman/wav2vec2-large-xlsr-53-japanese",
|
28 |
-
"zh": "jonatasgrosman/wav2vec2-large-xlsr-53-chinese-zh-cn",
|
29 |
-
"nl": "jonatasgrosman/wav2vec2-large-xlsr-53-dutch",
|
30 |
-
"uk": "Yehor/wav2vec2-xls-r-300m-uk-with-small-lm",
|
31 |
-
"pt": "jonatasgrosman/wav2vec2-large-xlsr-53-portuguese",
|
32 |
-
"ar": "jonatasgrosman/wav2vec2-large-xlsr-53-arabic",
|
33 |
-
"ru": "jonatasgrosman/wav2vec2-large-xlsr-53-russian",
|
34 |
-
"pl": "jonatasgrosman/wav2vec2-large-xlsr-53-polish",
|
35 |
-
"hu": "jonatasgrosman/wav2vec2-large-xlsr-53-hungarian",
|
36 |
-
"fi": "jonatasgrosman/wav2vec2-large-xlsr-53-finnish",
|
37 |
-
"fa": "jonatasgrosman/wav2vec2-large-xlsr-53-persian",
|
38 |
-
"el": "jonatasgrosman/wav2vec2-large-xlsr-53-greek",
|
39 |
-
"tr": "mpoyraz/wav2vec2-xls-r-300m-cv7-turkish",
|
40 |
-
}
|
41 |
-
|
42 |
-
|
43 |
-
def load_align_model(language_code, device, model_name=None):
|
44 |
-
if model_name is None:
|
45 |
-
# use default model
|
46 |
-
if language_code in DEFAULT_ALIGN_MODELS_TORCH:
|
47 |
-
model_name = DEFAULT_ALIGN_MODELS_TORCH[language_code]
|
48 |
-
elif language_code in DEFAULT_ALIGN_MODELS_HF:
|
49 |
-
model_name = DEFAULT_ALIGN_MODELS_HF[language_code]
|
50 |
-
else:
|
51 |
-
print(f"There is no default alignment model set for this language ({language_code}).\
|
52 |
-
Please find a wav2vec2.0 model finetuned on this language in https://huggingface.co/models, then pass the model name in --align_model [MODEL_NAME]")
|
53 |
-
raise ValueError(f"No default align-model for language: {language_code}")
|
54 |
-
|
55 |
-
if model_name in torchaudio.pipelines.__all__:
|
56 |
-
pipeline_type = "torchaudio"
|
57 |
-
bundle = torchaudio.pipelines.__dict__[model_name]
|
58 |
-
align_model = bundle.get_model().to(device)
|
59 |
-
labels = bundle.get_labels()
|
60 |
-
align_dictionary = {c.lower(): i for i, c in enumerate(labels)}
|
61 |
-
else:
|
62 |
-
try:
|
63 |
-
processor = Wav2Vec2Processor.from_pretrained(model_name)
|
64 |
-
align_model = Wav2Vec2ForCTC.from_pretrained(model_name)
|
65 |
-
except Exception as e:
|
66 |
-
print(e)
|
67 |
-
print(f"Error loading model from huggingface, check https://huggingface.co/models for finetuned wav2vec2.0 models")
|
68 |
-
raise ValueError(f'The chosen align_model "{model_name}" could not be found in huggingface (https://huggingface.co/models) or torchaudio (https://pytorch.org/audio/stable/pipelines.html#id14)')
|
69 |
-
pipeline_type = "huggingface"
|
70 |
-
align_model = align_model.to(device)
|
71 |
-
labels = processor.tokenizer.get_vocab()
|
72 |
-
align_dictionary = {char.lower(): code for char,code in processor.tokenizer.get_vocab().items()}
|
73 |
-
|
74 |
-
align_metadata = {"language": language_code, "dictionary": align_dictionary, "type": pipeline_type}
|
75 |
-
|
76 |
-
return align_model, align_metadata
|
77 |
-
|
78 |
-
|
79 |
-
def align(
|
80 |
-
transcript: Iterator[dict],
|
81 |
-
model: torch.nn.Module,
|
82 |
-
align_model_metadata: dict,
|
83 |
-
audio: Union[str, np.ndarray, torch.Tensor],
|
84 |
-
device: str,
|
85 |
-
extend_duration: float = 0.0,
|
86 |
-
start_from_previous: bool = True,
|
87 |
-
interpolate_method: str = "nearest",
|
88 |
-
):
|
89 |
-
"""
|
90 |
-
Force align phoneme recognition predictions to known transcription
|
91 |
-
|
92 |
-
Parameters
|
93 |
-
----------
|
94 |
-
transcript: Iterator[dict]
|
95 |
-
The Whisper model instance
|
96 |
-
|
97 |
-
model: torch.nn.Module
|
98 |
-
Alignment model (wav2vec2)
|
99 |
-
|
100 |
-
audio: Union[str, np.ndarray, torch.Tensor]
|
101 |
-
The path to the audio file to open, or the audio waveform
|
102 |
-
|
103 |
-
device: str
|
104 |
-
cuda device
|
105 |
-
|
106 |
-
diarization: pd.DataFrame {'start': List[float], 'end': List[float], 'speaker': List[float]}
|
107 |
-
diarization segments with speaker labels.
|
108 |
-
|
109 |
-
extend_duration: float
|
110 |
-
Amount to pad input segments by. If not using vad--filter then recommended to use 2 seconds
|
111 |
-
|
112 |
-
If the gzip compression ratio is above this value, treat as failed
|
113 |
-
|
114 |
-
interpolate_method: str ["nearest", "linear", "ignore"]
|
115 |
-
Method to assign timestamps to non-aligned words. Words are not able to be aligned when none of the characters occur in the align model dictionary.
|
116 |
-
"nearest" copies timestamp of nearest word within the segment. "linear" is linear interpolation. "drop" removes that word from output.
|
117 |
-
|
118 |
-
Returns
|
119 |
-
-------
|
120 |
-
A dictionary containing the resulting text ("text") and segment-level details ("segments"), and
|
121 |
-
the spoken language ("language"), which is detected when `decode_options["language"]` is None.
|
122 |
-
"""
|
123 |
-
if not torch.is_tensor(audio):
|
124 |
-
if isinstance(audio, str):
|
125 |
-
audio = load_audio(audio)
|
126 |
-
audio = torch.from_numpy(audio)
|
127 |
-
if len(audio.shape) == 1:
|
128 |
-
audio = audio.unsqueeze(0)
|
129 |
-
|
130 |
-
MAX_DURATION = audio.shape[1] / SAMPLE_RATE
|
131 |
-
|
132 |
-
model_dictionary = align_model_metadata["dictionary"]
|
133 |
-
model_lang = align_model_metadata["language"]
|
134 |
-
model_type = align_model_metadata["type"]
|
135 |
-
|
136 |
-
aligned_segments = []
|
137 |
-
|
138 |
-
prev_t2 = 0
|
139 |
-
|
140 |
-
char_segments_arr = {
|
141 |
-
"segment-idx": [],
|
142 |
-
"subsegment-idx": [],
|
143 |
-
"word-idx": [],
|
144 |
-
"char": [],
|
145 |
-
"start": [],
|
146 |
-
"end": [],
|
147 |
-
"score": [],
|
148 |
-
}
|
149 |
-
|
150 |
-
for sdx, segment in enumerate(transcript):
|
151 |
-
while True:
|
152 |
-
segment_align_success = False
|
153 |
-
|
154 |
-
# strip spaces at beginning / end, but keep track of the amount.
|
155 |
-
num_leading = len(segment["text"]) - len(segment["text"].lstrip())
|
156 |
-
num_trailing = len(segment["text"]) - len(segment["text"].rstrip())
|
157 |
-
transcription = segment["text"]
|
158 |
-
|
159 |
-
# TODO: convert number tokenizer / symbols to phonetic words for alignment.
|
160 |
-
# e.g. "$300" -> "three hundred dollars"
|
161 |
-
# currently "$300" is ignored since no characters present in the phonetic dictionary
|
162 |
-
|
163 |
-
# split into words
|
164 |
-
if model_lang not in LANGUAGES_WITHOUT_SPACES:
|
165 |
-
per_word = transcription.split(" ")
|
166 |
-
else:
|
167 |
-
per_word = transcription
|
168 |
-
|
169 |
-
# first check that characters in transcription can be aligned (they are contained in align model"s dictionary)
|
170 |
-
clean_char, clean_cdx = [], []
|
171 |
-
for cdx, char in enumerate(transcription):
|
172 |
-
char_ = char.lower()
|
173 |
-
# wav2vec2 models use "|" character to represent spaces
|
174 |
-
if model_lang not in LANGUAGES_WITHOUT_SPACES:
|
175 |
-
char_ = char_.replace(" ", "|")
|
176 |
-
|
177 |
-
# ignore whitespace at beginning and end of transcript
|
178 |
-
if cdx < num_leading:
|
179 |
-
pass
|
180 |
-
elif cdx > len(transcription) - num_trailing - 1:
|
181 |
-
pass
|
182 |
-
elif char_ in model_dictionary.keys():
|
183 |
-
clean_char.append(char_)
|
184 |
-
clean_cdx.append(cdx)
|
185 |
-
|
186 |
-
clean_wdx = []
|
187 |
-
for wdx, wrd in enumerate(per_word):
|
188 |
-
if any([c in model_dictionary.keys() for c in wrd]):
|
189 |
-
clean_wdx.append(wdx)
|
190 |
-
|
191 |
-
# if no characters are in the dictionary, then we skip this segment...
|
192 |
-
if len(clean_char) == 0:
|
193 |
-
print(f'Failed to align segment ("{segment["text"]}"): no characters in this segment found in model dictionary, resorting to original...')
|
194 |
-
break
|
195 |
-
|
196 |
-
transcription_cleaned = "".join(clean_char)
|
197 |
-
tokens = [model_dictionary[c] for c in transcription_cleaned]
|
198 |
-
|
199 |
-
# we only pad if not using VAD filtering
|
200 |
-
if "seg_text" not in segment:
|
201 |
-
# pad according original timestamps
|
202 |
-
t1 = max(segment["start"] - extend_duration, 0)
|
203 |
-
t2 = min(segment["end"] + extend_duration, MAX_DURATION)
|
204 |
-
|
205 |
-
# use prev_t2 as current t1 if it"s later
|
206 |
-
if start_from_previous and t1 < prev_t2:
|
207 |
-
t1 = prev_t2
|
208 |
-
|
209 |
-
# check if timestamp range is still valid
|
210 |
-
if t1 >= MAX_DURATION:
|
211 |
-
print("Failed to align segment: original start time longer than audio duration, skipping...")
|
212 |
-
break
|
213 |
-
if t2 - t1 < 0.02:
|
214 |
-
print("Failed to align segment: duration smaller than 0.02s time precision")
|
215 |
-
break
|
216 |
-
|
217 |
-
f1 = int(t1 * SAMPLE_RATE)
|
218 |
-
f2 = int(t2 * SAMPLE_RATE)
|
219 |
-
|
220 |
-
waveform_segment = audio[:, f1:f2]
|
221 |
-
|
222 |
-
with torch.inference_mode():
|
223 |
-
if model_type == "torchaudio":
|
224 |
-
emissions, _ = model(waveform_segment.to(device))
|
225 |
-
elif model_type == "huggingface":
|
226 |
-
emissions = model(waveform_segment.to(device)).logits
|
227 |
-
else:
|
228 |
-
raise NotImplementedError(f"Align model of type {model_type} not supported.")
|
229 |
-
emissions = torch.log_softmax(emissions, dim=-1)
|
230 |
-
|
231 |
-
emission = emissions[0].cpu().detach()
|
232 |
-
|
233 |
-
trellis = get_trellis(emission, tokens)
|
234 |
-
path = backtrack(trellis, emission, tokens)
|
235 |
-
if path is None:
|
236 |
-
print(f'Failed to align segment ("{segment["text"]}"): backtrack failed, resorting to original...')
|
237 |
-
break
|
238 |
-
char_segments = merge_repeats(path, transcription_cleaned)
|
239 |
-
# word_segments = merge_words(char_segments)
|
240 |
-
|
241 |
-
|
242 |
-
# sub-segments
|
243 |
-
if "seg-text" not in segment:
|
244 |
-
segment["seg-text"] = [transcription]
|
245 |
-
|
246 |
-
seg_lens = [0] + [len(x) for x in segment["seg-text"]]
|
247 |
-
seg_lens_cumsum = list(np.cumsum(seg_lens))
|
248 |
-
sub_seg_idx = 0
|
249 |
-
|
250 |
-
wdx = 0
|
251 |
-
duration = t2 - t1
|
252 |
-
ratio = duration * waveform_segment.size(0) / (trellis.size(0) - 1)
|
253 |
-
for cdx, char in enumerate(transcription + " "):
|
254 |
-
is_last = False
|
255 |
-
if cdx == len(transcription):
|
256 |
-
break
|
257 |
-
elif cdx+1 == len(transcription):
|
258 |
-
is_last = True
|
259 |
-
|
260 |
-
|
261 |
-
start, end, score = None, None, None
|
262 |
-
if cdx in clean_cdx:
|
263 |
-
char_seg = char_segments[clean_cdx.index(cdx)]
|
264 |
-
start = char_seg.start * ratio + t1
|
265 |
-
end = char_seg.end * ratio + t1
|
266 |
-
score = char_seg.score
|
267 |
-
|
268 |
-
char_segments_arr["char"].append(char)
|
269 |
-
char_segments_arr["start"].append(start)
|
270 |
-
char_segments_arr["end"].append(end)
|
271 |
-
char_segments_arr["score"].append(score)
|
272 |
-
char_segments_arr["word-idx"].append(wdx)
|
273 |
-
char_segments_arr["segment-idx"].append(sdx)
|
274 |
-
char_segments_arr["subsegment-idx"].append(sub_seg_idx)
|
275 |
-
|
276 |
-
# word-level info
|
277 |
-
if model_lang in LANGUAGES_WITHOUT_SPACES:
|
278 |
-
# character == word
|
279 |
-
wdx += 1
|
280 |
-
elif is_last or transcription[cdx+1] == " " or cdx == seg_lens_cumsum[sub_seg_idx+1] - 1:
|
281 |
-
wdx += 1
|
282 |
-
|
283 |
-
if is_last or cdx == seg_lens_cumsum[sub_seg_idx+1] - 1:
|
284 |
-
wdx = 0
|
285 |
-
sub_seg_idx += 1
|
286 |
-
|
287 |
-
prev_t2 = segment["end"]
|
288 |
-
|
289 |
-
segment_align_success = True
|
290 |
-
# end while True loop
|
291 |
-
break
|
292 |
-
|
293 |
-
# reset prev_t2 due to drifting issues
|
294 |
-
if not segment_align_success:
|
295 |
-
prev_t2 = 0
|
296 |
-
|
297 |
-
char_segments_arr = pd.DataFrame(char_segments_arr)
|
298 |
-
not_space = char_segments_arr["char"] != " "
|
299 |
-
|
300 |
-
per_seg_grp = char_segments_arr.groupby(["segment-idx", "subsegment-idx"], as_index = False)
|
301 |
-
char_segments_arr = per_seg_grp.apply(lambda x: x.reset_index(drop = True)).reset_index()
|
302 |
-
per_word_grp = char_segments_arr[not_space].groupby(["segment-idx", "subsegment-idx", "word-idx"])
|
303 |
-
per_subseg_grp = char_segments_arr[not_space].groupby(["segment-idx", "subsegment-idx"])
|
304 |
-
per_seg_grp = char_segments_arr[not_space].groupby(["segment-idx"])
|
305 |
-
char_segments_arr["local-char-idx"] = char_segments_arr.groupby(["segment-idx", "subsegment-idx"]).cumcount()
|
306 |
-
per_word_grp = char_segments_arr[not_space].groupby(["segment-idx", "subsegment-idx", "word-idx"]) # regroup
|
307 |
-
|
308 |
-
word_segments_arr = {}
|
309 |
-
|
310 |
-
# start of word is first char with a timestamp
|
311 |
-
word_segments_arr["start"] = per_word_grp["start"].min().values
|
312 |
-
# end of word is last char with a timestamp
|
313 |
-
word_segments_arr["end"] = per_word_grp["end"].max().values
|
314 |
-
# score of word is mean (excluding nan)
|
315 |
-
word_segments_arr["score"] = per_word_grp["score"].mean().values
|
316 |
-
|
317 |
-
word_segments_arr["segment-text-start"] = per_word_grp["local-char-idx"].min().astype(int).values
|
318 |
-
word_segments_arr["segment-text-end"] = per_word_grp["local-char-idx"].max().astype(int).values+1
|
319 |
-
word_segments_arr = pd.DataFrame(word_segments_arr)
|
320 |
-
|
321 |
-
word_segments_arr[["segment-idx", "subsegment-idx", "word-idx"]] = per_word_grp["local-char-idx"].min().reset_index()[["segment-idx", "subsegment-idx", "word-idx"]].astype(int)
|
322 |
-
segments_arr = {}
|
323 |
-
segments_arr["start"] = per_subseg_grp["start"].min().reset_index()["start"]
|
324 |
-
segments_arr["end"] = per_subseg_grp["end"].max().reset_index()["end"]
|
325 |
-
segments_arr = pd.DataFrame(segments_arr)
|
326 |
-
segments_arr[["segment-idx", "subsegment-idx-start"]] = per_subseg_grp["start"].min().reset_index()[["segment-idx", "subsegment-idx"]]
|
327 |
-
segments_arr["subsegment-idx-end"] = segments_arr["subsegment-idx-start"] + 1
|
328 |
-
|
329 |
-
# interpolate missing words / sub-segments
|
330 |
-
if interpolate_method != "ignore":
|
331 |
-
wrd_subseg_grp = word_segments_arr.groupby(["segment-idx", "subsegment-idx"], group_keys=False)
|
332 |
-
wrd_seg_grp = word_segments_arr.groupby(["segment-idx"], group_keys=False)
|
333 |
-
# we still know which word timestamps are interpolated because their score == nan
|
334 |
-
word_segments_arr["start"] = wrd_subseg_grp['start'].apply(lambda group: interpolate_nans(group, method=interpolate_method))
|
335 |
-
word_segments_arr["end"] = wrd_subseg_grp['end'].apply(lambda group: interpolate_nans(group, method=interpolate_method))
|
336 |
-
|
337 |
-
word_segments_arr["start"] = wrd_seg_grp['start'].apply(lambda group: interpolate_nans(group, method=interpolate_method))
|
338 |
-
word_segments_arr["end"] = wrd_seg_grp['end'].apply(lambda group: interpolate_nans(group, method=interpolate_method))
|
339 |
-
|
340 |
-
sub_seg_grp = segments_arr.groupby(["segment-idx"], group_keys=False)
|
341 |
-
segments_arr['start'] = sub_seg_grp['start'].apply(lambda group: interpolate_nans(group, method=interpolate_method))
|
342 |
-
segments_arr['end'] = sub_seg_grp['end'].apply(lambda group: interpolate_nans(group, method=interpolate_method))
|
343 |
-
|
344 |
-
# merge words & subsegments which are missing times
|
345 |
-
word_grp = word_segments_arr.groupby(["segment-idx", "subsegment-idx", "end"])
|
346 |
-
|
347 |
-
word_segments_arr["segment-text-start"] = word_grp["segment-text-start"].transform(min)
|
348 |
-
word_segments_arr["segment-text-end"] = word_grp["segment-text-end"].transform(max)
|
349 |
-
word_segments_arr.drop_duplicates(subset=["segment-idx", "subsegment-idx", "end"], inplace=True)
|
350 |
-
|
351 |
-
seg_grp_dup = segments_arr.groupby(["segment-idx", "start", "end"])
|
352 |
-
segments_arr["subsegment-idx-start"] = seg_grp_dup["subsegment-idx-start"].transform(min)
|
353 |
-
segments_arr["subsegment-idx-end"] = seg_grp_dup["subsegment-idx-end"].transform(max)
|
354 |
-
segments_arr.drop_duplicates(subset=["segment-idx", "subsegment-idx-start", "subsegment-idx-end"], inplace=True)
|
355 |
-
else:
|
356 |
-
word_segments_arr.dropna(inplace=True)
|
357 |
-
segments_arr.dropna(inplace=True)
|
358 |
-
|
359 |
-
# if some segments still have missing timestamps (usually because all numerals / symbols), then use original timestamps...
|
360 |
-
segments_arr['start'].fillna(pd.Series([x['start'] for x in transcript]), inplace=True)
|
361 |
-
segments_arr['end'].fillna(pd.Series([x['end'] for x in transcript]), inplace=True)
|
362 |
-
segments_arr['subsegment-idx-start'].fillna(0, inplace=True)
|
363 |
-
segments_arr['subsegment-idx-end'].fillna(1, inplace=True)
|
364 |
-
|
365 |
-
|
366 |
-
aligned_segments = []
|
367 |
-
aligned_segments_word = []
|
368 |
-
|
369 |
-
word_segments_arr.set_index(["segment-idx", "subsegment-idx"], inplace=True)
|
370 |
-
char_segments_arr.set_index(["segment-idx", "subsegment-idx", "word-idx"], inplace=True)
|
371 |
-
|
372 |
-
for sdx, srow in segments_arr.iterrows():
|
373 |
-
|
374 |
-
seg_idx = int(srow["segment-idx"])
|
375 |
-
sub_start = int(srow["subsegment-idx-start"])
|
376 |
-
sub_end = int(srow["subsegment-idx-end"])
|
377 |
-
|
378 |
-
seg = transcript[seg_idx]
|
379 |
-
text = "".join(seg["seg-text"][sub_start:sub_end])
|
380 |
-
|
381 |
-
wseg = word_segments_arr.loc[seg_idx].loc[sub_start:sub_end-1]
|
382 |
-
wseg["start"].fillna(srow["start"], inplace=True)
|
383 |
-
wseg["end"].fillna(srow["end"], inplace=True)
|
384 |
-
wseg["segment-text-start"].fillna(0, inplace=True)
|
385 |
-
wseg["segment-text-end"].fillna(len(text)-1, inplace=True)
|
386 |
-
|
387 |
-
cseg = char_segments_arr.loc[seg_idx].loc[sub_start:sub_end-1]
|
388 |
-
# fixes bug for single segment in transcript
|
389 |
-
cseg['segment-text-start'] = cseg['level_1'] if 'level_1' in cseg else 0
|
390 |
-
cseg['segment-text-end'] = cseg['level_1'] + 1 if 'level_1' in cseg else 1
|
391 |
-
if 'level_1' in cseg: del cseg['level_1']
|
392 |
-
if 'level_0' in cseg: del cseg['level_0']
|
393 |
-
cseg.reset_index(inplace=True)
|
394 |
-
aligned_segments.append(
|
395 |
-
{
|
396 |
-
"start": srow["start"],
|
397 |
-
"end": srow["end"],
|
398 |
-
"text": text,
|
399 |
-
"word-segments": wseg,
|
400 |
-
"char-segments": cseg
|
401 |
-
}
|
402 |
-
)
|
403 |
-
|
404 |
-
def get_raw_text(word_row):
|
405 |
-
return seg["seg-text"][word_row.name][int(word_row["segment-text-start"]):int(word_row["segment-text-end"])+1]
|
406 |
-
|
407 |
-
wdx = 0
|
408 |
-
curr_text = get_raw_text(wseg.iloc[wdx])
|
409 |
-
if len(wseg) > 1:
|
410 |
-
for _, wrow in wseg.iloc[1:].iterrows():
|
411 |
-
if wrow['start'] != wseg.iloc[wdx]['start']:
|
412 |
-
aligned_segments_word.append(
|
413 |
-
{
|
414 |
-
"text": curr_text.strip(),
|
415 |
-
"start": wseg.iloc[wdx]["start"],
|
416 |
-
"end": wseg.iloc[wdx]["end"],
|
417 |
-
}
|
418 |
-
)
|
419 |
-
curr_text = ""
|
420 |
-
curr_text += " " + get_raw_text(wrow)
|
421 |
-
wdx += 1
|
422 |
-
aligned_segments_word.append(
|
423 |
-
{
|
424 |
-
"text": curr_text.strip(),
|
425 |
-
"start": wseg.iloc[wdx]["start"],
|
426 |
-
"end": wseg.iloc[wdx]["end"]
|
427 |
-
}
|
428 |
-
)
|
429 |
-
|
430 |
-
|
431 |
-
return {"segments": aligned_segments, "word_segments": aligned_segments_word}
|
432 |
-
|
433 |
-
|
434 |
-
"""
|
435 |
-
source: https://pytorch.org/tutorials/intermediate/forced_alignment_with_torchaudio_tutorial.html
|
436 |
-
"""
|
437 |
-
def get_trellis(emission, tokens, blank_id=0):
|
438 |
-
num_frame = emission.size(0)
|
439 |
-
num_tokens = len(tokens)
|
440 |
-
|
441 |
-
# Trellis has extra diemsions for both time axis and tokens.
|
442 |
-
# The extra dim for tokens represents <SoS> (start-of-sentence)
|
443 |
-
# The extra dim for time axis is for simplification of the code.
|
444 |
-
trellis = torch.empty((num_frame + 1, num_tokens + 1))
|
445 |
-
trellis[0, 0] = 0
|
446 |
-
trellis[1:, 0] = torch.cumsum(emission[:, 0], 0)
|
447 |
-
trellis[0, -num_tokens:] = -float("inf")
|
448 |
-
trellis[-num_tokens:, 0] = float("inf")
|
449 |
-
|
450 |
-
for t in range(num_frame):
|
451 |
-
trellis[t + 1, 1:] = torch.maximum(
|
452 |
-
# Score for staying at the same token
|
453 |
-
trellis[t, 1:] + emission[t, blank_id],
|
454 |
-
# Score for changing to the next token
|
455 |
-
trellis[t, :-1] + emission[t, tokens],
|
456 |
-
)
|
457 |
-
return trellis
|
458 |
-
|
459 |
-
@dataclass
|
460 |
-
class Point:
|
461 |
-
token_index: int
|
462 |
-
time_index: int
|
463 |
-
score: float
|
464 |
-
|
465 |
-
def backtrack(trellis, emission, tokens, blank_id=0):
|
466 |
-
# Note:
|
467 |
-
# j and t are indices for trellis, which has extra dimensions
|
468 |
-
# for time and tokens at the beginning.
|
469 |
-
# When referring to time frame index `T` in trellis,
|
470 |
-
# the corresponding index in emission is `T-1`.
|
471 |
-
# Similarly, when referring to token index `J` in trellis,
|
472 |
-
# the corresponding index in transcript is `J-1`.
|
473 |
-
j = trellis.size(1) - 1
|
474 |
-
t_start = torch.argmax(trellis[:, j]).item()
|
475 |
-
|
476 |
-
path = []
|
477 |
-
for t in range(t_start, 0, -1):
|
478 |
-
# 1. Figure out if the current position was stay or change
|
479 |
-
# Note (again):
|
480 |
-
# `emission[J-1]` is the emission at time frame `J` of trellis dimension.
|
481 |
-
# Score for token staying the same from time frame J-1 to T.
|
482 |
-
stayed = trellis[t - 1, j] + emission[t - 1, blank_id]
|
483 |
-
# Score for token changing from C-1 at T-1 to J at T.
|
484 |
-
changed = trellis[t - 1, j - 1] + emission[t - 1, tokens[j - 1]]
|
485 |
-
|
486 |
-
# 2. Store the path with frame-wise probability.
|
487 |
-
prob = emission[t - 1, tokens[j - 1] if changed > stayed else 0].exp().item()
|
488 |
-
# Return token index and time index in non-trellis coordinate.
|
489 |
-
path.append(Point(j - 1, t - 1, prob))
|
490 |
-
|
491 |
-
# 3. Update the token
|
492 |
-
if changed > stayed:
|
493 |
-
j -= 1
|
494 |
-
if j == 0:
|
495 |
-
break
|
496 |
-
else:
|
497 |
-
# failed
|
498 |
-
return None
|
499 |
-
return path[::-1]
|
500 |
-
|
501 |
-
# Merge the labels
|
502 |
-
@dataclass
|
503 |
-
class Segment:
|
504 |
-
label: str
|
505 |
-
start: int
|
506 |
-
end: int
|
507 |
-
score: float
|
508 |
-
|
509 |
-
def __repr__(self):
|
510 |
-
return f"{self.label}\t({self.score:4.2f}): [{self.start:5d}, {self.end:5d})"
|
511 |
-
|
512 |
-
@property
|
513 |
-
def length(self):
|
514 |
-
return self.end - self.start
|
515 |
-
|
516 |
-
def merge_repeats(path, transcript):
|
517 |
-
i1, i2 = 0, 0
|
518 |
-
segments = []
|
519 |
-
while i1 < len(path):
|
520 |
-
while i2 < len(path) and path[i1].token_index == path[i2].token_index:
|
521 |
-
i2 += 1
|
522 |
-
score = sum(path[k].score for k in range(i1, i2)) / (i2 - i1)
|
523 |
-
segments.append(
|
524 |
-
Segment(
|
525 |
-
transcript[path[i1].token_index],
|
526 |
-
path[i1].time_index,
|
527 |
-
path[i2 - 1].time_index + 1,
|
528 |
-
score,
|
529 |
-
)
|
530 |
-
)
|
531 |
-
i1 = i2
|
532 |
-
return segments
|
533 |
-
|
534 |
-
def merge_words(segments, separator="|"):
|
535 |
-
words = []
|
536 |
-
i1, i2 = 0, 0
|
537 |
-
while i1 < len(segments):
|
538 |
-
if i2 >= len(segments) or segments[i2].label == separator:
|
539 |
-
if i1 != i2:
|
540 |
-
segs = segments[i1:i2]
|
541 |
-
word = "".join([seg.label for seg in segs])
|
542 |
-
score = sum(seg.score * seg.length for seg in segs) / sum(seg.length for seg in segs)
|
543 |
-
words.append(Segment(word, segments[i1].start, segments[i2 - 1].end, score))
|
544 |
-
i1 = i2 + 1
|
545 |
-
i2 = i1
|
546 |
-
else:
|
547 |
-
i2 += 1
|
548 |
-
return words
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
whisperx/asr.py
DELETED
@@ -1,429 +0,0 @@
|
|
1 |
-
import warnings
|
2 |
-
from typing import TYPE_CHECKING, Optional, Tuple, Union
|
3 |
-
import numpy as np
|
4 |
-
import torch
|
5 |
-
import tqdm
|
6 |
-
import ffmpeg
|
7 |
-
from whisper.audio import (
|
8 |
-
FRAMES_PER_SECOND,
|
9 |
-
HOP_LENGTH,
|
10 |
-
N_FRAMES,
|
11 |
-
N_SAMPLES,
|
12 |
-
SAMPLE_RATE,
|
13 |
-
CHUNK_LENGTH,
|
14 |
-
log_mel_spectrogram,
|
15 |
-
pad_or_trim,
|
16 |
-
load_audio
|
17 |
-
)
|
18 |
-
from whisper.decoding import DecodingOptions, DecodingResult
|
19 |
-
from whisper.timing import add_word_timestamps
|
20 |
-
from whisper.tokenizer import LANGUAGES, TO_LANGUAGE_CODE, get_tokenizer
|
21 |
-
from whisper.utils import (
|
22 |
-
exact_div,
|
23 |
-
format_timestamp,
|
24 |
-
make_safe,
|
25 |
-
)
|
26 |
-
|
27 |
-
if TYPE_CHECKING:
|
28 |
-
from whisper.model import Whisper
|
29 |
-
|
30 |
-
from .vad import merge_chunks
|
31 |
-
|
32 |
-
def transcribe(
|
33 |
-
model: "Whisper",
|
34 |
-
audio: Union[str, np.ndarray, torch.Tensor] = None,
|
35 |
-
mel: np.ndarray = None,
|
36 |
-
verbose: Optional[bool] = None,
|
37 |
-
temperature: Union[float, Tuple[float, ...]] = (0.0, 0.2, 0.4, 0.6, 0.8, 1.0),
|
38 |
-
compression_ratio_threshold: Optional[float] = 2.4,
|
39 |
-
logprob_threshold: Optional[float] = -1.0,
|
40 |
-
no_speech_threshold: Optional[float] = 0.6,
|
41 |
-
condition_on_previous_text: bool = True,
|
42 |
-
initial_prompt: Optional[str] = None,
|
43 |
-
word_timestamps: bool = False,
|
44 |
-
prepend_punctuations: str = "\"'“¿([{-",
|
45 |
-
append_punctuations: str = "\"'.。,,!!??::”)]}、",
|
46 |
-
**decode_options,
|
47 |
-
):
|
48 |
-
"""
|
49 |
-
Transcribe an audio file using Whisper.
|
50 |
-
We redefine the Whisper transcribe function to allow mel input (for sequential slicing of audio)
|
51 |
-
|
52 |
-
Parameters
|
53 |
-
----------
|
54 |
-
model: Whisper
|
55 |
-
The Whisper model instance
|
56 |
-
|
57 |
-
audio: Union[str, np.ndarray, torch.Tensor]
|
58 |
-
The path to the audio file to open, or the audio waveform
|
59 |
-
|
60 |
-
mel: np.ndarray
|
61 |
-
Mel spectrogram of audio segment.
|
62 |
-
|
63 |
-
verbose: bool
|
64 |
-
Whether to display the text being decoded to the console. If True, displays all the details,
|
65 |
-
If False, displays minimal details. If None, does not display anything
|
66 |
-
|
67 |
-
temperature: Union[float, Tuple[float, ...]]
|
68 |
-
Temperature for sampling. It can be a tuple of temperatures, which will be successively used
|
69 |
-
upon failures according to either `compression_ratio_threshold` or `logprob_threshold`.
|
70 |
-
|
71 |
-
compression_ratio_threshold: float
|
72 |
-
If the gzip compression ratio is above this value, treat as failed
|
73 |
-
|
74 |
-
logprob_threshold: float
|
75 |
-
If the average log probability over sampled tokens is below this value, treat as failed
|
76 |
-
|
77 |
-
no_speech_threshold: float
|
78 |
-
If the no_speech probability is higher than this value AND the average log probability
|
79 |
-
over sampled tokens is below `logprob_threshold`, consider the segment as silent
|
80 |
-
|
81 |
-
condition_on_previous_text: bool
|
82 |
-
if True, the previous output of the model is provided as a prompt for the next window;
|
83 |
-
disabling may make the text inconsistent across windows, but the model becomes less prone to
|
84 |
-
getting stuck in a failure loop, such as repetition looping or timestamps going out of sync.
|
85 |
-
|
86 |
-
word_timestamps: bool
|
87 |
-
Extract word-level timestamps using the cross-attention pattern and dynamic time warping,
|
88 |
-
and include the timestamps for each word in each segment.
|
89 |
-
|
90 |
-
prepend_punctuations: str
|
91 |
-
If word_timestamps is True, merge these punctuation symbols with the next word
|
92 |
-
|
93 |
-
append_punctuations: str
|
94 |
-
If word_timestamps is True, merge these punctuation symbols with the previous word
|
95 |
-
|
96 |
-
initial_prompt: Optional[str]
|
97 |
-
Optional text to provide as a prompt for the first window. This can be used to provide, or
|
98 |
-
"prompt-engineer" a context for transcription, e.g. custom vocabularies or proper nouns
|
99 |
-
to make it more likely to predict those word correctly.
|
100 |
-
|
101 |
-
decode_options: dict
|
102 |
-
Keyword arguments to construct `DecodingOptions` instances
|
103 |
-
|
104 |
-
Returns
|
105 |
-
-------
|
106 |
-
A dictionary containing the resulting text ("text") and segment-level details ("segments"), and
|
107 |
-
the spoken language ("language"), which is detected when `decode_options["language"]` is None.
|
108 |
-
"""
|
109 |
-
dtype = torch.float16 if decode_options.get("fp16", True) else torch.float32
|
110 |
-
if model.device == torch.device("cpu"):
|
111 |
-
if torch.cuda.is_available():
|
112 |
-
warnings.warn("Performing inference on CPU when CUDA is available")
|
113 |
-
if dtype == torch.float16:
|
114 |
-
warnings.warn("FP16 is not supported on CPU; using FP32 instead")
|
115 |
-
dtype = torch.float32
|
116 |
-
|
117 |
-
if dtype == torch.float32:
|
118 |
-
decode_options["fp16"] = False
|
119 |
-
|
120 |
-
# Pad 30-seconds of silence to the input audio, for slicing
|
121 |
-
if mel is None:
|
122 |
-
if audio is None:
|
123 |
-
raise ValueError("Transcribe needs either audio or mel as input, currently both are none.")
|
124 |
-
mel = log_mel_spectrogram(audio, padding=N_SAMPLES)
|
125 |
-
content_frames = mel.shape[-1] - N_FRAMES
|
126 |
-
|
127 |
-
if decode_options.get("language", None) is None:
|
128 |
-
if not model.is_multilingual:
|
129 |
-
decode_options["language"] = "en"
|
130 |
-
else:
|
131 |
-
if verbose:
|
132 |
-
print(
|
133 |
-
"Detecting language using up to the first 30 seconds. Use `--language` to specify the language"
|
134 |
-
)
|
135 |
-
mel_segment = pad_or_trim(mel, N_FRAMES).to(model.device).to(dtype)
|
136 |
-
_, probs = model.detect_language(mel_segment)
|
137 |
-
decode_options["language"] = max(probs, key=probs.get)
|
138 |
-
if verbose is not None:
|
139 |
-
print(
|
140 |
-
f"Detected language: {LANGUAGES[decode_options['language']].title()}"
|
141 |
-
)
|
142 |
-
|
143 |
-
language: str = decode_options["language"]
|
144 |
-
task: str = decode_options.get("task", "transcribe")
|
145 |
-
tokenizer = get_tokenizer(model.is_multilingual, language=language, task=task)
|
146 |
-
|
147 |
-
if word_timestamps and task == "translate":
|
148 |
-
warnings.warn("Word-level timestamps on translations may not be reliable.")
|
149 |
-
|
150 |
-
def decode_with_fallback(segment: torch.Tensor) -> DecodingResult:
|
151 |
-
temperatures = (
|
152 |
-
[temperature] if isinstance(temperature, (int, float)) else temperature
|
153 |
-
)
|
154 |
-
decode_result = None
|
155 |
-
|
156 |
-
for t in temperatures:
|
157 |
-
kwargs = {**decode_options}
|
158 |
-
if t > 0:
|
159 |
-
# disable beam_size and patience when t > 0
|
160 |
-
kwargs.pop("beam_size", None)
|
161 |
-
kwargs.pop("patience", None)
|
162 |
-
else:
|
163 |
-
# disable best_of when t == 0
|
164 |
-
kwargs.pop("best_of", None)
|
165 |
-
|
166 |
-
options = DecodingOptions(**kwargs, temperature=t)
|
167 |
-
decode_result = model.decode(segment, options)
|
168 |
-
|
169 |
-
needs_fallback = False
|
170 |
-
if (
|
171 |
-
compression_ratio_threshold is not None
|
172 |
-
and decode_result.compression_ratio > compression_ratio_threshold
|
173 |
-
):
|
174 |
-
needs_fallback = True # too repetitive
|
175 |
-
if (
|
176 |
-
logprob_threshold is not None
|
177 |
-
and decode_result.avg_logprob < logprob_threshold
|
178 |
-
):
|
179 |
-
needs_fallback = True # average log probability is too low
|
180 |
-
|
181 |
-
if not needs_fallback:
|
182 |
-
break
|
183 |
-
|
184 |
-
return decode_result
|
185 |
-
|
186 |
-
seek = 0
|
187 |
-
input_stride = exact_div(
|
188 |
-
N_FRAMES, model.dims.n_audio_ctx
|
189 |
-
) # mel frames per output token: 2
|
190 |
-
time_precision = (
|
191 |
-
input_stride * HOP_LENGTH / SAMPLE_RATE
|
192 |
-
) # time per output token: 0.02 (seconds)
|
193 |
-
all_tokens = []
|
194 |
-
all_segments = []
|
195 |
-
prompt_reset_since = 0
|
196 |
-
|
197 |
-
if initial_prompt is not None:
|
198 |
-
initial_prompt_tokens = tokenizer.encode(" " + initial_prompt.strip())
|
199 |
-
all_tokens.extend(initial_prompt_tokens)
|
200 |
-
else:
|
201 |
-
initial_prompt_tokens = []
|
202 |
-
|
203 |
-
def new_segment(
|
204 |
-
*, start: float, end: float, tokens: torch.Tensor, result: DecodingResult
|
205 |
-
):
|
206 |
-
tokens = tokens.tolist()
|
207 |
-
text_tokens = [token for token in tokens if token < tokenizer.eot]
|
208 |
-
return {
|
209 |
-
"seek": seek,
|
210 |
-
"start": start,
|
211 |
-
"end": end,
|
212 |
-
"text": tokenizer.decode(text_tokens),
|
213 |
-
"tokens": tokens,
|
214 |
-
"temperature": result.temperature,
|
215 |
-
"avg_logprob": result.avg_logprob,
|
216 |
-
"compression_ratio": result.compression_ratio,
|
217 |
-
"no_speech_prob": result.no_speech_prob,
|
218 |
-
}
|
219 |
-
|
220 |
-
|
221 |
-
# show the progress bar when verbose is False (if True, transcribed text will be printed)
|
222 |
-
with tqdm.tqdm(
|
223 |
-
total=content_frames, unit="frames", disable=verbose is not False
|
224 |
-
) as pbar:
|
225 |
-
while seek < content_frames:
|
226 |
-
time_offset = float(seek * HOP_LENGTH / SAMPLE_RATE)
|
227 |
-
mel_segment = mel[:, seek : seek + N_FRAMES]
|
228 |
-
segment_size = min(N_FRAMES, content_frames - seek)
|
229 |
-
segment_duration = segment_size * HOP_LENGTH / SAMPLE_RATE
|
230 |
-
mel_segment = pad_or_trim(mel_segment, N_FRAMES).to(model.device).to(dtype)
|
231 |
-
|
232 |
-
decode_options["prompt"] = all_tokens[prompt_reset_since:]
|
233 |
-
result: DecodingResult = decode_with_fallback(mel_segment)
|
234 |
-
tokens = torch.tensor(result.tokens)
|
235 |
-
if no_speech_threshold is not None:
|
236 |
-
# no voice activity check
|
237 |
-
should_skip = result.no_speech_prob > no_speech_threshold
|
238 |
-
if (
|
239 |
-
logprob_threshold is not None
|
240 |
-
and result.avg_logprob > logprob_threshold
|
241 |
-
):
|
242 |
-
# don't skip if the logprob is high enough, despite the no_speech_prob
|
243 |
-
should_skip = False
|
244 |
-
|
245 |
-
if should_skip:
|
246 |
-
seek += segment_size # fast-forward to the next segment boundary
|
247 |
-
continue
|
248 |
-
|
249 |
-
previous_seek = seek
|
250 |
-
current_segments = []
|
251 |
-
|
252 |
-
timestamp_tokens: torch.Tensor = tokens.ge(tokenizer.timestamp_begin)
|
253 |
-
single_timestamp_ending = timestamp_tokens[-2:].tolist() == [False, True]
|
254 |
-
|
255 |
-
consecutive = torch.where(timestamp_tokens[:-1] & timestamp_tokens[1:])[0]
|
256 |
-
consecutive.add_(1)
|
257 |
-
if len(consecutive) > 0:
|
258 |
-
# if the output contains two consecutive timestamp tokens
|
259 |
-
slices = consecutive.tolist()
|
260 |
-
if single_timestamp_ending:
|
261 |
-
slices.append(len(tokens))
|
262 |
-
|
263 |
-
last_slice = 0
|
264 |
-
for current_slice in slices:
|
265 |
-
sliced_tokens = tokens[last_slice:current_slice]
|
266 |
-
start_timestamp_pos = (
|
267 |
-
sliced_tokens[0].item() - tokenizer.timestamp_begin
|
268 |
-
)
|
269 |
-
end_timestamp_pos = (
|
270 |
-
sliced_tokens[-1].item() - tokenizer.timestamp_begin
|
271 |
-
)
|
272 |
-
current_segments.append(
|
273 |
-
new_segment(
|
274 |
-
start=time_offset + start_timestamp_pos * time_precision,
|
275 |
-
end=time_offset + end_timestamp_pos * time_precision,
|
276 |
-
tokens=sliced_tokens,
|
277 |
-
result=result,
|
278 |
-
)
|
279 |
-
)
|
280 |
-
last_slice = current_slice
|
281 |
-
|
282 |
-
if single_timestamp_ending:
|
283 |
-
# single timestamp at the end means no speech after the last timestamp.
|
284 |
-
seek += segment_size
|
285 |
-
else:
|
286 |
-
# otherwise, ignore the unfinished segment and seek to the last timestamp
|
287 |
-
last_timestamp_pos = (
|
288 |
-
tokens[last_slice - 1].item() - tokenizer.timestamp_begin
|
289 |
-
)
|
290 |
-
seek += last_timestamp_pos * input_stride
|
291 |
-
else:
|
292 |
-
duration = segment_duration
|
293 |
-
timestamps = tokens[timestamp_tokens.nonzero().flatten()]
|
294 |
-
if (
|
295 |
-
len(timestamps) > 0
|
296 |
-
and timestamps[-1].item() != tokenizer.timestamp_begin
|
297 |
-
):
|
298 |
-
# no consecutive timestamps but it has a timestamp; use the last one.
|
299 |
-
last_timestamp_pos = (
|
300 |
-
timestamps[-1].item() - tokenizer.timestamp_begin
|
301 |
-
)
|
302 |
-
duration = last_timestamp_pos * time_precision
|
303 |
-
|
304 |
-
current_segments.append(
|
305 |
-
new_segment(
|
306 |
-
start=time_offset,
|
307 |
-
end=time_offset + duration,
|
308 |
-
tokens=tokens,
|
309 |
-
result=result,
|
310 |
-
)
|
311 |
-
)
|
312 |
-
seek += segment_size
|
313 |
-
|
314 |
-
if not condition_on_previous_text or result.temperature > 0.5:
|
315 |
-
# do not feed the prompt tokens if a high temperature was used
|
316 |
-
prompt_reset_since = len(all_tokens)
|
317 |
-
|
318 |
-
if word_timestamps:
|
319 |
-
add_word_timestamps(
|
320 |
-
segments=current_segments,
|
321 |
-
model=model,
|
322 |
-
tokenizer=tokenizer,
|
323 |
-
mel=mel_segment,
|
324 |
-
num_frames=segment_size,
|
325 |
-
prepend_punctuations=prepend_punctuations,
|
326 |
-
append_punctuations=append_punctuations,
|
327 |
-
)
|
328 |
-
word_end_timestamps = [
|
329 |
-
w["end"] for s in current_segments for w in s["words"]
|
330 |
-
]
|
331 |
-
if not single_timestamp_ending and len(word_end_timestamps) > 0:
|
332 |
-
seek_shift = round(
|
333 |
-
(word_end_timestamps[-1] - time_offset) * FRAMES_PER_SECOND
|
334 |
-
)
|
335 |
-
if seek_shift > 0:
|
336 |
-
seek = previous_seek + seek_shift
|
337 |
-
|
338 |
-
if verbose:
|
339 |
-
for segment in current_segments:
|
340 |
-
start, end, text = segment["start"], segment["end"], segment["text"]
|
341 |
-
line = f"[{format_timestamp(start)} --> {format_timestamp(end)}] {text}"
|
342 |
-
print(make_safe(line))
|
343 |
-
|
344 |
-
# if a segment is instantaneous or does not contain text, clear it
|
345 |
-
for i, segment in enumerate(current_segments):
|
346 |
-
if segment["start"] == segment["end"] or segment["text"].strip() == "":
|
347 |
-
segment["text"] = ""
|
348 |
-
segment["tokens"] = []
|
349 |
-
segment["words"] = []
|
350 |
-
|
351 |
-
all_segments.extend(
|
352 |
-
[
|
353 |
-
{"id": i, **segment}
|
354 |
-
for i, segment in enumerate(
|
355 |
-
current_segments, start=len(all_segments)
|
356 |
-
)
|
357 |
-
]
|
358 |
-
)
|
359 |
-
all_tokens.extend(
|
360 |
-
[token for segment in current_segments for token in segment["tokens"]]
|
361 |
-
)
|
362 |
-
|
363 |
-
# update progress bar
|
364 |
-
pbar.update(min(content_frames, seek) - previous_seek)
|
365 |
-
|
366 |
-
|
367 |
-
return dict(
|
368 |
-
text=tokenizer.decode(all_tokens[len(initial_prompt_tokens) :]),
|
369 |
-
segments=all_segments,
|
370 |
-
language=language,
|
371 |
-
)
|
372 |
-
|
373 |
-
|
374 |
-
def transcribe_with_vad(
|
375 |
-
model: "Whisper",
|
376 |
-
audio: str,
|
377 |
-
vad_pipeline,
|
378 |
-
mel = None,
|
379 |
-
verbose: Optional[bool] = None,
|
380 |
-
**kwargs
|
381 |
-
):
|
382 |
-
"""
|
383 |
-
Transcribe per VAD segment
|
384 |
-
"""
|
385 |
-
|
386 |
-
vad_segments = vad_pipeline(audio)
|
387 |
-
|
388 |
-
# if not torch.is_tensor(audio):
|
389 |
-
# if isinstance(audio, str):
|
390 |
-
audio = load_audio(audio)
|
391 |
-
audio = torch.from_numpy(audio)
|
392 |
-
|
393 |
-
prev = 0
|
394 |
-
output = {"segments": []}
|
395 |
-
|
396 |
-
# merge segments to approx 30s inputs to make whisper most appropraite
|
397 |
-
vad_segments = merge_chunks(vad_segments, chunk_size=CHUNK_LENGTH)
|
398 |
-
if len(vad_segments) == 0:
|
399 |
-
return output
|
400 |
-
|
401 |
-
print(">>Performing transcription...")
|
402 |
-
for sdx, seg_t in enumerate(vad_segments):
|
403 |
-
if verbose:
|
404 |
-
print(f"~~ Transcribing VAD chunk: ({format_timestamp(seg_t['start'])} --> {format_timestamp(seg_t['end'])}) ~~")
|
405 |
-
seg_f_start, seg_f_end = int(seg_t["start"] * SAMPLE_RATE), int(seg_t["end"] * SAMPLE_RATE)
|
406 |
-
local_f_start, local_f_end = seg_f_start - prev, seg_f_end - prev
|
407 |
-
audio = audio[local_f_start:] # seek forward
|
408 |
-
seg_audio = audio[:local_f_end-local_f_start] # seek forward
|
409 |
-
prev = seg_f_start
|
410 |
-
local_mel = log_mel_spectrogram(seg_audio, padding=N_SAMPLES)
|
411 |
-
# need to pad
|
412 |
-
|
413 |
-
result = transcribe(model, audio, mel=local_mel, verbose=verbose, **kwargs)
|
414 |
-
seg_t["text"] = result["text"]
|
415 |
-
output["segments"].append(
|
416 |
-
{
|
417 |
-
"start": seg_t["start"],
|
418 |
-
"end": seg_t["end"],
|
419 |
-
"language": result["language"],
|
420 |
-
"text": result["text"],
|
421 |
-
"seg-text": [x["text"] for x in result["segments"]],
|
422 |
-
"seg-start": [x["start"] for x in result["segments"]],
|
423 |
-
"seg-end": [x["end"] for x in result["segments"]],
|
424 |
-
}
|
425 |
-
)
|
426 |
-
|
427 |
-
output["language"] = output["segments"][0]["language"]
|
428 |
-
|
429 |
-
return output
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
whisperx/diarize.py
DELETED
@@ -1,76 +0,0 @@
|
|
1 |
-
import numpy as np
|
2 |
-
import pandas as pd
|
3 |
-
from pyannote.audio import Pipeline
|
4 |
-
|
5 |
-
class DiarizationPipeline:
|
6 |
-
def __init__(
|
7 |
-
self,
|
8 |
-
model_name="pyannote/[email protected]",
|
9 |
-
use_auth_token=None,
|
10 |
-
):
|
11 |
-
self.model = Pipeline.from_pretrained(model_name, use_auth_token=use_auth_token)
|
12 |
-
|
13 |
-
def __call__(self, audio, min_speakers=None, max_speakers=None):
|
14 |
-
segments = self.model(audio, min_speakers=min_speakers, max_speakers=max_speakers)
|
15 |
-
diarize_df = pd.DataFrame(segments.itertracks(yield_label=True))
|
16 |
-
diarize_df['start'] = diarize_df[0].apply(lambda x: x.start)
|
17 |
-
diarize_df['end'] = diarize_df[0].apply(lambda x: x.end)
|
18 |
-
return diarize_df
|
19 |
-
|
20 |
-
def assign_word_speakers(diarize_df, result_segments, fill_nearest=False):
|
21 |
-
for seg in result_segments:
|
22 |
-
wdf = seg['word-segments']
|
23 |
-
if len(wdf['start'].dropna()) == 0:
|
24 |
-
wdf['start'] = seg['start']
|
25 |
-
wdf['end'] = seg['end']
|
26 |
-
speakers = []
|
27 |
-
for wdx, wrow in wdf.iterrows():
|
28 |
-
if not np.isnan(wrow['start']):
|
29 |
-
diarize_df['intersection'] = np.minimum(diarize_df['end'], wrow['end']) - np.maximum(diarize_df['start'], wrow['start'])
|
30 |
-
diarize_df['union'] = np.maximum(diarize_df['end'], wrow['end']) - np.minimum(diarize_df['start'], wrow['start'])
|
31 |
-
# remove no hit
|
32 |
-
if not fill_nearest:
|
33 |
-
dia_tmp = diarize_df[diarize_df['intersection'] > 0]
|
34 |
-
else:
|
35 |
-
dia_tmp = diarize_df
|
36 |
-
if len(dia_tmp) == 0:
|
37 |
-
speaker = None
|
38 |
-
else:
|
39 |
-
speaker = dia_tmp.sort_values("intersection", ascending=False).iloc[0][2]
|
40 |
-
else:
|
41 |
-
speaker = None
|
42 |
-
speakers.append(speaker)
|
43 |
-
seg['word-segments']['speaker'] = speakers
|
44 |
-
|
45 |
-
speaker_count = pd.Series(speakers).value_counts()
|
46 |
-
if len(speaker_count) == 0:
|
47 |
-
seg["speaker"]= "UNKNOWN"
|
48 |
-
else:
|
49 |
-
seg["speaker"] = speaker_count.index[0]
|
50 |
-
|
51 |
-
# create word level segments for .srt
|
52 |
-
word_seg = []
|
53 |
-
for seg in result_segments:
|
54 |
-
wseg = pd.DataFrame(seg["word-segments"])
|
55 |
-
for wdx, wrow in wseg.iterrows():
|
56 |
-
if wrow["start"] is not None:
|
57 |
-
speaker = wrow['speaker']
|
58 |
-
if speaker is None or speaker == np.nan:
|
59 |
-
speaker = "UNKNOWN"
|
60 |
-
word_seg.append(
|
61 |
-
{
|
62 |
-
"start": wrow["start"],
|
63 |
-
"end": wrow["end"],
|
64 |
-
"text": f"[{speaker}]: " + seg["text"][int(wrow["segment-text-start"]):int(wrow["segment-text-end"])]
|
65 |
-
}
|
66 |
-
)
|
67 |
-
|
68 |
-
# TODO: create segments but split words on new speaker
|
69 |
-
|
70 |
-
return result_segments, word_seg
|
71 |
-
|
72 |
-
class Segment:
|
73 |
-
def __init__(self, start, end, speaker=None):
|
74 |
-
self.start = start
|
75 |
-
self.end = end
|
76 |
-
self.speaker = speaker
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
whisperx/transcribe.py
DELETED
@@ -1,220 +0,0 @@
|
|
1 |
-
import argparse
|
2 |
-
import os
|
3 |
-
import gc
|
4 |
-
import warnings
|
5 |
-
from typing import TYPE_CHECKING, Optional, Tuple, Union
|
6 |
-
import numpy as np
|
7 |
-
import torch
|
8 |
-
import tempfile
|
9 |
-
import ffmpeg
|
10 |
-
from whisper.tokenizer import LANGUAGES, TO_LANGUAGE_CODE
|
11 |
-
from whisper.audio import SAMPLE_RATE
|
12 |
-
from whisper.utils import (
|
13 |
-
optional_float,
|
14 |
-
optional_int,
|
15 |
-
str2bool,
|
16 |
-
)
|
17 |
-
|
18 |
-
from .alignment import load_align_model, align
|
19 |
-
from .asr import transcribe, transcribe_with_vad
|
20 |
-
from .diarize import DiarizationPipeline, assign_word_speakers
|
21 |
-
from .utils import get_writer
|
22 |
-
from .vad import load_vad_model
|
23 |
-
|
24 |
-
def cli():
|
25 |
-
from whisper import available_models
|
26 |
-
|
27 |
-
# fmt: off
|
28 |
-
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
|
29 |
-
parser.add_argument("audio", nargs="+", type=str, help="audio file(s) to transcribe")
|
30 |
-
parser.add_argument("--model", default="small", choices=available_models(), help="name of the Whisper model to use")
|
31 |
-
parser.add_argument("--model_dir", type=str, default=None, help="the path to save model files; uses ~/.cache/whisper by default")
|
32 |
-
parser.add_argument("--device", default="cuda" if torch.cuda.is_available() else "cpu", help="device to use for PyTorch inference")
|
33 |
-
parser.add_argument("--output_dir", "-o", type=str, default=".", help="directory to save the outputs")
|
34 |
-
parser.add_argument("--output_format", "-f", type=str, default="all", choices=["all", "srt", "srt-word", "vtt", "txt", "tsv", "ass", "ass-char", "pickle", "vad"], help="format of the output file; if not specified, all available formats will be produced")
|
35 |
-
parser.add_argument("--verbose", type=str2bool, default=True, help="whether to print out the progress and debug messages")
|
36 |
-
|
37 |
-
parser.add_argument("--task", type=str, default="transcribe", choices=["transcribe", "translate"], help="whether to perform X->X speech recognition ('transcribe') or X->English translation ('translate')")
|
38 |
-
parser.add_argument("--language", type=str, default=None, choices=sorted(LANGUAGES.keys()) + sorted([k.title() for k in TO_LANGUAGE_CODE.keys()]), help="language spoken in the audio, specify None to perform language detection")
|
39 |
-
|
40 |
-
# alignment params
|
41 |
-
parser.add_argument("--align_model", default=None, help="Name of phoneme-level ASR model to do alignment")
|
42 |
-
parser.add_argument("--align_extend", default=2, type=float, help="Seconds before and after to extend the whisper segments for alignment (if not using VAD).")
|
43 |
-
parser.add_argument("--align_from_prev", default=True, type=bool, help="Whether to clip the alignment start time of current segment to the end time of the last aligned word of the previous segment (if not using VAD)")
|
44 |
-
parser.add_argument("--interpolate_method", default="nearest", choices=["nearest", "linear", "ignore"], help="For word .srt, method to assign timestamps to non-aligned words, or merge them into neighbouring.")
|
45 |
-
parser.add_argument("--no_align", action='store_true', help="Do not perform phoneme alignment")
|
46 |
-
|
47 |
-
# vad params
|
48 |
-
parser.add_argument("--vad_filter", type=str2bool, default=True, help="Whether to pre-segment audio with VAD, highly recommended! Produces more accurate alignment + timestamp see WhisperX paper https://arxiv.org/abs/2303.00747")
|
49 |
-
parser.add_argument("--vad_onset", type=float, default=0.500, help="Onset threshold for VAD (see pyannote.audio), reduce this if speech is not being detected")
|
50 |
-
parser.add_argument("--vad_offset", type=float, default=0.363, help="Offset threshold for VAD (see pyannote.audio), reduce this if speech is not being detected.")
|
51 |
-
|
52 |
-
# diarization params
|
53 |
-
parser.add_argument("--diarize", action="store_true", help="Apply diarization to assign speaker labels to each segment/word")
|
54 |
-
parser.add_argument("--min_speakers", default=None, type=int)
|
55 |
-
parser.add_argument("--max_speakers", default=None, type=int)
|
56 |
-
|
57 |
-
parser.add_argument("--temperature", type=float, default=0, help="temperature to use for sampling")
|
58 |
-
parser.add_argument("--best_of", type=optional_int, default=5, help="number of candidates when sampling with non-zero temperature")
|
59 |
-
parser.add_argument("--beam_size", type=optional_int, default=5, help="number of beams in beam search, only applicable when temperature is zero")
|
60 |
-
parser.add_argument("--patience", type=float, default=None, help="optional patience value to use in beam decoding, as in https://arxiv.org/abs/2204.05424, the default (1.0) is equivalent to conventional beam search")
|
61 |
-
parser.add_argument("--length_penalty", type=float, default=None, help="optional token length penalty coefficient (alpha) as in https://arxiv.org/abs/1609.08144, uses simple length normalization by default")
|
62 |
-
|
63 |
-
parser.add_argument("--suppress_tokens", type=str, default="-1", help="comma-separated list of token ids to suppress during sampling; '-1' will suppress most special characters except common punctuations")
|
64 |
-
parser.add_argument("--initial_prompt", type=str, default=None, help="optional text to provide as a prompt for the first window.")
|
65 |
-
parser.add_argument("--condition_on_previous_text", type=str2bool, default=False, help="if True, provide the previous output of the model as a prompt for the next window; disabling may make the text inconsistent across windows, but the model becomes less prone to getting stuck in a failure loop")
|
66 |
-
parser.add_argument("--fp16", type=str2bool, default=True, help="whether to perform inference in fp16; True by default")
|
67 |
-
|
68 |
-
parser.add_argument("--temperature_increment_on_fallback", type=optional_float, default=0.2, help="temperature to increase when falling back when the decoding fails to meet either of the thresholds below")
|
69 |
-
parser.add_argument("--compression_ratio_threshold", type=optional_float, default=2.4, help="if the gzip compression ratio is higher than this value, treat the decoding as failed")
|
70 |
-
parser.add_argument("--logprob_threshold", type=optional_float, default=-1.0, help="if the average log probability is lower than this value, treat the decoding as failed")
|
71 |
-
parser.add_argument("--no_speech_threshold", type=optional_float, default=0.6, help="if the probability of the <|nospeech|> token is higher than this value AND the decoding has failed due to `logprob_threshold`, consider the segment as silence")
|
72 |
-
parser.add_argument("--word_timestamps", type=str2bool, default=False, help="(experimental) extract word-level timestamps and refine the results based on them")
|
73 |
-
parser.add_argument("--prepend_punctuations", type=str, default="\"\'“¿([{-", help="if word_timestamps is True, merge these punctuation symbols with the next word")
|
74 |
-
parser.add_argument("--append_punctuations", type=str, default="\"\'.。,,!!??::”)]}、", help="if word_timestamps is True, merge these punctuation symbols with the previous word")
|
75 |
-
parser.add_argument("--threads", type=optional_int, default=0, help="number of threads used by torch for CPU inference; supercedes MKL_NUM_THREADS/OMP_NUM_THREADS")
|
76 |
-
|
77 |
-
parser.add_argument("--hf_token", type=str, default=None, help="Hugging Face Access Token to access PyAnnote gated models")
|
78 |
-
# parser.add_argument("--model_flush", action="store_true", help="Flush memory from each model after use, reduces GPU requirement but slower processing >1 audio file.")
|
79 |
-
parser.add_argument("--tmp_dir", default=None, help="Temporary directory to write audio file if input if not .wav format (only for VAD).")
|
80 |
-
# fmt: on
|
81 |
-
|
82 |
-
args = parser.parse_args().__dict__
|
83 |
-
model_name: str = args.pop("model")
|
84 |
-
model_dir: str = args.pop("model_dir")
|
85 |
-
output_dir: str = args.pop("output_dir")
|
86 |
-
output_format: str = args.pop("output_format")
|
87 |
-
device: str = args.pop("device")
|
88 |
-
# model_flush: bool = args.pop("model_flush")
|
89 |
-
os.makedirs(output_dir, exist_ok=True)
|
90 |
-
|
91 |
-
tmp_dir: str = args.pop("tmp_dir")
|
92 |
-
if tmp_dir is not None:
|
93 |
-
os.makedirs(tmp_dir, exist_ok=True)
|
94 |
-
|
95 |
-
align_model: str = args.pop("align_model")
|
96 |
-
align_extend: float = args.pop("align_extend")
|
97 |
-
align_from_prev: bool = args.pop("align_from_prev")
|
98 |
-
interpolate_method: str = args.pop("interpolate_method")
|
99 |
-
no_align: bool = args.pop("no_align")
|
100 |
-
|
101 |
-
hf_token: str = args.pop("hf_token")
|
102 |
-
vad_filter: bool = args.pop("vad_filter")
|
103 |
-
vad_onset: float = args.pop("vad_onset")
|
104 |
-
vad_offset: float = args.pop("vad_offset")
|
105 |
-
|
106 |
-
diarize: bool = args.pop("diarize")
|
107 |
-
min_speakers: int = args.pop("min_speakers")
|
108 |
-
max_speakers: int = args.pop("max_speakers")
|
109 |
-
|
110 |
-
if vad_filter:
|
111 |
-
from pyannote.audio import Pipeline
|
112 |
-
from pyannote.audio import Model, Pipeline
|
113 |
-
vad_model = load_vad_model(torch.device(device), vad_onset, vad_offset, use_auth_token=hf_token)
|
114 |
-
else:
|
115 |
-
vad_model = None
|
116 |
-
|
117 |
-
# if model_flush:
|
118 |
-
# print(">>Model flushing activated... Only loading model after ASR stage")
|
119 |
-
# del align_model
|
120 |
-
# align_model = ""
|
121 |
-
|
122 |
-
|
123 |
-
if model_name.endswith(".en") and args["language"] not in {"en", "English"}:
|
124 |
-
if args["language"] is not None:
|
125 |
-
warnings.warn(
|
126 |
-
f"{model_name} is an English-only model but receipted '{args['language']}'; using English instead."
|
127 |
-
)
|
128 |
-
args["language"] = "en"
|
129 |
-
|
130 |
-
temperature = args.pop("temperature")
|
131 |
-
if (increment := args.pop("temperature_increment_on_fallback")) is not None:
|
132 |
-
temperature = tuple(np.arange(temperature, 1.0 + 1e-6, increment))
|
133 |
-
else:
|
134 |
-
temperature = [temperature]
|
135 |
-
|
136 |
-
if (threads := args.pop("threads")) > 0:
|
137 |
-
torch.set_num_threads(threads)
|
138 |
-
|
139 |
-
from whisper import load_model
|
140 |
-
|
141 |
-
writer = get_writer(output_format, output_dir)
|
142 |
-
|
143 |
-
# Part 1: VAD & ASR Loop
|
144 |
-
results = []
|
145 |
-
tmp_results = []
|
146 |
-
model = load_model(model_name, device=device, download_root=model_dir)
|
147 |
-
for audio_path in args.pop("audio"):
|
148 |
-
input_audio_path = audio_path
|
149 |
-
tfile = None
|
150 |
-
|
151 |
-
# >> VAD & ASR
|
152 |
-
if vad_model is not None:
|
153 |
-
if not audio_path.endswith(".wav"):
|
154 |
-
print(">>VAD requires .wav format, converting to wav as a tempfile...")
|
155 |
-
audio_basename = os.path.splitext(os.path.basename(audio_path))[0]
|
156 |
-
if tmp_dir is not None:
|
157 |
-
input_audio_path = os.path.join(tmp_dir, audio_basename + ".wav")
|
158 |
-
else:
|
159 |
-
input_audio_path = os.path.join(os.path.dirname(audio_path), audio_basename + ".wav")
|
160 |
-
ffmpeg.input(audio_path, threads=0).output(input_audio_path, ac=1, ar=SAMPLE_RATE).run(cmd=["ffmpeg"])
|
161 |
-
print(">>Performing VAD...")
|
162 |
-
result = transcribe_with_vad(model, input_audio_path, vad_model, temperature=temperature, **args)
|
163 |
-
else:
|
164 |
-
print(">>Performing transcription...")
|
165 |
-
result = transcribe(model, input_audio_path, temperature=temperature, **args)
|
166 |
-
|
167 |
-
results.append((result, input_audio_path))
|
168 |
-
|
169 |
-
# Unload Whisper and VAD
|
170 |
-
del model
|
171 |
-
del vad_model
|
172 |
-
gc.collect()
|
173 |
-
torch.cuda.empty_cache()
|
174 |
-
|
175 |
-
# Part 2: Align Loop
|
176 |
-
if not no_align:
|
177 |
-
tmp_results = results
|
178 |
-
results = []
|
179 |
-
align_language = args["language"] if args["language"] is not None else "en" # default to loading english if not specified
|
180 |
-
align_model, align_metadata = load_align_model(align_language, device, model_name=align_model)
|
181 |
-
for result, input_audio_path in tmp_results:
|
182 |
-
# >> Align
|
183 |
-
if align_model is not None and len(result["segments"]) > 0:
|
184 |
-
if result.get("language", "en") != align_metadata["language"]:
|
185 |
-
# load new language
|
186 |
-
print(f"New language found ({result['language']})! Previous was ({align_metadata['language']}), loading new alignment model for new language...")
|
187 |
-
align_model, align_metadata = load_align_model(result["language"], device)
|
188 |
-
print(">>Performing alignment...")
|
189 |
-
result = align(result["segments"], align_model, align_metadata, input_audio_path, device,
|
190 |
-
extend_duration=align_extend, start_from_previous=align_from_prev, interpolate_method=interpolate_method)
|
191 |
-
results.append((result, input_audio_path))
|
192 |
-
|
193 |
-
# Unload align model
|
194 |
-
del align_model
|
195 |
-
gc.collect()
|
196 |
-
torch.cuda.empty_cache()
|
197 |
-
|
198 |
-
# >> Diarize
|
199 |
-
if diarize:
|
200 |
-
if hf_token is None:
|
201 |
-
print("Warning, no --hf_token used, needs to be saved in environment variable, otherwise will throw error loading diarization model...")
|
202 |
-
tmp_results = results
|
203 |
-
results = []
|
204 |
-
diarize_model = DiarizationPipeline(use_auth_token=hf_token)
|
205 |
-
for result, input_audio_path in tmp_results:
|
206 |
-
diarize_segments = diarize_model(input_audio_path, min_speakers=min_speakers, max_speakers=max_speakers)
|
207 |
-
results_segments, word_segments = assign_word_speakers(diarize_segments, result["segments"])
|
208 |
-
result = {"segments": results_segments, "word_segments": word_segments}
|
209 |
-
results.append((result, input_audio_path))
|
210 |
-
|
211 |
-
# >> Write
|
212 |
-
for result, audio_path in results:
|
213 |
-
writer(result, audio_path)
|
214 |
-
|
215 |
-
# cleanup
|
216 |
-
if input_audio_path != audio_path:
|
217 |
-
os.remove(input_audio_path)
|
218 |
-
|
219 |
-
if __name__ == "__main__":
|
220 |
-
cli()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
whisperx/utils.py
DELETED
@@ -1,317 +0,0 @@
|
|
1 |
-
import os
|
2 |
-
import zlib
|
3 |
-
from typing import Callable, TextIO, Iterator, Tuple
|
4 |
-
import pandas as pd
|
5 |
-
import numpy as np
|
6 |
-
|
7 |
-
def interpolate_nans(x, method='nearest'):
|
8 |
-
if x.notnull().sum() > 1:
|
9 |
-
return x.interpolate(method=method).ffill().bfill()
|
10 |
-
else:
|
11 |
-
return x.ffill().bfill()
|
12 |
-
|
13 |
-
|
14 |
-
def write_txt(transcript: Iterator[dict], file: TextIO):
|
15 |
-
for segment in transcript:
|
16 |
-
print(segment['text'].strip(), file=file, flush=True)
|
17 |
-
|
18 |
-
|
19 |
-
def write_vtt(transcript: Iterator[dict], file: TextIO):
|
20 |
-
print("WEBVTT\n", file=file)
|
21 |
-
for segment in transcript:
|
22 |
-
print(
|
23 |
-
f"{format_timestamp(segment['start'])} --> {format_timestamp(segment['end'])}\n"
|
24 |
-
f"{segment['text'].strip().replace('-->', '->')}\n",
|
25 |
-
file=file,
|
26 |
-
flush=True,
|
27 |
-
)
|
28 |
-
|
29 |
-
def write_tsv(transcript: Iterator[dict], file: TextIO):
|
30 |
-
print("start", "end", "text", sep="\t", file=file)
|
31 |
-
for segment in transcript:
|
32 |
-
print(segment['start'], file=file, end="\t")
|
33 |
-
print(segment['end'], file=file, end="\t")
|
34 |
-
print(segment['text'].strip().replace("\t", " "), file=file, flush=True)
|
35 |
-
|
36 |
-
|
37 |
-
def write_srt(transcript: Iterator[dict], file: TextIO):
|
38 |
-
"""
|
39 |
-
Write a transcript to a file in SRT format.
|
40 |
-
|
41 |
-
Example usage:
|
42 |
-
from pathlib import Path
|
43 |
-
from whisper.utils import write_srt
|
44 |
-
|
45 |
-
result = transcribe(model, audio_path, temperature=temperature, **args)
|
46 |
-
|
47 |
-
# save SRT
|
48 |
-
audio_basename = Path(audio_path).stem
|
49 |
-
with open(Path(output_dir) / (audio_basename + ".srt"), "w", encoding="utf-8") as srt:
|
50 |
-
write_srt(result["segments"], file=srt)
|
51 |
-
"""
|
52 |
-
for i, segment in enumerate(transcript, start=1):
|
53 |
-
# write srt lines
|
54 |
-
print(
|
55 |
-
f"{i}\n"
|
56 |
-
f"{format_timestamp(segment['start'], always_include_hours=True, decimal_marker=',')} --> "
|
57 |
-
f"{format_timestamp(segment['end'], always_include_hours=True, decimal_marker=',')}\n"
|
58 |
-
f"{segment['text'].strip().replace('-->', '->')}\n",
|
59 |
-
file=file,
|
60 |
-
flush=True,
|
61 |
-
)
|
62 |
-
|
63 |
-
|
64 |
-
def write_ass(transcript: Iterator[dict],
|
65 |
-
file: TextIO,
|
66 |
-
resolution: str = "word",
|
67 |
-
color: str = None, underline=True,
|
68 |
-
prefmt: str = None, suffmt: str = None,
|
69 |
-
font: str = None, font_size: int = 24,
|
70 |
-
strip=True, **kwargs):
|
71 |
-
"""
|
72 |
-
Credit: https://github.com/jianfch/stable-ts/blob/ff79549bd01f764427879f07ecd626c46a9a430a/stable_whisper/text_output.py
|
73 |
-
Generate Advanced SubStation Alpha (ass) file from results to
|
74 |
-
display both phrase-level & word-level timestamp simultaneously by:
|
75 |
-
-using segment-level timestamps display phrases as usual
|
76 |
-
-using word-level timestamps change formats (e.g. color/underline) of the word in the displayed segment
|
77 |
-
Note: ass file is used in the same way as srt, vtt, etc.
|
78 |
-
Parameters
|
79 |
-
----------
|
80 |
-
transcript: dict
|
81 |
-
results from modified model
|
82 |
-
file: TextIO
|
83 |
-
file object to write to
|
84 |
-
resolution: str
|
85 |
-
"word" or "char", timestamp resolution to highlight.
|
86 |
-
color: str
|
87 |
-
color code for a word at its corresponding timestamp
|
88 |
-
<bbggrr> reverse order hexadecimal RGB value (e.g. FF0000 is full intensity blue. Default: 00FF00)
|
89 |
-
underline: bool
|
90 |
-
whether to underline a word at its corresponding timestamp
|
91 |
-
prefmt: str
|
92 |
-
used to specify format for word-level timestamps (must be use with 'suffmt' and overrides 'color'&'underline')
|
93 |
-
appears as such in the .ass file:
|
94 |
-
Hi, {<prefmt>}how{<suffmt>} are you?
|
95 |
-
reference [Appendix A: Style override codes] in http://www.tcax.org/docs/ass-specs.htm
|
96 |
-
suffmt: str
|
97 |
-
used to specify format for word-level timestamps (must be use with 'prefmt' and overrides 'color'&'underline')
|
98 |
-
appears as such in the .ass file:
|
99 |
-
Hi, {<prefmt>}how{<suffmt>} are you?
|
100 |
-
reference [Appendix A: Style override codes] in http://www.tcax.org/docs/ass-specs.htm
|
101 |
-
font: str
|
102 |
-
word font (default: Arial)
|
103 |
-
font_size: int
|
104 |
-
word font size (default: 48)
|
105 |
-
kwargs:
|
106 |
-
used for format styles:
|
107 |
-
'Name', 'Fontname', 'Fontsize', 'PrimaryColour', 'SecondaryColour', 'OutlineColour', 'BackColour', 'Bold',
|
108 |
-
'Italic', 'Underline', 'StrikeOut', 'ScaleX', 'ScaleY', 'Spacing', 'Angle', 'BorderStyle', 'Outline',
|
109 |
-
'Shadow', 'Alignment', 'MarginL', 'MarginR', 'MarginV', 'Encoding'
|
110 |
-
|
111 |
-
"""
|
112 |
-
|
113 |
-
fmt_style_dict = {'Name': 'Default', 'Fontname': 'Arial', 'Fontsize': '48', 'PrimaryColour': '&Hffffff',
|
114 |
-
'SecondaryColour': '&Hffffff', 'OutlineColour': '&H0', 'BackColour': '&H0', 'Bold': '0',
|
115 |
-
'Italic': '0', 'Underline': '0', 'StrikeOut': '0', 'ScaleX': '100', 'ScaleY': '100',
|
116 |
-
'Spacing': '0', 'Angle': '0', 'BorderStyle': '1', 'Outline': '1', 'Shadow': '0',
|
117 |
-
'Alignment': '2', 'MarginL': '10', 'MarginR': '10', 'MarginV': '10', 'Encoding': '0'}
|
118 |
-
|
119 |
-
for k, v in filter(lambda x: 'colour' in x[0].lower() and not str(x[1]).startswith('&H'), kwargs.items()):
|
120 |
-
kwargs[k] = f'&H{kwargs[k]}'
|
121 |
-
|
122 |
-
fmt_style_dict.update((k, v) for k, v in kwargs.items() if k in fmt_style_dict)
|
123 |
-
|
124 |
-
if font:
|
125 |
-
fmt_style_dict.update(Fontname=font)
|
126 |
-
if font_size:
|
127 |
-
fmt_style_dict.update(Fontsize=font_size)
|
128 |
-
|
129 |
-
fmts = f'Format: {", ".join(map(str, fmt_style_dict.keys()))}'
|
130 |
-
|
131 |
-
styles = f'Style: {",".join(map(str, fmt_style_dict.values()))}'
|
132 |
-
|
133 |
-
ass_str = f'[Script Info]\nScriptType: v4.00+\nPlayResX: 384\nPlayResY: 288\nScaledBorderAndShadow: yes\n\n' \
|
134 |
-
f'[V4+ Styles]\n{fmts}\n{styles}\n\n' \
|
135 |
-
f'[Events]\nFormat: Layer, Start, End, Style, Name, MarginL, MarginR, MarginV, Effect, Text\n\n'
|
136 |
-
|
137 |
-
if prefmt or suffmt:
|
138 |
-
if suffmt:
|
139 |
-
assert prefmt, 'prefmt must be used along with suffmt'
|
140 |
-
else:
|
141 |
-
suffmt = r'\r'
|
142 |
-
else:
|
143 |
-
if not color:
|
144 |
-
color = 'HFF00'
|
145 |
-
underline_code = r'\u1' if underline else ''
|
146 |
-
|
147 |
-
prefmt = r'{\1c&' + f'{color.upper()}&{underline_code}' + '}'
|
148 |
-
suffmt = r'{\r}'
|
149 |
-
|
150 |
-
def secs_to_hhmmss(secs: Tuple[float, int]):
|
151 |
-
mm, ss = divmod(secs, 60)
|
152 |
-
hh, mm = divmod(mm, 60)
|
153 |
-
return f'{hh:0>1.0f}:{mm:0>2.0f}:{ss:0>2.2f}'
|
154 |
-
|
155 |
-
|
156 |
-
def dialogue(chars: str, start: float, end: float, idx_0: int, idx_1: int) -> str:
|
157 |
-
if idx_0 == -1:
|
158 |
-
text = chars
|
159 |
-
else:
|
160 |
-
text = f'{chars[:idx_0]}{prefmt}{chars[idx_0:idx_1]}{suffmt}{chars[idx_1:]}'
|
161 |
-
return f"Dialogue: 0,{secs_to_hhmmss(start)},{secs_to_hhmmss(end)}," \
|
162 |
-
f"Default,,0,0,0,,{text.strip() if strip else text}"
|
163 |
-
|
164 |
-
if resolution == "word":
|
165 |
-
resolution_key = "word-segments"
|
166 |
-
elif resolution == "char":
|
167 |
-
resolution_key = "char-segments"
|
168 |
-
else:
|
169 |
-
raise ValueError(".ass resolution should be 'word' or 'char', not ", resolution)
|
170 |
-
|
171 |
-
ass_arr = []
|
172 |
-
|
173 |
-
for segment in transcript:
|
174 |
-
# if "12" in segment['text']:
|
175 |
-
# import pdb; pdb.set_trace()
|
176 |
-
if resolution_key in segment:
|
177 |
-
res_segs = pd.DataFrame(segment[resolution_key])
|
178 |
-
prev = segment['start']
|
179 |
-
if "speaker" in segment:
|
180 |
-
speaker_str = f"[{segment['speaker']}]: "
|
181 |
-
else:
|
182 |
-
speaker_str = ""
|
183 |
-
for cdx, crow in res_segs.iterrows():
|
184 |
-
if not np.isnan(crow['start']):
|
185 |
-
if resolution == "char":
|
186 |
-
idx_0 = cdx
|
187 |
-
idx_1 = cdx + 1
|
188 |
-
elif resolution == "word":
|
189 |
-
idx_0 = int(crow["segment-text-start"])
|
190 |
-
idx_1 = int(crow["segment-text-end"])
|
191 |
-
# fill gap
|
192 |
-
if crow['start'] > prev:
|
193 |
-
filler_ts = {
|
194 |
-
"chars": speaker_str + segment['text'],
|
195 |
-
"start": prev,
|
196 |
-
"end": crow['start'],
|
197 |
-
"idx_0": -1,
|
198 |
-
"idx_1": -1
|
199 |
-
}
|
200 |
-
|
201 |
-
ass_arr.append(filler_ts)
|
202 |
-
# highlight current word
|
203 |
-
f_word_ts = {
|
204 |
-
"chars": speaker_str + segment['text'],
|
205 |
-
"start": crow['start'],
|
206 |
-
"end": crow['end'],
|
207 |
-
"idx_0": idx_0 + len(speaker_str),
|
208 |
-
"idx_1": idx_1 + len(speaker_str)
|
209 |
-
}
|
210 |
-
ass_arr.append(f_word_ts)
|
211 |
-
prev = crow['end']
|
212 |
-
|
213 |
-
ass_str += '\n'.join(map(lambda x: dialogue(**x), ass_arr))
|
214 |
-
|
215 |
-
file.write(ass_str)
|
216 |
-
|
217 |
-
|
218 |
-
from whisper.utils import SubtitlesWriter, ResultWriter, WriteTXT, WriteVTT, WriteSRT, WriteTSV, WriteJSON, format_timestamp
|
219 |
-
|
220 |
-
class WriteASS(ResultWriter):
|
221 |
-
extension: str = "ass"
|
222 |
-
|
223 |
-
def write_result(self, result: dict, file: TextIO):
|
224 |
-
write_ass(result["segments"], file, resolution="word")
|
225 |
-
|
226 |
-
class WriteASSchar(ResultWriter):
|
227 |
-
extension: str = "ass"
|
228 |
-
|
229 |
-
def write_result(self, result: dict, file: TextIO):
|
230 |
-
write_ass(result["segments"], file, resolution="char")
|
231 |
-
|
232 |
-
class WritePickle(ResultWriter):
|
233 |
-
extension: str = "ass"
|
234 |
-
|
235 |
-
def write_result(self, result: dict, file: TextIO):
|
236 |
-
pd.DataFrame(result["segments"]).to_pickle(file)
|
237 |
-
|
238 |
-
class WriteSRTWord(ResultWriter):
|
239 |
-
extension: str = "word.srt"
|
240 |
-
always_include_hours: bool = True
|
241 |
-
decimal_marker: str = ","
|
242 |
-
|
243 |
-
def iterate_result(self, result: dict):
|
244 |
-
for segment in result["word_segments"]:
|
245 |
-
segment_start = self.format_timestamp(segment["start"])
|
246 |
-
segment_end = self.format_timestamp(segment["end"])
|
247 |
-
segment_text = segment["text"].strip().replace("-->", "->")
|
248 |
-
|
249 |
-
if word_timings := segment.get("words", None):
|
250 |
-
all_words = [timing["word"] for timing in word_timings]
|
251 |
-
all_words[0] = all_words[0].strip() # remove the leading space, if any
|
252 |
-
last = segment_start
|
253 |
-
for i, this_word in enumerate(word_timings):
|
254 |
-
start = self.format_timestamp(this_word["start"])
|
255 |
-
end = self.format_timestamp(this_word["end"])
|
256 |
-
if last != start:
|
257 |
-
yield last, start, segment_text
|
258 |
-
|
259 |
-
yield start, end, "".join(
|
260 |
-
[
|
261 |
-
f"<u>{word}</u>" if j == i else word
|
262 |
-
for j, word in enumerate(all_words)
|
263 |
-
]
|
264 |
-
)
|
265 |
-
last = end
|
266 |
-
|
267 |
-
if last != segment_end:
|
268 |
-
yield last, segment_end, segment_text
|
269 |
-
else:
|
270 |
-
yield segment_start, segment_end, segment_text
|
271 |
-
|
272 |
-
def write_result(self, result: dict, file: TextIO):
|
273 |
-
if "word_segments" not in result:
|
274 |
-
return
|
275 |
-
for i, (start, end, text) in enumerate(self.iterate_result(result), start=1):
|
276 |
-
print(f"{i}\n{start} --> {end}\n{text}\n", file=file, flush=True)
|
277 |
-
|
278 |
-
def format_timestamp(self, seconds: float):
|
279 |
-
return format_timestamp(
|
280 |
-
seconds=seconds,
|
281 |
-
always_include_hours=self.always_include_hours,
|
282 |
-
decimal_marker=self.decimal_marker,
|
283 |
-
)
|
284 |
-
|
285 |
-
def get_writer(output_format: str, output_dir: str) -> Callable[[dict, TextIO], None]:
|
286 |
-
writers = {
|
287 |
-
"txt": WriteTXT,
|
288 |
-
"vtt": WriteVTT,
|
289 |
-
"srt": WriteSRT,
|
290 |
-
"tsv": WriteTSV,
|
291 |
-
"ass": WriteASS,
|
292 |
-
"srt-word": WriteSRTWord,
|
293 |
-
# "ass-char": WriteASSchar,
|
294 |
-
# "pickle": WritePickle,
|
295 |
-
# "json": WriteJSON,
|
296 |
-
}
|
297 |
-
|
298 |
-
writers_other = {
|
299 |
-
"pkl": WritePickle,
|
300 |
-
"ass-char": WriteASSchar
|
301 |
-
}
|
302 |
-
|
303 |
-
if output_format == "all":
|
304 |
-
all_writers = [writer(output_dir) for writer in writers.values()]
|
305 |
-
|
306 |
-
def write_all(result: dict, file: TextIO):
|
307 |
-
for writer in all_writers:
|
308 |
-
writer(result, file)
|
309 |
-
|
310 |
-
return write_all
|
311 |
-
|
312 |
-
if output_format in writers:
|
313 |
-
return writers[output_format](output_dir)
|
314 |
-
elif output_format in writers_other:
|
315 |
-
return writers_other[output_format](output_dir)
|
316 |
-
else:
|
317 |
-
raise ValueError(f"Output format '{output_format}' not supported, choose from {writers.keys()} and {writers_other.keys()}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
whisperx/vad.py
DELETED
@@ -1,305 +0,0 @@
|
|
1 |
-
import os
|
2 |
-
import urllib
|
3 |
-
import pandas as pd
|
4 |
-
import numpy as np
|
5 |
-
import torch
|
6 |
-
import hashlib
|
7 |
-
from tqdm import tqdm
|
8 |
-
from typing import Optional, Callable, Union, Text
|
9 |
-
from pyannote.audio.core.io import AudioFile
|
10 |
-
from pyannote.core import Annotation, Segment, SlidingWindowFeature
|
11 |
-
from pyannote.audio.pipelines.utils import PipelineModel
|
12 |
-
from pyannote.audio import Model
|
13 |
-
from pyannote.audio.pipelines import VoiceActivityDetection
|
14 |
-
from .diarize import Segment as SegmentX
|
15 |
-
from typing import List, Tuple, Optional
|
16 |
-
|
17 |
-
VAD_SEGMENTATION_URL = "https://whisperx.s3.eu-west-2.amazonaws.com/model_weights/segmentation/0b5b3216d60a2d32fc086b47ea8c67589aaeb26b7e07fcbe620d6d0b83e209ea/pytorch_model.bin"
|
18 |
-
|
19 |
-
def load_vad_model(device, vad_onset, vad_offset, use_auth_token=None):
|
20 |
-
model_dir = torch.hub._get_torch_home()
|
21 |
-
os.makedirs(model_dir, exist_ok = True)
|
22 |
-
model_fp = os.path.join(model_dir, "whisperx-vad-segmentation.bin")
|
23 |
-
if os.path.exists(model_fp) and not os.path.isfile(model_fp):
|
24 |
-
raise RuntimeError(f"{model_fp} exists and is not a regular file")
|
25 |
-
|
26 |
-
if not os.path.isfile(model_fp):
|
27 |
-
with urllib.request.urlopen(VAD_SEGMENTATION_URL) as source, open(model_fp, "wb") as output:
|
28 |
-
with tqdm(
|
29 |
-
total=int(source.info().get("Content-Length")),
|
30 |
-
ncols=80,
|
31 |
-
unit="iB",
|
32 |
-
unit_scale=True,
|
33 |
-
unit_divisor=1024,
|
34 |
-
) as loop:
|
35 |
-
while True:
|
36 |
-
buffer = source.read(8192)
|
37 |
-
if not buffer:
|
38 |
-
break
|
39 |
-
|
40 |
-
output.write(buffer)
|
41 |
-
loop.update(len(buffer))
|
42 |
-
|
43 |
-
model_bytes = open(model_fp, "rb").read()
|
44 |
-
if hashlib.sha256(model_bytes).hexdigest() != VAD_SEGMENTATION_URL.split('/')[-2]:
|
45 |
-
raise RuntimeError(
|
46 |
-
"Model has been downloaded but the SHA256 checksum does not not match. Please retry loading the model."
|
47 |
-
)
|
48 |
-
|
49 |
-
vad_model = Model.from_pretrained(model_fp, use_auth_token=use_auth_token)
|
50 |
-
hyperparameters = {"onset": vad_onset,
|
51 |
-
"offset": vad_offset,
|
52 |
-
"min_duration_on": 0.1,
|
53 |
-
"min_duration_off": 0.1}
|
54 |
-
vad_pipeline = VoiceActivitySegmentation(segmentation=vad_model, device=torch.device(device))
|
55 |
-
vad_pipeline.instantiate(hyperparameters)
|
56 |
-
|
57 |
-
return vad_pipeline
|
58 |
-
|
59 |
-
class Binarize:
|
60 |
-
"""Binarize detection scores using hysteresis thresholding, with min-cut operation
|
61 |
-
to ensure not segments are longer than max_duration.
|
62 |
-
|
63 |
-
Parameters
|
64 |
-
----------
|
65 |
-
onset : float, optional
|
66 |
-
Onset threshold. Defaults to 0.5.
|
67 |
-
offset : float, optional
|
68 |
-
Offset threshold. Defaults to `onset`.
|
69 |
-
min_duration_on : float, optional
|
70 |
-
Remove active regions shorter than that many seconds. Defaults to 0s.
|
71 |
-
min_duration_off : float, optional
|
72 |
-
Fill inactive regions shorter than that many seconds. Defaults to 0s.
|
73 |
-
pad_onset : float, optional
|
74 |
-
Extend active regions by moving their start time by that many seconds.
|
75 |
-
Defaults to 0s.
|
76 |
-
pad_offset : float, optional
|
77 |
-
Extend active regions by moving their end time by that many seconds.
|
78 |
-
Defaults to 0s.
|
79 |
-
max_duration: float
|
80 |
-
The maximum length of an active segment, divides segment at timestamp with lowest score.
|
81 |
-
Reference
|
82 |
-
---------
|
83 |
-
Gregory Gelly and Jean-Luc Gauvain. "Minimum Word Error Training of
|
84 |
-
RNN-based Voice Activity Detection", InterSpeech 2015.
|
85 |
-
|
86 |
-
Modified by Max Bain to include WhisperX's min-cut operation
|
87 |
-
https://arxiv.org/abs/2303.00747
|
88 |
-
|
89 |
-
Pyannote-audio
|
90 |
-
"""
|
91 |
-
|
92 |
-
def __init__(
|
93 |
-
self,
|
94 |
-
onset: float = 0.5,
|
95 |
-
offset: Optional[float] = None,
|
96 |
-
min_duration_on: float = 0.0,
|
97 |
-
min_duration_off: float = 0.0,
|
98 |
-
pad_onset: float = 0.0,
|
99 |
-
pad_offset: float = 0.0,
|
100 |
-
max_duration: float = float('inf')
|
101 |
-
):
|
102 |
-
|
103 |
-
super().__init__()
|
104 |
-
|
105 |
-
self.onset = onset
|
106 |
-
self.offset = offset or onset
|
107 |
-
|
108 |
-
self.pad_onset = pad_onset
|
109 |
-
self.pad_offset = pad_offset
|
110 |
-
|
111 |
-
self.min_duration_on = min_duration_on
|
112 |
-
self.min_duration_off = min_duration_off
|
113 |
-
|
114 |
-
self.max_duration = max_duration
|
115 |
-
|
116 |
-
def __call__(self, scores: SlidingWindowFeature) -> Annotation:
|
117 |
-
"""Binarize detection scores
|
118 |
-
Parameters
|
119 |
-
----------
|
120 |
-
scores : SlidingWindowFeature
|
121 |
-
Detection scores.
|
122 |
-
Returns
|
123 |
-
-------
|
124 |
-
active : Annotation
|
125 |
-
Binarized scores.
|
126 |
-
"""
|
127 |
-
|
128 |
-
num_frames, num_classes = scores.data.shape
|
129 |
-
frames = scores.sliding_window
|
130 |
-
timestamps = [frames[i].middle for i in range(num_frames)]
|
131 |
-
|
132 |
-
# annotation meant to store 'active' regions
|
133 |
-
active = Annotation()
|
134 |
-
for k, k_scores in enumerate(scores.data.T):
|
135 |
-
|
136 |
-
label = k if scores.labels is None else scores.labels[k]
|
137 |
-
|
138 |
-
# initial state
|
139 |
-
start = timestamps[0]
|
140 |
-
is_active = k_scores[0] > self.onset
|
141 |
-
curr_scores = [k_scores[0]]
|
142 |
-
curr_timestamps = [start]
|
143 |
-
for t, y in zip(timestamps[1:], k_scores[1:]):
|
144 |
-
# currently active
|
145 |
-
if is_active:
|
146 |
-
curr_duration = t - start
|
147 |
-
if curr_duration > self.max_duration:
|
148 |
-
# if curr_duration > 15:
|
149 |
-
# import pdb; pdb.set_trace()
|
150 |
-
search_after = len(curr_scores) // 2
|
151 |
-
# divide segment
|
152 |
-
min_score_div_idx = search_after + np.argmin(curr_scores[search_after:])
|
153 |
-
min_score_t = curr_timestamps[min_score_div_idx]
|
154 |
-
region = Segment(start - self.pad_onset, min_score_t + self.pad_offset)
|
155 |
-
active[region, k] = label
|
156 |
-
start = curr_timestamps[min_score_div_idx]
|
157 |
-
curr_scores = curr_scores[min_score_div_idx+1:]
|
158 |
-
curr_timestamps = curr_timestamps[min_score_div_idx+1:]
|
159 |
-
# switching from active to inactive
|
160 |
-
elif y < self.offset:
|
161 |
-
region = Segment(start - self.pad_onset, t + self.pad_offset)
|
162 |
-
active[region, k] = label
|
163 |
-
start = t
|
164 |
-
is_active = False
|
165 |
-
curr_scores = []
|
166 |
-
curr_timestamps = []
|
167 |
-
# currently inactive
|
168 |
-
else:
|
169 |
-
# switching from inactive to active
|
170 |
-
if y > self.onset:
|
171 |
-
start = t
|
172 |
-
is_active = True
|
173 |
-
curr_scores.append(y)
|
174 |
-
curr_timestamps.append(t)
|
175 |
-
|
176 |
-
# if active at the end, add final region
|
177 |
-
if is_active:
|
178 |
-
region = Segment(start - self.pad_onset, t + self.pad_offset)
|
179 |
-
active[region, k] = label
|
180 |
-
|
181 |
-
# because of padding, some active regions might be overlapping: merge them.
|
182 |
-
# also: fill same speaker gaps shorter than min_duration_off
|
183 |
-
if self.pad_offset > 0.0 or self.pad_onset > 0.0 or self.min_duration_off > 0.0:
|
184 |
-
if self.max_duration < float("inf"):
|
185 |
-
raise NotImplementedError(f"This would break current max_duration param")
|
186 |
-
active = active.support(collar=self.min_duration_off)
|
187 |
-
|
188 |
-
# remove tracks shorter than min_duration_on
|
189 |
-
if self.min_duration_on > 0:
|
190 |
-
for segment, track in list(active.itertracks()):
|
191 |
-
if segment.duration < self.min_duration_on:
|
192 |
-
del active[segment, track]
|
193 |
-
|
194 |
-
return active
|
195 |
-
|
196 |
-
|
197 |
-
class VoiceActivitySegmentation(VoiceActivityDetection):
|
198 |
-
def __init__(
|
199 |
-
self,
|
200 |
-
segmentation: PipelineModel = "pyannote/segmentation",
|
201 |
-
fscore: bool = False,
|
202 |
-
use_auth_token: Union[Text, None] = None,
|
203 |
-
**inference_kwargs,
|
204 |
-
):
|
205 |
-
|
206 |
-
super().__init__(segmentation=segmentation, fscore=fscore, use_auth_token=use_auth_token, **inference_kwargs)
|
207 |
-
|
208 |
-
def apply(self, file: AudioFile, hook: Optional[Callable] = None) -> Annotation:
|
209 |
-
"""Apply voice activity detection
|
210 |
-
|
211 |
-
Parameters
|
212 |
-
----------
|
213 |
-
file : AudioFile
|
214 |
-
Processed file.
|
215 |
-
hook : callable, optional
|
216 |
-
Hook called after each major step of the pipeline with the following
|
217 |
-
signature: hook("step_name", step_artefact, file=file)
|
218 |
-
|
219 |
-
Returns
|
220 |
-
-------
|
221 |
-
speech : Annotation
|
222 |
-
Speech regions.
|
223 |
-
"""
|
224 |
-
|
225 |
-
# setup hook (e.g. for debugging purposes)
|
226 |
-
hook = self.setup_hook(file, hook=hook)
|
227 |
-
|
228 |
-
# apply segmentation model (only if needed)
|
229 |
-
# output shape is (num_chunks, num_frames, 1)
|
230 |
-
if self.training:
|
231 |
-
if self.CACHED_SEGMENTATION in file:
|
232 |
-
segmentations = file[self.CACHED_SEGMENTATION]
|
233 |
-
else:
|
234 |
-
segmentations = self._segmentation(file)
|
235 |
-
file[self.CACHED_SEGMENTATION] = segmentations
|
236 |
-
else:
|
237 |
-
segmentations: SlidingWindowFeature = self._segmentation(file)
|
238 |
-
|
239 |
-
return segmentations
|
240 |
-
|
241 |
-
|
242 |
-
def merge_vad(vad_arr, pad_onset=0.0, pad_offset=0.0, min_duration_off=0.0, min_duration_on=0.0):
|
243 |
-
|
244 |
-
active = Annotation()
|
245 |
-
for k, vad_t in enumerate(vad_arr):
|
246 |
-
region = Segment(vad_t[0] - pad_onset, vad_t[1] + pad_offset)
|
247 |
-
active[region, k] = 1
|
248 |
-
|
249 |
-
|
250 |
-
if pad_offset > 0.0 or pad_onset > 0.0 or min_duration_off > 0.0:
|
251 |
-
active = active.support(collar=min_duration_off)
|
252 |
-
|
253 |
-
# remove tracks shorter than min_duration_on
|
254 |
-
if min_duration_on > 0:
|
255 |
-
for segment, track in list(active.itertracks()):
|
256 |
-
if segment.duration < min_duration_on:
|
257 |
-
del active[segment, track]
|
258 |
-
|
259 |
-
active = active.for_json()
|
260 |
-
active_segs = pd.DataFrame([x['segment'] for x in active['content']])
|
261 |
-
return active_segs
|
262 |
-
|
263 |
-
def merge_chunks(segments, chunk_size):
|
264 |
-
"""
|
265 |
-
Merge operation described in paper
|
266 |
-
"""
|
267 |
-
curr_end = 0
|
268 |
-
merged_segments = []
|
269 |
-
seg_idxs = []
|
270 |
-
speaker_idxs = []
|
271 |
-
|
272 |
-
assert chunk_size > 0
|
273 |
-
binarize = Binarize(max_duration=chunk_size)
|
274 |
-
segments = binarize(segments)
|
275 |
-
segments_list = []
|
276 |
-
for speech_turn in segments.get_timeline():
|
277 |
-
segments_list.append(SegmentX(speech_turn.start, speech_turn.end, "UNKNOWN"))
|
278 |
-
|
279 |
-
if len(segments_list) == 0:
|
280 |
-
print("No active speech found in audio")
|
281 |
-
return []
|
282 |
-
# assert segments_list, "segments_list is empty."
|
283 |
-
# Make sur the starting point is the start of the segment.
|
284 |
-
curr_start = segments_list[0].start
|
285 |
-
|
286 |
-
for seg in segments_list:
|
287 |
-
if seg.end - curr_start > chunk_size and curr_end-curr_start > 0:
|
288 |
-
merged_segments.append({
|
289 |
-
"start": curr_start,
|
290 |
-
"end": curr_end,
|
291 |
-
"segments": seg_idxs,
|
292 |
-
})
|
293 |
-
curr_start = seg.start
|
294 |
-
seg_idxs = []
|
295 |
-
speaker_idxs = []
|
296 |
-
curr_end = seg.end
|
297 |
-
seg_idxs.append((seg.start, seg.end))
|
298 |
-
speaker_idxs.append(seg.speaker)
|
299 |
-
# add final
|
300 |
-
merged_segments.append({
|
301 |
-
"start": curr_start,
|
302 |
-
"end": curr_end,
|
303 |
-
"segments": seg_idxs,
|
304 |
-
})
|
305 |
-
return merged_segments
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|