Spaces:
Build error
Build error
from transformers import Trainer | |
import torch | |
def get_custom_trainer(weights: torch.Tensor): | |
class CustomTrainer(Trainer): # got from https://huggingface.co/docs/transformers/main_classes/trainer | |
def compute_loss(self, model, inputs, return_outputs=False): | |
# recuperate labels | |
labels = inputs.get("labels") | |
# forward pass | |
outputs = model(**inputs) | |
# recuperate logits | |
logits = outputs.get("logits") | |
# compute custom loss (passing the weights) | |
loss_fct = nn.CrossEntropyLoss(weight=weights) | |
loss = loss_fct(logits.view(-1, self.model.config.num_labels), labels.view(-1)) | |
return (loss, outputs) if return_outputs else loss | |
return CustomTrainer | |