Commit
·
c3d5eb3
1
Parent(s):
7f8dd93
Upload stable_whisper.py
Browse files- stable_whisper.py +1493 -0
stable_whisper.py
ADDED
@@ -0,0 +1,1493 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
import ffmpeg
|
3 |
+
import whisper
|
4 |
+
import warnings
|
5 |
+
import numpy as np
|
6 |
+
import torch
|
7 |
+
from torch import Tensor
|
8 |
+
from torch.nn import functional as F
|
9 |
+
from torch.distributions import Categorical
|
10 |
+
from typing import List, Optional, Tuple, Union
|
11 |
+
from whisper.audio import SAMPLE_RATE, N_FRAMES, HOP_LENGTH, pad_or_trim, log_mel_spectrogram
|
12 |
+
from whisper.decoding import DecodingOptions, DecodingResult
|
13 |
+
from whisper.tokenizer import LANGUAGES
|
14 |
+
from whisper.utils import exact_div, format_timestamp, compression_ratio
|
15 |
+
from whisper.model import Whisper
|
16 |
+
from whisper.decoding import DecodingTask, BeamSearchDecoder, GreedyDecoder
|
17 |
+
from whisper.tokenizer import Tokenizer, get_tokenizer
|
18 |
+
from types import MethodType
|
19 |
+
from itertools import chain, repeat
|
20 |
+
from copy import deepcopy
|
21 |
+
import os
|
22 |
+
import json
|
23 |
+
|
24 |
+
|
25 |
+
# no_caption changed to no_speech newer commits
|
26 |
+
def get_new_attrs(obj_, attr: str):
|
27 |
+
if attr == 'no_caption_probs':
|
28 |
+
return getattr(obj_, attr) if hasattr(obj_, 'no_caption_probs') else getattr(obj_, 'no_speech_probs')
|
29 |
+
elif attr == 'no_caption_prob':
|
30 |
+
return getattr(obj_, attr) if hasattr(obj_, 'no_caption_prob') else getattr(obj_, 'no_speech_prob')
|
31 |
+
elif attr == 'no_captions':
|
32 |
+
return getattr(obj_, attr) if hasattr(obj_, 'no_captions') else getattr(obj_, 'no_speech')
|
33 |
+
else:
|
34 |
+
raise NotImplementedError(attr)
|
35 |
+
|
36 |
+
|
37 |
+
def check_ascending_sequence(seq: Union[List[Union[int, float]], np.ndarray], verbose=True) -> bool:
|
38 |
+
"""
|
39 |
+
check if a sequence of numbers are in ascending order
|
40 |
+
"""
|
41 |
+
is_ascending = True
|
42 |
+
for idx, (i, j) in enumerate(zip(seq[:-1], seq[1:])):
|
43 |
+
if i > j:
|
44 |
+
is_ascending = False
|
45 |
+
if verbose:
|
46 |
+
print(f'[Index{idx}]:{i} > [Index{idx + 1}]:{j}')
|
47 |
+
else:
|
48 |
+
break
|
49 |
+
|
50 |
+
return is_ascending
|
51 |
+
|
52 |
+
|
53 |
+
def check_ascending_sentence_ts(res: (dict, list)) -> bool:
|
54 |
+
segs = res['segments'] if isinstance(res, dict) else res
|
55 |
+
return check_ascending_sequence(list(chain.from_iterable((float(i['start']), float(i['end']))
|
56 |
+
for i in segs)))
|
57 |
+
|
58 |
+
|
59 |
+
def check_ascending_word_ts(res: (dict, list)) -> bool:
|
60 |
+
cc = group_word_timestamps(res['segments'] if isinstance(res, dict) else res, ts_key='word_timestamps')
|
61 |
+
return check_ascending_sequence((list(chain.from_iterable((float(i['start']), float(i['end']))
|
62 |
+
for i in cc))))
|
63 |
+
|
64 |
+
|
65 |
+
def is_equal_ts(a: (float, int, np.ndarray), b: (float, int, np.ndarray), rtol=1e-03):
|
66 |
+
"""
|
67 |
+
check if timestamp a and timestamp b are equal within the relative tolerance (rtol)
|
68 |
+
"""
|
69 |
+
return np.isclose(a, b, rtol=rtol)
|
70 |
+
|
71 |
+
|
72 |
+
def check_is_same_results(res0: (dict, list), res1: (dict, list), check_unstable=False) -> bool:
|
73 |
+
"""
|
74 |
+
check if res0 and res1 have same timestamps
|
75 |
+
"""
|
76 |
+
if isinstance(res0, dict):
|
77 |
+
res0 = res0['segments']
|
78 |
+
if isinstance(res1, dict):
|
79 |
+
res1 = res1['segments']
|
80 |
+
ts_key = 'unstable_word_timestamps' if check_unstable else 'word_timestamps'
|
81 |
+
inner_ts_key = 'timestamps' if check_unstable else 'timestamp'
|
82 |
+
|
83 |
+
def _reduce(x):
|
84 |
+
if isinstance(x, np.ndarray):
|
85 |
+
return set(tuple(x)) == {True}
|
86 |
+
return x
|
87 |
+
|
88 |
+
t = set(set(_reduce(is_equal_ts(a[inner_ts_key], b[inner_ts_key])) for a, b in zip(i[ts_key], j[ts_key])) == {True}
|
89 |
+
for i, j in zip(res0['segments'], res1['segments']))
|
90 |
+
return t == {True}
|
91 |
+
|
92 |
+
|
93 |
+
def to_srt(lines: List[dict], save_path: str = None, strip=False) -> str:
|
94 |
+
"""
|
95 |
+
lines: List[dict]
|
96 |
+
[{start:<start-timestamp-of-text>, end:<end-timestamp-of-text>, text:<str-of-text>}, ...]
|
97 |
+
"""
|
98 |
+
|
99 |
+
def secs_to_hhmmss(secs: (float, int)):
|
100 |
+
mm, ss = divmod(secs, 60)
|
101 |
+
hh, mm = divmod(mm, 60)
|
102 |
+
return f'{hh:0>2.0f}:{mm:0>2.0f}:{ss:0>6.3f}'.replace(".", ",")
|
103 |
+
|
104 |
+
srt_str = '\n'.join(
|
105 |
+
f'{i}\n'
|
106 |
+
f'{secs_to_hhmmss(sub["start"])} --> {secs_to_hhmmss(sub["end"])}\n'
|
107 |
+
f'{sub["text"].strip() if strip else sub["text"]}\n'
|
108 |
+
for i, sub in enumerate(lines, 1))
|
109 |
+
|
110 |
+
if save_path:
|
111 |
+
with open(save_path, 'w', encoding='utf-8') as f:
|
112 |
+
f.write(srt_str)
|
113 |
+
print(f'Saved: {os.path.abspath(save_path)}')
|
114 |
+
|
115 |
+
return srt_str
|
116 |
+
|
117 |
+
|
118 |
+
def group_word_timestamps(res: (dict, list), one_group=True, combine_compound=False,
|
119 |
+
ts_key='whole_word_timestamps', min_dur: float = None):
|
120 |
+
|
121 |
+
if min_dur is None:
|
122 |
+
min_dur = 0.0
|
123 |
+
|
124 |
+
def group_ts(ts_: List[dict], start) -> List[dict]:
|
125 |
+
first_group: List[dict] = []
|
126 |
+
for w_ts in ts_:
|
127 |
+
if first_group:
|
128 |
+
if (not combine_compound or w_ts['word'].startswith(' ')) and \
|
129 |
+
(w_ts['timestamp'] - first_group[-1]['start']) >= min_dur and \
|
130 |
+
first_group[-1]['end'] < w_ts['timestamp']:
|
131 |
+
first_group.append(dict(start=first_group[-1]['end'],
|
132 |
+
end=w_ts['timestamp'],
|
133 |
+
text=w_ts['word']))
|
134 |
+
else:
|
135 |
+
first_group[-1]['end'] = max(first_group[-1]['end'], w_ts['timestamp'])
|
136 |
+
first_group[-1]['text'] += w_ts['word']
|
137 |
+
else:
|
138 |
+
first_group.append(dict(start=start,
|
139 |
+
end=w_ts['timestamp'],
|
140 |
+
text=w_ts['word']))
|
141 |
+
|
142 |
+
return first_group
|
143 |
+
|
144 |
+
def group_zero_duration(first_group: List[dict]) -> List[dict]:
|
145 |
+
final_group: List[dict] = []
|
146 |
+
for ts_dict in first_group:
|
147 |
+
if not final_group or (ts_dict['end'] - ts_dict['start']) > 0:
|
148 |
+
final_group.append(ts_dict)
|
149 |
+
else:
|
150 |
+
final_group[-1]['end'] = ts_dict['end']
|
151 |
+
final_group[-1]['text'] += ts_dict['text']
|
152 |
+
|
153 |
+
return final_group
|
154 |
+
|
155 |
+
segs: List[dict] = res['segments'] if isinstance(res, dict) else res
|
156 |
+
assert set(ts_key in seg for seg in segs) == {True}, f'input contains missing {ts_key}'
|
157 |
+
|
158 |
+
grouped = (group_ts(seg[ts_key], seg['start']) for seg in segs)
|
159 |
+
return group_zero_duration(list(chain.from_iterable(grouped))) if one_group else list(grouped)
|
160 |
+
|
161 |
+
|
162 |
+
def tighten_timestamps(res: dict, end_at_last_word=True, end_before_period=False, start_at_first_word=False) -> dict:
|
163 |
+
res = deepcopy(res)
|
164 |
+
for i in range(len(res['segments'])):
|
165 |
+
if start_at_first_word:
|
166 |
+
res['segments'][i]['start'] = res['segments'][i]['word_timestamps'][0]['timestamp']
|
167 |
+
if end_before_period and \
|
168 |
+
res['segments'][i]['word_timestamps'][-1] == '.' and \
|
169 |
+
len(res['segments'][i]['word_timestamps']) > 1:
|
170 |
+
res['segments'][i]['end'] = res['segments'][i]['word_timestamps'][-2]['timestamp']
|
171 |
+
elif end_at_last_word:
|
172 |
+
res['segments'][i]['end'] = res['segments'][i]['word_timestamps'][-1]['timestamp']
|
173 |
+
|
174 |
+
return res
|
175 |
+
|
176 |
+
|
177 |
+
def results_to_srt(res: dict, srt_path, word_level=True, combine_compound=False,
|
178 |
+
end_at_last_word=True, end_before_period=False, start_at_first_word=True, strip=False):
|
179 |
+
if word_level:
|
180 |
+
results_to_word_srt(res, srt_path, combine_compound=combine_compound, strip=strip)
|
181 |
+
else:
|
182 |
+
results_to_sentence_srt(res, srt_path,
|
183 |
+
end_at_last_word=end_at_last_word,
|
184 |
+
end_before_period=end_before_period,
|
185 |
+
start_at_first_word=start_at_first_word,
|
186 |
+
strip=strip)
|
187 |
+
|
188 |
+
|
189 |
+
def results_to_sentence_srt(res: dict, srt_path,
|
190 |
+
end_at_last_word=False,
|
191 |
+
end_before_period=False,
|
192 |
+
start_at_first_word=False,
|
193 |
+
strip=False):
|
194 |
+
"""
|
195 |
+
|
196 |
+
Parameters
|
197 |
+
----------
|
198 |
+
res: dict
|
199 |
+
results from modified model
|
200 |
+
srt_path: str
|
201 |
+
output path of srt
|
202 |
+
end_at_last_word: bool
|
203 |
+
set end-of-sentence to timestamp-of-last-token
|
204 |
+
end_before_period: bool
|
205 |
+
set end-of-sentence to timestamp-of-last-non-period-token
|
206 |
+
start_at_first_word: bool
|
207 |
+
set start-of-sentence to timestamp-of-first-token
|
208 |
+
strip: bool
|
209 |
+
perform strip() on each sentence
|
210 |
+
|
211 |
+
"""
|
212 |
+
strict = any((end_at_last_word, end_before_period, start_at_first_word))
|
213 |
+
segs = tighten_timestamps(res,
|
214 |
+
end_at_last_word=end_at_last_word,
|
215 |
+
end_before_period=end_before_period,
|
216 |
+
start_at_first_word=start_at_first_word)['segments'] \
|
217 |
+
if strict else res['segments']
|
218 |
+
|
219 |
+
max_idx = len(segs) - 1
|
220 |
+
i = 1
|
221 |
+
while i <= max_idx:
|
222 |
+
if not (segs[i]['end'] - segs[i]['start']):
|
223 |
+
if segs[i - 1]['end'] == segs[i]['end']:
|
224 |
+
segs[i - 1]['text'] += (' ' + segs[i]['text'].strip())
|
225 |
+
del segs[i]
|
226 |
+
max_idx -= 1
|
227 |
+
continue
|
228 |
+
else:
|
229 |
+
segs[i]['start'] = segs[i - 1]['end']
|
230 |
+
i += 1
|
231 |
+
|
232 |
+
to_srt(segs, srt_path, strip=strip)
|
233 |
+
|
234 |
+
|
235 |
+
def results_to_word_srt(res: dict, srt_path, combine_compound=False, strip=False, min_dur: float = None):
|
236 |
+
"""
|
237 |
+
|
238 |
+
Parameters
|
239 |
+
----------
|
240 |
+
res: dict
|
241 |
+
results from modified model
|
242 |
+
srt_path: str
|
243 |
+
output path of srt
|
244 |
+
combine_compound: bool
|
245 |
+
concatenate words without inbetween spacing
|
246 |
+
strip: bool
|
247 |
+
perform strip() on each word
|
248 |
+
min_dur: bool
|
249 |
+
minimum duration for each word (i.e. concat the words if it is less than specified value; Default 0.02)
|
250 |
+
|
251 |
+
"""
|
252 |
+
to_srt(group_word_timestamps(res, combine_compound=combine_compound, min_dur=min_dur),
|
253 |
+
srt_path, strip=strip)
|
254 |
+
|
255 |
+
|
256 |
+
def results_to_token_srt(res: dict, srt_path, combine_compound=False, strip=False, min_dur: float = None):
|
257 |
+
"""
|
258 |
+
|
259 |
+
Parameters
|
260 |
+
----------
|
261 |
+
res: dict
|
262 |
+
results from modified model
|
263 |
+
srt_path: str
|
264 |
+
output path of srt
|
265 |
+
combine_compound: bool
|
266 |
+
concatenate words without inbetween spacing
|
267 |
+
strip: bool
|
268 |
+
perform strip() on each token
|
269 |
+
min_dur: bool
|
270 |
+
minimum duration for each token (i.e. concat the tokens if it is less than specified value; Default 0.02)
|
271 |
+
|
272 |
+
"""
|
273 |
+
to_srt(group_word_timestamps(res, combine_compound=combine_compound, ts_key='word_timestamps', min_dur=min_dur),
|
274 |
+
srt_path, strip=strip)
|
275 |
+
|
276 |
+
|
277 |
+
def _get_min_estimation(estimations: List[Union[list, np.ndarray]],
|
278 |
+
min_: (int, float) = None,
|
279 |
+
max_: (int, float) = None) -> np.ndarray:
|
280 |
+
estimations = deepcopy(estimations)
|
281 |
+
estimations = list(map(lambda est_: np.array(est_) if isinstance(est_, list) else est_, estimations))
|
282 |
+
prev_min = min_ or 0
|
283 |
+
curr_max = max_ or np.max(estimations[-1])
|
284 |
+
|
285 |
+
min_est = []
|
286 |
+
for curr_est in estimations:
|
287 |
+
curr_min = curr_est[np.logical_and(curr_max > curr_est, curr_est > prev_min)]
|
288 |
+
curr_min = np.min(curr_min) if curr_min.shape[0] else prev_min
|
289 |
+
min_est.append(curr_min)
|
290 |
+
prev_min = curr_min
|
291 |
+
|
292 |
+
return np.array(min_est)
|
293 |
+
|
294 |
+
|
295 |
+
def _get_max_estimation(estimations: List[Union[list, np.ndarray]],
|
296 |
+
max_: (int, float) = None,
|
297 |
+
min_: (int, float) = None) -> np.ndarray:
|
298 |
+
estimations = deepcopy(estimations)
|
299 |
+
estimations = list(map(lambda est_: np.array(est_) if isinstance(est_, list) else est_, estimations))
|
300 |
+
prev_max = max_ or np.max(estimations[-1])
|
301 |
+
curr_min = np.min(estimations[0]) if min_ is None else min_
|
302 |
+
|
303 |
+
max_est = []
|
304 |
+
for curr_est in reversed(estimations):
|
305 |
+
curr_max = curr_est[np.logical_and(prev_max > curr_est, curr_est > curr_min)]
|
306 |
+
curr_max = np.max(curr_max) if curr_max.shape[0] else prev_max
|
307 |
+
max_est.append(curr_max)
|
308 |
+
prev_max = curr_max
|
309 |
+
|
310 |
+
max_est.reverse()
|
311 |
+
return np.array(max_est)
|
312 |
+
|
313 |
+
|
314 |
+
def _remove_overestimation(x: Union[np.ndarray, List[Union[int, float]]], alt_est: List[Union[list, np.ndarray]] = None,
|
315 |
+
max_: (int, float) = None, min_: (int, float) = None,
|
316 |
+
aggressive=False) -> np.ndarray:
|
317 |
+
x = np.array(x) if isinstance(x, list) else deepcopy(x)
|
318 |
+
if alt_est is not None:
|
319 |
+
alt_est = list(map(lambda est_: np.array(est_) if isinstance(est_, list) else est_, alt_est))
|
320 |
+
assert x.ndim == 1
|
321 |
+
assert alt_est is None or len(alt_est) == x.shape[0]
|
322 |
+
max_val = x[-1] if max_ is None else max_
|
323 |
+
min_val = x[0] if min_ is None else min_
|
324 |
+
|
325 |
+
def curr_max_min(val):
|
326 |
+
if min_ is None:
|
327 |
+
return val
|
328 |
+
return max(min_, val)
|
329 |
+
|
330 |
+
if min_ is not None:
|
331 |
+
x[x < min_] = min_
|
332 |
+
reduce_ = np.min if aggressive else np.mean
|
333 |
+
for i in range(x.shape[-1] - 1, -1, -1):
|
334 |
+
if x[i] > max_val or (i > 1 and x[i] < reduce_(x[:i])): # spikes or dips
|
335 |
+
if alt_est is None or alt_est[i] is None:
|
336 |
+
x[i] = max_val
|
337 |
+
else:
|
338 |
+
tmp_min = min_val if i < 2 else curr_max_min(np.mean(x[:i]))
|
339 |
+
alt_ = alt_est[i][np.logical_and(alt_est[i] < max_val, alt_est[i] > tmp_min)]
|
340 |
+
x[i] = max_val if alt_.shape[0] == 0 else alt_[0]
|
341 |
+
max_val = x[i]
|
342 |
+
return x
|
343 |
+
|
344 |
+
|
345 |
+
def _remove_underestimation(x: Union[np.ndarray, List[Union[int, float]]],
|
346 |
+
alt_est: List[Union[list, np.ndarray]] = None,
|
347 |
+
min_: (int, float) = None, max_: (int, float) = None,
|
348 |
+
aggressive=False) -> np.ndarray:
|
349 |
+
x = np.array(x) if isinstance(x, list) else deepcopy(x)
|
350 |
+
if alt_est is not None:
|
351 |
+
alt_est = list(map(lambda est_: np.array(est_) if isinstance(est_, list) else est_, alt_est))
|
352 |
+
assert x.ndim == 1
|
353 |
+
assert alt_est is None or len(alt_est) == x.shape[0]
|
354 |
+
min_val = x[0] if min_ is None else min_
|
355 |
+
max_val = x[-1] if max_ is None else max_
|
356 |
+
|
357 |
+
def curr_min_max(val):
|
358 |
+
if max_ is None:
|
359 |
+
return val
|
360 |
+
return min(max_, val)
|
361 |
+
|
362 |
+
if max_ is not None:
|
363 |
+
x[x > max_] = max_
|
364 |
+
reduce_ = np.max if aggressive else np.mean
|
365 |
+
max_i_reduce = x.shape[-1] - 2
|
366 |
+
for i in range(0, x.shape[-1]):
|
367 |
+
if x[i] < min_val or (i < max_i_reduce and x[i] > reduce_(x[i + 1:])): # dips or spikes
|
368 |
+
if alt_est is None or alt_est[i] is None:
|
369 |
+
x[i] = min_val
|
370 |
+
else:
|
371 |
+
tmp_max = max_val if i >= max_i_reduce else curr_min_max(np.mean(x[i + 1:]))
|
372 |
+
alt_ = alt_est[i][np.logical_and(alt_est[i] > min_val, alt_est[i] < tmp_max)]
|
373 |
+
x[i] = min_val if alt_.shape[0] == 0 else alt_[0]
|
374 |
+
min_val = x[i]
|
375 |
+
return x
|
376 |
+
|
377 |
+
|
378 |
+
def _merge_max_min_estimation(mx: Union[np.ndarray, List[Union[int, float]]],
|
379 |
+
mn: Union[np.ndarray, List[Union[int, float]]],
|
380 |
+
alt_est: List[Union[list, np.ndarray]] = None) -> np.ndarray:
|
381 |
+
mx = np.array(mx) if isinstance(mx, list) else deepcopy(mx)
|
382 |
+
mn = np.array(mn) if isinstance(mn, list) else deepcopy(mn)
|
383 |
+
if alt_est is not None:
|
384 |
+
alt_est = list(map(lambda est_: np.array(est_) if isinstance(est_, list) else est_, alt_est))
|
385 |
+
assert mx.ndim == 1 and mn.ndim == 1
|
386 |
+
assert mx.shape[0] == mn.shape[0]
|
387 |
+
assert alt_est is None or len(alt_est) == mx.shape[0]
|
388 |
+
|
389 |
+
pref_mx = np.var(mx) > np.var(mn)
|
390 |
+
if pref_mx:
|
391 |
+
mn[0] = mx[0]
|
392 |
+
prev_min = mn[0]
|
393 |
+
for i in range(1, mn.shape[0]):
|
394 |
+
if prev_min > mn[i]:
|
395 |
+
if mn[i] > mx[i]: # prev_min > mn[i] > mx[i]
|
396 |
+
mn[i] = prev_min
|
397 |
+
elif mx[i] > mn[i]:
|
398 |
+
if prev_min > mx[i]: # prev_min > mx[i] > mn[i]
|
399 |
+
mn[i] = prev_min
|
400 |
+
else: # mx[i] > prev_min > mn[i]
|
401 |
+
alt_ = alt_est[i][np.logical_and(alt_est[i] > prev_min, alt_est[i] < mx[i])]
|
402 |
+
mn[i] = (mx[i] if pref_mx else prev_min) if alt_.shape[0] == 0 else alt_[0]
|
403 |
+
else: # prev_min > mn[i] == mx[i]
|
404 |
+
mn[i] = prev_min
|
405 |
+
elif mn[i] > prev_min:
|
406 |
+
# if prev_min > mx[i]: # mn[i] > prev_min > mx[i]
|
407 |
+
# pass
|
408 |
+
if mx[i] > prev_min:
|
409 |
+
if mn[i] > mx[i]: # mn[i] > mx[i] > prev_min
|
410 |
+
pass
|
411 |
+
elif mx[i] > mn[i]: # mx[i] > mn[i] > prev_min
|
412 |
+
alt_ = alt_est[i][np.logical_and(alt_est[i] > mn[i], alt_est[i] < mx[i])]
|
413 |
+
if alt_.shape[0]:
|
414 |
+
mn[i] = alt_[0]
|
415 |
+
elif pref_mx:
|
416 |
+
mn[i] = mx[i]
|
417 |
+
# else: # mx[i] == mn[i] > prev_min
|
418 |
+
# pass
|
419 |
+
# else: # mn[i] > mx[i] == prev_min
|
420 |
+
# pass
|
421 |
+
else: # mn[i] == prev_min
|
422 |
+
if mx[i] > mn[i]: # mx[i] > mn[i] == prev_min
|
423 |
+
alt_ = alt_est[i][np.logical_and(alt_est[i] > mn[i], alt_est[i] < mx[i])]
|
424 |
+
if alt_.shape[0]:
|
425 |
+
mn[i] = alt_[0]
|
426 |
+
elif pref_mx:
|
427 |
+
mn[i] = mx[i]
|
428 |
+
# elif mn[i] > mx[i]: # mn[i] == prev_min > mx[i]
|
429 |
+
# pass
|
430 |
+
# else: # mn[i] == prev_min == mx[i]
|
431 |
+
# pass
|
432 |
+
|
433 |
+
prev_min = mn[i]
|
434 |
+
|
435 |
+
return mn
|
436 |
+
|
437 |
+
|
438 |
+
def _avg_merge_min_max(mx: Union[np.ndarray, List[Union[int, float]]],
|
439 |
+
mn: Union[np.ndarray, List[Union[int, float]]],
|
440 |
+
alt_timestamps: List[Union[List[Union[int, float]], np.ndarray]] = None,
|
441 |
+
max_: (int, float) = None, min_: (int, float) = None):
|
442 |
+
mx = np.array(mx) if isinstance(mx, list) else deepcopy(mx)
|
443 |
+
mn = np.array(mn) if isinstance(mn, list) else deepcopy(mn)
|
444 |
+
assert mx.ndim == mn.ndim == 1
|
445 |
+
assert mx.shape[0] == mn.shape[0]
|
446 |
+
|
447 |
+
avg_ = (mx + mn) / 2
|
448 |
+
|
449 |
+
if check_ascending_sequence(avg_, verbose=False):
|
450 |
+
return avg_
|
451 |
+
|
452 |
+
if not max_:
|
453 |
+
max_ = max(mx[-1], mn[-1])
|
454 |
+
if min_ is None:
|
455 |
+
min_ = min(mn[0], mx[0])
|
456 |
+
|
457 |
+
return _stabilize_timestamps(avg_, alt_timestamps, max_=max_, min_=min_)
|
458 |
+
|
459 |
+
|
460 |
+
def _stabilize_timestamps(timestamps: Union[np.ndarray, List[Union[int, float]]],
|
461 |
+
alt_timestamps: List[Union[List[Union[int, float]], np.ndarray]] = None,
|
462 |
+
max_: (int, float) = None, min_: (int, float) = None, aggressive=False) -> np.ndarray:
|
463 |
+
mx = _remove_overestimation(timestamps, alt_est=alt_timestamps, max_=max_, min_=min_, aggressive=aggressive)
|
464 |
+
mn = _remove_underestimation(timestamps, alt_est=alt_timestamps, max_=max_, min_=min_, aggressive=aggressive)
|
465 |
+
return _merge_max_min_estimation(mx, mn, alt_timestamps)
|
466 |
+
|
467 |
+
|
468 |
+
def _stabilize_more_timestamps(timestamps: List[Union[list, np.ndarray]],
|
469 |
+
max_: (int, float) = None, min_: (int, float) = None, average=True) -> np.ndarray:
|
470 |
+
mx = _get_max_estimation(timestamps, max_=max_, min_=min_)
|
471 |
+
mn = _get_min_estimation(timestamps, max_=max_, min_=min_)
|
472 |
+
if average:
|
473 |
+
return _avg_merge_min_max(mx, mn, timestamps, max_=max_, min_=min_)
|
474 |
+
return _merge_max_min_estimation(mx, mn, timestamps)
|
475 |
+
|
476 |
+
|
477 |
+
def stabilize_timestamps(segments: Union[List[dict], dict],
|
478 |
+
top_focus=False, aggressive=False, average=True) -> List[dict]:
|
479 |
+
"""
|
480 |
+
|
481 |
+
Parameters
|
482 |
+
----------
|
483 |
+
segments: Union[List[dict], dict]
|
484 |
+
result['segments'] or result
|
485 |
+
top_focus: bool
|
486 |
+
adhere closely to the top predictions for word timestamps
|
487 |
+
aggressive: bool
|
488 |
+
only if top_focus=True,
|
489 |
+
allow greater variation in word_timestamps/whole_word_timestamps
|
490 |
+
average: bool
|
491 |
+
only if top_focus=False,
|
492 |
+
average min and max of unstable_word_timestamps to get word_timestamps/whole_word_timestamps
|
493 |
+
|
494 |
+
"""
|
495 |
+
if isinstance(segments, dict):
|
496 |
+
segments = segments['segments']
|
497 |
+
if not segments:
|
498 |
+
warnings.warn('No Segments Found')
|
499 |
+
return []
|
500 |
+
missing_ts_idx = set(map(lambda x: None if x[1].get('unstable_word_timestamps') else x[0], enumerate(segments))) - {
|
501 |
+
None}
|
502 |
+
no_word_timestamps = len(missing_ts_idx) == len(segments)
|
503 |
+
if not no_word_timestamps and missing_ts_idx:
|
504 |
+
warnings.warn(f'Segments {list(missing_ts_idx)} are missing unstable_word_timestamps. '
|
505 |
+
f'Word-level timestamp stabilization will skipped')
|
506 |
+
|
507 |
+
segments = deepcopy(segments)
|
508 |
+
sectioned_segments: List[List] = [[]]
|
509 |
+
for i, seg in enumerate(segments, 1):
|
510 |
+
sectioned_segments[-1].append(seg)
|
511 |
+
if seg['anchor_point']:
|
512 |
+
if i < len(segments):
|
513 |
+
sectioned_segments.append([])
|
514 |
+
|
515 |
+
assert all(set(len(set(s['offset'] for s in segs)) == 1 for segs in sectioned_segments))
|
516 |
+
|
517 |
+
sectioned_segments_timestamps = [dict(min_=segs[-1]['offset'],
|
518 |
+
max_=segs[-1]['next_offset'],
|
519 |
+
timestamps=list(chain.from_iterable((s['start'], s['end']) for s in segs)),
|
520 |
+
alt_timestamps=list(chain.from_iterable((s['alt_start_timestamps'],
|
521 |
+
s['alt_end_timestamps'])
|
522 |
+
for s in segs)))
|
523 |
+
for segs in sectioned_segments]
|
524 |
+
|
525 |
+
sectioned_stab_timestamps = [_stabilize_timestamps(**kwargs).reshape(-1, 2) for kwargs in
|
526 |
+
sectioned_segments_timestamps]
|
527 |
+
|
528 |
+
for i in range(len(sectioned_segments)):
|
529 |
+
for j in range(len(sectioned_segments[i])):
|
530 |
+
sectioned_segments[i][j]['start'], sectioned_segments[i][j]['end'] = sectioned_stab_timestamps[i][j]
|
531 |
+
|
532 |
+
if not missing_ts_idx:
|
533 |
+
if top_focus:
|
534 |
+
top_word_ts = [ts_['timestamps'][0] for ts_ in
|
535 |
+
sectioned_segments[i][j]['unstable_word_timestamps']]
|
536 |
+
alt_word_ts = [ts_['timestamps'][1:] for ts_ in
|
537 |
+
sectioned_segments[i][j]['unstable_word_timestamps']]
|
538 |
+
temp_stab_word_ts = _stabilize_timestamps(top_word_ts, alt_word_ts,
|
539 |
+
max_=sectioned_segments[i][j]['end'],
|
540 |
+
min_=sectioned_segments[i][j]['start'],
|
541 |
+
aggressive=aggressive)
|
542 |
+
else:
|
543 |
+
word_ts = [ts_['timestamps'] for ts_ in sectioned_segments[i][j]['unstable_word_timestamps']]
|
544 |
+
temp_stab_word_ts = _stabilize_more_timestamps(word_ts,
|
545 |
+
max_=sectioned_segments[i][j]['end'],
|
546 |
+
min_=sectioned_segments[i][j]['start'],
|
547 |
+
average=average)
|
548 |
+
|
549 |
+
temp_stab_word_ts = [{'word': sectioned_segments[i][j]['unstable_word_timestamps'][k]['word'],
|
550 |
+
'token': sectioned_segments[i][j]['unstable_word_timestamps'][k]['token'],
|
551 |
+
'timestamp': temp_stab_word_ts[k]}
|
552 |
+
for k in range(temp_stab_word_ts.shape[0])]
|
553 |
+
|
554 |
+
sectioned_segments[i][j]['word_timestamps'] = temp_stab_word_ts
|
555 |
+
|
556 |
+
return list(chain.from_iterable(sectioned_segments))
|
557 |
+
|
558 |
+
|
559 |
+
def save_as_json(results, path):
|
560 |
+
with open(path, 'w', encoding='utf-8') as f:
|
561 |
+
json.dump(results, f)
|
562 |
+
|
563 |
+
|
564 |
+
def add_whole_word_ts(tokenizer: Tokenizer, segments: Union[List[dict], dict], merge_non_space: bool = None,
|
565 |
+
prepend_punctuations: Union[List[str], Tuple[str]] = None,
|
566 |
+
append_punctuations: Union[List[str], Tuple[str]] = None):
|
567 |
+
merge_non_space = (tokenizer.language in ['en'] or tokenizer.language is None) \
|
568 |
+
if merge_non_space is None else merge_non_space
|
569 |
+
if prepend_punctuations is None:
|
570 |
+
prepend_punctuations = r'“¿([{'
|
571 |
+
if append_punctuations is None:
|
572 |
+
append_punctuations = r'.。,,!!??::”)]}、'
|
573 |
+
if isinstance(segments, dict):
|
574 |
+
segments = segments['segments']
|
575 |
+
if not segments:
|
576 |
+
print('No segments found, whole-word timestamps cannot be added.')
|
577 |
+
return
|
578 |
+
|
579 |
+
missing_idx = set(-1 if seg.get('word_timestamps') else i for i, seg in enumerate(segments)) - {-1}
|
580 |
+
|
581 |
+
if missing_idx:
|
582 |
+
if len(missing_idx) == len(segments):
|
583 |
+
print('No word_timestamps found, whole-word timestamps cannot be added.')
|
584 |
+
return
|
585 |
+
print(f'Some word_timestamps not found, '
|
586 |
+
f'whole-word timestamps cannot be added to the following segments: {tuple(missing_idx)}')
|
587 |
+
|
588 |
+
failed_idx = []
|
589 |
+
|
590 |
+
for seg_idx, seg in enumerate(segments):
|
591 |
+
if seg.get('word_timestamps'):
|
592 |
+
prev_idx = 0
|
593 |
+
remaining_text = seg['text']
|
594 |
+
has_prepend = False
|
595 |
+
whole_word_timestamps: List[dict] = []
|
596 |
+
for wts_idx in range(1, len(seg['word_timestamps']) + 1):
|
597 |
+
max_ts = seg['word_timestamps'][wts_idx - 1]['timestamp']
|
598 |
+
tokens = [wts['token'] for wts in seg['word_timestamps'][prev_idx: wts_idx]]
|
599 |
+
temp_whole_word = tokenizer.decode(tokens)
|
600 |
+
if temp_whole_word == remaining_text[:len(temp_whole_word)]:
|
601 |
+
prev_idx = wts_idx
|
602 |
+
remaining_text = remaining_text[len(temp_whole_word):]
|
603 |
+
if (not merge_non_space or temp_whole_word.startswith(' ') or not whole_word_timestamps) and \
|
604 |
+
temp_whole_word not in append_punctuations and \
|
605 |
+
not has_prepend:
|
606 |
+
has_prepend = temp_whole_word.strip() in prepend_punctuations
|
607 |
+
whole_word_timestamps.append(dict(word=temp_whole_word, timestamp=max_ts))
|
608 |
+
else:
|
609 |
+
has_prepend = False
|
610 |
+
if whole_word_timestamps == []:
|
611 |
+
continue
|
612 |
+
whole_word_timestamps[-1]['word'] += temp_whole_word
|
613 |
+
whole_word_timestamps[-1]['timestamp'] = max_ts
|
614 |
+
if remaining_text:
|
615 |
+
failed_idx.append(seg_idx)
|
616 |
+
whole_word_timestamps = []
|
617 |
+
seg['whole_word_timestamps'] = whole_word_timestamps or None
|
618 |
+
else:
|
619 |
+
seg['whole_word_timestamps'] = None
|
620 |
+
|
621 |
+
if failed_idx:
|
622 |
+
print(f'Failed to add whole-word timestamps to the following segments: {tuple(failed_idx)}')
|
623 |
+
|
624 |
+
|
625 |
+
def _load_audio_waveform(audio: Union[str, bytes, np.ndarray, torch.Tensor], h: int, w: int) -> np.ndarray:
|
626 |
+
"""
|
627 |
+
|
628 |
+
Parameters
|
629 |
+
----------
|
630 |
+
audio: Union[str, bytes, np.ndarray, torch.Tensor], shape = (*)
|
631 |
+
The path to audio or bytes of audio file or a NumPy array or Tensor containing the audio waveform in 16 kHz
|
632 |
+
h: int
|
633 |
+
Height of waveform image
|
634 |
+
w: int
|
635 |
+
Width of waveform image
|
636 |
+
|
637 |
+
Returns
|
638 |
+
-------
|
639 |
+
Audio waveform image as a NumPy array, in uint8 dtype.
|
640 |
+
"""
|
641 |
+
|
642 |
+
try:
|
643 |
+
if isinstance(audio, str):
|
644 |
+
stream = ffmpeg.input(audio, threads=0)
|
645 |
+
inp = None
|
646 |
+
|
647 |
+
else:
|
648 |
+
if isinstance(audio, bytes):
|
649 |
+
stream = ffmpeg.input('pipe:', threads=0)
|
650 |
+
inp = audio
|
651 |
+
else:
|
652 |
+
warnings.warn('A resampled input causes an unexplained temporal shift in waveform image '
|
653 |
+
'that will skew the timestamp suppression and may result in inaccurate timestamps.\n'
|
654 |
+
'Use audio_for_mask for transcribe() to provide the original audio track '
|
655 |
+
'as the path or bytes of the audio file.',
|
656 |
+
stacklevel=2)
|
657 |
+
stream = ffmpeg.input('pipe:', threads=0, ac=1, format='s16le')
|
658 |
+
if isinstance(audio, torch.Tensor):
|
659 |
+
audio = np.array(audio)
|
660 |
+
inp = (audio * 32768.0).astype(np.int16).tobytes()
|
661 |
+
|
662 |
+
waveform, err = (
|
663 |
+
stream.filter('aformat', channel_layouts='mono')
|
664 |
+
.filter('highpass', f='200').filter('lowpass', f='3000')
|
665 |
+
.filter('showwavespic', s=f'{w}x{h}')
|
666 |
+
.output('-', pix_fmt='gray', format='rawvideo')
|
667 |
+
.run(cmd="ffmpeg", capture_stdout=True, capture_stderr=True, input=inp)
|
668 |
+
)
|
669 |
+
except ffmpeg.Error as e:
|
670 |
+
raise RuntimeError(f"Failed to load audio in waveform: {e.stderr.decode()}") from e
|
671 |
+
else:
|
672 |
+
if not waveform:
|
673 |
+
partial_file = b'partial file' in err and b'Output file is empty' in err
|
674 |
+
add_msg = '\nMetadata for decoding are likely at end of file, try to use path of audio instead.' \
|
675 |
+
if partial_file and isinstance(audio, bytes) else ''
|
676 |
+
raise RuntimeError(f"Failed to load audio in waveform: {err.decode()}" + add_msg)
|
677 |
+
return np.frombuffer(waveform, dtype=np.uint8).reshape(h, w)
|
678 |
+
|
679 |
+
|
680 |
+
def _remove_lower_quantile(waveform: np.ndarray,
|
681 |
+
upper_quantile: float = None,
|
682 |
+
lower_quantile: float = None,
|
683 |
+
lower_threshold: float = None) -> np.ndarray:
|
684 |
+
"""
|
685 |
+
Removes lower quantile of amplitude from waveform image
|
686 |
+
"""
|
687 |
+
if upper_quantile is None:
|
688 |
+
upper_quantile = 0.85
|
689 |
+
if lower_quantile is None:
|
690 |
+
lower_quantile = 0.15
|
691 |
+
if lower_threshold is None:
|
692 |
+
lower_threshold = 0.15
|
693 |
+
waveform = deepcopy(waveform)
|
694 |
+
wave_sums = waveform.sum(0)
|
695 |
+
mx = np.quantile(wave_sums, upper_quantile, -1)
|
696 |
+
mn = np.quantile(wave_sums, lower_quantile, -1)
|
697 |
+
mn_threshold = (mx - mn) * lower_threshold + mn
|
698 |
+
waveform[:, wave_sums < mn_threshold] = 0
|
699 |
+
return waveform
|
700 |
+
|
701 |
+
|
702 |
+
def _wave_to_ts_filter(waveform: np.ndarray, suppress_middle=True,
|
703 |
+
max_index: (list, int) = None) -> np.ndarray:
|
704 |
+
"""
|
705 |
+
Returns A NumPy array mask of sections with amplitude zero
|
706 |
+
"""
|
707 |
+
assert waveform.ndim <= 2, f'waveform have at most 2 dims but found {waveform.ndim}'
|
708 |
+
if waveform.ndim == 1:
|
709 |
+
wave_sum = waveform
|
710 |
+
else:
|
711 |
+
wave_sum = waveform.sum(-2)
|
712 |
+
|
713 |
+
wave_filter = wave_sum.astype(bool)
|
714 |
+
|
715 |
+
if not suppress_middle:
|
716 |
+
nonzero_indices = wave_filter.nonzero()[0]
|
717 |
+
wave_filter[nonzero_indices[0]:nonzero_indices[-1] + 1] = True
|
718 |
+
if max_index is not None:
|
719 |
+
wave_filter[max_index + 1:] = False
|
720 |
+
|
721 |
+
return ~wave_filter
|
722 |
+
|
723 |
+
|
724 |
+
# modified version of whisper.transcribe.transcribe
|
725 |
+
def transcribe_word_level(
|
726 |
+
model: "Whisper",
|
727 |
+
audio: Union[str, np.ndarray, torch.Tensor],
|
728 |
+
*,
|
729 |
+
verbose: bool = False,
|
730 |
+
temperature: Union[float, Tuple[float, ...]] = (0.0, 0.2, 0.4, 0.6, 0.8, 1.0),
|
731 |
+
compression_ratio_threshold: Optional[float] = 2.4,
|
732 |
+
logprob_threshold: Optional[float] = -1.0,
|
733 |
+
no_speech_threshold: Optional[float] = 0.6,
|
734 |
+
condition_on_previous_text: bool = True,
|
735 |
+
stab=True, top_focus=False, ts_num: int = 10,
|
736 |
+
alpha: float = None, print_unstab=False,
|
737 |
+
suppress_silence: bool = True,
|
738 |
+
suppress_middle: bool = True,
|
739 |
+
suppress_word_ts: bool = True,
|
740 |
+
remove_background: bool = True,
|
741 |
+
silence_threshold: float = 0.1,
|
742 |
+
prepend_punctuations: Union[List[str], Tuple[str]] = None,
|
743 |
+
append_punctuations: Union[List[str], Tuple[str]] = None,
|
744 |
+
audio_for_mask: (str, bytes) = None,
|
745 |
+
**decode_options):
|
746 |
+
"""
|
747 |
+
Transcribe an audio file using Whisper
|
748 |
+
|
749 |
+
Parameters
|
750 |
+
----------
|
751 |
+
model: Whisper
|
752 |
+
The Whisper model instance
|
753 |
+
|
754 |
+
audio: Union[str, np.ndarray, torch.Tensor]
|
755 |
+
The path to the audio file to open, or the audio waveform
|
756 |
+
|
757 |
+
verbose: bool
|
758 |
+
Whether to display the decoded text (with finalized timestamps) to the console
|
759 |
+
|
760 |
+
temperature: Union[float, Tuple[float, ...]]
|
761 |
+
Temperature for sampling. It can be a tuple of temperatures, which will be successfully used
|
762 |
+
upon failures according to either `compression_ratio_threshold` or `logprob_threshold`.
|
763 |
+
|
764 |
+
compression_ratio_threshold: float
|
765 |
+
If the gzip compression ratio is above this value, treat as failed
|
766 |
+
|
767 |
+
logprob_threshold: float
|
768 |
+
If the average log probability over sampled tokens is below this value, treat as failed
|
769 |
+
|
770 |
+
no_speech_threshold: float
|
771 |
+
If the no_speech probability is higher than this value AND the average log probability
|
772 |
+
over sampled tokens is below `logprob_threshold`, consider the segment as silent
|
773 |
+
|
774 |
+
condition_on_previous_text: bool
|
775 |
+
if True, the previous output of the model is provided as a prompt for the next window;
|
776 |
+
disabling may make the text inconsistent across windows, but the model becomes less prone to
|
777 |
+
getting stuck in a failure loop, such as repetition looping or timestamps going out of sync.
|
778 |
+
|
779 |
+
stab: bool
|
780 |
+
Stabilizing timestamps by cross compare timestamps and using additional top timestamp predictions
|
781 |
+
to fill in when appropriate to ensure timestamps are chronological.
|
782 |
+
|
783 |
+
top_focus: bool
|
784 |
+
Adhere closely to the top predictions for token timestamps stabilization
|
785 |
+
|
786 |
+
ts_num: int
|
787 |
+
Number of top timestamp predictions to save for each word for postprocessing stabilization (default: 10).
|
788 |
+
|
789 |
+
alpha: float
|
790 |
+
Amount of noise to add to audio to produce slightly difference results.
|
791 |
+
audio_features *= torch.rand_like(audio_features) * alpha + 1
|
792 |
+
|
793 |
+
print_unstab: bool
|
794 |
+
Whether to display the text (without stabilize timestamps) being decoded to the console
|
795 |
+
(i.e. behaves like verbose before model was modified)
|
796 |
+
|
797 |
+
suppress_silence: bool
|
798 |
+
Suppress timestamp tokens that are marked as silent
|
799 |
+
|
800 |
+
suppress_middle: bool
|
801 |
+
Suppress any silent timestamps tokens of middle of the segment instead of only beginning and ending
|
802 |
+
|
803 |
+
suppress_word_ts: bool
|
804 |
+
Suppress timestamp tokens of words that are marked as silent
|
805 |
+
|
806 |
+
remove_background: bool
|
807 |
+
Whether to remove background noise from waveform so that it is marked silent.
|
808 |
+
Determined by parameters part of decode_options (i.e. specify like other options here):
|
809 |
+
upper_quantile: float
|
810 |
+
The upper quantile of amplitude to determine a max amplitude, mx (Default: 0.85)
|
811 |
+
lower_quantile: float
|
812 |
+
The lower quantile of amplitude to determine a min amplitude, mn (Default: 0.15)
|
813 |
+
lower_threshold: float
|
814 |
+
Suppressed sections of waveform where amplitude < lower_threshold*(mx-mn) + mn. (Default: 0.15)
|
815 |
+
|
816 |
+
silence_threshold: float:
|
817 |
+
Audio segments silence average >= silence_threshold
|
818 |
+
then that segment will not have background removed even if remove_background=True.
|
819 |
+
e.g. 0.5 means if less than half of the audio segment is silent then background will be removed accordingly
|
820 |
+
|
821 |
+
prepend_punctuations: Union[List[str], Tuple[str]]
|
822 |
+
Punctuations to prepend to next word (Default: “¿([{)
|
823 |
+
|
824 |
+
append_punctuations: Union[List[str], Tuple[str]]
|
825 |
+
Punctuations to append to previous word (Default: .。,,!!??::”)]}、)
|
826 |
+
|
827 |
+
audio_for_mask: (str, bytes)
|
828 |
+
Original audio track as path or bytes of audio file.
|
829 |
+
Since resampled audio may shift the waveform image,
|
830 |
+
this is an alternative to 'audio' option to generate suppression mask from the original audio.
|
831 |
+
|
832 |
+
decode_options: dict
|
833 |
+
Keyword arguments to construct `DecodingOptions` instances
|
834 |
+
|
835 |
+
Returns
|
836 |
+
-------
|
837 |
+
A dictionary containing the resulting text ("text") and segment-level details ("segments"), and
|
838 |
+
the spoken language ("language"), which is detected when `decode_options["language"]` is None.
|
839 |
+
"""
|
840 |
+
|
841 |
+
if 'no_captions_threshold' in decode_options:
|
842 |
+
warnings.warn('no_captions_threshold is deprecated. '
|
843 |
+
'Please use no_speech_threshold instead.', DeprecationWarning, stacklevel=2)
|
844 |
+
no_speech_threshold = decode_options.pop('no_captions_threshold')
|
845 |
+
|
846 |
+
dtype = torch.float16 if decode_options.get("fp16", True) else torch.float32
|
847 |
+
if model.device == torch.device("cpu"):
|
848 |
+
if torch.cuda.is_available():
|
849 |
+
warnings.warn("Performing inference on CPU when CUDA is available")
|
850 |
+
if dtype == torch.float16:
|
851 |
+
warnings.warn("FP16 is not supported on CPU; using FP32 instead")
|
852 |
+
dtype = torch.float32
|
853 |
+
|
854 |
+
if dtype == torch.float32:
|
855 |
+
decode_options["fp16"] = False
|
856 |
+
|
857 |
+
if 'max_initial_timestamp' not in decode_options:
|
858 |
+
decode_options['max_initial_timestamp'] = None
|
859 |
+
|
860 |
+
mel = log_mel_spectrogram(audio)
|
861 |
+
|
862 |
+
if decode_options.get("language", None) is None:
|
863 |
+
if verbose:
|
864 |
+
print("Detecting language using up to the first 30 seconds. Use `--language` to specify the language")
|
865 |
+
segment = pad_or_trim(mel, N_FRAMES).to(model.device).to(dtype)
|
866 |
+
_, probs = model.detect_language(segment)
|
867 |
+
decode_options["language"] = max(probs, key=probs.get)
|
868 |
+
print(f"Detected language: {LANGUAGES[decode_options['language']]}")
|
869 |
+
|
870 |
+
mel = mel.unsqueeze(0)
|
871 |
+
language = decode_options["language"]
|
872 |
+
task = decode_options.get("task", "transcribe")
|
873 |
+
tokenizer = get_tokenizer(model.is_multilingual, language=language, task=task)
|
874 |
+
|
875 |
+
def decode_with_fallback(segment: torch.Tensor, suppress_ts_mask: Tensor = None) \
|
876 |
+
-> Union[List[DecodingResult], tuple]:
|
877 |
+
temperatures = [temperature] if isinstance(temperature, (int, float)) else temperature
|
878 |
+
kwargs = {**decode_options}
|
879 |
+
t = temperatures[0]
|
880 |
+
if t == 0:
|
881 |
+
best_of = kwargs.pop("best_of", None)
|
882 |
+
else:
|
883 |
+
best_of = kwargs.get("best_of", None)
|
884 |
+
|
885 |
+
options = DecodingOptions(**kwargs, temperature=t)
|
886 |
+
results, ts_tokens, ts_logits_ = model.decode(segment, options, ts_num=ts_num, alpha=alpha,
|
887 |
+
suppress_ts_mask=suppress_ts_mask,
|
888 |
+
suppress_word_ts=suppress_word_ts)
|
889 |
+
|
890 |
+
kwargs.pop("beam_size", None) # no beam search for t > 0
|
891 |
+
kwargs.pop("patience", None) # no patience for t > 0
|
892 |
+
kwargs["best_of"] = best_of # enable best_of for t > 0
|
893 |
+
for t in temperatures[1:]:
|
894 |
+
needs_fallback = [
|
895 |
+
compression_ratio_threshold is not None
|
896 |
+
and result.compression_ratio > compression_ratio_threshold
|
897 |
+
or logprob_threshold is not None
|
898 |
+
and result.avg_logprob < logprob_threshold
|
899 |
+
for result in results
|
900 |
+
]
|
901 |
+
if any(needs_fallback):
|
902 |
+
options = DecodingOptions(**kwargs, temperature=t)
|
903 |
+
retries, r_ts_tokens, r_ts_logits = model.decode(segment[needs_fallback], options,
|
904 |
+
ts_num=ts_num, alpha=alpha,
|
905 |
+
suppress_ts_mask=suppress_ts_mask,
|
906 |
+
suppress_word_ts=suppress_word_ts)
|
907 |
+
for retry_index, original_index in enumerate(np.nonzero(needs_fallback)[0]):
|
908 |
+
results[original_index] = retries[retry_index]
|
909 |
+
ts_tokens[original_index] = r_ts_tokens[retry_index]
|
910 |
+
ts_logits_[original_index] = r_ts_logits[retry_index]
|
911 |
+
|
912 |
+
return results, ts_tokens, ts_logits_
|
913 |
+
|
914 |
+
seek = 0
|
915 |
+
input_stride = exact_div(
|
916 |
+
N_FRAMES, model.dims.n_audio_ctx
|
917 |
+
) # mel frames per output token: 2
|
918 |
+
time_precision = (
|
919 |
+
input_stride * HOP_LENGTH / SAMPLE_RATE
|
920 |
+
) # time per output token: 0.02 (seconds)
|
921 |
+
all_tokens = []
|
922 |
+
all_segments = []
|
923 |
+
prompt_reset_since = 0
|
924 |
+
|
925 |
+
initial_prompt = decode_options.pop("initial_prompt", None) or []
|
926 |
+
if initial_prompt:
|
927 |
+
initial_prompt = tokenizer.encode(" " + initial_prompt.strip())
|
928 |
+
all_tokens.extend(initial_prompt)
|
929 |
+
|
930 |
+
def _to_list(x: (Tensor, None)):
|
931 |
+
if x is None:
|
932 |
+
return x
|
933 |
+
return x.tolist()
|
934 |
+
|
935 |
+
def add_segment(
|
936 |
+
*, offset: float, start: float, end: float, text_tokens: Tensor, result: DecodingResult,
|
937 |
+
start_timestamps: list = None, end_timestamps: list = None, word_timestamps: Tensor = None,
|
938 |
+
start_ts_logits: list = None, end_ts_logits: list = None, word_ts_logits: Tensor = None
|
939 |
+
):
|
940 |
+
no_eot_mask = text_tokens < tokenizer.eot
|
941 |
+
text_tokens_no_eot = text_tokens[no_eot_mask]
|
942 |
+
text = tokenizer.decode(text_tokens_no_eot)
|
943 |
+
|
944 |
+
if len(text.strip()) == 0: # skip empty text output
|
945 |
+
return
|
946 |
+
|
947 |
+
if word_timestamps is not None:
|
948 |
+
assert word_timestamps.shape[0] == text_tokens.shape[0]
|
949 |
+
if word_ts_logits is None:
|
950 |
+
word_ts_fields = zip(text_tokens_no_eot, word_timestamps[no_eot_mask], repeat(None))
|
951 |
+
else:
|
952 |
+
assert word_ts_logits.shape[0] == text_tokens.shape[0]
|
953 |
+
word_ts_fields = zip(text_tokens_no_eot, word_timestamps[no_eot_mask], word_ts_logits[no_eot_mask])
|
954 |
+
|
955 |
+
word_timestamps = [dict(word=tokenizer.decode([token]),
|
956 |
+
token=token.item(),
|
957 |
+
timestamps=timestamps_.tolist(),
|
958 |
+
timestamp_logits=_to_list(ts_logits_))
|
959 |
+
for token, timestamps_, ts_logits_ in word_ts_fields]
|
960 |
+
|
961 |
+
all_segments.append(
|
962 |
+
{
|
963 |
+
"id": len(all_segments),
|
964 |
+
"seek": seek,
|
965 |
+
'offset': offset, # offset = float(seek * HOP_LENGTH / SAMPLE_RATE)
|
966 |
+
"start": start,
|
967 |
+
"end": end,
|
968 |
+
"text": text,
|
969 |
+
"tokens": result.tokens,
|
970 |
+
"temperature": result.temperature,
|
971 |
+
"avg_logprob": result.avg_logprob,
|
972 |
+
"compression_ratio": result.compression_ratio,
|
973 |
+
"no_speech_prob": get_new_attrs(result, 'no_caption_prob'),
|
974 |
+
"alt_start_timestamps": start_timestamps,
|
975 |
+
"start_ts_logits": start_ts_logits,
|
976 |
+
"alt_end_timestamps": end_timestamps,
|
977 |
+
"end_ts_logits": end_ts_logits,
|
978 |
+
"unstable_word_timestamps": word_timestamps,
|
979 |
+
'anchor_point': False
|
980 |
+
}
|
981 |
+
)
|
982 |
+
if print_unstab or (verbose and not stab):
|
983 |
+
print(f'[{format_timestamp(start)} --> {format_timestamp(end)}] "{text}"')
|
984 |
+
if word_timestamps is not None:
|
985 |
+
ts_str = (f' ->[{format_timestamp(ts_["timestamps"][0])}] "{ts_["word"].strip()}"' for ts_ in
|
986 |
+
word_timestamps)
|
987 |
+
print('\n'.join(ts_str), end='\n\n')
|
988 |
+
|
989 |
+
if suppress_silence:
|
990 |
+
ts_scale = HOP_LENGTH / SAMPLE_RATE / time_precision
|
991 |
+
wf = _load_audio_waveform(audio_for_mask or audio, 100, int(mel.shape[-1] * ts_scale))
|
992 |
+
|
993 |
+
upper_quantile = decode_options.pop('upper_quantile', 0.85)
|
994 |
+
lower_quantile = decode_options.pop('lower_quantile', 0.15)
|
995 |
+
lower_threshold = decode_options.pop('lower_threshold', 0.15)
|
996 |
+
|
997 |
+
while seek < mel.shape[-1]:
|
998 |
+
timestamp_offset = float(seek * HOP_LENGTH / SAMPLE_RATE)
|
999 |
+
remaining_duration = float((mel.shape[-1] - seek) * HOP_LENGTH / SAMPLE_RATE)
|
1000 |
+
segment = pad_or_trim(mel[:, :, seek:], N_FRAMES).to(model.device).to(dtype)
|
1001 |
+
segment_duration = min(float(segment.shape[-1] * HOP_LENGTH / SAMPLE_RATE), remaining_duration)
|
1002 |
+
segment_max_ts = segment_duration / time_precision
|
1003 |
+
|
1004 |
+
if suppress_silence:
|
1005 |
+
wf_seek = int(seek * ts_scale)
|
1006 |
+
segment_wf = wf[..., wf_seek:wf_seek + 1501]
|
1007 |
+
if remove_background and \
|
1008 |
+
(1 - segment_wf.sum(0).clip(max=1).mean()) < silence_threshold:
|
1009 |
+
segment_wf = _remove_lower_quantile(segment_wf.astype(np.float32),
|
1010 |
+
upper_quantile=upper_quantile,
|
1011 |
+
lower_quantile=lower_quantile,
|
1012 |
+
lower_threshold=lower_threshold)
|
1013 |
+
segment_wf = pad_or_trim(segment_wf, 1501)
|
1014 |
+
suppress_ts_mask = torch.from_numpy(_wave_to_ts_filter(segment_wf,
|
1015 |
+
suppress_middle=suppress_middle,
|
1016 |
+
max_index=int(segment_max_ts)))
|
1017 |
+
|
1018 |
+
if suppress_ts_mask.all(): # segment is silent
|
1019 |
+
seek += segment.shape[-1] # fast-forward to the next segment boundary
|
1020 |
+
continue
|
1021 |
+
else:
|
1022 |
+
suppress_ts_mask = None
|
1023 |
+
|
1024 |
+
decode_options["prompt"] = all_tokens[prompt_reset_since:]
|
1025 |
+
result, finalized_ts_tokens, ts_logits = decode_with_fallback(segment,
|
1026 |
+
suppress_ts_mask=suppress_ts_mask)
|
1027 |
+
|
1028 |
+
result = result[0]
|
1029 |
+
tokens = torch.tensor(result.tokens)
|
1030 |
+
finalized_ts_tokens = torch.tensor(finalized_ts_tokens[0])
|
1031 |
+
ts_logits = torch.tensor(ts_logits[0])
|
1032 |
+
|
1033 |
+
if no_speech_threshold is not None:
|
1034 |
+
# no voice activity check
|
1035 |
+
should_skip = get_new_attrs(result, 'no_caption_prob') > no_speech_threshold
|
1036 |
+
if logprob_threshold is not None and result.avg_logprob > logprob_threshold:
|
1037 |
+
# don't skip if the logprob is high enough, despite the no_speech_prob
|
1038 |
+
should_skip = False
|
1039 |
+
|
1040 |
+
if should_skip:
|
1041 |
+
seek += segment.shape[-1] # fast-forward to the next segment boundary
|
1042 |
+
continue
|
1043 |
+
|
1044 |
+
timestamp_tokens: torch.Tensor = tokens.ge(tokenizer.timestamp_begin)
|
1045 |
+
consecutive = torch.where(timestamp_tokens[:-1] & timestamp_tokens[1:])[0].add_(1)
|
1046 |
+
if len(consecutive) > 0: # if the output contains two consecutive timestamp tokens
|
1047 |
+
last_slice = 0
|
1048 |
+
for current_slice in consecutive:
|
1049 |
+
sliced_tokens = tokens[last_slice:current_slice]
|
1050 |
+
sliced_ts_tokens = finalized_ts_tokens[last_slice:current_slice]
|
1051 |
+
sliced_ts_logits = ts_logits[last_slice:current_slice]
|
1052 |
+
start_timestamp_position = (
|
1053 |
+
sliced_tokens[0].item() - tokenizer.timestamp_begin
|
1054 |
+
)
|
1055 |
+
end_timestamp_position = (
|
1056 |
+
sliced_tokens[-1].item() - tokenizer.timestamp_begin
|
1057 |
+
)
|
1058 |
+
|
1059 |
+
word_ts = timestamp_offset + sliced_ts_tokens * time_precision
|
1060 |
+
|
1061 |
+
add_segment(
|
1062 |
+
offset=timestamp_offset,
|
1063 |
+
start=timestamp_offset + start_timestamp_position * time_precision,
|
1064 |
+
end=min(timestamp_offset + end_timestamp_position * time_precision,
|
1065 |
+
timestamp_offset + segment_duration),
|
1066 |
+
text_tokens=sliced_tokens[1:-1],
|
1067 |
+
result=result,
|
1068 |
+
start_timestamps=word_ts[0].tolist(),
|
1069 |
+
end_timestamps=word_ts[-1].tolist(),
|
1070 |
+
word_timestamps=word_ts[1:-1],
|
1071 |
+
start_ts_logits=sliced_ts_logits[0].tolist(),
|
1072 |
+
end_ts_logits=sliced_ts_logits[-1].tolist(),
|
1073 |
+
word_ts_logits=sliced_ts_logits[1:-1]
|
1074 |
+
)
|
1075 |
+
last_slice = current_slice
|
1076 |
+
last_timestamp_position = (
|
1077 |
+
min(tokens[last_slice - 1].item() - tokenizer.timestamp_begin, segment_max_ts)
|
1078 |
+
)
|
1079 |
+
seek += last_timestamp_position * input_stride
|
1080 |
+
all_tokens.extend(tokens[: last_slice + 1].tolist())
|
1081 |
+
else:
|
1082 |
+
duration = segment_duration
|
1083 |
+
timestamps = tokens[timestamp_tokens.nonzero().flatten()]
|
1084 |
+
if len(timestamps) > 0:
|
1085 |
+
# no consecutive timestamps but it has a timestamp; use the last one.
|
1086 |
+
# single timestamp at the end means no speech after the last timestamp.
|
1087 |
+
last_timestamp_position = min(timestamps[-1].item() - tokenizer.timestamp_begin, segment_max_ts)
|
1088 |
+
duration = last_timestamp_position * time_precision
|
1089 |
+
|
1090 |
+
word_ts = timestamp_offset + finalized_ts_tokens * time_precision
|
1091 |
+
|
1092 |
+
add_segment(
|
1093 |
+
offset=timestamp_offset,
|
1094 |
+
start=timestamp_offset,
|
1095 |
+
end=timestamp_offset + duration,
|
1096 |
+
text_tokens=tokens,
|
1097 |
+
result=result,
|
1098 |
+
word_timestamps=word_ts,
|
1099 |
+
word_ts_logits=ts_logits
|
1100 |
+
)
|
1101 |
+
|
1102 |
+
seek += segment.shape[-1]
|
1103 |
+
all_tokens.extend(tokens.tolist())
|
1104 |
+
|
1105 |
+
if all_segments:
|
1106 |
+
all_segments[-1]['anchor_point'] = True
|
1107 |
+
all_segments[-1]['next_offset'] = float(seek * HOP_LENGTH / SAMPLE_RATE)
|
1108 |
+
if not condition_on_previous_text or result.temperature > 0.5:
|
1109 |
+
# do not feed the prompt tokens if a high temperature was used
|
1110 |
+
prompt_reset_since = len(all_tokens)
|
1111 |
+
|
1112 |
+
if len(all_segments) > 1 and all_segments[-1]['alt_start_timestamps'] is None:
|
1113 |
+
all_segments[-1]['alt_start_timestamps'] = all_segments[-2]['alt_end_timestamps']
|
1114 |
+
|
1115 |
+
if stab:
|
1116 |
+
all_segments = stabilize_timestamps(all_segments, top_focus=top_focus)
|
1117 |
+
add_whole_word_ts(tokenizer, all_segments,
|
1118 |
+
prepend_punctuations=prepend_punctuations,
|
1119 |
+
append_punctuations=append_punctuations)
|
1120 |
+
if verbose:
|
1121 |
+
print('\nSTABILIZED\n')
|
1122 |
+
for seg_ in all_segments:
|
1123 |
+
print(f'[{format_timestamp(seg_["start"])} --> {format_timestamp(seg_["end"])}] "{seg_["text"]}"')
|
1124 |
+
if seg_['word_timestamps']:
|
1125 |
+
ts_str = (f' ->[{format_timestamp(ts_["timestamp"])}] "{ts_["word"].strip()}"' for ts_ in
|
1126 |
+
seg_['word_timestamps'])
|
1127 |
+
print('\n'.join(ts_str), end='\n\n')
|
1128 |
+
|
1129 |
+
return dict(text=tokenizer.decode(all_tokens[len(initial_prompt):]), segments=all_segments, language=language)
|
1130 |
+
|
1131 |
+
|
1132 |
+
def _suppress_ts(ts_logits: Tensor, suppress_ts_mask: Tensor = None):
|
1133 |
+
if suppress_ts_mask is not None:
|
1134 |
+
ts_logits[:, suppress_ts_mask] = -np.inf
|
1135 |
+
|
1136 |
+
|
1137 |
+
def _ts_topk(ts_logits: Tensor, k: int, prev_ts: Tensor = None) -> Tensor:
|
1138 |
+
temp_ts = torch.stack(torch.topk(ts_logits, k, dim=-1), 0).unsqueeze(-2)
|
1139 |
+
return temp_ts if prev_ts is None else torch.cat([prev_ts, temp_ts], dim=-2)
|
1140 |
+
|
1141 |
+
|
1142 |
+
# modified version of whisper.GreedyDecoder
|
1143 |
+
class GreedyDecoderWordLevel(GreedyDecoder):
|
1144 |
+
def __init__(self, *args, **kwargs):
|
1145 |
+
self.ts_num: int = kwargs.pop('ts_num', 10)
|
1146 |
+
self.suppress_ts_mask: Tensor = kwargs.pop('suppress_ts_mask', None)
|
1147 |
+
self.timestamp_begin = kwargs.pop('timestamp_begin', 50364)
|
1148 |
+
super(GreedyDecoderWordLevel, self).__init__(*args, **kwargs)
|
1149 |
+
self.ts = None
|
1150 |
+
|
1151 |
+
def _suppress_ts(self, logits: Tensor):
|
1152 |
+
_suppress_ts(logits[:, self.timestamp_begin:],
|
1153 |
+
suppress_ts_mask=self.suppress_ts_mask)
|
1154 |
+
|
1155 |
+
def update_with_ts(self, tokens: Tensor, logits: Tensor, sum_logprobs: Tensor, ts: Tensor) -> Tuple[Tensor, bool]:
|
1156 |
+
self.ts = ts
|
1157 |
+
|
1158 |
+
self._suppress_ts(logits)
|
1159 |
+
|
1160 |
+
if self.temperature == 0:
|
1161 |
+
next_tokens = logits.argmax(dim=-1)
|
1162 |
+
else:
|
1163 |
+
next_tokens = Categorical(logits=logits / self.temperature).sample()
|
1164 |
+
|
1165 |
+
logprobs = F.log_softmax(logits.float(), dim=-1)
|
1166 |
+
current_logprobs = logprobs[torch.arange(logprobs.shape[0]), next_tokens]
|
1167 |
+
sum_logprobs += current_logprobs * (tokens[:, -1] != self.eot)
|
1168 |
+
|
1169 |
+
next_tokens[tokens[:, -1] == self.eot] = self.eot
|
1170 |
+
tokens = torch.cat([tokens, next_tokens[:, None]], dim=-1)
|
1171 |
+
|
1172 |
+
completed = (tokens[:, -1] == self.eot).all()
|
1173 |
+
return tokens, completed
|
1174 |
+
|
1175 |
+
def finalize(self, tokens: Tensor, sum_logprobs: Tensor):
|
1176 |
+
# make sure each sequence has at least one EOT token at the end
|
1177 |
+
tokens = F.pad(tokens, (0, 1), value=self.eot)
|
1178 |
+
return tokens, sum_logprobs.tolist(), self.ts.transpose(1, 0)[None]
|
1179 |
+
|
1180 |
+
|
1181 |
+
# modified version of whisper.BeamSearchDecoder
|
1182 |
+
class BeamSearchDecoderWordLevel(BeamSearchDecoder):
|
1183 |
+
|
1184 |
+
def __init__(self, *args, **kwargs):
|
1185 |
+
self.ts_num: int = kwargs.pop('ts_num', 10)
|
1186 |
+
self.suppress_ts_mask: Tensor = kwargs.pop('suppress_ts_mask', None)
|
1187 |
+
self.timestamp_begin = kwargs.pop('timestamp_begin', 50364)
|
1188 |
+
super(BeamSearchDecoderWordLevel, self).__init__(*args, **kwargs)
|
1189 |
+
self.ts = None
|
1190 |
+
self.finished_ts_ls = None
|
1191 |
+
|
1192 |
+
def reset(self):
|
1193 |
+
self.finished_sequences = None
|
1194 |
+
self.finished_ts_ls = None
|
1195 |
+
|
1196 |
+
def _suppress_ts(self, logits: Tensor):
|
1197 |
+
_suppress_ts(logits[:, self.timestamp_begin:],
|
1198 |
+
suppress_ts_mask=self.suppress_ts_mask)
|
1199 |
+
|
1200 |
+
def update_with_ts(self, tokens: Tensor, logits: Tensor, sum_logprobs: Tensor, ts: Tensor) -> Tuple[Tensor, bool]:
|
1201 |
+
if tokens.shape[0] % self.beam_size != 0:
|
1202 |
+
raise ValueError(f"{tokens.shape}[0] % {self.beam_size} != 0")
|
1203 |
+
|
1204 |
+
self.ts = ts
|
1205 |
+
|
1206 |
+
n_audio = tokens.shape[0] // self.beam_size
|
1207 |
+
if self.finished_sequences is None: # for the first update
|
1208 |
+
self.finished_sequences = [{} for _ in range(n_audio)]
|
1209 |
+
self.finished_ts_ls = [{} for _ in range(n_audio)]
|
1210 |
+
|
1211 |
+
logprobs = F.log_softmax(logits.float(), dim=-1)
|
1212 |
+
next_tokens, source_indices, finished_sequences, finished_ts_ls = [], [], [], []
|
1213 |
+
|
1214 |
+
self._suppress_ts(logprobs)
|
1215 |
+
|
1216 |
+
for i in range(n_audio):
|
1217 |
+
scores, sources, finished, finished_ts = {}, {}, {}, {}
|
1218 |
+
|
1219 |
+
# STEP 1: calculate the cumulative log probabilities for possible candidates
|
1220 |
+
for j in range(self.beam_size):
|
1221 |
+
idx = i * self.beam_size + j
|
1222 |
+
prefix = tokens[idx].tolist()
|
1223 |
+
for logprob, token in zip(*logprobs[idx].topk(self.beam_size + 1)):
|
1224 |
+
new_logprob = (sum_logprobs[idx] + logprob).item()
|
1225 |
+
sequence = tuple(prefix + [token.item()])
|
1226 |
+
scores[sequence] = new_logprob
|
1227 |
+
sources[sequence] = idx
|
1228 |
+
|
1229 |
+
# STEP 2: rank the candidates and keep the top beam_size sequences for each audio
|
1230 |
+
saved = 0
|
1231 |
+
for sequence in sorted(scores, key=scores.get, reverse=True):
|
1232 |
+
if sequence[-1] == self.eot:
|
1233 |
+
finished[sequence] = scores[sequence]
|
1234 |
+
finished_ts[sequence] = self.ts[:, sources[sequence]]
|
1235 |
+
else:
|
1236 |
+
sum_logprobs[len(next_tokens)] = scores[sequence]
|
1237 |
+
next_tokens.append(sequence)
|
1238 |
+
source_indices.append(sources[sequence])
|
1239 |
+
|
1240 |
+
saved += 1
|
1241 |
+
if saved == self.beam_size:
|
1242 |
+
break
|
1243 |
+
|
1244 |
+
finished_sequences.append(finished)
|
1245 |
+
finished_ts_ls.append(finished_ts)
|
1246 |
+
|
1247 |
+
tokens = torch.tensor(next_tokens, device=tokens.device)
|
1248 |
+
self.inference.rearrange_kv_cache(source_indices)
|
1249 |
+
self.ts = self.ts[:, source_indices]
|
1250 |
+
|
1251 |
+
# add newly finished sequences to self.finished_sequences
|
1252 |
+
assert len(self.finished_sequences) == len(finished_sequences)
|
1253 |
+
for previously_finished, newly_finished, \
|
1254 |
+
prev_ts_ls, new_ts_ls in \
|
1255 |
+
zip(self.finished_sequences, finished_sequences,
|
1256 |
+
self.finished_ts_ls, finished_ts_ls):
|
1257 |
+
for seq in sorted(newly_finished, key=newly_finished.get, reverse=True):
|
1258 |
+
if len(previously_finished) >= self.max_candidates:
|
1259 |
+
break # the candidate list is full
|
1260 |
+
previously_finished[seq] = newly_finished[seq]
|
1261 |
+
prev_ts_ls[seq] = new_ts_ls[seq]
|
1262 |
+
|
1263 |
+
# mark as completed if all audio has enough number of samples
|
1264 |
+
completed = all(
|
1265 |
+
len(sequences) >= self.max_candidates for sequences in self.finished_sequences
|
1266 |
+
)
|
1267 |
+
return tokens, completed
|
1268 |
+
|
1269 |
+
def finalize(self, preceding_tokens: Tensor, sum_logprobs: Tensor):
|
1270 |
+
# collect all finished sequences, including patience, and add unfinished ones if not enough
|
1271 |
+
self.ts = self.ts.reshape(self.ts.shape[0], *preceding_tokens.shape[:2], *self.ts.shape[2:])
|
1272 |
+
sum_logprobs = sum_logprobs.cpu()
|
1273 |
+
for i, (sequences, ts_) in \
|
1274 |
+
enumerate(zip(self.finished_sequences, self.finished_ts_ls)):
|
1275 |
+
if len(sequences) < self.beam_size: # when not enough sequences are finished
|
1276 |
+
for j in list(np.argsort(sum_logprobs[i]))[::-1]:
|
1277 |
+
sequence = preceding_tokens[i, j].tolist() + [self.eot]
|
1278 |
+
seq_tuple = tuple(sequence)
|
1279 |
+
sequences[seq_tuple] = sum_logprobs[i][j].item()
|
1280 |
+
ts_[seq_tuple] = self.ts[:, i, j]
|
1281 |
+
if len(sequences) >= self.beam_size:
|
1282 |
+
break
|
1283 |
+
|
1284 |
+
tokens: List[List[Tensor]] = [
|
1285 |
+
[torch.tensor(seq) for seq in sequences.keys()] for sequences in self.finished_sequences
|
1286 |
+
]
|
1287 |
+
sum_logprobs: List[List[float]] = [
|
1288 |
+
list(sequences.values()) for sequences in self.finished_sequences
|
1289 |
+
]
|
1290 |
+
final_ts: List[List[Tensor]] = [
|
1291 |
+
list(sequences.values()) for sequences in self.finished_ts_ls
|
1292 |
+
]
|
1293 |
+
return tokens, sum_logprobs, final_ts
|
1294 |
+
|
1295 |
+
|
1296 |
+
class DecodingTaskWordLevel(DecodingTask):
|
1297 |
+
|
1298 |
+
def __init__(self, *args, **kwargs):
|
1299 |
+
self.ts_num: int = kwargs.pop('ts_num', 10)
|
1300 |
+
self.alpha: float = kwargs.pop('alpha', None) # experimental
|
1301 |
+
self.suppress_ts_mask: Tensor = kwargs.pop('suppress_ts_mask', None)
|
1302 |
+
self.suppress_word_ts: bool = kwargs.pop('suppress_word_ts', True)
|
1303 |
+
super(DecodingTaskWordLevel, self).__init__(*args, **kwargs)
|
1304 |
+
if hasattr(self.decoder, 'beam_size'):
|
1305 |
+
self.decoder = BeamSearchDecoderWordLevel(self.decoder.beam_size,
|
1306 |
+
self.decoder.eot,
|
1307 |
+
self.inference,
|
1308 |
+
self.decoder.patience,
|
1309 |
+
ts_num=self.ts_num,
|
1310 |
+
suppress_ts_mask=self.suppress_ts_mask,
|
1311 |
+
timestamp_begin=self.tokenizer.timestamp_begin)
|
1312 |
+
else:
|
1313 |
+
self.decoder = GreedyDecoderWordLevel(self.decoder.temperature,
|
1314 |
+
self.decoder.eot,
|
1315 |
+
ts_num=self.ts_num,
|
1316 |
+
suppress_ts_mask=self.suppress_ts_mask,
|
1317 |
+
timestamp_begin=self.tokenizer.timestamp_begin)
|
1318 |
+
|
1319 |
+
# modified version of whisper.DecodingTask._main_loop
|
1320 |
+
def _main_loop(self, audio_features: Tensor, tokens: Tensor):
|
1321 |
+
assert audio_features.shape[0] == tokens.shape[0]
|
1322 |
+
n_batch = tokens.shape[0]
|
1323 |
+
sum_logprobs: Tensor = torch.zeros(n_batch, device=audio_features.device)
|
1324 |
+
no_speech_probs = [np.nan] * n_batch
|
1325 |
+
|
1326 |
+
# ts = None
|
1327 |
+
|
1328 |
+
try:
|
1329 |
+
for i in range(self.sample_len):
|
1330 |
+
if self.alpha:
|
1331 |
+
logits = self.inference.logits(tokens,
|
1332 |
+
audio_features * (torch.rand_like(audio_features) * self.alpha + 1))
|
1333 |
+
else:
|
1334 |
+
logits = self.inference.logits(tokens, audio_features)
|
1335 |
+
|
1336 |
+
if i == 0 and get_new_attrs(self.tokenizer, 'no_captions') is not None: # save no_speech_probs
|
1337 |
+
probs_at_sot = logits[:, self.sot_index].float().softmax(dim=-1)
|
1338 |
+
no_speech_probs = probs_at_sot[:, get_new_attrs(self.tokenizer, 'no_captions')].tolist()
|
1339 |
+
|
1340 |
+
# now we need to consider the logits at the last token only
|
1341 |
+
logits = logits[:, -1]
|
1342 |
+
|
1343 |
+
ts_logits = torch.clone(logits[:, self.tokenizer.timestamp_begin:])
|
1344 |
+
if self.suppress_word_ts:
|
1345 |
+
_suppress_ts(ts_logits, self.suppress_ts_mask)
|
1346 |
+
ts = _ts_topk(ts_logits, k=self.ts_num, prev_ts=self.decoder.ts)
|
1347 |
+
|
1348 |
+
# apply the logit filters, e.g. for suppressing or applying penalty to
|
1349 |
+
for logit_filter in self.logit_filters:
|
1350 |
+
logit_filter.apply(logits, tokens)
|
1351 |
+
|
1352 |
+
# expand the tokens tensor with the selected next tokens
|
1353 |
+
tokens, completed = self.decoder.update_with_ts(tokens, logits, sum_logprobs, ts)
|
1354 |
+
|
1355 |
+
if completed or tokens.shape[-1] > self.n_ctx:
|
1356 |
+
break
|
1357 |
+
finally:
|
1358 |
+
self.inference.cleanup_caching()
|
1359 |
+
|
1360 |
+
return tokens, sum_logprobs, no_speech_probs
|
1361 |
+
|
1362 |
+
# modified version of whisper.DecodingTask.run
|
1363 |
+
@torch.no_grad()
|
1364 |
+
def run(self, mel: Tensor) \
|
1365 |
+
-> Union[List[DecodingResult], Tuple[List[DecodingResult], List[List[int]]]]:
|
1366 |
+
self.decoder.reset()
|
1367 |
+
tokenizer: Tokenizer = self.tokenizer
|
1368 |
+
n_audio: int = mel.shape[0]
|
1369 |
+
|
1370 |
+
audio_features: Tensor = self._get_audio_features(mel) # encoder forward pass
|
1371 |
+
tokens: Tensor = torch.tensor([self.initial_tokens]).expand(n_audio, -1)
|
1372 |
+
|
1373 |
+
# detect language if requested, overwriting the language token
|
1374 |
+
languages, language_probs = self._detect_language(audio_features, tokens)
|
1375 |
+
if self.options.task == "lang_id":
|
1376 |
+
return [
|
1377 |
+
DecodingResult(audio_features=features, language=language, language_probs=probs)
|
1378 |
+
for features, language, probs in zip(audio_features, languages, language_probs)
|
1379 |
+
]
|
1380 |
+
|
1381 |
+
# repeat the audio & text tensors by the group size, for beam search or best-of-n sampling
|
1382 |
+
audio_features = audio_features.repeat_interleave(self.n_group, dim=0)
|
1383 |
+
tokens = tokens.repeat_interleave(self.n_group, dim=0).to(audio_features.device)
|
1384 |
+
|
1385 |
+
# call the main sampling loop
|
1386 |
+
tokens, sum_logprobs, no_speech_probs = self._main_loop(audio_features, tokens)
|
1387 |
+
|
1388 |
+
# reshape the tensors to have (n_audio, n_group) as the first two dimensions
|
1389 |
+
audio_features = audio_features[:: self.n_group]
|
1390 |
+
no_speech_probs = no_speech_probs[:: self.n_group]
|
1391 |
+
assert audio_features.shape[0] == len(no_speech_probs) == n_audio
|
1392 |
+
|
1393 |
+
tokens = tokens.reshape(n_audio, self.n_group, -1)
|
1394 |
+
sum_logprobs = sum_logprobs.reshape(n_audio, self.n_group)
|
1395 |
+
|
1396 |
+
# get the final candidates for each group, and slice between the first sampled token and EOT
|
1397 |
+
tokens, sum_logprobs, ts = self.decoder.finalize(tokens, sum_logprobs)
|
1398 |
+
tokens: List[List[Tensor]] = [
|
1399 |
+
[t[self.sample_begin: (t == tokenizer.eot).nonzero()[0, 0]] for t in s] for s in tokens
|
1400 |
+
]
|
1401 |
+
ts: List[List[Tensor]] = [[t[:, :tokens[i][j].shape[-1]] for j, t in enumerate(s)] for i, s in enumerate(ts)]
|
1402 |
+
|
1403 |
+
# select the top-ranked sample in each group
|
1404 |
+
selected = self.sequence_ranker.rank(tokens, sum_logprobs)
|
1405 |
+
tokens: List[List[int]] = [t[i].tolist() for i, t in zip(selected, tokens)]
|
1406 |
+
ts: List[List[int]] = [t[i].tolist() for i, t in zip(selected, ts)]
|
1407 |
+
texts: List[str] = [tokenizer.decode(t).strip() for t in tokens]
|
1408 |
+
|
1409 |
+
sum_logprobs: List[float] = [lp[i] for i, lp in zip(selected, sum_logprobs)]
|
1410 |
+
avg_logprobs: List[float] = [lp / (len(t) + 1) for t, lp in zip(tokens, sum_logprobs)]
|
1411 |
+
|
1412 |
+
fields = (texts, languages, tokens, audio_features, avg_logprobs, no_speech_probs)
|
1413 |
+
if len(set(map(len, fields))) != 1:
|
1414 |
+
raise RuntimeError(f"inconsistent result lengths: {list(map(len, fields))}")
|
1415 |
+
|
1416 |
+
return [
|
1417 |
+
DecodingResult(
|
1418 |
+
audio_features=features,
|
1419 |
+
language=language,
|
1420 |
+
tokens=tokens,
|
1421 |
+
text=text,
|
1422 |
+
avg_logprob=avg_logprob,
|
1423 |
+
**(dict(no_caption_prob=no_speech_prob) if hasattr(DecodingResult, 'no_caption_prob') else dict(
|
1424 |
+
no_speech_prob=no_speech_prob)),
|
1425 |
+
temperature=self.options.temperature,
|
1426 |
+
compression_ratio=compression_ratio(text),
|
1427 |
+
)
|
1428 |
+
for text, language, tokens, features, avg_logprob, no_speech_prob in zip(*fields)
|
1429 |
+
], ts
|
1430 |
+
|
1431 |
+
|
1432 |
+
# modified version of whisper.decoding.decode
|
1433 |
+
@torch.no_grad()
|
1434 |
+
def decode_word_level(model: "Whisper", mel: Tensor, options: DecodingOptions = DecodingOptions(),
|
1435 |
+
ts_num: int = None, alpha: float = None, suppress_ts_mask: Tensor = None,
|
1436 |
+
suppress_word_ts=False) -> \
|
1437 |
+
Union[DecodingResult, List[DecodingResult], tuple]:
|
1438 |
+
"""
|
1439 |
+
Performs decoding of 30-second audio segment(s), provided as Mel spectrogram(s).
|
1440 |
+
|
1441 |
+
Parameters
|
1442 |
+
----------
|
1443 |
+
model: Whisper
|
1444 |
+
the Whisper model instance
|
1445 |
+
|
1446 |
+
mel: torch.Tensor, shape = (80, 3000) or (*, 80, 3000)
|
1447 |
+
A tensor containing the Mel spectrogram(s)
|
1448 |
+
|
1449 |
+
options: DecodingOptions
|
1450 |
+
A dataclass that contains all necessary options for decoding 30-second segments
|
1451 |
+
|
1452 |
+
ts_num: int
|
1453 |
+
Number of additional top timestamp predictions to save for each word for postprocessing stabilization (default: 5).
|
1454 |
+
|
1455 |
+
alpha: float
|
1456 |
+
Amount of noise to add to audio to produce slightly difference results.
|
1457 |
+
audio_features *= torch.rand_like(audio_features) * alpha + 1
|
1458 |
+
|
1459 |
+
suppress_ts_mask: (list, Tensor)
|
1460 |
+
Mask suppress to timestamp token(s) for decoding
|
1461 |
+
|
1462 |
+
suppress_word_ts: bool
|
1463 |
+
Use suppress_ts_mask to suppress timestamp tokens of words
|
1464 |
+
|
1465 |
+
Returns
|
1466 |
+
-------
|
1467 |
+
result: Union[DecodingResult, List[DecodingResult]]
|
1468 |
+
The result(s) of decoding contained in `DecodingResult` dataclass instance(s)
|
1469 |
+
"""
|
1470 |
+
single = mel.ndim == 2
|
1471 |
+
if single:
|
1472 |
+
mel = mel.unsqueeze(0)
|
1473 |
+
|
1474 |
+
result, ts = DecodingTaskWordLevel(model, options,
|
1475 |
+
ts_num=ts_num,
|
1476 |
+
alpha=alpha,
|
1477 |
+
suppress_ts_mask=suppress_ts_mask,
|
1478 |
+
suppress_word_ts=suppress_word_ts).run(mel)
|
1479 |
+
|
1480 |
+
if single:
|
1481 |
+
result = result[0]
|
1482 |
+
ts_tokens = ts[0][1]
|
1483 |
+
ts_logits = ts[0][0]
|
1484 |
+
else:
|
1485 |
+
ts_tokens = [ts_[1] for ts_ in ts]
|
1486 |
+
ts_logits = [ts_[0] for ts_ in ts]
|
1487 |
+
|
1488 |
+
return result, ts_tokens, ts_logits
|
1489 |
+
|
1490 |
+
|
1491 |
+
def modify_model(model: whisper.model.Whisper):
|
1492 |
+
model.decode = MethodType(decode_word_level, model)
|
1493 |
+
model.transcribe = MethodType(transcribe_word_level, model)
|