NeoPy's picture
Upload 115 files
96134ee verified
raw
history blame
58.5 kB
import os
import sys
import gzip
import zlib
import tqdm
import torch
import base64
import string
import logging
import tiktoken
import itertools
import numba as nb
import numpy as np
import torch.nn as nn
import torch.nn.functional as F
from contextlib import contextmanager
from torch.distributions import Categorical
from functools import cached_property, lru_cache
from dataclasses import dataclass, replace
from torch.nn.functional import scaled_dot_product_attention
sys.path.append(os.getcwd())
from main.library.utils import load_audio
LANGUAGES = {"en": "english", "zh": "chinese", "de": "german", "es": "spanish", "ru": "russian", "ko": "korean", "fr": "french", "ja": "japanese", "pt": "portuguese", "tr": "turkish", "pl": "polish", "ca": "catalan", "nl": "dutch", "ar": "arabic", "sv": "swedish", "it": "italian", "id": "indonesian", "hi": "hindi", "fi": "finnish", "vi": "vietnamese", "he": "hebrew", "uk": "ukrainian", "el": "greek", "ms": "malay", "cs": "czech", "ro": "romanian", "da": "danish", "hu": "hungarian", "ta": "tamil", "no": "norwegian", "th": "thai", "ur": "urdu", "hr": "croatian", "bg": "bulgarian", "lt": "lithuanian", "la": "latin", "mi": "maori", "ml": "malayalam", "cy": "welsh", "sk": "slovak", "te": "telugu", "fa": "persian", "lv": "latvian", "bn": "bengali", "sr": "serbian", "az": "azerbaijani", "sl": "slovenian", "kn": "kannada", "et": "estonian", "mk": "macedonian", "br": "breton", "eu": "basque", "is": "icelandic", "hy": "armenian", "ne": "nepali", "mn": "mongolian", "bs": "bosnian", "kk": "kazakh", "sq": "albanian", "sw": "swahili", "gl": "galician", "mr": "marathi", "pa": "punjabi", "si": "sinhala", "km": "khmer", "sn": "shona", "yo": "yoruba", "so": "somali", "af": "afrikaans", "oc": "occitan", "ka": "georgian", "be": "belarusian", "tg": "tajik", "sd": "sindhi", "gu": "gujarati", "am": "amharic", "yi": "yiddish", "lo": "lao", "uz": "uzbek", "fo": "faroese", "ht": "haitian creole", "ps": "pashto", "tk": "turkmen", "nn": "nynorsk", "mt": "maltese", "sa": "sanskrit", "lb": "luxembourgish", "my": "myanmar", "bo": "tibetan", "tl": "tagalog", "mg": "malagasy", "as": "assamese", "tt": "tatar", "haw": "hawaiian", "ln": "lingala", "ha": "hausa", "ba": "bashkir", "jw": "javanese", "su": "sundanese", "yue": "cantonese"}
TO_LANGUAGE_CODE = {**{language: code for code, language in LANGUAGES.items()}, "burmese": "my", "valencian": "ca", "flemish": "nl", "haitian": "ht", "letzeburgesch": "lb", "pushto": "ps", "panjabi": "pa", "moldavian": "ro", "moldovan": "ro", "sinhalese": "si", "castilian": "es", "mandarin": "zh"}
_ALIGNMENT_HEADS = {"tiny": b"ABzY8bu8Lr0{>%RKn9Fp%m@SkK7Kt=7ytkO", "base": b"ABzY8KQ!870{>%RzyTQH3`Q^yNP!>##QT-<FaQ7m", "small": b"ABzY8DmU6=0{>%Rpa?J`kvJ6qF(V^F86#Xh7JUGMK}P<N0000", "medium": b"ABzY8B0Jh+0{>%R7}kK1fFL7w6%<-Pf*t^=N)Qr&0RR9", "large-v1": b"ABzY8r9j$a0{>%R7#4sLmoOs{s)o3~84-RPdcFk!JR<kSfC2yj", "large-v2": b"ABzY8zd+h!0{>%R7=D0pU<_bnWW*tkYAhobTNnu$jnkEkXqp)j;w1Tzk)UH3X%SZd&fFZ2fC2yj", "large-v3": b"ABzY8gWO1E0{>%R7(9S+Kn!D~%ngiGaR?*L!iJG9p-nab0JQ=-{D1-g00", "large-v3-turbo": b"ABzY8j^C+e0{>%RARaKHP%t(lGR*)0g!tONPyhe`"}
SAMPLE_RATE, N_FFT, HOP_LENGTH, CHUNK_LENGTH = 16000, 400, 160, 30
N_SAMPLES = CHUNK_LENGTH * SAMPLE_RATE
N_SAMPLES_PER_TOKEN = HOP_LENGTH * 2
def exact_div(x, y):
assert x % y == 0
return x // y
N_FRAMES = exact_div(N_SAMPLES, HOP_LENGTH)
FRAMES_PER_SECOND = exact_div(SAMPLE_RATE, HOP_LENGTH)
TOKENS_PER_SECOND = exact_div(SAMPLE_RATE, N_SAMPLES_PER_TOKEN)
def load_model(name = "base", device = "cpu"):
checkpoint_file = os.path.join("assets", "models", "speaker_diarization", "models", name + ".pt")
alignment_heads = _ALIGNMENT_HEADS[name]
with open(checkpoint_file, "rb") as fp:
checkpoint = torch.load(fp, map_location=device)
del checkpoint_file
model = Whisper(ModelDimensions(**checkpoint["dims"]))
model.load_state_dict(checkpoint["model_state_dict"])
model.set_alignment_heads(alignment_heads)
return model.to(device)
def merge_punctuations(alignment, prepended, appended):
i = len(alignment) - 2
j = len(alignment) - 1
while i >= 0:
previous = alignment[i]
following = alignment[j]
if previous.word.startswith(" ") and previous.word.strip() in prepended:
following.word = previous.word + following.word
following.tokens = previous.tokens + following.tokens
previous.word = ""
previous.tokens = []
else: j = i
i -= 1
i = 0
j = 1
while j < len(alignment):
previous = alignment[i]
following = alignment[j]
if not previous.word.endswith(" ") and following.word in appended:
previous.word = previous.word + following.word
previous.tokens = previous.tokens + following.tokens
following.word = ""
following.tokens = []
else: i = j
j += 1
class WordTiming:
def __init__(self, word, tokens, start, end, probability):
self.word = word
self.tokens = tokens
self.start = start
self.end = end
self.probability = probability
@contextmanager
def disable_sdpa():
prev_state = MultiHeadAttention.use_sdpa
try:
MultiHeadAttention.use_sdpa = False
yield
finally:
MultiHeadAttention.use_sdpa = prev_state
def median_filter(x, filter_width):
pad_width = filter_width // 2
if x.shape[-1] <= pad_width: return x
if (ndim := x.ndim) <= 2: x = x[None, None, :]
assert (filter_width > 0 and filter_width % 2 == 1)
result = None
x = F.pad(x, (filter_width // 2, filter_width // 2, 0, 0), mode="reflect")
if result is None: result = x.unfold(-1, filter_width, 1).sort()[0][..., filter_width // 2]
if ndim <= 2: result = result[0, 0]
return result
@nb.jit(nopython=True)
def backtrace(trace):
i = trace.shape[0] - 1
j = trace.shape[1] - 1
trace[0, :] = 2
trace[:, 0] = 1
result = []
while i > 0 or j > 0:
result.append((i - 1, j - 1))
if trace[i, j] == 0:
i -= 1
j -= 1
elif trace[i, j] == 1: i -= 1
elif trace[i, j] == 2: j -= 1
else: raise ValueError
return np.array(result)[::-1, :].T
@nb.jit(nopython=True, parallel=True)
def dtw_cpu(x):
N, M = x.shape
cost = np.ones((N + 1, M + 1), dtype=np.float32) * np.inf
trace = -np.ones((N + 1, M + 1), dtype=np.float32)
cost[0, 0] = 0
for j in range(1, M + 1):
for i in range(1, N + 1):
c0 = cost[i - 1, j - 1]
c1 = cost[i - 1, j]
c2 = cost[i, j - 1]
if c0 < c1 and c0 < c2: c, t = c0, 0
elif c1 < c0 and c1 < c2: c, t = c1, 1
else: c, t = c2, 2
cost[i, j] = x[i - 1, j - 1] + c
trace[i, j] = t
return backtrace(trace)
def dtw(x):
return dtw_cpu(x.double().cpu().numpy())
def find_alignment(model, tokenizer, text_tokens, mel, num_frames, *, medfilt_width = 7, qk_scale = 1.0):
if len(text_tokens) == 0: return []
tokens = torch.tensor([*tokenizer.sot_sequence, tokenizer.no_timestamps, *text_tokens, tokenizer.eot]).to(model.device)
QKs = [None] * model.dims.n_text_layer
hooks = [block.cross_attn.register_forward_hook(lambda _, ins, outs, index=i: QKs.__setitem__(index, outs[-1][0])) for i, block in enumerate(model.decoder.blocks)]
with torch.no_grad(), disable_sdpa():
token_probs = model(mel.unsqueeze(0), tokens.unsqueeze(0))[0][len(tokenizer.sot_sequence) :, : tokenizer.eot].softmax(dim=-1)
text_token_probs = token_probs[np.arange(len(text_tokens)), text_tokens].tolist()
for hook in hooks:
hook.remove()
weights = (torch.stack([QKs[_l][_h] for _l, _h in model.alignment_heads.indices().T])[:, :, : num_frames // 2] * qk_scale).softmax(dim=-1)
std, mean = torch.std_mean(weights, dim=-2, keepdim=True, unbiased=False)
weights = median_filter((weights - mean) / std, medfilt_width)
text_indices, time_indices = dtw(-weights.mean(axis=0)[len(tokenizer.sot_sequence) : -1])
words, word_tokens = tokenizer.split_to_word_tokens(text_tokens + [tokenizer.eot])
if len(word_tokens) <= 1: return []
word_boundaries = np.pad(np.cumsum([len(t) for t in word_tokens[:-1]]), (1, 0))
jump_times = time_indices[np.pad(np.diff(text_indices), (1, 0), constant_values=1).astype(bool)] / TOKENS_PER_SECOND
return [WordTiming(word, tokens, start, end, probability) for word, tokens, start, end, probability in zip(words, word_tokens, jump_times[word_boundaries[:-1]], jump_times[word_boundaries[1:]], [np.mean(text_token_probs[i:j]) for i, j in zip(word_boundaries[:-1], word_boundaries[1:])])]
def add_word_timestamps(*, segments, model, tokenizer, mel, num_frames, prepend_punctuations = "\"'“¿([{-", append_punctuations = "\"'.。,,!!??::”)]}、", last_speech_timestamp, **kwargs):
if len(segments) == 0: return
text_tokens_per_segment = [[token for token in segment["tokens"] if token < tokenizer.eot] for segment in segments]
text_tokens = list(itertools.chain.from_iterable(text_tokens_per_segment))
alignment = find_alignment(model, tokenizer, text_tokens, mel, num_frames, **kwargs)
word_durations = np.array([t.end - t.start for t in alignment])
word_durations = word_durations[word_durations.nonzero()]
median_duration = min(0.7, float(np.median(word_durations) if len(word_durations) > 0 else 0.0))
max_duration = median_duration * 2
if len(word_durations) > 0:
sentence_end_marks = ".。!!??"
for i in range(1, len(alignment)):
if alignment[i].end - alignment[i].start > max_duration:
if alignment[i].word in sentence_end_marks: alignment[i].end = alignment[i].start + max_duration
elif alignment[i - 1].word in sentence_end_marks: alignment[i].start = alignment[i].end - max_duration
merge_punctuations(alignment, prepend_punctuations, append_punctuations)
time_offset = segments[0]["seek"] * HOP_LENGTH / SAMPLE_RATE
word_index = 0
for segment, text_tokens in zip(segments, text_tokens_per_segment):
saved_tokens = 0
words = []
while word_index < len(alignment) and saved_tokens < len(text_tokens):
timing = alignment[word_index]
if timing.word: words.append(dict(word=timing.word, start=round(time_offset + timing.start, 2), end=round(time_offset + timing.end, 2), probability=timing.probability))
saved_tokens += len(timing.tokens)
word_index += 1
if len(words) > 0:
if words[0]["end"] - last_speech_timestamp > median_duration * 4 and (words[0]["end"] - words[0]["start"] > max_duration or (len(words) > 1 and words[1]["end"] - words[0]["start"] > max_duration * 2)):
if (len(words) > 1 and words[1]["end"] - words[1]["start"] > max_duration): words[0]["end"] = words[1]["start"] = max(words[1]["end"] / 2, words[1]["end"] - max_duration)
words[0]["start"] = max(0, words[0]["end"] - max_duration)
if (segment["start"] < words[0]["end"] and segment["start"] - 0.5 > words[0]["start"]): words[0]["start"] = max(0, min(words[0]["end"] - median_duration, segment["start"]))
else: segment["start"] = words[0]["start"]
if (segment["end"] > words[-1]["start"] and segment["end"] + 0.5 < words[-1]["end"]): words[-1]["end"] = max(words[-1]["start"] + median_duration, segment["end"])
else: segment["end"] = words[-1]["end"]
last_speech_timestamp = segment["end"]
segment["words"] = words
@lru_cache(maxsize=None)
def mel_filters(device, n_mels):
assert n_mels in {80, 128}
with np.load(os.path.join("assets", "models", "speaker_diarization", "assets", "mel_filters.npz"), allow_pickle=False) as f:
return torch.from_numpy(f[f"mel_{n_mels}"]).to(device)
def log_mel_spectrogram(audio, n_mels = 80, padding = 0, device = None):
if not torch.is_tensor(audio):
if isinstance(audio, str): audio = load_audio(logging.getLogger(__name__), audio, sample_rate=SAMPLE_RATE).astype(np.float32)
audio = torch.from_numpy(audio)
if device is not None: audio = audio.to(device)
if padding > 0: audio = F.pad(audio, (0, padding))
log_spec = torch.clamp(mel_filters(audio.device, n_mels) @ torch.stft(audio, N_FFT, HOP_LENGTH, window=torch.hann_window(N_FFT).to(audio.device), return_complex=True)[..., :-1].abs() ** 2, min=1e-10).log10()
return (torch.maximum(log_spec, log_spec.max() - 8.0) + 4.0) / 4.0
def pad_or_trim(array, length = N_SAMPLES, *, axis = -1):
if torch.is_tensor(array):
if array.shape[axis] > length: array = array.index_select(dim=axis, index=torch.arange(length, device=array.device))
if array.shape[axis] < length:
pad_widths = [(0, 0)] * array.ndim
pad_widths[axis] = (0, length - array.shape[axis])
array = F.pad(array, [pad for sizes in pad_widths[::-1] for pad in sizes])
else:
if array.shape[axis] > length: array = array.take(indices=range(length), axis=axis)
if array.shape[axis] < length:
pad_widths = [(0, 0)] * array.ndim
pad_widths[axis] = (0, length - array.shape[axis])
array = np.pad(array, pad_widths)
return array
def get_end(segments):
return next((w["end"] for s in reversed(segments) for w in reversed(s["words"])), segments[-1]["end"] if segments else None)
def transcribe_function(model, audio, *, verbose = None, temperature = (0.0, 0.2, 0.4, 0.6, 0.8, 1.0), compression_ratio_threshold = 2.4, logprob_threshold = -1.0, no_speech_threshold = 0.6, condition_on_previous_text = True, initial_prompt = None, carry_initial_prompt = False, word_timestamps = False, prepend_punctuations = "\"'“¿([{-", append_punctuations = "\"'.。,,!!??::”)]}、", clip_timestamps = "0", hallucination_silence_threshold = None, fp16 = False, **decode_options):
dtype = torch.float32
decode_options["fp16"] = fp16
mel = log_mel_spectrogram(audio, model.dims.n_mels, padding=N_SAMPLES)
content_frames = mel.shape[-1] - N_FRAMES
content_duration = float(content_frames * HOP_LENGTH / SAMPLE_RATE)
if decode_options.get("language", None) is None:
if not model.is_multilingual: decode_options["language"] = "vi"
else:
mel_segment = pad_or_trim(mel, N_FRAMES).to(model.device).to(dtype)
_, probs = model.detect_language(mel_segment)
decode_options["language"] = max(probs, key=probs.get)
if verbose is not None: print(f"{LANGUAGES[decode_options['language']].title()}")
language = decode_options["language"]
task = decode_options.get("task", "transcribe")
tokenizer = get_tokenizer(model.is_multilingual, num_languages=model.num_languages, language=language, task=task)
if isinstance(clip_timestamps, str): clip_timestamps = [float(ts) for ts in (clip_timestamps.split(",") if clip_timestamps else [])]
seek_points = [round(ts * FRAMES_PER_SECOND) for ts in clip_timestamps]
if len(seek_points) == 0: seek_points.append(0)
if len(seek_points) % 2 == 1: seek_points.append(content_frames)
seek_clips = list(zip(seek_points[::2], seek_points[1::2]))
punctuation = "\"'“¿([{-\"'.。,,!!??::”)]}、"
def decode_with_fallback(segment):
temperatures = ([temperature] if isinstance(temperature, (int, float)) else temperature)
decode_result = None
for t in temperatures:
kwargs = {**decode_options}
if t > 0:
kwargs.pop("beam_size", None)
kwargs.pop("patience", None)
else: kwargs.pop("best_of", None)
decode_result = model.decode(segment, DecodingOptions(**kwargs, temperature=t))
needs_fallback = False
if (compression_ratio_threshold is not None and decode_result.compression_ratio > compression_ratio_threshold): needs_fallback = True
if (logprob_threshold is not None and decode_result.avg_logprob < logprob_threshold): needs_fallback = True
if (no_speech_threshold is not None and decode_result.no_speech_prob > no_speech_threshold and logprob_threshold is not None and decode_result.avg_logprob < logprob_threshold): needs_fallback = False
if not needs_fallback: break
return decode_result
clip_idx = 0
seek = seek_clips[clip_idx][0]
input_stride = exact_div(N_FRAMES, model.dims.n_audio_ctx)
time_precision = (input_stride * HOP_LENGTH / SAMPLE_RATE)
all_tokens, all_segments = [], []
prompt_reset_since = 0
remaining_prompt_length = model.dims.n_text_ctx // 2 - 1
if initial_prompt is not None:
initial_prompt_tokens = tokenizer.encode(" " + initial_prompt.strip())
all_tokens.extend(initial_prompt_tokens)
remaining_prompt_length -= len(initial_prompt_tokens)
else: initial_prompt_tokens = []
def new_segment(*, start, end, tokens, result):
tokens = tokens.tolist()
return {"seek": seek, "start": start, "end": end, "text": tokenizer.decode([token for token in tokens if token < tokenizer.eot]), "tokens": tokens, "temperature": result.temperature, "avg_logprob": result.avg_logprob, "compression_ratio": result.compression_ratio, "no_speech_prob": result.no_speech_prob}
with tqdm.tqdm(total=content_frames, unit="frames", disable=verbose is not False) as pbar:
last_speech_timestamp = 0.0
while clip_idx < len(seek_clips):
seek_clip_start, seek_clip_end = seek_clips[clip_idx]
if seek < seek_clip_start: seek = seek_clip_start
if seek >= seek_clip_end:
clip_idx += 1
if clip_idx < len(seek_clips): seek = seek_clips[clip_idx][0]
continue
time_offset = float(seek * HOP_LENGTH / SAMPLE_RATE)
window_end_time = float((seek + N_FRAMES) * HOP_LENGTH / SAMPLE_RATE)
segment_size = min(N_FRAMES, content_frames - seek, seek_clip_end - seek)
mel_segment = mel[:, seek : seek + segment_size]
segment_duration = segment_size * HOP_LENGTH / SAMPLE_RATE
mel_segment = pad_or_trim(mel_segment, N_FRAMES).to(model.device).to(dtype)
if carry_initial_prompt: decode_options["prompt"] = initial_prompt_tokens + all_tokens[max(len(initial_prompt_tokens), prompt_reset_since):][-remaining_prompt_length:]
else: decode_options["prompt"] = all_tokens[prompt_reset_since:]
result = decode_with_fallback(mel_segment)
tokens = torch.tensor(result.tokens)
if no_speech_threshold is not None:
should_skip = result.no_speech_prob > no_speech_threshold
if (logprob_threshold is not None and result.avg_logprob > logprob_threshold):
should_skip = False
if should_skip:
seek += segment_size
continue
previous_seek = seek
current_segments = []
def word_anomaly_score(word):
probability = word.get("probability", 0.0)
duration = word["end"] - word["start"]
score = 0.0
if probability < 0.15: score += 1.0
if duration < 0.133: score += (0.133 - duration) * 15
if duration > 2.0: score += duration - 2.0
return score
def is_segment_anomaly(segment):
if segment is None or not segment["words"]: return False
words = [w for w in segment["words"] if w["word"] not in punctuation]
words = words[:8]
score = sum(word_anomaly_score(w) for w in words)
return score >= 3 or score + 0.01 >= len(words)
def next_words_segment(segments):
return next((s for s in segments if s["words"]), None)
timestamp_tokens = tokens.ge(tokenizer.timestamp_begin)
single_timestamp_ending = timestamp_tokens[-2:].tolist() == [False, True]
consecutive = torch.where(timestamp_tokens[:-1] & timestamp_tokens[1:])[0]
consecutive.add_(1)
if len(consecutive) > 0:
slices = consecutive.tolist()
if single_timestamp_ending:
slices.append(len(tokens))
last_slice = 0
for current_slice in slices:
sliced_tokens = tokens[last_slice:current_slice]
current_segments.append(new_segment(start=time_offset + (sliced_tokens[0].item() - tokenizer.timestamp_begin) * time_precision, end=time_offset + (sliced_tokens[-1].item() - tokenizer.timestamp_begin) * time_precision, tokens=sliced_tokens, result=result))
last_slice = current_slice
if single_timestamp_ending: seek += segment_size
else: seek += (tokens[last_slice - 1].item() - tokenizer.timestamp_begin) * input_stride
else:
duration = segment_duration
timestamps = tokens[timestamp_tokens.nonzero().flatten()]
if (len(timestamps) > 0 and timestamps[-1].item() != tokenizer.timestamp_begin): duration = (timestamps[-1].item() - tokenizer.timestamp_begin) * time_precision
current_segments.append(new_segment(start=time_offset, end=time_offset + duration, tokens=tokens, result=result))
seek += segment_size
if word_timestamps:
add_word_timestamps(segments=current_segments, model=model, tokenizer=tokenizer, mel=mel_segment, num_frames=segment_size, prepend_punctuations=prepend_punctuations, append_punctuations=append_punctuations, last_speech_timestamp=last_speech_timestamp)
if not single_timestamp_ending:
last_word_end = get_end(current_segments)
if last_word_end is not None and last_word_end > time_offset: seek = round(last_word_end * FRAMES_PER_SECOND)
if hallucination_silence_threshold is not None:
threshold = hallucination_silence_threshold
if not single_timestamp_ending:
last_word_end = get_end(current_segments)
if last_word_end is not None and last_word_end > time_offset: seek = round(last_word_end * FRAMES_PER_SECOND) if (window_end_time - last_word_end) > threshold else (previous_seek + segment_size)
first_segment = next_words_segment(current_segments)
if first_segment is not None and is_segment_anomaly(first_segment):
gap = first_segment["start"] - time_offset
if gap > threshold:
seek = previous_seek + round(gap * FRAMES_PER_SECOND)
continue
hal_last_end = last_speech_timestamp
for si in range(len(current_segments)):
segment = current_segments[si]
if not segment["words"]: continue
if is_segment_anomaly(segment):
next_segment = next_words_segment(current_segments[si + 1 :])
hal_next_start = next_segment["words"][0]["start"] if next_segment is not None else (time_offset + segment_duration)
if (segment["start"] - hal_last_end > threshold or segment["start"] < threshold or segment["start"] - time_offset < 2.0) and (hal_next_start - segment["end"] > threshold or is_segment_anomaly(next_segment) or window_end_time - segment["end"] < 2.0):
seek = round(max(time_offset + 1, segment["start"]) * FRAMES_PER_SECOND)
if content_duration - segment["end"] < threshold: seek = content_frames
current_segments[si:] = []
break
hal_last_end = segment["end"]
last_word_end = get_end(current_segments)
if last_word_end is not None: last_speech_timestamp = last_word_end
for _, segment in enumerate(current_segments):
if segment["start"] == segment["end"] or segment["text"].strip() == "":
segment["text"] = ""
segment["tokens"] = []
segment["words"] = []
all_segments.extend([{"id": i, **segment} for i, segment in enumerate(current_segments, start=len(all_segments))])
all_tokens.extend([token for segment in current_segments for token in segment["tokens"]])
if not condition_on_previous_text or result.temperature > 0.5: prompt_reset_since = len(all_tokens)
pbar.update(min(content_frames, seek) - previous_seek)
return dict(text=tokenizer.decode(all_tokens[len(initial_prompt_tokens) :]), segments=all_segments, language=language)
def compression_ratio(text):
text_bytes = text.encode("utf-8")
return len(text_bytes) / len(zlib.compress(text_bytes))
def sinusoids(length, channels, max_timescale=10000):
assert channels % 2 == 0
scaled_time = torch.arange(length)[:, np.newaxis] * torch.exp(-(np.log(max_timescale) / (channels // 2 - 1)) * torch.arange(channels // 2))[np.newaxis, :]
return torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=1)
@torch.no_grad()
def detect_language_function(model, mel, tokenizer = None):
if tokenizer is None: tokenizer = get_tokenizer(model.is_multilingual, num_languages=model.num_languages)
if (tokenizer.language is None or tokenizer.language_token not in tokenizer.sot_sequence): raise ValueError
single = mel.ndim == 2
if single: mel = mel.unsqueeze(0)
if mel.shape[-2:] != (model.dims.n_audio_ctx, model.dims.n_audio_state): mel = model.encoder(mel)
n_audio = mel.shape[0]
logits = model.logits(torch.tensor([[tokenizer.sot]] * n_audio).to(mel.device), mel)[:, 0]
mask = torch.ones(logits.shape[-1], dtype=torch.bool)
mask[list(tokenizer.all_language_tokens)] = False
logits[:, mask] = -np.inf
language_tokens = logits.argmax(dim=-1)
language_probs = [{c: logits.softmax(dim=-1).cpu()[i, j].item() for j, c in zip(tokenizer.all_language_tokens, tokenizer.all_language_codes)} for i in range(n_audio)]
if single:
language_tokens = language_tokens[0]
language_probs = language_probs[0]
return language_tokens, language_probs
@lru_cache(maxsize=None)
def get_tokenizer(multilingual, *, num_languages = 99, language = None, task = None):
if language is not None:
language = language.lower()
if language not in LANGUAGES:
if language in TO_LANGUAGE_CODE: language = TO_LANGUAGE_CODE[language]
else: raise ValueError
if multilingual:
encoding_name = "multilingual"
language = language or "en"
task = task or "transcribe"
else:
encoding_name = "gpt2"
language = None
task = None
return Tokenizer(encoding_name=encoding_name, num_languages=num_languages, language=language, task=task)
@lru_cache(maxsize=None)
def get_encoding(name = "gpt2", num_languages = 99):
vocab_path = os.path.join("assets", "models", "speaker_diarization", "assets", f"{name}.tiktoken")
ranks = {base64.b64decode(token): int(rank) for token, rank in (line.split() for line in open(vocab_path) if line)}
n_vocab = len(ranks)
special_tokens = {}
specials = ["<|endoftext|>", "<|startoftranscript|>", *[f"<|{lang}|>" for lang in list(LANGUAGES.keys())[:num_languages]], "<|translate|>", "<|transcribe|>", "<|startoflm|>", "<|startofprev|>", "<|nospeech|>", "<|notimestamps|>", *[f"<|{i * 0.02:.2f}|>" for i in range(1501)]]
for token in specials:
special_tokens[token] = n_vocab
n_vocab += 1
return tiktoken.Encoding(name=os.path.basename(vocab_path), explicit_n_vocab=n_vocab, pat_str=r"""'s|'t|'re|'ve|'m|'ll|'d| ?\p{L}+| ?\p{N}+| ?[^\s\p{L}\p{N}]+|\s+(?!\S)|\s+""", mergeable_ranks=ranks, special_tokens=special_tokens)
class DecodingOptions:
def __init__(self, task = "transcribe", language = None, temperature = 0.0, sample_len = None, best_of = None, beam_size = None, patience = None, length_penalty = None, prompt = None, prefix = None, suppress_tokens = "-1", suppress_blank = True, without_timestamps = False, max_initial_timestamp = 1.0, fp16 = False):
self.task = task
self.language = language
self.temperature = temperature
self.sample_len = sample_len
self.best_of = best_of
self.beam_size = beam_size
self.patience = patience
self.length_penalty = length_penalty
self.prompt = prompt
self.prefix = prefix
self.suppress_tokens = suppress_tokens
self.suppress_blank = suppress_blank
self.without_timestamps = without_timestamps
self.max_initial_timestamp = max_initial_timestamp
self.fp16 = fp16
@torch.no_grad()
def decode_function(model, mel, options = DecodingOptions(), **kwargs):
if single := mel.ndim == 2: mel = mel.unsqueeze(0)
if kwargs: options = replace(options, **kwargs)
result = DecodingTask(model, options).run(mel)
return result[0] if single else result
@dataclass
class ModelDimensions:
n_mels: int
n_audio_ctx: int
n_audio_state: int
n_audio_head: int
n_audio_layer: int
n_vocab: int
n_text_ctx: int
n_text_state: int
n_text_head: int
n_text_layer: int
class LayerNorm(nn.LayerNorm):
def forward(self, x):
return super().forward(x.float()).type(x.dtype)
class Linear(nn.Linear):
def forward(self, x):
return F.linear(x, self.weight.to(x.dtype), None if self.bias is None else self.bias.to(x.dtype))
class Conv1d(nn.Conv1d):
def _conv_forward(self, x, weight, bias):
return super()._conv_forward(x, weight.to(x.dtype), None if bias is None else bias.to(x.dtype))
class TextDecoder(nn.Module):
def __init__(self, n_vocab, n_ctx, n_state, n_head, n_layer):
super().__init__()
self.token_embedding = nn.Embedding(n_vocab, n_state)
self.positional_embedding = nn.Parameter(torch.empty(n_ctx, n_state))
self.blocks = nn.ModuleList([ResidualAttentionBlock(n_state, n_head, cross_attention=True) for _ in range(n_layer)])
self.ln = LayerNorm(n_state)
self.register_buffer("mask", torch.empty(n_ctx, n_ctx).fill_(-np.inf).triu_(1), persistent=False)
def forward(self, x, xa, kv_cache = None):
offset = next(iter(kv_cache.values())).shape[1] if kv_cache else 0
x = (self.token_embedding(x) + self.positional_embedding[offset : offset + x.shape[-1]]).to(xa.dtype)
for block in self.blocks:
x = block(x, xa, mask=self.mask, kv_cache=kv_cache)
x = self.ln(x)
return (x @ torch.transpose(self.token_embedding.weight.to(x.dtype), 0, 1)).float()
class AudioEncoder(nn.Module):
def __init__(self, n_mels, n_ctx, n_state, n_head, n_layer):
super().__init__()
self.conv1 = Conv1d(n_mels, n_state, kernel_size=3, padding=1)
self.conv2 = Conv1d(n_state, n_state, kernel_size=3, stride=2, padding=1)
self.register_buffer("positional_embedding", sinusoids(n_ctx, n_state))
self.blocks = nn.ModuleList([ResidualAttentionBlock(n_state, n_head) for _ in range(n_layer)])
self.ln_post = LayerNorm(n_state)
def forward(self, x):
x = F.gelu(self.conv2(F.gelu(self.conv1(x)))).permute(0, 2, 1)
assert x.shape[1:] == self.positional_embedding.shape
x = (x + self.positional_embedding).to(x.dtype)
for block in self.blocks:
x = block(x)
return self.ln_post(x)
class Whisper(nn.Module):
def __init__(self, dims):
super().__init__()
self.dims = dims
self.encoder = AudioEncoder(self.dims.n_mels, self.dims.n_audio_ctx, self.dims.n_audio_state, self.dims.n_audio_head, self.dims.n_audio_layer)
self.decoder = TextDecoder(self.dims.n_vocab, self.dims.n_text_ctx, self.dims.n_text_state, self.dims.n_text_head, self.dims.n_text_layer)
all_heads = torch.zeros(self.dims.n_text_layer, self.dims.n_text_head, dtype=torch.bool)
all_heads[self.dims.n_text_layer // 2 :] = True
self.register_buffer("alignment_heads", all_heads.to_sparse(), persistent=False)
def set_alignment_heads(self, dump):
self.register_buffer("alignment_heads", torch.from_numpy(np.frombuffer(gzip.decompress(base64.b85decode(dump)), dtype=bool).copy()).reshape(self.dims.n_text_layer, self.dims.n_text_head).to_sparse(), persistent=False)
def embed_audio(self, mel):
return self.encoder(mel)
def logits(self, tokens, audio_features):
return self.decoder(tokens, audio_features)
def forward(self, mel, tokens):
return self.decoder(tokens, self.encoder(mel))
@property
def device(self):
return next(self.parameters()).device
@property
def is_multilingual(self):
return self.dims.n_vocab >= 51865
@property
def num_languages(self):
return self.dims.n_vocab - 51765 - int(self.is_multilingual)
def install_kv_cache_hooks(self, cache = None):
cache = {**cache} if cache is not None else {}
hooks = []
def save_to_cache(module, _, output):
cache[module] = output if module not in cache or output.shape[1] > self.dims.n_text_ctx else torch.cat([cache[module], output], dim=1).detach()
return cache[module]
def install_hooks(layer: nn.Module):
if isinstance(layer, MultiHeadAttention):
hooks.append(layer.key.register_forward_hook(save_to_cache))
hooks.append(layer.value.register_forward_hook(save_to_cache))
self.decoder.apply(install_hooks)
return cache, hooks
detect_language = detect_language_function
transcribe = transcribe_function
decode = decode_function
class ResidualAttentionBlock(nn.Module):
def __init__(self, n_state: int, n_head: int, cross_attention: bool = False):
super().__init__()
self.attn = MultiHeadAttention(n_state, n_head)
self.attn_ln = LayerNorm(n_state)
self.cross_attn = (MultiHeadAttention(n_state, n_head) if cross_attention else None)
self.cross_attn_ln = LayerNorm(n_state) if cross_attention else None
n_mlp = n_state * 4
self.mlp = nn.Sequential(Linear(n_state, n_mlp), nn.GELU(), Linear(n_mlp, n_state))
self.mlp_ln = LayerNorm(n_state)
def forward(self, x, xa = None, mask = None, kv_cache = None):
x = x + self.attn(self.attn_ln(x), mask=mask, kv_cache=kv_cache)[0]
if self.cross_attn: x = x + self.cross_attn(self.cross_attn_ln(x), xa, kv_cache=kv_cache)[0]
return x + self.mlp(self.mlp_ln(x))
class MultiHeadAttention(nn.Module):
def __init__(self, n_state, n_head):
super().__init__()
self.n_head = n_head
self.query = Linear(n_state, n_state)
self.key = Linear(n_state, n_state, bias=False)
self.value = Linear(n_state, n_state)
self.out = Linear(n_state, n_state)
def forward(self, x, xa = None, mask = None, kv_cache = None):
k, v = (self.key(x if xa is None else xa), self.value(x if xa is None else xa)) if kv_cache is None or xa is None or self.key not in kv_cache else (kv_cache[self.key], kv_cache[self.value])
wv, qk = self.qkv_attention(self.query(x), k, v, mask)
return self.out(wv), qk
def qkv_attention(self, q, k, v, mask = None):
_, n_ctx, _ = q.shape
q, k, v = q.view(*q.shape[:2], self.n_head, -1).permute(0, 2, 1, 3), k.view(*k.shape[:2], self.n_head, -1).permute(0, 2, 1, 3), v.view(*v.shape[:2], self.n_head, -1).permute(0, 2, 1, 3)
return scaled_dot_product_attention(q, k, v, is_causal=mask is not None and n_ctx > 1).permute(0, 2, 1, 3).flatten(start_dim=2), None
class LogitFilter:
def apply(self, logits, tokens):
pass
class SuppressBlank(LogitFilter):
def __init__(self, tokenizer, sample_begin):
self.tokenizer = tokenizer
self.sample_begin = sample_begin
def apply(self, logits, tokens):
if tokens.shape[1] == self.sample_begin: logits[:, self.tokenizer.encode(" ") + [self.tokenizer.eot]] = -np.inf
class SuppressTokens(LogitFilter):
def __init__(self, suppress_tokens):
self.suppress_tokens = list(suppress_tokens)
def apply(self, logits, tokens):
logits[:, self.suppress_tokens] = -np.inf
class Inference:
def logits(self, tokens, audio_features):
pass
def rearrange_kv_cache(self, source_indices):
pass
def cleanup_caching(self):
pass
class PyTorchInference(Inference):
def __init__(self, model, initial_token_length):
self.model = model
self.initial_token_length = initial_token_length
self.kv_cache = {}
self.hooks = []
self.kv_modules = [block.attn.key for block in self.model.decoder.blocks] + [block.attn.value for block in self.model.decoder.blocks]
def logits(self, tokens, audio_features):
if not self.kv_cache: self.kv_cache, self.hooks = self.model.install_kv_cache_hooks()
if tokens.shape[-1] > self.initial_token_length: tokens = tokens[:, -1:]
return self.model.decoder(tokens, audio_features, kv_cache=self.kv_cache)
def cleanup_caching(self):
for hook in self.hooks:
hook.remove()
self.kv_cache = {}
self.hooks = []
def rearrange_kv_cache(self, source_indices):
if source_indices != list(range(len(source_indices))):
for module in self.kv_modules:
self.kv_cache[module] = self.kv_cache[module][source_indices].detach()
class SequenceRanker:
def rank(self, tokens, sum_logprobs):
pass
class MaximumLikelihoodRanker(SequenceRanker):
def __init__(self, length_penalty):
self.length_penalty = length_penalty
def rank(self, tokens, sum_logprobs):
def scores(logprobs, lengths):
result = []
for logprob, length in zip(logprobs, lengths):
result.append(logprob / (length if self.length_penalty is None else ((5 + length) / 6) ** self.length_penalty))
return result
return [np.argmax(scores(p, l)) for p, l in zip(sum_logprobs, [[len(t) for t in s] for s in tokens])]
class TokenDecoder:
def reset(self):
pass
def update(self, tokens, logits, sum_logprobs):
pass
def finalize(self, tokens, sum_logprobs):
pass
class GreedyDecoder(TokenDecoder):
def __init__(self, temperature, eot):
self.temperature = temperature
self.eot = eot
def update(self, tokens, logits, sum_logprobs):
next_tokens = logits.argmax(dim=-1) if self.temperature == 0 else Categorical(logits=logits / self.temperature).sample()
logprobs = F.log_softmax(logits.float(), dim=-1)
sum_logprobs += logprobs[torch.arange(logprobs.shape[0]), next_tokens] * (tokens[:, -1] != self.eot)
next_tokens[tokens[:, -1] == self.eot] = self.eot
tokens = torch.cat([tokens, next_tokens[:, None]], dim=-1)
return tokens, (tokens[:, -1] == self.eot).all()
def finalize(self, tokens, sum_logprobs):
return F.pad(tokens, (0, 1), value=self.eot), sum_logprobs.tolist()
class BeamSearchDecoder(TokenDecoder):
def __init__(self, beam_size, eot, inference, patience = None):
self.beam_size = beam_size
self.eot = eot
self.inference = inference
self.patience = patience or 1.0
self.max_candidates = round(beam_size * self.patience)
self.finished_sequences = None
assert (self.max_candidates > 0)
def reset(self):
self.finished_sequences = None
def update(self, tokens, logits, sum_logprobs):
if tokens.shape[0] % self.beam_size != 0: raise ValueError(f"{tokens.shape}[0] % {self.beam_size} != 0")
n_audio = tokens.shape[0] // self.beam_size
if self.finished_sequences is None: self.finished_sequences = [{} for _ in range(n_audio)]
logprobs = F.log_softmax(logits.float(), dim=-1)
next_tokens, source_indices, finished_sequences = [], [], []
for i in range(n_audio):
scores, sources, finished = {}, {}, {}
for j in range(self.beam_size):
idx = i * self.beam_size + j
prefix = tokens[idx].tolist()
for logprob, token in zip(*logprobs[idx].topk(self.beam_size + 1)):
sequence = tuple(prefix + [token.item()])
scores[sequence] = (sum_logprobs[idx] + logprob).item()
sources[sequence] = idx
saved = 0
for sequence in sorted(scores, key=scores.get, reverse=True):
if sequence[-1] == self.eot: finished[sequence] = scores[sequence]
else:
sum_logprobs[len(next_tokens)] = scores[sequence]
next_tokens.append(sequence)
source_indices.append(sources[sequence])
saved += 1
if saved == self.beam_size: break
finished_sequences.append(finished)
self.inference.rearrange_kv_cache(source_indices)
assert len(self.finished_sequences) == len(finished_sequences)
for previously_finished, newly_finished in zip(self.finished_sequences, finished_sequences):
for seq in sorted(newly_finished, key=newly_finished.get, reverse=True):
if len(previously_finished) >= self.max_candidates: break
previously_finished[seq] = newly_finished[seq]
return torch.tensor(next_tokens, device=tokens.device), all(len(sequences) >= self.max_candidates for sequences in self.finished_sequences)
def finalize(self, preceding_tokens, sum_logprobs):
sum_logprobs = sum_logprobs.cpu()
for i, sequences in enumerate(self.finished_sequences):
if (len(sequences) < self.beam_size):
for j in list(np.argsort(sum_logprobs[i]))[::-1]:
sequence = preceding_tokens[i, j].tolist() + [self.eot]
sequences[tuple(sequence)] = sum_logprobs[i][j].item()
if len(sequences) >= self.beam_size: break
return [[torch.tensor(seq) for seq in sequences.keys()] for sequences in self.finished_sequences], [list(sequences.values()) for sequences in self.finished_sequences]
class ApplyTimestampRules(LogitFilter):
def __init__(self, tokenizer, sample_begin, max_initial_timestamp_index):
self.tokenizer = tokenizer
self.sample_begin = sample_begin
self.max_initial_timestamp_index = max_initial_timestamp_index
def apply(self, logits, tokens):
if self.tokenizer.no_timestamps is not None: logits[:, self.tokenizer.no_timestamps] = -np.inf
for k in range(tokens.shape[0]):
sampled_tokens = tokens[k, self.sample_begin :]
seq = [t for t in sampled_tokens.tolist()]
last_was_timestamp = (len(seq) >= 1 and seq[-1] >= self.tokenizer.timestamp_begin)
penultimate_was_timestamp = (len(seq) < 2 or seq[-2] >= self.tokenizer.timestamp_begin)
if last_was_timestamp:
if penultimate_was_timestamp: logits[k, self.tokenizer.timestamp_begin :] = -np.inf
else: logits[k, : self.tokenizer.eot] = -np.inf
timestamps = sampled_tokens[sampled_tokens.ge(self.tokenizer.timestamp_begin)]
if timestamps.numel() > 0: logits[k, self.tokenizer.timestamp_begin : timestamps[-1] if last_was_timestamp and not penultimate_was_timestamp else (timestamps[-1] + 1)] = -np.inf
if tokens.shape[1] == self.sample_begin:
logits[:, : self.tokenizer.timestamp_begin] = -np.inf
if self.max_initial_timestamp_index is not None:
last_allowed = (self.tokenizer.timestamp_begin + self.max_initial_timestamp_index)
logits[:, last_allowed + 1 :] = -np.inf
logprobs = F.log_softmax(logits.float(), dim=-1)
for k in range(tokens.shape[0]):
if logprobs[k, self.tokenizer.timestamp_begin :].logsumexp(dim=-1) > logprobs[k, : self.tokenizer.timestamp_begin].max(): logits[k, : self.tokenizer.timestamp_begin] = -np.inf
class DecodingTask:
def __init__(self, model, options):
self.model = model
language = options.language or "en"
tokenizer = get_tokenizer(model.is_multilingual, num_languages=model.num_languages, language=language, task=options.task)
self.tokenizer = tokenizer
self.options = self._verify_options(options)
self.n_group = options.beam_size or options.best_of or 1
self.n_ctx = model.dims.n_text_ctx
self.sample_len = options.sample_len or model.dims.n_text_ctx // 2
self.sot_sequence = tokenizer.sot_sequence
if self.options.without_timestamps: self.sot_sequence = tokenizer.sot_sequence_including_notimestamps
self.initial_tokens = self._get_initial_tokens()
self.sample_begin = len(self.initial_tokens)
self.sot_index = self.initial_tokens.index(tokenizer.sot)
self.inference = PyTorchInference(model, len(self.initial_tokens))
self.sequence_ranker = MaximumLikelihoodRanker(options.length_penalty)
self.decoder = BeamSearchDecoder(options.beam_size, tokenizer.eot, self.inference, options.patience) if options.beam_size is not None else GreedyDecoder(options.temperature, tokenizer.eot)
self.logit_filters = []
if self.options.suppress_blank: self.logit_filters.append(SuppressBlank(self.tokenizer, self.sample_begin))
if self.options.suppress_tokens: self.logit_filters.append(SuppressTokens(self._get_suppress_tokens()))
if not options.without_timestamps:
max_initial_timestamp_index = None
if options.max_initial_timestamp: max_initial_timestamp_index = round(self.options.max_initial_timestamp / (CHUNK_LENGTH / model.dims.n_audio_ctx))
self.logit_filters.append(ApplyTimestampRules(tokenizer, self.sample_begin, max_initial_timestamp_index))
def _verify_options(self, options):
if options.beam_size is not None and options.best_of is not None: raise ValueError
if options.temperature == 0 and options.best_of is not None: raise ValueError
if options.patience is not None and options.beam_size is None: raise ValueError
if options.length_penalty is not None and not (0 <= options.length_penalty <= 1): raise ValueError
return options
def _get_initial_tokens(self):
tokens = list(self.sot_sequence)
if prefix := self.options.prefix:
prefix_tokens = (self.tokenizer.encode(" " + prefix.strip()) if isinstance(prefix, str) else prefix)
if self.sample_len is not None: prefix_tokens = prefix_tokens[-(self.n_ctx // 2 - self.sample_len):]
tokens = tokens + prefix_tokens
if prompt := self.options.prompt: tokens = ([self.tokenizer.sot_prev] + (self.tokenizer.encode(" " + prompt.strip()) if isinstance(prompt, str) else prompt)[-(self.n_ctx // 2 - 1) :] + tokens)
return tuple(tokens)
def _get_suppress_tokens(self):
suppress_tokens = self.options.suppress_tokens
if isinstance(suppress_tokens, str): suppress_tokens = [int(t) for t in suppress_tokens.split(",")]
if -1 in suppress_tokens:
suppress_tokens = [t for t in suppress_tokens if t >= 0]
suppress_tokens.extend(self.tokenizer.non_speech_tokens)
elif suppress_tokens is None or len(suppress_tokens) == 0: suppress_tokens = []
else: assert isinstance(suppress_tokens, list)
suppress_tokens.extend([self.tokenizer.transcribe, self.tokenizer.translate, self.tokenizer.sot, self.tokenizer.sot_prev, self.tokenizer.sot_lm])
if self.tokenizer.no_speech is not None: suppress_tokens.append(self.tokenizer.no_speech)
return tuple(sorted(set(suppress_tokens)))
def _get_audio_features(self, mel):
if self.options.fp16: mel = mel.half()
audio_features = mel if mel.shape[-2:] == (self.model.dims.n_audio_ctx, self.model.dims.n_audio_state) else self.model.encoder(mel)
if audio_features.dtype != (torch.float16 if self.options.fp16 else torch.float32): return TypeError(f"audio_features has an incorrect dtype: {audio_features.dtype}")
return audio_features
def _detect_language(self, audio_features, tokens):
languages = [self.options.language] * audio_features.shape[0]
lang_probs = None
if self.options.language is None or self.options.task == "lang_id":
lang_tokens, lang_probs = self.model.detect_language(audio_features, self.tokenizer)
languages = [max(probs, key=probs.get) for probs in lang_probs]
if self.options.language is None: tokens[:, self.sot_index + 1] = lang_tokens
return languages, lang_probs
def _main_loop(self, audio_features, tokens):
n_batch = tokens.shape[0]
sum_logprobs = torch.zeros(n_batch, device=audio_features.device)
no_speech_probs = [np.nan] * n_batch
try:
for i in range(self.sample_len):
logits = self.inference.logits(tokens, audio_features)
if (i == 0 and self.tokenizer.no_speech is not None):
probs_at_sot = logits[:, self.sot_index].float().softmax(dim=-1)
no_speech_probs = probs_at_sot[:, self.tokenizer.no_speech].tolist()
logits = logits[:, -1]
for logit_filter in self.logit_filters:
logit_filter.apply(logits, tokens)
tokens, completed = self.decoder.update(tokens, logits, sum_logprobs)
if completed or tokens.shape[-1] > self.n_ctx: break
finally:
self.inference.cleanup_caching()
return tokens, sum_logprobs, no_speech_probs
@torch.no_grad()
def run(self, mel):
self.decoder.reset()
tokenizer = self.tokenizer
n_audio = mel.shape[0]
audio_features = self._get_audio_features(mel)
tokens = torch.tensor([self.initial_tokens]).repeat(n_audio, 1)
languages, language_probs = self._detect_language(audio_features, tokens)
if self.options.task == "lang_id": return [DecodingResult(audio_features=features, language=language, language_probs=probs) for features, language, probs in zip(audio_features, languages, language_probs)]
tokens = tokens.repeat_interleave(self.n_group, dim=0).to(audio_features.device)
tokens, sum_logprobs, no_speech_probs = self._main_loop(audio_features, tokens)
audio_features = audio_features[:: self.n_group]
no_speech_probs = no_speech_probs[:: self.n_group]
assert audio_features.shape[0] == len(no_speech_probs) == n_audio
tokens = tokens.reshape(n_audio, self.n_group, -1)
sum_logprobs = sum_logprobs.reshape(n_audio, self.n_group)
tokens, sum_logprobs = self.decoder.finalize(tokens, sum_logprobs)
tokens = [[t[self.sample_begin : (t == tokenizer.eot).nonzero()[0, 0]] for t in s] for s in tokens]
selected = self.sequence_ranker.rank(tokens, sum_logprobs)
tokens = [t[i].tolist() for i, t in zip(selected, tokens)]
fields = ([tokenizer.decode(t).strip() for t in tokens], languages, tokens, audio_features, [lp / (len(t) + 1) for t, lp in zip(tokens, [lp[i] for i, lp in zip(selected, sum_logprobs)])], no_speech_probs)
if len(set(map(len, fields))) != 1: raise RuntimeError
return [DecodingResult(audio_features=features, language=language, tokens=tokens, text=text, avg_logprob=avg_logprob, no_speech_prob=no_speech_prob, temperature=self.options.temperature, compression_ratio=compression_ratio(text)) for text, language, tokens, features, avg_logprob, no_speech_prob in zip(*fields)]
class DecodingResult:
def __init__(self, audio_features, language, language_probs = None, tokens = None, text = "", avg_logprob = np.nan, no_speech_prob = np.nan, temperature = np.nan, compression_ratio = np.nan):
self.audio_features = audio_features
self.language = language
self.language_probs = language_probs if language_probs is not None else {}
self.tokens = tokens if tokens is not None else []
self.text = text
self.avg_logprob = avg_logprob
self.no_speech_prob = no_speech_prob
self.temperature = temperature
self.compression_ratio = compression_ratio
class Tokenizer:
def __init__(self, encoding_name, num_languages = 2, language = None, task = None, sot_sequence = ()):
self.encoding = get_encoding(name=encoding_name, num_languages=num_languages)
self.num_languages = num_languages
self.language = language
self.task = task
self.sot_sequence = sot_sequence
self.special_tokens = {}
for special in self.encoding.special_tokens_set:
special_token = self.encoding.encode_single_token(special)
self.special_tokens[special] = special_token
sot = self.special_tokens["<|startoftranscript|>"]
langs = tuple(LANGUAGES.keys())[: self.num_languages]
sot_sequence = [sot]
if self.language is not None: sot_sequence.append(sot + 1 + langs.index(self.language))
if self.task is not None: sot_sequence.append(self.special_tokens["<|transcribe|>"] if self.task == "transcribe" else self.special_tokens["<|translate|>"])
self.sot_sequence = tuple(sot_sequence)
def encode(self, text, **kwargs):
return self.encoding.encode(text, **kwargs)
def decode(self, token_ids, **kwargs):
return self.encoding.decode([t for t in token_ids if t < self.timestamp_begin], **kwargs)
def decode_with_timestamps(self, token_ids, **kwargs):
return self.encoding.decode(token_ids, **kwargs)
@cached_property
def eot(self):
return self.encoding.eot_token
@cached_property
def transcribe(self):
return self.special_tokens["<|transcribe|>"]
@cached_property
def translate(self):
return self.special_tokens["<|translate|>"]
@cached_property
def sot(self):
return self.special_tokens["<|startoftranscript|>"]
@cached_property
def sot_lm(self):
return self.special_tokens["<|startoflm|>"]
@cached_property
def sot_prev(self):
return self.special_tokens["<|startofprev|>"]
@cached_property
def no_speech(self):
return self.special_tokens["<|nospeech|>"]
@cached_property
def no_timestamps(self):
return self.special_tokens["<|notimestamps|>"]
@cached_property
def timestamp_begin(self):
return self.special_tokens["<|0.00|>"]
@cached_property
def language_token(self):
if self.language is None: raise ValueError
return self.to_language_token(self.language)
def to_language_token(self, language):
if token := self.special_tokens.get(f"<|{language}|>", None): return token
raise KeyError
@cached_property
def all_language_tokens(self):
result = []
for token, token_id in self.special_tokens.items():
if token.strip("<|>") in LANGUAGES: result.append(token_id)
return tuple(result)[: self.num_languages]
@cached_property
def all_language_codes(self):
return tuple(self.decode([_l]).strip("<|>") for _l in self.all_language_tokens)
@cached_property
def sot_sequence_including_notimestamps(self):
return tuple(list(self.sot_sequence) + [self.no_timestamps])
@cached_property
def non_speech_tokens(self):
symbols = list('"#()*+/:;<=>@[\\]^_`{|}~「」『』')
symbols += ("<< >> <<< >>> -- --- -( -[ (' (\" (( )) ((( ))) [[ ]] {{ }} ♪♪ ♪♪♪".split())
miscellaneous = set("♩♪♫♬♭♮♯")
assert all(0x2640 <= ord(c) <= 0x267F for c in miscellaneous)
result = {self.encoding.encode(" -")[0], self.encoding.encode(" '")[0]}
for symbol in symbols + list(miscellaneous):
for tokens in [self.encoding.encode(symbol), self.encoding.encode(" " + symbol)]:
if len(tokens) == 1 or symbol in miscellaneous: result.add(tokens[0])
return tuple(sorted(result))
def split_to_word_tokens(self, tokens):
if self.language in {"zh", "ja", "th", "lo", "my", "yue"}: return self.split_tokens_on_unicode(tokens)
return self.split_tokens_on_spaces(tokens)
def split_tokens_on_unicode(self, tokens):
replacement_char = "\ufffd"
words, word_tokens, current_tokens = [], [], []
unicode_offset = 0
for token in tokens:
current_tokens.append(token)
decoded = self.decode_with_timestamps(current_tokens)
if (replacement_char not in decoded or self.decode_with_timestamps(tokens)[unicode_offset + decoded.index(replacement_char)] == replacement_char):
words.append(decoded)
word_tokens.append(current_tokens)
current_tokens = []
unicode_offset += len(decoded)
return words, word_tokens
def split_tokens_on_spaces(self, tokens):
subwords, subword_tokens_list = self.split_tokens_on_unicode(tokens)
words, word_tokens = [], []
for subword, subword_tokens in zip(subwords, subword_tokens_list):
if (subword_tokens[0] >= self.eot) or (subword.startswith(" ")) or (subword.strip() in string.punctuation) or len(words) == 0:
words.append(subword)
word_tokens.append(subword_tokens)
else:
words[-1] = words[-1] + subword
word_tokens[-1].extend(subword_tokens)
return words, word_tokens