=
adding package
592bfb5
raw
history blame contribute delete
866 Bytes
from transformers import Trainer
import torch
def get_custom_trainer(weights: torch.Tensor):
class CustomTrainer(Trainer): # got from https://huggingface.co/docs/transformers/main_classes/trainer
def compute_loss(self, model, inputs, return_outputs=False):
# recuperate labels
labels = inputs.get("labels")
# forward pass
outputs = model(**inputs)
# recuperate logits
logits = outputs.get("logits")
# compute custom loss (passing the weights)
loss_fct = nn.CrossEntropyLoss(weight=weights)
loss = loss_fct(logits.view(-1, self.model.config.num_labels), labels.view(-1))
return (loss, outputs) if return_outputs else loss
return CustomTrainer