robertselvam commited on
Commit
c3d5eb3
·
1 Parent(s): 7f8dd93

Upload stable_whisper.py

Browse files
Files changed (1) hide show
  1. stable_whisper.py +1493 -0
stable_whisper.py ADDED
@@ -0,0 +1,1493 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import ffmpeg
3
+ import whisper
4
+ import warnings
5
+ import numpy as np
6
+ import torch
7
+ from torch import Tensor
8
+ from torch.nn import functional as F
9
+ from torch.distributions import Categorical
10
+ from typing import List, Optional, Tuple, Union
11
+ from whisper.audio import SAMPLE_RATE, N_FRAMES, HOP_LENGTH, pad_or_trim, log_mel_spectrogram
12
+ from whisper.decoding import DecodingOptions, DecodingResult
13
+ from whisper.tokenizer import LANGUAGES
14
+ from whisper.utils import exact_div, format_timestamp, compression_ratio
15
+ from whisper.model import Whisper
16
+ from whisper.decoding import DecodingTask, BeamSearchDecoder, GreedyDecoder
17
+ from whisper.tokenizer import Tokenizer, get_tokenizer
18
+ from types import MethodType
19
+ from itertools import chain, repeat
20
+ from copy import deepcopy
21
+ import os
22
+ import json
23
+
24
+
25
+ # no_caption changed to no_speech newer commits
26
+ def get_new_attrs(obj_, attr: str):
27
+ if attr == 'no_caption_probs':
28
+ return getattr(obj_, attr) if hasattr(obj_, 'no_caption_probs') else getattr(obj_, 'no_speech_probs')
29
+ elif attr == 'no_caption_prob':
30
+ return getattr(obj_, attr) if hasattr(obj_, 'no_caption_prob') else getattr(obj_, 'no_speech_prob')
31
+ elif attr == 'no_captions':
32
+ return getattr(obj_, attr) if hasattr(obj_, 'no_captions') else getattr(obj_, 'no_speech')
33
+ else:
34
+ raise NotImplementedError(attr)
35
+
36
+
37
+ def check_ascending_sequence(seq: Union[List[Union[int, float]], np.ndarray], verbose=True) -> bool:
38
+ """
39
+ check if a sequence of numbers are in ascending order
40
+ """
41
+ is_ascending = True
42
+ for idx, (i, j) in enumerate(zip(seq[:-1], seq[1:])):
43
+ if i > j:
44
+ is_ascending = False
45
+ if verbose:
46
+ print(f'[Index{idx}]:{i} > [Index{idx + 1}]:{j}')
47
+ else:
48
+ break
49
+
50
+ return is_ascending
51
+
52
+
53
+ def check_ascending_sentence_ts(res: (dict, list)) -> bool:
54
+ segs = res['segments'] if isinstance(res, dict) else res
55
+ return check_ascending_sequence(list(chain.from_iterable((float(i['start']), float(i['end']))
56
+ for i in segs)))
57
+
58
+
59
+ def check_ascending_word_ts(res: (dict, list)) -> bool:
60
+ cc = group_word_timestamps(res['segments'] if isinstance(res, dict) else res, ts_key='word_timestamps')
61
+ return check_ascending_sequence((list(chain.from_iterable((float(i['start']), float(i['end']))
62
+ for i in cc))))
63
+
64
+
65
+ def is_equal_ts(a: (float, int, np.ndarray), b: (float, int, np.ndarray), rtol=1e-03):
66
+ """
67
+ check if timestamp a and timestamp b are equal within the relative tolerance (rtol)
68
+ """
69
+ return np.isclose(a, b, rtol=rtol)
70
+
71
+
72
+ def check_is_same_results(res0: (dict, list), res1: (dict, list), check_unstable=False) -> bool:
73
+ """
74
+ check if res0 and res1 have same timestamps
75
+ """
76
+ if isinstance(res0, dict):
77
+ res0 = res0['segments']
78
+ if isinstance(res1, dict):
79
+ res1 = res1['segments']
80
+ ts_key = 'unstable_word_timestamps' if check_unstable else 'word_timestamps'
81
+ inner_ts_key = 'timestamps' if check_unstable else 'timestamp'
82
+
83
+ def _reduce(x):
84
+ if isinstance(x, np.ndarray):
85
+ return set(tuple(x)) == {True}
86
+ return x
87
+
88
+ t = set(set(_reduce(is_equal_ts(a[inner_ts_key], b[inner_ts_key])) for a, b in zip(i[ts_key], j[ts_key])) == {True}
89
+ for i, j in zip(res0['segments'], res1['segments']))
90
+ return t == {True}
91
+
92
+
93
+ def to_srt(lines: List[dict], save_path: str = None, strip=False) -> str:
94
+ """
95
+ lines: List[dict]
96
+ [{start:<start-timestamp-of-text>, end:<end-timestamp-of-text>, text:<str-of-text>}, ...]
97
+ """
98
+
99
+ def secs_to_hhmmss(secs: (float, int)):
100
+ mm, ss = divmod(secs, 60)
101
+ hh, mm = divmod(mm, 60)
102
+ return f'{hh:0>2.0f}:{mm:0>2.0f}:{ss:0>6.3f}'.replace(".", ",")
103
+
104
+ srt_str = '\n'.join(
105
+ f'{i}\n'
106
+ f'{secs_to_hhmmss(sub["start"])} --> {secs_to_hhmmss(sub["end"])}\n'
107
+ f'{sub["text"].strip() if strip else sub["text"]}\n'
108
+ for i, sub in enumerate(lines, 1))
109
+
110
+ if save_path:
111
+ with open(save_path, 'w', encoding='utf-8') as f:
112
+ f.write(srt_str)
113
+ print(f'Saved: {os.path.abspath(save_path)}')
114
+
115
+ return srt_str
116
+
117
+
118
+ def group_word_timestamps(res: (dict, list), one_group=True, combine_compound=False,
119
+ ts_key='whole_word_timestamps', min_dur: float = None):
120
+
121
+ if min_dur is None:
122
+ min_dur = 0.0
123
+
124
+ def group_ts(ts_: List[dict], start) -> List[dict]:
125
+ first_group: List[dict] = []
126
+ for w_ts in ts_:
127
+ if first_group:
128
+ if (not combine_compound or w_ts['word'].startswith(' ')) and \
129
+ (w_ts['timestamp'] - first_group[-1]['start']) >= min_dur and \
130
+ first_group[-1]['end'] < w_ts['timestamp']:
131
+ first_group.append(dict(start=first_group[-1]['end'],
132
+ end=w_ts['timestamp'],
133
+ text=w_ts['word']))
134
+ else:
135
+ first_group[-1]['end'] = max(first_group[-1]['end'], w_ts['timestamp'])
136
+ first_group[-1]['text'] += w_ts['word']
137
+ else:
138
+ first_group.append(dict(start=start,
139
+ end=w_ts['timestamp'],
140
+ text=w_ts['word']))
141
+
142
+ return first_group
143
+
144
+ def group_zero_duration(first_group: List[dict]) -> List[dict]:
145
+ final_group: List[dict] = []
146
+ for ts_dict in first_group:
147
+ if not final_group or (ts_dict['end'] - ts_dict['start']) > 0:
148
+ final_group.append(ts_dict)
149
+ else:
150
+ final_group[-1]['end'] = ts_dict['end']
151
+ final_group[-1]['text'] += ts_dict['text']
152
+
153
+ return final_group
154
+
155
+ segs: List[dict] = res['segments'] if isinstance(res, dict) else res
156
+ assert set(ts_key in seg for seg in segs) == {True}, f'input contains missing {ts_key}'
157
+
158
+ grouped = (group_ts(seg[ts_key], seg['start']) for seg in segs)
159
+ return group_zero_duration(list(chain.from_iterable(grouped))) if one_group else list(grouped)
160
+
161
+
162
+ def tighten_timestamps(res: dict, end_at_last_word=True, end_before_period=False, start_at_first_word=False) -> dict:
163
+ res = deepcopy(res)
164
+ for i in range(len(res['segments'])):
165
+ if start_at_first_word:
166
+ res['segments'][i]['start'] = res['segments'][i]['word_timestamps'][0]['timestamp']
167
+ if end_before_period and \
168
+ res['segments'][i]['word_timestamps'][-1] == '.' and \
169
+ len(res['segments'][i]['word_timestamps']) > 1:
170
+ res['segments'][i]['end'] = res['segments'][i]['word_timestamps'][-2]['timestamp']
171
+ elif end_at_last_word:
172
+ res['segments'][i]['end'] = res['segments'][i]['word_timestamps'][-1]['timestamp']
173
+
174
+ return res
175
+
176
+
177
+ def results_to_srt(res: dict, srt_path, word_level=True, combine_compound=False,
178
+ end_at_last_word=True, end_before_period=False, start_at_first_word=True, strip=False):
179
+ if word_level:
180
+ results_to_word_srt(res, srt_path, combine_compound=combine_compound, strip=strip)
181
+ else:
182
+ results_to_sentence_srt(res, srt_path,
183
+ end_at_last_word=end_at_last_word,
184
+ end_before_period=end_before_period,
185
+ start_at_first_word=start_at_first_word,
186
+ strip=strip)
187
+
188
+
189
+ def results_to_sentence_srt(res: dict, srt_path,
190
+ end_at_last_word=False,
191
+ end_before_period=False,
192
+ start_at_first_word=False,
193
+ strip=False):
194
+ """
195
+
196
+ Parameters
197
+ ----------
198
+ res: dict
199
+ results from modified model
200
+ srt_path: str
201
+ output path of srt
202
+ end_at_last_word: bool
203
+ set end-of-sentence to timestamp-of-last-token
204
+ end_before_period: bool
205
+ set end-of-sentence to timestamp-of-last-non-period-token
206
+ start_at_first_word: bool
207
+ set start-of-sentence to timestamp-of-first-token
208
+ strip: bool
209
+ perform strip() on each sentence
210
+
211
+ """
212
+ strict = any((end_at_last_word, end_before_period, start_at_first_word))
213
+ segs = tighten_timestamps(res,
214
+ end_at_last_word=end_at_last_word,
215
+ end_before_period=end_before_period,
216
+ start_at_first_word=start_at_first_word)['segments'] \
217
+ if strict else res['segments']
218
+
219
+ max_idx = len(segs) - 1
220
+ i = 1
221
+ while i <= max_idx:
222
+ if not (segs[i]['end'] - segs[i]['start']):
223
+ if segs[i - 1]['end'] == segs[i]['end']:
224
+ segs[i - 1]['text'] += (' ' + segs[i]['text'].strip())
225
+ del segs[i]
226
+ max_idx -= 1
227
+ continue
228
+ else:
229
+ segs[i]['start'] = segs[i - 1]['end']
230
+ i += 1
231
+
232
+ to_srt(segs, srt_path, strip=strip)
233
+
234
+
235
+ def results_to_word_srt(res: dict, srt_path, combine_compound=False, strip=False, min_dur: float = None):
236
+ """
237
+
238
+ Parameters
239
+ ----------
240
+ res: dict
241
+ results from modified model
242
+ srt_path: str
243
+ output path of srt
244
+ combine_compound: bool
245
+ concatenate words without inbetween spacing
246
+ strip: bool
247
+ perform strip() on each word
248
+ min_dur: bool
249
+ minimum duration for each word (i.e. concat the words if it is less than specified value; Default 0.02)
250
+
251
+ """
252
+ to_srt(group_word_timestamps(res, combine_compound=combine_compound, min_dur=min_dur),
253
+ srt_path, strip=strip)
254
+
255
+
256
+ def results_to_token_srt(res: dict, srt_path, combine_compound=False, strip=False, min_dur: float = None):
257
+ """
258
+
259
+ Parameters
260
+ ----------
261
+ res: dict
262
+ results from modified model
263
+ srt_path: str
264
+ output path of srt
265
+ combine_compound: bool
266
+ concatenate words without inbetween spacing
267
+ strip: bool
268
+ perform strip() on each token
269
+ min_dur: bool
270
+ minimum duration for each token (i.e. concat the tokens if it is less than specified value; Default 0.02)
271
+
272
+ """
273
+ to_srt(group_word_timestamps(res, combine_compound=combine_compound, ts_key='word_timestamps', min_dur=min_dur),
274
+ srt_path, strip=strip)
275
+
276
+
277
+ def _get_min_estimation(estimations: List[Union[list, np.ndarray]],
278
+ min_: (int, float) = None,
279
+ max_: (int, float) = None) -> np.ndarray:
280
+ estimations = deepcopy(estimations)
281
+ estimations = list(map(lambda est_: np.array(est_) if isinstance(est_, list) else est_, estimations))
282
+ prev_min = min_ or 0
283
+ curr_max = max_ or np.max(estimations[-1])
284
+
285
+ min_est = []
286
+ for curr_est in estimations:
287
+ curr_min = curr_est[np.logical_and(curr_max > curr_est, curr_est > prev_min)]
288
+ curr_min = np.min(curr_min) if curr_min.shape[0] else prev_min
289
+ min_est.append(curr_min)
290
+ prev_min = curr_min
291
+
292
+ return np.array(min_est)
293
+
294
+
295
+ def _get_max_estimation(estimations: List[Union[list, np.ndarray]],
296
+ max_: (int, float) = None,
297
+ min_: (int, float) = None) -> np.ndarray:
298
+ estimations = deepcopy(estimations)
299
+ estimations = list(map(lambda est_: np.array(est_) if isinstance(est_, list) else est_, estimations))
300
+ prev_max = max_ or np.max(estimations[-1])
301
+ curr_min = np.min(estimations[0]) if min_ is None else min_
302
+
303
+ max_est = []
304
+ for curr_est in reversed(estimations):
305
+ curr_max = curr_est[np.logical_and(prev_max > curr_est, curr_est > curr_min)]
306
+ curr_max = np.max(curr_max) if curr_max.shape[0] else prev_max
307
+ max_est.append(curr_max)
308
+ prev_max = curr_max
309
+
310
+ max_est.reverse()
311
+ return np.array(max_est)
312
+
313
+
314
+ def _remove_overestimation(x: Union[np.ndarray, List[Union[int, float]]], alt_est: List[Union[list, np.ndarray]] = None,
315
+ max_: (int, float) = None, min_: (int, float) = None,
316
+ aggressive=False) -> np.ndarray:
317
+ x = np.array(x) if isinstance(x, list) else deepcopy(x)
318
+ if alt_est is not None:
319
+ alt_est = list(map(lambda est_: np.array(est_) if isinstance(est_, list) else est_, alt_est))
320
+ assert x.ndim == 1
321
+ assert alt_est is None or len(alt_est) == x.shape[0]
322
+ max_val = x[-1] if max_ is None else max_
323
+ min_val = x[0] if min_ is None else min_
324
+
325
+ def curr_max_min(val):
326
+ if min_ is None:
327
+ return val
328
+ return max(min_, val)
329
+
330
+ if min_ is not None:
331
+ x[x < min_] = min_
332
+ reduce_ = np.min if aggressive else np.mean
333
+ for i in range(x.shape[-1] - 1, -1, -1):
334
+ if x[i] > max_val or (i > 1 and x[i] < reduce_(x[:i])): # spikes or dips
335
+ if alt_est is None or alt_est[i] is None:
336
+ x[i] = max_val
337
+ else:
338
+ tmp_min = min_val if i < 2 else curr_max_min(np.mean(x[:i]))
339
+ alt_ = alt_est[i][np.logical_and(alt_est[i] < max_val, alt_est[i] > tmp_min)]
340
+ x[i] = max_val if alt_.shape[0] == 0 else alt_[0]
341
+ max_val = x[i]
342
+ return x
343
+
344
+
345
+ def _remove_underestimation(x: Union[np.ndarray, List[Union[int, float]]],
346
+ alt_est: List[Union[list, np.ndarray]] = None,
347
+ min_: (int, float) = None, max_: (int, float) = None,
348
+ aggressive=False) -> np.ndarray:
349
+ x = np.array(x) if isinstance(x, list) else deepcopy(x)
350
+ if alt_est is not None:
351
+ alt_est = list(map(lambda est_: np.array(est_) if isinstance(est_, list) else est_, alt_est))
352
+ assert x.ndim == 1
353
+ assert alt_est is None or len(alt_est) == x.shape[0]
354
+ min_val = x[0] if min_ is None else min_
355
+ max_val = x[-1] if max_ is None else max_
356
+
357
+ def curr_min_max(val):
358
+ if max_ is None:
359
+ return val
360
+ return min(max_, val)
361
+
362
+ if max_ is not None:
363
+ x[x > max_] = max_
364
+ reduce_ = np.max if aggressive else np.mean
365
+ max_i_reduce = x.shape[-1] - 2
366
+ for i in range(0, x.shape[-1]):
367
+ if x[i] < min_val or (i < max_i_reduce and x[i] > reduce_(x[i + 1:])): # dips or spikes
368
+ if alt_est is None or alt_est[i] is None:
369
+ x[i] = min_val
370
+ else:
371
+ tmp_max = max_val if i >= max_i_reduce else curr_min_max(np.mean(x[i + 1:]))
372
+ alt_ = alt_est[i][np.logical_and(alt_est[i] > min_val, alt_est[i] < tmp_max)]
373
+ x[i] = min_val if alt_.shape[0] == 0 else alt_[0]
374
+ min_val = x[i]
375
+ return x
376
+
377
+
378
+ def _merge_max_min_estimation(mx: Union[np.ndarray, List[Union[int, float]]],
379
+ mn: Union[np.ndarray, List[Union[int, float]]],
380
+ alt_est: List[Union[list, np.ndarray]] = None) -> np.ndarray:
381
+ mx = np.array(mx) if isinstance(mx, list) else deepcopy(mx)
382
+ mn = np.array(mn) if isinstance(mn, list) else deepcopy(mn)
383
+ if alt_est is not None:
384
+ alt_est = list(map(lambda est_: np.array(est_) if isinstance(est_, list) else est_, alt_est))
385
+ assert mx.ndim == 1 and mn.ndim == 1
386
+ assert mx.shape[0] == mn.shape[0]
387
+ assert alt_est is None or len(alt_est) == mx.shape[0]
388
+
389
+ pref_mx = np.var(mx) > np.var(mn)
390
+ if pref_mx:
391
+ mn[0] = mx[0]
392
+ prev_min = mn[0]
393
+ for i in range(1, mn.shape[0]):
394
+ if prev_min > mn[i]:
395
+ if mn[i] > mx[i]: # prev_min > mn[i] > mx[i]
396
+ mn[i] = prev_min
397
+ elif mx[i] > mn[i]:
398
+ if prev_min > mx[i]: # prev_min > mx[i] > mn[i]
399
+ mn[i] = prev_min
400
+ else: # mx[i] > prev_min > mn[i]
401
+ alt_ = alt_est[i][np.logical_and(alt_est[i] > prev_min, alt_est[i] < mx[i])]
402
+ mn[i] = (mx[i] if pref_mx else prev_min) if alt_.shape[0] == 0 else alt_[0]
403
+ else: # prev_min > mn[i] == mx[i]
404
+ mn[i] = prev_min
405
+ elif mn[i] > prev_min:
406
+ # if prev_min > mx[i]: # mn[i] > prev_min > mx[i]
407
+ # pass
408
+ if mx[i] > prev_min:
409
+ if mn[i] > mx[i]: # mn[i] > mx[i] > prev_min
410
+ pass
411
+ elif mx[i] > mn[i]: # mx[i] > mn[i] > prev_min
412
+ alt_ = alt_est[i][np.logical_and(alt_est[i] > mn[i], alt_est[i] < mx[i])]
413
+ if alt_.shape[0]:
414
+ mn[i] = alt_[0]
415
+ elif pref_mx:
416
+ mn[i] = mx[i]
417
+ # else: # mx[i] == mn[i] > prev_min
418
+ # pass
419
+ # else: # mn[i] > mx[i] == prev_min
420
+ # pass
421
+ else: # mn[i] == prev_min
422
+ if mx[i] > mn[i]: # mx[i] > mn[i] == prev_min
423
+ alt_ = alt_est[i][np.logical_and(alt_est[i] > mn[i], alt_est[i] < mx[i])]
424
+ if alt_.shape[0]:
425
+ mn[i] = alt_[0]
426
+ elif pref_mx:
427
+ mn[i] = mx[i]
428
+ # elif mn[i] > mx[i]: # mn[i] == prev_min > mx[i]
429
+ # pass
430
+ # else: # mn[i] == prev_min == mx[i]
431
+ # pass
432
+
433
+ prev_min = mn[i]
434
+
435
+ return mn
436
+
437
+
438
+ def _avg_merge_min_max(mx: Union[np.ndarray, List[Union[int, float]]],
439
+ mn: Union[np.ndarray, List[Union[int, float]]],
440
+ alt_timestamps: List[Union[List[Union[int, float]], np.ndarray]] = None,
441
+ max_: (int, float) = None, min_: (int, float) = None):
442
+ mx = np.array(mx) if isinstance(mx, list) else deepcopy(mx)
443
+ mn = np.array(mn) if isinstance(mn, list) else deepcopy(mn)
444
+ assert mx.ndim == mn.ndim == 1
445
+ assert mx.shape[0] == mn.shape[0]
446
+
447
+ avg_ = (mx + mn) / 2
448
+
449
+ if check_ascending_sequence(avg_, verbose=False):
450
+ return avg_
451
+
452
+ if not max_:
453
+ max_ = max(mx[-1], mn[-1])
454
+ if min_ is None:
455
+ min_ = min(mn[0], mx[0])
456
+
457
+ return _stabilize_timestamps(avg_, alt_timestamps, max_=max_, min_=min_)
458
+
459
+
460
+ def _stabilize_timestamps(timestamps: Union[np.ndarray, List[Union[int, float]]],
461
+ alt_timestamps: List[Union[List[Union[int, float]], np.ndarray]] = None,
462
+ max_: (int, float) = None, min_: (int, float) = None, aggressive=False) -> np.ndarray:
463
+ mx = _remove_overestimation(timestamps, alt_est=alt_timestamps, max_=max_, min_=min_, aggressive=aggressive)
464
+ mn = _remove_underestimation(timestamps, alt_est=alt_timestamps, max_=max_, min_=min_, aggressive=aggressive)
465
+ return _merge_max_min_estimation(mx, mn, alt_timestamps)
466
+
467
+
468
+ def _stabilize_more_timestamps(timestamps: List[Union[list, np.ndarray]],
469
+ max_: (int, float) = None, min_: (int, float) = None, average=True) -> np.ndarray:
470
+ mx = _get_max_estimation(timestamps, max_=max_, min_=min_)
471
+ mn = _get_min_estimation(timestamps, max_=max_, min_=min_)
472
+ if average:
473
+ return _avg_merge_min_max(mx, mn, timestamps, max_=max_, min_=min_)
474
+ return _merge_max_min_estimation(mx, mn, timestamps)
475
+
476
+
477
+ def stabilize_timestamps(segments: Union[List[dict], dict],
478
+ top_focus=False, aggressive=False, average=True) -> List[dict]:
479
+ """
480
+
481
+ Parameters
482
+ ----------
483
+ segments: Union[List[dict], dict]
484
+ result['segments'] or result
485
+ top_focus: bool
486
+ adhere closely to the top predictions for word timestamps
487
+ aggressive: bool
488
+ only if top_focus=True,
489
+ allow greater variation in word_timestamps/whole_word_timestamps
490
+ average: bool
491
+ only if top_focus=False,
492
+ average min and max of unstable_word_timestamps to get word_timestamps/whole_word_timestamps
493
+
494
+ """
495
+ if isinstance(segments, dict):
496
+ segments = segments['segments']
497
+ if not segments:
498
+ warnings.warn('No Segments Found')
499
+ return []
500
+ missing_ts_idx = set(map(lambda x: None if x[1].get('unstable_word_timestamps') else x[0], enumerate(segments))) - {
501
+ None}
502
+ no_word_timestamps = len(missing_ts_idx) == len(segments)
503
+ if not no_word_timestamps and missing_ts_idx:
504
+ warnings.warn(f'Segments {list(missing_ts_idx)} are missing unstable_word_timestamps. '
505
+ f'Word-level timestamp stabilization will skipped')
506
+
507
+ segments = deepcopy(segments)
508
+ sectioned_segments: List[List] = [[]]
509
+ for i, seg in enumerate(segments, 1):
510
+ sectioned_segments[-1].append(seg)
511
+ if seg['anchor_point']:
512
+ if i < len(segments):
513
+ sectioned_segments.append([])
514
+
515
+ assert all(set(len(set(s['offset'] for s in segs)) == 1 for segs in sectioned_segments))
516
+
517
+ sectioned_segments_timestamps = [dict(min_=segs[-1]['offset'],
518
+ max_=segs[-1]['next_offset'],
519
+ timestamps=list(chain.from_iterable((s['start'], s['end']) for s in segs)),
520
+ alt_timestamps=list(chain.from_iterable((s['alt_start_timestamps'],
521
+ s['alt_end_timestamps'])
522
+ for s in segs)))
523
+ for segs in sectioned_segments]
524
+
525
+ sectioned_stab_timestamps = [_stabilize_timestamps(**kwargs).reshape(-1, 2) for kwargs in
526
+ sectioned_segments_timestamps]
527
+
528
+ for i in range(len(sectioned_segments)):
529
+ for j in range(len(sectioned_segments[i])):
530
+ sectioned_segments[i][j]['start'], sectioned_segments[i][j]['end'] = sectioned_stab_timestamps[i][j]
531
+
532
+ if not missing_ts_idx:
533
+ if top_focus:
534
+ top_word_ts = [ts_['timestamps'][0] for ts_ in
535
+ sectioned_segments[i][j]['unstable_word_timestamps']]
536
+ alt_word_ts = [ts_['timestamps'][1:] for ts_ in
537
+ sectioned_segments[i][j]['unstable_word_timestamps']]
538
+ temp_stab_word_ts = _stabilize_timestamps(top_word_ts, alt_word_ts,
539
+ max_=sectioned_segments[i][j]['end'],
540
+ min_=sectioned_segments[i][j]['start'],
541
+ aggressive=aggressive)
542
+ else:
543
+ word_ts = [ts_['timestamps'] for ts_ in sectioned_segments[i][j]['unstable_word_timestamps']]
544
+ temp_stab_word_ts = _stabilize_more_timestamps(word_ts,
545
+ max_=sectioned_segments[i][j]['end'],
546
+ min_=sectioned_segments[i][j]['start'],
547
+ average=average)
548
+
549
+ temp_stab_word_ts = [{'word': sectioned_segments[i][j]['unstable_word_timestamps'][k]['word'],
550
+ 'token': sectioned_segments[i][j]['unstable_word_timestamps'][k]['token'],
551
+ 'timestamp': temp_stab_word_ts[k]}
552
+ for k in range(temp_stab_word_ts.shape[0])]
553
+
554
+ sectioned_segments[i][j]['word_timestamps'] = temp_stab_word_ts
555
+
556
+ return list(chain.from_iterable(sectioned_segments))
557
+
558
+
559
+ def save_as_json(results, path):
560
+ with open(path, 'w', encoding='utf-8') as f:
561
+ json.dump(results, f)
562
+
563
+
564
+ def add_whole_word_ts(tokenizer: Tokenizer, segments: Union[List[dict], dict], merge_non_space: bool = None,
565
+ prepend_punctuations: Union[List[str], Tuple[str]] = None,
566
+ append_punctuations: Union[List[str], Tuple[str]] = None):
567
+ merge_non_space = (tokenizer.language in ['en'] or tokenizer.language is None) \
568
+ if merge_non_space is None else merge_non_space
569
+ if prepend_punctuations is None:
570
+ prepend_punctuations = r'“¿([{'
571
+ if append_punctuations is None:
572
+ append_punctuations = r'.。,,!!??::”)]}、'
573
+ if isinstance(segments, dict):
574
+ segments = segments['segments']
575
+ if not segments:
576
+ print('No segments found, whole-word timestamps cannot be added.')
577
+ return
578
+
579
+ missing_idx = set(-1 if seg.get('word_timestamps') else i for i, seg in enumerate(segments)) - {-1}
580
+
581
+ if missing_idx:
582
+ if len(missing_idx) == len(segments):
583
+ print('No word_timestamps found, whole-word timestamps cannot be added.')
584
+ return
585
+ print(f'Some word_timestamps not found, '
586
+ f'whole-word timestamps cannot be added to the following segments: {tuple(missing_idx)}')
587
+
588
+ failed_idx = []
589
+
590
+ for seg_idx, seg in enumerate(segments):
591
+ if seg.get('word_timestamps'):
592
+ prev_idx = 0
593
+ remaining_text = seg['text']
594
+ has_prepend = False
595
+ whole_word_timestamps: List[dict] = []
596
+ for wts_idx in range(1, len(seg['word_timestamps']) + 1):
597
+ max_ts = seg['word_timestamps'][wts_idx - 1]['timestamp']
598
+ tokens = [wts['token'] for wts in seg['word_timestamps'][prev_idx: wts_idx]]
599
+ temp_whole_word = tokenizer.decode(tokens)
600
+ if temp_whole_word == remaining_text[:len(temp_whole_word)]:
601
+ prev_idx = wts_idx
602
+ remaining_text = remaining_text[len(temp_whole_word):]
603
+ if (not merge_non_space or temp_whole_word.startswith(' ') or not whole_word_timestamps) and \
604
+ temp_whole_word not in append_punctuations and \
605
+ not has_prepend:
606
+ has_prepend = temp_whole_word.strip() in prepend_punctuations
607
+ whole_word_timestamps.append(dict(word=temp_whole_word, timestamp=max_ts))
608
+ else:
609
+ has_prepend = False
610
+ if whole_word_timestamps == []:
611
+ continue
612
+ whole_word_timestamps[-1]['word'] += temp_whole_word
613
+ whole_word_timestamps[-1]['timestamp'] = max_ts
614
+ if remaining_text:
615
+ failed_idx.append(seg_idx)
616
+ whole_word_timestamps = []
617
+ seg['whole_word_timestamps'] = whole_word_timestamps or None
618
+ else:
619
+ seg['whole_word_timestamps'] = None
620
+
621
+ if failed_idx:
622
+ print(f'Failed to add whole-word timestamps to the following segments: {tuple(failed_idx)}')
623
+
624
+
625
+ def _load_audio_waveform(audio: Union[str, bytes, np.ndarray, torch.Tensor], h: int, w: int) -> np.ndarray:
626
+ """
627
+
628
+ Parameters
629
+ ----------
630
+ audio: Union[str, bytes, np.ndarray, torch.Tensor], shape = (*)
631
+ The path to audio or bytes of audio file or a NumPy array or Tensor containing the audio waveform in 16 kHz
632
+ h: int
633
+ Height of waveform image
634
+ w: int
635
+ Width of waveform image
636
+
637
+ Returns
638
+ -------
639
+ Audio waveform image as a NumPy array, in uint8 dtype.
640
+ """
641
+
642
+ try:
643
+ if isinstance(audio, str):
644
+ stream = ffmpeg.input(audio, threads=0)
645
+ inp = None
646
+
647
+ else:
648
+ if isinstance(audio, bytes):
649
+ stream = ffmpeg.input('pipe:', threads=0)
650
+ inp = audio
651
+ else:
652
+ warnings.warn('A resampled input causes an unexplained temporal shift in waveform image '
653
+ 'that will skew the timestamp suppression and may result in inaccurate timestamps.\n'
654
+ 'Use audio_for_mask for transcribe() to provide the original audio track '
655
+ 'as the path or bytes of the audio file.',
656
+ stacklevel=2)
657
+ stream = ffmpeg.input('pipe:', threads=0, ac=1, format='s16le')
658
+ if isinstance(audio, torch.Tensor):
659
+ audio = np.array(audio)
660
+ inp = (audio * 32768.0).astype(np.int16).tobytes()
661
+
662
+ waveform, err = (
663
+ stream.filter('aformat', channel_layouts='mono')
664
+ .filter('highpass', f='200').filter('lowpass', f='3000')
665
+ .filter('showwavespic', s=f'{w}x{h}')
666
+ .output('-', pix_fmt='gray', format='rawvideo')
667
+ .run(cmd="ffmpeg", capture_stdout=True, capture_stderr=True, input=inp)
668
+ )
669
+ except ffmpeg.Error as e:
670
+ raise RuntimeError(f"Failed to load audio in waveform: {e.stderr.decode()}") from e
671
+ else:
672
+ if not waveform:
673
+ partial_file = b'partial file' in err and b'Output file is empty' in err
674
+ add_msg = '\nMetadata for decoding are likely at end of file, try to use path of audio instead.' \
675
+ if partial_file and isinstance(audio, bytes) else ''
676
+ raise RuntimeError(f"Failed to load audio in waveform: {err.decode()}" + add_msg)
677
+ return np.frombuffer(waveform, dtype=np.uint8).reshape(h, w)
678
+
679
+
680
+ def _remove_lower_quantile(waveform: np.ndarray,
681
+ upper_quantile: float = None,
682
+ lower_quantile: float = None,
683
+ lower_threshold: float = None) -> np.ndarray:
684
+ """
685
+ Removes lower quantile of amplitude from waveform image
686
+ """
687
+ if upper_quantile is None:
688
+ upper_quantile = 0.85
689
+ if lower_quantile is None:
690
+ lower_quantile = 0.15
691
+ if lower_threshold is None:
692
+ lower_threshold = 0.15
693
+ waveform = deepcopy(waveform)
694
+ wave_sums = waveform.sum(0)
695
+ mx = np.quantile(wave_sums, upper_quantile, -1)
696
+ mn = np.quantile(wave_sums, lower_quantile, -1)
697
+ mn_threshold = (mx - mn) * lower_threshold + mn
698
+ waveform[:, wave_sums < mn_threshold] = 0
699
+ return waveform
700
+
701
+
702
+ def _wave_to_ts_filter(waveform: np.ndarray, suppress_middle=True,
703
+ max_index: (list, int) = None) -> np.ndarray:
704
+ """
705
+ Returns A NumPy array mask of sections with amplitude zero
706
+ """
707
+ assert waveform.ndim <= 2, f'waveform have at most 2 dims but found {waveform.ndim}'
708
+ if waveform.ndim == 1:
709
+ wave_sum = waveform
710
+ else:
711
+ wave_sum = waveform.sum(-2)
712
+
713
+ wave_filter = wave_sum.astype(bool)
714
+
715
+ if not suppress_middle:
716
+ nonzero_indices = wave_filter.nonzero()[0]
717
+ wave_filter[nonzero_indices[0]:nonzero_indices[-1] + 1] = True
718
+ if max_index is not None:
719
+ wave_filter[max_index + 1:] = False
720
+
721
+ return ~wave_filter
722
+
723
+
724
+ # modified version of whisper.transcribe.transcribe
725
+ def transcribe_word_level(
726
+ model: "Whisper",
727
+ audio: Union[str, np.ndarray, torch.Tensor],
728
+ *,
729
+ verbose: bool = False,
730
+ temperature: Union[float, Tuple[float, ...]] = (0.0, 0.2, 0.4, 0.6, 0.8, 1.0),
731
+ compression_ratio_threshold: Optional[float] = 2.4,
732
+ logprob_threshold: Optional[float] = -1.0,
733
+ no_speech_threshold: Optional[float] = 0.6,
734
+ condition_on_previous_text: bool = True,
735
+ stab=True, top_focus=False, ts_num: int = 10,
736
+ alpha: float = None, print_unstab=False,
737
+ suppress_silence: bool = True,
738
+ suppress_middle: bool = True,
739
+ suppress_word_ts: bool = True,
740
+ remove_background: bool = True,
741
+ silence_threshold: float = 0.1,
742
+ prepend_punctuations: Union[List[str], Tuple[str]] = None,
743
+ append_punctuations: Union[List[str], Tuple[str]] = None,
744
+ audio_for_mask: (str, bytes) = None,
745
+ **decode_options):
746
+ """
747
+ Transcribe an audio file using Whisper
748
+
749
+ Parameters
750
+ ----------
751
+ model: Whisper
752
+ The Whisper model instance
753
+
754
+ audio: Union[str, np.ndarray, torch.Tensor]
755
+ The path to the audio file to open, or the audio waveform
756
+
757
+ verbose: bool
758
+ Whether to display the decoded text (with finalized timestamps) to the console
759
+
760
+ temperature: Union[float, Tuple[float, ...]]
761
+ Temperature for sampling. It can be a tuple of temperatures, which will be successfully used
762
+ upon failures according to either `compression_ratio_threshold` or `logprob_threshold`.
763
+
764
+ compression_ratio_threshold: float
765
+ If the gzip compression ratio is above this value, treat as failed
766
+
767
+ logprob_threshold: float
768
+ If the average log probability over sampled tokens is below this value, treat as failed
769
+
770
+ no_speech_threshold: float
771
+ If the no_speech probability is higher than this value AND the average log probability
772
+ over sampled tokens is below `logprob_threshold`, consider the segment as silent
773
+
774
+ condition_on_previous_text: bool
775
+ if True, the previous output of the model is provided as a prompt for the next window;
776
+ disabling may make the text inconsistent across windows, but the model becomes less prone to
777
+ getting stuck in a failure loop, such as repetition looping or timestamps going out of sync.
778
+
779
+ stab: bool
780
+ Stabilizing timestamps by cross compare timestamps and using additional top timestamp predictions
781
+ to fill in when appropriate to ensure timestamps are chronological.
782
+
783
+ top_focus: bool
784
+ Adhere closely to the top predictions for token timestamps stabilization
785
+
786
+ ts_num: int
787
+ Number of top timestamp predictions to save for each word for postprocessing stabilization (default: 10).
788
+
789
+ alpha: float
790
+ Amount of noise to add to audio to produce slightly difference results.
791
+ audio_features *= torch.rand_like(audio_features) * alpha + 1
792
+
793
+ print_unstab: bool
794
+ Whether to display the text (without stabilize timestamps) being decoded to the console
795
+ (i.e. behaves like verbose before model was modified)
796
+
797
+ suppress_silence: bool
798
+ Suppress timestamp tokens that are marked as silent
799
+
800
+ suppress_middle: bool
801
+ Suppress any silent timestamps tokens of middle of the segment instead of only beginning and ending
802
+
803
+ suppress_word_ts: bool
804
+ Suppress timestamp tokens of words that are marked as silent
805
+
806
+ remove_background: bool
807
+ Whether to remove background noise from waveform so that it is marked silent.
808
+ Determined by parameters part of decode_options (i.e. specify like other options here):
809
+ upper_quantile: float
810
+ The upper quantile of amplitude to determine a max amplitude, mx (Default: 0.85)
811
+ lower_quantile: float
812
+ The lower quantile of amplitude to determine a min amplitude, mn (Default: 0.15)
813
+ lower_threshold: float
814
+ Suppressed sections of waveform where amplitude < lower_threshold*(mx-mn) + mn. (Default: 0.15)
815
+
816
+ silence_threshold: float:
817
+ Audio segments silence average >= silence_threshold
818
+ then that segment will not have background removed even if remove_background=True.
819
+ e.g. 0.5 means if less than half of the audio segment is silent then background will be removed accordingly
820
+
821
+ prepend_punctuations: Union[List[str], Tuple[str]]
822
+ Punctuations to prepend to next word (Default: “¿([{)
823
+
824
+ append_punctuations: Union[List[str], Tuple[str]]
825
+ Punctuations to append to previous word (Default: .。,,!!??::”)]}、)
826
+
827
+ audio_for_mask: (str, bytes)
828
+ Original audio track as path or bytes of audio file.
829
+ Since resampled audio may shift the waveform image,
830
+ this is an alternative to 'audio' option to generate suppression mask from the original audio.
831
+
832
+ decode_options: dict
833
+ Keyword arguments to construct `DecodingOptions` instances
834
+
835
+ Returns
836
+ -------
837
+ A dictionary containing the resulting text ("text") and segment-level details ("segments"), and
838
+ the spoken language ("language"), which is detected when `decode_options["language"]` is None.
839
+ """
840
+
841
+ if 'no_captions_threshold' in decode_options:
842
+ warnings.warn('no_captions_threshold is deprecated. '
843
+ 'Please use no_speech_threshold instead.', DeprecationWarning, stacklevel=2)
844
+ no_speech_threshold = decode_options.pop('no_captions_threshold')
845
+
846
+ dtype = torch.float16 if decode_options.get("fp16", True) else torch.float32
847
+ if model.device == torch.device("cpu"):
848
+ if torch.cuda.is_available():
849
+ warnings.warn("Performing inference on CPU when CUDA is available")
850
+ if dtype == torch.float16:
851
+ warnings.warn("FP16 is not supported on CPU; using FP32 instead")
852
+ dtype = torch.float32
853
+
854
+ if dtype == torch.float32:
855
+ decode_options["fp16"] = False
856
+
857
+ if 'max_initial_timestamp' not in decode_options:
858
+ decode_options['max_initial_timestamp'] = None
859
+
860
+ mel = log_mel_spectrogram(audio)
861
+
862
+ if decode_options.get("language", None) is None:
863
+ if verbose:
864
+ print("Detecting language using up to the first 30 seconds. Use `--language` to specify the language")
865
+ segment = pad_or_trim(mel, N_FRAMES).to(model.device).to(dtype)
866
+ _, probs = model.detect_language(segment)
867
+ decode_options["language"] = max(probs, key=probs.get)
868
+ print(f"Detected language: {LANGUAGES[decode_options['language']]}")
869
+
870
+ mel = mel.unsqueeze(0)
871
+ language = decode_options["language"]
872
+ task = decode_options.get("task", "transcribe")
873
+ tokenizer = get_tokenizer(model.is_multilingual, language=language, task=task)
874
+
875
+ def decode_with_fallback(segment: torch.Tensor, suppress_ts_mask: Tensor = None) \
876
+ -> Union[List[DecodingResult], tuple]:
877
+ temperatures = [temperature] if isinstance(temperature, (int, float)) else temperature
878
+ kwargs = {**decode_options}
879
+ t = temperatures[0]
880
+ if t == 0:
881
+ best_of = kwargs.pop("best_of", None)
882
+ else:
883
+ best_of = kwargs.get("best_of", None)
884
+
885
+ options = DecodingOptions(**kwargs, temperature=t)
886
+ results, ts_tokens, ts_logits_ = model.decode(segment, options, ts_num=ts_num, alpha=alpha,
887
+ suppress_ts_mask=suppress_ts_mask,
888
+ suppress_word_ts=suppress_word_ts)
889
+
890
+ kwargs.pop("beam_size", None) # no beam search for t > 0
891
+ kwargs.pop("patience", None) # no patience for t > 0
892
+ kwargs["best_of"] = best_of # enable best_of for t > 0
893
+ for t in temperatures[1:]:
894
+ needs_fallback = [
895
+ compression_ratio_threshold is not None
896
+ and result.compression_ratio > compression_ratio_threshold
897
+ or logprob_threshold is not None
898
+ and result.avg_logprob < logprob_threshold
899
+ for result in results
900
+ ]
901
+ if any(needs_fallback):
902
+ options = DecodingOptions(**kwargs, temperature=t)
903
+ retries, r_ts_tokens, r_ts_logits = model.decode(segment[needs_fallback], options,
904
+ ts_num=ts_num, alpha=alpha,
905
+ suppress_ts_mask=suppress_ts_mask,
906
+ suppress_word_ts=suppress_word_ts)
907
+ for retry_index, original_index in enumerate(np.nonzero(needs_fallback)[0]):
908
+ results[original_index] = retries[retry_index]
909
+ ts_tokens[original_index] = r_ts_tokens[retry_index]
910
+ ts_logits_[original_index] = r_ts_logits[retry_index]
911
+
912
+ return results, ts_tokens, ts_logits_
913
+
914
+ seek = 0
915
+ input_stride = exact_div(
916
+ N_FRAMES, model.dims.n_audio_ctx
917
+ ) # mel frames per output token: 2
918
+ time_precision = (
919
+ input_stride * HOP_LENGTH / SAMPLE_RATE
920
+ ) # time per output token: 0.02 (seconds)
921
+ all_tokens = []
922
+ all_segments = []
923
+ prompt_reset_since = 0
924
+
925
+ initial_prompt = decode_options.pop("initial_prompt", None) or []
926
+ if initial_prompt:
927
+ initial_prompt = tokenizer.encode(" " + initial_prompt.strip())
928
+ all_tokens.extend(initial_prompt)
929
+
930
+ def _to_list(x: (Tensor, None)):
931
+ if x is None:
932
+ return x
933
+ return x.tolist()
934
+
935
+ def add_segment(
936
+ *, offset: float, start: float, end: float, text_tokens: Tensor, result: DecodingResult,
937
+ start_timestamps: list = None, end_timestamps: list = None, word_timestamps: Tensor = None,
938
+ start_ts_logits: list = None, end_ts_logits: list = None, word_ts_logits: Tensor = None
939
+ ):
940
+ no_eot_mask = text_tokens < tokenizer.eot
941
+ text_tokens_no_eot = text_tokens[no_eot_mask]
942
+ text = tokenizer.decode(text_tokens_no_eot)
943
+
944
+ if len(text.strip()) == 0: # skip empty text output
945
+ return
946
+
947
+ if word_timestamps is not None:
948
+ assert word_timestamps.shape[0] == text_tokens.shape[0]
949
+ if word_ts_logits is None:
950
+ word_ts_fields = zip(text_tokens_no_eot, word_timestamps[no_eot_mask], repeat(None))
951
+ else:
952
+ assert word_ts_logits.shape[0] == text_tokens.shape[0]
953
+ word_ts_fields = zip(text_tokens_no_eot, word_timestamps[no_eot_mask], word_ts_logits[no_eot_mask])
954
+
955
+ word_timestamps = [dict(word=tokenizer.decode([token]),
956
+ token=token.item(),
957
+ timestamps=timestamps_.tolist(),
958
+ timestamp_logits=_to_list(ts_logits_))
959
+ for token, timestamps_, ts_logits_ in word_ts_fields]
960
+
961
+ all_segments.append(
962
+ {
963
+ "id": len(all_segments),
964
+ "seek": seek,
965
+ 'offset': offset, # offset = float(seek * HOP_LENGTH / SAMPLE_RATE)
966
+ "start": start,
967
+ "end": end,
968
+ "text": text,
969
+ "tokens": result.tokens,
970
+ "temperature": result.temperature,
971
+ "avg_logprob": result.avg_logprob,
972
+ "compression_ratio": result.compression_ratio,
973
+ "no_speech_prob": get_new_attrs(result, 'no_caption_prob'),
974
+ "alt_start_timestamps": start_timestamps,
975
+ "start_ts_logits": start_ts_logits,
976
+ "alt_end_timestamps": end_timestamps,
977
+ "end_ts_logits": end_ts_logits,
978
+ "unstable_word_timestamps": word_timestamps,
979
+ 'anchor_point': False
980
+ }
981
+ )
982
+ if print_unstab or (verbose and not stab):
983
+ print(f'[{format_timestamp(start)} --> {format_timestamp(end)}] "{text}"')
984
+ if word_timestamps is not None:
985
+ ts_str = (f' ->[{format_timestamp(ts_["timestamps"][0])}] "{ts_["word"].strip()}"' for ts_ in
986
+ word_timestamps)
987
+ print('\n'.join(ts_str), end='\n\n')
988
+
989
+ if suppress_silence:
990
+ ts_scale = HOP_LENGTH / SAMPLE_RATE / time_precision
991
+ wf = _load_audio_waveform(audio_for_mask or audio, 100, int(mel.shape[-1] * ts_scale))
992
+
993
+ upper_quantile = decode_options.pop('upper_quantile', 0.85)
994
+ lower_quantile = decode_options.pop('lower_quantile', 0.15)
995
+ lower_threshold = decode_options.pop('lower_threshold', 0.15)
996
+
997
+ while seek < mel.shape[-1]:
998
+ timestamp_offset = float(seek * HOP_LENGTH / SAMPLE_RATE)
999
+ remaining_duration = float((mel.shape[-1] - seek) * HOP_LENGTH / SAMPLE_RATE)
1000
+ segment = pad_or_trim(mel[:, :, seek:], N_FRAMES).to(model.device).to(dtype)
1001
+ segment_duration = min(float(segment.shape[-1] * HOP_LENGTH / SAMPLE_RATE), remaining_duration)
1002
+ segment_max_ts = segment_duration / time_precision
1003
+
1004
+ if suppress_silence:
1005
+ wf_seek = int(seek * ts_scale)
1006
+ segment_wf = wf[..., wf_seek:wf_seek + 1501]
1007
+ if remove_background and \
1008
+ (1 - segment_wf.sum(0).clip(max=1).mean()) < silence_threshold:
1009
+ segment_wf = _remove_lower_quantile(segment_wf.astype(np.float32),
1010
+ upper_quantile=upper_quantile,
1011
+ lower_quantile=lower_quantile,
1012
+ lower_threshold=lower_threshold)
1013
+ segment_wf = pad_or_trim(segment_wf, 1501)
1014
+ suppress_ts_mask = torch.from_numpy(_wave_to_ts_filter(segment_wf,
1015
+ suppress_middle=suppress_middle,
1016
+ max_index=int(segment_max_ts)))
1017
+
1018
+ if suppress_ts_mask.all(): # segment is silent
1019
+ seek += segment.shape[-1] # fast-forward to the next segment boundary
1020
+ continue
1021
+ else:
1022
+ suppress_ts_mask = None
1023
+
1024
+ decode_options["prompt"] = all_tokens[prompt_reset_since:]
1025
+ result, finalized_ts_tokens, ts_logits = decode_with_fallback(segment,
1026
+ suppress_ts_mask=suppress_ts_mask)
1027
+
1028
+ result = result[0]
1029
+ tokens = torch.tensor(result.tokens)
1030
+ finalized_ts_tokens = torch.tensor(finalized_ts_tokens[0])
1031
+ ts_logits = torch.tensor(ts_logits[0])
1032
+
1033
+ if no_speech_threshold is not None:
1034
+ # no voice activity check
1035
+ should_skip = get_new_attrs(result, 'no_caption_prob') > no_speech_threshold
1036
+ if logprob_threshold is not None and result.avg_logprob > logprob_threshold:
1037
+ # don't skip if the logprob is high enough, despite the no_speech_prob
1038
+ should_skip = False
1039
+
1040
+ if should_skip:
1041
+ seek += segment.shape[-1] # fast-forward to the next segment boundary
1042
+ continue
1043
+
1044
+ timestamp_tokens: torch.Tensor = tokens.ge(tokenizer.timestamp_begin)
1045
+ consecutive = torch.where(timestamp_tokens[:-1] & timestamp_tokens[1:])[0].add_(1)
1046
+ if len(consecutive) > 0: # if the output contains two consecutive timestamp tokens
1047
+ last_slice = 0
1048
+ for current_slice in consecutive:
1049
+ sliced_tokens = tokens[last_slice:current_slice]
1050
+ sliced_ts_tokens = finalized_ts_tokens[last_slice:current_slice]
1051
+ sliced_ts_logits = ts_logits[last_slice:current_slice]
1052
+ start_timestamp_position = (
1053
+ sliced_tokens[0].item() - tokenizer.timestamp_begin
1054
+ )
1055
+ end_timestamp_position = (
1056
+ sliced_tokens[-1].item() - tokenizer.timestamp_begin
1057
+ )
1058
+
1059
+ word_ts = timestamp_offset + sliced_ts_tokens * time_precision
1060
+
1061
+ add_segment(
1062
+ offset=timestamp_offset,
1063
+ start=timestamp_offset + start_timestamp_position * time_precision,
1064
+ end=min(timestamp_offset + end_timestamp_position * time_precision,
1065
+ timestamp_offset + segment_duration),
1066
+ text_tokens=sliced_tokens[1:-1],
1067
+ result=result,
1068
+ start_timestamps=word_ts[0].tolist(),
1069
+ end_timestamps=word_ts[-1].tolist(),
1070
+ word_timestamps=word_ts[1:-1],
1071
+ start_ts_logits=sliced_ts_logits[0].tolist(),
1072
+ end_ts_logits=sliced_ts_logits[-1].tolist(),
1073
+ word_ts_logits=sliced_ts_logits[1:-1]
1074
+ )
1075
+ last_slice = current_slice
1076
+ last_timestamp_position = (
1077
+ min(tokens[last_slice - 1].item() - tokenizer.timestamp_begin, segment_max_ts)
1078
+ )
1079
+ seek += last_timestamp_position * input_stride
1080
+ all_tokens.extend(tokens[: last_slice + 1].tolist())
1081
+ else:
1082
+ duration = segment_duration
1083
+ timestamps = tokens[timestamp_tokens.nonzero().flatten()]
1084
+ if len(timestamps) > 0:
1085
+ # no consecutive timestamps but it has a timestamp; use the last one.
1086
+ # single timestamp at the end means no speech after the last timestamp.
1087
+ last_timestamp_position = min(timestamps[-1].item() - tokenizer.timestamp_begin, segment_max_ts)
1088
+ duration = last_timestamp_position * time_precision
1089
+
1090
+ word_ts = timestamp_offset + finalized_ts_tokens * time_precision
1091
+
1092
+ add_segment(
1093
+ offset=timestamp_offset,
1094
+ start=timestamp_offset,
1095
+ end=timestamp_offset + duration,
1096
+ text_tokens=tokens,
1097
+ result=result,
1098
+ word_timestamps=word_ts,
1099
+ word_ts_logits=ts_logits
1100
+ )
1101
+
1102
+ seek += segment.shape[-1]
1103
+ all_tokens.extend(tokens.tolist())
1104
+
1105
+ if all_segments:
1106
+ all_segments[-1]['anchor_point'] = True
1107
+ all_segments[-1]['next_offset'] = float(seek * HOP_LENGTH / SAMPLE_RATE)
1108
+ if not condition_on_previous_text or result.temperature > 0.5:
1109
+ # do not feed the prompt tokens if a high temperature was used
1110
+ prompt_reset_since = len(all_tokens)
1111
+
1112
+ if len(all_segments) > 1 and all_segments[-1]['alt_start_timestamps'] is None:
1113
+ all_segments[-1]['alt_start_timestamps'] = all_segments[-2]['alt_end_timestamps']
1114
+
1115
+ if stab:
1116
+ all_segments = stabilize_timestamps(all_segments, top_focus=top_focus)
1117
+ add_whole_word_ts(tokenizer, all_segments,
1118
+ prepend_punctuations=prepend_punctuations,
1119
+ append_punctuations=append_punctuations)
1120
+ if verbose:
1121
+ print('\nSTABILIZED\n')
1122
+ for seg_ in all_segments:
1123
+ print(f'[{format_timestamp(seg_["start"])} --> {format_timestamp(seg_["end"])}] "{seg_["text"]}"')
1124
+ if seg_['word_timestamps']:
1125
+ ts_str = (f' ->[{format_timestamp(ts_["timestamp"])}] "{ts_["word"].strip()}"' for ts_ in
1126
+ seg_['word_timestamps'])
1127
+ print('\n'.join(ts_str), end='\n\n')
1128
+
1129
+ return dict(text=tokenizer.decode(all_tokens[len(initial_prompt):]), segments=all_segments, language=language)
1130
+
1131
+
1132
+ def _suppress_ts(ts_logits: Tensor, suppress_ts_mask: Tensor = None):
1133
+ if suppress_ts_mask is not None:
1134
+ ts_logits[:, suppress_ts_mask] = -np.inf
1135
+
1136
+
1137
+ def _ts_topk(ts_logits: Tensor, k: int, prev_ts: Tensor = None) -> Tensor:
1138
+ temp_ts = torch.stack(torch.topk(ts_logits, k, dim=-1), 0).unsqueeze(-2)
1139
+ return temp_ts if prev_ts is None else torch.cat([prev_ts, temp_ts], dim=-2)
1140
+
1141
+
1142
+ # modified version of whisper.GreedyDecoder
1143
+ class GreedyDecoderWordLevel(GreedyDecoder):
1144
+ def __init__(self, *args, **kwargs):
1145
+ self.ts_num: int = kwargs.pop('ts_num', 10)
1146
+ self.suppress_ts_mask: Tensor = kwargs.pop('suppress_ts_mask', None)
1147
+ self.timestamp_begin = kwargs.pop('timestamp_begin', 50364)
1148
+ super(GreedyDecoderWordLevel, self).__init__(*args, **kwargs)
1149
+ self.ts = None
1150
+
1151
+ def _suppress_ts(self, logits: Tensor):
1152
+ _suppress_ts(logits[:, self.timestamp_begin:],
1153
+ suppress_ts_mask=self.suppress_ts_mask)
1154
+
1155
+ def update_with_ts(self, tokens: Tensor, logits: Tensor, sum_logprobs: Tensor, ts: Tensor) -> Tuple[Tensor, bool]:
1156
+ self.ts = ts
1157
+
1158
+ self._suppress_ts(logits)
1159
+
1160
+ if self.temperature == 0:
1161
+ next_tokens = logits.argmax(dim=-1)
1162
+ else:
1163
+ next_tokens = Categorical(logits=logits / self.temperature).sample()
1164
+
1165
+ logprobs = F.log_softmax(logits.float(), dim=-1)
1166
+ current_logprobs = logprobs[torch.arange(logprobs.shape[0]), next_tokens]
1167
+ sum_logprobs += current_logprobs * (tokens[:, -1] != self.eot)
1168
+
1169
+ next_tokens[tokens[:, -1] == self.eot] = self.eot
1170
+ tokens = torch.cat([tokens, next_tokens[:, None]], dim=-1)
1171
+
1172
+ completed = (tokens[:, -1] == self.eot).all()
1173
+ return tokens, completed
1174
+
1175
+ def finalize(self, tokens: Tensor, sum_logprobs: Tensor):
1176
+ # make sure each sequence has at least one EOT token at the end
1177
+ tokens = F.pad(tokens, (0, 1), value=self.eot)
1178
+ return tokens, sum_logprobs.tolist(), self.ts.transpose(1, 0)[None]
1179
+
1180
+
1181
+ # modified version of whisper.BeamSearchDecoder
1182
+ class BeamSearchDecoderWordLevel(BeamSearchDecoder):
1183
+
1184
+ def __init__(self, *args, **kwargs):
1185
+ self.ts_num: int = kwargs.pop('ts_num', 10)
1186
+ self.suppress_ts_mask: Tensor = kwargs.pop('suppress_ts_mask', None)
1187
+ self.timestamp_begin = kwargs.pop('timestamp_begin', 50364)
1188
+ super(BeamSearchDecoderWordLevel, self).__init__(*args, **kwargs)
1189
+ self.ts = None
1190
+ self.finished_ts_ls = None
1191
+
1192
+ def reset(self):
1193
+ self.finished_sequences = None
1194
+ self.finished_ts_ls = None
1195
+
1196
+ def _suppress_ts(self, logits: Tensor):
1197
+ _suppress_ts(logits[:, self.timestamp_begin:],
1198
+ suppress_ts_mask=self.suppress_ts_mask)
1199
+
1200
+ def update_with_ts(self, tokens: Tensor, logits: Tensor, sum_logprobs: Tensor, ts: Tensor) -> Tuple[Tensor, bool]:
1201
+ if tokens.shape[0] % self.beam_size != 0:
1202
+ raise ValueError(f"{tokens.shape}[0] % {self.beam_size} != 0")
1203
+
1204
+ self.ts = ts
1205
+
1206
+ n_audio = tokens.shape[0] // self.beam_size
1207
+ if self.finished_sequences is None: # for the first update
1208
+ self.finished_sequences = [{} for _ in range(n_audio)]
1209
+ self.finished_ts_ls = [{} for _ in range(n_audio)]
1210
+
1211
+ logprobs = F.log_softmax(logits.float(), dim=-1)
1212
+ next_tokens, source_indices, finished_sequences, finished_ts_ls = [], [], [], []
1213
+
1214
+ self._suppress_ts(logprobs)
1215
+
1216
+ for i in range(n_audio):
1217
+ scores, sources, finished, finished_ts = {}, {}, {}, {}
1218
+
1219
+ # STEP 1: calculate the cumulative log probabilities for possible candidates
1220
+ for j in range(self.beam_size):
1221
+ idx = i * self.beam_size + j
1222
+ prefix = tokens[idx].tolist()
1223
+ for logprob, token in zip(*logprobs[idx].topk(self.beam_size + 1)):
1224
+ new_logprob = (sum_logprobs[idx] + logprob).item()
1225
+ sequence = tuple(prefix + [token.item()])
1226
+ scores[sequence] = new_logprob
1227
+ sources[sequence] = idx
1228
+
1229
+ # STEP 2: rank the candidates and keep the top beam_size sequences for each audio
1230
+ saved = 0
1231
+ for sequence in sorted(scores, key=scores.get, reverse=True):
1232
+ if sequence[-1] == self.eot:
1233
+ finished[sequence] = scores[sequence]
1234
+ finished_ts[sequence] = self.ts[:, sources[sequence]]
1235
+ else:
1236
+ sum_logprobs[len(next_tokens)] = scores[sequence]
1237
+ next_tokens.append(sequence)
1238
+ source_indices.append(sources[sequence])
1239
+
1240
+ saved += 1
1241
+ if saved == self.beam_size:
1242
+ break
1243
+
1244
+ finished_sequences.append(finished)
1245
+ finished_ts_ls.append(finished_ts)
1246
+
1247
+ tokens = torch.tensor(next_tokens, device=tokens.device)
1248
+ self.inference.rearrange_kv_cache(source_indices)
1249
+ self.ts = self.ts[:, source_indices]
1250
+
1251
+ # add newly finished sequences to self.finished_sequences
1252
+ assert len(self.finished_sequences) == len(finished_sequences)
1253
+ for previously_finished, newly_finished, \
1254
+ prev_ts_ls, new_ts_ls in \
1255
+ zip(self.finished_sequences, finished_sequences,
1256
+ self.finished_ts_ls, finished_ts_ls):
1257
+ for seq in sorted(newly_finished, key=newly_finished.get, reverse=True):
1258
+ if len(previously_finished) >= self.max_candidates:
1259
+ break # the candidate list is full
1260
+ previously_finished[seq] = newly_finished[seq]
1261
+ prev_ts_ls[seq] = new_ts_ls[seq]
1262
+
1263
+ # mark as completed if all audio has enough number of samples
1264
+ completed = all(
1265
+ len(sequences) >= self.max_candidates for sequences in self.finished_sequences
1266
+ )
1267
+ return tokens, completed
1268
+
1269
+ def finalize(self, preceding_tokens: Tensor, sum_logprobs: Tensor):
1270
+ # collect all finished sequences, including patience, and add unfinished ones if not enough
1271
+ self.ts = self.ts.reshape(self.ts.shape[0], *preceding_tokens.shape[:2], *self.ts.shape[2:])
1272
+ sum_logprobs = sum_logprobs.cpu()
1273
+ for i, (sequences, ts_) in \
1274
+ enumerate(zip(self.finished_sequences, self.finished_ts_ls)):
1275
+ if len(sequences) < self.beam_size: # when not enough sequences are finished
1276
+ for j in list(np.argsort(sum_logprobs[i]))[::-1]:
1277
+ sequence = preceding_tokens[i, j].tolist() + [self.eot]
1278
+ seq_tuple = tuple(sequence)
1279
+ sequences[seq_tuple] = sum_logprobs[i][j].item()
1280
+ ts_[seq_tuple] = self.ts[:, i, j]
1281
+ if len(sequences) >= self.beam_size:
1282
+ break
1283
+
1284
+ tokens: List[List[Tensor]] = [
1285
+ [torch.tensor(seq) for seq in sequences.keys()] for sequences in self.finished_sequences
1286
+ ]
1287
+ sum_logprobs: List[List[float]] = [
1288
+ list(sequences.values()) for sequences in self.finished_sequences
1289
+ ]
1290
+ final_ts: List[List[Tensor]] = [
1291
+ list(sequences.values()) for sequences in self.finished_ts_ls
1292
+ ]
1293
+ return tokens, sum_logprobs, final_ts
1294
+
1295
+
1296
+ class DecodingTaskWordLevel(DecodingTask):
1297
+
1298
+ def __init__(self, *args, **kwargs):
1299
+ self.ts_num: int = kwargs.pop('ts_num', 10)
1300
+ self.alpha: float = kwargs.pop('alpha', None) # experimental
1301
+ self.suppress_ts_mask: Tensor = kwargs.pop('suppress_ts_mask', None)
1302
+ self.suppress_word_ts: bool = kwargs.pop('suppress_word_ts', True)
1303
+ super(DecodingTaskWordLevel, self).__init__(*args, **kwargs)
1304
+ if hasattr(self.decoder, 'beam_size'):
1305
+ self.decoder = BeamSearchDecoderWordLevel(self.decoder.beam_size,
1306
+ self.decoder.eot,
1307
+ self.inference,
1308
+ self.decoder.patience,
1309
+ ts_num=self.ts_num,
1310
+ suppress_ts_mask=self.suppress_ts_mask,
1311
+ timestamp_begin=self.tokenizer.timestamp_begin)
1312
+ else:
1313
+ self.decoder = GreedyDecoderWordLevel(self.decoder.temperature,
1314
+ self.decoder.eot,
1315
+ ts_num=self.ts_num,
1316
+ suppress_ts_mask=self.suppress_ts_mask,
1317
+ timestamp_begin=self.tokenizer.timestamp_begin)
1318
+
1319
+ # modified version of whisper.DecodingTask._main_loop
1320
+ def _main_loop(self, audio_features: Tensor, tokens: Tensor):
1321
+ assert audio_features.shape[0] == tokens.shape[0]
1322
+ n_batch = tokens.shape[0]
1323
+ sum_logprobs: Tensor = torch.zeros(n_batch, device=audio_features.device)
1324
+ no_speech_probs = [np.nan] * n_batch
1325
+
1326
+ # ts = None
1327
+
1328
+ try:
1329
+ for i in range(self.sample_len):
1330
+ if self.alpha:
1331
+ logits = self.inference.logits(tokens,
1332
+ audio_features * (torch.rand_like(audio_features) * self.alpha + 1))
1333
+ else:
1334
+ logits = self.inference.logits(tokens, audio_features)
1335
+
1336
+ if i == 0 and get_new_attrs(self.tokenizer, 'no_captions') is not None: # save no_speech_probs
1337
+ probs_at_sot = logits[:, self.sot_index].float().softmax(dim=-1)
1338
+ no_speech_probs = probs_at_sot[:, get_new_attrs(self.tokenizer, 'no_captions')].tolist()
1339
+
1340
+ # now we need to consider the logits at the last token only
1341
+ logits = logits[:, -1]
1342
+
1343
+ ts_logits = torch.clone(logits[:, self.tokenizer.timestamp_begin:])
1344
+ if self.suppress_word_ts:
1345
+ _suppress_ts(ts_logits, self.suppress_ts_mask)
1346
+ ts = _ts_topk(ts_logits, k=self.ts_num, prev_ts=self.decoder.ts)
1347
+
1348
+ # apply the logit filters, e.g. for suppressing or applying penalty to
1349
+ for logit_filter in self.logit_filters:
1350
+ logit_filter.apply(logits, tokens)
1351
+
1352
+ # expand the tokens tensor with the selected next tokens
1353
+ tokens, completed = self.decoder.update_with_ts(tokens, logits, sum_logprobs, ts)
1354
+
1355
+ if completed or tokens.shape[-1] > self.n_ctx:
1356
+ break
1357
+ finally:
1358
+ self.inference.cleanup_caching()
1359
+
1360
+ return tokens, sum_logprobs, no_speech_probs
1361
+
1362
+ # modified version of whisper.DecodingTask.run
1363
+ @torch.no_grad()
1364
+ def run(self, mel: Tensor) \
1365
+ -> Union[List[DecodingResult], Tuple[List[DecodingResult], List[List[int]]]]:
1366
+ self.decoder.reset()
1367
+ tokenizer: Tokenizer = self.tokenizer
1368
+ n_audio: int = mel.shape[0]
1369
+
1370
+ audio_features: Tensor = self._get_audio_features(mel) # encoder forward pass
1371
+ tokens: Tensor = torch.tensor([self.initial_tokens]).expand(n_audio, -1)
1372
+
1373
+ # detect language if requested, overwriting the language token
1374
+ languages, language_probs = self._detect_language(audio_features, tokens)
1375
+ if self.options.task == "lang_id":
1376
+ return [
1377
+ DecodingResult(audio_features=features, language=language, language_probs=probs)
1378
+ for features, language, probs in zip(audio_features, languages, language_probs)
1379
+ ]
1380
+
1381
+ # repeat the audio & text tensors by the group size, for beam search or best-of-n sampling
1382
+ audio_features = audio_features.repeat_interleave(self.n_group, dim=0)
1383
+ tokens = tokens.repeat_interleave(self.n_group, dim=0).to(audio_features.device)
1384
+
1385
+ # call the main sampling loop
1386
+ tokens, sum_logprobs, no_speech_probs = self._main_loop(audio_features, tokens)
1387
+
1388
+ # reshape the tensors to have (n_audio, n_group) as the first two dimensions
1389
+ audio_features = audio_features[:: self.n_group]
1390
+ no_speech_probs = no_speech_probs[:: self.n_group]
1391
+ assert audio_features.shape[0] == len(no_speech_probs) == n_audio
1392
+
1393
+ tokens = tokens.reshape(n_audio, self.n_group, -1)
1394
+ sum_logprobs = sum_logprobs.reshape(n_audio, self.n_group)
1395
+
1396
+ # get the final candidates for each group, and slice between the first sampled token and EOT
1397
+ tokens, sum_logprobs, ts = self.decoder.finalize(tokens, sum_logprobs)
1398
+ tokens: List[List[Tensor]] = [
1399
+ [t[self.sample_begin: (t == tokenizer.eot).nonzero()[0, 0]] for t in s] for s in tokens
1400
+ ]
1401
+ ts: List[List[Tensor]] = [[t[:, :tokens[i][j].shape[-1]] for j, t in enumerate(s)] for i, s in enumerate(ts)]
1402
+
1403
+ # select the top-ranked sample in each group
1404
+ selected = self.sequence_ranker.rank(tokens, sum_logprobs)
1405
+ tokens: List[List[int]] = [t[i].tolist() for i, t in zip(selected, tokens)]
1406
+ ts: List[List[int]] = [t[i].tolist() for i, t in zip(selected, ts)]
1407
+ texts: List[str] = [tokenizer.decode(t).strip() for t in tokens]
1408
+
1409
+ sum_logprobs: List[float] = [lp[i] for i, lp in zip(selected, sum_logprobs)]
1410
+ avg_logprobs: List[float] = [lp / (len(t) + 1) for t, lp in zip(tokens, sum_logprobs)]
1411
+
1412
+ fields = (texts, languages, tokens, audio_features, avg_logprobs, no_speech_probs)
1413
+ if len(set(map(len, fields))) != 1:
1414
+ raise RuntimeError(f"inconsistent result lengths: {list(map(len, fields))}")
1415
+
1416
+ return [
1417
+ DecodingResult(
1418
+ audio_features=features,
1419
+ language=language,
1420
+ tokens=tokens,
1421
+ text=text,
1422
+ avg_logprob=avg_logprob,
1423
+ **(dict(no_caption_prob=no_speech_prob) if hasattr(DecodingResult, 'no_caption_prob') else dict(
1424
+ no_speech_prob=no_speech_prob)),
1425
+ temperature=self.options.temperature,
1426
+ compression_ratio=compression_ratio(text),
1427
+ )
1428
+ for text, language, tokens, features, avg_logprob, no_speech_prob in zip(*fields)
1429
+ ], ts
1430
+
1431
+
1432
+ # modified version of whisper.decoding.decode
1433
+ @torch.no_grad()
1434
+ def decode_word_level(model: "Whisper", mel: Tensor, options: DecodingOptions = DecodingOptions(),
1435
+ ts_num: int = None, alpha: float = None, suppress_ts_mask: Tensor = None,
1436
+ suppress_word_ts=False) -> \
1437
+ Union[DecodingResult, List[DecodingResult], tuple]:
1438
+ """
1439
+ Performs decoding of 30-second audio segment(s), provided as Mel spectrogram(s).
1440
+
1441
+ Parameters
1442
+ ----------
1443
+ model: Whisper
1444
+ the Whisper model instance
1445
+
1446
+ mel: torch.Tensor, shape = (80, 3000) or (*, 80, 3000)
1447
+ A tensor containing the Mel spectrogram(s)
1448
+
1449
+ options: DecodingOptions
1450
+ A dataclass that contains all necessary options for decoding 30-second segments
1451
+
1452
+ ts_num: int
1453
+ Number of additional top timestamp predictions to save for each word for postprocessing stabilization (default: 5).
1454
+
1455
+ alpha: float
1456
+ Amount of noise to add to audio to produce slightly difference results.
1457
+ audio_features *= torch.rand_like(audio_features) * alpha + 1
1458
+
1459
+ suppress_ts_mask: (list, Tensor)
1460
+ Mask suppress to timestamp token(s) for decoding
1461
+
1462
+ suppress_word_ts: bool
1463
+ Use suppress_ts_mask to suppress timestamp tokens of words
1464
+
1465
+ Returns
1466
+ -------
1467
+ result: Union[DecodingResult, List[DecodingResult]]
1468
+ The result(s) of decoding contained in `DecodingResult` dataclass instance(s)
1469
+ """
1470
+ single = mel.ndim == 2
1471
+ if single:
1472
+ mel = mel.unsqueeze(0)
1473
+
1474
+ result, ts = DecodingTaskWordLevel(model, options,
1475
+ ts_num=ts_num,
1476
+ alpha=alpha,
1477
+ suppress_ts_mask=suppress_ts_mask,
1478
+ suppress_word_ts=suppress_word_ts).run(mel)
1479
+
1480
+ if single:
1481
+ result = result[0]
1482
+ ts_tokens = ts[0][1]
1483
+ ts_logits = ts[0][0]
1484
+ else:
1485
+ ts_tokens = [ts_[1] for ts_ in ts]
1486
+ ts_logits = [ts_[0] for ts_ in ts]
1487
+
1488
+ return result, ts_tokens, ts_logits
1489
+
1490
+
1491
+ def modify_model(model: whisper.model.Whisper):
1492
+ model.decode = MethodType(decode_word_level, model)
1493
+ model.transcribe = MethodType(transcribe_word_level, model)