import torch import torch.nn.functional as F from typing import Optional from torch import nn class RMSNorm(nn.Module): def __init__(self, dim, eps): super().__init__() self.eps = eps self.weight = nn.Parameter(torch.ones(dim)) def forward(self, x): # Root Mean Square Layer Normalization rms = torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) return x * rms * self.weight class RotaryEmbedding(nn.Module): def __init__(self, dim, max_seq_len=2048, theta=10000): super().__init__() self.dim = dim self.max_seq_len = max_seq_len freqs = 1.0 / (theta ** (torch.arange(0, dim, 2).float() / dim)) self.register_buffer("freqs", freqs) t = torch.arange(max_seq_len, dtype=self.freqs.dtype) freqs = torch.outer(t, self.freqs) cos = freqs.cos() sin = freqs.sin() self.register_buffer('cos', cos) self.register_buffer('sin', sin) def rotate_half(self, x): rot_dim = x.shape[-1] x1 = x[..., :rot_dim // 2] x2 = x[..., rot_dim // 2:] return torch.cat((-x2, x1), dim=-1) def apply_rotary_emb(self, t, x): rot_dim = self.freqs.shape[-1] cos = self.cos[t, :rot_dim] sin = self.sin[t, :rot_dim] rotated_x = (x[..., :rot_dim] * cos) + (self.rotate_half(x[..., :rot_dim]) * sin) if x.shape[-1] > rot_dim: rotated_x = torch.cat((rotated_x, x[..., rot_dim:]), dim=-1) return rotated_x def forward(self, x, seq_dim=-2): seq_len = x.shape[seq_dim] t = torch.arange(seq_len, device=x.device) return self.apply_rotary_emb(t, x) class Attention(nn.Module): def __init__(self, args): super().__init__() self.dim = args.dim self.num_heads = args.n_heads self.kv_heads = args.n_kv_heads self.head_dim = args.dim // args.n_heads self.kv_head_dim = args.dim // args.n_kv_heads assert self.head_dim * args.n_heads == args.dim, "args.dim must be divisible by args.n_heads" assert self.kv_head_dim * args.n_kv_heads == args.dim, "args.dim must be divisible by args.n_kv_heads" self.query_proj = nn.Linear(args.dim, args.dim, bias=False) self.key_proj = nn.Linear(args.dim, args.n_kv_heads * self.head_dim, bias=False) self.value_proj = nn.Linear(args.dim, args.n_kv_heads * self.head_dim, bias=False) self.rope = RotaryEmbedding(self.head_dim) self.out_proj = nn.Linear(args.dim, args.dim, bias=False) self.dropout = nn.Dropout(args.dropout) # # Caching storage (keys and values) cached_keys = None cached_values = None self.register_buffer('cached_keys', cached_keys) self.register_buffer('cached_values', cached_values) def forward(self, x, mask=None, use_cache=False): # # batch_size = x.size(0) batch_size, seq_len, C = x.size() query = self.query_proj(x) key = self.key_proj(x) value = self.value_proj(x) # Reshape for attention computation query = query.view(batch_size, seq_len, self.num_heads, self.head_dim) key = key.view(batch_size, seq_len, self.kv_heads, self.head_dim) value = value.view(batch_size, seq_len, self.kv_heads, self.head_dim) # Transpose for attention computation query = query.transpose(1, 2) # [batch, num_heads, seq_len, head_dim] key = key.transpose(1, 2) # [batch, num_kv_heads, seq_len, head_dim] value = value.transpose(1, 2) # [batch, num_kv_heads, seq_len, head_dim] query = self.rope(query) key = self.rope(key) # # If kv_heads are less than num_heads, repeat them # if self.kv_heads < self.num_heads: # key = key.repeat_interleave(self.num_heads // self.kv_heads, dim=1) # value = value.repeat_interleave(self.num_heads // self.kv_heads, dim=1) # # Compute attention # attn_weights = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(self.head_dim) # if mask is not None: # attn_weights = attn_weights + mask # attn_weights = F.softmax(attn_weights, dim=-1) # # Compute output # output = torch.matmul(attn_weights, value) # Flash-attn output = F.scaled_dot_product_attention(query, key, value, is_causal=True, dropout_p=self.dropout.p, enable_gqa=True) # Update cache only if using cache if use_cache: self.cached_keys = key self.cached_values = value else: # Reset cached values during training (to prevent unwanted accumulation) self.cached_keys = None self.cached_values = None output = output.transpose(1, 2).contiguous().view(batch_size, seq_len, -1) # [batch, seq_len, num_heads * head_dim] return self.out_proj(output) class FeedForward(nn.Module): def __init__(self, args): """ Initialize the FeedForward module. Args: dim (int): Input dimension. hidden_dim (int): Hidden dimension of the feedforward layer. # 2304 ffn_dim_multiplier (float, optional): Custom multiplier for hidden dimension. Defaults to None. Attributes: w1 (nn.Linear): Linear transformation for the first layer. w2 (nn.Linear): Linear transformation for the second layer. w3 (nn.Linear): Linear transformation for the third layer. """ super().__init__() self.w1 = nn.Linear(args.dim, args.intermediate_dim, bias=False) self.w2 = nn.Linear(args.intermediate_dim, args.dim, bias=False) self.w3 = nn.Linear(args.dim, args.intermediate_dim, bias=False) def forward(self, x): return self.w2(F.silu(self.w1(x)) * self.w3(x)) class TransformerBlock(nn.Module): def __init__(self, layer_id: int, args): """ Initialize a TransformerBlock. Args: layer_id (int): Identifier for the layer. args (ModelArgs): Model configuration parameters. Attributes: n_heads (int): Number of attention heads. dim (int): Dimension size of the model. head_dim (int): Dimension size of each attention head. attention (Attention): Attention module. feed_forward (FeedForward): FeedForward module. layer_id (int): Identifier for the layer. attention_norm (RMSNorm): Layer normalization for attention output. ffn_norm (RMSNorm): Layer normalization for feedforward output. """ super().__init__() self.n_heads = args.n_heads self.dim = args.dim self.head_dim = args.dim // args.n_heads self.attention = Attention(args) self.feed_forward = FeedForward(args) self.layer_id = layer_id self.attention_norm = RMSNorm(args.dim, eps=args.norm_eps) self.ffn_norm = RMSNorm(args.dim, eps=args.norm_eps) def forward( self, x: torch.Tensor, mask: Optional[torch.Tensor], use_cache: bool ): """ Perform a forward pass through the TransformerBlock. Args: x (torch.Tensor): Input tensor. mask (torch.Tensor, optional): Masking tensor for attention. Defaults to None. use_cache (bool): whether to use kv_cache Returns: torch.Tensor: Output tensor after applying attention and feedforward layers. """ h = x + self.attention(self.attention_norm(x), mask=mask, use_cache=use_cache) out = h + self.feed_forward(self.ffn_norm(h)) return out class Transformer(nn.Module): def __init__(self, args): """ Initialize a Transformer model. Args: args (ModelArgs): Model configuration parameters. Attributes: args (ModelArgs): Model configuration parameters. vocab_size (int): Vocabulary size. n_layers (int): Number of layers in the model. tok_embeddings (nn.Embedding): Token embeddings. layers (torch.nn.ModuleList): List of Transformer blocks. norm (RMSNorm): Layer normalization for the model output. output (nn.Linear): Linear layer for final output. """ super().__init__() self.args = args self.vocab_size = args.vocab_size self.n_layers = args.n_layers self.tok_embeddings = nn.Embedding(args.vocab_size, args.dim) self.layers = torch.nn.ModuleList() for layer_id in range(args.n_layers): self.layers.append(TransformerBlock(layer_id, args)) self.norm = RMSNorm(args.dim, eps=args.norm_eps) # self.output = nn.Linear( # args.dim, args.vocab_size, bias=False # ) # # weight sharing # self.output.weight = self.tok_embeddings.weight # weight initialization self.apply(self._init_weights) def _init_weights(self, module): std = self.args.init_scale if isinstance(module, nn.Linear): module.weight.data.normal_(mean=0.0, std=std) # if module.bias is not None: # module.bias.data.zero_() elif isinstance(module, nn.Embedding): module.weight.data.normal_(mean=0.0, std=std) def forward(self, tokens: torch.Tensor, mask: torch.Tensor = None, use_cache: bool = False): """ Perform a forward pass through the Transformer model. Args: tokens (torch.Tensor): Input token indices. mask (torch.Tensor, optional): Masking tensor for attention. Defaults to None. use_cache (bool): whether to use kv_cache Returns: torch.Tensor: Output logits after applying the Transformer model. """ _, seqlen = tokens.shape h = self.tok_embeddings(tokens) if mask is None: mask = torch.triu(torch.ones((seqlen, seqlen), dtype=torch.bool, device=tokens.device), diagonal=1) mask = mask.unsqueeze(0).unsqueeze(0) mask = mask * -1e4 for layer in self.layers: h = layer(h, mask, use_cache) h = self.norm(h) # output = self.output(h).float() output = F.linear(h, self.tok_embeddings.weight) return output def generate(self, input_ids, max_length, min_length=None, num_return_sequences=1, pad_token_id=None, do_sample=True, temperature=0.8, top_k=50, top_p=0.95 ): self.eval() # batch_size = input_ids.shape[0] min_length = min_length if min_length is not None else input_ids.shape[1] with torch.no_grad(): for ret_seq in range(num_return_sequences): print(f"Sequence #{ret_seq + 1}:") for _ in range(max_length - input_ids.shape[1]): outputs = self(input_ids, use_cache=True) next_token_logits = outputs[:, -1, :] # Apply temperature next_token_logits = next_token_logits / temperature # Apply top-k filtering if top_k > 0: indices_to_remove = next_token_logits < torch.topk(next_token_logits, top_k)[0][..., -1, None] next_token_logits[indices_to_remove] = float('-inf') # Apply top-p (nucleus) filtering if top_p < 1.0: sorted_logits, sorted_indices = torch.sort(next_token_logits, descending=True) cumulative_probs = torch.cumsum(torch.softmax(sorted_logits, dim=-1), dim=-1) sorted_indices_to_remove = cumulative_probs > top_p sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone() sorted_indices_to_remove[..., 0] = 0 indices_to_remove = sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove) next_token_logits[indices_to_remove] = float('-inf') # Sample from the filtered distribution if do_sample: probs = torch.softmax(next_token_logits, dim=-1) next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1) else: next_tokens = torch.argmax(next_token_logits, dim=-1) input_ids = torch.cat([input_ids, next_tokens.unsqueeze(-1)], dim=-1) # Stop if all sequences have hit the pad token if pad_token_id is not None and (next_tokens == pad_token_id).all(): break # Stop if we've reached min_length if input_ids.shape[1] < min_length: continue return input_ids