CRYSTAL-Mac
/
Perceptrix
/finetune
/build
/lib
/llmfoundry
/models
/layers
/llama_attention_monkeypatch.py
# Copyright 2022 MosaicML LLM Foundry authors | |
# SPDX-License-Identifier: Apache-2.0 | |
# This file is copied and modified from | |
# https://github.com/huggingface/transformers/blob/fe3c8ab1af558b95f67f5fafc0c55f09fd2b09db/src/transformers/models/llama/modeling_llama.py | |
# See the clearly denoted code blocks for the main modifications (there are a few others like type ignores, and error messages) | |
import logging | |
from typing import Callable, Optional, Tuple | |
import torch | |
import torch.nn.functional as F | |
from transformers.models.llama.modeling_llama import LlamaAttention | |
from llmfoundry.models.layers.attention import ( | |
scaled_multihead_dot_product_attention, triton_flash_attn_fn) | |
log = logging.getLogger(__name__) | |
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: | |
"""Equivalent of torch.repeat_interleave(x, dim=1, | |
repeats=n_rep). | |
The hidden states go from (batch, num_key_value_heads, seqlen, head_dim) to | |
(batch, num_attention_heads, seqlen, head_dim) | |
""" | |
batch, num_key_value_heads, slen, head_dim = hidden_states.shape | |
if n_rep == 1: | |
return hidden_states | |
hidden_states = hidden_states[:, :, | |
None, :, :].expand(batch, num_key_value_heads, | |
n_rep, slen, head_dim) | |
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, | |
head_dim) | |
def rotate_half(x: torch.Tensor) -> torch.Tensor: | |
"""Rotates half the hidden dims of the input.""" | |
x1 = x[..., :x.shape[-1] // 2] | |
x2 = x[..., x.shape[-1] // 2:] | |
return torch.cat((-x2, x1), dim=-1) | |
def apply_rotary_pos_emb( | |
q: torch.Tensor, | |
k: torch.Tensor, | |
cos: torch.Tensor, | |
sin: torch.Tensor, | |
position_ids: Optional[torch.Tensor] = None | |
) -> Tuple[torch.Tensor, torch.Tensor]: | |
# The first two dimensions of cos and sin are always 1, so we can `squeeze` them. | |
cos = cos.squeeze(1).squeeze(0) # [seq_len, dim] | |
sin = sin.squeeze(1).squeeze(0) # [seq_len, dim] | |
cos = cos[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim] | |
sin = sin[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim] | |
q_embed = (q * cos) + (rotate_half(q) * sin) | |
k_embed = (k * cos) + (rotate_half(k) * sin) | |
return q_embed, k_embed | |
def get_llama_attention_patch_fn(patch_fn_name: str = 'torch') -> Callable: | |
if patch_fn_name == 'torch': | |
return llama_attention_patch_torch | |
elif patch_fn_name == 'triton': | |
return llama_attention_patch_triton | |
else: | |
raise ValueError( | |
f'Unrecognized llama attention patch function: {patch_fn_name}') | |
def llama_attention_patch_torch( | |
self: LlamaAttention, | |
hidden_states: torch.Tensor, | |
attention_mask: Optional[torch.Tensor] = None, | |
position_ids: Optional[torch.LongTensor] = None, | |
past_key_value: Optional[Tuple[torch.Tensor]] = None, | |
output_attentions: bool = False, | |
use_cache: bool = False, | |
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: | |
if use_cache: | |
raise NotImplementedError( | |
'use_cache is not yet supported when patching Llama attention.') | |
bsz, q_len, _ = hidden_states.size() | |
if self.config.pretraining_tp > 1: | |
key_value_slicing = (self.num_key_value_heads * | |
self.head_dim) // self.config.pretraining_tp | |
query_slices = self.q_proj.weight.split( | |
(self.num_heads * self.head_dim) // self.config.pretraining_tp, | |
dim=0) | |
key_slices = self.k_proj.weight.split(key_value_slicing, dim=0) | |
value_slices = self.v_proj.weight.split(key_value_slicing, dim=0) | |
query_states = [ | |
F.linear(hidden_states, query_slices[i]) | |
for i in range(self.config.pretraining_tp) | |
] | |
query_states = torch.cat(query_states, dim=-1) | |
key_states = [ | |
F.linear(hidden_states, key_slices[i]) | |
for i in range(self.config.pretraining_tp) | |
] | |
key_states = torch.cat(key_states, dim=-1) | |
value_states = [ | |
F.linear(hidden_states, value_slices[i]) | |
for i in range(self.config.pretraining_tp) | |
] | |
value_states = torch.cat(value_states, dim=-1) | |
else: | |
query_states = self.q_proj(hidden_states) | |
key_states = self.k_proj(hidden_states) | |
value_states = self.v_proj(hidden_states) | |
query_states = query_states.view(bsz, q_len, self.num_heads, | |
self.head_dim).transpose(1, 2) | |
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, | |
self.head_dim).transpose(1, 2) | |
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, | |
self.head_dim).transpose(1, 2) | |
kv_seq_len = key_states.shape[-2] | |
if past_key_value is not None: | |
kv_seq_len += past_key_value[0].shape[-2] | |
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) | |
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, | |
cos, sin, position_ids) | |
### MAIN MODIFICATIONS START HERE ### | |
query_states = query_states.transpose(1, 2).view( | |
bsz, q_len, self.num_heads * self.head_dim) | |
key_states = key_states.transpose(1, 2).view( | |
bsz, q_len, self.num_key_value_heads * self.head_dim) | |
value_states = value_states.transpose(1, 2).view( | |
bsz, q_len, self.num_key_value_heads * self.head_dim) | |
attn_output, attn_weights, _ = scaled_multihead_dot_product_attention( | |
query=query_states, | |
key=key_states, | |
value=value_states, | |
n_heads=self.num_heads, | |
kv_n_heads=self.num_key_value_heads, | |
past_key_value=None, | |
softmax_scale=None, | |
attn_bias=attention_mask, | |
key_padding_mask=None, | |
is_causal=False, # The causal mask is propagated from LLamaForCausalLM | |
dropout_p=0, | |
training=self.training, | |
needs_weights=False, | |
) | |
### MAIN MODIFICATIONS END HERE ### | |
if self.config.pretraining_tp > 1: | |
attn_output = attn_output.split(self.hidden_size // | |
self.config.pretraining_tp, | |
dim=2) | |
o_proj_slices = self.o_proj.weight.split(self.hidden_size // | |
self.config.pretraining_tp, | |
dim=1) | |
attn_output = sum([ | |
F.linear(attn_output[i], o_proj_slices[i]) | |
for i in range(self.config.pretraining_tp) | |
]) | |
else: | |
attn_output = self.o_proj(attn_output) | |
assert isinstance(attn_output, torch.Tensor) | |
if not output_attentions: | |
attn_weights = None | |
return attn_output, attn_weights, None | |
def llama_attention_patch_triton( | |
self: LlamaAttention, | |
hidden_states: torch.Tensor, | |
attention_mask: Optional[torch.Tensor] = None, | |
position_ids: Optional[torch.LongTensor] = None, | |
past_key_value: Optional[Tuple[torch.Tensor]] = None, | |
output_attentions: bool = False, | |
use_cache: bool = False, | |
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: | |
if use_cache: | |
raise NotImplementedError( | |
'use_cache is not yet supported when patching Llama attention.') | |
# output_attentions is not support for triton attention | |
if output_attentions: | |
raise NotImplementedError( | |
'output_attentions is not supported when patching Llama attention with triton attention.' | |
) | |
bsz, q_len, _ = hidden_states.size() | |
if self.config.pretraining_tp > 1: | |
key_value_slicing = (self.num_key_value_heads * | |
self.head_dim) // self.config.pretraining_tp | |
query_slices = self.q_proj.weight.split( | |
(self.num_heads * self.head_dim) // self.config.pretraining_tp, | |
dim=0) | |
key_slices = self.k_proj.weight.split(key_value_slicing, dim=0) | |
value_slices = self.v_proj.weight.split(key_value_slicing, dim=0) | |
query_states = [ | |
F.linear(hidden_states, query_slices[i]) | |
for i in range(self.config.pretraining_tp) | |
] | |
query_states = torch.cat(query_states, dim=-1) | |
key_states = [ | |
F.linear(hidden_states, key_slices[i]) | |
for i in range(self.config.pretraining_tp) | |
] | |
key_states = torch.cat(key_states, dim=-1) | |
value_states = [ | |
F.linear(hidden_states, value_slices[i]) | |
for i in range(self.config.pretraining_tp) | |
] | |
value_states = torch.cat(value_states, dim=-1) | |
else: | |
query_states = self.q_proj(hidden_states) | |
key_states = self.k_proj(hidden_states) | |
value_states = self.v_proj(hidden_states) | |
query_states = query_states.view(bsz, q_len, self.num_heads, | |
self.head_dim).transpose(1, 2) | |
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, | |
self.head_dim).transpose(1, 2) | |
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, | |
self.head_dim).transpose(1, 2) | |
kv_seq_len = key_states.shape[-2] | |
if past_key_value is not None: | |
kv_seq_len += past_key_value[0].shape[-2] | |
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) | |
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, | |
cos, sin, position_ids) | |
### MAIN MODIFICATIONS START HERE ### | |
query_states = query_states.transpose(1, 2).view( | |
bsz, q_len, self.num_heads * self.head_dim) | |
key_states = key_states.transpose(1, 2).view( | |
bsz, q_len, self.num_key_value_heads * self.head_dim) | |
value_states = value_states.transpose(1, 2).view( | |
bsz, q_len, self.num_key_value_heads * self.head_dim) | |
attn_output, _, _ = triton_flash_attn_fn( | |
query=query_states, | |
key=key_states, | |
value=value_states, | |
n_heads=self.num_heads, | |
kv_n_heads=self.num_key_value_heads, | |
past_key_value=None, | |
softmax_scale=None, | |
attn_bias=attention_mask, | |
key_padding_mask=None, | |
is_causal=False, # The causal mask is propagated from LLamaForCausalLM | |
dropout_p=0, | |
training=self.training, | |
needs_weights=False, | |
) | |
### MAIN MODIFICATIONS END HERE ### | |
if self.config.pretraining_tp > 1: | |
attn_output = attn_output.split(self.hidden_size // | |
self.config.pretraining_tp, | |
dim=2) | |
o_proj_slices = self.o_proj.weight.split(self.hidden_size // | |
self.config.pretraining_tp, | |
dim=1) | |
attn_output = sum([ | |
F.linear(attn_output[i], o_proj_slices[i]) | |
for i in range(self.config.pretraining_tp) | |
]) | |
else: | |
attn_output = self.o_proj(attn_output) | |
assert isinstance(attn_output, torch.Tensor) | |
return attn_output, None, None | |