|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
""" PyTorch LLaMA model.""" |
|
import copy |
|
import math |
|
from typing import List, Optional, Tuple, Union |
|
|
|
import torch |
|
from torch import nn |
|
|
|
from transformers.cache_utils import Cache, StaticCache |
|
from transformers.modeling_attn_mask_utils import AttentionMaskConverter |
|
from transformers.modeling_outputs import BaseModelOutputWithPast |
|
from transformers.models.llama.modeling_llama import ( |
|
LLAMA_INPUTS_DOCSTRING, |
|
LlamaAttention, |
|
LlamaDecoderLayer, |
|
LlamaForCausalLM, |
|
LlamaModel, |
|
LlamaPreTrainedModel, |
|
_prepare_4d_causal_attention_mask_with_cache_position, |
|
logger, |
|
repeat_kv, |
|
rotate_half, |
|
) |
|
from transformers.utils import add_start_docstrings_to_model_forward, is_flash_attn_greater_or_equal_2_10 |
|
|
|
from .cache_utils import AutoLayerCache, LayerCache |
|
from .configuration_lckv import LCKVLlamaConfig |
|
from .utils import IterStep, LayerTypeParser, flash_attention_forward |
|
|
|
|
|
def apply_rotary(q, cos, sin, unsqueeze_dim=1): |
|
cos = cos.unsqueeze(unsqueeze_dim) |
|
sin = sin.unsqueeze(unsqueeze_dim) |
|
q_embed = (q * cos) + (rotate_half(q) * sin) |
|
return q_embed |
|
|
|
|
|
class LCKVLlamaAttention(LlamaAttention): |
|
""" |
|
LCKV Attention may not need to initialize weights for the key and value projections. |
|
""" |
|
|
|
def __init__(self, config: LCKVLlamaConfig, layer_idx: Optional[int] = None): |
|
super().__init__(config, layer_idx) |
|
self.layer_type = LayerTypeParser(config.layer_types)[layer_idx] |
|
self.sliding_window = config.sliding_window if self.layer_type.use_sliding_window else None |
|
|
|
|
|
if not self.layer_type.computes_kv: |
|
del self.k_proj |
|
del self.v_proj |
|
|
|
def forward( |
|
self, |
|
hidden_states: torch.Tensor, |
|
attention_mask: Optional[torch.Tensor] = None, |
|
position_ids: Optional[torch.LongTensor] = None, |
|
past_key_value: Optional[Cache] = None, |
|
output_attentions: bool = False, |
|
use_cache: bool = False, |
|
cache_position: Optional[torch.LongTensor] = None, |
|
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, |
|
**kwargs, |
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: |
|
bsz, q_len, _ = hidden_states.size() |
|
cos, sin = position_embeddings |
|
|
|
query_states = self.q_proj(hidden_states) |
|
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) |
|
query_states = apply_rotary(query_states, cos, sin) |
|
|
|
|
|
if self.layer_type.computes_kv: |
|
key_states = self.k_proj(hidden_states) |
|
value_states = self.v_proj(hidden_states) |
|
|
|
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) |
|
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) |
|
key_states = apply_rotary(key_states, cos, sin) |
|
|
|
if isinstance(past_key_value, Cache): |
|
|
|
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} |
|
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) |
|
|
|
past_key_value.layer_set(self.layer_idx, key_states, value_states) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
key_states, value_states = past_key_value.layer_get( |
|
self.layer_type.attends_to, |
|
zerofill=self.layer_type.attends_top and q_len == 1, |
|
) |
|
|
|
|
|
key_states = repeat_kv(key_states, self.num_key_value_groups) |
|
value_states = repeat_kv(value_states, self.num_key_value_groups) |
|
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) |
|
|
|
if attention_mask is not None: |
|
causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] |
|
attn_weights = attn_weights + causal_mask |
|
|
|
|
|
if self.config.force_nodiag or self.layer_type.attends_top: |
|
kv_len = key_states.size(2) |
|
mask = attn_weights.new_full((q_len, kv_len), torch.finfo(attn_weights.dtype).min) |
|
mask = mask.tril(diagonal=kv_len - q_len).triu(diagonal=kv_len - q_len) |
|
attn_weights = attn_weights + mask |
|
|
|
|
|
if self.sliding_window: |
|
kv_len = key_states.size(2) |
|
mask = attn_weights.new_full((q_len, kv_len), torch.finfo(attn_weights.dtype).min) |
|
mask = mask.tril(diagonal=kv_len - q_len - self.sliding_window) |
|
attn_weights = attn_weights + mask |
|
|
|
|
|
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) |
|
attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training) |
|
attn_output = torch.matmul(attn_weights, value_states) |
|
|
|
if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim): |
|
raise ValueError( |
|
f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is" |
|
f" {attn_output.size()}" |
|
) |
|
|
|
attn_output = attn_output.transpose(1, 2).contiguous() |
|
attn_output = attn_output.reshape(bsz, q_len, -1) |
|
attn_output = self.o_proj(attn_output) |
|
|
|
if not output_attentions: |
|
attn_weights = None |
|
|
|
return attn_output, attn_weights, past_key_value |
|
|
|
|
|
class LCKVLlamaFlashAttention2(LCKVLlamaAttention): |
|
""" |
|
LCKV Attention may not need to initialize weights for the key and value projections. |
|
""" |
|
|
|
def __init__(self, config: LCKVLlamaConfig, layer_idx: Optional[int] = None): |
|
super().__init__(config, layer_idx) |
|
|
|
|
|
|
|
|
|
self._flash_attn_uses_top_left_mask = not is_flash_attn_greater_or_equal_2_10() |
|
|
|
def forward( |
|
self, |
|
hidden_states: torch.Tensor, |
|
attention_mask: Optional[torch.LongTensor] = None, |
|
position_ids: Optional[torch.LongTensor] = None, |
|
past_key_value: Optional[LayerCache] = None, |
|
output_attentions: bool = False, |
|
use_cache: bool = False, |
|
cache_position: Optional[torch.LongTensor] = None, |
|
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, |
|
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: |
|
|
|
output_attentions = False |
|
|
|
bsz, q_len, _ = hidden_states.size() |
|
cos, sin = position_embeddings |
|
|
|
|
|
|
|
|
|
query_states = self.q_proj(hidden_states) |
|
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) |
|
query_states = apply_rotary(query_states, cos, sin) |
|
|
|
|
|
if self.layer_type.computes_kv: |
|
key_states = self.k_proj(hidden_states) |
|
value_states = self.v_proj(hidden_states) |
|
|
|
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) |
|
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) |
|
key_states = apply_rotary(key_states, cos, sin) |
|
|
|
if isinstance(past_key_value, Cache): |
|
|
|
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} |
|
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) |
|
|
|
past_key_value.layer_set(self.layer_idx, key_states, value_states) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
key_states, value_states = past_key_value.layer_get( |
|
self.layer_type.attends_to, |
|
zerofill=self.layer_type.attends_top and q_len == 1, |
|
) |
|
|
|
|
|
|
|
query_states = query_states.transpose(1, 2) |
|
key_states = key_states.transpose(1, 2) |
|
value_states = value_states.transpose(1, 2) |
|
|
|
dropout_rate = self.attention_dropout if self.training else 0.0 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
input_dtype = query_states.dtype |
|
if input_dtype == torch.float32: |
|
if torch.is_autocast_enabled(): |
|
target_dtype = torch.get_autocast_gpu_dtype() |
|
|
|
elif hasattr(self.config, "_pre_quantization_dtype"): |
|
target_dtype = self.config._pre_quantization_dtype |
|
else: |
|
target_dtype = self.q_proj.weight.dtype |
|
|
|
logger.warning_once( |
|
f"The input hidden states seems to be silently casted in float32, this might be related to" |
|
f" the fact you have upcasted embedding or layer norm layers in float32. We will cast back the input in" |
|
f" {target_dtype}." |
|
) |
|
|
|
query_states = query_states.to(target_dtype) |
|
key_states = key_states.to(target_dtype) |
|
value_states = value_states.to(target_dtype) |
|
|
|
attn_output = flash_attention_forward( |
|
query_states, |
|
key_states, |
|
value_states, |
|
attention_mask, |
|
q_len, |
|
position_ids=position_ids, |
|
dropout=dropout_rate, |
|
sliding_window=self.sliding_window, |
|
use_top_left_mask=self._flash_attn_uses_top_left_mask, |
|
is_causal=self.is_causal, |
|
no_diag=(self.config.force_nodiag or self.layer_type.attends_top), |
|
) |
|
|
|
attn_output = attn_output.reshape(bsz, q_len, -1).contiguous() |
|
attn_output = self.o_proj(attn_output) |
|
|
|
if not output_attentions: |
|
attn_weights = None |
|
|
|
return attn_output, attn_weights, past_key_value |
|
|
|
|
|
LCKV_LLAMA_ATTENTION_CLASSES = { |
|
"eager": LCKVLlamaAttention, |
|
"flash_attention_2": LCKVLlamaFlashAttention2, |
|
} |
|
|
|
|
|
class LCKVLlamaDecoderLayer(LlamaDecoderLayer): |
|
def __init__(self, config: LCKVLlamaConfig, layer_idx: int): |
|
super().__init__(config, layer_idx) |
|
self.self_attn = LCKV_LLAMA_ATTENTION_CLASSES[config._attn_implementation](config=config, layer_idx=layer_idx) |
|
|
|
|
|
class LCKVLlamaPreTrainedModel(LlamaPreTrainedModel): |
|
config_class = LCKVLlamaConfig |
|
supports_gradient_checkpointing = False |
|
_no_split_modules = ["LCKVLlamaDecoderLayer"] |
|
_supports_flash_attn_2 = True |
|
_supports_sdpa = False |
|
|
|
|
|
class LCKVLlamaModel(LCKVLlamaPreTrainedModel, LlamaModel): |
|
def __init__(self, config: LCKVLlamaConfig): |
|
LCKVLlamaPreTrainedModel.__init__(self, config) |
|
LlamaModel.__init__(self, copy.deepcopy(config)) |
|
self.layers = nn.ModuleList([LCKVLlamaDecoderLayer(config, layer_idx=i) for i in range(config.num_hidden_layers)]) |
|
self.parser = LayerTypeParser(config.layer_types) |
|
|
|
|
|
self.post_init() |
|
|
|
@add_start_docstrings_to_model_forward(LLAMA_INPUTS_DOCSTRING) |
|
def forward( |
|
self, |
|
input_ids: torch.LongTensor = None, |
|
attention_mask: Optional[torch.Tensor] = None, |
|
position_ids: Optional[torch.LongTensor] = None, |
|
past_key_values: Optional[LayerCache] = None, |
|
inputs_embeds: Optional[torch.FloatTensor] = None, |
|
use_cache: Optional[bool] = None, |
|
output_attentions: Optional[bool] = None, |
|
output_hidden_states: Optional[bool] = None, |
|
return_dict: Optional[bool] = None, |
|
cache_position: Optional[torch.LongTensor] = None, |
|
) -> Union[Tuple, BaseModelOutputWithPast]: |
|
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions |
|
output_hidden_states = ( |
|
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states |
|
) |
|
use_cache = use_cache if use_cache is not None else self.config.use_cache |
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
|
|
|
if (input_ids is None) ^ (inputs_embeds is not None): |
|
raise ValueError( |
|
"You cannot specify both input_ids and inputs_embeds at the same time, and must specify either one" |
|
) |
|
|
|
if self.gradient_checkpointing and self.training and use_cache: |
|
logger.warning_once( |
|
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`." |
|
) |
|
use_cache = False |
|
|
|
if inputs_embeds is None: |
|
inputs_embeds = self.embed_tokens(input_ids) |
|
|
|
|
|
if not isinstance(past_key_values, LayerCache): |
|
placeholder = inputs_embeds.new_zeros( |
|
inputs_embeds.shape[0], |
|
self.config.num_key_value_heads, |
|
1, |
|
getattr(self.config, "head_dim", self.config.hidden_size // self.config.num_attention_heads) |
|
) |
|
|
|
if past_key_values is None: |
|
past_key_values = LayerCache() |
|
elif isinstance(past_key_values, Cache): |
|
past_key_values = AutoLayerCache.from_cache(past_key_values) |
|
else: |
|
raise NotImplementedError("Only DynamicCache is supported for now.") |
|
|
|
past_key_values.setup(placeholder) |
|
|
|
if cache_position is None: |
|
past_seen_tokens = past_key_values.get_seq_length() if isinstance(past_key_values, Cache) else 0 |
|
cache_position = torch.arange( |
|
past_seen_tokens, past_seen_tokens + inputs_embeds.shape[1], device=inputs_embeds.device |
|
) |
|
if position_ids is None: |
|
position_ids = cache_position.unsqueeze(0) |
|
|
|
causal_mask = self._update_causal_mask( |
|
attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions |
|
) |
|
hidden_states = inputs_embeds |
|
|
|
|
|
position_embeddings = self.rotary_emb(hidden_states, position_ids) |
|
|
|
|
|
use_sequential = ( |
|
self.config.use_sequential |
|
or inputs_embeds.shape[1] <= self.config.forward_passes + self.config.backward_passes |
|
and self.parser.attends_top() |
|
) |
|
|
|
if use_sequential: |
|
|
|
iteration_outputs = self._modeling_sequential( |
|
hidden_states, |
|
attention_mask=causal_mask, |
|
position_ids=position_ids, |
|
past_key_values=past_key_values, |
|
output_attentions=output_attentions, |
|
use_cache=use_cache, |
|
cache_position=cache_position, |
|
position_embeddings=position_embeddings, |
|
output_hidden_states=output_hidden_states, |
|
) |
|
|
|
else: |
|
|
|
|
|
past_key_values.initialize(self.parser, inputs_embeds.shape[1]) |
|
|
|
|
|
plan = self.parser.iteration_plan(self.config.forward_passes, self.config.backward_passes) |
|
|
|
iteration_outputs = self._modeling_with_plan( |
|
hidden_states, |
|
attention_mask=causal_mask, |
|
position_ids=position_ids, |
|
past_key_values=past_key_values, |
|
output_attentions=output_attentions, |
|
use_cache=use_cache, |
|
cache_position=cache_position, |
|
position_embeddings=position_embeddings, |
|
output_hidden_states=output_hidden_states, |
|
modeling_plan=plan, |
|
) |
|
|
|
hidden_states = iteration_outputs.last_hidden_state |
|
all_hidden_states = iteration_outputs.hidden_states |
|
all_self_attns = iteration_outputs.attentions |
|
next_decoder_cache = iteration_outputs.past_key_values |
|
|
|
hidden_states = self.norm(hidden_states) |
|
|
|
|
|
if output_hidden_states: |
|
all_hidden_states += (hidden_states,) |
|
|
|
next_cache = next_decoder_cache if use_cache else None |
|
|
|
if not return_dict: |
|
return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) |
|
return BaseModelOutputWithPast( |
|
last_hidden_state=hidden_states, |
|
past_key_values=next_cache, |
|
hidden_states=all_hidden_states, |
|
attentions=all_self_attns, |
|
) |
|
|
|
def _iterate_layers( |
|
self, |
|
hidden_states: torch.Tensor, |
|
attention_mask: Optional[torch.LongTensor] = None, |
|
position_ids: Optional[torch.LongTensor] = None, |
|
past_key_values: Optional[LayerCache] = None, |
|
output_attentions: bool = False, |
|
use_cache: bool = False, |
|
cache_position: Optional[torch.LongTensor] = None, |
|
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, |
|
output_hidden_states: Optional[bool] = False, |
|
layer_slice: Optional[slice] = None, |
|
) -> BaseModelOutputWithPast: |
|
""" |
|
Iterates over the layers of the model, calling each layer in turn. |
|
""" |
|
|
|
all_hidden_states = () if output_hidden_states else None |
|
all_self_attns = () if output_attentions else None |
|
next_decoder_cache = None |
|
|
|
|
|
if layer_slice is None: |
|
layer_slice = slice(None) |
|
|
|
for decoder_layer in self.layers[layer_slice]: |
|
if output_hidden_states: |
|
all_hidden_states += (hidden_states,) |
|
|
|
if self.gradient_checkpointing and self.training: |
|
layer_outputs = self._gradient_checkpointing_func( |
|
decoder_layer.__call__, |
|
hidden_states, |
|
attention_mask, |
|
position_ids, |
|
past_key_values, |
|
output_attentions, |
|
use_cache, |
|
cache_position, |
|
position_embeddings, |
|
) |
|
else: |
|
layer_outputs = decoder_layer( |
|
hidden_states, |
|
attention_mask=attention_mask, |
|
position_ids=position_ids, |
|
past_key_value=past_key_values, |
|
output_attentions=output_attentions, |
|
use_cache=use_cache, |
|
cache_position=cache_position, |
|
position_embeddings=position_embeddings, |
|
) |
|
|
|
hidden_states = layer_outputs[0] |
|
|
|
if use_cache: |
|
next_decoder_cache = layer_outputs[2 if output_attentions else 1] |
|
|
|
if output_attentions: |
|
all_self_attns += (layer_outputs[1],) |
|
|
|
next_cache = next_decoder_cache if use_cache else None |
|
|
|
return BaseModelOutputWithPast( |
|
last_hidden_state=hidden_states, |
|
past_key_values=next_cache, |
|
hidden_states=all_hidden_states, |
|
attentions=all_self_attns, |
|
) |
|
|
|
def _modeling_with_plan( |
|
self, |
|
hidden_states: torch.Tensor, |
|
attention_mask: Optional[torch.LongTensor] = None, |
|
position_ids: Optional[torch.LongTensor] = None, |
|
past_key_values: Optional[LayerCache] = None, |
|
output_attentions: bool = False, |
|
use_cache: bool = False, |
|
cache_position: Optional[torch.LongTensor] = None, |
|
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, |
|
output_hidden_states: Optional[bool] = False, |
|
modeling_plan: List[IterStep] = None, |
|
) -> BaseModelOutputWithPast: |
|
""" |
|
Given a plan, iteratively update the hidden states. |
|
""" |
|
|
|
all_hidden_states = () if output_hidden_states else None |
|
all_self_attns = () if output_attentions else None |
|
next_decoder_cache = None |
|
|
|
for step in modeling_plan: |
|
end = len(self.layers) if step.layer_slice.stop is None else step.layer_slice.stop |
|
iteration_func = self._iterate_layers if step.requires_grad else torch.no_grad()(self._iterate_layers) |
|
|
|
if isinstance(past_key_values, Cache): |
|
past_key_values._update = step.update |
|
|
|
iteration_outputs = iteration_func( |
|
hidden_states, |
|
attention_mask=attention_mask, |
|
position_ids=position_ids, |
|
past_key_values=past_key_values, |
|
output_attentions=output_attentions, |
|
use_cache=use_cache, |
|
cache_position=cache_position, |
|
position_embeddings=position_embeddings, |
|
output_hidden_states=output_hidden_states, |
|
layer_slice=step.layer_slice |
|
) |
|
|
|
|
|
if step.update: |
|
hidden_states = iteration_outputs.last_hidden_state |
|
|
|
if output_hidden_states: |
|
all_hidden_states = all_hidden_states[:end] + iteration_outputs.hidden_states |
|
|
|
if output_attentions: |
|
all_self_attns = all_self_attns[:end] + iteration_outputs.attentions |
|
|
|
if use_cache: |
|
next_decoder_cache = iteration_outputs.past_key_values |
|
|
|
return BaseModelOutputWithPast( |
|
last_hidden_state=hidden_states, |
|
past_key_values=next_decoder_cache, |
|
hidden_states=all_hidden_states, |
|
attentions=all_self_attns, |
|
) |
|
|
|
def _modeling_sequential( |
|
self, |
|
hidden_states: torch.Tensor, |
|
attention_mask: Optional[torch.LongTensor] = None, |
|
position_ids: Optional[torch.LongTensor] = None, |
|
past_key_values: Optional[LayerCache] = None, |
|
output_attentions: bool = False, |
|
use_cache: bool = False, |
|
cache_position: Optional[torch.LongTensor] = None, |
|
position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, |
|
output_hidden_states: Optional[bool] = False, |
|
) -> BaseModelOutputWithPast: |
|
""" |
|
Sequentially update the hidden states, token by token. |
|
""" |
|
seq_len = hidden_states.shape[1] |
|
last_hidden_state = [] |
|
all_hidden_states = [] |
|
all_self_attns = [] |
|
|
|
for i in range(seq_len): |
|
m_hidden_states = hidden_states[:, i:i+1] |
|
m_attention_mask = ( |
|
(attention_mask[:, : i + 1] if attention_mask.ndim == 2 else attention_mask[:, :, i : i + 1]) |
|
if attention_mask is not None |
|
else None |
|
) |
|
m_position_ids = position_ids[:, i:i+1] if position_ids is not None else None |
|
m_cache_position = cache_position[i:i+1] if cache_position is not None else None |
|
m_position_embeddings = ( |
|
position_embeddings[0][:, i:i+1], |
|
position_embeddings[1][:, i:i+1] |
|
) |
|
|
|
outputs = self._iterate_layers( |
|
m_hidden_states, |
|
attention_mask=m_attention_mask, |
|
position_ids=m_position_ids, |
|
past_key_values=past_key_values, |
|
output_attentions=output_attentions, |
|
use_cache=use_cache, |
|
cache_position=m_cache_position, |
|
position_embeddings=m_position_embeddings, |
|
output_hidden_states=output_hidden_states |
|
) |
|
|
|
last_hidden_state.append(outputs.last_hidden_state) |
|
|
|
if output_hidden_states: |
|
all_hidden_states.append(outputs.hidden_states) |
|
|
|
if output_attentions: |
|
all_self_attns.append(outputs.attentions) |
|
|
|
if use_cache: |
|
past_key_values = outputs.past_key_values |
|
|
|
last_hidden_state = torch.cat(last_hidden_state, dim=1) |
|
|
|
if output_hidden_states: |
|
all_hidden_states = [ |
|
torch.cat([hs[i] for hs in all_hidden_states], dim=1) for i in range(len(all_hidden_states[0])) |
|
] |
|
|
|
if output_attentions: |
|
|
|
all_self_attns = all_self_attns[-1] |
|
|
|
return BaseModelOutputWithPast( |
|
last_hidden_state=last_hidden_state, |
|
past_key_values=past_key_values, |
|
hidden_states=all_hidden_states, |
|
attentions=all_self_attns, |
|
) |
|
|
|
def _update_causal_mask( |
|
self, |
|
attention_mask: torch.Tensor, |
|
input_tensor: torch.Tensor, |
|
cache_position: torch.Tensor, |
|
past_key_values: Cache, |
|
output_attentions: bool, |
|
): |
|
"""fix this function to handle layer cache""" |
|
if self.config._attn_implementation == "flash_attention_2": |
|
if attention_mask is not None and 0.0 in attention_mask: |
|
return attention_mask |
|
return None |
|
|
|
|
|
|
|
|
|
past_seen_tokens = past_key_values.get_seq_length() if isinstance(past_key_values, Cache) else 0 |
|
using_static_cache = isinstance(past_key_values, StaticCache) |
|
|
|
|
|
if self.config._attn_implementation == "sdpa" and not using_static_cache and not output_attentions: |
|
if AttentionMaskConverter._ignore_causal_mask_sdpa( |
|
attention_mask, |
|
inputs_embeds=input_tensor, |
|
past_key_values_length=past_seen_tokens, |
|
is_training=self.training, |
|
): |
|
return None |
|
|
|
dtype, device = input_tensor.dtype, input_tensor.device |
|
min_dtype = torch.finfo(dtype).min |
|
sequence_length = input_tensor.shape[1] |
|
if using_static_cache: |
|
target_length = past_key_values.get_max_length() |
|
else: |
|
target_length = ( |
|
attention_mask.shape[-1] |
|
if isinstance(attention_mask, torch.Tensor) |
|
else past_seen_tokens + sequence_length + 1 |
|
) |
|
|
|
|
|
causal_mask = _prepare_4d_causal_attention_mask_with_cache_position( |
|
attention_mask, |
|
sequence_length=sequence_length, |
|
target_length=target_length, |
|
dtype=dtype, |
|
device=device, |
|
min_dtype=min_dtype, |
|
cache_position=cache_position, |
|
batch_size=input_tensor.shape[0], |
|
) |
|
|
|
if ( |
|
self.config._attn_implementation == "sdpa" |
|
and attention_mask is not None |
|
and attention_mask.device.type == "cuda" |
|
and not output_attentions |
|
): |
|
|
|
|
|
|
|
causal_mask = AttentionMaskConverter._unmask_unattended(causal_mask, min_dtype) |
|
|
|
return causal_mask |
|
|
|
|
|
class LCKVLlamaForCausalLM(LCKVLlamaPreTrainedModel, LlamaForCausalLM): |
|
def __init__(self, config): |
|
LCKVLlamaPreTrainedModel.__init__(self, config) |
|
LlamaForCausalLM.__init__(self, copy.deepcopy(config)) |
|
self.model = LCKVLlamaModel(config) |
|
|
|
|
|
self.post_init() |
|
|
|
def prepare_inputs_for_generation( |
|
self, |
|
input_ids, |
|
past_key_values=None, |
|
attention_mask=None, |
|
inputs_embeds=None, |
|
cache_position=None, |
|
position_ids=None, |
|
use_cache=True, |
|
num_logits_to_keep=None, |
|
**kwargs, |
|
): |
|
"""fix this function to handle sink cache""" |
|
|
|
|
|
|
|
if isinstance(past_key_values, Cache): |
|
if inputs_embeds is not None: |
|
input_ids = input_ids[:, -cache_position.shape[0] :] |
|
elif input_ids.shape[1] != cache_position.shape[0]: |
|
input_ids = input_ids[:, cache_position] |
|
|
|
if attention_mask is not None and position_ids is None: |
|
|
|
position_ids = attention_mask.long().cumsum(-1) - 1 |
|
position_ids.masked_fill_(attention_mask == 0, 1) |
|
if isinstance(past_key_values, Cache): |
|
|
|
if getattr(past_key_values, "build_position_ids_based_on_cache", False): |
|
cur_cache_length = past_key_values.get_seq_length() |
|
position_ids = position_ids[:, cur_cache_length :cur_cache_length + input_ids.shape[1]] |
|
else: |
|
position_ids = position_ids[:, -input_ids.shape[1] :] |
|
|
|
|
|
position_ids = position_ids.clone(memory_format=torch.contiguous_format) |
|
|
|
|
|
if inputs_embeds is not None and cache_position[0] == 0: |
|
model_inputs = {"inputs_embeds": inputs_embeds, "input_ids": None} |
|
else: |
|
|
|
model_inputs = {"input_ids": input_ids.clone(memory_format=torch.contiguous_format), "inputs_embeds": None} |
|
|
|
if isinstance(past_key_values, StaticCache) and attention_mask.ndim == 2: |
|
if model_inputs["inputs_embeds"] is not None: |
|
batch_size, sequence_length, _ = model_inputs["inputs_embeds"].shape |
|
device = model_inputs["inputs_embeds"].device |
|
else: |
|
batch_size, sequence_length = model_inputs["input_ids"].shape |
|
device = model_inputs["input_ids"].device |
|
|
|
dtype = self.lm_head.weight.dtype |
|
min_dtype = torch.finfo(dtype).min |
|
|
|
attention_mask = _prepare_4d_causal_attention_mask_with_cache_position( |
|
attention_mask, |
|
sequence_length=sequence_length, |
|
target_length=past_key_values.get_max_length(), |
|
dtype=dtype, |
|
device=device, |
|
min_dtype=min_dtype, |
|
cache_position=cache_position, |
|
batch_size=batch_size, |
|
) |
|
|
|
if num_logits_to_keep is not None: |
|
model_inputs["num_logits_to_keep"] = num_logits_to_keep |
|
|
|
model_inputs.update( |
|
{ |
|
"position_ids": position_ids, |
|
"cache_position": cache_position, |
|
"past_key_values": past_key_values, |
|
"use_cache": use_cache, |
|
"attention_mask": attention_mask, |
|
} |
|
) |
|
return model_inputs |
|
|