Spaces:
Running
Running
from typing import Optional, Union | |
import einops | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
class Attention(nn.Module): | |
""" | |
Minimal multi-head attention layer. | |
""" | |
def __init__( | |
self, | |
d_model: int, | |
n_heads: int, | |
device: Optional[Union[str, torch.device]] = None, | |
dtype: Optional[torch.dtype] = None, | |
): | |
super().__init__() | |
self.d_model = d_model | |
self.n_heads = n_heads | |
factory_kwargs = {"device": device, "dtype": dtype} | |
self.d_head, remainder = divmod(self.d_model, self.n_heads) | |
assert not remainder, f"{n_heads=} must divide {d_model=} evenly" | |
self.lin_qkv = nn.Linear( | |
self.d_model, | |
3 * self.d_model, | |
**factory_kwargs, | |
) | |
self.lin_out = nn.Linear(self.d_model, self.d_model, **factory_kwargs) | |
def forward( | |
self, | |
inputs: torch.Tensor, | |
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]: | |
bsz, seq_len, _ = inputs.size() | |
# Create the queries, keys, values | |
qkv = einops.rearrange( | |
self.lin_qkv(inputs), | |
"b s (three n_h d_h) -> three b s n_h d_h", | |
b=bsz, | |
s=seq_len, | |
three=3, | |
n_h=self.n_heads, | |
d_h=self.d_head, | |
) | |
q, k, v = qkv | |
bsz, seq_len, n_heads, d_head = q.shape | |
shape_kwargs = dict(b=bsz, n_h=n_heads, s=seq_len, d_h=d_head) | |
q = einops.rearrange(q, "b s n_h d_h -> b n_h s d_h", **shape_kwargs) | |
k = einops.rearrange(k, "b s n_h d_h -> b n_h s d_h", **shape_kwargs) | |
v = einops.rearrange(v, "b s n_h d_h -> b n_h s d_h", **shape_kwargs) | |
# Multi-head self-attention | |
attn_output = F.scaled_dot_product_attention(q, k, v, is_causal=True) | |
attn_output = einops.rearrange( | |
attn_output, | |
"b n_h s d_h -> b s (n_h d_h)", | |
b=bsz, | |
n_h=n_heads, | |
s=seq_len, | |
d_h=d_head, | |
) | |
# Final projection | |
out = self.lin_out(attn_output) | |
return out | |
class MLP(nn.Module): | |
""" | |
Basic MLP layer with optional Dropout. | |
""" | |
def __init__( | |
self, | |
d_model: int, | |
act_fn: nn.Module, | |
dropout_prob: Optional[float] = None, | |
device: Optional[Union[str, torch.device]] = None, | |
dtype: Optional[torch.dtype] = None, | |
) -> None: | |
super().__init__() | |
print(f"Shapes: d_model: {d_model}, act_fn: {act_fn}, dropout_prob: {dropout_prob}, device: {device}, dtype: {dtype}") | |
self.d_model = d_model | |
self.act_fn = act_fn | |
self.dropout_prob = dropout_prob | |
factory_kwargs = {"device": device, "dtype": dtype} | |
self.lin_0 = nn.Linear(self.d_model, 4 * self.d_model, **factory_kwargs) | |
self.lin_1 = nn.Linear(4 * self.d_model, self.d_model, **factory_kwargs) | |
self.dropout = nn.Dropout(self.dropout_prob) if self.dropout_prob else None | |
def forward(self, inputs: torch.Tensor) -> torch.Tensor: | |
x = self.lin_0(inputs) | |
x = self.act_fn(x) | |
x = self.lin_1(x) | |
if self.dropout is not None: | |
x = self.dropout(x) | |
return x | |
class SwiGLUMLP(nn.Module): | |
""" | |
Llama 3 SwiGLU MLP layer with optional Dropout. | |
""" | |
def __init__( | |
self, | |
d_model: int, | |
intermediate_size: int, | |
act_fn: nn.Module, | |
dropout_prob: Optional[float] = None, | |
device: Optional[Union[str, torch.device]] = None, | |
dtype: Optional[torch.dtype] = None, | |
) -> None: | |
super().__init__() | |
print(f"Shapes: d_model: {d_model}, intermediate_size: {intermediate_size}, act_fn: {act_fn}, dropout_prob: {dropout_prob}, device: {device}, dtype: {dtype}") | |
self.d_model = d_model | |
self.intermediate_size = intermediate_size | |
self.act_fn = act_fn | |
self.dropout_prob = dropout_prob | |
factory_kwargs = {"device": device, "dtype": dtype} | |
self.gate_proj = nn.Linear(self.d_model, self.intermediate_size, **factory_kwargs) | |
self.up_proj = nn.Linear(self.d_model, self.intermediate_size, **factory_kwargs) | |
self.down_proj = nn.Linear(self.intermediate_size, self.d_model, **factory_kwargs) | |
self.dropout = nn.Dropout(self.dropout_prob) if self.dropout_prob else None | |
def forward(self, inputs: torch.Tensor) -> torch.Tensor: | |
x = self.down_proj(self.act_fn(self.gate_proj(inputs)) * self.up_proj(inputs)) | |
if self.dropout is not None: | |
x = self.dropout(x) | |
return x | |
class Block(nn.Module): | |
""" | |
Basic transformer block. | |
Schematic: | |
ββββββββ | |
βinputsβ | |
ββ¬ββ¬ββββ | |
βββ½ββββββββββββ | |
ββnorm_0, attnβ | |
βββ¬ββββββββββββ | |
ββ½ββ½βββ | |
β add β | |
ββ¬ββ¬βββ | |
βββ½βββββββββββ | |
ββnorm_1, mlpβ | |
βββ¬βββββββββββ | |
ββ½ββ½βββ | |
β add β | |
ββ¬βββββ | |
ββ½βββββββ | |
βoutputsβ | |
βββββββββ | |
""" | |
def __init__( | |
self, | |
d_model: int, | |
n_heads: int, | |
act_fn: nn.Module, | |
dropout_prob: Optional[float] = None, | |
dtype: Optional[torch.dtype] = None, | |
device: Optional[Union[str, torch.device]] = None, | |
): | |
super().__init__() | |
factory_kwargs = {"device": device, "dtype": dtype} | |
self.attn = Attention(d_model=d_model, n_heads=n_heads, **factory_kwargs) | |
self.mlp = MLP(d_model=d_model, act_fn=act_fn, dropout_prob=dropout_prob, **factory_kwargs) | |
self.norm_0 = nn.LayerNorm(d_model, **factory_kwargs) | |
self.norm_1 = nn.LayerNorm(d_model, **factory_kwargs) | |
def forward(self, inputs: torch.Tensor) -> torch.Tensor: | |
outputs = self.attn(self.norm_0(inputs)) + inputs | |
outputs = self.mlp(self.norm_1(outputs)) + outputs | |
return outputs | |