Spaces:
Runtime error
Runtime error
# Copyright (c) OpenMMLab. All rights reserved. | |
import warnings | |
from typing import Optional, Tuple | |
import torch | |
import torch.distributed as dist | |
import torch.nn.functional as F | |
from einops import rearrange | |
from mmengine import MessageHub | |
from .attention import (SUPPORT_FLASH2, flash_attn_w_mask, flash_attn_wo_mask, | |
varlen_flash_attn) | |
from .triton_kernels import apply_rotary_emb | |
class InternLM2RotaryEmbedding(torch.nn.Module): | |
def __init__(self, | |
dim, | |
max_position_embeddings=2048, | |
base=1000000, | |
device=None): | |
super().__init__() | |
self.dim = dim | |
self.max_position_embeddings = max_position_embeddings | |
self.base = base | |
self.inv_freq = 1.0 / ( | |
base**(torch.arange(0, dim, 2).float().to(device) / dim)) | |
# Build here to make `torch.jit.trace` work. | |
self.max_seq_len_cached = max_position_embeddings | |
t = torch.arange( | |
self.max_seq_len_cached, | |
device=self.inv_freq.device, | |
dtype=self.inv_freq.dtype) | |
freqs = torch.einsum('i,j->ij', t, self.inv_freq) | |
emb = torch.cat((freqs, freqs), dim=-1) | |
self.cos_cached = emb.cos() | |
self.sin_cached = emb.sin() | |
def forward(self, x, seq_len): | |
# x: [bs, num_attention_heads, seq_len, head_size] | |
if (seq_len > self.max_seq_len_cached | |
or self.cos_cached.device != x.device | |
or self.cos_cached.dtype != x.dtype): | |
self.max_seq_len_cached = seq_len | |
assert self.inv_freq.dtype == torch.float32 | |
t = torch.arange( | |
self.max_seq_len_cached, | |
device=x.device, | |
dtype=self.inv_freq.dtype) | |
freqs = torch.einsum('i,j->ij', t, self.inv_freq.to(t.device)) | |
emb = torch.cat((freqs, freqs), dim=-1).to(x.device) | |
self.cos_cached = emb.cos().to(x.dtype) | |
self.sin_cached = emb.sin().to(x.dtype) | |
return ( | |
self.cos_cached[:seq_len, ...], | |
self.sin_cached[:seq_len, ...], | |
) | |
def rotate_half(x): | |
"""Rotates half the hidden dims of the input.""" | |
x1 = x[..., :x.shape[-1] // 2] | |
x2 = x[..., x.shape[-1] // 2:] | |
return torch.cat((-x2, x1), dim=-1) | |
def apply_rotary_pos_emb(q, k, cos, sin, position_ids, unsqueeze_dim=1): | |
cos = cos[position_ids].unsqueeze(unsqueeze_dim) | |
sin = sin[position_ids].unsqueeze(unsqueeze_dim) | |
q_embed = (q * cos) + (rotate_half(q) * sin) | |
k_embed = (k * cos) + (rotate_half(k) * sin) | |
return q_embed, k_embed | |
def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: | |
"""This is the equivalent of torch.repeat_interleave(x, dim=1, | |
repeats=n_rep). | |
The hidden states go from (batch, num_key_value_heads, seqlen, head_dim) to | |
(batch, num_attention_heads, seqlen, head_dim) | |
""" | |
batch, num_key_value_heads, slen, head_dim = hidden_states.shape | |
if n_rep == 1: | |
return hidden_states | |
hidden_states = hidden_states[:, :, | |
None, :, :].expand(batch, | |
num_key_value_heads, | |
n_rep, slen, head_dim) | |
return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, | |
head_dim) | |
def repeat_kv_bshd(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor: | |
"""The hidden states go from (batch, seqlen, num_key_value_heads, head_dim) | |
to (batch, seqlen, num_attention_heads, head_dim)""" | |
batch, slen, num_key_value_heads, head_dim = hidden_states.shape | |
if n_rep == 1: | |
return hidden_states | |
hidden_states = hidden_states[:, :, :, | |
None, :].expand(batch, slen, | |
num_key_value_heads, n_rep, | |
head_dim) | |
return hidden_states.reshape(batch, slen, num_key_value_heads * n_rep, | |
head_dim) | |
def internlm2_attn_forward( | |
self, | |
hidden_states: torch.Tensor, | |
attention_mask: Optional[torch.LongTensor] = None, | |
position_ids: Optional[torch.LongTensor] = None, | |
past_key_value: Optional[Tuple[torch.Tensor]] = None, | |
output_attentions: bool = False, | |
use_cache: bool = False, | |
**kwargs, | |
): | |
if 'padding_mask' in kwargs: | |
warnings.warn( | |
'Passing `padding_mask` is deprecated and will be removed in v4.37' | |
'Please make sure use `attention_mask` instead.`') | |
# overwrite attention_mask with padding_mask | |
attention_mask = kwargs.pop('padding_mask') | |
output_attentions = False | |
bsz, q_len, _ = hidden_states.size() | |
qkv_states = self.wqkv(hidden_states) | |
qkv_states = rearrange( | |
qkv_states, | |
'b q (h gs d) -> b q h gs d', | |
gs=2 + self.num_key_value_groups, | |
d=self.head_dim, | |
) | |
query_states = qkv_states[..., :self.num_key_value_groups, :] | |
query_states = rearrange(query_states, 'b q h gs d -> b q (h gs) d') | |
key_states = qkv_states[..., -2, :] | |
value_states = qkv_states[..., -1, :] | |
query_states = query_states.transpose(1, 2) | |
key_states = key_states.transpose(1, 2) | |
value_states = value_states.transpose(1, 2) | |
kv_seq_len = key_states.shape[-2] | |
if past_key_value is not None: | |
kv_seq_len += past_key_value[0].shape[-2] | |
# This modification is necessary for sequential parallel | |
assert position_ids is not None and (position_ids.max() + 1) >= kv_seq_len | |
cos, sin = self.rotary_emb(value_states, seq_len=position_ids.max() + 1) | |
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, | |
cos, sin, position_ids) | |
if past_key_value is not None: | |
# reuse k, v, self_attention | |
key_states = torch.cat([past_key_value[0], key_states], dim=2) | |
value_states = torch.cat([past_key_value[1], value_states], dim=2) | |
past_key_value = (key_states, value_states) if use_cache else None | |
# repeat kv for sequence parallel | |
key_states = repeat_kv(key_states, self.num_key_value_groups) | |
value_states = repeat_kv(value_states, self.num_key_value_groups) | |
if SUPPORT_FLASH2: | |
# the shape of attention_mask used by flash_attn and | |
# F.scaled_dot_product_attention are different | |
assert attention_mask is None or attention_mask.ndim == 2, \ | |
('When using flash_attn, attention_mask.ndim should equal to 2.' | |
f'But got attention_mask.shape = {attention_mask.shape}.' | |
'We can pass the `attn_implementation="flash_attention_2"` flag ' | |
'to `.from_pretrained` method when instantiating a Internlm2 ' | |
'model.') | |
# flash attn 2 need (bs, seq_len, nhead, h_dim) | |
query_states = query_states.transpose(1, 2) | |
key_states = key_states.transpose(1, 2) | |
value_states = value_states.transpose(1, 2) | |
causal = self.is_causal and q_len != 1 | |
if attention_mask is not None: | |
attn_output = flash_attn_w_mask( | |
query_states, | |
key_states, | |
value_states, | |
attention_mask, | |
causal=causal, | |
training=self.training) | |
else: | |
attn_output = flash_attn_wo_mask( | |
query_states, | |
key_states, | |
value_states, | |
causal=causal, | |
training=self.training) | |
else: | |
# use flash attention implemented by pytorch | |
# do not support sequence parallel | |
attn_output = F.scaled_dot_product_attention( | |
query_states, key_states, value_states, attn_mask=attention_mask) | |
attn_output = attn_output.transpose(1, 2) | |
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) | |
attn_output = self.wo(attn_output) | |
if not output_attentions: | |
attn_weights = None | |
return attn_output, attn_weights, past_key_value | |
def internlm2_varlen_attn_forward( | |
self, | |
hidden_states: torch.Tensor, | |
attention_mask: Optional[torch.Tensor] = None, | |
position_ids: Optional[torch.LongTensor] = None, | |
past_key_value: Optional[Tuple[torch.Tensor]] = None, | |
output_attentions: bool = False, | |
use_cache: bool = False, | |
) -> Tuple[torch.Tensor, Optional[torch.Tensor], | |
Optional[Tuple[torch.Tensor]]]: | |
# Modified from https://huggingface.co/internlm/internlm-7b/blob/939a68c0dc1bd5f35b63c87d44af05ce33379061/modeling_internlm.py#L161 # noqa:E501 | |
message_hub = MessageHub.get_instance('varlen_attn_args') | |
rank = dist.get_rank() | |
cumulative_len = message_hub.get_info(f'cumulative_len_rank_{rank}') | |
max_seqlen = message_hub.get_info(f'max_seqlen_rank_{rank}') | |
use_varlen_atten = (cumulative_len is not None) | |
bsz, q_len, _ = hidden_states.size() | |
assert bsz == 1, (f'If utilizing local attention, the batch size should be' | |
f' set to 1, but got {bsz}') | |
qkv_states = self.wqkv(hidden_states) | |
qkv_states = rearrange( | |
qkv_states, | |
'b q (h gs d) -> b q h gs d', | |
gs=2 + self.num_key_value_groups, | |
d=self.head_dim, | |
) | |
query_states = qkv_states[..., :self.num_key_value_groups, :] | |
query_states = rearrange(query_states, 'b q h gs d -> b q (h gs) d') | |
key_states = qkv_states[..., -2, :] | |
value_states = qkv_states[..., -1, :] | |
kv_seq_len = key_states.shape[-3] | |
if past_key_value is not None: | |
kv_seq_len += past_key_value[0].shape[-2] | |
if use_varlen_atten: | |
cos, sin = self.rotary_emb(value_states, max_seqlen) | |
query_states = apply_rotary_emb(query_states, | |
cos[position_ids].squeeze(0), | |
sin[position_ids].squeeze(0)) | |
key_states = apply_rotary_emb(key_states, cos[position_ids].squeeze(0), | |
sin[position_ids].squeeze(0)) | |
else: | |
query_states = query_states.transpose(1, 2) | |
key_states = key_states.transpose(1, 2) | |
value_states = value_states.transpose(1, 2) | |
cos, sin = self.rotary_emb(value_states, kv_seq_len) | |
query_states, key_states = apply_rotary_pos_emb( | |
query_states, key_states, cos, sin, position_ids) | |
if past_key_value is not None: | |
# reuse k, v, self_attention | |
key_states = torch.cat([past_key_value[0], key_states], dim=2) | |
value_states = torch.cat([past_key_value[1], value_states], dim=2) | |
past_key_value = (key_states, value_states) if use_cache else None | |
query_states = query_states.transpose(1, 2) | |
key_states = key_states.transpose(1, 2) | |
value_states = value_states.transpose(1, 2) | |
# repeat kv for sequence parallel | |
key_states = repeat_kv_bshd(key_states, self.num_key_value_groups) | |
value_states = repeat_kv_bshd(value_states, self.num_key_value_groups) | |
assert SUPPORT_FLASH2 | |
if use_varlen_atten: | |
attn_output = varlen_flash_attn( | |
query_states, | |
key_states, | |
value_states, | |
cumulative_len, | |
max_seqlen, | |
training=self.training) | |
else: | |
attn_output = flash_attn_wo_mask( | |
query_states, | |
key_states, | |
value_states, | |
causal=True, | |
training=False) | |
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) | |
attn_output = self.wo(attn_output) | |
# Due to the implementation of the PyTorch version of flash attention, | |
# even when the output_attentions flag is set to True, it is not possible | |
# to return the attn_weights. | |
return attn_output, None, past_key_value | |