Files changed (1) hide show
  1. modeling_pharia.py +15 -2
modeling_pharia.py CHANGED
@@ -764,9 +764,22 @@ class PhariaForCausalLM(PhariaPreTrainedModel):
764
 
765
  hidden_states = outputs[0]
766
  logits = self.lm_head(hidden_states)
767
-
 
 
 
 
 
 
 
 
 
 
 
 
 
768
  return CausalLMOutputWithPast(
769
- loss=0.0,
770
  logits=logits,
771
  past_key_values=outputs.past_key_values,
772
  hidden_states=outputs.hidden_states,
 
764
 
765
  hidden_states = outputs[0]
766
  logits = self.lm_head(hidden_states)
767
+ loss = 0.0
768
+ if labels is not None:
769
+ # Shift logits and labels for causal language modeling
770
+ shift_logits = logits[..., :-1, :].contiguous()
771
+ shift_labels = outputs['labels'][..., 1:].contiguous()
772
+
773
+ # Flatten the tokens
774
+ shift_logits = shift_logits.view(-1, shift_logits.size(-1))
775
+ shift_labels = shift_labels.view(-1)
776
+
777
+ # Compute loss
778
+ loss_fct = torch.nn.CrossEntropyLoss(ignore_index=1) # Pad token ID for Pharia is 1
779
+ loss = loss_fct(shift_logits, shift_labels)
780
+
781
  return CausalLMOutputWithPast(
782
+ loss=loss,
783
  logits=logits,
784
  past_key_values=outputs.past_key_values,
785
  hidden_states=outputs.hidden_states,