Spaces:
Running
on
T4
Running
on
T4
# AUTOGENERATED! DO NOT EDIT! File to edit: ../nbs/2B. Whisper quantization (semantic token) model.ipynb. | |
# %% auto 0 | |
__all__ = ['RQBottleneckTransformer', 'make_model'] | |
# %% ../nbs/2B. Whisper quantization (semantic token) model.ipynb 2 | |
import io | |
import sys | |
import time | |
import torch | |
import torchaudio | |
# %% ../nbs/2B. Whisper quantization (semantic token) model.ipynb 3 | |
from pathlib import Path | |
import json | |
from fastprogress import progress_bar, master_bar | |
import fastprogress | |
import numpy as np | |
import pylab as plt | |
import pandas as pd | |
import random | |
import whisper | |
from huggingface_hub import hf_hub_download | |
from fastcore.basics import store_attr | |
from torch import nn | |
import torch.optim as optim | |
import torch.nn.functional as F | |
from torch.utils.data.dataloader import DataLoader | |
import webdataset as wds | |
from . import utils | |
from vector_quantize_pytorch import ResidualVQ | |
from fastcore.script import * | |
# %% ../nbs/2B. Whisper quantization (semantic token) model.ipynb 9 | |
def merge_in(dataset_fun): | |
"""Merge a dataset into the current one returning samples with the union of keys. Pass in a function | |
that takes a URL of a sample and returns a dataset for it (called everytime the URL changes). | |
It requires (and validates) that both datasets have the same ordering of keys so you have | |
to use it before any sample shuffling. Shard shuffling is ok. | |
""" | |
def merge_loop(main_samples): | |
#print("new merge loop:", dataset_fun) | |
merged_samples = None | |
cur_url = None | |
i = None | |
for s in main_samples: | |
url = s['__url__'] | |
if url != cur_url: | |
# this will open a new file when we get the first sample with a new __url__ | |
merged_samples = iter(dataset_fun(url)) | |
cur_url = url | |
try: | |
merge_s = next(merged_samples) | |
except StopIteration: | |
# if the original shard got repeated we won't observe a __url__ change | |
# in this case restart the dataset from the beginning | |
merged_samples = iter(dataset_fun(url)) | |
merge_s = next(merged_samples) | |
assert merge_s['__key__'] == s['__key__'], f"sample keys don't match: {merge_s['__key__']}, {s['__key__']} in file {s['__url__']}" | |
news = {} | |
news.update(merge_s) | |
news.update(s) | |
yield news | |
return merge_loop | |
# %% ../nbs/2B. Whisper quantization (semantic token) model.ipynb 10 | |
def derived_dataset(kind, key='audio'): | |
def deriver(url): | |
url = str(Path(url).parent/(Path(url).name.replace(key, kind) + ".gz")) | |
return wds.WebDataset( | |
wds.SimpleShardList([url]) | |
).decode() | |
return deriver | |
# %% ../nbs/2B. Whisper quantization (semantic token) model.ipynb 17 | |
def add_masks(samples): | |
for s in samples: | |
seconds = s['tend'] - s['tstart'] | |
# a mask (downsampled to the Whisper encoder token rate of 50/s) is used | |
# to teach the model the concept of padding | |
# this let's us decode shorter sequences later | |
mask = torch.zeros(30*16000//320, dtype=torch.bool) | |
mask[:int(seconds * 16000) // 320] = 1 | |
s['mask'] = mask | |
yield s | |
def tokenize_text(samples, ttoks_size=200, model="base.en", language="en"): | |
multilingual = not model.endswith(".en") | |
tokenizer = whisper.tokenizer.get_tokenizer(multilingual, language=language, task="transcribe") | |
for s in samples: | |
ttoks = tokenizer.encode(s['txt']) | |
tokens = list(tokenizer.sot_sequence) + ttoks | |
rpad = ttoks_size - len(tokens) | |
s['in_ttoks'] = F.pad(torch.tensor(tokens), (0, rpad), value=tokenizer.eot) | |
s['out_ttoks'] = F.pad(torch.tensor(tokens[1:] + [tokenizer.eot]), (0, rpad), value=-100) | |
yield s | |
# %% ../nbs/2B. Whisper quantization (semantic token) model.ipynb 22 | |
def load_dataset( | |
shard_spec:str, | |
proc_dataset_path:Path, # processed VAD and txt files | |
samples:int, # set the per-GPU sample count | |
txt_label:str="base.en-txt", # the label of the files containing transcriptions | |
model:str="base.en", | |
key:str="flac", | |
language:str=None, | |
validation:bool=False, | |
): | |
from . import wh_transcribe | |
shards = utils.shard_glob(shard_spec) | |
if not language and model.endswith('en'): language = 'en' | |
assert language, "please provide the dataset language for multilang models" | |
same_on_all_nodes = lambda urls: urls # will only be used for validation | |
ds = wds.WebDataset(shards, resampled=not validation, nodesplitter=same_on_all_nodes).compose( | |
wds.decode(wds.torch_audio), | |
wds.select(lambda x: 'wav' in x or 'flac' in x or 'mp3' in x or 'ogg' in x), # skip samples without audio | |
wds.rename(audio="flac;mp3;wav;ogg"), | |
merge_in(derived_dataset(proc_dataset_path, 'vad', key=key)), | |
wds.map_dict(**{"vad.npy":wh_transcribe.chunk_merger}), | |
wh_transcribe.split_to_chunks, | |
utils.resampler(16000, 'samples_16k'), | |
merge_in(derived_dataset(proc_dataset_path, txt_label, key=key)), | |
) | |
if 'librilight' in shards[0]: | |
ds = ds.compose( | |
# drop the first and last segment because they tend to be inaccurate | |
# (the transcriptions don't have the "LibriVox" headers and "end of chapter" suffixes) | |
wds.select(lambda x: x['i'] != 0 and x['i'] != x['imax']), | |
) | |
ds = ds.compose( | |
add_masks, | |
lambda x: tokenize_text(x, model=model, language=language), | |
wds.to_tuple('samples_16k', 'mask', 'in_ttoks', 'out_ttoks'), | |
wds.batched(32), | |
) | |
ds.total_samples = samples | |
return ds | |
# %% ../nbs/2B. Whisper quantization (semantic token) model.ipynb 28 | |
from whisperspeech.train import * | |
from whisperspeech.modules import * | |
# %% ../nbs/2B. Whisper quantization (semantic token) model.ipynb 29 | |
import dataclasses | |
def rand(start, end): | |
return random.random() * (end - start) + start | |
def logrand(start, end): | |
return 10**rand(math.log10(start), math.log10(end)) | |
class Tunables: | |
init_std :float = 1.5 | |
embeddings_std :float = 4.5e-2 | |
embeddings_lr_scale: float = 1 | |
output_mult :float = 1 | |
query_mult :float = 2 | |
rope :bool = True | |
mask_embs :bool = True # force embeddings corresponding to the input audio padding to a constant value | |
downsample_conv: bool = False | |
downsample_mean: bool = True | |
codebook_dim: int = 32 | |
codebook_decay: float = 0.9 | |
lr0 :float = .9e-3 | |
clip_gradient_norm :float = 2 | |
weight_decay :float = 1e-3 | |
warmup_steps :float = 850 | |
random :bool = False | |
def __post_init__(self): | |
# randomize the hyperparams if requested | |
if self.random: | |
self.init_std = logrand(1, 2) | |
self.embeddings_std = logrand(3e-2,6e-2) | |
self.embeddings_lr_scale = 2**rand(0,3) | |
self.output_mult = 2**rand(-3,3) | |
self.query_mult = logrand(1,8) | |
self.codebook_dim = int(logrand(30,50)) | |
self.codebook_decay = logrand(0.86,0.95) | |
self.rope = True | |
self.mask_embs = True | |
self.downsample_mean = True | |
self.lr0 = logrand(.8e-3,1e-3) | |
self.clip_gradient_norm = 10**rand(-1,1) | |
self.warmup_steps = logrand(700,1000) | |
def upgrade(args): | |
args = {k:v for k,v in args.items()} | |
def old_default(name, value): | |
if name not in args: args[name] = value | |
old_default('output_mult', 1) | |
old_default('query_mult', 1) | |
old_default('rope', False) | |
old_default('mask_embs', False) | |
old_default('downsample_conv', False) | |
old_default('downsample_mean', False) | |
if 'encoder_depth_ratio' in args: del args['encoder_depth_ratio'] | |
if 'vq_codes' in args: del args['vq_codes'] | |
return args | |
# %% ../nbs/2B. Whisper quantization (semantic token) model.ipynb 30 | |
import math | |
# %% ../nbs/2B. Whisper quantization (semantic token) model.ipynb 31 | |
class RQBottleneckTransformer(nn.Module): | |
def __init__(self, vq_codes=512, q_depth=12, depth=1, n_head=2, head_width=64, ffn_mult=4, | |
codebook_dim=2, threshold_ema_dead_code=2, use_cosine_sim = False, kl_loss_mul=1, | |
downsample=1, | |
whisper_model_name='tiny.en', tunables=Tunables()): | |
super().__init__() | |
width = n_head * head_width | |
store_attr("codebook_dim,vq_codes,q_depth,n_head,head_width,ffn_mult,depth,use_cosine_sim,downsample,whisper_model_name") | |
self.width = width | |
self.base_width = 3 * head_width | |
self.vq_codes = vq_codes | |
self.tunables = tunables | |
self.stoks_len = 1500//downsample | |
self.stoks_per_sec = self.stoks_len//30 | |
qk_scale = self.tunables.query_mult * 8 / math.sqrt(head_width) | |
self.kl_loss_mul = kl_loss_mul | |
n_mlp = width * ffn_mult | |
self.mlp = nn.Sequential( | |
nn.Linear(width, n_mlp), nn.GELU(), nn.Linear(n_mlp, width) | |
) | |
self.mlp_ln = LayerNorm(width) | |
if tunables.downsample_conv: | |
self.downsample_conv = nn.Conv1d(width, width, kernel_size=3, stride=downsample, padding=1) | |
else: | |
self.downsample_conv = None | |
if tunables.mask_embs: vq_codes = vq_codes + 1 | |
self.rq = ResidualVQ( | |
dim = width, | |
codebook_size = vq_codes, # codebook size | |
decay = tunables.codebook_decay, # the exponential moving average decay, lower means the dictionary will change faster | |
commitment_weight = 1., # the weight on the commitment loss | |
threshold_ema_dead_code = threshold_ema_dead_code, | |
use_cosine_sim = use_cosine_sim, | |
codebook_dim = codebook_dim, | |
num_quantizers= 1, | |
) | |
self.ce_lossf = nn.CrossEntropyLoss(ignore_index=-100) | |
self.kl_lossf = nn.KLDivLoss(reduction='batchmean') | |
self.positional_embedding = nn.Embedding(1500, width) # FIXME: should be self.stoks_len | |
self.out_blocks = nn.Sequential(*[ | |
ResidualAttentionBlock(width, n_head, qk_scale=qk_scale, ffn_mult=ffn_mult, rope=tunables.rope) for _ in range(depth) | |
]) | |
self.ln_post = LayerNorm(width) | |
self.whmodel = None | |
self.apply(self.init_transformer) | |
self.register_buffer('val_true', torch.zeros(1).cuda()) | |
self.register_buffer('val_total', torch.zeros(1).cuda()) | |
def setup(self, device): | |
self.ensure_whisper(device) | |
def init_transformer(self, m): | |
if isinstance(m, LinearHead): | |
m.no_weight_decay = True | |
torch.nn.init.constant_(m.weight, 0) | |
elif isinstance(m, QueryHead): | |
m.lr_scale = 1/(m.weight.shape[1] / self.base_width) | |
torch.nn.init.constant_(m.weight, 0) | |
elif isinstance(m, nn.Embedding): | |
m.no_weight_decay = True | |
m.lr_scale = self.tunables.embeddings_lr_scale | |
std = self.tunables.embeddings_std | |
torch.nn.init.trunc_normal_(m.weight, std=std, a=-3*std, b=3*std) | |
elif isinstance(m, nn.Linear): | |
m.lr_scale = 1/(m.weight.shape[1] / self.base_width) | |
std = self.tunables.init_std / m.weight.shape[1] | |
torch.nn.init.trunc_normal_(m.weight, std=std, a=-3*std, b=3*std) | |
if m.bias is not None: | |
torch.nn.init.trunc_normal_(m.bias, std=std, a=-3*std, b=3*std) | |
elif isinstance(m, nn.LayerNorm): | |
m.no_weight_decay = True | |
torch.nn.init.constant_(m.bias, 0) | |
torch.nn.init.constant_(m.weight, 1) | |
def device(self): | |
return next(self.parameters()).device | |
# | |
# training | |
# | |
def extract_teacher(self, samples, input_toks, output_toks): | |
embs = self.whmodel[0].encoder(whisper.log_mel_spectrogram(samples)) | |
teacher_logits = self.whmodel[0].decoder(input_toks, embs) | |
# set teacher logits to 0 for padding positions so KLDivLoss ignores them | |
teacher_logits[output_toks == -100] = 0 | |
return embs, teacher_logits | |
def downsample_embeddings(self, x): | |
if self.downsample_conv is not None: | |
return x[:,::self.downsample] + self.downsample_conv(x.transpose(-1,-2)).transpose(-2,-1) | |
elif self.tunables.downsample_mean: | |
bs,slen,depth = x.shape | |
return x.reshape(bs,slen//self.downsample,self.downsample,depth).mean(-2) | |
else: | |
return x[:,::self.downsample] | |
def forward(self, samples, mask, input_toks, output_toks): | |
embs, teacher_logits = self.extract_teacher(samples, input_toks, output_toks) | |
x = self.downsample_embeddings(embs) | |
x = x + self.mlp(self.mlp_ln(x)) | |
# VQ bottleneck | |
quantized, self.indices, self.commit_loss = self.rq(x) | |
self.commit_loss = self.commit_loss.mean() | |
x = quantized.repeat_interleave(self.downsample, -2) | |
project_out = getattr(self.rq, 'project_out', None) or self.rq.layers[0].project_out | |
if self.tunables.mask_embs: x[~mask] = project_out(self.rq.layers[0]._codebook.embed[0,self.vq_codes]) | |
positions = torch.arange(0, x.shape[-2], dtype=torch.long, device=x.device) | |
x = x + self.positional_embedding(positions) | |
x = self.ln_post(self.out_blocks(x)) | |
logits = self.whmodel[0].decoder(input_toks, x) | |
self.ce_loss = self.ce_lossf(logits.view(-1,logits.shape[-1]), output_toks.view(-1)) | |
self.kl_loss = self.kl_lossf(F.log_softmax(logits, dim=-1), F.softmax(teacher_logits, dim=-1)) | |
loss = self.ce_loss + self.kl_loss_mul * self.kl_loss + self.commit_loss | |
if not self.training: | |
valid_toks = output_toks != -100 | |
self.val_true += (logits.argmax(-1)[valid_toks] == output_toks[valid_toks]).float().sum() | |
self.val_total += valid_toks.float().sum() | |
return x, loss | |
def get_metrics(self): | |
metrics = { | |
'acc_0': (self.val_true / self.val_total).item(), | |
} | |
self.val_true[:] = 0 | |
self.val_total[:] = 0 | |
return metrics | |
# | |
# inference | |
# | |
def load_model(cls, ref="collabora/spear-tts-pytorch:whisper-vq-stoks-medium-en+pl.model", | |
repo_id=None, filename=None, local_filename=None): | |
if repo_id is None and filename is None and local_filename is None: | |
if ":" in ref: | |
repo_id, filename = ref.split(":", 1) | |
else: | |
local_filename = ref | |
if not local_filename: | |
local_filename = hf_hub_download(repo_id=repo_id, filename=filename) | |
spec = torch.load(local_filename) | |
vqmodel = cls(**spec['config'], tunables=Tunables(**Tunables.upgrade(spec.get('tunables', {})))) | |
vqmodel.load_state_dict(spec['state_dict']) | |
vqmodel.eval() | |
return vqmodel | |
def load_checkpoint(self, local_filename): | |
spec = torch.load(local_filename, map_location='cpu') | |
assert 'pytorch-lightning_version' in spec, 'not a valid PyTorch Lightning checkpoint' | |
state_dict = {k.replace('model.', ''):v | |
for k,v in spec['state_dict'].items()} | |
self.load_state_dict(state_dict) | |
return self | |
def save_model(self, fname, store_parameters=True): | |
torch.save(dict(config = self.__stored_args__, | |
tunables = dataclasses.asdict(self.tunables), | |
state_dict = self.state_dict() if store_parameters else None), fname) | |
def ensure_whisper(self, device): | |
# the list wrapper is a hack to make sure the whole of Whisper is not sucked into self.parameters() | |
if self.whmodel is None: self.whmodel = [whisper.load_model(self.whisper_model_name, device=device)] | |
self.decoding_options = whisper.DecodingOptions() | |
multilingual = not self.whisper_model_name.endswith('.en') | |
self.tokenizer = whisper.tokenizer.get_tokenizer(multilingual) | |
def quantize(self, embs): | |
x = self.downsample_embeddings(embs) | |
x = x + self.mlp(self.mlp_ln(x)) | |
_, stoks, _ = self.rq(x) | |
if self.q_depth == 1: | |
stoks = stoks.squeeze(-1) | |
return stoks | |
def dequantize(self, stoks): | |
assert self.q_depth == 1 | |
assert len(stoks.shape) == 1, "batch processing is not supported" | |
if isinstance(stoks, np.ndarray): stoks = torch.tensor(stoks) | |
# remove padding | |
padding = torch.nonzero(stoks == self.vq_codes) | |
if padding.any(): stoks = stoks[:padding[0,0]] | |
stoks = F.pad(stoks, (0,self.stoks_len - stoks.shape[-1]), value=self.vq_codes if self.tunables.mask_embs else 0) | |
x = self.rq.layers[0]._codebook.embed[0,stoks.to(torch.long).view(-1)] | |
x = x.repeat_interleave(self.downsample, -2) | |
project_out = getattr(self.rq, 'project_out', None) or self.rq.layers[0].project_out | |
x = project_out(x).unsqueeze(0) | |
positions = torch.arange(0, x.shape[-2], dtype=torch.long, device=x.device) | |
x = x + self.positional_embedding(positions) | |
return self.ln_post(self.out_blocks(x)) | |
def encode_audio(self, audio): | |
if isinstance(audio, str): | |
x, sr = torchaudio.load(audio) | |
x = torchaudio.transforms.Resample(sr, 16000)(x)[0] | |
audio = x.unsqueeze(0) | |
return self.encode_mel(whisper.log_mel_spectrogram(audio).to(self.device)) | |
def encode_mel(self, mel): | |
assert len(mel.shape) == 3, "invalid mel spectrogram shape, expect (batch,chn,time)" | |
self.ensure_whisper(self.device) | |
n = mel.shape[-1] | |
if n > whisper.audio.N_FRAMES: | |
padding = 0 | |
padded = mel[:,:,:whisper.audio.N_FRAMES] | |
else: | |
padding = -n % whisper.audio.N_FRAMES | |
padded = F.pad(mel, (0, padding), value=-1.5) | |
embs = self.whmodel[0].encoder(padded)#.to(self.whmodel[0].device))#[:,:n//2] | |
stoks = self.quantize(embs) | |
if self.tunables.mask_embs: | |
return stoks[:,:n//2//self.downsample] | |
else: | |
return stoks | |
def decode_text(self, stoks, decoding_options=None): | |
self.ensure_whisper(self.device) | |
if decoding_options is None: decoding_options = self.decoding_options | |
embs = self.dequantize(stoks).to(self.whmodel[0].device) | |
return self.whmodel[0].decode(embs, decoding_options) | |
# %% ../nbs/2B. Whisper quantization (semantic token) model.ipynb 33 | |
def make_model(size:str, tunables:Tunables=Tunables(), dataset:torch.utils.data.Dataset=None): | |
if size == 'base.en-2d-4096c': | |
model = RQBottleneckTransformer(codebook_dim=32, vq_codes=4096, q_depth=1, n_head=8, depth=1, | |
downsample=2, threshold_ema_dead_code=0, use_cosine_sim=True, | |
whisper_model_name=size.split("-")[0], tunables=tunables) | |
return model | |
if size == 'base.en-2d-512c': | |
model = RQBottleneckTransformer(codebook_dim=32, vq_codes=512, q_depth=1, n_head=8, depth=1, | |
downsample=2, threshold_ema_dead_code=0, use_cosine_sim=True, | |
whisper_model_name=size.split("-")[0], tunables=tunables) | |
return model | |
if size == 'base.en-2d-512c-dim64': | |
model = RQBottleneckTransformer(codebook_dim=64, vq_codes=512, q_depth=1, n_head=8, depth=1, | |
downsample=2, threshold_ema_dead_code=0, use_cosine_sim=True, | |
whisper_model_name=size.split("-")[0], tunables=tunables) | |
return model | |
if size == 'base-2d-512c-dim64': | |
model = RQBottleneckTransformer(codebook_dim=64, vq_codes=512, q_depth=1, n_head=8, depth=1, | |
downsample=2, threshold_ema_dead_code=0, use_cosine_sim=True, | |
whisper_model_name=size.split("-")[0], tunables=tunables) | |
return model | |
if size == 'base-2d-1024c-dim64': | |
model = RQBottleneckTransformer(codebook_dim=64, vq_codes=1024, q_depth=1, n_head=8, depth=1, | |
downsample=2, threshold_ema_dead_code=0, use_cosine_sim=True, | |
whisper_model_name=size.split("-")[0], tunables=tunables) | |
return model | |
if size == 'medium-2d-512c-dim64': | |
model = RQBottleneckTransformer(codebook_dim=64, vq_codes=512, q_depth=1, n_head=16, depth=1, | |
downsample=2, threshold_ema_dead_code=0, use_cosine_sim=True, | |
whisper_model_name=size.split("-")[0], tunables=tunables) | |
return model | |
if size == 'medium-2d-1024c-dim64': | |
model = RQBottleneckTransformer(codebook_dim=64, vq_codes=1024, q_depth=1, n_head=16, depth=1, | |
downsample=2, threshold_ema_dead_code=0, use_cosine_sim=True, | |
whisper_model_name=size.split("-")[0], tunables=tunables) | |
return model | |
raise ArgumentError(f"invalid model size: {size}") | |