return past hidden states when `output_hidden_states` provided

#59
by noahtren - opened
Files changed (1) hide show
  1. modeling_phi.py +2 -1
modeling_phi.py CHANGED
@@ -947,6 +947,7 @@ class PhiForCausalLM(PhiPreTrainedModel):
947
  input_ids: torch.LongTensor,
948
  past_key_values: Optional[Union[torch.FloatTensor, InferenceParams]] = None,
949
  attention_mask: Optional[torch.BoolTensor] = None,
 
950
  labels: Optional[torch.LongTensor] = None,
951
  **kwargs,
952
  ) -> CausalLMOutputWithPast:
@@ -957,4 +958,4 @@ class PhiForCausalLM(PhiPreTrainedModel):
957
  if labels is not None:
958
  loss = self.loss(lm_logits, labels)
959
 
960
- return CausalLMOutputWithPast(loss=loss, logits=lm_logits, past_key_values=past_key_values)
 
947
  input_ids: torch.LongTensor,
948
  past_key_values: Optional[Union[torch.FloatTensor, InferenceParams]] = None,
949
  attention_mask: Optional[torch.BoolTensor] = None,
950
+ output_hidden_states: Optional[bool] = None,
951
  labels: Optional[torch.LongTensor] = None,
952
  **kwargs,
953
  ) -> CausalLMOutputWithPast:
 
958
  if labels is not None:
959
  loss = self.loss(lm_logits, labels)
960
 
961
+ return CausalLMOutputWithPast(loss=loss, logits=lm_logits, past_key_values=past_key_values, hidden_states=hidden_states if output_hidden_states else None)