thomwolf's picture
thomwolf HF staff
update
f2c15d5
raw
history blame
6.13 kB
from typing import Optional, Union
import einops
import torch
import torch.nn as nn
import torch.nn.functional as F
class Attention(nn.Module):
"""
Minimal multi-head attention layer.
"""
def __init__(
self,
d_model: int,
n_heads: int,
device: Optional[Union[str, torch.device]] = None,
dtype: Optional[torch.dtype] = None,
):
super().__init__()
self.d_model = d_model
self.n_heads = n_heads
factory_kwargs = {"device": device, "dtype": dtype}
self.d_head, remainder = divmod(self.d_model, self.n_heads)
assert not remainder, f"{n_heads=} must divide {d_model=} evenly"
self.lin_qkv = nn.Linear(
self.d_model,
3 * self.d_model,
**factory_kwargs,
)
self.lin_out = nn.Linear(self.d_model, self.d_model, **factory_kwargs)
def forward(
self,
inputs: torch.Tensor,
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
bsz, seq_len, _ = inputs.size()
# Create the queries, keys, values
qkv = einops.rearrange(
self.lin_qkv(inputs),
"b s (three n_h d_h) -> three b s n_h d_h",
b=bsz,
s=seq_len,
three=3,
n_h=self.n_heads,
d_h=self.d_head,
)
q, k, v = qkv
bsz, seq_len, n_heads, d_head = q.shape
shape_kwargs = dict(b=bsz, n_h=n_heads, s=seq_len, d_h=d_head)
q = einops.rearrange(q, "b s n_h d_h -> b n_h s d_h", **shape_kwargs)
k = einops.rearrange(k, "b s n_h d_h -> b n_h s d_h", **shape_kwargs)
v = einops.rearrange(v, "b s n_h d_h -> b n_h s d_h", **shape_kwargs)
# Multi-head self-attention
attn_output = F.scaled_dot_product_attention(q, k, v, is_causal=True)
attn_output = einops.rearrange(
attn_output,
"b n_h s d_h -> b s (n_h d_h)",
b=bsz,
n_h=n_heads,
s=seq_len,
d_h=d_head,
)
# Final projection
out = self.lin_out(attn_output)
return out
class MLP(nn.Module):
"""
Basic MLP layer with optional Dropout.
"""
def __init__(
self,
d_model: int,
act_fn: nn.Module,
dropout_prob: Optional[float] = None,
device: Optional[Union[str, torch.device]] = None,
dtype: Optional[torch.dtype] = None,
) -> None:
super().__init__()
print(f"Shapes: d_model: {d_model}, act_fn: {act_fn}, dropout_prob: {dropout_prob}, device: {device}, dtype: {dtype}")
self.d_model = d_model
self.act_fn = act_fn
self.dropout_prob = dropout_prob
factory_kwargs = {"device": device, "dtype": dtype}
self.lin_0 = nn.Linear(self.d_model, 4 * self.d_model, **factory_kwargs)
self.lin_1 = nn.Linear(4 * self.d_model, self.d_model, **factory_kwargs)
self.dropout = nn.Dropout(self.dropout_prob) if self.dropout_prob else None
def forward(self, inputs: torch.Tensor) -> torch.Tensor:
x = self.lin_0(inputs)
x = self.act_fn(x)
x = self.lin_1(x)
if self.dropout is not None:
x = self.dropout(x)
return x
class SwiGLUMLP(nn.Module):
"""
Llama 3 SwiGLU MLP layer with optional Dropout.
"""
def __init__(
self,
d_model: int,
intermediate_size: int,
act_fn: nn.Module,
dropout_prob: Optional[float] = None,
device: Optional[Union[str, torch.device]] = None,
dtype: Optional[torch.dtype] = None,
) -> None:
super().__init__()
print(f"Shapes: d_model: {d_model}, intermediate_size: {intermediate_size}, act_fn: {act_fn}, dropout_prob: {dropout_prob}, device: {device}, dtype: {dtype}")
self.d_model = d_model
self.intermediate_size = intermediate_size
self.act_fn = act_fn
self.dropout_prob = dropout_prob
factory_kwargs = {"device": device, "dtype": dtype}
self.gate_proj = nn.Linear(self.d_model, self.intermediate_size, **factory_kwargs)
self.up_proj = nn.Linear(self.d_model, self.intermediate_size, **factory_kwargs)
self.down_proj = nn.Linear(self.intermediate_size, self.d_model, **factory_kwargs)
self.dropout = nn.Dropout(self.dropout_prob) if self.dropout_prob else None
def forward(self, inputs: torch.Tensor) -> torch.Tensor:
x = self.down_proj(self.act_fn(self.gate_proj(inputs)) * self.up_proj(inputs))
if self.dropout is not None:
x = self.dropout(x)
return x
class Block(nn.Module):
"""
Basic transformer block.
Schematic:
β”Œβ”€β”€β”€β”€β”€β”€β”
β”‚inputsβ”‚
β””β”¬β”€β”¬β”€β”€β”€β”˜
β”‚β”Œβ–½β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”
β”‚β”‚norm_0, attnβ”‚
β”‚β””β”¬β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜
β”Œβ–½β”€β–½β”€β”€β”
β”‚ add β”‚
β””β”¬β”€β”¬β”€β”€β”˜
β”‚β”Œβ–½β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”
β”‚β”‚norm_1, mlpβ”‚
β”‚β””β”¬β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜
β”Œβ–½β”€β–½β”€β”€β”
β”‚ add β”‚
β””β”¬β”€β”€β”€β”€β”˜
β”Œβ–½β”€β”€β”€β”€β”€β”€β”
β”‚outputsβ”‚
β””β”€β”€β”€β”€β”€β”€β”€β”˜
"""
def __init__(
self,
d_model: int,
n_heads: int,
act_fn: nn.Module,
dropout_prob: Optional[float] = None,
dtype: Optional[torch.dtype] = None,
device: Optional[Union[str, torch.device]] = None,
):
super().__init__()
factory_kwargs = {"device": device, "dtype": dtype}
self.attn = Attention(d_model=d_model, n_heads=n_heads, **factory_kwargs)
self.mlp = MLP(d_model=d_model, act_fn=act_fn, dropout_prob=dropout_prob, **factory_kwargs)
self.norm_0 = nn.LayerNorm(d_model, **factory_kwargs)
self.norm_1 = nn.LayerNorm(d_model, **factory_kwargs)
def forward(self, inputs: torch.Tensor) -> torch.Tensor:
outputs = self.attn(self.norm_0(inputs)) + inputs
outputs = self.mlp(self.norm_1(outputs)) + outputs
return outputs