Take input attention masks to support left-padded sequences

#1
by hiyouga - opened
Files changed (1) hide show
  1. modeling_baichuan.py +329 -140
modeling_baichuan.py CHANGED
@@ -5,6 +5,8 @@ from typing import List, Optional, Tuple, Union
5
 
6
  import torch
7
  import torch.utils.checkpoint
 
 
8
  from torch.nn import CrossEntropyLoss
9
  from transformers import PreTrainedModel
10
  from transformers.activations import ACT2FN
@@ -14,72 +16,117 @@ from transformers.generation.utils import GenerationConfig
14
 
15
  from .configuration_baichuan import BaichuanConfig
16
 
 
17
  logger = logging.get_logger(__name__)
18
 
19
- def _get_interleave(n):
20
- def _get_interleave_power_of_2(n):
21
- start = (2 ** (-2 ** -(math.log2(n) - 3)))
22
- ratio = start
23
- return [start * ratio ** i for i in range(n)]
24
-
25
- if math.log2(n).is_integer():
26
- return _get_interleave_power_of_2(n)
27
- else:
28
- closest_power_of_2 = 2 ** math.floor(math.log2(n))
29
- return _get_interleave_power_of_2(closest_power_of_2) + \
30
- _get_interleave(2 * closest_power_of_2)[0::2][:n - closest_power_of_2]
31
-
32
- def _fill_with_neg_inf(t):
33
- """FP16-compatible function that fills a tensor with -inf."""
34
- return t.float().fill_(float("-inf")).type_as(t)
35
-
36
- def _gen_alibi_mask(n_head, max_pos):
37
- slopes = torch.Tensor(_get_interleave(n_head))
38
- alibi = slopes.unsqueeze(1).unsqueeze(1) * torch.arange(max_pos).unsqueeze(0).unsqueeze(0).expand(
39
- n_head, -1, -1)
40
- alibi = alibi.view(n_head, 1, max_pos)
41
- alibi_mask = torch.triu(
42
- _fill_with_neg_inf(torch.zeros([max_pos, max_pos])), 1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
43
  )
44
- alibi_mask = alibi_mask.unsqueeze(0) + alibi
45
- return alibi_mask
 
 
 
 
 
 
 
 
46
 
 
 
 
 
 
 
 
 
 
 
47
 
48
- class RMSNorm(torch.nn.Module):
49
  def __init__(self, hidden_size, epsilon=1e-6):
50
  super().__init__()
51
- self.weight = torch.nn.Parameter(torch.empty(hidden_size))
52
  self.epsilon = epsilon
53
 
54
- def forward(self, hidden_states):
 
55
  variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
56
  hidden_states = hidden_states * torch.rsqrt(variance + self.epsilon)
57
 
58
- # convert into half-precision
59
- if self.weight.dtype in [torch.float16, torch.bfloat16]:
60
- hidden_states = hidden_states.to(self.weight.dtype)
61
 
62
- return self.weight * hidden_states
63
 
 
64
 
65
- class MLP(torch.nn.Module):
66
  def __init__(
67
- self,
68
- hidden_size: int,
69
- intermediate_size: int,
70
- hidden_act: str,
71
  ):
72
  super().__init__()
73
- self.gate_proj = torch.nn.Linear(hidden_size, intermediate_size, bias=False)
74
- self.down_proj = torch.nn.Linear(intermediate_size, hidden_size, bias=False)
75
- self.up_proj = torch.nn.Linear(hidden_size, intermediate_size, bias=False)
76
  self.act_fn = ACT2FN[hidden_act]
77
 
78
  def forward(self, x):
79
  return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
80
 
81
 
82
- class BaichuanAttention(torch.nn.Module):
83
 
84
  def __init__(self, config: BaichuanConfig):
85
  super().__init__()
@@ -93,62 +140,89 @@ class BaichuanAttention(torch.nn.Module):
93
  raise ValueError(
94
  f"hidden_size {self.hidden_size} is not divisible by num_heads {self.num_heads}"
95
  )
96
- self.W_pack = torch.nn.Linear(self.hidden_size, 3 * self.hidden_size, bias=False)
97
- self.o_proj = torch.nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
 
 
 
 
 
98
 
99
  def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
100
  return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
101
 
102
  def forward(
103
- self,
104
- hidden_states: torch.Tensor,
105
- attention_mask: Optional[torch.Tensor] = None,
106
- past_key_value: Optional[Tuple[torch.Tensor]] = None,
107
- output_attentions: bool = False,
108
- use_cache: bool = False,
 
109
  ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
110
 
111
  bsz, q_len, _ = hidden_states.size()
112
 
113
- proj = self.W_pack(hidden_states)
114
  proj = proj.unflatten(-1, (3, self.hidden_size)).unsqueeze(0).transpose(0, -2).squeeze(-2)
115
- query_states = proj[0].view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
116
- key_states = proj[1].view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
117
- value_states = proj[2].view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
118
 
119
- kv_seq_len = key_states.shape[-2]
120
- if past_key_value is not None:
121
- kv_seq_len += past_key_value[0].shape[-2]
122
 
123
  if past_key_value is not None:
124
  # reuse k, v, self_attention
125
- key_states = torch.cat([past_key_value[0], key_states], dim=2)
126
- value_states = torch.cat([past_key_value[1], value_states], dim=2)
 
 
 
127
 
128
  past_key_value = (key_states, value_states) if use_cache else None
129
 
130
- attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
 
 
 
 
 
 
 
 
 
 
131
 
132
- if attention_mask is not None:
133
- if attn_weights.size(-2) == 1:
134
- attention_mask = attention_mask[:, -1:, :]
135
- attn_weights = attn_weights + attention_mask.unsqueeze(0)
136
- attn_weights = torch.max(attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min))
 
 
 
137
 
138
- attn_weights = torch.nn.functional.softmax(attn_weights, dim=-1)
139
- attn_output = torch.matmul(attn_weights, value_states)
140
 
141
- attn_output = attn_output.transpose(1, 2)
142
- attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
 
 
 
 
143
  attn_output = self.o_proj(attn_output)
144
 
145
  if not output_attentions:
146
- attn_weights = None
 
 
147
 
148
- return attn_output, attn_weights, past_key_value
149
 
 
150
 
151
- class BaichuanLayer(torch.nn.Module):
152
  def __init__(self, config: BaichuanConfig):
153
  super().__init__()
154
  self.hidden_size = config.hidden_size
@@ -162,12 +236,13 @@ class BaichuanLayer(torch.nn.Module):
162
  self.post_attention_layernorm = RMSNorm(config.hidden_size, epsilon=config.rms_norm_eps)
163
 
164
  def forward(
165
- self,
166
- hidden_states: torch.Tensor,
167
- attention_mask: Optional[torch.Tensor] = None,
168
- past_key_value: Optional[Tuple[torch.Tensor]] = None,
169
- output_attentions: Optional[bool] = False,
170
- use_cache: Optional[bool] = False,
 
171
  ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
172
 
173
  residual = hidden_states
@@ -177,6 +252,7 @@ class BaichuanLayer(torch.nn.Module):
177
  # Self Attention
178
  hidden_states, self_attn_weights, present_key_value = self.self_attn(
179
  hidden_states=hidden_states,
 
180
  attention_mask=attention_mask,
181
  past_key_value=past_key_value,
182
  output_attentions=output_attentions,
@@ -192,6 +268,9 @@ class BaichuanLayer(torch.nn.Module):
192
 
193
  outputs = (hidden_states,)
194
 
 
 
 
195
  if use_cache:
196
  outputs += (present_key_value,)
197
 
@@ -203,15 +282,16 @@ class BaichuanPreTrainedModel(PreTrainedModel):
203
  base_model_prefix = "model"
204
  supports_gradient_checkpointing = True
205
  _no_split_modules = ["BaichuanLayer"]
 
206
  _keys_to_ignore_on_load_unexpected = [r"decoder\.version"]
207
 
208
  def _init_weights(self, module):
209
  std = self.config.initializer_range
210
- if isinstance(module, torch.nn.Linear):
211
  module.weight.data.normal_(mean=0.0, std=std)
212
  if module.bias is not None:
213
  module.bias.data.zero_()
214
- elif isinstance(module, torch.nn.Embedding):
215
  module.weight.data.normal_(mean=0.0, std=std)
216
  if module.padding_idx is not None:
217
  module.weight.data[module.padding_idx].zero_()
@@ -220,50 +300,109 @@ class BaichuanPreTrainedModel(PreTrainedModel):
220
  if isinstance(module, BaichuanModel):
221
  module.gradient_checkpointing = value
222
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
223
 
224
 
225
  class BaichuanModel(BaichuanPreTrainedModel):
 
226
  def __init__(self, config: BaichuanConfig):
227
  super().__init__(config)
228
  self.padding_idx = config.pad_token_id
229
  self.vocab_size = config.vocab_size
230
  self.n_head = config.num_attention_heads
231
- self.embed_tokens = torch.nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
232
- self.layers = torch.nn.ModuleList([BaichuanLayer(config) for _ in range(config.num_hidden_layers)])
 
233
  self.norm = RMSNorm(config.hidden_size, epsilon=config.rms_norm_eps)
234
 
235
  self.gradient_checkpointing = config.gradient_checkpointing
236
  self.post_init()
237
- self.max_cache_pos = config.model_max_length
238
- self.first_run = True
239
 
240
  def get_input_embeddings(self):
241
  return self.embed_tokens
242
-
243
  def set_input_embeddings(self, value):
244
- self.embed_tokens = value
245
-
246
- def get_alibi_mask(self, tensor, seq_length_with_past):
247
- if self.first_run:
248
- self.first_run = False
249
- self.register_buffer("future_mask", _gen_alibi_mask(self.n_head, self.max_cache_pos).to(tensor), persistent=False)
250
- if seq_length_with_past > self.max_cache_pos:
251
- self.max_cache_pos = seq_length_with_past
252
- self.register_buffer("future_mask", _gen_alibi_mask(self.n_head, self.max_cache_pos).to(tensor), persistent=False)
253
- mask = self.future_mask[:self.n_head, :seq_length_with_past, :seq_length_with_past]
254
- return mask
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
255
 
256
  def forward(
257
- self,
258
- input_ids: torch.LongTensor = None,
259
- past_key_values: Optional[List[torch.FloatTensor]] = None,
260
- inputs_embeds: Optional[torch.FloatTensor] = None,
261
- use_cache: Optional[bool] = False,
262
- output_attentions: Optional[bool] = False,
263
- output_hidden_states: Optional[bool] = False,
264
- return_dict: Optional[bool] = True,
 
265
  ) -> Union[Tuple, BaseModelOutputWithPast]:
266
-
 
 
 
 
 
267
 
268
  if input_ids is not None and inputs_embeds is not None:
269
  raise ValueError("You cannot provide both input_ids and inputs_embeds simultaneously")
@@ -275,19 +414,21 @@ class BaichuanModel(BaichuanPreTrainedModel):
275
  raise ValueError("You need to provide input_ids or inputs_embeds")
276
 
277
  seq_length_with_past = seq_length
278
-
279
  if past_key_values is not None:
280
- past_key_values_length = past_key_values[0][0].shape[2]
281
  seq_length_with_past = seq_length_with_past + past_key_values_length
282
 
283
  if inputs_embeds is None:
284
  inputs_embeds = self.embed_tokens(input_ids)
285
 
286
- # embed positions
287
- attention_mask = self.get_alibi_mask(inputs_embeds, seq_length_with_past)
288
-
289
  hidden_states = inputs_embeds
290
 
 
 
 
 
 
291
  if self.gradient_checkpointing and self.training:
292
  if use_cache:
293
  logger.warning_once(
@@ -295,6 +436,15 @@ class BaichuanModel(BaichuanPreTrainedModel):
295
  )
296
  use_cache = False
297
 
 
 
 
 
 
 
 
 
 
298
  # decoder layers
299
  all_hidden_states = () if output_hidden_states else None
300
  all_self_attns = () if output_attentions else None
@@ -318,13 +468,15 @@ class BaichuanModel(BaichuanPreTrainedModel):
318
  layer_outputs = torch.utils.checkpoint.checkpoint(
319
  create_custom_forward(decoder_layer),
320
  hidden_states,
321
- attention_mask,
 
322
  None,
323
  )
324
  else:
325
  layer_outputs = decoder_layer(
326
  hidden_states,
327
- attention_mask=attention_mask,
 
328
  past_key_value=past_key_value,
329
  output_attentions=output_attentions,
330
  use_cache=use_cache,
@@ -345,21 +497,25 @@ class BaichuanModel(BaichuanPreTrainedModel):
345
  all_hidden_states += (hidden_states,)
346
 
347
  next_cache = next_decoder_cache if use_cache else None
 
348
  if not return_dict:
349
  return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
 
350
  return BaseModelOutputWithPast(
351
  last_hidden_state=hidden_states,
352
  past_key_values=next_cache,
353
  hidden_states=all_hidden_states,
354
  attentions=all_self_attns,
355
  )
356
-
357
 
358
  class BaichuanForCausalLM(BaichuanPreTrainedModel):
 
359
  def __init__(self, config):
360
  super().__init__(config)
361
  self.model = BaichuanModel(config)
362
- self.lm_head = torch.nn.Linear(config.hidden_size, config.vocab_size, bias=False)
 
363
 
364
  # Initialize weights and apply final processing
365
  self.post_init()
@@ -381,31 +537,37 @@ class BaichuanForCausalLM(BaichuanPreTrainedModel):
381
 
382
  def get_decoder(self):
383
  return self.model
384
-
385
  def forward(
386
- self,
387
- input_ids: torch.LongTensor = None,
388
- past_key_values: Optional[List[torch.FloatTensor]] = None,
389
- inputs_embeds: Optional[torch.FloatTensor] = None,
390
- labels: Optional[torch.LongTensor] = None,
391
- use_cache: Optional[bool] = None,
392
- output_attentions: Optional[bool] = False,
393
- output_hidden_states: Optional[bool] = False,
394
- return_dict: Optional[bool] = True,
395
- **kwargs
 
396
  ) -> Union[Tuple, CausalLMOutputWithPast]:
397
-
 
 
 
 
398
 
399
  # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
400
  outputs = self.model(
401
  input_ids=input_ids,
 
402
  past_key_values=past_key_values,
403
  inputs_embeds=inputs_embeds,
404
  use_cache=use_cache,
405
  output_attentions=output_attentions,
406
  output_hidden_states=output_hidden_states,
407
  return_dict=return_dict,
408
- )
409
 
410
  hidden_states = outputs[0]
411
  logits = self.lm_head(hidden_states)
@@ -436,11 +598,20 @@ class BaichuanForCausalLM(BaichuanPreTrainedModel):
436
  )
437
 
438
  def prepare_inputs_for_generation(
439
- self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
440
- ):
 
 
 
 
 
441
  if past_key_values:
442
  input_ids = input_ids[:, -1:]
443
 
 
 
 
 
444
  # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
445
  if inputs_embeds is not None and past_key_values is None:
446
  model_inputs = {"inputs_embeds": inputs_embeds}
@@ -448,20 +619,38 @@ class BaichuanForCausalLM(BaichuanPreTrainedModel):
448
  model_inputs = {"input_ids": input_ids}
449
 
450
  model_inputs.update(
451
- {
452
  "past_key_values": past_key_values,
453
  "use_cache": kwargs.get("use_cache"),
454
- }
455
- )
 
456
  return model_inputs
457
 
458
- @staticmethod
459
- def _reorder_cache(past_key_values, beam_idx):
460
- return tuple(
461
- tuple(past_state.index_select(0, beam_idx) for past_state in layer_past)
462
- for layer_past in past_key_values
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
463
  )
464
-
465
 
466
  def quantize(self, bits: int):
467
  try:
@@ -470,7 +659,7 @@ class BaichuanForCausalLM(BaichuanPreTrainedModel):
470
  raise ImportError(
471
  f"Needs QLinear to run quantize."
472
  )
473
-
474
  for layer in self.model.layers:
475
  layer.self_attn.W_pack = QLinear(
476
  bits=bits,
@@ -497,7 +686,7 @@ class BaichuanForCausalLM(BaichuanPreTrainedModel):
497
  weight=layer.mlp.up_proj.weight,
498
  bias = None,
499
  )
500
- return self
501
 
502
  def _build_chat_input(self, tokenizer, messages: List[dict], max_new_tokens: int=0):
503
  max_new_tokens = max_new_tokens or self.generation_config.max_new_tokens
 
5
 
6
  import torch
7
  import torch.utils.checkpoint
8
+ import torch.nn.functional as F
9
+ from torch import nn
10
  from torch.nn import CrossEntropyLoss
11
  from transformers import PreTrainedModel
12
  from transformers.activations import ACT2FN
 
16
 
17
  from .configuration_baichuan import BaichuanConfig
18
 
19
+
20
  logger = logging.get_logger(__name__)
21
 
22
+
23
+ # Copied from transformers.models.bloom.modeling_bloom._make_causal_mask
24
+ def _make_causal_mask(
25
+ input_ids_shape: torch.Size, device: torch.device, past_key_values_length: int
26
+ ) -> torch.BoolTensor:
27
+ """
28
+ Make causal mask used for self-attention.
29
+ """
30
+ batch_size, target_length = input_ids_shape
31
+ mask = torch.empty((target_length, target_length + past_key_values_length), dtype=torch.bool, device=device)
32
+ # ONNX doesn't support `torch.Tensor.triu` properly, thus we use this workaround
33
+ seq_ids = torch.arange(target_length, device=device)
34
+ mask[:, past_key_values_length:] = seq_ids[:, None] < seq_ids[None, :]
35
+
36
+ if past_key_values_length > 0:
37
+ mask[:, :past_key_values_length] = False
38
+
39
+ expanded_mask = mask[None, None, :, :].expand(batch_size, 1, target_length, target_length + past_key_values_length)
40
+ return expanded_mask
41
+
42
+
43
+ # Copied from transformers.models.bloom.modeling_bloom._expand_mask
44
+ def _expand_mask(mask: torch.Tensor, tgt_length: int) -> torch.BoolTensor:
45
+ """
46
+ Expands attention_mask from `[batch_size, src_length]` to `[batch_size, 1, tgt_length, src_length]`.
47
+ """
48
+ batch_size, src_length = mask.shape
49
+ tgt_length = tgt_length if tgt_length is not None else src_length
50
+
51
+ expanded_mask = ~(mask[:, None, None, :].to(torch.bool))
52
+ return expanded_mask.expand(batch_size, 1, tgt_length, src_length)
53
+
54
+
55
+ # Copied from transformers.models.bloom.modeling_bloom.build_alibi_tensor
56
+ def build_alibi_tensor(attention_mask: torch.Tensor, num_heads: int, dtype: torch.dtype) -> torch.Tensor:
57
+ """
58
+ Link to paper: https://arxiv.org/abs/2108.12409 Alibi tensor is not causal as the original paper mentions, it
59
+ relies on a translation invariance of softmax for quick implementation: with l being a tensor, and a fixed value
60
+ `softmax(l+a) = softmax(l)`.
61
+
62
+ Args:
63
+ Returns tensor shaped (batch_size * num_heads, 1, max_seq_len)
64
+ attention_mask (`torch.Tensor`):
65
+ Token-wise attention mask, this should be of shape (batch_size, max_seq_len).
66
+ num_heads (`int`, *required*):
67
+ number of heads
68
+ dtype (`torch.dtype`, *optional*, default=`torch.bfloat16`):
69
+ dtype of the output tensor
70
+ """
71
+ batch_size, seq_length = attention_mask.shape
72
+ closest_power_of_2 = 2 ** math.floor(math.log2(num_heads))
73
+ base = torch.tensor(
74
+ 2 ** (-(2 ** -(math.log2(closest_power_of_2) - 3))), device=attention_mask.device, dtype=torch.float32
75
  )
76
+ powers = torch.arange(1, 1 + closest_power_of_2, device=attention_mask.device, dtype=torch.int32)
77
+ slopes = torch.pow(base, powers)
78
+
79
+ if closest_power_of_2 != num_heads:
80
+ extra_base = torch.tensor(
81
+ 2 ** (-(2 ** -(math.log2(2 * closest_power_of_2) - 3))), device=attention_mask.device, dtype=torch.float32
82
+ )
83
+ num_remaining_heads = min(closest_power_of_2, num_heads - closest_power_of_2)
84
+ extra_powers = torch.arange(1, 1 + 2 * num_remaining_heads, 2, device=attention_mask.device, dtype=torch.int32)
85
+ slopes = torch.cat([slopes, torch.pow(extra_base, extra_powers)], dim=0)
86
 
87
+ # Note: alibi will added to the attention bias that will be applied to the query, key product of attention
88
+ # => therefore alibi will have to be of shape (batch_size, num_heads, query_length, key_length)
89
+ # => here we set (batch_size=1, num_heads=num_heads, query_length=1, key_length=max_length)
90
+ # => the query_length dimension will then be broadcasted correctly
91
+ arange_tensor = ((attention_mask.cumsum(dim=-1) - 1) * attention_mask)[:, None, :]
92
+ alibi = slopes[..., None] * arange_tensor
93
+ return alibi.reshape(batch_size * num_heads, 1, seq_length).to(dtype)
94
+
95
+
96
+ class RMSNorm(nn.Module):
97
 
 
98
  def __init__(self, hidden_size, epsilon=1e-6):
99
  super().__init__()
100
+ self.weight = nn.Parameter(torch.ones(hidden_size))
101
  self.epsilon = epsilon
102
 
103
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
104
+ input_dtype = hidden_states.dtype
105
  variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
106
  hidden_states = hidden_states * torch.rsqrt(variance + self.epsilon)
107
 
108
+ return (self.weight * hidden_states).to(input_dtype)
 
 
109
 
 
110
 
111
+ class MLP(nn.Module):
112
 
 
113
  def __init__(
114
+ self,
115
+ hidden_size: int,
116
+ intermediate_size: int,
117
+ hidden_act: str,
118
  ):
119
  super().__init__()
120
+ self.gate_proj = nn.Linear(hidden_size, intermediate_size, bias=False)
121
+ self.down_proj = nn.Linear(intermediate_size, hidden_size, bias=False)
122
+ self.up_proj = nn.Linear(hidden_size, intermediate_size, bias=False)
123
  self.act_fn = ACT2FN[hidden_act]
124
 
125
  def forward(self, x):
126
  return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
127
 
128
 
129
+ class BaichuanAttention(nn.Module):
130
 
131
  def __init__(self, config: BaichuanConfig):
132
  super().__init__()
 
140
  raise ValueError(
141
  f"hidden_size {self.hidden_size} is not divisible by num_heads {self.num_heads}"
142
  )
143
+
144
+ # Layer-wise attention scaling
145
+ self.inv_norm_factor = 1.0 / math.sqrt(self.head_dim)
146
+ self.beta = 1.0
147
+
148
+ self.W_pack = nn.Linear(self.hidden_size, 3 * self.hidden_size, bias=False)
149
+ self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
150
 
151
  def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
152
  return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
153
 
154
  def forward(
155
+ self,
156
+ hidden_states: torch.Tensor,
157
+ alibi: torch.Tensor,
158
+ attention_mask: torch.Tensor,
159
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
160
+ output_attentions: bool = False,
161
+ use_cache: bool = False,
162
  ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
163
 
164
  bsz, q_len, _ = hidden_states.size()
165
 
166
+ proj = self.W_pack(hidden_states) # [batch_size, seq_length, 3 x hidden_size]
167
  proj = proj.unflatten(-1, (3, self.hidden_size)).unsqueeze(0).transpose(0, -2).squeeze(-2)
168
+ query_states = proj[0].view(bsz, q_len, self.num_heads, self.head_dim)
169
+ key_states = proj[1].view(bsz, q_len, self.num_heads, self.head_dim)
170
+ value_states = proj[2].view(bsz, q_len, self.num_heads, self.head_dim)
171
 
172
+ query_states = query_states.transpose(1, 2).reshape(bsz * self.num_heads, q_len, self.head_dim)
173
+ key_states = key_states.permute(0, 2, 3, 1).reshape(bsz * self.num_heads, self.head_dim, q_len)
174
+ value_states = value_states.transpose(1, 2).reshape(bsz * self.num_heads, q_len, self.head_dim)
175
 
176
  if past_key_value is not None:
177
  # reuse k, v, self_attention
178
+ past_key, past_value = past_key_value
179
+ key_states = torch.cat([past_key, key_states], dim=2)
180
+ value_states = torch.cat([past_value, value_states], dim=1)
181
+
182
+ _, _, kv_seq_len = key_states.shape
183
 
184
  past_key_value = (key_states, value_states) if use_cache else None
185
 
186
+ # [batch_size * num_heads, q_length, kv_length]
187
+ # we use `torch.Tensor.baddbmm` instead of `torch.baddbmm` as the latter isn't supported by TorchScript v1.11
188
+ matmul_result = alibi.baddbmm(
189
+ batch1=query_states,
190
+ batch2=key_states,
191
+ beta=self.beta,
192
+ alpha=self.inv_norm_factor,
193
+ )
194
+
195
+ # change view to [batch_size, num_heads, q_length, kv_length]
196
+ attention_scores = matmul_result.view(bsz, self.num_heads, q_len, kv_seq_len)
197
 
198
+ # cast attention scores to fp32, compute scaled softmax and cast back to initial dtype
199
+ # [batch_size, num_heads, q_length, kv_length]
200
+ input_dtype = attention_scores.dtype
201
+ # `float16` has a minimum value of -65504.0, whereas `bfloat16` and `float32` have a minimum value of `-3.4e+38`
202
+ if input_dtype == torch.float16:
203
+ attention_scores = attention_scores.to(torch.float)
204
+ attn_weights = torch.masked_fill(attention_scores, attention_mask, torch.finfo(attention_scores.dtype).min)
205
+ attention_probs = F.softmax(attn_weights, dim=-1, dtype=torch.float32).to(input_dtype)
206
 
207
+ # change view [batch_size x num_heads, q_length, kv_length]
208
+ attention_probs_reshaped = attention_probs.view(bsz * self.num_heads, q_len, kv_seq_len)
209
 
210
+ # matmul: [batch_size * num_heads, q_length, head_dim]
211
+ attn_output = torch.bmm(attention_probs_reshaped, value_states)
212
+
213
+ attn_output = attn_output.view(bsz, self.num_heads, q_len, self.head_dim)
214
+
215
+ attn_output = attn_output.transpose(1, 2).reshape(bsz, q_len, self.hidden_size)
216
  attn_output = self.o_proj(attn_output)
217
 
218
  if not output_attentions:
219
+ attention_probs = None
220
+
221
+ return attn_output, attention_probs, past_key_value
222
 
 
223
 
224
+ class BaichuanLayer(nn.Module):
225
 
 
226
  def __init__(self, config: BaichuanConfig):
227
  super().__init__()
228
  self.hidden_size = config.hidden_size
 
236
  self.post_attention_layernorm = RMSNorm(config.hidden_size, epsilon=config.rms_norm_eps)
237
 
238
  def forward(
239
+ self,
240
+ hidden_states: torch.Tensor,
241
+ alibi: torch.Tensor,
242
+ attention_mask: torch.Tensor,
243
+ past_key_value: Optional[Tuple[torch.Tensor]] = None,
244
+ output_attentions: Optional[bool] = False,
245
+ use_cache: Optional[bool] = False,
246
  ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
247
 
248
  residual = hidden_states
 
252
  # Self Attention
253
  hidden_states, self_attn_weights, present_key_value = self.self_attn(
254
  hidden_states=hidden_states,
255
+ alibi=alibi,
256
  attention_mask=attention_mask,
257
  past_key_value=past_key_value,
258
  output_attentions=output_attentions,
 
268
 
269
  outputs = (hidden_states,)
270
 
271
+ if output_attentions:
272
+ outputs += (self_attn_weights,)
273
+
274
  if use_cache:
275
  outputs += (present_key_value,)
276
 
 
282
  base_model_prefix = "model"
283
  supports_gradient_checkpointing = True
284
  _no_split_modules = ["BaichuanLayer"]
285
+ _skip_keys_device_placement = "past_key_values"
286
  _keys_to_ignore_on_load_unexpected = [r"decoder\.version"]
287
 
288
  def _init_weights(self, module):
289
  std = self.config.initializer_range
290
+ if isinstance(module, nn.Linear):
291
  module.weight.data.normal_(mean=0.0, std=std)
292
  if module.bias is not None:
293
  module.bias.data.zero_()
294
+ elif isinstance(module, nn.Embedding):
295
  module.weight.data.normal_(mean=0.0, std=std)
296
  if module.padding_idx is not None:
297
  module.weight.data[module.padding_idx].zero_()
 
300
  if isinstance(module, BaichuanModel):
301
  module.gradient_checkpointing = value
302
 
303
+ @staticmethod
304
+ def _convert_to_standard_cache(
305
+ past_key_value: Tuple[Tuple[torch.Tensor, torch.Tensor]], batch_size: int
306
+ ) -> Tuple[Tuple[torch.Tensor, torch.Tensor]]:
307
+ """
308
+ Standardizes the format of the cache so as to match most implementations, i.e. to tuple(tuple([batch_size,
309
+ num_heads, ...]))
310
+ """
311
+ batch_size_times_num_heads, head_dim, seq_length = past_key_value[0][0].shape
312
+ num_heads = batch_size_times_num_heads // batch_size
313
+ # key: [batch_size * num_heads, head_dim, seq_length] -> [batch_size, num_heads, head_dim, seq_length]
314
+ # value: [batch_size * num_heads, seq_length, head_dim] -> [batch_size, num_heads, seq_length, head_dim]
315
+ return tuple(
316
+ (
317
+ layer_past[0].view(batch_size, num_heads, head_dim, seq_length),
318
+ layer_past[1].view(batch_size, num_heads, seq_length, head_dim),
319
+ )
320
+ for layer_past in past_key_value
321
+ )
322
+
323
+ @staticmethod
324
+ def _convert_to_baichuan_cache(
325
+ past_key_value: Tuple[Tuple[torch.Tensor, torch.Tensor]]
326
+ ) -> Tuple[Tuple[torch.Tensor, torch.Tensor]]:
327
+ """
328
+ Converts the cache to the format expected by Baichuan, i.e. to tuple(tuple([batch_size * num_heads, ...]))
329
+ """
330
+ batch_size, num_heads, head_dim, seq_length = past_key_value[0][0].shape
331
+ batch_size_times_num_heads = batch_size * num_heads
332
+ # key: [batch_size, num_heads, head_dim, seq_length] -> [batch_size * num_heads, head_dim, seq_length]
333
+ # value: [batch_size, num_heads, seq_length, head_dim] -> [batch_size * num_heads, seq_length, head_dim]
334
+ return tuple(
335
+ (
336
+ layer_past[0].view(batch_size_times_num_heads, head_dim, seq_length),
337
+ layer_past[1].view(batch_size_times_num_heads, seq_length, head_dim),
338
+ )
339
+ for layer_past in past_key_value
340
+ )
341
 
342
 
343
  class BaichuanModel(BaichuanPreTrainedModel):
344
+
345
  def __init__(self, config: BaichuanConfig):
346
  super().__init__(config)
347
  self.padding_idx = config.pad_token_id
348
  self.vocab_size = config.vocab_size
349
  self.n_head = config.num_attention_heads
350
+
351
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
352
+ self.layers = nn.ModuleList([BaichuanLayer(config) for _ in range(config.num_hidden_layers)])
353
  self.norm = RMSNorm(config.hidden_size, epsilon=config.rms_norm_eps)
354
 
355
  self.gradient_checkpointing = config.gradient_checkpointing
356
  self.post_init()
 
 
357
 
358
  def get_input_embeddings(self):
359
  return self.embed_tokens
360
+
361
  def set_input_embeddings(self, value):
362
+ self.embed_tokens = value
363
+
364
+ def build_alibi_tensor(self, attention_mask: torch.Tensor, num_heads: int, dtype: torch.dtype) -> torch.Tensor:
365
+ return build_alibi_tensor(attention_mask, num_heads, dtype)
366
+
367
+ def _prepare_attn_mask(
368
+ self, attention_mask: torch.Tensor, input_shape: Tuple[int, int], past_key_values_length: int
369
+ ) -> torch.BoolTensor:
370
+ # create causal mask
371
+ # [batch_size, seq_length] -> [batch_size, 1, tgt_length, src_length]
372
+ combined_attention_mask = None
373
+ device = attention_mask.device
374
+ _, src_length = input_shape
375
+
376
+ if src_length > 1:
377
+ combined_attention_mask = _make_causal_mask(
378
+ input_shape, device=device, past_key_values_length=past_key_values_length
379
+ )
380
+
381
+ # [batch_size, seq_length] -> [batch_size, 1, tgt_length, src_length]
382
+ expanded_attn_mask = _expand_mask(attention_mask, tgt_length=src_length)
383
+ combined_attention_mask = (
384
+ expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask | combined_attention_mask
385
+ )
386
+
387
+ return combined_attention_mask
388
 
389
  def forward(
390
+ self,
391
+ input_ids: torch.LongTensor = None,
392
+ attention_mask: Optional[torch.Tensor] = None,
393
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
394
+ inputs_embeds: Optional[torch.FloatTensor] = None,
395
+ use_cache: Optional[bool] = None,
396
+ output_attentions: Optional[bool] = None,
397
+ output_hidden_states: Optional[bool] = None,
398
+ return_dict: Optional[bool] = None,
399
  ) -> Union[Tuple, BaseModelOutputWithPast]:
400
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
401
+ output_hidden_states = (
402
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
403
+ )
404
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
405
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
406
 
407
  if input_ids is not None and inputs_embeds is not None:
408
  raise ValueError("You cannot provide both input_ids and inputs_embeds simultaneously")
 
414
  raise ValueError("You need to provide input_ids or inputs_embeds")
415
 
416
  seq_length_with_past = seq_length
417
+ past_key_values_length = 0
418
  if past_key_values is not None:
419
+ past_key_values_length = past_key_values[0][0].shape[1]
420
  seq_length_with_past = seq_length_with_past + past_key_values_length
421
 
422
  if inputs_embeds is None:
423
  inputs_embeds = self.embed_tokens(input_ids)
424
 
 
 
 
425
  hidden_states = inputs_embeds
426
 
427
+ if attention_mask is None:
428
+ attention_mask = torch.ones((batch_size, seq_length_with_past), device=hidden_states.device)
429
+ else:
430
+ attention_mask = attention_mask.to(hidden_states.device)
431
+
432
  if self.gradient_checkpointing and self.training:
433
  if use_cache:
434
  logger.warning_once(
 
436
  )
437
  use_cache = False
438
 
439
+ # Compute alibi tensor: check build_alibi_tensor documentation
440
+ alibi = self.build_alibi_tensor(attention_mask, self.n_head, dtype=hidden_states.dtype)
441
+
442
+ causal_mask = self._prepare_attn_mask(
443
+ attention_mask,
444
+ input_shape=(batch_size, seq_length),
445
+ past_key_values_length=past_key_values_length,
446
+ )
447
+
448
  # decoder layers
449
  all_hidden_states = () if output_hidden_states else None
450
  all_self_attns = () if output_attentions else None
 
468
  layer_outputs = torch.utils.checkpoint.checkpoint(
469
  create_custom_forward(decoder_layer),
470
  hidden_states,
471
+ alibi,
472
+ causal_mask,
473
  None,
474
  )
475
  else:
476
  layer_outputs = decoder_layer(
477
  hidden_states,
478
+ alibi=alibi,
479
+ attention_mask=causal_mask,
480
  past_key_value=past_key_value,
481
  output_attentions=output_attentions,
482
  use_cache=use_cache,
 
497
  all_hidden_states += (hidden_states,)
498
 
499
  next_cache = next_decoder_cache if use_cache else None
500
+
501
  if not return_dict:
502
  return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
503
+
504
  return BaseModelOutputWithPast(
505
  last_hidden_state=hidden_states,
506
  past_key_values=next_cache,
507
  hidden_states=all_hidden_states,
508
  attentions=all_self_attns,
509
  )
510
+
511
 
512
  class BaichuanForCausalLM(BaichuanPreTrainedModel):
513
+
514
  def __init__(self, config):
515
  super().__init__(config)
516
  self.model = BaichuanModel(config)
517
+
518
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
519
 
520
  # Initialize weights and apply final processing
521
  self.post_init()
 
537
 
538
  def get_decoder(self):
539
  return self.model
540
+
541
  def forward(
542
+ self,
543
+ input_ids: torch.LongTensor = None,
544
+ attention_mask: Optional[torch.Tensor] = None,
545
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
546
+ inputs_embeds: Optional[torch.FloatTensor] = None,
547
+ labels: Optional[torch.LongTensor] = None,
548
+ use_cache: Optional[bool] = None,
549
+ output_attentions: Optional[bool] = None,
550
+ output_hidden_states: Optional[bool] = None,
551
+ return_dict: Optional[bool] = None,
552
+ **kwargs
553
  ) -> Union[Tuple, CausalLMOutputWithPast]:
554
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
555
+ output_hidden_states = (
556
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
557
+ )
558
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
559
 
560
  # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
561
  outputs = self.model(
562
  input_ids=input_ids,
563
+ attention_mask=attention_mask,
564
  past_key_values=past_key_values,
565
  inputs_embeds=inputs_embeds,
566
  use_cache=use_cache,
567
  output_attentions=output_attentions,
568
  output_hidden_states=output_hidden_states,
569
  return_dict=return_dict,
570
+ )
571
 
572
  hidden_states = outputs[0]
573
  logits = self.lm_head(hidden_states)
 
598
  )
599
 
600
  def prepare_inputs_for_generation(
601
+ self,
602
+ input_ids: torch.LongTensor,
603
+ past_key_values: Optional[torch.Tensor] = None,
604
+ attention_mask: Optional[torch.Tensor] = None,
605
+ inputs_embeds: Optional[torch.Tensor] = None,
606
+ **kwargs
607
+ ) -> dict:
608
  if past_key_values:
609
  input_ids = input_ids[:, -1:]
610
 
611
+ # the cache may be in the standard format (e.g. in contrastive search)
612
+ if past_key_values[0][0].shape[0] == input_ids.shape[0]:
613
+ past_key_values = self._convert_to_baichuan_cache(past_key_values)
614
+
615
  # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
616
  if inputs_embeds is not None and past_key_values is None:
617
  model_inputs = {"inputs_embeds": inputs_embeds}
 
619
  model_inputs = {"input_ids": input_ids}
620
 
621
  model_inputs.update(
622
+ {
623
  "past_key_values": past_key_values,
624
  "use_cache": kwargs.get("use_cache"),
625
+ "attention_mask": attention_mask,
626
+ }
627
+ )
628
  return model_inputs
629
 
630
+ def _reorder_cache(
631
+ self, past: Tuple[Tuple[torch.Tensor, torch.Tensor], ...], beam_idx: torch.LongTensor
632
+ ) -> Tuple[Tuple[torch.Tensor, torch.Tensor], ...]:
633
+ """
634
+ This function is used to re-order the `past_key_values` cache if [`~PreTrainedModel.beam_search`] or
635
+ [`~PreTrainedModel.beam_sample`] is called. This is required to match `past_key_values` with the correct
636
+ beam_idx at every generation step.
637
+
638
+ Output shares the same memory storage as `past`.
639
+ """
640
+ standardized_past = self._convert_to_standard_cache(past, batch_size=len(beam_idx))
641
+
642
+ # Get a copy of `beam_idx` on all the devices where we need those indices.
643
+ device_to_beam_idx = {
644
+ past_state.device: beam_idx.to(past_state.device) for layer_past in past for past_state in layer_past
645
+ }
646
+ reordered_past = tuple(
647
+ (
648
+ layer_past[0].index_select(0, device_to_beam_idx[layer_past[0].device]),
649
+ layer_past[1].index_select(0, device_to_beam_idx[layer_past[0].device]),
650
+ )
651
+ for layer_past in standardized_past
652
  )
653
+ return self._convert_to_baichuan_cache(reordered_past)
654
 
655
  def quantize(self, bits: int):
656
  try:
 
659
  raise ImportError(
660
  f"Needs QLinear to run quantize."
661
  )
662
+
663
  for layer in self.model.layers:
664
  layer.self_attn.W_pack = QLinear(
665
  bits=bits,
 
686
  weight=layer.mlp.up_proj.weight,
687
  bias = None,
688
  )
689
+ return self
690
 
691
  def _build_chat_input(self, tokenizer, messages: List[dict], max_new_tokens: int=0):
692
  max_new_tokens = max_new_tokens or self.generation_config.max_new_tokens