Add in labels to forward
Browse files- mosaic_gpt.py +22 -3
mosaic_gpt.py
CHANGED
@@ -238,6 +238,7 @@ class MosaicGPT(PreTrainedModel):
|
|
238 |
input_ids: torch.LongTensor,
|
239 |
past_key_values: Optional[List[Tuple[torch.FloatTensor]]] = None,
|
240 |
attention_mask: Optional[torch.ByteTensor] = None,
|
|
|
241 |
prefix_mask: Optional[torch.ByteTensor] = None,
|
242 |
sequence_id: Optional[torch.LongTensor] = None,
|
243 |
return_dict: Optional[bool] = None,
|
@@ -370,9 +371,27 @@ class MosaicGPT(PreTrainedModel):
|
|
370 |
)
|
371 |
logits *= self.logit_scale
|
372 |
|
373 |
-
|
374 |
-
|
375 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
376 |
|
377 |
# Param Initialization, needed for device='meta' fast initialization
|
378 |
def param_init_fn(self, module):
|
|
|
238 |
input_ids: torch.LongTensor,
|
239 |
past_key_values: Optional[List[Tuple[torch.FloatTensor]]] = None,
|
240 |
attention_mask: Optional[torch.ByteTensor] = None,
|
241 |
+
labels: Optional[torch.LongTensor] = None,
|
242 |
prefix_mask: Optional[torch.ByteTensor] = None,
|
243 |
sequence_id: Optional[torch.LongTensor] = None,
|
244 |
return_dict: Optional[bool] = None,
|
|
|
371 |
)
|
372 |
logits *= self.logit_scale
|
373 |
|
374 |
+
# compute loss from logits
|
375 |
+
if labels is not None:
|
376 |
+
# Shift so that tokens < n predict n
|
377 |
+
shift_logits = logits[..., :-1, :].contiguous()
|
378 |
+
shift_labels = labels[..., 1:].contiguous()
|
379 |
+
# Flatten the tokens
|
380 |
+
loss_fct = nn.CrossEntropyLoss()
|
381 |
+
loss = loss_fct(
|
382 |
+
shift_logits.view(
|
383 |
+
-1, self.transformer.wte.num_embeddings
|
384 |
+
),
|
385 |
+
shift_labels.view(-1),
|
386 |
+
)
|
387 |
+
return CausalLMOutputWithPast(loss=loss, logits=logits,
|
388 |
+
past_key_values=past_key_values,
|
389 |
+
hidden_states=all_hidden_states)
|
390 |
+
|
391 |
+
else:
|
392 |
+
return CausalLMOutputWithPast(logits=logits,
|
393 |
+
past_key_values=past_key_values,
|
394 |
+
hidden_states=all_hidden_states)
|
395 |
|
396 |
# Param Initialization, needed for device='meta' fast initialization
|
397 |
def param_init_fn(self, module):
|