|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
from torch.nn import Module, ModuleList |
|
import torchaudio |
|
from einops import rearrange |
|
import numpy as np |
|
|
|
|
|
from torchtune.modules import RotaryPositionalEmbeddings |
|
|
|
|
|
|
|
class RMSNorm(torch.nn.Module): |
|
def __init__(self, dim: int, eps: float = 1e-6): |
|
r"""https://github.com/meta-llama/llama/blob/main/llama/model.py""" |
|
super().__init__() |
|
self.eps = eps |
|
self.weight = nn.Parameter(torch.ones(dim)) |
|
|
|
def forward(self, x): |
|
norm_x = torch.mean(x ** 2, dim=-1, keepdim=True) |
|
output = x * torch.rsqrt(norm_x + self.eps) * self.weight |
|
return output |
|
|
|
|
|
|
|
class MLP(nn.Module): |
|
def __init__(self, dim: int) -> None: |
|
super().__init__() |
|
|
|
self.fc1 = nn.Linear(dim, 4 * dim, bias=False) |
|
self.silu = nn.SiLU() |
|
self.fc2 = nn.Linear(4 * dim, dim, bias=False) |
|
|
|
def forward(self, x): |
|
x = self.fc1(x) |
|
x = self.silu(x) |
|
x = self.fc2(x) |
|
return x |
|
|
|
|
|
class Attention(nn.Module): |
|
|
|
def __init__(self, dim: int, n_heads: int, rotary_embed: RotaryPositionalEmbeddings): |
|
super().__init__() |
|
|
|
assert dim % n_heads == 0 |
|
|
|
self.n_heads = n_heads |
|
self.dim = dim |
|
self.rotary_embed = rotary_embed |
|
|
|
self.flash = hasattr(torch.nn.functional, 'scaled_dot_product_attention') |
|
assert self.flash, "Must have flash attention." |
|
|
|
self.c_attn = nn.Linear(dim, 3 * dim, bias=False) |
|
self.c_proj = nn.Linear(dim, dim, bias=False) |
|
|
|
def forward(self, x): |
|
r""" |
|
Args: |
|
x: (b, t, h*d) |
|
|
|
Constants: |
|
b: batch_size |
|
t: time steps |
|
r: 3 |
|
h: heads_num |
|
d: heads_dim |
|
""" |
|
B, T, C = x.size() |
|
|
|
q, k, v = rearrange(self.c_attn(x), 'b t (r h d) -> r b h t d', r=3, h=self.n_heads) |
|
|
|
|
|
q = self.rotary_embed(q) |
|
k = self.rotary_embed(k) |
|
|
|
if self.flash: |
|
y = torch.nn.functional.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=0, is_causal=False) |
|
|
|
y = rearrange(y, 'b h t d -> b t (h d)') |
|
|
|
y = self.c_proj(y) |
|
|
|
|
|
return y |
|
|
|
|
|
class TransformerBlock(nn.Module): |
|
def __init__(self, dim: int, n_heads: int, rotary_embed: RotaryPositionalEmbeddings): |
|
|
|
super().__init__() |
|
self.dim = dim |
|
self.n_heads = n_heads |
|
|
|
self.att_norm = RMSNorm(dim) |
|
self.ffn_norm = RMSNorm(dim) |
|
self.att = Attention(dim=dim, n_heads=n_heads, rotary_embed=rotary_embed) |
|
self.mlp = MLP(dim=dim) |
|
|
|
|
|
def forward( |
|
self, |
|
x: torch.Tensor, |
|
): |
|
x = x + self.att(self.att_norm(x)) |
|
x = x + self.mlp(self.ffn_norm(x)) |
|
return x |
|
|
|
|
|
if __name__ == '__main__': |
|
rotary_embed_128 = RotaryPositionalEmbeddings(dim=128) |
|
transformer_block = TransformerBlock( |
|
dim=1024, |
|
n_heads=8, |
|
rotary_embed=rotary_embed_128 |
|
) |
|
x = torch.randn(2, 128, 1024) |
|
y = transformer_block(x) |
|
print(y.shape) |
|
c=1 |