from itertools import zip_longest import torch class MultitaskLoss(torch.nn.Module): """A generic multitask loss class that takes a tuple of loss functions as input""" def __init__(self, loss_fns, reduction='sum'): super().__init__() self.n_tasks = len(loss_fns) # assuming the number of tasks is equal to the number of loss functions self.loss_fns = loss_fns # store the tuple of loss functions self.reduction = reduction def forward(self, preds, target): if isinstance(preds, torch.Tensor): preds = (preds,) if isinstance(target, torch.Tensor): target = (target,) # compute the weighted losses for each task by applying the corresponding loss function and weight # losses = [weight * loss_fn(p, t) # for weight, loss_fn, p, t in zip_longest(self.weights, self.loss_fns, preds, target)] losses = [] for loss_fn, p, t in zip_longest(self.loss_fns, preds, target): if t is not None: loss = loss_fn(p, t) else: loss = loss_fn(p) losses.append(loss) reduced_loss = None # apply reduction if specified if self.reduction == 'sum': reduced_loss = sum(losses) elif self.reduction == 'mean': reduced_loss = sum(losses) / self.n_tasks # return the tuple of losses or the reduced value return reduced_loss class MultitaskWeightedLoss(MultitaskLoss): """A multitask loss class that takes a tuple of loss functions and weights as input""" def __init__(self, loss_fns, weights, reduction='sum'): super().__init__(loss_fns, reduction) self.weights = weights # store the tuple of weights def forward(self, preds, target): if isinstance(preds, torch.Tensor): preds = (preds,) if isinstance(target, torch.Tensor): target = (target,) # compute the weighted losses for each task by applying the corresponding loss function and weight # losses = [weight * loss_fn(p, t) # for weight, loss_fn, p, t in zip_longest(self.weights, self.loss_fns, preds, target)] losses = [] for weight, loss_fn, p, t in zip_longest(self.weights, self.loss_fns, preds, target): if t is not None: loss = weight * loss_fn(p, t) else: loss = weight * loss_fn(p) losses.append(loss) reduced_loss = None # apply reduction if specified if self.reduction == 'sum': reduced_loss = sum(losses) elif self.reduction == 'mean': reduced_loss = sum(losses) / self.n_tasks # return the tuple of losses or the reduced value return reduced_loss class MultitaskUncertaintyLoss(MultitaskLoss): """ Modified from https://arxiv.org/abs/1705.07115. Removed task-specific scale factor for flexibility. """ def __init__(self, loss_fns): # for loss_fn in loss_fns: # loss_fn.reduction = 'none' super().__init__(loss_fns, reduction='none') self.log_vars = torch.nn.Parameter(torch.zeros(self.n_tasks, requires_grad=True)) def forward(self, preds, targets, rescale=True): losses = super().forward(preds, targets) stds = torch.exp(self.log_vars / 2) coeffs = 1 / (stds ** 2) loss = coeffs * losses + torch.log(stds) return loss class MultitaskAutomaticWeightedLoss(MultitaskLoss): """Automatically weighted multitask loss Params: loss_fns: tuple of loss functions num: int, the number of losses x: multitask loss Examples: loss1 = 1 loss2 = 2 awl = AutomaticWeightedLoss(2) loss_sum = awl(loss1, loss2) """ def __init__(self, loss_fns): super().__init__(loss_fns, reduction='none') self.params = torch.nn.Parameter(torch.ones(self.n_tasks, requires_grad=True)) def forward(self, preds, target): losses = super().forward(preds, target) loss = sum( 0.5 / (param ** 2) * loss + torch.log(1 + param ** 2) for param, loss in zip(self.params, losses) ) return loss