Spaces:
Runtime error
Runtime error
whisperX merged in
Browse files- .gitignore +2 -0
- EXAMPLES.md +37 -0
- LICENSE +27 -0
- MANIFEST.in +4 -0
- figures/pipeline.png +0 -0
- requirements.txt +10 -0
- setup.py +28 -0
- whisperx/__init__.py +3 -0
- whisperx/__main__.py +4 -0
- whisperx/alignment.py +548 -0
- whisperx/asr.py +429 -0
- whisperx/diarize.py +76 -0
- whisperx/transcribe.py +220 -0
- whisperx/utils.py +317 -0
- whisperx/vad.py +305 -0
.gitignore
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
whisperx.egg-info/
|
2 |
+
**/__pycache__/
|
EXAMPLES.md
ADDED
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# More Examples
|
2 |
+
|
3 |
+
## Other Languages
|
4 |
+
|
5 |
+
For non-english ASR, it is best to use the `large` whisper model. Alignment models are automatically picked by the chosen language from the default [lists](https://github.com/m-bain/whisperX/blob/main/whisperx/alignment.py#L18).
|
6 |
+
|
7 |
+
Currently support default models tested for {en, fr, de, es, it, ja, zh, nl}
|
8 |
+
|
9 |
+
|
10 |
+
If the detected language is not in this list, you need to find a phoneme-based ASR model from [huggingface model hub](https://huggingface.co/models) and test it on your data.
|
11 |
+
|
12 |
+
### French
|
13 |
+
whisperx --model large --language fr examples/sample_fr_01.wav
|
14 |
+
|
15 |
+
|
16 |
+
https://user-images.githubusercontent.com/36994049/208298804-31c49d6f-6787-444e-a53f-e93c52706752.mov
|
17 |
+
|
18 |
+
|
19 |
+
### German
|
20 |
+
whisperx --model large --language de examples/sample_de_01.wav
|
21 |
+
|
22 |
+
|
23 |
+
https://user-images.githubusercontent.com/36994049/208298811-e36002ba-3698-4731-97d4-0aebd07e0eb3.mov
|
24 |
+
|
25 |
+
|
26 |
+
### Italian
|
27 |
+
whisperx --model large --language de examples/sample_it_01.wav
|
28 |
+
|
29 |
+
|
30 |
+
https://user-images.githubusercontent.com/36994049/208298819-6f462b2c-8cae-4c54-b8e1-90855794efc7.mov
|
31 |
+
|
32 |
+
|
33 |
+
### Japanese
|
34 |
+
whisperx --model large --language ja examples/sample_ja_01.wav
|
35 |
+
|
36 |
+
|
37 |
+
https://user-images.githubusercontent.com/19920981/208731743-311f2360-b73b-4c60-809d-aaf3cd7e06f4.mov
|
LICENSE
ADDED
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Copyright (c) 2022, Max Bain
|
2 |
+
All rights reserved.
|
3 |
+
|
4 |
+
Redistribution and use in source and binary forms, with or without
|
5 |
+
modification, are permitted provided that the following conditions are met:
|
6 |
+
1. Redistributions of source code must retain the above copyright
|
7 |
+
notice, this list of conditions and the following disclaimer.
|
8 |
+
2. Redistributions in binary form must reproduce the above copyright
|
9 |
+
notice, this list of conditions and the following disclaimer in the
|
10 |
+
documentation and/or other materials provided with the distribution.
|
11 |
+
3. All advertising materials mentioning features or use of this software
|
12 |
+
must display the following acknowledgement:
|
13 |
+
This product includes software developed by Max Bain.
|
14 |
+
4. Neither the name of Max Bain nor the
|
15 |
+
names of its contributors may be used to endorse or promote products
|
16 |
+
derived from this software without specific prior written permission.
|
17 |
+
|
18 |
+
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDER ''AS IS'' AND ANY
|
19 |
+
EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
|
20 |
+
WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
21 |
+
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
22 |
+
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
23 |
+
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
24 |
+
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
25 |
+
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
26 |
+
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE
|
27 |
+
USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
MANIFEST.in
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
include whisperx/assets/*
|
2 |
+
include whisperx/assets/gpt2/*
|
3 |
+
include whisperx/assets/multilingual/*
|
4 |
+
include whisperx/normalizers/english.json
|
figures/pipeline.png
ADDED
![]() |
requirements.txt
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
numpy
|
2 |
+
pandas
|
3 |
+
torch >=1.9
|
4 |
+
torchaudio >=0.10,<1.0
|
5 |
+
tqdm
|
6 |
+
more-itertools
|
7 |
+
transformers>=4.19.0
|
8 |
+
ffmpeg-python==0.2.0
|
9 |
+
pyannote.audio
|
10 |
+
openai-whisper==20230314
|
setup.py
ADDED
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
|
3 |
+
import pkg_resources
|
4 |
+
from setuptools import setup, find_packages
|
5 |
+
|
6 |
+
setup(
|
7 |
+
name="whisperx",
|
8 |
+
py_modules=["whisperx"],
|
9 |
+
version="2.0",
|
10 |
+
description="Time-Accurate Automatic Speech Recognition using Whisper.",
|
11 |
+
readme="README.md",
|
12 |
+
python_requires=">=3.8",
|
13 |
+
author="Max Bain",
|
14 |
+
url="https://github.com/m-bain/whisperx",
|
15 |
+
license="MIT",
|
16 |
+
packages=find_packages(exclude=["tests*"]),
|
17 |
+
install_requires=[
|
18 |
+
str(r)
|
19 |
+
for r in pkg_resources.parse_requirements(
|
20 |
+
open(os.path.join(os.path.dirname(__file__), "requirements.txt"))
|
21 |
+
)
|
22 |
+
],
|
23 |
+
entry_points = {
|
24 |
+
'console_scripts': ['whisperx=whisperx.transcribe:cli'],
|
25 |
+
},
|
26 |
+
include_package_data=True,
|
27 |
+
extras_require={'dev': ['pytest']},
|
28 |
+
)
|
whisperx/__init__.py
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
from .transcribe import transcribe, transcribe_with_vad
|
2 |
+
from .alignment import load_align_model, align
|
3 |
+
from .vad import load_vad_model
|
whisperx/__main__.py
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .transcribe import cli
|
2 |
+
|
3 |
+
|
4 |
+
cli()
|
whisperx/alignment.py
ADDED
@@ -0,0 +1,548 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
""""
|
2 |
+
Forced Alignment with Whisper
|
3 |
+
C. Max Bain
|
4 |
+
"""
|
5 |
+
import numpy as np
|
6 |
+
import pandas as pd
|
7 |
+
from typing import List, Union, Iterator, TYPE_CHECKING
|
8 |
+
from transformers import Wav2Vec2ForCTC, Wav2Vec2Processor
|
9 |
+
import torchaudio
|
10 |
+
import torch
|
11 |
+
from dataclasses import dataclass
|
12 |
+
from whisper.audio import SAMPLE_RATE, load_audio
|
13 |
+
from .utils import interpolate_nans
|
14 |
+
|
15 |
+
|
16 |
+
LANGUAGES_WITHOUT_SPACES = ["ja", "zh"]
|
17 |
+
|
18 |
+
DEFAULT_ALIGN_MODELS_TORCH = {
|
19 |
+
"en": "WAV2VEC2_ASR_BASE_960H",
|
20 |
+
"fr": "VOXPOPULI_ASR_BASE_10K_FR",
|
21 |
+
"de": "VOXPOPULI_ASR_BASE_10K_DE",
|
22 |
+
"es": "VOXPOPULI_ASR_BASE_10K_ES",
|
23 |
+
"it": "VOXPOPULI_ASR_BASE_10K_IT",
|
24 |
+
}
|
25 |
+
|
26 |
+
DEFAULT_ALIGN_MODELS_HF = {
|
27 |
+
"ja": "jonatasgrosman/wav2vec2-large-xlsr-53-japanese",
|
28 |
+
"zh": "jonatasgrosman/wav2vec2-large-xlsr-53-chinese-zh-cn",
|
29 |
+
"nl": "jonatasgrosman/wav2vec2-large-xlsr-53-dutch",
|
30 |
+
"uk": "Yehor/wav2vec2-xls-r-300m-uk-with-small-lm",
|
31 |
+
"pt": "jonatasgrosman/wav2vec2-large-xlsr-53-portuguese",
|
32 |
+
"ar": "jonatasgrosman/wav2vec2-large-xlsr-53-arabic",
|
33 |
+
"ru": "jonatasgrosman/wav2vec2-large-xlsr-53-russian",
|
34 |
+
"pl": "jonatasgrosman/wav2vec2-large-xlsr-53-polish",
|
35 |
+
"hu": "jonatasgrosman/wav2vec2-large-xlsr-53-hungarian",
|
36 |
+
"fi": "jonatasgrosman/wav2vec2-large-xlsr-53-finnish",
|
37 |
+
"fa": "jonatasgrosman/wav2vec2-large-xlsr-53-persian",
|
38 |
+
"el": "jonatasgrosman/wav2vec2-large-xlsr-53-greek",
|
39 |
+
"tr": "mpoyraz/wav2vec2-xls-r-300m-cv7-turkish",
|
40 |
+
}
|
41 |
+
|
42 |
+
|
43 |
+
def load_align_model(language_code, device, model_name=None):
|
44 |
+
if model_name is None:
|
45 |
+
# use default model
|
46 |
+
if language_code in DEFAULT_ALIGN_MODELS_TORCH:
|
47 |
+
model_name = DEFAULT_ALIGN_MODELS_TORCH[language_code]
|
48 |
+
elif language_code in DEFAULT_ALIGN_MODELS_HF:
|
49 |
+
model_name = DEFAULT_ALIGN_MODELS_HF[language_code]
|
50 |
+
else:
|
51 |
+
print(f"There is no default alignment model set for this language ({language_code}).\
|
52 |
+
Please find a wav2vec2.0 model finetuned on this language in https://huggingface.co/models, then pass the model name in --align_model [MODEL_NAME]")
|
53 |
+
raise ValueError(f"No default align-model for language: {language_code}")
|
54 |
+
|
55 |
+
if model_name in torchaudio.pipelines.__all__:
|
56 |
+
pipeline_type = "torchaudio"
|
57 |
+
bundle = torchaudio.pipelines.__dict__[model_name]
|
58 |
+
align_model = bundle.get_model().to(device)
|
59 |
+
labels = bundle.get_labels()
|
60 |
+
align_dictionary = {c.lower(): i for i, c in enumerate(labels)}
|
61 |
+
else:
|
62 |
+
try:
|
63 |
+
processor = Wav2Vec2Processor.from_pretrained(model_name)
|
64 |
+
align_model = Wav2Vec2ForCTC.from_pretrained(model_name)
|
65 |
+
except Exception as e:
|
66 |
+
print(e)
|
67 |
+
print(f"Error loading model from huggingface, check https://huggingface.co/models for finetuned wav2vec2.0 models")
|
68 |
+
raise ValueError(f'The chosen align_model "{model_name}" could not be found in huggingface (https://huggingface.co/models) or torchaudio (https://pytorch.org/audio/stable/pipelines.html#id14)')
|
69 |
+
pipeline_type = "huggingface"
|
70 |
+
align_model = align_model.to(device)
|
71 |
+
labels = processor.tokenizer.get_vocab()
|
72 |
+
align_dictionary = {char.lower(): code for char,code in processor.tokenizer.get_vocab().items()}
|
73 |
+
|
74 |
+
align_metadata = {"language": language_code, "dictionary": align_dictionary, "type": pipeline_type}
|
75 |
+
|
76 |
+
return align_model, align_metadata
|
77 |
+
|
78 |
+
|
79 |
+
def align(
|
80 |
+
transcript: Iterator[dict],
|
81 |
+
model: torch.nn.Module,
|
82 |
+
align_model_metadata: dict,
|
83 |
+
audio: Union[str, np.ndarray, torch.Tensor],
|
84 |
+
device: str,
|
85 |
+
extend_duration: float = 0.0,
|
86 |
+
start_from_previous: bool = True,
|
87 |
+
interpolate_method: str = "nearest",
|
88 |
+
):
|
89 |
+
"""
|
90 |
+
Force align phoneme recognition predictions to known transcription
|
91 |
+
|
92 |
+
Parameters
|
93 |
+
----------
|
94 |
+
transcript: Iterator[dict]
|
95 |
+
The Whisper model instance
|
96 |
+
|
97 |
+
model: torch.nn.Module
|
98 |
+
Alignment model (wav2vec2)
|
99 |
+
|
100 |
+
audio: Union[str, np.ndarray, torch.Tensor]
|
101 |
+
The path to the audio file to open, or the audio waveform
|
102 |
+
|
103 |
+
device: str
|
104 |
+
cuda device
|
105 |
+
|
106 |
+
diarization: pd.DataFrame {'start': List[float], 'end': List[float], 'speaker': List[float]}
|
107 |
+
diarization segments with speaker labels.
|
108 |
+
|
109 |
+
extend_duration: float
|
110 |
+
Amount to pad input segments by. If not using vad--filter then recommended to use 2 seconds
|
111 |
+
|
112 |
+
If the gzip compression ratio is above this value, treat as failed
|
113 |
+
|
114 |
+
interpolate_method: str ["nearest", "linear", "ignore"]
|
115 |
+
Method to assign timestamps to non-aligned words. Words are not able to be aligned when none of the characters occur in the align model dictionary.
|
116 |
+
"nearest" copies timestamp of nearest word within the segment. "linear" is linear interpolation. "drop" removes that word from output.
|
117 |
+
|
118 |
+
Returns
|
119 |
+
-------
|
120 |
+
A dictionary containing the resulting text ("text") and segment-level details ("segments"), and
|
121 |
+
the spoken language ("language"), which is detected when `decode_options["language"]` is None.
|
122 |
+
"""
|
123 |
+
if not torch.is_tensor(audio):
|
124 |
+
if isinstance(audio, str):
|
125 |
+
audio = load_audio(audio)
|
126 |
+
audio = torch.from_numpy(audio)
|
127 |
+
if len(audio.shape) == 1:
|
128 |
+
audio = audio.unsqueeze(0)
|
129 |
+
|
130 |
+
MAX_DURATION = audio.shape[1] / SAMPLE_RATE
|
131 |
+
|
132 |
+
model_dictionary = align_model_metadata["dictionary"]
|
133 |
+
model_lang = align_model_metadata["language"]
|
134 |
+
model_type = align_model_metadata["type"]
|
135 |
+
|
136 |
+
aligned_segments = []
|
137 |
+
|
138 |
+
prev_t2 = 0
|
139 |
+
|
140 |
+
char_segments_arr = {
|
141 |
+
"segment-idx": [],
|
142 |
+
"subsegment-idx": [],
|
143 |
+
"word-idx": [],
|
144 |
+
"char": [],
|
145 |
+
"start": [],
|
146 |
+
"end": [],
|
147 |
+
"score": [],
|
148 |
+
}
|
149 |
+
|
150 |
+
for sdx, segment in enumerate(transcript):
|
151 |
+
while True:
|
152 |
+
segment_align_success = False
|
153 |
+
|
154 |
+
# strip spaces at beginning / end, but keep track of the amount.
|
155 |
+
num_leading = len(segment["text"]) - len(segment["text"].lstrip())
|
156 |
+
num_trailing = len(segment["text"]) - len(segment["text"].rstrip())
|
157 |
+
transcription = segment["text"]
|
158 |
+
|
159 |
+
# TODO: convert number tokenizer / symbols to phonetic words for alignment.
|
160 |
+
# e.g. "$300" -> "three hundred dollars"
|
161 |
+
# currently "$300" is ignored since no characters present in the phonetic dictionary
|
162 |
+
|
163 |
+
# split into words
|
164 |
+
if model_lang not in LANGUAGES_WITHOUT_SPACES:
|
165 |
+
per_word = transcription.split(" ")
|
166 |
+
else:
|
167 |
+
per_word = transcription
|
168 |
+
|
169 |
+
# first check that characters in transcription can be aligned (they are contained in align model"s dictionary)
|
170 |
+
clean_char, clean_cdx = [], []
|
171 |
+
for cdx, char in enumerate(transcription):
|
172 |
+
char_ = char.lower()
|
173 |
+
# wav2vec2 models use "|" character to represent spaces
|
174 |
+
if model_lang not in LANGUAGES_WITHOUT_SPACES:
|
175 |
+
char_ = char_.replace(" ", "|")
|
176 |
+
|
177 |
+
# ignore whitespace at beginning and end of transcript
|
178 |
+
if cdx < num_leading:
|
179 |
+
pass
|
180 |
+
elif cdx > len(transcription) - num_trailing - 1:
|
181 |
+
pass
|
182 |
+
elif char_ in model_dictionary.keys():
|
183 |
+
clean_char.append(char_)
|
184 |
+
clean_cdx.append(cdx)
|
185 |
+
|
186 |
+
clean_wdx = []
|
187 |
+
for wdx, wrd in enumerate(per_word):
|
188 |
+
if any([c in model_dictionary.keys() for c in wrd]):
|
189 |
+
clean_wdx.append(wdx)
|
190 |
+
|
191 |
+
# if no characters are in the dictionary, then we skip this segment...
|
192 |
+
if len(clean_char) == 0:
|
193 |
+
print(f'Failed to align segment ("{segment["text"]}"): no characters in this segment found in model dictionary, resorting to original...')
|
194 |
+
break
|
195 |
+
|
196 |
+
transcription_cleaned = "".join(clean_char)
|
197 |
+
tokens = [model_dictionary[c] for c in transcription_cleaned]
|
198 |
+
|
199 |
+
# we only pad if not using VAD filtering
|
200 |
+
if "seg_text" not in segment:
|
201 |
+
# pad according original timestamps
|
202 |
+
t1 = max(segment["start"] - extend_duration, 0)
|
203 |
+
t2 = min(segment["end"] + extend_duration, MAX_DURATION)
|
204 |
+
|
205 |
+
# use prev_t2 as current t1 if it"s later
|
206 |
+
if start_from_previous and t1 < prev_t2:
|
207 |
+
t1 = prev_t2
|
208 |
+
|
209 |
+
# check if timestamp range is still valid
|
210 |
+
if t1 >= MAX_DURATION:
|
211 |
+
print("Failed to align segment: original start time longer than audio duration, skipping...")
|
212 |
+
break
|
213 |
+
if t2 - t1 < 0.02:
|
214 |
+
print("Failed to align segment: duration smaller than 0.02s time precision")
|
215 |
+
break
|
216 |
+
|
217 |
+
f1 = int(t1 * SAMPLE_RATE)
|
218 |
+
f2 = int(t2 * SAMPLE_RATE)
|
219 |
+
|
220 |
+
waveform_segment = audio[:, f1:f2]
|
221 |
+
|
222 |
+
with torch.inference_mode():
|
223 |
+
if model_type == "torchaudio":
|
224 |
+
emissions, _ = model(waveform_segment.to(device))
|
225 |
+
elif model_type == "huggingface":
|
226 |
+
emissions = model(waveform_segment.to(device)).logits
|
227 |
+
else:
|
228 |
+
raise NotImplementedError(f"Align model of type {model_type} not supported.")
|
229 |
+
emissions = torch.log_softmax(emissions, dim=-1)
|
230 |
+
|
231 |
+
emission = emissions[0].cpu().detach()
|
232 |
+
|
233 |
+
trellis = get_trellis(emission, tokens)
|
234 |
+
path = backtrack(trellis, emission, tokens)
|
235 |
+
if path is None:
|
236 |
+
print(f'Failed to align segment ("{segment["text"]}"): backtrack failed, resorting to original...')
|
237 |
+
break
|
238 |
+
char_segments = merge_repeats(path, transcription_cleaned)
|
239 |
+
# word_segments = merge_words(char_segments)
|
240 |
+
|
241 |
+
|
242 |
+
# sub-segments
|
243 |
+
if "seg-text" not in segment:
|
244 |
+
segment["seg-text"] = [transcription]
|
245 |
+
|
246 |
+
seg_lens = [0] + [len(x) for x in segment["seg-text"]]
|
247 |
+
seg_lens_cumsum = list(np.cumsum(seg_lens))
|
248 |
+
sub_seg_idx = 0
|
249 |
+
|
250 |
+
wdx = 0
|
251 |
+
duration = t2 - t1
|
252 |
+
ratio = duration * waveform_segment.size(0) / (trellis.size(0) - 1)
|
253 |
+
for cdx, char in enumerate(transcription + " "):
|
254 |
+
is_last = False
|
255 |
+
if cdx == len(transcription):
|
256 |
+
break
|
257 |
+
elif cdx+1 == len(transcription):
|
258 |
+
is_last = True
|
259 |
+
|
260 |
+
|
261 |
+
start, end, score = None, None, None
|
262 |
+
if cdx in clean_cdx:
|
263 |
+
char_seg = char_segments[clean_cdx.index(cdx)]
|
264 |
+
start = char_seg.start * ratio + t1
|
265 |
+
end = char_seg.end * ratio + t1
|
266 |
+
score = char_seg.score
|
267 |
+
|
268 |
+
char_segments_arr["char"].append(char)
|
269 |
+
char_segments_arr["start"].append(start)
|
270 |
+
char_segments_arr["end"].append(end)
|
271 |
+
char_segments_arr["score"].append(score)
|
272 |
+
char_segments_arr["word-idx"].append(wdx)
|
273 |
+
char_segments_arr["segment-idx"].append(sdx)
|
274 |
+
char_segments_arr["subsegment-idx"].append(sub_seg_idx)
|
275 |
+
|
276 |
+
# word-level info
|
277 |
+
if model_lang in LANGUAGES_WITHOUT_SPACES:
|
278 |
+
# character == word
|
279 |
+
wdx += 1
|
280 |
+
elif is_last or transcription[cdx+1] == " " or cdx == seg_lens_cumsum[sub_seg_idx+1] - 1:
|
281 |
+
wdx += 1
|
282 |
+
|
283 |
+
if is_last or cdx == seg_lens_cumsum[sub_seg_idx+1] - 1:
|
284 |
+
wdx = 0
|
285 |
+
sub_seg_idx += 1
|
286 |
+
|
287 |
+
prev_t2 = segment["end"]
|
288 |
+
|
289 |
+
segment_align_success = True
|
290 |
+
# end while True loop
|
291 |
+
break
|
292 |
+
|
293 |
+
# reset prev_t2 due to drifting issues
|
294 |
+
if not segment_align_success:
|
295 |
+
prev_t2 = 0
|
296 |
+
|
297 |
+
char_segments_arr = pd.DataFrame(char_segments_arr)
|
298 |
+
not_space = char_segments_arr["char"] != " "
|
299 |
+
|
300 |
+
per_seg_grp = char_segments_arr.groupby(["segment-idx", "subsegment-idx"], as_index = False)
|
301 |
+
char_segments_arr = per_seg_grp.apply(lambda x: x.reset_index(drop = True)).reset_index()
|
302 |
+
per_word_grp = char_segments_arr[not_space].groupby(["segment-idx", "subsegment-idx", "word-idx"])
|
303 |
+
per_subseg_grp = char_segments_arr[not_space].groupby(["segment-idx", "subsegment-idx"])
|
304 |
+
per_seg_grp = char_segments_arr[not_space].groupby(["segment-idx"])
|
305 |
+
char_segments_arr["local-char-idx"] = char_segments_arr.groupby(["segment-idx", "subsegment-idx"]).cumcount()
|
306 |
+
per_word_grp = char_segments_arr[not_space].groupby(["segment-idx", "subsegment-idx", "word-idx"]) # regroup
|
307 |
+
|
308 |
+
word_segments_arr = {}
|
309 |
+
|
310 |
+
# start of word is first char with a timestamp
|
311 |
+
word_segments_arr["start"] = per_word_grp["start"].min().values
|
312 |
+
# end of word is last char with a timestamp
|
313 |
+
word_segments_arr["end"] = per_word_grp["end"].max().values
|
314 |
+
# score of word is mean (excluding nan)
|
315 |
+
word_segments_arr["score"] = per_word_grp["score"].mean().values
|
316 |
+
|
317 |
+
word_segments_arr["segment-text-start"] = per_word_grp["local-char-idx"].min().astype(int).values
|
318 |
+
word_segments_arr["segment-text-end"] = per_word_grp["local-char-idx"].max().astype(int).values+1
|
319 |
+
word_segments_arr = pd.DataFrame(word_segments_arr)
|
320 |
+
|
321 |
+
word_segments_arr[["segment-idx", "subsegment-idx", "word-idx"]] = per_word_grp["local-char-idx"].min().reset_index()[["segment-idx", "subsegment-idx", "word-idx"]].astype(int)
|
322 |
+
segments_arr = {}
|
323 |
+
segments_arr["start"] = per_subseg_grp["start"].min().reset_index()["start"]
|
324 |
+
segments_arr["end"] = per_subseg_grp["end"].max().reset_index()["end"]
|
325 |
+
segments_arr = pd.DataFrame(segments_arr)
|
326 |
+
segments_arr[["segment-idx", "subsegment-idx-start"]] = per_subseg_grp["start"].min().reset_index()[["segment-idx", "subsegment-idx"]]
|
327 |
+
segments_arr["subsegment-idx-end"] = segments_arr["subsegment-idx-start"] + 1
|
328 |
+
|
329 |
+
# interpolate missing words / sub-segments
|
330 |
+
if interpolate_method != "ignore":
|
331 |
+
wrd_subseg_grp = word_segments_arr.groupby(["segment-idx", "subsegment-idx"], group_keys=False)
|
332 |
+
wrd_seg_grp = word_segments_arr.groupby(["segment-idx"], group_keys=False)
|
333 |
+
# we still know which word timestamps are interpolated because their score == nan
|
334 |
+
word_segments_arr["start"] = wrd_subseg_grp['start'].apply(lambda group: interpolate_nans(group, method=interpolate_method))
|
335 |
+
word_segments_arr["end"] = wrd_subseg_grp['end'].apply(lambda group: interpolate_nans(group, method=interpolate_method))
|
336 |
+
|
337 |
+
word_segments_arr["start"] = wrd_seg_grp['start'].apply(lambda group: interpolate_nans(group, method=interpolate_method))
|
338 |
+
word_segments_arr["end"] = wrd_seg_grp['end'].apply(lambda group: interpolate_nans(group, method=interpolate_method))
|
339 |
+
|
340 |
+
sub_seg_grp = segments_arr.groupby(["segment-idx"], group_keys=False)
|
341 |
+
segments_arr['start'] = sub_seg_grp['start'].apply(lambda group: interpolate_nans(group, method=interpolate_method))
|
342 |
+
segments_arr['end'] = sub_seg_grp['end'].apply(lambda group: interpolate_nans(group, method=interpolate_method))
|
343 |
+
|
344 |
+
# merge words & subsegments which are missing times
|
345 |
+
word_grp = word_segments_arr.groupby(["segment-idx", "subsegment-idx", "end"])
|
346 |
+
|
347 |
+
word_segments_arr["segment-text-start"] = word_grp["segment-text-start"].transform(min)
|
348 |
+
word_segments_arr["segment-text-end"] = word_grp["segment-text-end"].transform(max)
|
349 |
+
word_segments_arr.drop_duplicates(subset=["segment-idx", "subsegment-idx", "end"], inplace=True)
|
350 |
+
|
351 |
+
seg_grp_dup = segments_arr.groupby(["segment-idx", "start", "end"])
|
352 |
+
segments_arr["subsegment-idx-start"] = seg_grp_dup["subsegment-idx-start"].transform(min)
|
353 |
+
segments_arr["subsegment-idx-end"] = seg_grp_dup["subsegment-idx-end"].transform(max)
|
354 |
+
segments_arr.drop_duplicates(subset=["segment-idx", "subsegment-idx-start", "subsegment-idx-end"], inplace=True)
|
355 |
+
else:
|
356 |
+
word_segments_arr.dropna(inplace=True)
|
357 |
+
segments_arr.dropna(inplace=True)
|
358 |
+
|
359 |
+
# if some segments still have missing timestamps (usually because all numerals / symbols), then use original timestamps...
|
360 |
+
segments_arr['start'].fillna(pd.Series([x['start'] for x in transcript]), inplace=True)
|
361 |
+
segments_arr['end'].fillna(pd.Series([x['end'] for x in transcript]), inplace=True)
|
362 |
+
segments_arr['subsegment-idx-start'].fillna(0, inplace=True)
|
363 |
+
segments_arr['subsegment-idx-end'].fillna(1, inplace=True)
|
364 |
+
|
365 |
+
|
366 |
+
aligned_segments = []
|
367 |
+
aligned_segments_word = []
|
368 |
+
|
369 |
+
word_segments_arr.set_index(["segment-idx", "subsegment-idx"], inplace=True)
|
370 |
+
char_segments_arr.set_index(["segment-idx", "subsegment-idx", "word-idx"], inplace=True)
|
371 |
+
|
372 |
+
for sdx, srow in segments_arr.iterrows():
|
373 |
+
|
374 |
+
seg_idx = int(srow["segment-idx"])
|
375 |
+
sub_start = int(srow["subsegment-idx-start"])
|
376 |
+
sub_end = int(srow["subsegment-idx-end"])
|
377 |
+
|
378 |
+
seg = transcript[seg_idx]
|
379 |
+
text = "".join(seg["seg-text"][sub_start:sub_end])
|
380 |
+
|
381 |
+
wseg = word_segments_arr.loc[seg_idx].loc[sub_start:sub_end-1]
|
382 |
+
wseg["start"].fillna(srow["start"], inplace=True)
|
383 |
+
wseg["end"].fillna(srow["end"], inplace=True)
|
384 |
+
wseg["segment-text-start"].fillna(0, inplace=True)
|
385 |
+
wseg["segment-text-end"].fillna(len(text)-1, inplace=True)
|
386 |
+
|
387 |
+
cseg = char_segments_arr.loc[seg_idx].loc[sub_start:sub_end-1]
|
388 |
+
# fixes bug for single segment in transcript
|
389 |
+
cseg['segment-text-start'] = cseg['level_1'] if 'level_1' in cseg else 0
|
390 |
+
cseg['segment-text-end'] = cseg['level_1'] + 1 if 'level_1' in cseg else 1
|
391 |
+
if 'level_1' in cseg: del cseg['level_1']
|
392 |
+
if 'level_0' in cseg: del cseg['level_0']
|
393 |
+
cseg.reset_index(inplace=True)
|
394 |
+
aligned_segments.append(
|
395 |
+
{
|
396 |
+
"start": srow["start"],
|
397 |
+
"end": srow["end"],
|
398 |
+
"text": text,
|
399 |
+
"word-segments": wseg,
|
400 |
+
"char-segments": cseg
|
401 |
+
}
|
402 |
+
)
|
403 |
+
|
404 |
+
def get_raw_text(word_row):
|
405 |
+
return seg["seg-text"][word_row.name][int(word_row["segment-text-start"]):int(word_row["segment-text-end"])+1]
|
406 |
+
|
407 |
+
wdx = 0
|
408 |
+
curr_text = get_raw_text(wseg.iloc[wdx])
|
409 |
+
if len(wseg) > 1:
|
410 |
+
for _, wrow in wseg.iloc[1:].iterrows():
|
411 |
+
if wrow['start'] != wseg.iloc[wdx]['start']:
|
412 |
+
aligned_segments_word.append(
|
413 |
+
{
|
414 |
+
"text": curr_text.strip(),
|
415 |
+
"start": wseg.iloc[wdx]["start"],
|
416 |
+
"end": wseg.iloc[wdx]["end"],
|
417 |
+
}
|
418 |
+
)
|
419 |
+
curr_text = ""
|
420 |
+
curr_text += " " + get_raw_text(wrow)
|
421 |
+
wdx += 1
|
422 |
+
aligned_segments_word.append(
|
423 |
+
{
|
424 |
+
"text": curr_text.strip(),
|
425 |
+
"start": wseg.iloc[wdx]["start"],
|
426 |
+
"end": wseg.iloc[wdx]["end"]
|
427 |
+
}
|
428 |
+
)
|
429 |
+
|
430 |
+
|
431 |
+
return {"segments": aligned_segments, "word_segments": aligned_segments_word}
|
432 |
+
|
433 |
+
|
434 |
+
"""
|
435 |
+
source: https://pytorch.org/tutorials/intermediate/forced_alignment_with_torchaudio_tutorial.html
|
436 |
+
"""
|
437 |
+
def get_trellis(emission, tokens, blank_id=0):
|
438 |
+
num_frame = emission.size(0)
|
439 |
+
num_tokens = len(tokens)
|
440 |
+
|
441 |
+
# Trellis has extra diemsions for both time axis and tokens.
|
442 |
+
# The extra dim for tokens represents <SoS> (start-of-sentence)
|
443 |
+
# The extra dim for time axis is for simplification of the code.
|
444 |
+
trellis = torch.empty((num_frame + 1, num_tokens + 1))
|
445 |
+
trellis[0, 0] = 0
|
446 |
+
trellis[1:, 0] = torch.cumsum(emission[:, 0], 0)
|
447 |
+
trellis[0, -num_tokens:] = -float("inf")
|
448 |
+
trellis[-num_tokens:, 0] = float("inf")
|
449 |
+
|
450 |
+
for t in range(num_frame):
|
451 |
+
trellis[t + 1, 1:] = torch.maximum(
|
452 |
+
# Score for staying at the same token
|
453 |
+
trellis[t, 1:] + emission[t, blank_id],
|
454 |
+
# Score for changing to the next token
|
455 |
+
trellis[t, :-1] + emission[t, tokens],
|
456 |
+
)
|
457 |
+
return trellis
|
458 |
+
|
459 |
+
@dataclass
|
460 |
+
class Point:
|
461 |
+
token_index: int
|
462 |
+
time_index: int
|
463 |
+
score: float
|
464 |
+
|
465 |
+
def backtrack(trellis, emission, tokens, blank_id=0):
|
466 |
+
# Note:
|
467 |
+
# j and t are indices for trellis, which has extra dimensions
|
468 |
+
# for time and tokens at the beginning.
|
469 |
+
# When referring to time frame index `T` in trellis,
|
470 |
+
# the corresponding index in emission is `T-1`.
|
471 |
+
# Similarly, when referring to token index `J` in trellis,
|
472 |
+
# the corresponding index in transcript is `J-1`.
|
473 |
+
j = trellis.size(1) - 1
|
474 |
+
t_start = torch.argmax(trellis[:, j]).item()
|
475 |
+
|
476 |
+
path = []
|
477 |
+
for t in range(t_start, 0, -1):
|
478 |
+
# 1. Figure out if the current position was stay or change
|
479 |
+
# Note (again):
|
480 |
+
# `emission[J-1]` is the emission at time frame `J` of trellis dimension.
|
481 |
+
# Score for token staying the same from time frame J-1 to T.
|
482 |
+
stayed = trellis[t - 1, j] + emission[t - 1, blank_id]
|
483 |
+
# Score for token changing from C-1 at T-1 to J at T.
|
484 |
+
changed = trellis[t - 1, j - 1] + emission[t - 1, tokens[j - 1]]
|
485 |
+
|
486 |
+
# 2. Store the path with frame-wise probability.
|
487 |
+
prob = emission[t - 1, tokens[j - 1] if changed > stayed else 0].exp().item()
|
488 |
+
# Return token index and time index in non-trellis coordinate.
|
489 |
+
path.append(Point(j - 1, t - 1, prob))
|
490 |
+
|
491 |
+
# 3. Update the token
|
492 |
+
if changed > stayed:
|
493 |
+
j -= 1
|
494 |
+
if j == 0:
|
495 |
+
break
|
496 |
+
else:
|
497 |
+
# failed
|
498 |
+
return None
|
499 |
+
return path[::-1]
|
500 |
+
|
501 |
+
# Merge the labels
|
502 |
+
@dataclass
|
503 |
+
class Segment:
|
504 |
+
label: str
|
505 |
+
start: int
|
506 |
+
end: int
|
507 |
+
score: float
|
508 |
+
|
509 |
+
def __repr__(self):
|
510 |
+
return f"{self.label}\t({self.score:4.2f}): [{self.start:5d}, {self.end:5d})"
|
511 |
+
|
512 |
+
@property
|
513 |
+
def length(self):
|
514 |
+
return self.end - self.start
|
515 |
+
|
516 |
+
def merge_repeats(path, transcript):
|
517 |
+
i1, i2 = 0, 0
|
518 |
+
segments = []
|
519 |
+
while i1 < len(path):
|
520 |
+
while i2 < len(path) and path[i1].token_index == path[i2].token_index:
|
521 |
+
i2 += 1
|
522 |
+
score = sum(path[k].score for k in range(i1, i2)) / (i2 - i1)
|
523 |
+
segments.append(
|
524 |
+
Segment(
|
525 |
+
transcript[path[i1].token_index],
|
526 |
+
path[i1].time_index,
|
527 |
+
path[i2 - 1].time_index + 1,
|
528 |
+
score,
|
529 |
+
)
|
530 |
+
)
|
531 |
+
i1 = i2
|
532 |
+
return segments
|
533 |
+
|
534 |
+
def merge_words(segments, separator="|"):
|
535 |
+
words = []
|
536 |
+
i1, i2 = 0, 0
|
537 |
+
while i1 < len(segments):
|
538 |
+
if i2 >= len(segments) or segments[i2].label == separator:
|
539 |
+
if i1 != i2:
|
540 |
+
segs = segments[i1:i2]
|
541 |
+
word = "".join([seg.label for seg in segs])
|
542 |
+
score = sum(seg.score * seg.length for seg in segs) / sum(seg.length for seg in segs)
|
543 |
+
words.append(Segment(word, segments[i1].start, segments[i2 - 1].end, score))
|
544 |
+
i1 = i2 + 1
|
545 |
+
i2 = i1
|
546 |
+
else:
|
547 |
+
i2 += 1
|
548 |
+
return words
|
whisperx/asr.py
ADDED
@@ -0,0 +1,429 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import warnings
|
2 |
+
from typing import TYPE_CHECKING, Optional, Tuple, Union
|
3 |
+
import numpy as np
|
4 |
+
import torch
|
5 |
+
import tqdm
|
6 |
+
import ffmpeg
|
7 |
+
from whisper.audio import (
|
8 |
+
FRAMES_PER_SECOND,
|
9 |
+
HOP_LENGTH,
|
10 |
+
N_FRAMES,
|
11 |
+
N_SAMPLES,
|
12 |
+
SAMPLE_RATE,
|
13 |
+
CHUNK_LENGTH,
|
14 |
+
log_mel_spectrogram,
|
15 |
+
pad_or_trim,
|
16 |
+
load_audio
|
17 |
+
)
|
18 |
+
from whisper.decoding import DecodingOptions, DecodingResult
|
19 |
+
from whisper.timing import add_word_timestamps
|
20 |
+
from whisper.tokenizer import LANGUAGES, TO_LANGUAGE_CODE, get_tokenizer
|
21 |
+
from whisper.utils import (
|
22 |
+
exact_div,
|
23 |
+
format_timestamp,
|
24 |
+
make_safe,
|
25 |
+
)
|
26 |
+
|
27 |
+
if TYPE_CHECKING:
|
28 |
+
from whisper.model import Whisper
|
29 |
+
|
30 |
+
from .vad import merge_chunks
|
31 |
+
|
32 |
+
def transcribe(
|
33 |
+
model: "Whisper",
|
34 |
+
audio: Union[str, np.ndarray, torch.Tensor] = None,
|
35 |
+
mel: np.ndarray = None,
|
36 |
+
verbose: Optional[bool] = None,
|
37 |
+
temperature: Union[float, Tuple[float, ...]] = (0.0, 0.2, 0.4, 0.6, 0.8, 1.0),
|
38 |
+
compression_ratio_threshold: Optional[float] = 2.4,
|
39 |
+
logprob_threshold: Optional[float] = -1.0,
|
40 |
+
no_speech_threshold: Optional[float] = 0.6,
|
41 |
+
condition_on_previous_text: bool = True,
|
42 |
+
initial_prompt: Optional[str] = None,
|
43 |
+
word_timestamps: bool = False,
|
44 |
+
prepend_punctuations: str = "\"'“¿([{-",
|
45 |
+
append_punctuations: str = "\"'.。,,!!??::”)]}、",
|
46 |
+
**decode_options,
|
47 |
+
):
|
48 |
+
"""
|
49 |
+
Transcribe an audio file using Whisper.
|
50 |
+
We redefine the Whisper transcribe function to allow mel input (for sequential slicing of audio)
|
51 |
+
|
52 |
+
Parameters
|
53 |
+
----------
|
54 |
+
model: Whisper
|
55 |
+
The Whisper model instance
|
56 |
+
|
57 |
+
audio: Union[str, np.ndarray, torch.Tensor]
|
58 |
+
The path to the audio file to open, or the audio waveform
|
59 |
+
|
60 |
+
mel: np.ndarray
|
61 |
+
Mel spectrogram of audio segment.
|
62 |
+
|
63 |
+
verbose: bool
|
64 |
+
Whether to display the text being decoded to the console. If True, displays all the details,
|
65 |
+
If False, displays minimal details. If None, does not display anything
|
66 |
+
|
67 |
+
temperature: Union[float, Tuple[float, ...]]
|
68 |
+
Temperature for sampling. It can be a tuple of temperatures, which will be successively used
|
69 |
+
upon failures according to either `compression_ratio_threshold` or `logprob_threshold`.
|
70 |
+
|
71 |
+
compression_ratio_threshold: float
|
72 |
+
If the gzip compression ratio is above this value, treat as failed
|
73 |
+
|
74 |
+
logprob_threshold: float
|
75 |
+
If the average log probability over sampled tokens is below this value, treat as failed
|
76 |
+
|
77 |
+
no_speech_threshold: float
|
78 |
+
If the no_speech probability is higher than this value AND the average log probability
|
79 |
+
over sampled tokens is below `logprob_threshold`, consider the segment as silent
|
80 |
+
|
81 |
+
condition_on_previous_text: bool
|
82 |
+
if True, the previous output of the model is provided as a prompt for the next window;
|
83 |
+
disabling may make the text inconsistent across windows, but the model becomes less prone to
|
84 |
+
getting stuck in a failure loop, such as repetition looping or timestamps going out of sync.
|
85 |
+
|
86 |
+
word_timestamps: bool
|
87 |
+
Extract word-level timestamps using the cross-attention pattern and dynamic time warping,
|
88 |
+
and include the timestamps for each word in each segment.
|
89 |
+
|
90 |
+
prepend_punctuations: str
|
91 |
+
If word_timestamps is True, merge these punctuation symbols with the next word
|
92 |
+
|
93 |
+
append_punctuations: str
|
94 |
+
If word_timestamps is True, merge these punctuation symbols with the previous word
|
95 |
+
|
96 |
+
initial_prompt: Optional[str]
|
97 |
+
Optional text to provide as a prompt for the first window. This can be used to provide, or
|
98 |
+
"prompt-engineer" a context for transcription, e.g. custom vocabularies or proper nouns
|
99 |
+
to make it more likely to predict those word correctly.
|
100 |
+
|
101 |
+
decode_options: dict
|
102 |
+
Keyword arguments to construct `DecodingOptions` instances
|
103 |
+
|
104 |
+
Returns
|
105 |
+
-------
|
106 |
+
A dictionary containing the resulting text ("text") and segment-level details ("segments"), and
|
107 |
+
the spoken language ("language"), which is detected when `decode_options["language"]` is None.
|
108 |
+
"""
|
109 |
+
dtype = torch.float16 if decode_options.get("fp16", True) else torch.float32
|
110 |
+
if model.device == torch.device("cpu"):
|
111 |
+
if torch.cuda.is_available():
|
112 |
+
warnings.warn("Performing inference on CPU when CUDA is available")
|
113 |
+
if dtype == torch.float16:
|
114 |
+
warnings.warn("FP16 is not supported on CPU; using FP32 instead")
|
115 |
+
dtype = torch.float32
|
116 |
+
|
117 |
+
if dtype == torch.float32:
|
118 |
+
decode_options["fp16"] = False
|
119 |
+
|
120 |
+
# Pad 30-seconds of silence to the input audio, for slicing
|
121 |
+
if mel is None:
|
122 |
+
if audio is None:
|
123 |
+
raise ValueError("Transcribe needs either audio or mel as input, currently both are none.")
|
124 |
+
mel = log_mel_spectrogram(audio, padding=N_SAMPLES)
|
125 |
+
content_frames = mel.shape[-1] - N_FRAMES
|
126 |
+
|
127 |
+
if decode_options.get("language", None) is None:
|
128 |
+
if not model.is_multilingual:
|
129 |
+
decode_options["language"] = "en"
|
130 |
+
else:
|
131 |
+
if verbose:
|
132 |
+
print(
|
133 |
+
"Detecting language using up to the first 30 seconds. Use `--language` to specify the language"
|
134 |
+
)
|
135 |
+
mel_segment = pad_or_trim(mel, N_FRAMES).to(model.device).to(dtype)
|
136 |
+
_, probs = model.detect_language(mel_segment)
|
137 |
+
decode_options["language"] = max(probs, key=probs.get)
|
138 |
+
if verbose is not None:
|
139 |
+
print(
|
140 |
+
f"Detected language: {LANGUAGES[decode_options['language']].title()}"
|
141 |
+
)
|
142 |
+
|
143 |
+
language: str = decode_options["language"]
|
144 |
+
task: str = decode_options.get("task", "transcribe")
|
145 |
+
tokenizer = get_tokenizer(model.is_multilingual, language=language, task=task)
|
146 |
+
|
147 |
+
if word_timestamps and task == "translate":
|
148 |
+
warnings.warn("Word-level timestamps on translations may not be reliable.")
|
149 |
+
|
150 |
+
def decode_with_fallback(segment: torch.Tensor) -> DecodingResult:
|
151 |
+
temperatures = (
|
152 |
+
[temperature] if isinstance(temperature, (int, float)) else temperature
|
153 |
+
)
|
154 |
+
decode_result = None
|
155 |
+
|
156 |
+
for t in temperatures:
|
157 |
+
kwargs = {**decode_options}
|
158 |
+
if t > 0:
|
159 |
+
# disable beam_size and patience when t > 0
|
160 |
+
kwargs.pop("beam_size", None)
|
161 |
+
kwargs.pop("patience", None)
|
162 |
+
else:
|
163 |
+
# disable best_of when t == 0
|
164 |
+
kwargs.pop("best_of", None)
|
165 |
+
|
166 |
+
options = DecodingOptions(**kwargs, temperature=t)
|
167 |
+
decode_result = model.decode(segment, options)
|
168 |
+
|
169 |
+
needs_fallback = False
|
170 |
+
if (
|
171 |
+
compression_ratio_threshold is not None
|
172 |
+
and decode_result.compression_ratio > compression_ratio_threshold
|
173 |
+
):
|
174 |
+
needs_fallback = True # too repetitive
|
175 |
+
if (
|
176 |
+
logprob_threshold is not None
|
177 |
+
and decode_result.avg_logprob < logprob_threshold
|
178 |
+
):
|
179 |
+
needs_fallback = True # average log probability is too low
|
180 |
+
|
181 |
+
if not needs_fallback:
|
182 |
+
break
|
183 |
+
|
184 |
+
return decode_result
|
185 |
+
|
186 |
+
seek = 0
|
187 |
+
input_stride = exact_div(
|
188 |
+
N_FRAMES, model.dims.n_audio_ctx
|
189 |
+
) # mel frames per output token: 2
|
190 |
+
time_precision = (
|
191 |
+
input_stride * HOP_LENGTH / SAMPLE_RATE
|
192 |
+
) # time per output token: 0.02 (seconds)
|
193 |
+
all_tokens = []
|
194 |
+
all_segments = []
|
195 |
+
prompt_reset_since = 0
|
196 |
+
|
197 |
+
if initial_prompt is not None:
|
198 |
+
initial_prompt_tokens = tokenizer.encode(" " + initial_prompt.strip())
|
199 |
+
all_tokens.extend(initial_prompt_tokens)
|
200 |
+
else:
|
201 |
+
initial_prompt_tokens = []
|
202 |
+
|
203 |
+
def new_segment(
|
204 |
+
*, start: float, end: float, tokens: torch.Tensor, result: DecodingResult
|
205 |
+
):
|
206 |
+
tokens = tokens.tolist()
|
207 |
+
text_tokens = [token for token in tokens if token < tokenizer.eot]
|
208 |
+
return {
|
209 |
+
"seek": seek,
|
210 |
+
"start": start,
|
211 |
+
"end": end,
|
212 |
+
"text": tokenizer.decode(text_tokens),
|
213 |
+
"tokens": tokens,
|
214 |
+
"temperature": result.temperature,
|
215 |
+
"avg_logprob": result.avg_logprob,
|
216 |
+
"compression_ratio": result.compression_ratio,
|
217 |
+
"no_speech_prob": result.no_speech_prob,
|
218 |
+
}
|
219 |
+
|
220 |
+
|
221 |
+
# show the progress bar when verbose is False (if True, transcribed text will be printed)
|
222 |
+
with tqdm.tqdm(
|
223 |
+
total=content_frames, unit="frames", disable=verbose is not False
|
224 |
+
) as pbar:
|
225 |
+
while seek < content_frames:
|
226 |
+
time_offset = float(seek * HOP_LENGTH / SAMPLE_RATE)
|
227 |
+
mel_segment = mel[:, seek : seek + N_FRAMES]
|
228 |
+
segment_size = min(N_FRAMES, content_frames - seek)
|
229 |
+
segment_duration = segment_size * HOP_LENGTH / SAMPLE_RATE
|
230 |
+
mel_segment = pad_or_trim(mel_segment, N_FRAMES).to(model.device).to(dtype)
|
231 |
+
|
232 |
+
decode_options["prompt"] = all_tokens[prompt_reset_since:]
|
233 |
+
result: DecodingResult = decode_with_fallback(mel_segment)
|
234 |
+
tokens = torch.tensor(result.tokens)
|
235 |
+
if no_speech_threshold is not None:
|
236 |
+
# no voice activity check
|
237 |
+
should_skip = result.no_speech_prob > no_speech_threshold
|
238 |
+
if (
|
239 |
+
logprob_threshold is not None
|
240 |
+
and result.avg_logprob > logprob_threshold
|
241 |
+
):
|
242 |
+
# don't skip if the logprob is high enough, despite the no_speech_prob
|
243 |
+
should_skip = False
|
244 |
+
|
245 |
+
if should_skip:
|
246 |
+
seek += segment_size # fast-forward to the next segment boundary
|
247 |
+
continue
|
248 |
+
|
249 |
+
previous_seek = seek
|
250 |
+
current_segments = []
|
251 |
+
|
252 |
+
timestamp_tokens: torch.Tensor = tokens.ge(tokenizer.timestamp_begin)
|
253 |
+
single_timestamp_ending = timestamp_tokens[-2:].tolist() == [False, True]
|
254 |
+
|
255 |
+
consecutive = torch.where(timestamp_tokens[:-1] & timestamp_tokens[1:])[0]
|
256 |
+
consecutive.add_(1)
|
257 |
+
if len(consecutive) > 0:
|
258 |
+
# if the output contains two consecutive timestamp tokens
|
259 |
+
slices = consecutive.tolist()
|
260 |
+
if single_timestamp_ending:
|
261 |
+
slices.append(len(tokens))
|
262 |
+
|
263 |
+
last_slice = 0
|
264 |
+
for current_slice in slices:
|
265 |
+
sliced_tokens = tokens[last_slice:current_slice]
|
266 |
+
start_timestamp_pos = (
|
267 |
+
sliced_tokens[0].item() - tokenizer.timestamp_begin
|
268 |
+
)
|
269 |
+
end_timestamp_pos = (
|
270 |
+
sliced_tokens[-1].item() - tokenizer.timestamp_begin
|
271 |
+
)
|
272 |
+
current_segments.append(
|
273 |
+
new_segment(
|
274 |
+
start=time_offset + start_timestamp_pos * time_precision,
|
275 |
+
end=time_offset + end_timestamp_pos * time_precision,
|
276 |
+
tokens=sliced_tokens,
|
277 |
+
result=result,
|
278 |
+
)
|
279 |
+
)
|
280 |
+
last_slice = current_slice
|
281 |
+
|
282 |
+
if single_timestamp_ending:
|
283 |
+
# single timestamp at the end means no speech after the last timestamp.
|
284 |
+
seek += segment_size
|
285 |
+
else:
|
286 |
+
# otherwise, ignore the unfinished segment and seek to the last timestamp
|
287 |
+
last_timestamp_pos = (
|
288 |
+
tokens[last_slice - 1].item() - tokenizer.timestamp_begin
|
289 |
+
)
|
290 |
+
seek += last_timestamp_pos * input_stride
|
291 |
+
else:
|
292 |
+
duration = segment_duration
|
293 |
+
timestamps = tokens[timestamp_tokens.nonzero().flatten()]
|
294 |
+
if (
|
295 |
+
len(timestamps) > 0
|
296 |
+
and timestamps[-1].item() != tokenizer.timestamp_begin
|
297 |
+
):
|
298 |
+
# no consecutive timestamps but it has a timestamp; use the last one.
|
299 |
+
last_timestamp_pos = (
|
300 |
+
timestamps[-1].item() - tokenizer.timestamp_begin
|
301 |
+
)
|
302 |
+
duration = last_timestamp_pos * time_precision
|
303 |
+
|
304 |
+
current_segments.append(
|
305 |
+
new_segment(
|
306 |
+
start=time_offset,
|
307 |
+
end=time_offset + duration,
|
308 |
+
tokens=tokens,
|
309 |
+
result=result,
|
310 |
+
)
|
311 |
+
)
|
312 |
+
seek += segment_size
|
313 |
+
|
314 |
+
if not condition_on_previous_text or result.temperature > 0.5:
|
315 |
+
# do not feed the prompt tokens if a high temperature was used
|
316 |
+
prompt_reset_since = len(all_tokens)
|
317 |
+
|
318 |
+
if word_timestamps:
|
319 |
+
add_word_timestamps(
|
320 |
+
segments=current_segments,
|
321 |
+
model=model,
|
322 |
+
tokenizer=tokenizer,
|
323 |
+
mel=mel_segment,
|
324 |
+
num_frames=segment_size,
|
325 |
+
prepend_punctuations=prepend_punctuations,
|
326 |
+
append_punctuations=append_punctuations,
|
327 |
+
)
|
328 |
+
word_end_timestamps = [
|
329 |
+
w["end"] for s in current_segments for w in s["words"]
|
330 |
+
]
|
331 |
+
if not single_timestamp_ending and len(word_end_timestamps) > 0:
|
332 |
+
seek_shift = round(
|
333 |
+
(word_end_timestamps[-1] - time_offset) * FRAMES_PER_SECOND
|
334 |
+
)
|
335 |
+
if seek_shift > 0:
|
336 |
+
seek = previous_seek + seek_shift
|
337 |
+
|
338 |
+
if verbose:
|
339 |
+
for segment in current_segments:
|
340 |
+
start, end, text = segment["start"], segment["end"], segment["text"]
|
341 |
+
line = f"[{format_timestamp(start)} --> {format_timestamp(end)}] {text}"
|
342 |
+
print(make_safe(line))
|
343 |
+
|
344 |
+
# if a segment is instantaneous or does not contain text, clear it
|
345 |
+
for i, segment in enumerate(current_segments):
|
346 |
+
if segment["start"] == segment["end"] or segment["text"].strip() == "":
|
347 |
+
segment["text"] = ""
|
348 |
+
segment["tokens"] = []
|
349 |
+
segment["words"] = []
|
350 |
+
|
351 |
+
all_segments.extend(
|
352 |
+
[
|
353 |
+
{"id": i, **segment}
|
354 |
+
for i, segment in enumerate(
|
355 |
+
current_segments, start=len(all_segments)
|
356 |
+
)
|
357 |
+
]
|
358 |
+
)
|
359 |
+
all_tokens.extend(
|
360 |
+
[token for segment in current_segments for token in segment["tokens"]]
|
361 |
+
)
|
362 |
+
|
363 |
+
# update progress bar
|
364 |
+
pbar.update(min(content_frames, seek) - previous_seek)
|
365 |
+
|
366 |
+
|
367 |
+
return dict(
|
368 |
+
text=tokenizer.decode(all_tokens[len(initial_prompt_tokens) :]),
|
369 |
+
segments=all_segments,
|
370 |
+
language=language,
|
371 |
+
)
|
372 |
+
|
373 |
+
|
374 |
+
def transcribe_with_vad(
|
375 |
+
model: "Whisper",
|
376 |
+
audio: str,
|
377 |
+
vad_pipeline,
|
378 |
+
mel = None,
|
379 |
+
verbose: Optional[bool] = None,
|
380 |
+
**kwargs
|
381 |
+
):
|
382 |
+
"""
|
383 |
+
Transcribe per VAD segment
|
384 |
+
"""
|
385 |
+
|
386 |
+
vad_segments = vad_pipeline(audio)
|
387 |
+
|
388 |
+
# if not torch.is_tensor(audio):
|
389 |
+
# if isinstance(audio, str):
|
390 |
+
audio = load_audio(audio)
|
391 |
+
audio = torch.from_numpy(audio)
|
392 |
+
|
393 |
+
prev = 0
|
394 |
+
output = {"segments": []}
|
395 |
+
|
396 |
+
# merge segments to approx 30s inputs to make whisper most appropraite
|
397 |
+
vad_segments = merge_chunks(vad_segments, chunk_size=CHUNK_LENGTH)
|
398 |
+
if len(vad_segments) == 0:
|
399 |
+
return output
|
400 |
+
|
401 |
+
print(">>Performing transcription...")
|
402 |
+
for sdx, seg_t in enumerate(vad_segments):
|
403 |
+
if verbose:
|
404 |
+
print(f"~~ Transcribing VAD chunk: ({format_timestamp(seg_t['start'])} --> {format_timestamp(seg_t['end'])}) ~~")
|
405 |
+
seg_f_start, seg_f_end = int(seg_t["start"] * SAMPLE_RATE), int(seg_t["end"] * SAMPLE_RATE)
|
406 |
+
local_f_start, local_f_end = seg_f_start - prev, seg_f_end - prev
|
407 |
+
audio = audio[local_f_start:] # seek forward
|
408 |
+
seg_audio = audio[:local_f_end-local_f_start] # seek forward
|
409 |
+
prev = seg_f_start
|
410 |
+
local_mel = log_mel_spectrogram(seg_audio, padding=N_SAMPLES)
|
411 |
+
# need to pad
|
412 |
+
|
413 |
+
result = transcribe(model, audio, mel=local_mel, verbose=verbose, **kwargs)
|
414 |
+
seg_t["text"] = result["text"]
|
415 |
+
output["segments"].append(
|
416 |
+
{
|
417 |
+
"start": seg_t["start"],
|
418 |
+
"end": seg_t["end"],
|
419 |
+
"language": result["language"],
|
420 |
+
"text": result["text"],
|
421 |
+
"seg-text": [x["text"] for x in result["segments"]],
|
422 |
+
"seg-start": [x["start"] for x in result["segments"]],
|
423 |
+
"seg-end": [x["end"] for x in result["segments"]],
|
424 |
+
}
|
425 |
+
)
|
426 |
+
|
427 |
+
output["language"] = output["segments"][0]["language"]
|
428 |
+
|
429 |
+
return output
|
whisperx/diarize.py
ADDED
@@ -0,0 +1,76 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import pandas as pd
|
3 |
+
from pyannote.audio import Pipeline
|
4 |
+
|
5 |
+
class DiarizationPipeline:
|
6 |
+
def __init__(
|
7 |
+
self,
|
8 |
+
model_name="pyannote/[email protected]",
|
9 |
+
use_auth_token=None,
|
10 |
+
):
|
11 |
+
self.model = Pipeline.from_pretrained(model_name, use_auth_token=use_auth_token)
|
12 |
+
|
13 |
+
def __call__(self, audio, min_speakers=None, max_speakers=None):
|
14 |
+
segments = self.model(audio, min_speakers=min_speakers, max_speakers=max_speakers)
|
15 |
+
diarize_df = pd.DataFrame(segments.itertracks(yield_label=True))
|
16 |
+
diarize_df['start'] = diarize_df[0].apply(lambda x: x.start)
|
17 |
+
diarize_df['end'] = diarize_df[0].apply(lambda x: x.end)
|
18 |
+
return diarize_df
|
19 |
+
|
20 |
+
def assign_word_speakers(diarize_df, result_segments, fill_nearest=False):
|
21 |
+
for seg in result_segments:
|
22 |
+
wdf = seg['word-segments']
|
23 |
+
if len(wdf['start'].dropna()) == 0:
|
24 |
+
wdf['start'] = seg['start']
|
25 |
+
wdf['end'] = seg['end']
|
26 |
+
speakers = []
|
27 |
+
for wdx, wrow in wdf.iterrows():
|
28 |
+
if not np.isnan(wrow['start']):
|
29 |
+
diarize_df['intersection'] = np.minimum(diarize_df['end'], wrow['end']) - np.maximum(diarize_df['start'], wrow['start'])
|
30 |
+
diarize_df['union'] = np.maximum(diarize_df['end'], wrow['end']) - np.minimum(diarize_df['start'], wrow['start'])
|
31 |
+
# remove no hit
|
32 |
+
if not fill_nearest:
|
33 |
+
dia_tmp = diarize_df[diarize_df['intersection'] > 0]
|
34 |
+
else:
|
35 |
+
dia_tmp = diarize_df
|
36 |
+
if len(dia_tmp) == 0:
|
37 |
+
speaker = None
|
38 |
+
else:
|
39 |
+
speaker = dia_tmp.sort_values("intersection", ascending=False).iloc[0][2]
|
40 |
+
else:
|
41 |
+
speaker = None
|
42 |
+
speakers.append(speaker)
|
43 |
+
seg['word-segments']['speaker'] = speakers
|
44 |
+
|
45 |
+
speaker_count = pd.Series(speakers).value_counts()
|
46 |
+
if len(speaker_count) == 0:
|
47 |
+
seg["speaker"]= "UNKNOWN"
|
48 |
+
else:
|
49 |
+
seg["speaker"] = speaker_count.index[0]
|
50 |
+
|
51 |
+
# create word level segments for .srt
|
52 |
+
word_seg = []
|
53 |
+
for seg in result_segments:
|
54 |
+
wseg = pd.DataFrame(seg["word-segments"])
|
55 |
+
for wdx, wrow in wseg.iterrows():
|
56 |
+
if wrow["start"] is not None:
|
57 |
+
speaker = wrow['speaker']
|
58 |
+
if speaker is None or speaker == np.nan:
|
59 |
+
speaker = "UNKNOWN"
|
60 |
+
word_seg.append(
|
61 |
+
{
|
62 |
+
"start": wrow["start"],
|
63 |
+
"end": wrow["end"],
|
64 |
+
"text": f"[{speaker}]: " + seg["text"][int(wrow["segment-text-start"]):int(wrow["segment-text-end"])]
|
65 |
+
}
|
66 |
+
)
|
67 |
+
|
68 |
+
# TODO: create segments but split words on new speaker
|
69 |
+
|
70 |
+
return result_segments, word_seg
|
71 |
+
|
72 |
+
class Segment:
|
73 |
+
def __init__(self, start, end, speaker=None):
|
74 |
+
self.start = start
|
75 |
+
self.end = end
|
76 |
+
self.speaker = speaker
|
whisperx/transcribe.py
ADDED
@@ -0,0 +1,220 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import os
|
3 |
+
import gc
|
4 |
+
import warnings
|
5 |
+
from typing import TYPE_CHECKING, Optional, Tuple, Union
|
6 |
+
import numpy as np
|
7 |
+
import torch
|
8 |
+
import tempfile
|
9 |
+
import ffmpeg
|
10 |
+
from whisper.tokenizer import LANGUAGES, TO_LANGUAGE_CODE
|
11 |
+
from whisper.audio import SAMPLE_RATE
|
12 |
+
from whisper.utils import (
|
13 |
+
optional_float,
|
14 |
+
optional_int,
|
15 |
+
str2bool,
|
16 |
+
)
|
17 |
+
|
18 |
+
from .alignment import load_align_model, align
|
19 |
+
from .asr import transcribe, transcribe_with_vad
|
20 |
+
from .diarize import DiarizationPipeline, assign_word_speakers
|
21 |
+
from .utils import get_writer
|
22 |
+
from .vad import load_vad_model
|
23 |
+
|
24 |
+
def cli():
|
25 |
+
from whisper import available_models
|
26 |
+
|
27 |
+
# fmt: off
|
28 |
+
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
|
29 |
+
parser.add_argument("audio", nargs="+", type=str, help="audio file(s) to transcribe")
|
30 |
+
parser.add_argument("--model", default="small", choices=available_models(), help="name of the Whisper model to use")
|
31 |
+
parser.add_argument("--model_dir", type=str, default=None, help="the path to save model files; uses ~/.cache/whisper by default")
|
32 |
+
parser.add_argument("--device", default="cuda" if torch.cuda.is_available() else "cpu", help="device to use for PyTorch inference")
|
33 |
+
parser.add_argument("--output_dir", "-o", type=str, default=".", help="directory to save the outputs")
|
34 |
+
parser.add_argument("--output_format", "-f", type=str, default="all", choices=["all", "srt", "srt-word", "vtt", "txt", "tsv", "ass", "ass-char", "pickle", "vad"], help="format of the output file; if not specified, all available formats will be produced")
|
35 |
+
parser.add_argument("--verbose", type=str2bool, default=True, help="whether to print out the progress and debug messages")
|
36 |
+
|
37 |
+
parser.add_argument("--task", type=str, default="transcribe", choices=["transcribe", "translate"], help="whether to perform X->X speech recognition ('transcribe') or X->English translation ('translate')")
|
38 |
+
parser.add_argument("--language", type=str, default=None, choices=sorted(LANGUAGES.keys()) + sorted([k.title() for k in TO_LANGUAGE_CODE.keys()]), help="language spoken in the audio, specify None to perform language detection")
|
39 |
+
|
40 |
+
# alignment params
|
41 |
+
parser.add_argument("--align_model", default=None, help="Name of phoneme-level ASR model to do alignment")
|
42 |
+
parser.add_argument("--align_extend", default=2, type=float, help="Seconds before and after to extend the whisper segments for alignment (if not using VAD).")
|
43 |
+
parser.add_argument("--align_from_prev", default=True, type=bool, help="Whether to clip the alignment start time of current segment to the end time of the last aligned word of the previous segment (if not using VAD)")
|
44 |
+
parser.add_argument("--interpolate_method", default="nearest", choices=["nearest", "linear", "ignore"], help="For word .srt, method to assign timestamps to non-aligned words, or merge them into neighbouring.")
|
45 |
+
parser.add_argument("--no_align", action='store_true', help="Do not perform phoneme alignment")
|
46 |
+
|
47 |
+
# vad params
|
48 |
+
parser.add_argument("--vad_filter", type=str2bool, default=True, help="Whether to pre-segment audio with VAD, highly recommended! Produces more accurate alignment + timestamp see WhisperX paper https://arxiv.org/abs/2303.00747")
|
49 |
+
parser.add_argument("--vad_onset", type=float, default=0.500, help="Onset threshold for VAD (see pyannote.audio), reduce this if speech is not being detected")
|
50 |
+
parser.add_argument("--vad_offset", type=float, default=0.363, help="Offset threshold for VAD (see pyannote.audio), reduce this if speech is not being detected.")
|
51 |
+
|
52 |
+
# diarization params
|
53 |
+
parser.add_argument("--diarize", action="store_true", help="Apply diarization to assign speaker labels to each segment/word")
|
54 |
+
parser.add_argument("--min_speakers", default=None, type=int)
|
55 |
+
parser.add_argument("--max_speakers", default=None, type=int)
|
56 |
+
|
57 |
+
parser.add_argument("--temperature", type=float, default=0, help="temperature to use for sampling")
|
58 |
+
parser.add_argument("--best_of", type=optional_int, default=5, help="number of candidates when sampling with non-zero temperature")
|
59 |
+
parser.add_argument("--beam_size", type=optional_int, default=5, help="number of beams in beam search, only applicable when temperature is zero")
|
60 |
+
parser.add_argument("--patience", type=float, default=None, help="optional patience value to use in beam decoding, as in https://arxiv.org/abs/2204.05424, the default (1.0) is equivalent to conventional beam search")
|
61 |
+
parser.add_argument("--length_penalty", type=float, default=None, help="optional token length penalty coefficient (alpha) as in https://arxiv.org/abs/1609.08144, uses simple length normalization by default")
|
62 |
+
|
63 |
+
parser.add_argument("--suppress_tokens", type=str, default="-1", help="comma-separated list of token ids to suppress during sampling; '-1' will suppress most special characters except common punctuations")
|
64 |
+
parser.add_argument("--initial_prompt", type=str, default=None, help="optional text to provide as a prompt for the first window.")
|
65 |
+
parser.add_argument("--condition_on_previous_text", type=str2bool, default=False, help="if True, provide the previous output of the model as a prompt for the next window; disabling may make the text inconsistent across windows, but the model becomes less prone to getting stuck in a failure loop")
|
66 |
+
parser.add_argument("--fp16", type=str2bool, default=True, help="whether to perform inference in fp16; True by default")
|
67 |
+
|
68 |
+
parser.add_argument("--temperature_increment_on_fallback", type=optional_float, default=0.2, help="temperature to increase when falling back when the decoding fails to meet either of the thresholds below")
|
69 |
+
parser.add_argument("--compression_ratio_threshold", type=optional_float, default=2.4, help="if the gzip compression ratio is higher than this value, treat the decoding as failed")
|
70 |
+
parser.add_argument("--logprob_threshold", type=optional_float, default=-1.0, help="if the average log probability is lower than this value, treat the decoding as failed")
|
71 |
+
parser.add_argument("--no_speech_threshold", type=optional_float, default=0.6, help="if the probability of the <|nospeech|> token is higher than this value AND the decoding has failed due to `logprob_threshold`, consider the segment as silence")
|
72 |
+
parser.add_argument("--word_timestamps", type=str2bool, default=False, help="(experimental) extract word-level timestamps and refine the results based on them")
|
73 |
+
parser.add_argument("--prepend_punctuations", type=str, default="\"\'“¿([{-", help="if word_timestamps is True, merge these punctuation symbols with the next word")
|
74 |
+
parser.add_argument("--append_punctuations", type=str, default="\"\'.。,,!!??::”)]}、", help="if word_timestamps is True, merge these punctuation symbols with the previous word")
|
75 |
+
parser.add_argument("--threads", type=optional_int, default=0, help="number of threads used by torch for CPU inference; supercedes MKL_NUM_THREADS/OMP_NUM_THREADS")
|
76 |
+
|
77 |
+
parser.add_argument("--hf_token", type=str, default=None, help="Hugging Face Access Token to access PyAnnote gated models")
|
78 |
+
# parser.add_argument("--model_flush", action="store_true", help="Flush memory from each model after use, reduces GPU requirement but slower processing >1 audio file.")
|
79 |
+
parser.add_argument("--tmp_dir", default=None, help="Temporary directory to write audio file if input if not .wav format (only for VAD).")
|
80 |
+
# fmt: on
|
81 |
+
|
82 |
+
args = parser.parse_args().__dict__
|
83 |
+
model_name: str = args.pop("model")
|
84 |
+
model_dir: str = args.pop("model_dir")
|
85 |
+
output_dir: str = args.pop("output_dir")
|
86 |
+
output_format: str = args.pop("output_format")
|
87 |
+
device: str = args.pop("device")
|
88 |
+
# model_flush: bool = args.pop("model_flush")
|
89 |
+
os.makedirs(output_dir, exist_ok=True)
|
90 |
+
|
91 |
+
tmp_dir: str = args.pop("tmp_dir")
|
92 |
+
if tmp_dir is not None:
|
93 |
+
os.makedirs(tmp_dir, exist_ok=True)
|
94 |
+
|
95 |
+
align_model: str = args.pop("align_model")
|
96 |
+
align_extend: float = args.pop("align_extend")
|
97 |
+
align_from_prev: bool = args.pop("align_from_prev")
|
98 |
+
interpolate_method: str = args.pop("interpolate_method")
|
99 |
+
no_align: bool = args.pop("no_align")
|
100 |
+
|
101 |
+
hf_token: str = args.pop("hf_token")
|
102 |
+
vad_filter: bool = args.pop("vad_filter")
|
103 |
+
vad_onset: float = args.pop("vad_onset")
|
104 |
+
vad_offset: float = args.pop("vad_offset")
|
105 |
+
|
106 |
+
diarize: bool = args.pop("diarize")
|
107 |
+
min_speakers: int = args.pop("min_speakers")
|
108 |
+
max_speakers: int = args.pop("max_speakers")
|
109 |
+
|
110 |
+
if vad_filter:
|
111 |
+
from pyannote.audio import Pipeline
|
112 |
+
from pyannote.audio import Model, Pipeline
|
113 |
+
vad_model = load_vad_model(torch.device(device), vad_onset, vad_offset, use_auth_token=hf_token)
|
114 |
+
else:
|
115 |
+
vad_model = None
|
116 |
+
|
117 |
+
# if model_flush:
|
118 |
+
# print(">>Model flushing activated... Only loading model after ASR stage")
|
119 |
+
# del align_model
|
120 |
+
# align_model = ""
|
121 |
+
|
122 |
+
|
123 |
+
if model_name.endswith(".en") and args["language"] not in {"en", "English"}:
|
124 |
+
if args["language"] is not None:
|
125 |
+
warnings.warn(
|
126 |
+
f"{model_name} is an English-only model but receipted '{args['language']}'; using English instead."
|
127 |
+
)
|
128 |
+
args["language"] = "en"
|
129 |
+
|
130 |
+
temperature = args.pop("temperature")
|
131 |
+
if (increment := args.pop("temperature_increment_on_fallback")) is not None:
|
132 |
+
temperature = tuple(np.arange(temperature, 1.0 + 1e-6, increment))
|
133 |
+
else:
|
134 |
+
temperature = [temperature]
|
135 |
+
|
136 |
+
if (threads := args.pop("threads")) > 0:
|
137 |
+
torch.set_num_threads(threads)
|
138 |
+
|
139 |
+
from whisper import load_model
|
140 |
+
|
141 |
+
writer = get_writer(output_format, output_dir)
|
142 |
+
|
143 |
+
# Part 1: VAD & ASR Loop
|
144 |
+
results = []
|
145 |
+
tmp_results = []
|
146 |
+
model = load_model(model_name, device=device, download_root=model_dir)
|
147 |
+
for audio_path in args.pop("audio"):
|
148 |
+
input_audio_path = audio_path
|
149 |
+
tfile = None
|
150 |
+
|
151 |
+
# >> VAD & ASR
|
152 |
+
if vad_model is not None:
|
153 |
+
if not audio_path.endswith(".wav"):
|
154 |
+
print(">>VAD requires .wav format, converting to wav as a tempfile...")
|
155 |
+
audio_basename = os.path.splitext(os.path.basename(audio_path))[0]
|
156 |
+
if tmp_dir is not None:
|
157 |
+
input_audio_path = os.path.join(tmp_dir, audio_basename + ".wav")
|
158 |
+
else:
|
159 |
+
input_audio_path = os.path.join(os.path.dirname(audio_path), audio_basename + ".wav")
|
160 |
+
ffmpeg.input(audio_path, threads=0).output(input_audio_path, ac=1, ar=SAMPLE_RATE).run(cmd=["ffmpeg"])
|
161 |
+
print(">>Performing VAD...")
|
162 |
+
result = transcribe_with_vad(model, input_audio_path, vad_model, temperature=temperature, **args)
|
163 |
+
else:
|
164 |
+
print(">>Performing transcription...")
|
165 |
+
result = transcribe(model, input_audio_path, temperature=temperature, **args)
|
166 |
+
|
167 |
+
results.append((result, input_audio_path))
|
168 |
+
|
169 |
+
# Unload Whisper and VAD
|
170 |
+
del model
|
171 |
+
del vad_model
|
172 |
+
gc.collect()
|
173 |
+
torch.cuda.empty_cache()
|
174 |
+
|
175 |
+
# Part 2: Align Loop
|
176 |
+
if not no_align:
|
177 |
+
tmp_results = results
|
178 |
+
results = []
|
179 |
+
align_language = args["language"] if args["language"] is not None else "en" # default to loading english if not specified
|
180 |
+
align_model, align_metadata = load_align_model(align_language, device, model_name=align_model)
|
181 |
+
for result, input_audio_path in tmp_results:
|
182 |
+
# >> Align
|
183 |
+
if align_model is not None and len(result["segments"]) > 0:
|
184 |
+
if result.get("language", "en") != align_metadata["language"]:
|
185 |
+
# load new language
|
186 |
+
print(f"New language found ({result['language']})! Previous was ({align_metadata['language']}), loading new alignment model for new language...")
|
187 |
+
align_model, align_metadata = load_align_model(result["language"], device)
|
188 |
+
print(">>Performing alignment...")
|
189 |
+
result = align(result["segments"], align_model, align_metadata, input_audio_path, device,
|
190 |
+
extend_duration=align_extend, start_from_previous=align_from_prev, interpolate_method=interpolate_method)
|
191 |
+
results.append((result, input_audio_path))
|
192 |
+
|
193 |
+
# Unload align model
|
194 |
+
del align_model
|
195 |
+
gc.collect()
|
196 |
+
torch.cuda.empty_cache()
|
197 |
+
|
198 |
+
# >> Diarize
|
199 |
+
if diarize:
|
200 |
+
if hf_token is None:
|
201 |
+
print("Warning, no --hf_token used, needs to be saved in environment variable, otherwise will throw error loading diarization model...")
|
202 |
+
tmp_results = results
|
203 |
+
results = []
|
204 |
+
diarize_model = DiarizationPipeline(use_auth_token=hf_token)
|
205 |
+
for result, input_audio_path in tmp_results:
|
206 |
+
diarize_segments = diarize_model(input_audio_path, min_speakers=min_speakers, max_speakers=max_speakers)
|
207 |
+
results_segments, word_segments = assign_word_speakers(diarize_segments, result["segments"])
|
208 |
+
result = {"segments": results_segments, "word_segments": word_segments}
|
209 |
+
results.append((result, input_audio_path))
|
210 |
+
|
211 |
+
# >> Write
|
212 |
+
for result, audio_path in results:
|
213 |
+
writer(result, audio_path)
|
214 |
+
|
215 |
+
# cleanup
|
216 |
+
if input_audio_path != audio_path:
|
217 |
+
os.remove(input_audio_path)
|
218 |
+
|
219 |
+
if __name__ == "__main__":
|
220 |
+
cli()
|
whisperx/utils.py
ADDED
@@ -0,0 +1,317 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import zlib
|
3 |
+
from typing import Callable, TextIO, Iterator, Tuple
|
4 |
+
import pandas as pd
|
5 |
+
import numpy as np
|
6 |
+
|
7 |
+
def interpolate_nans(x, method='nearest'):
|
8 |
+
if x.notnull().sum() > 1:
|
9 |
+
return x.interpolate(method=method).ffill().bfill()
|
10 |
+
else:
|
11 |
+
return x.ffill().bfill()
|
12 |
+
|
13 |
+
|
14 |
+
def write_txt(transcript: Iterator[dict], file: TextIO):
|
15 |
+
for segment in transcript:
|
16 |
+
print(segment['text'].strip(), file=file, flush=True)
|
17 |
+
|
18 |
+
|
19 |
+
def write_vtt(transcript: Iterator[dict], file: TextIO):
|
20 |
+
print("WEBVTT\n", file=file)
|
21 |
+
for segment in transcript:
|
22 |
+
print(
|
23 |
+
f"{format_timestamp(segment['start'])} --> {format_timestamp(segment['end'])}\n"
|
24 |
+
f"{segment['text'].strip().replace('-->', '->')}\n",
|
25 |
+
file=file,
|
26 |
+
flush=True,
|
27 |
+
)
|
28 |
+
|
29 |
+
def write_tsv(transcript: Iterator[dict], file: TextIO):
|
30 |
+
print("start", "end", "text", sep="\t", file=file)
|
31 |
+
for segment in transcript:
|
32 |
+
print(segment['start'], file=file, end="\t")
|
33 |
+
print(segment['end'], file=file, end="\t")
|
34 |
+
print(segment['text'].strip().replace("\t", " "), file=file, flush=True)
|
35 |
+
|
36 |
+
|
37 |
+
def write_srt(transcript: Iterator[dict], file: TextIO):
|
38 |
+
"""
|
39 |
+
Write a transcript to a file in SRT format.
|
40 |
+
|
41 |
+
Example usage:
|
42 |
+
from pathlib import Path
|
43 |
+
from whisper.utils import write_srt
|
44 |
+
|
45 |
+
result = transcribe(model, audio_path, temperature=temperature, **args)
|
46 |
+
|
47 |
+
# save SRT
|
48 |
+
audio_basename = Path(audio_path).stem
|
49 |
+
with open(Path(output_dir) / (audio_basename + ".srt"), "w", encoding="utf-8") as srt:
|
50 |
+
write_srt(result["segments"], file=srt)
|
51 |
+
"""
|
52 |
+
for i, segment in enumerate(transcript, start=1):
|
53 |
+
# write srt lines
|
54 |
+
print(
|
55 |
+
f"{i}\n"
|
56 |
+
f"{format_timestamp(segment['start'], always_include_hours=True, decimal_marker=',')} --> "
|
57 |
+
f"{format_timestamp(segment['end'], always_include_hours=True, decimal_marker=',')}\n"
|
58 |
+
f"{segment['text'].strip().replace('-->', '->')}\n",
|
59 |
+
file=file,
|
60 |
+
flush=True,
|
61 |
+
)
|
62 |
+
|
63 |
+
|
64 |
+
def write_ass(transcript: Iterator[dict],
|
65 |
+
file: TextIO,
|
66 |
+
resolution: str = "word",
|
67 |
+
color: str = None, underline=True,
|
68 |
+
prefmt: str = None, suffmt: str = None,
|
69 |
+
font: str = None, font_size: int = 24,
|
70 |
+
strip=True, **kwargs):
|
71 |
+
"""
|
72 |
+
Credit: https://github.com/jianfch/stable-ts/blob/ff79549bd01f764427879f07ecd626c46a9a430a/stable_whisper/text_output.py
|
73 |
+
Generate Advanced SubStation Alpha (ass) file from results to
|
74 |
+
display both phrase-level & word-level timestamp simultaneously by:
|
75 |
+
-using segment-level timestamps display phrases as usual
|
76 |
+
-using word-level timestamps change formats (e.g. color/underline) of the word in the displayed segment
|
77 |
+
Note: ass file is used in the same way as srt, vtt, etc.
|
78 |
+
Parameters
|
79 |
+
----------
|
80 |
+
transcript: dict
|
81 |
+
results from modified model
|
82 |
+
file: TextIO
|
83 |
+
file object to write to
|
84 |
+
resolution: str
|
85 |
+
"word" or "char", timestamp resolution to highlight.
|
86 |
+
color: str
|
87 |
+
color code for a word at its corresponding timestamp
|
88 |
+
<bbggrr> reverse order hexadecimal RGB value (e.g. FF0000 is full intensity blue. Default: 00FF00)
|
89 |
+
underline: bool
|
90 |
+
whether to underline a word at its corresponding timestamp
|
91 |
+
prefmt: str
|
92 |
+
used to specify format for word-level timestamps (must be use with 'suffmt' and overrides 'color'&'underline')
|
93 |
+
appears as such in the .ass file:
|
94 |
+
Hi, {<prefmt>}how{<suffmt>} are you?
|
95 |
+
reference [Appendix A: Style override codes] in http://www.tcax.org/docs/ass-specs.htm
|
96 |
+
suffmt: str
|
97 |
+
used to specify format for word-level timestamps (must be use with 'prefmt' and overrides 'color'&'underline')
|
98 |
+
appears as such in the .ass file:
|
99 |
+
Hi, {<prefmt>}how{<suffmt>} are you?
|
100 |
+
reference [Appendix A: Style override codes] in http://www.tcax.org/docs/ass-specs.htm
|
101 |
+
font: str
|
102 |
+
word font (default: Arial)
|
103 |
+
font_size: int
|
104 |
+
word font size (default: 48)
|
105 |
+
kwargs:
|
106 |
+
used for format styles:
|
107 |
+
'Name', 'Fontname', 'Fontsize', 'PrimaryColour', 'SecondaryColour', 'OutlineColour', 'BackColour', 'Bold',
|
108 |
+
'Italic', 'Underline', 'StrikeOut', 'ScaleX', 'ScaleY', 'Spacing', 'Angle', 'BorderStyle', 'Outline',
|
109 |
+
'Shadow', 'Alignment', 'MarginL', 'MarginR', 'MarginV', 'Encoding'
|
110 |
+
|
111 |
+
"""
|
112 |
+
|
113 |
+
fmt_style_dict = {'Name': 'Default', 'Fontname': 'Arial', 'Fontsize': '48', 'PrimaryColour': '&Hffffff',
|
114 |
+
'SecondaryColour': '&Hffffff', 'OutlineColour': '&H0', 'BackColour': '&H0', 'Bold': '0',
|
115 |
+
'Italic': '0', 'Underline': '0', 'StrikeOut': '0', 'ScaleX': '100', 'ScaleY': '100',
|
116 |
+
'Spacing': '0', 'Angle': '0', 'BorderStyle': '1', 'Outline': '1', 'Shadow': '0',
|
117 |
+
'Alignment': '2', 'MarginL': '10', 'MarginR': '10', 'MarginV': '10', 'Encoding': '0'}
|
118 |
+
|
119 |
+
for k, v in filter(lambda x: 'colour' in x[0].lower() and not str(x[1]).startswith('&H'), kwargs.items()):
|
120 |
+
kwargs[k] = f'&H{kwargs[k]}'
|
121 |
+
|
122 |
+
fmt_style_dict.update((k, v) for k, v in kwargs.items() if k in fmt_style_dict)
|
123 |
+
|
124 |
+
if font:
|
125 |
+
fmt_style_dict.update(Fontname=font)
|
126 |
+
if font_size:
|
127 |
+
fmt_style_dict.update(Fontsize=font_size)
|
128 |
+
|
129 |
+
fmts = f'Format: {", ".join(map(str, fmt_style_dict.keys()))}'
|
130 |
+
|
131 |
+
styles = f'Style: {",".join(map(str, fmt_style_dict.values()))}'
|
132 |
+
|
133 |
+
ass_str = f'[Script Info]\nScriptType: v4.00+\nPlayResX: 384\nPlayResY: 288\nScaledBorderAndShadow: yes\n\n' \
|
134 |
+
f'[V4+ Styles]\n{fmts}\n{styles}\n\n' \
|
135 |
+
f'[Events]\nFormat: Layer, Start, End, Style, Name, MarginL, MarginR, MarginV, Effect, Text\n\n'
|
136 |
+
|
137 |
+
if prefmt or suffmt:
|
138 |
+
if suffmt:
|
139 |
+
assert prefmt, 'prefmt must be used along with suffmt'
|
140 |
+
else:
|
141 |
+
suffmt = r'\r'
|
142 |
+
else:
|
143 |
+
if not color:
|
144 |
+
color = 'HFF00'
|
145 |
+
underline_code = r'\u1' if underline else ''
|
146 |
+
|
147 |
+
prefmt = r'{\1c&' + f'{color.upper()}&{underline_code}' + '}'
|
148 |
+
suffmt = r'{\r}'
|
149 |
+
|
150 |
+
def secs_to_hhmmss(secs: Tuple[float, int]):
|
151 |
+
mm, ss = divmod(secs, 60)
|
152 |
+
hh, mm = divmod(mm, 60)
|
153 |
+
return f'{hh:0>1.0f}:{mm:0>2.0f}:{ss:0>2.2f}'
|
154 |
+
|
155 |
+
|
156 |
+
def dialogue(chars: str, start: float, end: float, idx_0: int, idx_1: int) -> str:
|
157 |
+
if idx_0 == -1:
|
158 |
+
text = chars
|
159 |
+
else:
|
160 |
+
text = f'{chars[:idx_0]}{prefmt}{chars[idx_0:idx_1]}{suffmt}{chars[idx_1:]}'
|
161 |
+
return f"Dialogue: 0,{secs_to_hhmmss(start)},{secs_to_hhmmss(end)}," \
|
162 |
+
f"Default,,0,0,0,,{text.strip() if strip else text}"
|
163 |
+
|
164 |
+
if resolution == "word":
|
165 |
+
resolution_key = "word-segments"
|
166 |
+
elif resolution == "char":
|
167 |
+
resolution_key = "char-segments"
|
168 |
+
else:
|
169 |
+
raise ValueError(".ass resolution should be 'word' or 'char', not ", resolution)
|
170 |
+
|
171 |
+
ass_arr = []
|
172 |
+
|
173 |
+
for segment in transcript:
|
174 |
+
# if "12" in segment['text']:
|
175 |
+
# import pdb; pdb.set_trace()
|
176 |
+
if resolution_key in segment:
|
177 |
+
res_segs = pd.DataFrame(segment[resolution_key])
|
178 |
+
prev = segment['start']
|
179 |
+
if "speaker" in segment:
|
180 |
+
speaker_str = f"[{segment['speaker']}]: "
|
181 |
+
else:
|
182 |
+
speaker_str = ""
|
183 |
+
for cdx, crow in res_segs.iterrows():
|
184 |
+
if not np.isnan(crow['start']):
|
185 |
+
if resolution == "char":
|
186 |
+
idx_0 = cdx
|
187 |
+
idx_1 = cdx + 1
|
188 |
+
elif resolution == "word":
|
189 |
+
idx_0 = int(crow["segment-text-start"])
|
190 |
+
idx_1 = int(crow["segment-text-end"])
|
191 |
+
# fill gap
|
192 |
+
if crow['start'] > prev:
|
193 |
+
filler_ts = {
|
194 |
+
"chars": speaker_str + segment['text'],
|
195 |
+
"start": prev,
|
196 |
+
"end": crow['start'],
|
197 |
+
"idx_0": -1,
|
198 |
+
"idx_1": -1
|
199 |
+
}
|
200 |
+
|
201 |
+
ass_arr.append(filler_ts)
|
202 |
+
# highlight current word
|
203 |
+
f_word_ts = {
|
204 |
+
"chars": speaker_str + segment['text'],
|
205 |
+
"start": crow['start'],
|
206 |
+
"end": crow['end'],
|
207 |
+
"idx_0": idx_0 + len(speaker_str),
|
208 |
+
"idx_1": idx_1 + len(speaker_str)
|
209 |
+
}
|
210 |
+
ass_arr.append(f_word_ts)
|
211 |
+
prev = crow['end']
|
212 |
+
|
213 |
+
ass_str += '\n'.join(map(lambda x: dialogue(**x), ass_arr))
|
214 |
+
|
215 |
+
file.write(ass_str)
|
216 |
+
|
217 |
+
|
218 |
+
from whisper.utils import SubtitlesWriter, ResultWriter, WriteTXT, WriteVTT, WriteSRT, WriteTSV, WriteJSON, format_timestamp
|
219 |
+
|
220 |
+
class WriteASS(ResultWriter):
|
221 |
+
extension: str = "ass"
|
222 |
+
|
223 |
+
def write_result(self, result: dict, file: TextIO):
|
224 |
+
write_ass(result["segments"], file, resolution="word")
|
225 |
+
|
226 |
+
class WriteASSchar(ResultWriter):
|
227 |
+
extension: str = "ass"
|
228 |
+
|
229 |
+
def write_result(self, result: dict, file: TextIO):
|
230 |
+
write_ass(result["segments"], file, resolution="char")
|
231 |
+
|
232 |
+
class WritePickle(ResultWriter):
|
233 |
+
extension: str = "ass"
|
234 |
+
|
235 |
+
def write_result(self, result: dict, file: TextIO):
|
236 |
+
pd.DataFrame(result["segments"]).to_pickle(file)
|
237 |
+
|
238 |
+
class WriteSRTWord(ResultWriter):
|
239 |
+
extension: str = "word.srt"
|
240 |
+
always_include_hours: bool = True
|
241 |
+
decimal_marker: str = ","
|
242 |
+
|
243 |
+
def iterate_result(self, result: dict):
|
244 |
+
for segment in result["word_segments"]:
|
245 |
+
segment_start = self.format_timestamp(segment["start"])
|
246 |
+
segment_end = self.format_timestamp(segment["end"])
|
247 |
+
segment_text = segment["text"].strip().replace("-->", "->")
|
248 |
+
|
249 |
+
if word_timings := segment.get("words", None):
|
250 |
+
all_words = [timing["word"] for timing in word_timings]
|
251 |
+
all_words[0] = all_words[0].strip() # remove the leading space, if any
|
252 |
+
last = segment_start
|
253 |
+
for i, this_word in enumerate(word_timings):
|
254 |
+
start = self.format_timestamp(this_word["start"])
|
255 |
+
end = self.format_timestamp(this_word["end"])
|
256 |
+
if last != start:
|
257 |
+
yield last, start, segment_text
|
258 |
+
|
259 |
+
yield start, end, "".join(
|
260 |
+
[
|
261 |
+
f"<u>{word}</u>" if j == i else word
|
262 |
+
for j, word in enumerate(all_words)
|
263 |
+
]
|
264 |
+
)
|
265 |
+
last = end
|
266 |
+
|
267 |
+
if last != segment_end:
|
268 |
+
yield last, segment_end, segment_text
|
269 |
+
else:
|
270 |
+
yield segment_start, segment_end, segment_text
|
271 |
+
|
272 |
+
def write_result(self, result: dict, file: TextIO):
|
273 |
+
if "word_segments" not in result:
|
274 |
+
return
|
275 |
+
for i, (start, end, text) in enumerate(self.iterate_result(result), start=1):
|
276 |
+
print(f"{i}\n{start} --> {end}\n{text}\n", file=file, flush=True)
|
277 |
+
|
278 |
+
def format_timestamp(self, seconds: float):
|
279 |
+
return format_timestamp(
|
280 |
+
seconds=seconds,
|
281 |
+
always_include_hours=self.always_include_hours,
|
282 |
+
decimal_marker=self.decimal_marker,
|
283 |
+
)
|
284 |
+
|
285 |
+
def get_writer(output_format: str, output_dir: str) -> Callable[[dict, TextIO], None]:
|
286 |
+
writers = {
|
287 |
+
"txt": WriteTXT,
|
288 |
+
"vtt": WriteVTT,
|
289 |
+
"srt": WriteSRT,
|
290 |
+
"tsv": WriteTSV,
|
291 |
+
"ass": WriteASS,
|
292 |
+
"srt-word": WriteSRTWord,
|
293 |
+
# "ass-char": WriteASSchar,
|
294 |
+
# "pickle": WritePickle,
|
295 |
+
# "json": WriteJSON,
|
296 |
+
}
|
297 |
+
|
298 |
+
writers_other = {
|
299 |
+
"pkl": WritePickle,
|
300 |
+
"ass-char": WriteASSchar
|
301 |
+
}
|
302 |
+
|
303 |
+
if output_format == "all":
|
304 |
+
all_writers = [writer(output_dir) for writer in writers.values()]
|
305 |
+
|
306 |
+
def write_all(result: dict, file: TextIO):
|
307 |
+
for writer in all_writers:
|
308 |
+
writer(result, file)
|
309 |
+
|
310 |
+
return write_all
|
311 |
+
|
312 |
+
if output_format in writers:
|
313 |
+
return writers[output_format](output_dir)
|
314 |
+
elif output_format in writers_other:
|
315 |
+
return writers_other[output_format](output_dir)
|
316 |
+
else:
|
317 |
+
raise ValueError(f"Output format '{output_format}' not supported, choose from {writers.keys()} and {writers_other.keys()}")
|
whisperx/vad.py
ADDED
@@ -0,0 +1,305 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import urllib
|
3 |
+
import pandas as pd
|
4 |
+
import numpy as np
|
5 |
+
import torch
|
6 |
+
import hashlib
|
7 |
+
from tqdm import tqdm
|
8 |
+
from typing import Optional, Callable, Union, Text
|
9 |
+
from pyannote.audio.core.io import AudioFile
|
10 |
+
from pyannote.core import Annotation, Segment, SlidingWindowFeature
|
11 |
+
from pyannote.audio.pipelines.utils import PipelineModel
|
12 |
+
from pyannote.audio import Model
|
13 |
+
from pyannote.audio.pipelines import VoiceActivityDetection
|
14 |
+
from .diarize import Segment as SegmentX
|
15 |
+
from typing import List, Tuple, Optional
|
16 |
+
|
17 |
+
VAD_SEGMENTATION_URL = "https://whisperx.s3.eu-west-2.amazonaws.com/model_weights/segmentation/0b5b3216d60a2d32fc086b47ea8c67589aaeb26b7e07fcbe620d6d0b83e209ea/pytorch_model.bin"
|
18 |
+
|
19 |
+
def load_vad_model(device, vad_onset, vad_offset, use_auth_token=None):
|
20 |
+
model_dir = torch.hub._get_torch_home()
|
21 |
+
os.makedirs(model_dir, exist_ok = True)
|
22 |
+
model_fp = os.path.join(model_dir, "whisperx-vad-segmentation.bin")
|
23 |
+
if os.path.exists(model_fp) and not os.path.isfile(model_fp):
|
24 |
+
raise RuntimeError(f"{model_fp} exists and is not a regular file")
|
25 |
+
|
26 |
+
if not os.path.isfile(model_fp):
|
27 |
+
with urllib.request.urlopen(VAD_SEGMENTATION_URL) as source, open(model_fp, "wb") as output:
|
28 |
+
with tqdm(
|
29 |
+
total=int(source.info().get("Content-Length")),
|
30 |
+
ncols=80,
|
31 |
+
unit="iB",
|
32 |
+
unit_scale=True,
|
33 |
+
unit_divisor=1024,
|
34 |
+
) as loop:
|
35 |
+
while True:
|
36 |
+
buffer = source.read(8192)
|
37 |
+
if not buffer:
|
38 |
+
break
|
39 |
+
|
40 |
+
output.write(buffer)
|
41 |
+
loop.update(len(buffer))
|
42 |
+
|
43 |
+
model_bytes = open(model_fp, "rb").read()
|
44 |
+
if hashlib.sha256(model_bytes).hexdigest() != VAD_SEGMENTATION_URL.split('/')[-2]:
|
45 |
+
raise RuntimeError(
|
46 |
+
"Model has been downloaded but the SHA256 checksum does not not match. Please retry loading the model."
|
47 |
+
)
|
48 |
+
|
49 |
+
vad_model = Model.from_pretrained(model_fp, use_auth_token=use_auth_token)
|
50 |
+
hyperparameters = {"onset": vad_onset,
|
51 |
+
"offset": vad_offset,
|
52 |
+
"min_duration_on": 0.1,
|
53 |
+
"min_duration_off": 0.1}
|
54 |
+
vad_pipeline = VoiceActivitySegmentation(segmentation=vad_model, device=torch.device(device))
|
55 |
+
vad_pipeline.instantiate(hyperparameters)
|
56 |
+
|
57 |
+
return vad_pipeline
|
58 |
+
|
59 |
+
class Binarize:
|
60 |
+
"""Binarize detection scores using hysteresis thresholding, with min-cut operation
|
61 |
+
to ensure not segments are longer than max_duration.
|
62 |
+
|
63 |
+
Parameters
|
64 |
+
----------
|
65 |
+
onset : float, optional
|
66 |
+
Onset threshold. Defaults to 0.5.
|
67 |
+
offset : float, optional
|
68 |
+
Offset threshold. Defaults to `onset`.
|
69 |
+
min_duration_on : float, optional
|
70 |
+
Remove active regions shorter than that many seconds. Defaults to 0s.
|
71 |
+
min_duration_off : float, optional
|
72 |
+
Fill inactive regions shorter than that many seconds. Defaults to 0s.
|
73 |
+
pad_onset : float, optional
|
74 |
+
Extend active regions by moving their start time by that many seconds.
|
75 |
+
Defaults to 0s.
|
76 |
+
pad_offset : float, optional
|
77 |
+
Extend active regions by moving their end time by that many seconds.
|
78 |
+
Defaults to 0s.
|
79 |
+
max_duration: float
|
80 |
+
The maximum length of an active segment, divides segment at timestamp with lowest score.
|
81 |
+
Reference
|
82 |
+
---------
|
83 |
+
Gregory Gelly and Jean-Luc Gauvain. "Minimum Word Error Training of
|
84 |
+
RNN-based Voice Activity Detection", InterSpeech 2015.
|
85 |
+
|
86 |
+
Modified by Max Bain to include WhisperX's min-cut operation
|
87 |
+
https://arxiv.org/abs/2303.00747
|
88 |
+
|
89 |
+
Pyannote-audio
|
90 |
+
"""
|
91 |
+
|
92 |
+
def __init__(
|
93 |
+
self,
|
94 |
+
onset: float = 0.5,
|
95 |
+
offset: Optional[float] = None,
|
96 |
+
min_duration_on: float = 0.0,
|
97 |
+
min_duration_off: float = 0.0,
|
98 |
+
pad_onset: float = 0.0,
|
99 |
+
pad_offset: float = 0.0,
|
100 |
+
max_duration: float = float('inf')
|
101 |
+
):
|
102 |
+
|
103 |
+
super().__init__()
|
104 |
+
|
105 |
+
self.onset = onset
|
106 |
+
self.offset = offset or onset
|
107 |
+
|
108 |
+
self.pad_onset = pad_onset
|
109 |
+
self.pad_offset = pad_offset
|
110 |
+
|
111 |
+
self.min_duration_on = min_duration_on
|
112 |
+
self.min_duration_off = min_duration_off
|
113 |
+
|
114 |
+
self.max_duration = max_duration
|
115 |
+
|
116 |
+
def __call__(self, scores: SlidingWindowFeature) -> Annotation:
|
117 |
+
"""Binarize detection scores
|
118 |
+
Parameters
|
119 |
+
----------
|
120 |
+
scores : SlidingWindowFeature
|
121 |
+
Detection scores.
|
122 |
+
Returns
|
123 |
+
-------
|
124 |
+
active : Annotation
|
125 |
+
Binarized scores.
|
126 |
+
"""
|
127 |
+
|
128 |
+
num_frames, num_classes = scores.data.shape
|
129 |
+
frames = scores.sliding_window
|
130 |
+
timestamps = [frames[i].middle for i in range(num_frames)]
|
131 |
+
|
132 |
+
# annotation meant to store 'active' regions
|
133 |
+
active = Annotation()
|
134 |
+
for k, k_scores in enumerate(scores.data.T):
|
135 |
+
|
136 |
+
label = k if scores.labels is None else scores.labels[k]
|
137 |
+
|
138 |
+
# initial state
|
139 |
+
start = timestamps[0]
|
140 |
+
is_active = k_scores[0] > self.onset
|
141 |
+
curr_scores = [k_scores[0]]
|
142 |
+
curr_timestamps = [start]
|
143 |
+
for t, y in zip(timestamps[1:], k_scores[1:]):
|
144 |
+
# currently active
|
145 |
+
if is_active:
|
146 |
+
curr_duration = t - start
|
147 |
+
if curr_duration > self.max_duration:
|
148 |
+
# if curr_duration > 15:
|
149 |
+
# import pdb; pdb.set_trace()
|
150 |
+
search_after = len(curr_scores) // 2
|
151 |
+
# divide segment
|
152 |
+
min_score_div_idx = search_after + np.argmin(curr_scores[search_after:])
|
153 |
+
min_score_t = curr_timestamps[min_score_div_idx]
|
154 |
+
region = Segment(start - self.pad_onset, min_score_t + self.pad_offset)
|
155 |
+
active[region, k] = label
|
156 |
+
start = curr_timestamps[min_score_div_idx]
|
157 |
+
curr_scores = curr_scores[min_score_div_idx+1:]
|
158 |
+
curr_timestamps = curr_timestamps[min_score_div_idx+1:]
|
159 |
+
# switching from active to inactive
|
160 |
+
elif y < self.offset:
|
161 |
+
region = Segment(start - self.pad_onset, t + self.pad_offset)
|
162 |
+
active[region, k] = label
|
163 |
+
start = t
|
164 |
+
is_active = False
|
165 |
+
curr_scores = []
|
166 |
+
curr_timestamps = []
|
167 |
+
# currently inactive
|
168 |
+
else:
|
169 |
+
# switching from inactive to active
|
170 |
+
if y > self.onset:
|
171 |
+
start = t
|
172 |
+
is_active = True
|
173 |
+
curr_scores.append(y)
|
174 |
+
curr_timestamps.append(t)
|
175 |
+
|
176 |
+
# if active at the end, add final region
|
177 |
+
if is_active:
|
178 |
+
region = Segment(start - self.pad_onset, t + self.pad_offset)
|
179 |
+
active[region, k] = label
|
180 |
+
|
181 |
+
# because of padding, some active regions might be overlapping: merge them.
|
182 |
+
# also: fill same speaker gaps shorter than min_duration_off
|
183 |
+
if self.pad_offset > 0.0 or self.pad_onset > 0.0 or self.min_duration_off > 0.0:
|
184 |
+
if self.max_duration < float("inf"):
|
185 |
+
raise NotImplementedError(f"This would break current max_duration param")
|
186 |
+
active = active.support(collar=self.min_duration_off)
|
187 |
+
|
188 |
+
# remove tracks shorter than min_duration_on
|
189 |
+
if self.min_duration_on > 0:
|
190 |
+
for segment, track in list(active.itertracks()):
|
191 |
+
if segment.duration < self.min_duration_on:
|
192 |
+
del active[segment, track]
|
193 |
+
|
194 |
+
return active
|
195 |
+
|
196 |
+
|
197 |
+
class VoiceActivitySegmentation(VoiceActivityDetection):
|
198 |
+
def __init__(
|
199 |
+
self,
|
200 |
+
segmentation: PipelineModel = "pyannote/segmentation",
|
201 |
+
fscore: bool = False,
|
202 |
+
use_auth_token: Union[Text, None] = None,
|
203 |
+
**inference_kwargs,
|
204 |
+
):
|
205 |
+
|
206 |
+
super().__init__(segmentation=segmentation, fscore=fscore, use_auth_token=use_auth_token, **inference_kwargs)
|
207 |
+
|
208 |
+
def apply(self, file: AudioFile, hook: Optional[Callable] = None) -> Annotation:
|
209 |
+
"""Apply voice activity detection
|
210 |
+
|
211 |
+
Parameters
|
212 |
+
----------
|
213 |
+
file : AudioFile
|
214 |
+
Processed file.
|
215 |
+
hook : callable, optional
|
216 |
+
Hook called after each major step of the pipeline with the following
|
217 |
+
signature: hook("step_name", step_artefact, file=file)
|
218 |
+
|
219 |
+
Returns
|
220 |
+
-------
|
221 |
+
speech : Annotation
|
222 |
+
Speech regions.
|
223 |
+
"""
|
224 |
+
|
225 |
+
# setup hook (e.g. for debugging purposes)
|
226 |
+
hook = self.setup_hook(file, hook=hook)
|
227 |
+
|
228 |
+
# apply segmentation model (only if needed)
|
229 |
+
# output shape is (num_chunks, num_frames, 1)
|
230 |
+
if self.training:
|
231 |
+
if self.CACHED_SEGMENTATION in file:
|
232 |
+
segmentations = file[self.CACHED_SEGMENTATION]
|
233 |
+
else:
|
234 |
+
segmentations = self._segmentation(file)
|
235 |
+
file[self.CACHED_SEGMENTATION] = segmentations
|
236 |
+
else:
|
237 |
+
segmentations: SlidingWindowFeature = self._segmentation(file)
|
238 |
+
|
239 |
+
return segmentations
|
240 |
+
|
241 |
+
|
242 |
+
def merge_vad(vad_arr, pad_onset=0.0, pad_offset=0.0, min_duration_off=0.0, min_duration_on=0.0):
|
243 |
+
|
244 |
+
active = Annotation()
|
245 |
+
for k, vad_t in enumerate(vad_arr):
|
246 |
+
region = Segment(vad_t[0] - pad_onset, vad_t[1] + pad_offset)
|
247 |
+
active[region, k] = 1
|
248 |
+
|
249 |
+
|
250 |
+
if pad_offset > 0.0 or pad_onset > 0.0 or min_duration_off > 0.0:
|
251 |
+
active = active.support(collar=min_duration_off)
|
252 |
+
|
253 |
+
# remove tracks shorter than min_duration_on
|
254 |
+
if min_duration_on > 0:
|
255 |
+
for segment, track in list(active.itertracks()):
|
256 |
+
if segment.duration < min_duration_on:
|
257 |
+
del active[segment, track]
|
258 |
+
|
259 |
+
active = active.for_json()
|
260 |
+
active_segs = pd.DataFrame([x['segment'] for x in active['content']])
|
261 |
+
return active_segs
|
262 |
+
|
263 |
+
def merge_chunks(segments, chunk_size):
|
264 |
+
"""
|
265 |
+
Merge operation described in paper
|
266 |
+
"""
|
267 |
+
curr_end = 0
|
268 |
+
merged_segments = []
|
269 |
+
seg_idxs = []
|
270 |
+
speaker_idxs = []
|
271 |
+
|
272 |
+
assert chunk_size > 0
|
273 |
+
binarize = Binarize(max_duration=chunk_size)
|
274 |
+
segments = binarize(segments)
|
275 |
+
segments_list = []
|
276 |
+
for speech_turn in segments.get_timeline():
|
277 |
+
segments_list.append(SegmentX(speech_turn.start, speech_turn.end, "UNKNOWN"))
|
278 |
+
|
279 |
+
if len(segments_list) == 0:
|
280 |
+
print("No active speech found in audio")
|
281 |
+
return []
|
282 |
+
# assert segments_list, "segments_list is empty."
|
283 |
+
# Make sur the starting point is the start of the segment.
|
284 |
+
curr_start = segments_list[0].start
|
285 |
+
|
286 |
+
for seg in segments_list:
|
287 |
+
if seg.end - curr_start > chunk_size and curr_end-curr_start > 0:
|
288 |
+
merged_segments.append({
|
289 |
+
"start": curr_start,
|
290 |
+
"end": curr_end,
|
291 |
+
"segments": seg_idxs,
|
292 |
+
})
|
293 |
+
curr_start = seg.start
|
294 |
+
seg_idxs = []
|
295 |
+
speaker_idxs = []
|
296 |
+
curr_end = seg.end
|
297 |
+
seg_idxs.append((seg.start, seg.end))
|
298 |
+
speaker_idxs.append(seg.speaker)
|
299 |
+
# add final
|
300 |
+
merged_segments.append({
|
301 |
+
"start": curr_start,
|
302 |
+
"end": curr_end,
|
303 |
+
"segments": seg_idxs,
|
304 |
+
})
|
305 |
+
return merged_segments
|