Spaces:
Runtime error
Runtime error
# Copyright (c) OpenMMLab. All rights reserved. | |
from typing import Optional, Tuple | |
import torch | |
import torch.distributed as dist | |
import torch.nn.functional as F | |
from mmengine import MessageHub | |
from .triton_kernels import apply_rotary_emb | |
SUPPORT_FLASH2 = False | |
try: | |
from flash_attn import flash_attn_func, flash_attn_varlen_func | |
SUPPORT_FLASH2 = True | |
except ImportError: | |
pass | |
class InternLMRotaryEmbedding(torch.nn.Module): | |
def __init__(self, | |
dim, | |
max_position_embeddings=2048, | |
base=10000, | |
device=None): | |
super().__init__() | |
self.inv_freq = 1.0 / ( | |
base**(torch.arange(0, dim, 2).float().to(device) / dim)) | |
# Build here to make `torch.jit.trace` work. | |
self.max_seq_len_cached = max_position_embeddings | |
t = torch.arange( | |
self.max_seq_len_cached, | |
device=self.inv_freq.device, | |
dtype=self.inv_freq.dtype) | |
freqs = torch.einsum('i,j->ij', t, self.inv_freq) | |
emb = torch.cat((freqs, freqs), dim=-1) | |
self.cos_cached = emb.cos() | |
self.sin_cached = emb.sin() | |
def forward(self, x, seq_len): | |
# x: [bs, num_attention_heads, seq_len, head_size] | |
if (seq_len > self.max_seq_len_cached | |
or self.cos_cached.device != x.device | |
or self.cos_cached.dtype != x.dtype): | |
self.max_seq_len_cached = seq_len | |
assert self.inv_freq.dtype == torch.float32 | |
t = torch.arange( | |
self.max_seq_len_cached, | |
device=x.device, | |
dtype=self.inv_freq.dtype) | |
freqs = torch.einsum('i,j->ij', t, self.inv_freq.to(t.device)) | |
emb = torch.cat((freqs, freqs), dim=-1).to(x.device) | |
self.cos_cached = emb.cos().to(x.dtype) | |
self.sin_cached = emb.sin().to(x.dtype) | |
return ( | |
self.cos_cached[:seq_len, ...], | |
self.sin_cached[:seq_len, ...], | |
) | |
def rotate_half(x): | |
"""Rotates half the hidden dims of the input.""" | |
x1 = x[..., :x.shape[-1] // 2] | |
x2 = x[..., x.shape[-1] // 2:] | |
return torch.cat((-x2, x1), dim=-1) | |
def apply_rotary_pos_emb(q, k, cos, sin, position_ids): | |
cos = cos[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim] | |
sin = sin[position_ids].unsqueeze(1) # [bs, 1, seq_len, dim] | |
q_embed = (q * cos) + (rotate_half(q) * sin) | |
k_embed = (k * cos) + (rotate_half(k) * sin) | |
return q_embed, k_embed | |
def internlm_attn_forward( | |
self, | |
hidden_states: torch.Tensor, | |
attention_mask: Optional[torch.Tensor] = None, | |
position_ids: Optional[torch.LongTensor] = None, | |
past_key_value: Optional[Tuple[torch.Tensor]] = None, | |
output_attentions: bool = False, | |
use_cache: bool = False, | |
) -> Tuple[torch.Tensor, Optional[torch.Tensor], | |
Optional[Tuple[torch.Tensor]]]: | |
# Modified from https://huggingface.co/internlm/internlm-7b/blob/939a68c0dc1bd5f35b63c87d44af05ce33379061/modeling_internlm.py#L161 # noqa:E501 | |
bsz, q_len, _ = hidden_states.size() | |
query_states = self.q_proj(hidden_states).view(bsz, q_len, self.num_heads, | |
self.head_dim).transpose( | |
1, 2) | |
key_states = self.k_proj(hidden_states).view(bsz, q_len, self.num_heads, | |
self.head_dim).transpose( | |
1, 2) | |
value_states = self.v_proj(hidden_states).view(bsz, q_len, self.num_heads, | |
self.head_dim).transpose( | |
1, 2) | |
kv_seq_len = key_states.shape[-2] | |
if past_key_value is not None: | |
kv_seq_len += past_key_value[0].shape[-2] | |
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) | |
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, | |
cos, sin, position_ids) | |
# [bsz, nh, t, hd] | |
if past_key_value is not None: | |
# reuse k, v, self_attention | |
key_states = torch.cat([past_key_value[0], key_states], dim=2) | |
value_states = torch.cat([past_key_value[1], value_states], dim=2) | |
past_key_value = (key_states, value_states) if use_cache else None | |
if SUPPORT_FLASH2: | |
query_states = query_states.transpose(1, 2) | |
key_states = key_states.transpose(1, 2) | |
value_states = value_states.transpose(1, 2) | |
attn_output = flash_attn_func( | |
query_states, key_states, value_states, causal=True) | |
attn_output = attn_output.contiguous() | |
else: | |
# use flash attention implemented by pytorch | |
attn_output = F.scaled_dot_product_attention( | |
query_states, key_states, value_states, attn_mask=attention_mask) | |
attn_output = attn_output.transpose(1, 2) | |
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) | |
attn_output = self.o_proj(attn_output) | |
# Due to the implementation of the PyTorch version of flash attention, | |
# even when the output_attentions flag is set to True, it is not possible | |
# to return the attn_weights. | |
return attn_output, None, past_key_value | |
def internlm_varlen_attn_forward( | |
self, | |
hidden_states: torch.Tensor, | |
attention_mask: Optional[torch.Tensor] = None, | |
position_ids: Optional[torch.LongTensor] = None, | |
past_key_value: Optional[Tuple[torch.Tensor]] = None, | |
output_attentions: bool = False, | |
use_cache: bool = False, | |
) -> Tuple[torch.Tensor, Optional[torch.Tensor], | |
Optional[Tuple[torch.Tensor]]]: | |
# Modified from https://huggingface.co/internlm/internlm-7b/blob/939a68c0dc1bd5f35b63c87d44af05ce33379061/modeling_internlm.py#L161 # noqa:E501 | |
message_hub = MessageHub.get_instance('varlen_attn_args') | |
rank = dist.get_rank() | |
cumulative_len = message_hub.get_info(f'cumulative_len_rank_{rank}') | |
# position_ids = message_hub.get_info(f'position_ids_rank_{rank}') | |
max_seqlen = message_hub.get_info(f'max_seqlen_rank_{rank}') | |
use_varlen_atten = (cumulative_len is not None) | |
bsz, q_len, _ = hidden_states.size() | |
assert bsz == 1, (f'If utilizing local attention, the batch size should be' | |
f' set to 1, but got {bsz}') | |
query_states = self.q_proj(hidden_states).view(bsz, q_len, self.num_heads, | |
self.head_dim) | |
key_states = self.k_proj(hidden_states).view(bsz, q_len, self.num_heads, | |
self.head_dim) | |
value_states = self.v_proj(hidden_states).view(bsz, q_len, self.num_heads, | |
self.head_dim) | |
kv_seq_len = key_states.shape[-3] | |
if past_key_value is not None: | |
kv_seq_len += past_key_value[0].shape[-2] | |
if use_varlen_atten: | |
cos, sin = self.rotary_emb(value_states, max_seqlen) | |
query_states = apply_rotary_emb(query_states, | |
cos[position_ids].squeeze(0), | |
sin[position_ids].squeeze(0)) | |
key_states = apply_rotary_emb(key_states, cos[position_ids].squeeze(0), | |
sin[position_ids].squeeze(0)) | |
else: | |
query_states = query_states.transpose(1, 2) | |
key_states = key_states.transpose(1, 2) | |
value_states = value_states.transpose(1, 2) | |
cos, sin = self.rotary_emb(value_states, kv_seq_len) | |
query_states, key_states = apply_rotary_pos_emb( | |
query_states, key_states, cos, sin, position_ids) | |
if past_key_value is not None: | |
# reuse k, v, self_attention | |
key_states = torch.cat([past_key_value[0], key_states], dim=2) | |
value_states = torch.cat([past_key_value[1], value_states], dim=2) | |
past_key_value = (key_states, value_states) if use_cache else None | |
query_states = query_states.transpose(1, 2) | |
key_states = key_states.transpose(1, 2) | |
value_states = value_states.transpose(1, 2) | |
assert SUPPORT_FLASH2 | |
if use_varlen_atten: | |
q_unpad, k_unpad, v_unpad = query_states.flatten( | |
0, 1), key_states.flatten(0, 1), value_states.flatten(0, 1) | |
cumulative_len = torch.cat(cumulative_len, dim=0) | |
attn_output = flash_attn_varlen_func( | |
q_unpad, | |
k_unpad, | |
v_unpad, | |
cumulative_len, | |
cumulative_len, | |
max_seqlen, | |
max_seqlen, | |
0, | |
return_attn_probs=False, | |
causal=True, | |
) | |
else: | |
attn_output = flash_attn_func( | |
query_states, key_states, value_states, causal=True) | |
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) | |
attn_output = self.o_proj(attn_output) | |
# Due to the implementation of the PyTorch version of flash attention, | |
# even when the output_attentions flag is set to True, it is not possible | |
# to return the attn_weights. | |
return attn_output, None, past_key_value | |