BAAI
/

shunxing1234 commited on
Commit
65ffa9d
·
1 Parent(s): 9ea6bfd

Update modeling_aquila.py

Browse files
Files changed (1) hide show
  1. modeling_aquila.py +191 -54
modeling_aquila.py CHANGED
@@ -93,34 +93,83 @@ class AquilaRMSNorm(nn.Module):
93
  class AquilaRotaryEmbedding(torch.nn.Module):
94
  def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
95
  super().__init__()
96
- inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float().to(device) / dim))
97
- self.register_buffer("inv_freq", inv_freq)
 
 
 
 
98
 
99
  # Build here to make `torch.jit.trace` work.
100
- self.max_seq_len_cached = max_position_embeddings
101
- t = torch.arange(self.max_seq_len_cached, device=self.inv_freq.device, dtype=self.inv_freq.dtype)
 
 
 
 
 
 
102
  freqs = torch.einsum("i,j->ij", t, self.inv_freq)
103
  # Different from paper, but it uses a different permutation in order to obtain the same calculation
104
  emb = torch.cat((freqs, freqs), dim=-1)
105
- self.register_buffer("cos_cached", emb.cos()[None, None, :, :], persistent=False)
106
- self.register_buffer("sin_cached", emb.sin()[None, None, :, :], persistent=False)
107
 
108
  def forward(self, x, seq_len=None):
109
  # x: [bs, num_attention_heads, seq_len, head_size]
110
- # This `if` block is unlikely to be run after we build sin/cos in `__init__`. Keep the logic here just in case.
111
  if seq_len > self.max_seq_len_cached:
112
- self.max_seq_len_cached = seq_len
113
- t = torch.arange(self.max_seq_len_cached, device=x.device, dtype=self.inv_freq.dtype)
114
- freqs = torch.einsum("i,j->ij", t, self.inv_freq)
115
- # Different from paper, but it uses a different permutation in order to obtain the same calculation
116
- emb = torch.cat((freqs, freqs), dim=-1).to(x.device)
117
- self.register_buffer("cos_cached", emb.cos()[None, None, :, :], persistent=False)
118
- self.register_buffer("sin_cached", emb.sin()[None, None, :, :], persistent=False)
119
  return (
120
  self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
121
  self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
122
  )
123
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
124
 
125
  def rotate_half(x):
126
  """Rotates half the hidden dims of the input."""
@@ -142,33 +191,64 @@ def apply_rotary_pos_emb(q, k, cos, sin, position_ids):
142
 
143
  # Copied from transformers.models.llama.modeling_llama.LlamaMLP with Llama->Aquila
144
  class AquilaMLP(nn.Module):
145
- def __init__(
146
- self,
147
- hidden_size: int,
148
- intermediate_size: int,
149
- hidden_act: str,
150
- ):
151
  super().__init__()
152
- self.gate_proj = nn.Linear(hidden_size, intermediate_size, bias=False)
153
- self.down_proj = nn.Linear(intermediate_size, hidden_size, bias=False)
154
- self.up_proj = nn.Linear(hidden_size, intermediate_size, bias=False)
155
- self.act_fn = ACT2FN[hidden_act]
 
 
 
156
 
157
  def forward(self, x):
158
- return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
159
 
160
 
161
  # Copied from transformers.models.llama.modeling_llama.LlamaAttention with Llama->Aquila
162
  class AquilaAttention(nn.Module):
163
  """Multi-headed attention from 'Attention Is All You Need' paper"""
164
-
165
  def __init__(self, config: AquilaConfig):
166
  super().__init__()
167
  self.config = config
168
  self.hidden_size = config.hidden_size
169
  self.num_heads = config.num_attention_heads
170
  self.head_dim = self.hidden_size // self.num_heads
 
 
171
  self.max_position_embeddings = config.max_position_embeddings
 
172
 
173
  if (self.head_dim * self.num_heads) != self.hidden_size:
174
  raise ValueError(
@@ -176,10 +256,37 @@ class AquilaAttention(nn.Module):
176
  f" and `num_heads`: {self.num_heads})."
177
  )
178
  self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
179
- self.k_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
180
- self.v_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
181
  self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
182
- self.rotary_emb = AquilaRotaryEmbedding(self.head_dim, max_position_embeddings=self.max_position_embeddings)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
183
 
184
  def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
185
  return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
@@ -195,16 +302,37 @@ class AquilaAttention(nn.Module):
195
  ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
196
  bsz, q_len, _ = hidden_states.size()
197
 
198
- query_states = self.q_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
199
- key_states = self.k_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
200
- value_states = self.v_proj(hidden_states).view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
201
 
202
  kv_seq_len = key_states.shape[-2]
203
  if past_key_value is not None:
204
  kv_seq_len += past_key_value[0].shape[-2]
205
  cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
206
  query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
207
- # [bsz, nh, t, hd]
208
 
209
  if past_key_value is not None:
210
  # reuse k, v, self_attention
@@ -213,9 +341,12 @@ class AquilaAttention(nn.Module):
213
 
214
  past_key_value = (key_states, value_states) if use_cache else None
215
 
 
 
 
 
216
  attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
217
 
218
- attn_weights = torch.clamp(attn_weights, min=-1024., max=1024.)
219
  if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
220
  raise ValueError(
221
  f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is"
@@ -228,9 +359,6 @@ class AquilaAttention(nn.Module):
228
  f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
229
  )
230
  attn_weights = attn_weights + attention_mask
231
- attn_weights = torch.max(
232
- attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min, device=attn_weights.device)
233
- )
234
 
235
  # upcast attention to fp32
236
  attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
@@ -242,10 +370,15 @@ class AquilaAttention(nn.Module):
242
  f" {attn_output.size()}"
243
  )
244
 
245
- attn_output = attn_output.transpose(1, 2)
246
  attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
247
 
248
- attn_output = self.o_proj(attn_output)
 
 
 
 
 
249
 
250
  if not output_attentions:
251
  attn_weights = None
@@ -259,11 +392,7 @@ class AquilaDecoderLayer(nn.Module):
259
  super().__init__()
260
  self.hidden_size = config.hidden_size
261
  self.self_attn = AquilaAttention(config=config)
262
- self.mlp = AquilaMLP(
263
- hidden_size=self.hidden_size,
264
- intermediate_size=config.intermediate_size,
265
- hidden_act=config.hidden_act,
266
- )
267
  self.input_layernorm = AquilaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
268
  self.post_attention_layernorm = AquilaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
269
 
@@ -321,7 +450,6 @@ class AquilaDecoderLayer(nn.Module):
321
 
322
  return outputs
323
 
324
-
325
  AQUILA_START_DOCSTRING = r"""
326
  This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
327
  library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
@@ -350,7 +478,6 @@ class AquilaPreTrainedModel(PreTrainedModel):
350
  supports_gradient_checkpointing = True
351
  _no_split_modules = ["AquilaDecoderLayer"]
352
  _skip_keys_device_placement = "past_key_values"
353
- _keys_to_ignore_on_load_unexpected = [r"decoder\.version"]
354
 
355
  def _init_weights(self, module):
356
  std = self.config.initializer_range
@@ -570,7 +697,7 @@ class AquilaModel(AquilaPreTrainedModel):
570
  def create_custom_forward(module):
571
  def custom_forward(*inputs):
572
  # None for past_key_value
573
- return module(*inputs, output_attentions, None)
574
 
575
  return custom_forward
576
 
@@ -579,7 +706,6 @@ class AquilaModel(AquilaPreTrainedModel):
579
  hidden_states,
580
  attention_mask,
581
  position_ids,
582
- None,
583
  )
584
  else:
585
  layer_outputs = decoder_layer(
@@ -600,6 +726,7 @@ class AquilaModel(AquilaPreTrainedModel):
600
  all_self_attns += (layer_outputs[1],)
601
 
602
  hidden_states = self.norm(hidden_states)
 
603
  # add hidden states from the last decoder layer
604
  if output_hidden_states:
605
  all_hidden_states += (hidden_states,)
@@ -614,13 +741,14 @@ class AquilaModel(AquilaPreTrainedModel):
614
  attentions=all_self_attns,
615
  )
616
 
617
-
618
  # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM with LLAMA->AQUILA,Llama->Aquila
619
  class AquilaForCausalLM(AquilaPreTrainedModel):
 
 
620
  def __init__(self, config):
621
  super().__init__(config)
622
  self.model = AquilaModel(config)
623
-
624
  self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
625
 
626
  # Initialize weights and apply final processing
@@ -705,7 +833,13 @@ class AquilaForCausalLM(AquilaPreTrainedModel):
705
  )
706
 
707
  hidden_states = outputs[0]
708
- logits = self.lm_head(hidden_states)
 
 
 
 
 
 
709
 
710
  loss = None
711
  if labels is not None:
@@ -766,10 +900,11 @@ class AquilaForCausalLM(AquilaPreTrainedModel):
766
  def _reorder_cache(past_key_values, beam_idx):
767
  reordered_past = ()
768
  for layer_past in past_key_values:
769
- reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),)
 
 
770
  return reordered_past
771
 
772
-
773
  @add_start_docstrings(
774
  """
775
  The LLaMa Model transformer with a sequence classification head on top (linear layer).
@@ -851,7 +986,9 @@ class AquilaForSequenceClassification(AquilaPreTrainedModel):
851
  sequence_lengths = -1
852
  else:
853
  if input_ids is not None:
854
- sequence_lengths = (torch.ne(input_ids, self.config.pad_token_id).sum(-1) - 1).to(logits.device)
 
 
855
  else:
856
  sequence_lengths = -1
857
 
 
93
  class AquilaRotaryEmbedding(torch.nn.Module):
94
  def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
95
  super().__init__()
96
+
97
+ self.dim = dim
98
+ self.max_position_embeddings = max_position_embeddings
99
+ self.base = base
100
+ inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim))
101
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
102
 
103
  # Build here to make `torch.jit.trace` work.
104
+ self._set_cos_sin_cache(
105
+ seq_len=max_position_embeddings, device=self.inv_freq.device, dtype=torch.get_default_dtype()
106
+ )
107
+
108
+ def _set_cos_sin_cache(self, seq_len, device, dtype):
109
+ self.max_seq_len_cached = seq_len
110
+ t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)
111
+
112
  freqs = torch.einsum("i,j->ij", t, self.inv_freq)
113
  # Different from paper, but it uses a different permutation in order to obtain the same calculation
114
  emb = torch.cat((freqs, freqs), dim=-1)
115
+ self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False)
116
+ self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False)
117
 
118
  def forward(self, x, seq_len=None):
119
  # x: [bs, num_attention_heads, seq_len, head_size]
 
120
  if seq_len > self.max_seq_len_cached:
121
+ self._set_cos_sin_cache(seq_len=seq_len, device=x.device, dtype=x.dtype)
122
+
 
 
 
 
 
123
  return (
124
  self.cos_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
125
  self.sin_cached[:, :, :seq_len, ...].to(dtype=x.dtype),
126
  )
127
 
128
+ # Copied from transformers.models.llama.modeling_llama.LlamaLinearScalingRotaryEmbedding with Llama->Aquila
129
+ class AquilaLinearScalingRotaryEmbedding(AquilaRotaryEmbedding):
130
+ """LlamaRotaryEmbedding extended with linear scaling. Credits to the Reddit user /u/kaiokendev"""
131
+
132
+ def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0):
133
+ self.scaling_factor = scaling_factor
134
+ super().__init__(dim, max_position_embeddings, base, device)
135
+
136
+ def _set_cos_sin_cache(self, seq_len, device, dtype):
137
+ self.max_seq_len_cached = seq_len
138
+ t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)
139
+ t = t / self.scaling_factor
140
+
141
+ freqs = torch.einsum("i,j->ij", t, self.inv_freq)
142
+ # Different from paper, but it uses a different permutation in order to obtain the same calculation
143
+ emb = torch.cat((freqs, freqs), dim=-1)
144
+ self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False)
145
+ self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False)
146
+
147
+ # Copied from transformers.models.llama.modeling_llama.LlamaDynamicNTKScalingRotaryEmbedding with Llama->Aquila
148
+ class AquilaDynamicNTKScalingRotaryEmbedding(LlamaRotaryEmbedding):
149
+ """LlamaRotaryEmbedding extended with Dynamic NTK scaling. Credits to the Reddit users /u/bloc97 and /u/emozilla"""
150
+
151
+ def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None, scaling_factor=1.0):
152
+ self.scaling_factor = scaling_factor
153
+ super().__init__(dim, max_position_embeddings, base, device)
154
+
155
+ def _set_cos_sin_cache(self, seq_len, device, dtype):
156
+ self.max_seq_len_cached = seq_len
157
+
158
+ if seq_len > self.max_position_embeddings:
159
+ base = self.base * (
160
+ (self.scaling_factor * seq_len / self.max_position_embeddings) - (self.scaling_factor - 1)
161
+ ) ** (self.dim / (self.dim - 2))
162
+ inv_freq = 1.0 / (base ** (torch.arange(0, self.dim, 2).float().to(device) / self.dim))
163
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
164
+
165
+ t = torch.arange(self.max_seq_len_cached, device=device, dtype=self.inv_freq.dtype)
166
+
167
+ freqs = torch.einsum("i,j->ij", t, self.inv_freq)
168
+ # Different from paper, but it uses a different permutation in order to obtain the same calculation
169
+ emb = torch.cat((freqs, freqs), dim=-1)
170
+ self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False)
171
+ self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False)
172
+
173
 
174
  def rotate_half(x):
175
  """Rotates half the hidden dims of the input."""
 
191
 
192
  # Copied from transformers.models.llama.modeling_llama.LlamaMLP with Llama->Aquila
193
  class AquilaMLP(nn.Module):
194
+ def __init__(self, config):
 
 
 
 
 
195
  super().__init__()
196
+ self.config = config
197
+ self.hidden_size = config.hidden_size
198
+ self.intermediate_size = config.intermediate_size
199
+ self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
200
+ self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
201
+ self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
202
+ self.act_fn = ACT2FN[config.hidden_act]
203
 
204
  def forward(self, x):
205
+ if self.config.pretraining_tp > 1:
206
+ slice = self.intermediate_size // self.config.pretraining_tp
207
+ gate_proj_slices = self.gate_proj.weight.split(slice, dim=0)
208
+ up_proj_slices = self.up_proj.weight.split(slice, dim=0)
209
+ down_proj_slices = self.down_proj.weight.split(slice, dim=1)
210
+
211
+ gate_proj = torch.cat(
212
+ [F.linear(x, gate_proj_slices[i]) for i in range(self.config.pretraining_tp)], dim=-1
213
+ )
214
+ up_proj = torch.cat([F.linear(x, up_proj_slices[i]) for i in range(self.config.pretraining_tp)], dim=-1)
215
+
216
+ intermediate_states = (self.act_fn(gate_proj) * up_proj).split(slice, dim=2)
217
+ down_proj = [
218
+ F.linear(intermediate_states[i], down_proj_slices[i]) for i in range(self.config.pretraining_tp)
219
+ ]
220
+ down_proj = sum(down_proj)
221
+ else:
222
+ down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
223
+
224
+ return down_proj
225
+
226
+
227
+ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
228
+ """
229
+ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
230
+ num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
231
+ """
232
+ batch, num_key_value_heads, slen, head_dim = hidden_states.shape
233
+ if n_rep == 1:
234
+ return hidden_states
235
+ hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
236
+ return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
237
 
238
 
239
  # Copied from transformers.models.llama.modeling_llama.LlamaAttention with Llama->Aquila
240
  class AquilaAttention(nn.Module):
241
  """Multi-headed attention from 'Attention Is All You Need' paper"""
 
242
  def __init__(self, config: AquilaConfig):
243
  super().__init__()
244
  self.config = config
245
  self.hidden_size = config.hidden_size
246
  self.num_heads = config.num_attention_heads
247
  self.head_dim = self.hidden_size // self.num_heads
248
+ self.num_key_value_heads = config.num_key_value_heads
249
+ self.num_key_value_groups = self.num_heads // self.num_key_value_heads
250
  self.max_position_embeddings = config.max_position_embeddings
251
+ self.rope_theta = config.rope_theta
252
 
253
  if (self.head_dim * self.num_heads) != self.hidden_size:
254
  raise ValueError(
 
256
  f" and `num_heads`: {self.num_heads})."
257
  )
258
  self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=False)
259
+ self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
260
+ self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=False)
261
  self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False)
262
+ self._init_rope()
263
+
264
+ def _init_rope(self):
265
+ if self.config.rope_scaling is None:
266
+ self.rotary_emb = AquilaRotaryEmbedding(
267
+ self.head_dim,
268
+ max_position_embeddings=self.max_position_embeddings,
269
+ base=self.rope_theta,
270
+ )
271
+ else:
272
+ scaling_type = self.config.rope_scaling["type"]
273
+ scaling_factor = self.config.rope_scaling["factor"]
274
+ if scaling_type == "linear":
275
+ self.rotary_emb = AquilaLinearScalingRotaryEmbedding(
276
+ self.head_dim,
277
+ max_position_embeddings=self.max_position_embeddings,
278
+ scaling_factor=scaling_factor,
279
+ base=self.rope_theta,
280
+ )
281
+ elif scaling_type == "dynamic":
282
+ self.rotary_emb = AquilaDynamicNTKScalingRotaryEmbedding(
283
+ self.head_dim,
284
+ max_position_embeddings=self.max_position_embeddings,
285
+ scaling_factor=scaling_factor,
286
+ base=self.rope_theta,
287
+ )
288
+ else:
289
+ raise ValueError(f"Unknown RoPE scaling type {scaling_type}")
290
 
291
  def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int):
292
  return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous()
 
302
  ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
303
  bsz, q_len, _ = hidden_states.size()
304
 
305
+ if self.config.pretraining_tp > 1:
306
+ key_value_slicing = (self.num_key_value_heads * self.head_dim) // self.config.pretraining_tp
307
+ query_slices = self.q_proj.weight.split(
308
+ (self.num_heads * self.head_dim) // self.config.pretraining_tp, dim=0
309
+ )
310
+ key_slices = self.k_proj.weight.split(key_value_slicing, dim=0)
311
+ value_slices = self.v_proj.weight.split(key_value_slicing, dim=0)
312
+
313
+ query_states = [F.linear(hidden_states, query_slices[i]) for i in range(self.config.pretraining_tp)]
314
+ query_states = torch.cat(query_states, dim=-1)
315
+
316
+ key_states = [F.linear(hidden_states, key_slices[i]) for i in range(self.config.pretraining_tp)]
317
+ key_states = torch.cat(key_states, dim=-1)
318
+
319
+ value_states = [F.linear(hidden_states, value_slices[i]) for i in range(self.config.pretraining_tp)]
320
+ value_states = torch.cat(value_states, dim=-1)
321
+
322
+ else:
323
+ query_states = self.q_proj(hidden_states)
324
+ key_states = self.k_proj(hidden_states)
325
+ value_states = self.v_proj(hidden_states)
326
+
327
+ query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
328
+ key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
329
+ value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
330
 
331
  kv_seq_len = key_states.shape[-2]
332
  if past_key_value is not None:
333
  kv_seq_len += past_key_value[0].shape[-2]
334
  cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
335
  query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
 
336
 
337
  if past_key_value is not None:
338
  # reuse k, v, self_attention
 
341
 
342
  past_key_value = (key_states, value_states) if use_cache else None
343
 
344
+ # repeat k/v heads if n_kv_heads < n_heads
345
+ key_states = repeat_kv(key_states, self.num_key_value_groups)
346
+ value_states = repeat_kv(value_states, self.num_key_value_groups)
347
+
348
  attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
349
 
 
350
  if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
351
  raise ValueError(
352
  f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is"
 
359
  f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
360
  )
361
  attn_weights = attn_weights + attention_mask
 
 
 
362
 
363
  # upcast attention to fp32
364
  attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
 
370
  f" {attn_output.size()}"
371
  )
372
 
373
+ attn_output = attn_output.transpose(1, 2).contiguous()
374
  attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
375
 
376
+ if self.config.pretraining_tp > 1:
377
+ attn_output = attn_output.split(self.hidden_size // self.config.pretraining_tp, dim=2)
378
+ o_proj_slices = self.o_proj.weight.split(self.hidden_size // self.config.pretraining_tp, dim=1)
379
+ attn_output = sum([F.linear(attn_output[i], o_proj_slices[i]) for i in range(self.config.pretraining_tp)])
380
+ else:
381
+ attn_output = self.o_proj(attn_output)
382
 
383
  if not output_attentions:
384
  attn_weights = None
 
392
  super().__init__()
393
  self.hidden_size = config.hidden_size
394
  self.self_attn = AquilaAttention(config=config)
395
+ self.mlp = AquilaMLP(config)
 
 
 
 
396
  self.input_layernorm = AquilaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
397
  self.post_attention_layernorm = AquilaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
398
 
 
450
 
451
  return outputs
452
 
 
453
  AQUILA_START_DOCSTRING = r"""
454
  This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
455
  library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
 
478
  supports_gradient_checkpointing = True
479
  _no_split_modules = ["AquilaDecoderLayer"]
480
  _skip_keys_device_placement = "past_key_values"
 
481
 
482
  def _init_weights(self, module):
483
  std = self.config.initializer_range
 
697
  def create_custom_forward(module):
698
  def custom_forward(*inputs):
699
  # None for past_key_value
700
+ return module(*inputs, past_key_value, output_attentions)
701
 
702
  return custom_forward
703
 
 
706
  hidden_states,
707
  attention_mask,
708
  position_ids,
 
709
  )
710
  else:
711
  layer_outputs = decoder_layer(
 
726
  all_self_attns += (layer_outputs[1],)
727
 
728
  hidden_states = self.norm(hidden_states)
729
+
730
  # add hidden states from the last decoder layer
731
  if output_hidden_states:
732
  all_hidden_states += (hidden_states,)
 
741
  attentions=all_self_attns,
742
  )
743
 
 
744
  # Copied from transformers.models.llama.modeling_llama.LlamaForCausalLM with LLAMA->AQUILA,Llama->Aquila
745
  class AquilaForCausalLM(AquilaPreTrainedModel):
746
+ _tied_weights_keys = ["lm_head.weight"]
747
+
748
  def __init__(self, config):
749
  super().__init__(config)
750
  self.model = AquilaModel(config)
751
+ self.vocab_size = config.vocab_size
752
  self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
753
 
754
  # Initialize weights and apply final processing
 
833
  )
834
 
835
  hidden_states = outputs[0]
836
+ if self.config.pretraining_tp > 1:
837
+ lm_head_slices = self.lm_head.weight.split(self.vocab_size // self.config.pretraining_tp, dim=0)
838
+ logits = [F.linear(hidden_states, lm_head_slices[i]) for i in range(self.config.pretraining_tp)]
839
+ logits = torch.cat(logits, dim=-1)
840
+ else:
841
+ logits = self.lm_head(hidden_states)
842
+ logits = logits.float()
843
 
844
  loss = None
845
  if labels is not None:
 
900
  def _reorder_cache(past_key_values, beam_idx):
901
  reordered_past = ()
902
  for layer_past in past_key_values:
903
+ reordered_past += (
904
+ tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),
905
+ )
906
  return reordered_past
907
 
 
908
  @add_start_docstrings(
909
  """
910
  The LLaMa Model transformer with a sequence classification head on top (linear layer).
 
986
  sequence_lengths = -1
987
  else:
988
  if input_ids is not None:
989
+ sequence_lengths = (torch.eq(input_ids, self.config.pad_token_id).long().argmax(-1) - 1).to(
990
+ logits.device
991
+ )
992
  else:
993
  sequence_lengths = -1
994