Italia-9B / modello_italia.py
leafspark's picture
add model
56811f1 verified
# Copyright Lightning AI. Licensed under the Apache License 2.0, see LICENSE file.
# Derivated from https://github.com/Lightning-AI/litgpt/blob/main/litgpt/model.py
import math
from typing import Optional, Tuple
import torch
import torch.nn as nn
from dataclasses import dataclass
from pathlib import Path
from typing import Optional, Union
from sentencepiece import SentencePieceProcessor
import torch
@dataclass
class ItaliaConfig:
block_size: int = 4096
vocab_size: int = 50_000
padding_multiple: int = 512
padded_vocab_size: int = 50176
head_size: int = 160
n_layer: int = 34
n_head: int = 32
n_embd: int = 5120
rotary_percentage: float = 0.4
parallel_residual: bool = True
bias: bool = True
lm_head_bias: bool = True
n_query_groups: int = 32
shared_attention_norm: bool = True
norm_eps: float = 1e-5
intermediate_size: int = 12800
rope_condense_ratio: int = 1
rope_n_elem: int = 64
rope_base: int = 10000
class Tokenizer:
def __init__(self, checkpoint_dir: Union[Path, str]) -> None:
checkpoint_dir = Path(checkpoint_dir)
if not checkpoint_dir.exists():
raise NotADirectoryError(
f"The checkpoint directory does not exist: {str(checkpoint_dir)}"
)
self.use_bos = True
self.bos_id = None
self.eos_id = None
if (vocabulary_path := checkpoint_dir / "tokenizer.model").is_file():
self.processor = SentencePieceProcessor(model_file=str(vocabulary_path))
self.backend = "sentencepiece"
self.bos_id = self.processor.bos_id()
self.eos_id = self.processor.eos_id()
else:
raise FileNotFoundError(
f"tokenizer.model not found in {str(checkpoint_dir)}"
)
@property
def vocab_size(self) -> int:
return self.processor.vocab_size()
def token_to_id(self, token: str) -> int:
return self.processor.piece_to_id(token)
def encode(
self,
string: str,
device: Optional[torch.device] = None,
max_length: int = -1,
) -> torch.Tensor:
tokens = self.processor.encode(string)
tokens = [self.bos_id] + tokens
if max_length > 0:
tokens = tokens[:max_length]
return torch.tensor(tokens, dtype=torch.int, device=device)
def decode(self, tensor: torch.Tensor) -> str:
tokens = [tensor.item()] if tensor.ndim == 0 else tensor.tolist()
return self.processor.decode(tokens).strip()
class Italia(nn.Module):
def __init__(self, config: ItaliaConfig) -> None:
super().__init__()
assert config.padded_vocab_size is not None
self.config = config
self.lm_head = nn.Linear(
config.n_embd, config.padded_vocab_size, bias=config.lm_head_bias
)
self.transformer = nn.ModuleDict(
dict(
wte=nn.Embedding(config.padded_vocab_size, config.n_embd),
h=nn.ModuleList(Block(config) for _ in range(config.n_layer)),
ln_f=nn.LayerNorm(config.n_embd, eps=config.norm_eps),
)
)
self.max_seq_length = self.config.block_size
self.mask_cache: Optional[torch.Tensor] = None
@property
def max_seq_length(self) -> int:
return self._max_seq_length
@max_seq_length.setter
def max_seq_length(self, value: int) -> None:
"""
When doing inference, the sequences used might be shorter than the model's context length.
This allows setting a smaller number to avoid allocating unused memory
"""
if value > self.config.block_size:
raise ValueError(
f"Cannot attend to {value}, block size is only {self.config.block_size}"
)
self._max_seq_length = value
if not hasattr(self, "cos"):
cos, sin = self.rope_cache()
self.register_buffer("cos", cos, persistent=False)
self.register_buffer("sin", sin, persistent=False)
elif value != self.cos.size(0):
self.cos, self.sin = self.rope_cache(device=self.cos.device)
def reset_parameters(self) -> None:
self.cos, self.sin = self.rope_cache()
def forward(
self, idx: torch.Tensor, input_pos: Optional[torch.Tensor] = None
) -> torch.Tensor:
T = idx.size(1)
if self.max_seq_length < T:
raise ValueError(
f"Cannot forward sequence of length {T}, max seq length is only {self.max_seq_length}."
)
if input_pos is not None: # use the kv cache
cos = self.cos.index_select(0, input_pos)
sin = self.sin.index_select(0, input_pos)
if self.mask_cache is None:
raise TypeError("You need to call `gpt.set_kv_cache()`")
mask = self.mask_cache.index_select(2, input_pos)
else:
cos = self.cos[:T]
sin = self.sin[:T]
mask = None
x = self.transformer.wte(idx) # token embeddings of shape (b, t, n_embd)
for block in self.transformer.h:
x = block(x, cos, sin, mask, input_pos)
x = self.transformer.ln_f(x)
return self.lm_head(x) # (b, t, vocab_size)
def rope_cache(
self, device: Optional[torch.device] = None
) -> Tuple[torch.Tensor, torch.Tensor]:
return build_rope_cache(
seq_len=self.max_seq_length,
n_elem=self.config.rope_n_elem,
device=device,
condense_ratio=self.config.rope_condense_ratio,
base=self.config.rope_base,
)
def set_kv_cache(
self,
batch_size: int,
rope_cache_length: Optional[int] = None,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
) -> None:
if rope_cache_length is None:
rope_cache_length = self.cos.size(-1)
max_seq_length = self.max_seq_length
for block in self.transformer.h:
block.attn.kv_cache = block.attn.build_kv_cache(
batch_size, max_seq_length, rope_cache_length, device, dtype
)
if self.mask_cache is None or self.mask_cache.size(3) != max_seq_length:
self.mask_cache = build_mask_cache(max_seq_length, device)
def clear_kv_cache(self) -> None:
self.mask_cache = None
for block in self.transformer.h:
block.attn.kv_cache = None
class Block(nn.Module):
def __init__(self, config: ItaliaConfig) -> None:
super().__init__()
self.norm_1 = nn.LayerNorm(config.n_embd, eps=config.norm_eps)
self.attn = CausalSelfAttention(config)
self.mlp = MLP(config)
self.config = config
def forward(
self,
x: torch.Tensor,
cos: torch.Tensor,
sin: torch.Tensor,
mask: Optional[torch.Tensor] = None,
input_pos: Optional[torch.Tensor] = None,
) -> torch.Tensor:
n_1 = self.norm_1(x)
h = self.attn(n_1, cos, sin, mask, input_pos)
n_2 = n_1 if self.config.shared_attention_norm else self.norm_2(x)
x = self.mlp(n_2) + h + x
return x
class CausalSelfAttention(nn.Module):
def __init__(self, config: ItaliaConfig) -> None:
super().__init__()
shape = (config.n_head + 2 * config.n_query_groups) * config.head_size
linear_module = nn.Linear
self.attn = linear_module(config.n_embd, shape, bias=config.bias)
self.proj = linear_module(config.n_embd, config.n_embd, bias=config.bias)
self.kv_cache: Optional[KVCache] = None
self.config = config
def forward(
self,
x: torch.Tensor,
cos: torch.Tensor,
sin: torch.Tensor,
mask: Optional[torch.Tensor] = None,
input_pos: Optional[torch.Tensor] = None,
) -> torch.Tensor:
B, T, _ = (
x.size()
) # batch size, sequence length, embedding dimensionality (n_embd)
qkv = self.attn(x)
# assemble into a number of query groups to support MHA, MQA and GQA together (see `config.n_query_groups`)
q_per_kv = self.config.n_head // self.config.n_query_groups
total_qkv = q_per_kv + 2 # each group has 1+ queries, 1 key, and 1 value
qkv = qkv.view(
B, T, self.config.n_query_groups, total_qkv, self.config.head_size
)
qkv = qkv.permute(0, 2, 3, 1, 4) # (B, n_query_groups, total_qkv, T, hs)
# split batched computation into three
q, k, v = qkv.split((q_per_kv, 1, 1), dim=2)
q = q.reshape(B, -1, T, self.config.head_size) # (B, nh_q, T, hs)
k = k.reshape(B, -1, T, self.config.head_size) # (B, nh_k, T, hs)
v = v.reshape(B, -1, T, self.config.head_size) # (B, nh_v, T, hs)
q_roped = apply_rope(q[..., : self.config.rope_n_elem], cos, sin)
k_roped = apply_rope(k[..., : self.config.rope_n_elem], cos, sin)
q = torch.cat((q_roped, q[..., self.config.rope_n_elem :]), dim=-1)
k = torch.cat((k_roped, k[..., self.config.rope_n_elem :]), dim=-1)
if input_pos is not None:
if not isinstance(self.kv_cache, KVCache):
raise TypeError("You need to call `gpt.set_kv_cache()`")
k, v = self.kv_cache(input_pos, k, v)
y = self.scaled_dot_product_attention(q, k, v, mask)
y = y.reshape(
B, T, self.config.n_embd
) # re-assemble all head outputs side by side
# output projection
return self.proj(y)
def scaled_dot_product_attention(
self,
q: torch.Tensor,
k: torch.Tensor,
v: torch.Tensor,
mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
scale = 1.0 / math.sqrt(self.config.head_size)
y = torch.nn.functional.scaled_dot_product_attention(
q, k, v, attn_mask=mask, dropout_p=0.0, scale=scale, is_causal=mask is None
)
return y.transpose(1, 2)
def build_kv_cache(
self,
batch_size: int,
max_seq_length: int,
rope_cache_length: Optional[int] = None,
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
) -> "KVCache":
heads = 1 if self.config.n_query_groups == 1 else self.config.n_head
v_shape = (batch_size, heads, max_seq_length, self.config.head_size)
if rope_cache_length is None:
if self.config.rotary_percentage != 1.0:
raise TypeError(
"Please pass the `rope_cache_length=gpt.cos.size(-1)` value"
)
k_shape = v_shape
else:
k_shape = (
batch_size,
heads,
max_seq_length,
rope_cache_length + self.config.head_size - self.config.rope_n_elem,
)
return KVCache(k_shape, v_shape, device=device, dtype=dtype)
class MLP(nn.Module):
def __init__(self, config: ItaliaConfig) -> None:
super().__init__()
self.fc = nn.Linear(config.n_embd, config.intermediate_size, bias=config.bias)
self.proj = nn.Linear(config.intermediate_size, config.n_embd, bias=config.bias)
self.config = config
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.fc(x)
x = torch.nn.functional.gelu(x, approximate="tanh")
return self.proj(x)
def build_rope_cache(
seq_len: int,
n_elem: int,
device: Optional[torch.device] = None,
base: int = 10000,
condense_ratio: int = 1,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Enhanced Transformer with Rotary Position Embedding.
Derived from: https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/labml_nn/
transformers/rope/__init__.py. MIT License:
https://github.com/labmlai/annotated_deep_learning_paper_implementations/blob/master/license.
"""
# $\Theta = {\theta_i = 10000^{\frac{2(i-1)}{d}}, i \in [1, 2, ..., \frac{d}{2}]}$
theta = 1.0 / (base ** (torch.arange(0, n_elem, 2, device=device).float() / n_elem))
# Create position indexes `[0, 1, ..., seq_len - 1]`
seq_idx = torch.arange(seq_len, device=device) / condense_ratio
# Calculate the product of position index and $\theta_i$
idx_theta = torch.outer(seq_idx, theta).repeat(1, 2)
return torch.cos(idx_theta), torch.sin(idx_theta)
def apply_rope(x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor) -> torch.Tensor:
head_size = x.size(-1)
x1 = x[..., : head_size // 2] # (B, nh, T, hs/2)
x2 = x[..., head_size // 2 :] # (B, nh, T, hs/2)
rotated = torch.cat((-x2, x1), dim=-1) # (B, nh, T, hs)
roped = (x * cos) + (rotated * sin)
return roped.to(dtype=x.dtype)
class KVCache(nn.Module):
def __init__(
self,
k_shape: Tuple[int, int, int, int],
v_shape: Tuple[int, int, int, int],
device: Optional[torch.device] = None,
dtype: Optional[torch.dtype] = None,
) -> None:
super().__init__()
self.register_buffer(
"k", torch.zeros(k_shape, device=device, dtype=dtype), persistent=False
)
self.register_buffer(
"v", torch.zeros(v_shape, device=device, dtype=dtype), persistent=False
)
def forward(
self, input_pos: torch.Tensor, k: torch.Tensor, v: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
# move the buffer to the activation dtype for when AMP is used
self.k = self.k.to(k.dtype)
self.v = self.v.to(v.dtype)
# update the cache
k = self.k.index_copy_(2, input_pos, k)
v = self.v.index_copy_(2, input_pos, v)
return k, v
def reset_parameters(self) -> None:
torch.nn.init.zeros_(self.k)
torch.nn.init.zeros_(self.v)
def build_mask_cache(
max_seq_length: int, device: Optional[torch.device] = None
) -> torch.Tensor:
ones = torch.ones((max_seq_length, max_seq_length), device=device, dtype=torch.bool)
return torch.tril(ones).unsqueeze(0).unsqueeze(0)