import torch import torch.nn.functional as F from einops import rearrange from torch import nn # helpers def exists(val): return val is not None def default(val, d): return val if exists(val) else d def cast_tuple(val, depth=1): if isinstance(val, list): val = tuple(val) return val if isinstance(val, tuple) else (val,) * depth def max_neg_value(t): return -torch.finfo(t.dtype).max def stable_softmax(t, dim=-1, alpha=32**2): t = t / alpha t = t - torch.amax(t, dim=dim, keepdim=True).detach() return (t * alpha).softmax(dim=dim) def route_args(router, args, depth): routed_args = [(dict(), dict()) for _ in range(depth)] matched_keys = [key for key in args.keys() if key in router] for key in matched_keys: val = args[key] for depth, ((f_args, g_args), routes) in enumerate( zip(routed_args, router[key]) ): new_f_args, new_g_args = map( lambda route: ({key: val} if route else {}), routes ) routed_args[depth] = ({**f_args, **new_f_args}, {**g_args, **new_g_args}) return routed_args # classes class SequentialSequence(nn.Module): def __init__(self, layers, args_route={}, layer_dropout=0.0): super().__init__() assert all( len(route) == len(layers) for route in args_route.values() ), "each argument route map must have the same depth as the number of sequential layers" self.layers = layers self.args_route = args_route self.layer_dropout = layer_dropout def forward(self, x, **kwargs): args = route_args(self.args_route, kwargs, len(self.layers)) layers_and_args = list(zip(self.layers, args)) for (f, g), (f_args, g_args) in layers_and_args: x = x + f(x, **f_args) x = x + g(x, **g_args) return x class DivideMax(nn.Module): def __init__(self, dim): super().__init__() self.dim = dim def forward(self, x): maxes = x.amax(dim=self.dim, keepdim=True).detach() return x / maxes # https://arxiv.org/abs/2103.17239 class LayerScale(nn.Module): def __init__(self, dim, depth, fn): super().__init__() if depth <= 18: init_eps = 0.1 elif depth > 18 and depth <= 24: init_eps = 1e-5 else: init_eps = 1e-6 scale = torch.zeros(1, 1, dim).fill_(init_eps) self.scale = nn.Parameter(scale) self.fn = fn def forward(self, x, **kwargs): return self.fn(x, **kwargs) * self.scale # layer norm class PreNorm(nn.Module): def __init__(self, dim, fn, sandwich=False): super().__init__() self.norm = nn.LayerNorm(dim) self.norm_out = nn.LayerNorm(dim) if sandwich else nn.Identity() self.fn = fn def forward(self, x, **kwargs): x = self.norm(x) x = self.fn(x, **kwargs) return self.norm_out(x) # feed forward class GEGLU(nn.Module): def forward(self, x): x, gates = x.chunk(2, dim=-1) return x * F.gelu(gates) class FeedForward(nn.Module): def __init__(self, dim, dropout=0.0, mult=4.0): super().__init__() self.net = nn.Sequential( nn.Linear(dim, dim * mult * 2), GEGLU(), nn.Dropout(dropout), nn.Linear(dim * mult, dim), ) def forward(self, x): return self.net(x) # Attention class Attention(nn.Module): def __init__(self, dim, seq_len, causal=True, heads=8, dim_head=64, dropout=0.0): super().__init__() inner_dim = dim_head * heads self.heads = heads self.seq_len = seq_len self.scale = dim_head**-0.5 self.causal = causal self.to_qkv = nn.Linear(dim, inner_dim * 3, bias=False) self.to_out = nn.Sequential(nn.Linear(inner_dim, dim), nn.Dropout(dropout)) def forward(self, x, mask=None): b, n, _, h, device = *x.shape, self.heads, x.device softmax = torch.softmax qkv = self.to_qkv(x).chunk(3, dim=-1) q, k, v = map(lambda t: rearrange(t, "b n (h d) -> b h n d", h=h), qkv) q = q * self.scale dots = torch.einsum("b h i d, b h j d -> b h i j", q, k) mask_value = max_neg_value(dots) if exists(mask): mask = rearrange(mask, "b j -> b () () j") dots.masked_fill_(~mask, mask_value) del mask if self.causal: i, j = dots.shape[-2:] mask = torch.ones(i, j, device=device).triu_(j - i + 1).bool() dots.masked_fill_(mask, mask_value) attn = softmax(dots, dim=-1) out = torch.einsum("b h i j, b h j d -> b h i d", attn, v) out = rearrange(out, "b h n d -> b n (h d)") out = self.to_out(out) return out # main transformer class class Transformer(nn.Module): def __init__( self, *, dim, depth, seq_len, causal=True, heads=8, dim_head=64, ff_mult=4, attn_dropout=0.0, ff_dropout=0.0, sparse_attn=False, sandwich_norm=False, ): super().__init__() layers = nn.ModuleList([]) sparse_layer = cast_tuple(sparse_attn, depth) for ind, sparse_attn in zip(range(depth), sparse_layer): attn = Attention( dim, causal=causal, seq_len=seq_len, heads=heads, dim_head=dim_head, dropout=attn_dropout, ) ff = FeedForward(dim, mult=ff_mult, dropout=ff_dropout) layers.append( nn.ModuleList( [ LayerScale( dim, ind + 1, PreNorm(dim, attn, sandwich=sandwich_norm) ), LayerScale( dim, ind + 1, PreNorm(dim, ff, sandwich=sandwich_norm) ), ] ) ) execute_type = SequentialSequence route_attn = ((True, False),) * depth attn_route_map = {"mask": route_attn} self.layers = execute_type(layers, args_route=attn_route_map) def forward(self, x, **kwargs): return self.layers(x, **kwargs)