|
|
|
|
|
|
|
import math |
|
from typing import Optional, Tuple |
|
|
|
import torch |
|
import torch.nn as nn |
|
|
|
|
|
from dataclasses import dataclass |
|
|
|
|
|
from pathlib import Path |
|
from typing import Optional, Union |
|
from sentencepiece import SentencePieceProcessor |
|
import torch |
|
|
|
|
|
@dataclass |
|
class ItaliaConfig: |
|
block_size: int = 4096 |
|
vocab_size: int = 50_000 |
|
padding_multiple: int = 512 |
|
padded_vocab_size: int = 50176 |
|
head_size: int = 160 |
|
n_layer: int = 34 |
|
n_head: int = 32 |
|
n_embd: int = 5120 |
|
rotary_percentage: float = 0.4 |
|
parallel_residual: bool = True |
|
bias: bool = True |
|
lm_head_bias: bool = True |
|
n_query_groups: int = 32 |
|
shared_attention_norm: bool = True |
|
norm_eps: float = 1e-5 |
|
intermediate_size: int = 12800 |
|
rope_condense_ratio: int = 1 |
|
rope_n_elem: int = 64 |
|
rope_base: int = 10000 |
|
|
|
|
|
class Tokenizer: |
|
def __init__(self, checkpoint_dir: Union[Path, str]) -> None: |
|
checkpoint_dir = Path(checkpoint_dir) |
|
if not checkpoint_dir.exists(): |
|
raise NotADirectoryError( |
|
f"The checkpoint directory does not exist: {str(checkpoint_dir)}" |
|
) |
|
|
|
self.use_bos = True |
|
self.bos_id = None |
|
self.eos_id = None |
|
|
|
if (vocabulary_path := checkpoint_dir / "tokenizer.model").is_file(): |
|
self.processor = SentencePieceProcessor(model_file=str(vocabulary_path)) |
|
self.backend = "sentencepiece" |
|
self.bos_id = self.processor.bos_id() |
|
self.eos_id = self.processor.eos_id() |
|
else: |
|
raise FileNotFoundError( |
|
f"tokenizer.model not found in {str(checkpoint_dir)}" |
|
) |
|
|
|
@property |
|
def vocab_size(self) -> int: |
|
return self.processor.vocab_size() |
|
|
|
def token_to_id(self, token: str) -> int: |
|
return self.processor.piece_to_id(token) |
|
|
|
def encode( |
|
self, |
|
string: str, |
|
device: Optional[torch.device] = None, |
|
max_length: int = -1, |
|
) -> torch.Tensor: |
|
|
|
tokens = self.processor.encode(string) |
|
tokens = [self.bos_id] + tokens |
|
|
|
if max_length > 0: |
|
tokens = tokens[:max_length] |
|
return torch.tensor(tokens, dtype=torch.int, device=device) |
|
|
|
def decode(self, tensor: torch.Tensor) -> str: |
|
tokens = [tensor.item()] if tensor.ndim == 0 else tensor.tolist() |
|
return self.processor.decode(tokens).strip() |
|
|
|
|
|
class Italia(nn.Module): |
|
def __init__(self, config: ItaliaConfig) -> None: |
|
super().__init__() |
|
assert config.padded_vocab_size is not None |
|
self.config = config |
|
|
|
self.lm_head = nn.Linear( |
|
config.n_embd, config.padded_vocab_size, bias=config.lm_head_bias |
|
) |
|
self.transformer = nn.ModuleDict( |
|
dict( |
|
wte=nn.Embedding(config.padded_vocab_size, config.n_embd), |
|
h=nn.ModuleList(Block(config) for _ in range(config.n_layer)), |
|
ln_f=nn.LayerNorm(config.n_embd, eps=config.norm_eps), |
|
) |
|
) |
|
self.max_seq_length = self.config.block_size |
|
self.mask_cache: Optional[torch.Tensor] = None |
|
|
|
@property |
|
def max_seq_length(self) -> int: |
|
return self._max_seq_length |
|
|
|
@max_seq_length.setter |
|
def max_seq_length(self, value: int) -> None: |
|
""" |
|
When doing inference, the sequences used might be shorter than the model's context length. |
|
This allows setting a smaller number to avoid allocating unused memory |
|
""" |
|
if value > self.config.block_size: |
|
raise ValueError( |
|
f"Cannot attend to {value}, block size is only {self.config.block_size}" |
|
) |
|
self._max_seq_length = value |
|
if not hasattr(self, "cos"): |
|
cos, sin = self.rope_cache() |
|
self.register_buffer("cos", cos, persistent=False) |
|
self.register_buffer("sin", sin, persistent=False) |
|
|
|
elif value != self.cos.size(0): |
|
self.cos, self.sin = self.rope_cache(device=self.cos.device) |
|
|
|
def reset_parameters(self) -> None: |
|
self.cos, self.sin = self.rope_cache() |
|
|
|
def forward( |
|
self, idx: torch.Tensor, input_pos: Optional[torch.Tensor] = None |
|
) -> torch.Tensor: |
|
T = idx.size(1) |
|
if self.max_seq_length < T: |
|
raise ValueError( |
|
f"Cannot forward sequence of length {T}, max seq length is only {self.max_seq_length}." |
|
) |
|
|
|
if input_pos is not None: |
|
cos = self.cos.index_select(0, input_pos) |
|
sin = self.sin.index_select(0, input_pos) |
|
if self.mask_cache is None: |
|
raise TypeError("You need to call `gpt.set_kv_cache()`") |
|
mask = self.mask_cache.index_select(2, input_pos) |
|
else: |
|
cos = self.cos[:T] |
|
sin = self.sin[:T] |
|
mask = None |
|
|
|
x = self.transformer.wte(idx) |
|
for block in self.transformer.h: |
|
x = block(x, cos, sin, mask, input_pos) |
|
x = self.transformer.ln_f(x) |
|
return self.lm_head(x) |
|
|
|
def rope_cache( |
|
self, device: Optional[torch.device] = None |
|
) -> Tuple[torch.Tensor, torch.Tensor]: |
|
return build_rope_cache( |
|
seq_len=self.max_seq_length, |
|
n_elem=self.config.rope_n_elem, |
|
device=device, |
|
condense_ratio=self.config.rope_condense_ratio, |
|
base=self.config.rope_base, |
|
) |
|
|
|
def set_kv_cache( |
|
self, |
|
batch_size: int, |
|
rope_cache_length: Optional[int] = None, |
|
device: Optional[torch.device] = None, |
|
dtype: Optional[torch.dtype] = None, |
|
) -> None: |
|
if rope_cache_length is None: |
|
rope_cache_length = self.cos.size(-1) |
|
max_seq_length = self.max_seq_length |
|
|
|
for block in self.transformer.h: |
|
block.attn.kv_cache = block.attn.build_kv_cache( |
|
batch_size, max_seq_length, rope_cache_length, device, dtype |
|
) |
|
|
|
if self.mask_cache is None or self.mask_cache.size(3) != max_seq_length: |
|
self.mask_cache = build_mask_cache(max_seq_length, device) |
|
|
|
def clear_kv_cache(self) -> None: |
|
self.mask_cache = None |
|
for block in self.transformer.h: |
|
block.attn.kv_cache = None |
|
|
|
|
|
class Block(nn.Module): |
|
def __init__(self, config: ItaliaConfig) -> None: |
|
super().__init__() |
|
self.norm_1 = nn.LayerNorm(config.n_embd, eps=config.norm_eps) |
|
self.attn = CausalSelfAttention(config) |
|
self.mlp = MLP(config) |
|
self.config = config |
|
|
|
def forward( |
|
self, |
|
x: torch.Tensor, |
|
cos: torch.Tensor, |
|
sin: torch.Tensor, |
|
mask: Optional[torch.Tensor] = None, |
|
input_pos: Optional[torch.Tensor] = None, |
|
) -> torch.Tensor: |
|
n_1 = self.norm_1(x) |
|
h = self.attn(n_1, cos, sin, mask, input_pos) |
|
n_2 = n_1 if self.config.shared_attention_norm else self.norm_2(x) |
|
x = self.mlp(n_2) + h + x |
|
return x |
|
|
|
|
|
class CausalSelfAttention(nn.Module): |
|
def __init__(self, config: ItaliaConfig) -> None: |
|
super().__init__() |
|
shape = (config.n_head + 2 * config.n_query_groups) * config.head_size |
|
linear_module = nn.Linear |
|
self.attn = linear_module(config.n_embd, shape, bias=config.bias) |
|
self.proj = linear_module(config.n_embd, config.n_embd, bias=config.bias) |
|
self.kv_cache: Optional[KVCache] = None |
|
|
|
self.config = config |
|
|
|
def forward( |
|
self, |
|
x: torch.Tensor, |
|
cos: torch.Tensor, |
|
sin: torch.Tensor, |
|
mask: Optional[torch.Tensor] = None, |
|
input_pos: Optional[torch.Tensor] = None, |
|
) -> torch.Tensor: |
|
B, T, _ = ( |
|
x.size() |
|
) |
|
|
|
qkv = self.attn(x) |
|
|
|
|
|
q_per_kv = self.config.n_head // self.config.n_query_groups |
|
total_qkv = q_per_kv + 2 |
|
qkv = qkv.view( |
|
B, T, self.config.n_query_groups, total_qkv, self.config.head_size |
|
) |
|
qkv = qkv.permute(0, 2, 3, 1, 4) |
|
|
|
|
|
q, k, v = qkv.split((q_per_kv, 1, 1), dim=2) |
|
|
|
q = q.reshape(B, -1, T, self.config.head_size) |
|
k = k.reshape(B, -1, T, self.config.head_size) |
|
v = v.reshape(B, -1, T, self.config.head_size) |
|
|
|
q_roped = apply_rope(q[..., : self.config.rope_n_elem], cos, sin) |
|
k_roped = apply_rope(k[..., : self.config.rope_n_elem], cos, sin) |
|
q = torch.cat((q_roped, q[..., self.config.rope_n_elem :]), dim=-1) |
|
k = torch.cat((k_roped, k[..., self.config.rope_n_elem :]), dim=-1) |
|
|
|
if input_pos is not None: |
|
if not isinstance(self.kv_cache, KVCache): |
|
raise TypeError("You need to call `gpt.set_kv_cache()`") |
|
k, v = self.kv_cache(input_pos, k, v) |
|
|
|
y = self.scaled_dot_product_attention(q, k, v, mask) |
|
|
|
y = y.reshape( |
|
B, T, self.config.n_embd |
|
) |
|
|
|
|
|
return self.proj(y) |
|
|
|
def scaled_dot_product_attention( |
|
self, |
|
q: torch.Tensor, |
|
k: torch.Tensor, |
|
v: torch.Tensor, |
|
mask: Optional[torch.Tensor] = None, |
|
) -> torch.Tensor: |
|
scale = 1.0 / math.sqrt(self.config.head_size) |
|
y = torch.nn.functional.scaled_dot_product_attention( |
|
q, k, v, attn_mask=mask, dropout_p=0.0, scale=scale, is_causal=mask is None |
|
) |
|
return y.transpose(1, 2) |
|
|
|
def build_kv_cache( |
|
self, |
|
batch_size: int, |
|
max_seq_length: int, |
|
rope_cache_length: Optional[int] = None, |
|
device: Optional[torch.device] = None, |
|
dtype: Optional[torch.dtype] = None, |
|
) -> "KVCache": |
|
heads = 1 if self.config.n_query_groups == 1 else self.config.n_head |
|
v_shape = (batch_size, heads, max_seq_length, self.config.head_size) |
|
if rope_cache_length is None: |
|
if self.config.rotary_percentage != 1.0: |
|
raise TypeError( |
|
"Please pass the `rope_cache_length=gpt.cos.size(-1)` value" |
|
) |
|
k_shape = v_shape |
|
else: |
|
k_shape = ( |
|
batch_size, |
|
heads, |
|
max_seq_length, |
|
rope_cache_length + self.config.head_size - self.config.rope_n_elem, |
|
) |
|
return KVCache(k_shape, v_shape, device=device, dtype=dtype) |
|
|
|
|
|
class MLP(nn.Module): |
|
def __init__(self, config: ItaliaConfig) -> None: |
|
super().__init__() |
|
self.fc = nn.Linear(config.n_embd, config.intermediate_size, bias=config.bias) |
|
self.proj = nn.Linear(config.intermediate_size, config.n_embd, bias=config.bias) |
|
|
|
self.config = config |
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor: |
|
x = self.fc(x) |
|
x = torch.nn.functional.gelu(x, approximate="tanh") |
|
return self.proj(x) |
|
|
|
|
|
def build_rope_cache( |
|
seq_len: int, |
|
n_elem: int, |
|
device: Optional[torch.device] = None, |
|
base: int = 10000, |
|
condense_ratio: int = 1, |
|
) -> Tuple[torch.Tensor, torch.Tensor]: |
|
"""Enhanced Transformer with Rotary Position Embedding. |
|
|
|
Derived from: https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/labml_nn/ |
|
transformers/rope/__init__.py. MIT License: |
|
https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/license. |
|
""" |
|
|
|
theta = 1.0 / (base ** (torch.arange(0, n_elem, 2, device=device).float() / n_elem)) |
|
|
|
|
|
seq_idx = torch.arange(seq_len, device=device) / condense_ratio |
|
|
|
|
|
idx_theta = torch.outer(seq_idx, theta).repeat(1, 2) |
|
|
|
return torch.cos(idx_theta), torch.sin(idx_theta) |
|
|
|
|
|
def apply_rope(x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor) -> torch.Tensor: |
|
head_size = x.size(-1) |
|
x1 = x[..., : head_size // 2] |
|
x2 = x[..., head_size // 2 :] |
|
rotated = torch.cat((-x2, x1), dim=-1) |
|
roped = (x * cos) + (rotated * sin) |
|
return roped.to(dtype=x.dtype) |
|
|
|
|
|
class KVCache(nn.Module): |
|
def __init__( |
|
self, |
|
k_shape: Tuple[int, int, int, int], |
|
v_shape: Tuple[int, int, int, int], |
|
device: Optional[torch.device] = None, |
|
dtype: Optional[torch.dtype] = None, |
|
) -> None: |
|
super().__init__() |
|
self.register_buffer( |
|
"k", torch.zeros(k_shape, device=device, dtype=dtype), persistent=False |
|
) |
|
self.register_buffer( |
|
"v", torch.zeros(v_shape, device=device, dtype=dtype), persistent=False |
|
) |
|
|
|
def forward( |
|
self, input_pos: torch.Tensor, k: torch.Tensor, v: torch.Tensor |
|
) -> Tuple[torch.Tensor, torch.Tensor]: |
|
|
|
self.k = self.k.to(k.dtype) |
|
self.v = self.v.to(v.dtype) |
|
|
|
k = self.k.index_copy_(2, input_pos, k) |
|
v = self.v.index_copy_(2, input_pos, v) |
|
return k, v |
|
|
|
def reset_parameters(self) -> None: |
|
torch.nn.init.zeros_(self.k) |
|
torch.nn.init.zeros_(self.v) |
|
|
|
|
|
def build_mask_cache( |
|
max_seq_length: int, device: Optional[torch.device] = None |
|
) -> torch.Tensor: |
|
ones = torch.ones((max_seq_length, max_seq_length), device=device, dtype=torch.bool) |
|
return torch.tril(ones).unsqueeze(0).unsqueeze(0) |
|
|