File size: 12,694 Bytes
2365f8d 0309277 2365f8d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 |
# coding=utf-8
# Copyright and license here
""" PyTorch DeciLM model."""
import math
from typing import Optional, Tuple
import torch
import torch.nn.functional as F
import torch.utils.checkpoint
from torch import nn
from packaging import version
import transformers
if version.parse(transformers.__version__) < version.parse("4.31.0"):
raise ImportError(
f"You are using transformers=={transformers.__version__}, but transformers>=4.31.0 is required to use DeciLM. Please upgrade transformers."
)
from transformers.models.llama.modeling_llama import LlamaMLP, LlamaRMSNorm, LlamaAttention, apply_rotary_pos_emb, \
repeat_kv, LlamaPreTrainedModel, LLAMA_START_DOCSTRING, LlamaDecoderLayer, LlamaForCausalLM, LlamaModel
from transformers.utils import add_start_docstrings
from .configuration_decilm import DeciLMConfig
_CONFIG_FOR_DOC = "DeciLMConfig"
class DeciLMAttention(LlamaAttention):
"""Multi-headed attention from 'Attention Is All You Need' paper"""
def __init__(self, config: DeciLMConfig, layer_idx: int):
nn.Module.__init__(self)
self.config = config
self.hidden_size = config.hidden_size
self.num_heads = config.num_attention_heads
self.head_dim = self.hidden_size // self.num_heads
self.layer_idx = layer_idx
self.num_key_value_heads = config.num_key_value_heads_per_layer[layer_idx]
self.num_key_value_groups = self.num_heads // self.num_key_value_heads
self.pretraining_tp = config.pretraining_tp
self.max_position_embeddings = config.max_position_embeddings
self.rope_theta = getattr(config, 'rope_theta', None)
if (self.head_dim * self.num_heads) != self.hidden_size:
raise ValueError(
f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
f" and `num_heads`: {self.num_heads})."
)
self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
self.naive_attention_prefill = config.naive_attention_prefill
self.naive_attention_decode_batched = config.naive_attention_decode_batched
self.naive_attention_decode_single = config.naive_attention_decode_single
self._init_rope()
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: bool = False,
use_cache: bool = False,
padding_mask: Optional[torch.LongTensor] = None,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
bsz, q_len, _ = hidden_states.size()
if past_key_value is None:
is_decode = False
else:
is_decode = True
if self.pretraining_tp > 1:
key_value_slicing = (self.num_key_value_heads * self.head_dim) // self.pretraining_tp
query_slices = self.q_proj.weight.split((self.num_heads * self.head_dim) // self.pretraining_tp, dim=0)
key_slices = self.k_proj.weight.split(key_value_slicing, dim=0)
value_slices = self.v_proj.weight.split(key_value_slicing, dim=0)
query_states = [F.linear(hidden_states, query_slices[i]) for i in range(self.pretraining_tp)]
query_states = torch.cat(query_states, dim=-1)
key_states = [F.linear(hidden_states, key_slices[i]) for i in range(self.pretraining_tp)]
key_states = torch.cat(key_states, dim=-1)
value_states = [F.linear(hidden_states, value_slices[i]) for i in range(self.pretraining_tp)]
value_states = torch.cat(value_states, dim=-1)
else:
query_states = self.q_proj(hidden_states)
key_states = self.k_proj(hidden_states)
value_states = self.v_proj(hidden_states)
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
kv_seq_len = key_states.shape[-2]
if past_key_value is not None:
kv_seq_len += past_key_value[0].shape[-2]
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
if past_key_value is not None:
# reuse k, v, self_attention
key_states = torch.cat([past_key_value[0], key_states], dim=2)
value_states = torch.cat([past_key_value[1], value_states], dim=2)
past_key_value = (key_states, value_states) if use_cache else None
# repeat k/v heads if n_kv_heads < n_heads
key_states = repeat_kv(key_states, self.num_key_value_groups)
value_states = repeat_kv(value_states, self.num_key_value_groups)
if is_decode:
if self.naive_attention_decode_batched and bsz > 1 or self.naive_attention_decode_single and bsz == 1:
attn_weights = (query_states @ key_states.transpose(-2, -1)) / math.sqrt(key_states.size(-1))
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
if attention_mask is not None:
if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
raise ValueError(
f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
)
attn_weights = attn_weights + attention_mask
attn_output = torch.matmul(attn_weights, value_states)
else:
attn_output = F.scaled_dot_product_attention(query_states, key_states, value_states, is_causal=False,
dropout_p=0.0)
attn_output = attn_output.contiguous().view(bsz, q_len, self.hidden_size)
else:
if not self.naive_attention_prefill:
with torch.backends.cuda.sdp_kernel(enable_math=True, enable_flash=False, enable_mem_efficient=False):
attn_output = F.scaled_dot_product_attention(query_states, key_states, value_states, is_causal=True,
dropout_p=0.0)
else:
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
raise ValueError(
f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is"
f" {attn_weights.size()}"
)
if attention_mask is not None:
if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
raise ValueError(
f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
)
attn_weights = attn_weights + attention_mask
# upcast attention to fp32
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
attn_output = torch.matmul(attn_weights, value_states)
if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
raise ValueError(
f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
f" {attn_output.size()}"
)
attn_output = attn_output.transpose(1, 2).contiguous().view(bsz, q_len, self.hidden_size)
if self.pretraining_tp > 1:
attn_output = attn_output.split(self.hidden_size // self.pretraining_tp, dim=2)
o_proj_slices = self.o_proj.weight.split(self.hidden_size // self.pretraining_tp, dim=1)
attn_output = sum([F.linear(attn_output[i], o_proj_slices[i]) for i in range(self.pretraining_tp)])
else:
attn_output = self.o_proj(attn_output)
if not output_attentions:
attn_weights = None
return attn_output, attn_weights, past_key_value
class DeciLMDecoderLayer(LlamaDecoderLayer):
def __init__(self, config: DeciLMConfig, layer_idx: int):
nn.Module.__init__(self)
self.hidden_size = config.hidden_size
self.layer_idx = layer_idx
self.self_attn = DeciLMAttention(config=config, layer_idx=layer_idx)
self.mlp = LlamaMLP(config)
self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.post_attention_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
@add_start_docstrings(
"The bare DeciLM Model outputting raw hidden-states without any specific head on top.",
LLAMA_START_DOCSTRING,
)
class DeciLMPreTrainedModel(LlamaPreTrainedModel):
config_class = DeciLMConfig
_no_split_modules = ["DeciLMDecoderLayer"]
_keys_to_ignore_on_load_missing = ["self_attn.rotary_emb.inv_freq"]
@add_start_docstrings(
"The bare DeciLM Model outputting raw hidden-states without any specific head on top.",
LLAMA_START_DOCSTRING,
)
class DeciLMModel(LlamaModel, DeciLMPreTrainedModel):
"""
Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`DeciLMDecoderLayer`]
Args:
config: DeciLMConfig
"""
def __init__(self, config: DeciLMConfig):
DeciLMPreTrainedModel.__init__(self, config)
self.padding_idx = config.pad_token_id
self.vocab_size = config.vocab_size
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
self.layers = nn.ModuleList([DeciLMDecoderLayer(config, layer_idx) for layer_idx
in range(config.num_hidden_layers)])
self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.gradient_checkpointing = False
# Initialize weights and apply final processing
self.post_init()
def _prepare_decoder_attention_mask(self, attention_mask, input_shape, inputs_embeds, past_key_values_length):
self._validate_config_supports_attention_mask(attention_mask, input_shape, past_key_values_length)
return LlamaModel._prepare_decoder_attention_mask(
self, attention_mask, input_shape, inputs_embeds, past_key_values_length)
def _validate_config_supports_attention_mask(self, attention_mask, input_shape, past_key_values_length):
is_decode = past_key_values_length > 0
if not torch.all(torch.eq(attention_mask, 1)).item():
if is_decode:
if input_shape[0] == 1 and not self.config.naive_attention_decode_single:
raise ValueError(
"For support of custom attention masks please set naive_attention_decode_single to True in the "
"config")
elif input_shape[0] > 1 and not self.config.naive_attention_decode_batched:
raise ValueError(
"For support of custom attention masks please set naive_attention_decode_batched to True in the"
"config")
else:
if not self.config.naive_attention_prefill:
raise ValueError("For support of custom attention masks please set naive_attention_prefill to "
"True in the config")
class DeciLMForCausalLM(LlamaForCausalLM, DeciLMPreTrainedModel):
def __init__(self, config):
DeciLMPreTrainedModel.__init__(self, config)
self.model = DeciLMModel(config)
self.pretraining_tp = config.pretraining_tp
self.vocab_size = config.vocab_size
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
# Initialize weights and apply final processing
self.post_init()
|