zhangtao-whu's picture
Upload folder using huggingface_hub
476ac07 verified
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Optional, Tuple
import torch
import torch.distributed as dist
import torch.nn.functional as F
from mmengine import MessageHub
from .triton_kernels import apply_rotary_emb
SUPPORT_FLASH2 = False
try:
from flash_attn import flash_attn_func, flash_attn_varlen_func
SUPPORT_FLASH2 = True
except ImportError:
pass
class InternLMRotaryEmbedding(torch.nn.Module):
def __init__(self,
dim,
max_position_embeddings=2048,
base=10000,
device=None):
super().__init__()
self.inv_freq = 1.0 / (
base**(torch.arange(0, dim, 2).float().to(device) / dim))
# Build here to make `torch.jit.trace` work.
self.max_seq_len_cached = max_position_embeddings
t = torch.arange(
self.max_seq_len_cached,
device=self.inv_freq.device,
dtype=self.inv_freq.dtype)
freqs = torch.einsum('i,j->ij', t, self.inv_freq)
emb = torch.cat((freqs, freqs), dim=-1)
self.cos_cached = emb.cos()
self.sin_cached = emb.sin()
def forward(self, x, seq_len):
# x: [bs, num_attention_heads, seq_len, head_size]
if (seq_len > self.max_seq_len_cached
or self.cos_cached.device != x.device
or self.cos_cached.dtype != x.dtype):
self.max_seq_len_cached = seq_len
assert self.inv_freq.dtype == torch.float32
t = torch.arange(
self.max_seq_len_cached,
device=x.device,
dtype=self.inv_freq.dtype)
freqs = torch.einsum('i,j->ij', t, self.inv_freq.to(t.device))
emb = torch.cat((freqs, freqs), dim=-1).to(x.device)
self.cos_cached = emb.cos().to(x.dtype)
self.sin_cached = emb.sin().to(x.dtype)
return (
self.cos_cached[:seq_len, ...],
self.sin_cached[:seq_len, ...],
)
def rotate_half(x):
"""Rotates half the hidden dims of the input."""
x1 = x[..., :x.shape[-1] // 2]
x2 = x[..., x.shape[-1] // 2:]
return torch.cat((-x2, x1), dim=-1)
def apply_rotary_pos_emb(q, k, cos, sin, position_ids):
cos = cos[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
sin = sin[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim]
q_embed = (q * cos) + (rotate_half(q) * sin)
k_embed = (k * cos) + (rotate_half(k) * sin)
return q_embed, k_embed
def internlm_attn_forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: bool = False,
use_cache: bool = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor],
Optional[Tuple[torch.Tensor]]]:
# Modified from https://huggingface.co/internlm/internlm-7b/blob/939a68c0dc1bd5f35b63c87d44af05ce33379061/modeling_internlm.py#L161 # noqa:E501
bsz, q_len, _ = hidden_states.size()
query_states = self.q_proj(hidden_states).view(bsz, q_len, self.num_heads,
self.head_dim).transpose(
1, 2)
key_states = self.k_proj(hidden_states).view(bsz, q_len, self.num_heads,
self.head_dim).transpose(
1, 2)
value_states = self.v_proj(hidden_states).view(bsz, q_len, self.num_heads,
self.head_dim).transpose(
1, 2)
kv_seq_len = key_states.shape[-2]
if past_key_value is not None:
kv_seq_len += past_key_value[0].shape[-2]
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states,
cos, sin, position_ids)
# [bsz, nh, t, hd]
if past_key_value is not None:
# reuse k, v, self_attention
key_states = torch.cat([past_key_value[0], key_states], dim=2)
value_states = torch.cat([past_key_value[1], value_states], dim=2)
past_key_value = (key_states, value_states) if use_cache else None
if SUPPORT_FLASH2:
query_states = query_states.transpose(1, 2)
key_states = key_states.transpose(1, 2)
value_states = value_states.transpose(1, 2)
attn_output = flash_attn_func(
query_states, key_states, value_states, causal=True)
attn_output = attn_output.contiguous()
else:
# use flash attention implemented by pytorch
attn_output = F.scaled_dot_product_attention(
query_states, key_states, value_states, attn_mask=attention_mask)
attn_output = attn_output.transpose(1, 2)
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
attn_output = self.o_proj(attn_output)
# Due to the implementation of the PyTorch version of flash attention,
# even when the output_attentions flag is set to True, it is not possible
# to return the attn_weights.
return attn_output, None, past_key_value
def internlm_varlen_attn_forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: bool = False,
use_cache: bool = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor],
Optional[Tuple[torch.Tensor]]]:
# Modified from https://huggingface.co/internlm/internlm-7b/blob/939a68c0dc1bd5f35b63c87d44af05ce33379061/modeling_internlm.py#L161 # noqa:E501
message_hub = MessageHub.get_instance('varlen_attn_args')
rank = dist.get_rank()
cumulative_len = message_hub.get_info(f'cumulative_len_rank_{rank}')
# position_ids = message_hub.get_info(f'position_ids_rank_{rank}')
max_seqlen = message_hub.get_info(f'max_seqlen_rank_{rank}')
use_varlen_atten = (cumulative_len is not None)
bsz, q_len, _ = hidden_states.size()
assert bsz == 1, (f'If utilizing local attention, the batch size should be'
f' set to 1, but got {bsz}')
query_states = self.q_proj(hidden_states).view(bsz, q_len, self.num_heads,
self.head_dim)
key_states = self.k_proj(hidden_states).view(bsz, q_len, self.num_heads,
self.head_dim)
value_states = self.v_proj(hidden_states).view(bsz, q_len, self.num_heads,
self.head_dim)
kv_seq_len = key_states.shape[-3]
if past_key_value is not None:
kv_seq_len += past_key_value[0].shape[-2]
if use_varlen_atten:
cos, sin = self.rotary_emb(value_states, max_seqlen)
query_states = apply_rotary_emb(query_states,
cos[position_ids].squeeze(0),
sin[position_ids].squeeze(0))
key_states = apply_rotary_emb(key_states, cos[position_ids].squeeze(0),
sin[position_ids].squeeze(0))
else:
query_states = query_states.transpose(1, 2)
key_states = key_states.transpose(1, 2)
value_states = value_states.transpose(1, 2)
cos, sin = self.rotary_emb(value_states, kv_seq_len)
query_states, key_states = apply_rotary_pos_emb(
query_states, key_states, cos, sin, position_ids)
if past_key_value is not None:
# reuse k, v, self_attention
key_states = torch.cat([past_key_value[0], key_states], dim=2)
value_states = torch.cat([past_key_value[1], value_states], dim=2)
past_key_value = (key_states, value_states) if use_cache else None
query_states = query_states.transpose(1, 2)
key_states = key_states.transpose(1, 2)
value_states = value_states.transpose(1, 2)
assert SUPPORT_FLASH2
if use_varlen_atten:
q_unpad, k_unpad, v_unpad = query_states.flatten(
0, 1), key_states.flatten(0, 1), value_states.flatten(0, 1)
cumulative_len = torch.cat(cumulative_len, dim=0)
attn_output = flash_attn_varlen_func(
q_unpad,
k_unpad,
v_unpad,
cumulative_len,
cumulative_len,
max_seqlen,
max_seqlen,
0,
return_attn_probs=False,
causal=True,
)
else:
attn_output = flash_attn_func(
query_states, key_states, value_states, causal=True)
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
attn_output = self.o_proj(attn_output)
# Due to the implementation of the PyTorch version of flash attention,
# even when the output_attentions flag is set to True, it is not possible
# to return the attn_weights.
return attn_output, None, past_key_value