JingzeShi commited on
Commit
38cff35
·
verified ·
1 Parent(s): 143cac7

Upload DogeForCausalLM

Browse files
Files changed (4) hide show
  1. config.json +1 -0
  2. configuration_doge.py +8 -0
  3. model.safetensors +2 -2
  4. modeling_doge.py +126 -29
config.json CHANGED
@@ -9,6 +9,7 @@
9
  "AutoModelForCausalLM": "modeling_doge.DogeForCausalLM"
10
  },
11
  "bos_token_id": 0,
 
12
  "eos_token_id": 1,
13
  "expert_retrieval_size": 256,
14
  "hidden_act": "silu",
 
9
  "AutoModelForCausalLM": "modeling_doge.DogeForCausalLM"
10
  },
11
  "bos_token_id": 0,
12
+ "dynamic_mask_ratio": 0.0,
13
  "eos_token_id": 1,
14
  "expert_retrieval_size": 256,
15
  "hidden_act": "silu",
configuration_doge.py CHANGED
@@ -111,6 +111,8 @@ class DogeConfig(PretrainedConfig):
111
  If it is not specified, will default to `num_attention_heads`.
112
  attention_dropout (`float`, *optional*, defaults to 0.0):
113
  The dropout ratio for the attention probabilities.
 
 
114
  is_moe (`bool`, *optional*, defaults to `False`):
115
  Whether to use the Cross Domain Mixture of Experts, if `True`, the MoE will inherit the MLP to initialize
116
  num_cdmmoe_experts (`int`, *optional*, defaults to 2048):
@@ -154,6 +156,7 @@ class DogeConfig(PretrainedConfig):
154
  num_attention_heads=8,
155
  num_key_value_heads=None,
156
  attention_dropout=0.0,
 
157
  is_moe=False,
158
  num_cdmmoe_experts=2048,
159
  num_cdmmoe_heads=4,
@@ -183,6 +186,7 @@ class DogeConfig(PretrainedConfig):
183
  self.num_attention_heads = num_attention_heads
184
  self.num_key_value_heads = num_key_value_heads
185
  self.attention_dropout = attention_dropout
 
186
  self.is_moe = is_moe
187
  self.num_cdmmoe_experts = num_cdmmoe_experts
188
  self.num_cdmmoe_heads = num_cdmmoe_heads
@@ -195,6 +199,10 @@ class DogeConfig(PretrainedConfig):
195
  self.rope_scaling["rope_type"] = self.rope_scaling["type"]
196
  rope_config_validation(self)
197
 
 
 
 
 
198
  super().__init__(
199
  bos_token_id=bos_token_id,
200
  eos_token_id=eos_token_id,
 
111
  If it is not specified, will default to `num_attention_heads`.
112
  attention_dropout (`float`, *optional*, defaults to 0.0):
113
  The dropout ratio for the attention probabilities.
114
+ dynamic_mask_ratio (`float`, *optional*, defaults to 0.0, range [0, 1]):
115
+ The ratio to control the proportion of the dynamic mask filled with the minimum value.
116
  is_moe (`bool`, *optional*, defaults to `False`):
117
  Whether to use the Cross Domain Mixture of Experts, if `True`, the MoE will inherit the MLP to initialize
118
  num_cdmmoe_experts (`int`, *optional*, defaults to 2048):
 
156
  num_attention_heads=8,
157
  num_key_value_heads=None,
158
  attention_dropout=0.0,
159
+ dynamic_mask_ratio=0.0,
160
  is_moe=False,
161
  num_cdmmoe_experts=2048,
162
  num_cdmmoe_heads=4,
 
186
  self.num_attention_heads = num_attention_heads
187
  self.num_key_value_heads = num_key_value_heads
188
  self.attention_dropout = attention_dropout
189
+ self.dynamic_mask_ratio = dynamic_mask_ratio
190
  self.is_moe = is_moe
191
  self.num_cdmmoe_experts = num_cdmmoe_experts
192
  self.num_cdmmoe_heads = num_cdmmoe_heads
 
199
  self.rope_scaling["rope_type"] = self.rope_scaling["type"]
200
  rope_config_validation(self)
201
 
202
+ # for backward compatibility
203
+ if num_key_value_heads is None:
204
+ self.num_key_value_heads = num_attention_heads
205
+
206
  super().__init__(
207
  bos_token_id=bos_token_id,
208
  eos_token_id=eos_token_id,
model.safetensors CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:09c44ce706b29a9afa8252cd96767076dc44b5ea32b60d832e51559bb26df3ed
3
- size 52490344
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3632a5c94bc7d3cf66602318b168603ec19f1025e0aef01c286d65e30ed55e8b
3
+ size 52482152
modeling_doge.py CHANGED
@@ -39,6 +39,7 @@ from transformers.modeling_utils import PreTrainedModel
39
  from transformers.utils import (
40
  add_start_docstrings,
41
  add_start_docstrings_to_model_forward,
 
42
  logging,
43
  replace_return_docstrings,
44
  )
@@ -49,6 +50,9 @@ try:
49
  except ImportError:
50
  einx_add = None
51
 
 
 
 
52
 
53
  logger = logging.get_logger(__name__)
54
 
@@ -216,14 +220,15 @@ class DogeDynamicMaskAttention(nn.Module):
216
  self.num_key_value_heads = config.num_key_value_heads
217
  self.num_key_value_groups = self.num_heads // self.num_key_value_heads
218
  self.attention_dropout = config.attention_dropout
 
219
 
220
  # Q K V O projections
221
  self.q_proj = nn.Linear(self.hidden_dim, self.num_heads * self.head_dim, bias=config.hidden_bias)
222
  self.k_proj = nn.Linear(self.hidden_dim, self.num_key_value_heads * self.head_dim, bias=config.hidden_bias)
 
223
  # dynamic mask for the QK^T attention score matrix
224
  self.A = nn.Parameter(torch.ones(self.num_heads))
225
- self.dt_proj = nn.Linear(self.hidden_dim, self.num_heads, bias=config.hidden_bias)
226
- self.v_proj = nn.Linear(self.hidden_dim, self.num_key_value_heads * self.head_dim, bias=config.hidden_bias)
227
  self.o_proj = nn.Linear(self.hidden_dim, self.hidden_dim, bias=config.hidden_bias)
228
 
229
  def forward(
@@ -254,6 +259,10 @@ class DogeDynamicMaskAttention(nn.Module):
254
  cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
255
  key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
256
 
 
 
 
 
257
  # repeat key and value states
258
  key_states = repeat_kv(key_states, self.num_key_value_groups)
259
  value_states = repeat_kv(value_states, self.num_key_value_groups)
@@ -262,12 +271,13 @@ class DogeDynamicMaskAttention(nn.Module):
262
  attn_weights = torch.matmul(query_states, key_states.transpose(-1, -2)) / math.sqrt(self.head_dim)
263
 
264
  # add mask to attention scores
265
- if attention_mask is not None:
266
- dt_states = self.dt_proj(value_states.transpose(1, 2).reshape(bsz, value_states.shape[-2], -1))
267
- dynamic_mask = torch.exp(self.A * F.softplus(dt_states)).transpose(-1, -2)
268
- dynamic_mask = dynamic_mask < 1.0
269
- causal_mask = attention_mask[:, :, :, : key_states.shape[-2]].masked_fill(dynamic_mask[:, :, None, :], torch.finfo(hidden_states.dtype).min)
270
- attn_weights = attn_weights + causal_mask
 
271
 
272
  # upcast attention scores to fp32
273
  attn_weights = F.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
@@ -282,8 +292,35 @@ class DogeDynamicMaskAttention(nn.Module):
282
 
283
  return attn_output, past_key_value
284
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
285
 
286
- class DogeSdpaDynamicMaskAttn(DogeDynamicMaskAttention):
 
287
 
288
  def forward(
289
  self,
@@ -312,34 +349,31 @@ class DogeSdpaDynamicMaskAttn(DogeDynamicMaskAttention):
312
  # sin and cos are specific to RoPE models; cache_position needed for the static cache
313
  cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
314
  key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
 
 
 
 
315
 
316
- # repeat key and value states
317
- key_states = repeat_kv(key_states, self.num_key_value_groups)
318
- value_states = repeat_kv(value_states, self.num_key_value_groups)
319
-
320
- causal_mask = attention_mask
321
- if attention_mask is not None:
322
- dt_states = self.dt_proj(value_states.transpose(1, 2).reshape(bsz, value_states.shape[-2], -1))
323
- dynamic_mask = torch.exp(self.A * F.softplus(dt_states)).transpose(-1, -2)
324
- dynamic_mask = dynamic_mask < 1.0
325
- causal_mask = attention_mask[:, :, :, : key_states.shape[-2]].masked_fill(dynamic_mask[:, :, None, :], torch.finfo(hidden_states.dtype).min)
326
 
327
  query_states = query_states.contiguous()
328
  key_states = key_states.contiguous()
329
  value_states = value_states.contiguous()
330
 
331
- # We dispatch to SDPA's Flash Attention or Efficient kernels via this `is_causal` if statement instead of an inline conditional assignment in SDPA to support both torch.compile's dynamic shapes and full graph options. An inline conditional prevents dynamic shapes from compiling.
332
- is_causal = True if causal_mask is None and q_len > 1 else False
333
-
334
  # NOTE: As of pytorch 2.5.1, cuDNN's SDPA backward pass is still incorrect, so we disable cuDNN SDPA (see https://github.com/pytorch/pytorch/issues/138581)
335
  torch.backends.cuda.enable_cudnn_sdp(False)
336
  attn_output = F.scaled_dot_product_attention(
337
  query_states,
338
  key_states,
339
  value_states,
340
- attn_mask=causal_mask,
341
  dropout_p=self.attention_dropout if self.training else 0.0,
342
- is_causal=is_causal,
343
  )
344
 
345
  attn_output = attn_output.transpose(1, 2).contiguous()
@@ -349,9 +383,70 @@ class DogeSdpaDynamicMaskAttn(DogeDynamicMaskAttention):
349
  return attn_output, past_key_value
350
 
351
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
352
  DOGE_ATTENTION_CLASSES = {
 
353
  "eager": DogeDynamicMaskAttention,
354
- "sdpa": DogeSdpaDynamicMaskAttn,
355
  }
356
 
357
 
@@ -519,6 +614,7 @@ class DogePreTrainedModel(PreTrainedModel):
519
  supports_gradient_checkpointing = True
520
  _no_split_modules = ["DogeDecoderLayer"]
521
  _skip_keys_device_placement = ["past_key_values"]
 
522
  _supports_sdpa = True
523
  _supports_cache_class = True
524
  _supports_quantized_cache = True
@@ -693,7 +789,7 @@ class DogeModel(DogePreTrainedModel):
693
  all_self_attns = () if output_attentions else None
694
  next_decoder_cache = None
695
 
696
- for decoder_layer in self.layers:
697
  if output_hidden_states:
698
  all_hidden_states += (hidden_states,)
699
 
@@ -877,7 +973,7 @@ class DogeForCausalLM(DogePreTrainedModel, GenerationMixin):
877
  input_ids: torch.LongTensor = None,
878
  attention_mask: Optional[torch.Tensor] = None,
879
  position_ids: Optional[torch.LongTensor] = None,
880
- past_key_values: Optional[torch.Tensor] = None,
881
  inputs_embeds: Optional[torch.FloatTensor] = None,
882
  labels: Optional[torch.LongTensor] = None,
883
  use_cache: Optional[bool] = None,
@@ -886,7 +982,7 @@ class DogeForCausalLM(DogePreTrainedModel, GenerationMixin):
886
  return_dict: Optional[bool] = None,
887
  cache_position: Optional[torch.LongTensor] = None,
888
  num_logits_to_keep: int = 0,
889
- **loss_kwargs,
890
  ) -> Union[Tuple, CausalLMOutputWithPast]:
891
  r"""
892
  Args:
@@ -920,6 +1016,7 @@ class DogeForCausalLM(DogePreTrainedModel, GenerationMixin):
920
  output_hidden_states=output_hidden_states,
921
  return_dict=return_dict,
922
  cache_position=cache_position,
 
923
  )
924
 
925
  hidden_states = outputs[0]
@@ -929,7 +1026,7 @@ class DogeForCausalLM(DogePreTrainedModel, GenerationMixin):
929
 
930
  loss = None
931
  if labels is not None:
932
- loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.vocab_size, **loss_kwargs)
933
 
934
  if not return_dict:
935
  output = (logits,) + outputs[1:]
 
39
  from transformers.utils import (
40
  add_start_docstrings,
41
  add_start_docstrings_to_model_forward,
42
+ is_torch_greater_or_equal,
43
  logging,
44
  replace_return_docstrings,
45
  )
 
50
  except ImportError:
51
  einx_add = None
52
 
53
+ if is_torch_greater_or_equal("2.5"):
54
+ from torch.nn.attention.flex_attention import flex_attention
55
+
56
 
57
  logger = logging.get_logger(__name__)
58
 
 
220
  self.num_key_value_heads = config.num_key_value_heads
221
  self.num_key_value_groups = self.num_heads // self.num_key_value_heads
222
  self.attention_dropout = config.attention_dropout
223
+ self.dynamic_mask_ratio = config.dynamic_mask_ratio
224
 
225
  # Q K V O projections
226
  self.q_proj = nn.Linear(self.hidden_dim, self.num_heads * self.head_dim, bias=config.hidden_bias)
227
  self.k_proj = nn.Linear(self.hidden_dim, self.num_key_value_heads * self.head_dim, bias=config.hidden_bias)
228
+ self.v_proj = nn.Linear(self.hidden_dim, self.num_key_value_heads * self.head_dim, bias=config.hidden_bias)
229
  # dynamic mask for the QK^T attention score matrix
230
  self.A = nn.Parameter(torch.ones(self.num_heads))
231
+ self.dt_proj = nn.Linear(self.num_key_value_heads * self.head_dim, self.num_heads, bias=config.hidden_bias)
 
232
  self.o_proj = nn.Linear(self.hidden_dim, self.hidden_dim, bias=config.hidden_bias)
233
 
234
  def forward(
 
259
  cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
260
  key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
261
 
262
+ # calculate dynamic mask from value_states
263
+ dt_states = self.dt_proj(value_states.transpose(1, 2).reshape(bsz, value_states.shape[-2], -1))
264
+ dynamic_mask = torch.exp(self.A * F.softplus(dt_states)).transpose(-1, -2)
265
+
266
  # repeat key and value states
267
  key_states = repeat_kv(key_states, self.num_key_value_groups)
268
  value_states = repeat_kv(value_states, self.num_key_value_groups)
 
271
  attn_weights = torch.matmul(query_states, key_states.transpose(-1, -2)) / math.sqrt(self.head_dim)
272
 
273
  # add mask to attention scores
274
+ attn_mask = self.prepare_dynamic_mask(
275
+ hidden_states=hidden_states,
276
+ dynamic_mask=dynamic_mask,
277
+ dynamic_mask_ratio=0.1,
278
+ attention_mask=attention_mask,
279
+ )
280
+ attn_weights = attn_weights + attn_mask
281
 
282
  # upcast attention scores to fp32
283
  attn_weights = F.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
 
292
 
293
  return attn_output, past_key_value
294
 
295
+ def prepare_dynamic_mask(
296
+ self,
297
+ hidden_states: torch.Tensor,
298
+ dynamic_mask: torch.Tensor,
299
+ dynamic_mask_ratio: float = 0.0,
300
+ attention_mask: Optional[torch.Tensor] = None,
301
+ ):
302
+ """
303
+ Combine `dynamic_mask` with `attention_mask` to generate the final `attn_mask`.
304
+
305
+ Args:
306
+ hidden_states (`torch.Tensor`): The input hidden_states, used to determine the minimum value of the current input precision.
307
+ dynamic_mask (`torch.Tensor`): dynamic mask of shape `(batch_size, num_heads, key_sequence_length)`.
308
+ dynamic_mask_ratio (`float`, *optional*): Ratio from 0.0 to 1.0 used to control the proportion of the dynamic mask filled with the minimum value.
309
+ attention_mask (`torch.Tensor`, *optional*): attention mask of shape `(batch_size, 1, query_sequence_length, key_sequence_length)`.
310
+ """
311
+ min_type = torch.finfo(hidden_states.dtype).min
312
+ attn_mask = dynamic_mask[:, :, None, :]
313
+ if 0.0 < dynamic_mask_ratio < 1.0:
314
+ num_dynamic_mask = int(attn_mask.shape[-1] * dynamic_mask_ratio)
315
+ if num_dynamic_mask > 0:
316
+ rate_value = torch.kthvalue(attn_mask, num_dynamic_mask, dim=-1, keepdim=True).values
317
+ attn_mask = attn_mask.masked_fill(attn_mask < rate_value, min_type)
318
+ if attention_mask is not None:
319
+ attn_mask = attn_mask.masked_fill(attention_mask[:, :, :, : hidden_states.shape[-2]] == min_type, min_type)
320
+ return attn_mask
321
 
322
+
323
+ class DogeSdpaDynamicMaskAttention(DogeDynamicMaskAttention):
324
 
325
  def forward(
326
  self,
 
349
  # sin and cos are specific to RoPE models; cache_position needed for the static cache
350
  cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
351
  key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
352
+
353
+ # calculate dynamic mask from value_states
354
+ dt_states = self.dt_proj(value_states.transpose(1, 2).reshape(bsz, value_states.shape[-2], -1))
355
+ dynamic_mask = torch.exp(self.A * F.softplus(dt_states)).transpose(-1, -2)
356
 
357
+ attn_mask = self.prepare_dynamic_mask(
358
+ hidden_states=hidden_states,
359
+ dynamic_mask=dynamic_mask,
360
+ dynamic_mask_ratio=self.dynamic_mask_ratio,
361
+ attention_mask=attention_mask,
362
+ )
 
 
 
 
363
 
364
  query_states = query_states.contiguous()
365
  key_states = key_states.contiguous()
366
  value_states = value_states.contiguous()
367
 
 
 
 
368
  # NOTE: As of pytorch 2.5.1, cuDNN's SDPA backward pass is still incorrect, so we disable cuDNN SDPA (see https://github.com/pytorch/pytorch/issues/138581)
369
  torch.backends.cuda.enable_cudnn_sdp(False)
370
  attn_output = F.scaled_dot_product_attention(
371
  query_states,
372
  key_states,
373
  value_states,
374
+ attn_mask=attn_mask,
375
  dropout_p=self.attention_dropout if self.training else 0.0,
376
+ enable_gqa=True,
377
  )
378
 
379
  attn_output = attn_output.transpose(1, 2).contiguous()
 
383
  return attn_output, past_key_value
384
 
385
 
386
+ class DogeFlexDynamicMaskAttention(DogeDynamicMaskAttention):
387
+
388
+ def forward(
389
+ self,
390
+ hidden_states: torch.Tensor,
391
+ attention_mask: Optional[torch.Tensor] = None,
392
+ position_ids: Optional[torch.LongTensor] = None,
393
+ past_key_value: Optional[Cache] = None,
394
+ cache_position: Optional[torch.LongTensor] = None,
395
+ position_embeddings: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
396
+ **kwargs,
397
+ ) -> Tuple[torch.Tensor, Optional[Cache]]:
398
+ bsz, q_len, _ = hidden_states.shape
399
+
400
+ query_states = self.q_proj(hidden_states)
401
+ key_states = self.k_proj(hidden_states)
402
+ value_states = self.v_proj(hidden_states)
403
+
404
+ query_states = query_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
405
+ key_states = key_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
406
+ value_states = value_states.view(bsz, q_len, -1, self.head_dim).transpose(1, 2)
407
+
408
+ cos, sin = position_embeddings
409
+ query_states, key_states = apply_QK_rotary_pos_emb(query_states, key_states, cos, sin)
410
+
411
+ if past_key_value is not None:
412
+ # sin and cos are specific to RoPE models; cache_position needed for the static cache
413
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
414
+ key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
415
+
416
+ dt_states = self.dt_proj(value_states.transpose(1, 2).reshape(bsz, value_states.shape[-2], -1))
417
+ dynamic_mask = torch.exp(self.A * F.softplus(dt_states)).transpose(-1, -2)
418
+
419
+ attn_mask = self.prepare_dynamic_mask(
420
+ hidden_states=hidden_states,
421
+ dynamic_mask=dynamic_mask,
422
+ dynamic_mask_ratio=self.dynamic_mask_ratio,
423
+ attention_mask=attention_mask,
424
+ )
425
+ # TODO: flex_attention: Captured buffers that require grad are not yet supported.
426
+ # NOTE: So we only use flex_attention in inference mode.
427
+ def dynamic_mask_mod(score, batch, head, q_idx, kv_idx):
428
+ score = score + attn_mask[batch][head][q_idx][kv_idx]
429
+ return score
430
+
431
+ attn_output = flex_attention(
432
+ query_states,
433
+ key_states,
434
+ value_states,
435
+ score_mod=dynamic_mask_mod,
436
+ enable_gqa=True,
437
+ )
438
+
439
+ attn_output = attn_output.transpose(1, 2).contiguous()
440
+ attn_output = attn_output.view(bsz, q_len, -1)
441
+ attn_output = self.o_proj(attn_output)
442
+
443
+ return attn_output, past_key_value
444
+
445
+
446
  DOGE_ATTENTION_CLASSES = {
447
+ "flex_attention": DogeFlexDynamicMaskAttention,
448
  "eager": DogeDynamicMaskAttention,
449
+ "sdpa": DogeSdpaDynamicMaskAttention,
450
  }
451
 
452
 
 
614
  supports_gradient_checkpointing = True
615
  _no_split_modules = ["DogeDecoderLayer"]
616
  _skip_keys_device_placement = ["past_key_values"]
617
+ _supports_flex_attn = True
618
  _supports_sdpa = True
619
  _supports_cache_class = True
620
  _supports_quantized_cache = True
 
789
  all_self_attns = () if output_attentions else None
790
  next_decoder_cache = None
791
 
792
+ for decoder_layer in self.layers[: self.config.num_hidden_layers]:
793
  if output_hidden_states:
794
  all_hidden_states += (hidden_states,)
795
 
 
973
  input_ids: torch.LongTensor = None,
974
  attention_mask: Optional[torch.Tensor] = None,
975
  position_ids: Optional[torch.LongTensor] = None,
976
+ past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
977
  inputs_embeds: Optional[torch.FloatTensor] = None,
978
  labels: Optional[torch.LongTensor] = None,
979
  use_cache: Optional[bool] = None,
 
982
  return_dict: Optional[bool] = None,
983
  cache_position: Optional[torch.LongTensor] = None,
984
  num_logits_to_keep: int = 0,
985
+ **kwargs,
986
  ) -> Union[Tuple, CausalLMOutputWithPast]:
987
  r"""
988
  Args:
 
1016
  output_hidden_states=output_hidden_states,
1017
  return_dict=return_dict,
1018
  cache_position=cache_position,
1019
+ **kwargs,
1020
  )
1021
 
1022
  hidden_states = outputs[0]
 
1026
 
1027
  loss = None
1028
  if labels is not None:
1029
+ loss = self.loss_function(logits=logits, labels=labels, vocab_size=self.vocab_size, **kwargs)
1030
 
1031
  if not return_dict:
1032
  output = (logits,) + outputs[1:]