import torch import torch.utils.checkpoint from torch import nn from typing import List, Optional, Tuple, Union from transformers.cache_utils import Cache, DynamicCache, StaticCache from transformers.models.llama.modeling_llama import LlamaRMSNorm, LlamaDecoderLayer from transformers.modeling_outputs import BaseModelOutputWithPast class AR_head(nn.Module): """ Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`GemmaDecoderLayer`] Args: config: GemmaConfig """ def __init__(self, config, codebook_size, num_codebooks): super().__init__() # import pdb;pdb.set_trace() self.num_codebooks = num_codebooks vocab_size = codebook_size self.sub_vocab_size = vocab_size // self.num_codebooks # self.layers = nn.ModuleList( # [GemmaDecoderLayer(config, layer_idx) for layer_idx in range(3)] # ) # self.norm = GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.linear_head = nn.Linear(config.hidden_size, self.sub_vocab_size) self.layers = nn.ModuleList( [LlamaDecoderLayer(config, layer_idx) for layer_idx in range(3)] ) self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.gradient_checkpointing = False # vocab_size 16384 self.codebooks = nn.ModuleList() for _ in range(self.num_codebooks-1): codebook = nn.Embedding(self.sub_vocab_size, config.hidden_size) self.codebooks.append(codebook) # import pdb;pdb.set_trace() self.config = config self.gradient_checkpointing = False # Initialize weights and apply final processing self._init_weights(self.layers) def set_input_embeddings(self, value): self.embed_tokens = value def _init_weights(self, module): std = self.config.initializer_range if isinstance(module, nn.Linear): module.weight.data.normal_(mean=0.0, std=std) if module.bias is not None: module.bias.data.zero_() elif isinstance(module, nn.Embedding): module.weight.data.normal_(mean=0.0, std=std) if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() # Ignore copy def forward( self, input_ids: torch.LongTensor = None, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[List[torch.FloatTensor]] = None, inputs_embeds: Optional[torch.FloatTensor] = None, use_cache: Optional[bool] = None, output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, cache_position: Optional[torch.LongTensor] = None, ) -> torch.tensor: output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions output_hidden_states = ( output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states ) use_cache = use_cache if use_cache is not None else self.config.use_cache return_dict = return_dict if return_dict is not None else self.config.use_return_dict if (input_ids is None) ^ (inputs_embeds is not None): raise ValueError( "You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one" ) if self.gradient_checkpointing and self.training and use_cache: logger.warning_once( "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`." ) use_cache = False if inputs_embeds is None: inputs_embeds = self.embed_tokens(input_ids) past_seen_tokens = 0 if use_cache: # kept for BC (cache positions) if not isinstance(past_key_values, StaticCache): past_key_values = DynamicCache.from_legacy_cache(past_key_values) past_seen_tokens = past_key_values.get_seq_length() if cache_position is None: if isinstance(past_key_values, StaticCache): raise ValueError("cache_position is a required argument when using StaticCache.") cache_position = torch.arange( past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device ) if position_ids is None: position_ids = cache_position.unsqueeze(0) causal_mask = self._update_causal_mask(attention_mask, inputs_embeds, cache_position) # embed positions hidden_states = inputs_embeds # decoder layers all_hidden_states = () if output_hidden_states else None all_self_attns = () if output_attentions else None next_decoder_cache = None for decoder_layer in self.layers: if output_hidden_states: all_hidden_states += (hidden_states,) if self.gradient_checkpointing and self.training: layer_outputs = self._gradient_checkpointing_func( decoder_layer.__call__, hidden_states, causal_mask, position_ids, past_key_values, output_attentions, use_cache, cache_position, ) else: layer_outputs = decoder_layer( hidden_states, attention_mask=causal_mask, position_ids=position_ids, past_key_value=past_key_values, output_attentions=output_attentions, use_cache=use_cache, cache_position=cache_position, ) hidden_states = layer_outputs[0] if use_cache: next_decoder_cache = layer_outputs[2 if output_attentions else 1] if output_attentions: all_self_attns += (layer_outputs[1],) hidden_states = self.norm(hidden_states) if output_hidden_states: all_hidden_states += (hidden_states,) next_cache = None if use_cache: next_cache = ( next_decoder_cache.to_legacy_cache() if isinstance(next_decoder_cache, Cache) else next_decoder_cache ) if not return_dict: return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) return BaseModelOutputWithPast( last_hidden_state=hidden_states, past_key_values=next_cache, hidden_states=all_hidden_states, attentions=all_self_attns, ) # TODO: As of torch==2.2.0, the `attention_mask` passed to the model in `generate` is 2D and of dynamic length even when the static # KV cache is used. This is an issue for torch.compile which then recaptures cudagraphs at each decode steps due to the dynamic shapes. # (`recording cudagraph tree for symint key 13`, etc.), which is VERY slow. A workaround is `@torch.compiler.disable`, but this prevents using # `fullgraph=True`. See more context in https://github.com/huggingface/transformers/pull/29114 def _update_causal_mask(self, attention_mask, input_tensor, cache_position): if self.config._attn_implementation == "flash_attention_2": if attention_mask is not None and 0.0 in attention_mask: return attention_mask return None dtype, device = input_tensor.dtype, input_tensor.device min_dtype = torch.finfo(dtype).min sequence_length = input_tensor.shape[1] if hasattr(self.layers[0].self_attn, "past_key_value"): # static cache target_length = self.config.max_position_embeddings else: # dynamic cache target_length = ( attention_mask.shape[-1] if isinstance(attention_mask, torch.Tensor) else cache_position[-1] + 1 ) causal_mask = torch.full((sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device) if sequence_length != 1: causal_mask = torch.triu(causal_mask, diagonal=1) causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1) causal_mask = causal_mask[None, None, :, :].expand(input_tensor.shape[0], 1, -1, -1) if attention_mask is not None: causal_mask = causal_mask.clone() # copy to contiguous memory for in-place edit if attention_mask.dim() == 2: mask_length = attention_mask.shape[-1] padding_mask = causal_mask[..., :mask_length].eq(0.0) * attention_mask[:, None, None, :].eq(0.0) causal_mask[..., :mask_length] = causal_mask[..., :mask_length].masked_fill(padding_mask, min_dtype) elif attention_mask.dim() == 4: # backwards compatibility: we allow passing a 4D attention mask shorter than the input length with # cache. In that case, the 4D attention mask attends to the newest tokens only. if attention_mask.shape[-2] < cache_position[0] + sequence_length: offset = cache_position[0] else: offset = 0 mask_shape = attention_mask.shape mask_slice = (attention_mask.eq(0.0)).to(dtype=dtype) * min_dtype causal_mask[ : mask_shape[0], : mask_shape[1], offset : mask_shape[2] + offset, : mask_shape[3] ] = mask_slice if ( self.config._attn_implementation == "sdpa" and attention_mask is not None and attention_mask.device.type == "cuda" ): # TODO: For dynamo, rather use a check on fullgraph=True once this is possible (https://github.com/pytorch/pytorch/pull/120400). is_tracing = ( torch.jit.is_tracing() or isinstance(input_tensor, torch.fx.Proxy) or (hasattr(torch, "_dynamo") and torch._dynamo.is_compiling()) ) if not is_tracing and torch.any(attention_mask != 1): # Attend to all tokens in fully masked rows in the causal_mask, for example the relevant first rows when # using left padding. This is required by F.scaled_dot_product_attention memory-efficient attention path. # Details: https://github.com/pytorch/pytorch/issues/110213 causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype) return causal_mask