tonic
Laion WhisperSpeech Demo
33d9042
raw
history blame
12.9 kB
# AUTOGENERATED! DO NOT EDIT! File to edit: ../nbs/A. Neural modules.ipynb.
# %% auto 0
__all__ = ['LayerNorm', 'LinearHead', 'QueryHead', 'init_transformer', 'sinusoids', 'MultiHeadAttention',
'ResidualAttentionBlock', 'BaseDecoder', 'EmbeddingProjector', 'FlexEmbeddings']
# %% ../nbs/A. Neural modules.ipynb 2
import torch
import numpy as np
import math
from torch import Tensor, nn
import torch.nn.functional as F
from typing import Dict, Iterable, Optional
# import xformers.ops as xops
# %% ../nbs/A. Neural modules.ipynb 3
# Code in this file is mostly borrowed from
# https://github.com/openai/whisper/blob/main/whisper/model.py
# and is under the MIT License
class LayerNorm(nn.LayerNorm):
def forward(self, x):
return super().forward(x.float()).type(x.dtype)
# Used in μP to initialize the weights and configure the optimizer
# These two layers map the transformer width into a fixed dimension
class LinearHead(nn.Linear):
pass
class QueryHead(nn.Linear):
pass
# based on https://github.com/karpathy/minGPT/blob/master/mingpt/model.py#L163
def init_transformer(m):
if isinstance(m, (nn.Linear, nn.Embedding)):
torch.nn.init.trunc_normal_(m.weight, std=.02)
if isinstance(m, nn.Linear) and m.bias is not None:
torch.nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
torch.nn.init.constant_(m.bias, 0)
torch.nn.init.constant_(m.weight, 1.0)
# %% ../nbs/A. Neural modules.ipynb 4
def sinusoids(length, channels, max_timescale=10000):
"""Returns sinusoids for positional embedding"""
assert channels % 2 == 0
log_timescale_increment = np.log(max_timescale) / (channels // 2 - 1)
inv_timescales = torch.exp(-log_timescale_increment * torch.arange(channels // 2))
scaled_time = torch.arange(length)[:, np.newaxis] * inv_timescales[np.newaxis, :]
return torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=1)
# %% ../nbs/A. Neural modules.ipynb 5
class MultiHeadAttention(nn.Module):
def __init__(self, n_state: int, n_head: int, qk_scale: float = 1, rope: bool = False, cross=False):
super().__init__()
self.n_state = n_state
self.n_head = n_head
self.sqrt_qk_scale = math.sqrt(qk_scale)
self.query = QueryHead(n_state, n_state)
self.key = nn.Linear(n_state, n_state, bias=False)
self.value = nn.Linear(n_state, n_state)
self.out = nn.Linear(n_state, n_state)
self.cross = cross
self.query_subsampling = 1
self.key_subsampling = 1
self.cached_kvx = None
self.register_buffer('k_cache', None)
self.register_buffer('v_cache', None)
self.rotary = None
if rope:
self.rotary = Rotary(n_state // n_head)
self.qkv = None
self.kv = None
def setup_kv_cache(self, max_batch_size, max_seq_len, dtype=torch.float32):
cache_shape = (max_batch_size, self.n_head, max_seq_len, self.n_state//self.n_head)
self.k_cache = torch.zeros(cache_shape, dtype=dtype, device=self.key.weight.device)
self.v_cache = torch.zeros(cache_shape, dtype=dtype, device=self.value.weight.device)
def merge_linears(self, layers, mults):
bias = [x.bias for x in layers if x.bias is not None][0]
din, dout = layers[0].weight.shape
new = nn.Linear(din, len(layers) * dout).to(layers[0].weight.device)
with torch.no_grad():
new.weight[:] = torch.cat([x.weight * m for x,m in zip(layers, mults)])
new.bias[:] = torch.cat([torch.zeros_like(bias) if x.bias is None else x.bias * m for x, m in zip(layers, mults)])
return new
def convert_for_eval(self):
if self.qkv or self.kv: raise AttributeError("already converted")
self.odim = self.key.weight.shape[1]
if self.cross:
self.q = self.merge_linears([self.query], [self.sqrt_qk_scale])
self.kv = self.merge_linears([self.key, self.value],
[self.sqrt_qk_scale, 1])
else:
self.qkv = self.merge_linears([self.query, self.key, self.value],
[self.sqrt_qk_scale, self.sqrt_qk_scale, 1])
def split_heads(self, x, x_positions, rope=False, subsampling=1):
x = x.view(*x.shape[:2], self.n_head, -1)
if rope:
x = rope_rotate(x, x_positions * subsampling, *self.rotary(x))
return x.permute(0, 2, 1, 3)
def forward(
self,
qx,
q_positions,
kvx,
kv_positions,
causal = False,
mask=None,
):
if self.qkv:
q,k,v = self.qkv(qx).split(self.odim, dim=-1)
elif self.kv:
q = self.q(qx)
k,v = self.kv(kvx).split(self.odim, dim=-1)
else:
q,k,v = None,None,None
if q is None: q = self.query(qx) * self.sqrt_qk_scale
q = self.split_heads(q, q_positions, rope = self.rotary, subsampling = self.query_subsampling)
if kvx is not self.cached_kvx:
if k is None: k = self.key(kvx) * self.sqrt_qk_scale
k = self.split_heads(k, kv_positions, rope = self.rotary, subsampling = self.key_subsampling)
if v is None: v = self.value(kvx)
v = self.split_heads(v, kv_positions)
if self.k_cache is not None:
self.k_cache[:,:,kv_positions] = k
self.v_cache[:,:,kv_positions] = v
if self.k_cache is not None:
k, v = self.k_cache, self.v_cache
if mask is not None:
mask = mask[q_positions]
wv = F.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0, is_causal=causal)
return self.out(wv.permute(0, 2, 1, 3).flatten(start_dim=2))
# %% ../nbs/A. Neural modules.ipynb 6
# modified from https://blog.eleuther.ai/rotary-embeddings/
import torch
class Rotary(torch.nn.Module):
def __init__(self, dim, base=10000):
super().__init__()
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim))
self.register_buffer("inv_freq", inv_freq)
self.seq_len_cached = None
self.cos_cached = None
self.sin_cached = None
def forward(self, x, seq_dim=1):
seq_len = x.shape[seq_dim]
if not self.seq_len_cached or seq_len > self.seq_len_cached:
self.seq_len_cached = 2500
# self.seq_len_cached = seq_len
t = torch.arange(self.seq_len_cached, device=x.device).type_as(self.inv_freq)
freqs = torch.einsum("i,j->ij", t, self.inv_freq)
emb = torch.cat((freqs, freqs), dim=-1).to(x.device)
self.cos_cached = emb.cos()[None, :, None, :]
self.sin_cached = emb.sin()[None, :, None, :]
return self.cos_cached, self.sin_cached
# rotary pos emb helpers:
def rotate_half(x):
x1, x2 = x[..., : x.shape[-1] // 2], x[..., x.shape[-1] // 2 :]
return torch.cat(
(-x2, x1), dim=len(x.shape)-1
)
def rope_rotate(x, positions, cos, sin):
return x * cos[:,positions] + rotate_half(x) * sin[:,positions]
# %% ../nbs/A. Neural modules.ipynb 7
class ResidualAttentionBlock(nn.Module):
def __init__(self, n_state: int, n_head: int, cross_attention: bool = False, rope: bool = False,
qk_scale: float = 1, ffn_mult: int = 4):
super().__init__()
self.attn = MultiHeadAttention(n_state, n_head, qk_scale=qk_scale, rope=rope)
self.attn_ln = LayerNorm(n_state)
self.cross_attn = (
MultiHeadAttention(n_state, n_head, qk_scale=qk_scale, rope=rope, cross=True) if cross_attention else None
)
self.cross_attn_ln = LayerNorm(n_state) if cross_attention else None
n_mlp = n_state * ffn_mult
self.mlp = nn.Sequential(
nn.Linear(n_state, n_mlp), nn.GELU(), nn.Linear(n_mlp, n_state)
)
self.mlp_ln = LayerNorm(n_state)
def setup_kv_cache(self, max_batch_size, max_seq_len, max_cross_seq_len=None):
self.attn.setup_kv_cache(max_batch_size, max_seq_len)
if self.cross_attn:
self.cross_attn.setup_kv_cache(max_batch_size, max_cross_seq_len)
def forward(
self,
x: Tensor,
x_positions: Tensor = None,
xa: Optional[Tensor] = None,
xa_positions: Optional[Tensor] = None,
causal = False,
mask=None,
):
lnx = self.attn_ln(x)
x = x + self.attn(lnx, x_positions, lnx, x_positions, causal=causal, mask=mask)
if self.cross_attn:
lnx = self.cross_attn_ln(x)
x = x + self.cross_attn(lnx, x_positions, xa, xa_positions)
x = x + self.mlp(self.mlp_ln(x))
return x
# %% ../nbs/A. Neural modules.ipynb 8
class BaseDecoder(nn.Module):
def __init__(self, depth=6, n_head=6, width=384, qk_scale=1, ffn_mult=4, length=2250, rope=False):
super().__init__()
self.length = length
self.width = width
self.layers = nn.ModuleList([
ResidualAttentionBlock(
self.width, n_head, qk_scale=qk_scale, ffn_mult=ffn_mult, cross_attention=True, rope=rope
) for _ in range(math.floor(depth))
])
self.ln_post = LayerNorm(width)
mask = torch.empty(length, length).fill_(-torch.inf).triu_(1)
self.register_buffer("mask", mask, persistent=False)
def forward(self, x, x_positions, xenc, xenc_positions):
for i,l in enumerate(self.layers):
x = l(x, x_positions, xenc, xenc_positions, causal=False, mask=self.mask)
x = self.ln_post(x)
return x
# %% ../nbs/A. Neural modules.ipynb 9
class EmbeddingProjector(nn.Linear):
pass
class FlexEmbeddings(nn.Module):
def __init__(self, codes, width, special_codes=None, frozen_width=None, special_embedding=None, unembed=True):
super().__init__()
self.codes = codes
self.special_codes = special_codes
if frozen_width is None: frozen_width = width
self.main = nn.Embedding(codes, frozen_width or width)
self.emb_to_hidden = EmbeddingProjector(frozen_width, width) if frozen_width != width else None
self.hidden_to_emb = EmbeddingProjector(width, frozen_width) if unembed and frozen_width != width else None
if special_codes:
self.special = special_embedding or nn.Embedding(special_codes, width)
self.register_buffer('merged_in', None)
self.register_buffer('merged_out', None)
self.register_buffer('bias_out', None)
def set_frozen_embeddings(self, values):
with torch.no_grad():
self.main.weight[:] = values
self.main.lr_scale = 0
@torch.no_grad()
def convert_for_eval(self):
if not self.special_codes: return
# in
main_w = self.main.weight
if self.emb_to_hidden is not None: main_w = self.emb_to_hidden(main_w)
weight = torch.cat([main_w, self.special.weight], dim=0)
self.merged_in = nn.Embedding(*weight.shape, _weight=weight)
# out
weight = self.main.weight
if self.hidden_to_emb: weight = weight @ self.hidden_to_emb.weight
self.merged_out = torch.cat([weight.T, self.special.weight.T], dim=1).T.contiguous() # T is for F.linear
if self.hidden_to_emb:
self.bias_out = torch.cat([
self.hidden_to_emb.bias @ self.main.weight.T,
torch.zeros(self.special.weight.shape[0], device=weight.device, dtype=weight.dtype)
], dim=0)
else:
self.bias_out = None
def forward(self, toks):
if not self.training and self.merged_in is not None:
return self.merged_in(toks)
if self.special_codes:
special_mask = toks >= self.codes
embs = self.main(torch.where(special_mask, 0, toks))
else:
embs = self.main(toks)
if self.emb_to_hidden: embs = self.emb_to_hidden(embs)
if self.special_codes:
embs[special_mask] = self.special(toks[special_mask] - self.codes).to(embs.dtype)
return embs
def unembed(self, embs):
if not self.training and self.merged_out is not None:
return F.linear(embs, self.merged_out, self.bias_out) # embs @ self.merged_out + self.bias_out
orig_embs = embs
if self.hidden_to_emb: embs = self.hidden_to_emb(embs)
main_logits = (embs @ self.main.weight.to(embs.dtype).T).float()
if not self.special_codes:
return main_logits
special_logits = (orig_embs @ self.special.weight.to(orig_embs.dtype).T).float()
return torch.cat([main_logits, special_logits], dim=-1)