Spaces:
Paused
Paused
# AUTOGENERATED! DO NOT EDIT! File to edit: ../nbs/A. Neural modules.ipynb. | |
# %% auto 0 | |
__all__ = ['LayerNorm', 'LinearHead', 'QueryHead', 'init_transformer', 'sinusoids', 'MultiHeadAttention', | |
'ResidualAttentionBlock', 'BaseDecoder', 'EmbeddingProjector', 'FlexEmbeddings'] | |
# %% ../nbs/A. Neural modules.ipynb 2 | |
import torch | |
import numpy as np | |
import math | |
from torch import Tensor, nn | |
import torch.nn.functional as F | |
from typing import Dict, Iterable, Optional | |
# import xformers.ops as xops | |
# %% ../nbs/A. Neural modules.ipynb 3 | |
# Code in this file is mostly borrowed from | |
# https://github.com/openai/whisper/blob/main/whisper/model.py | |
# and is under the MIT License | |
class LayerNorm(nn.LayerNorm): | |
def forward(self, x): | |
return super().forward(x.float()).type(x.dtype) | |
# Used in μP to initialize the weights and configure the optimizer | |
# These two layers map the transformer width into a fixed dimension | |
class LinearHead(nn.Linear): | |
pass | |
class QueryHead(nn.Linear): | |
pass | |
# based on https://github.com/karpathy/minGPT/blob/master/mingpt/model.py#L163 | |
def init_transformer(m): | |
if isinstance(m, (nn.Linear, nn.Embedding)): | |
torch.nn.init.trunc_normal_(m.weight, std=.02) | |
if isinstance(m, nn.Linear) and m.bias is not None: | |
torch.nn.init.constant_(m.bias, 0) | |
elif isinstance(m, nn.LayerNorm): | |
torch.nn.init.constant_(m.bias, 0) | |
torch.nn.init.constant_(m.weight, 1.0) | |
# %% ../nbs/A. Neural modules.ipynb 4 | |
def sinusoids(length, channels, max_timescale=10000): | |
"""Returns sinusoids for positional embedding""" | |
assert channels % 2 == 0 | |
log_timescale_increment = np.log(max_timescale) / (channels // 2 - 1) | |
inv_timescales = torch.exp(-log_timescale_increment * torch.arange(channels // 2)) | |
scaled_time = torch.arange(length)[:, np.newaxis] * inv_timescales[np.newaxis, :] | |
return torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=1) | |
# %% ../nbs/A. Neural modules.ipynb 5 | |
class MultiHeadAttention(nn.Module): | |
def __init__(self, n_state: int, n_head: int, qk_scale: float = 1, rope: bool = False, cross=False): | |
super().__init__() | |
self.n_state = n_state | |
self.n_head = n_head | |
self.sqrt_qk_scale = math.sqrt(qk_scale) | |
self.query = QueryHead(n_state, n_state) | |
self.key = nn.Linear(n_state, n_state, bias=False) | |
self.value = nn.Linear(n_state, n_state) | |
self.out = nn.Linear(n_state, n_state) | |
self.cross = cross | |
self.query_subsampling = 1 | |
self.key_subsampling = 1 | |
self.cached_kvx = None | |
self.register_buffer('k_cache', None) | |
self.register_buffer('v_cache', None) | |
self.rotary = None | |
if rope: | |
self.rotary = Rotary(n_state // n_head) | |
self.qkv = None | |
self.kv = None | |
def setup_kv_cache(self, max_batch_size, max_seq_len, dtype=torch.float32): | |
cache_shape = (max_batch_size, self.n_head, max_seq_len, self.n_state//self.n_head) | |
self.k_cache = torch.zeros(cache_shape, dtype=dtype, device=self.key.weight.device) | |
self.v_cache = torch.zeros(cache_shape, dtype=dtype, device=self.value.weight.device) | |
def merge_linears(self, layers, mults): | |
bias = [x.bias for x in layers if x.bias is not None][0] | |
din, dout = layers[0].weight.shape | |
new = nn.Linear(din, len(layers) * dout).to(layers[0].weight.device) | |
with torch.no_grad(): | |
new.weight[:] = torch.cat([x.weight * m for x,m in zip(layers, mults)]) | |
new.bias[:] = torch.cat([torch.zeros_like(bias) if x.bias is None else x.bias * m for x, m in zip(layers, mults)]) | |
return new | |
def convert_for_eval(self): | |
if self.qkv or self.kv: raise AttributeError("already converted") | |
self.odim = self.key.weight.shape[1] | |
if self.cross: | |
self.q = self.merge_linears([self.query], [self.sqrt_qk_scale]) | |
self.kv = self.merge_linears([self.key, self.value], | |
[self.sqrt_qk_scale, 1]) | |
else: | |
self.qkv = self.merge_linears([self.query, self.key, self.value], | |
[self.sqrt_qk_scale, self.sqrt_qk_scale, 1]) | |
def split_heads(self, x, x_positions, rope=False, subsampling=1): | |
x = x.view(*x.shape[:2], self.n_head, -1) | |
if rope: | |
x = rope_rotate(x, x_positions * subsampling, *self.rotary(x)) | |
return x.permute(0, 2, 1, 3) | |
def forward( | |
self, | |
qx, | |
q_positions, | |
kvx, | |
kv_positions, | |
causal = False, | |
mask=None, | |
): | |
if self.qkv: | |
q,k,v = self.qkv(qx).split(self.odim, dim=-1) | |
elif self.kv: | |
q = self.q(qx) | |
k,v = self.kv(kvx).split(self.odim, dim=-1) | |
else: | |
q,k,v = None,None,None | |
if q is None: q = self.query(qx) * self.sqrt_qk_scale | |
q = self.split_heads(q, q_positions, rope = self.rotary, subsampling = self.query_subsampling) | |
if kvx is not self.cached_kvx: | |
if k is None: k = self.key(kvx) * self.sqrt_qk_scale | |
k = self.split_heads(k, kv_positions, rope = self.rotary, subsampling = self.key_subsampling) | |
if v is None: v = self.value(kvx) | |
v = self.split_heads(v, kv_positions) | |
if self.k_cache is not None: | |
self.k_cache[:,:,kv_positions] = k | |
self.v_cache[:,:,kv_positions] = v | |
if self.k_cache is not None: | |
k, v = self.k_cache, self.v_cache | |
if mask is not None: | |
mask = mask[q_positions] | |
wv = F.scaled_dot_product_attention(q, k, v, attn_mask=mask, dropout_p=0, is_causal=causal) | |
return self.out(wv.permute(0, 2, 1, 3).flatten(start_dim=2)) | |
# %% ../nbs/A. Neural modules.ipynb 6 | |
# modified from https://blog.eleuther.ai/rotary-embeddings/ | |
import torch | |
class Rotary(torch.nn.Module): | |
def __init__(self, dim, base=10000): | |
super().__init__() | |
inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float() / dim)) | |
self.register_buffer("inv_freq", inv_freq) | |
self.seq_len_cached = None | |
self.cos_cached = None | |
self.sin_cached = None | |
def forward(self, x, seq_dim=1): | |
seq_len = x.shape[seq_dim] | |
if not self.seq_len_cached or seq_len > self.seq_len_cached: | |
self.seq_len_cached = 2500 | |
# self.seq_len_cached = seq_len | |
t = torch.arange(self.seq_len_cached, device=x.device).type_as(self.inv_freq) | |
freqs = torch.einsum("i,j->ij", t, self.inv_freq) | |
emb = torch.cat((freqs, freqs), dim=-1).to(x.device) | |
self.cos_cached = emb.cos()[None, :, None, :] | |
self.sin_cached = emb.sin()[None, :, None, :] | |
return self.cos_cached, self.sin_cached | |
# rotary pos emb helpers: | |
def rotate_half(x): | |
x1, x2 = x[..., : x.shape[-1] // 2], x[..., x.shape[-1] // 2 :] | |
return torch.cat( | |
(-x2, x1), dim=len(x.shape)-1 | |
) | |
def rope_rotate(x, positions, cos, sin): | |
return x * cos[:,positions] + rotate_half(x) * sin[:,positions] | |
# %% ../nbs/A. Neural modules.ipynb 7 | |
class ResidualAttentionBlock(nn.Module): | |
def __init__(self, n_state: int, n_head: int, cross_attention: bool = False, rope: bool = False, | |
qk_scale: float = 1, ffn_mult: int = 4): | |
super().__init__() | |
self.attn = MultiHeadAttention(n_state, n_head, qk_scale=qk_scale, rope=rope) | |
self.attn_ln = LayerNorm(n_state) | |
self.cross_attn = ( | |
MultiHeadAttention(n_state, n_head, qk_scale=qk_scale, rope=rope, cross=True) if cross_attention else None | |
) | |
self.cross_attn_ln = LayerNorm(n_state) if cross_attention else None | |
n_mlp = n_state * ffn_mult | |
self.mlp = nn.Sequential( | |
nn.Linear(n_state, n_mlp), nn.GELU(), nn.Linear(n_mlp, n_state) | |
) | |
self.mlp_ln = LayerNorm(n_state) | |
def setup_kv_cache(self, max_batch_size, max_seq_len, max_cross_seq_len=None): | |
self.attn.setup_kv_cache(max_batch_size, max_seq_len) | |
if self.cross_attn: | |
self.cross_attn.setup_kv_cache(max_batch_size, max_cross_seq_len) | |
def forward( | |
self, | |
x: Tensor, | |
x_positions: Tensor = None, | |
xa: Optional[Tensor] = None, | |
xa_positions: Optional[Tensor] = None, | |
causal = False, | |
mask=None, | |
): | |
lnx = self.attn_ln(x) | |
x = x + self.attn(lnx, x_positions, lnx, x_positions, causal=causal, mask=mask) | |
if self.cross_attn: | |
lnx = self.cross_attn_ln(x) | |
x = x + self.cross_attn(lnx, x_positions, xa, xa_positions) | |
x = x + self.mlp(self.mlp_ln(x)) | |
return x | |
# %% ../nbs/A. Neural modules.ipynb 8 | |
class BaseDecoder(nn.Module): | |
def __init__(self, depth=6, n_head=6, width=384, qk_scale=1, ffn_mult=4, length=2250, rope=False): | |
super().__init__() | |
self.length = length | |
self.width = width | |
self.layers = nn.ModuleList([ | |
ResidualAttentionBlock( | |
self.width, n_head, qk_scale=qk_scale, ffn_mult=ffn_mult, cross_attention=True, rope=rope | |
) for _ in range(math.floor(depth)) | |
]) | |
self.ln_post = LayerNorm(width) | |
mask = torch.empty(length, length).fill_(-torch.inf).triu_(1) | |
self.register_buffer("mask", mask, persistent=False) | |
def forward(self, x, x_positions, xenc, xenc_positions): | |
for i,l in enumerate(self.layers): | |
x = l(x, x_positions, xenc, xenc_positions, causal=False, mask=self.mask) | |
x = self.ln_post(x) | |
return x | |
# %% ../nbs/A. Neural modules.ipynb 9 | |
class EmbeddingProjector(nn.Linear): | |
pass | |
class FlexEmbeddings(nn.Module): | |
def __init__(self, codes, width, special_codes=None, frozen_width=None, special_embedding=None, unembed=True): | |
super().__init__() | |
self.codes = codes | |
self.special_codes = special_codes | |
if frozen_width is None: frozen_width = width | |
self.main = nn.Embedding(codes, frozen_width or width) | |
self.emb_to_hidden = EmbeddingProjector(frozen_width, width) if frozen_width != width else None | |
self.hidden_to_emb = EmbeddingProjector(width, frozen_width) if unembed and frozen_width != width else None | |
if special_codes: | |
self.special = special_embedding or nn.Embedding(special_codes, width) | |
self.register_buffer('merged_in', None) | |
self.register_buffer('merged_out', None) | |
self.register_buffer('bias_out', None) | |
def set_frozen_embeddings(self, values): | |
with torch.no_grad(): | |
self.main.weight[:] = values | |
self.main.lr_scale = 0 | |
def convert_for_eval(self): | |
if not self.special_codes: return | |
# in | |
main_w = self.main.weight | |
if self.emb_to_hidden is not None: main_w = self.emb_to_hidden(main_w) | |
weight = torch.cat([main_w, self.special.weight], dim=0) | |
self.merged_in = nn.Embedding(*weight.shape, _weight=weight) | |
# out | |
weight = self.main.weight | |
if self.hidden_to_emb: weight = weight @ self.hidden_to_emb.weight | |
self.merged_out = torch.cat([weight.T, self.special.weight.T], dim=1).T.contiguous() # T is for F.linear | |
if self.hidden_to_emb: | |
self.bias_out = torch.cat([ | |
self.hidden_to_emb.bias @ self.main.weight.T, | |
torch.zeros(self.special.weight.shape[0], device=weight.device, dtype=weight.dtype) | |
], dim=0) | |
else: | |
self.bias_out = None | |
def forward(self, toks): | |
if not self.training and self.merged_in is not None: | |
return self.merged_in(toks) | |
if self.special_codes: | |
special_mask = toks >= self.codes | |
embs = self.main(torch.where(special_mask, 0, toks)) | |
else: | |
embs = self.main(toks) | |
if self.emb_to_hidden: embs = self.emb_to_hidden(embs) | |
if self.special_codes: | |
embs[special_mask] = self.special(toks[special_mask] - self.codes).to(embs.dtype) | |
return embs | |
def unembed(self, embs): | |
if not self.training and self.merged_out is not None: | |
return F.linear(embs, self.merged_out, self.bias_out) # embs @ self.merged_out + self.bias_out | |
orig_embs = embs | |
if self.hidden_to_emb: embs = self.hidden_to_emb(embs) | |
main_logits = (embs @ self.main.weight.to(embs.dtype).T).float() | |
if not self.special_codes: | |
return main_logits | |
special_logits = (orig_embs @ self.special.weight.to(orig_embs.dtype).T).float() | |
return torch.cat([main_logits, special_logits], dim=-1) | |