tortoise5c / tortoise /models /transformer.py
djkesu's picture
added model
3bbf2c7
raw
history blame
6.41 kB
import torch
import torch.nn.functional as F
from einops import rearrange
from torch import nn
# helpers
def exists(val):
return val is not None
def default(val, d):
return val if exists(val) else d
def cast_tuple(val, depth=1):
if isinstance(val, list):
val = tuple(val)
return val if isinstance(val, tuple) else (val,) * depth
def max_neg_value(t):
return -torch.finfo(t.dtype).max
def stable_softmax(t, dim=-1, alpha=32**2):
t = t / alpha
t = t - torch.amax(t, dim=dim, keepdim=True).detach()
return (t * alpha).softmax(dim=dim)
def route_args(router, args, depth):
routed_args = [(dict(), dict()) for _ in range(depth)]
matched_keys = [key for key in args.keys() if key in router]
for key in matched_keys:
val = args[key]
for depth, ((f_args, g_args), routes) in enumerate(
zip(routed_args, router[key])
):
new_f_args, new_g_args = map(
lambda route: ({key: val} if route else {}), routes
)
routed_args[depth] = ({**f_args, **new_f_args}, {**g_args, **new_g_args})
return routed_args
# classes
class SequentialSequence(nn.Module):
def __init__(self, layers, args_route={}, layer_dropout=0.0):
super().__init__()
assert all(
len(route) == len(layers) for route in args_route.values()
), "each argument route map must have the same depth as the number of sequential layers"
self.layers = layers
self.args_route = args_route
self.layer_dropout = layer_dropout
def forward(self, x, **kwargs):
args = route_args(self.args_route, kwargs, len(self.layers))
layers_and_args = list(zip(self.layers, args))
for (f, g), (f_args, g_args) in layers_and_args:
x = x + f(x, **f_args)
x = x + g(x, **g_args)
return x
class DivideMax(nn.Module):
def __init__(self, dim):
super().__init__()
self.dim = dim
def forward(self, x):
maxes = x.amax(dim=self.dim, keepdim=True).detach()
return x / maxes
# https://arxiv.org/abs/2103.17239
class LayerScale(nn.Module):
def __init__(self, dim, depth, fn):
super().__init__()
if depth <= 18:
init_eps = 0.1
elif depth > 18 and depth <= 24:
init_eps = 1e-5
else:
init_eps = 1e-6
scale = torch.zeros(1, 1, dim).fill_(init_eps)
self.scale = nn.Parameter(scale)
self.fn = fn
def forward(self, x, **kwargs):
return self.fn(x, **kwargs) * self.scale
# layer norm
class PreNorm(nn.Module):
def __init__(self, dim, fn, sandwich=False):
super().__init__()
self.norm = nn.LayerNorm(dim)
self.norm_out = nn.LayerNorm(dim) if sandwich else nn.Identity()
self.fn = fn
def forward(self, x, **kwargs):
x = self.norm(x)
x = self.fn(x, **kwargs)
return self.norm_out(x)
# feed forward
class GEGLU(nn.Module):
def forward(self, x):
x, gates = x.chunk(2, dim=-1)
return x * F.gelu(gates)
class FeedForward(nn.Module):
def __init__(self, dim, dropout=0.0, mult=4.0):
super().__init__()
self.net = nn.Sequential(
nn.Linear(dim, dim * mult * 2),
GEGLU(),
nn.Dropout(dropout),
nn.Linear(dim * mult, dim),
)
def forward(self, x):
return self.net(x)
# Attention
class Attention(nn.Module):
def __init__(self, dim, seq_len, causal=True, heads=8, dim_head=64, dropout=0.0):
super().__init__()
inner_dim = dim_head * heads
self.heads = heads
self.seq_len = seq_len
self.scale = dim_head**-0.5
self.causal = causal
self.to_qkv = nn.Linear(dim, inner_dim * 3, bias=False)
self.to_out = nn.Sequential(nn.Linear(inner_dim, dim), nn.Dropout(dropout))
def forward(self, x, mask=None):
b, n, _, h, device = *x.shape, self.heads, x.device
softmax = torch.softmax
qkv = self.to_qkv(x).chunk(3, dim=-1)
q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=h), qkv)
q = q * self.scale
dots = torch.einsum("b h i d, b h j d -> b h i j", q, k)
mask_value = max_neg_value(dots)
if exists(mask):
mask = rearrange(mask, "b j -> b () () j")
dots.masked_fill_(~mask, mask_value)
del mask
if self.causal:
i, j = dots.shape[-2:]
mask = torch.ones(i, j, device=device).triu_(j - i + 1).bool()
dots.masked_fill_(mask, mask_value)
attn = softmax(dots, dim=-1)
out = torch.einsum("b h i j, b h j d -> b h i d", attn, v)
out = rearrange(out, "b h n d -> b n (h d)")
out = self.to_out(out)
return out
# main transformer class
class Transformer(nn.Module):
def __init__(
self,
*,
dim,
depth,
seq_len,
causal=True,
heads=8,
dim_head=64,
ff_mult=4,
attn_dropout=0.0,
ff_dropout=0.0,
sparse_attn=False,
sandwich_norm=False,
):
super().__init__()
layers = nn.ModuleList([])
sparse_layer = cast_tuple(sparse_attn, depth)
for ind, sparse_attn in zip(range(depth), sparse_layer):
attn = Attention(
dim,
causal=causal,
seq_len=seq_len,
heads=heads,
dim_head=dim_head,
dropout=attn_dropout,
)
ff = FeedForward(dim, mult=ff_mult, dropout=ff_dropout)
layers.append(
nn.ModuleList(
[
LayerScale(
dim, ind + 1, PreNorm(dim, attn, sandwich=sandwich_norm)
),
LayerScale(
dim, ind + 1, PreNorm(dim, ff, sandwich=sandwich_norm)
),
]
)
)
execute_type = SequentialSequence
route_attn = ((True, False),) * depth
attn_route_map = {"mask": route_attn}
self.layers = execute_type(layers, args_route=attn_route_map)
def forward(self, x, **kwargs):
return self.layers(x, **kwargs)