return past hidden states when `output_hidden_states` provided
#59
by
noahtren
- opened
- modeling_phi.py +2 -1
modeling_phi.py
CHANGED
@@ -947,6 +947,7 @@ class PhiForCausalLM(PhiPreTrainedModel):
|
|
947 |
input_ids: torch.LongTensor,
|
948 |
past_key_values: Optional[Union[torch.FloatTensor, InferenceParams]] = None,
|
949 |
attention_mask: Optional[torch.BoolTensor] = None,
|
|
|
950 |
labels: Optional[torch.LongTensor] = None,
|
951 |
**kwargs,
|
952 |
) -> CausalLMOutputWithPast:
|
@@ -957,4 +958,4 @@ class PhiForCausalLM(PhiPreTrainedModel):
|
|
957 |
if labels is not None:
|
958 |
loss = self.loss(lm_logits, labels)
|
959 |
|
960 |
-
return CausalLMOutputWithPast(loss=loss, logits=lm_logits, past_key_values=past_key_values)
|
|
|
947 |
input_ids: torch.LongTensor,
|
948 |
past_key_values: Optional[Union[torch.FloatTensor, InferenceParams]] = None,
|
949 |
attention_mask: Optional[torch.BoolTensor] = None,
|
950 |
+
output_hidden_states: Optional[bool] = None,
|
951 |
labels: Optional[torch.LongTensor] = None,
|
952 |
**kwargs,
|
953 |
) -> CausalLMOutputWithPast:
|
|
|
958 |
if labels is not None:
|
959 |
loss = self.loss(lm_logits, labels)
|
960 |
|
961 |
+
return CausalLMOutputWithPast(loss=loss, logits=lm_logits, past_key_values=past_key_values, hidden_states=hidden_states if output_hidden_states else None)
|