Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
from itertools import zip_longest | |
import torch | |
class MultitaskLoss(torch.nn.Module): | |
"""A generic multitask loss class that takes a tuple of loss functions as input""" | |
def __init__(self, loss_fns, reduction='sum'): | |
super().__init__() | |
self.n_tasks = len(loss_fns) # assuming the number of tasks is equal to the number of loss functions | |
self.loss_fns = loss_fns # store the tuple of loss functions | |
self.reduction = reduction | |
def forward(self, preds, target): | |
if isinstance(preds, torch.Tensor): | |
preds = (preds,) | |
if isinstance(target, torch.Tensor): | |
target = (target,) | |
# compute the weighted losses for each task by applying the corresponding loss function and weight | |
# losses = [weight * loss_fn(p, t) | |
# for weight, loss_fn, p, t in zip_longest(self.weights, self.loss_fns, preds, target)] | |
losses = [] | |
for loss_fn, p, t in zip_longest(self.loss_fns, preds, target): | |
if t is not None: | |
loss = loss_fn(p, t) | |
else: | |
loss = loss_fn(p) | |
losses.append(loss) | |
reduced_loss = None | |
# apply reduction if specified | |
if self.reduction == 'sum': | |
reduced_loss = sum(losses) | |
elif self.reduction == 'mean': | |
reduced_loss = sum(losses) / self.n_tasks | |
# return the tuple of losses or the reduced value | |
return reduced_loss | |
class MultitaskWeightedLoss(MultitaskLoss): | |
"""A multitask loss class that takes a tuple of loss functions and weights as input""" | |
def __init__(self, loss_fns, weights, reduction='sum'): | |
super().__init__(loss_fns, reduction) | |
self.weights = weights # store the tuple of weights | |
def forward(self, preds, target): | |
if isinstance(preds, torch.Tensor): | |
preds = (preds,) | |
if isinstance(target, torch.Tensor): | |
target = (target,) | |
# compute the weighted losses for each task by applying the corresponding loss function and weight | |
# losses = [weight * loss_fn(p, t) | |
# for weight, loss_fn, p, t in zip_longest(self.weights, self.loss_fns, preds, target)] | |
losses = [] | |
for weight, loss_fn, p, t in zip_longest(self.weights, self.loss_fns, preds, target): | |
if t is not None: | |
loss = weight * loss_fn(p, t) | |
else: | |
loss = weight * loss_fn(p) | |
losses.append(loss) | |
reduced_loss = None | |
# apply reduction if specified | |
if self.reduction == 'sum': | |
reduced_loss = sum(losses) | |
elif self.reduction == 'mean': | |
reduced_loss = sum(losses) / self.n_tasks | |
# return the tuple of losses or the reduced value | |
return reduced_loss | |
class MultitaskUncertaintyLoss(MultitaskLoss): | |
""" | |
Modified from https://arxiv.org/abs/1705.07115. | |
Removed task-specific scale factor for flexibility. | |
""" | |
def __init__(self, loss_fns): | |
# for loss_fn in loss_fns: | |
# loss_fn.reduction = 'none' | |
super().__init__(loss_fns, reduction='none') | |
self.log_vars = torch.nn.Parameter(torch.zeros(self.n_tasks, requires_grad=True)) | |
def forward(self, preds, targets, rescale=True): | |
losses = super().forward(preds, targets) | |
stds = torch.exp(self.log_vars / 2) | |
coeffs = 1 / (stds ** 2) | |
loss = coeffs * losses + torch.log(stds) | |
return loss | |
class MultitaskAutomaticWeightedLoss(MultitaskLoss): | |
"""Automatically weighted multitask loss | |
Params: | |
loss_fns: tuple of loss functions | |
num: int, the number of losses | |
x: multitask loss | |
Examples: | |
loss1 = 1 | |
loss2 = 2 | |
awl = AutomaticWeightedLoss(2) | |
loss_sum = awl(loss1, loss2) | |
""" | |
def __init__(self, loss_fns): | |
super().__init__(loss_fns, reduction='none') | |
self.params = torch.nn.Parameter(torch.ones(self.n_tasks, requires_grad=True)) | |
def forward(self, preds, target): | |
losses = super().forward(preds, target) | |
loss = sum( | |
0.5 / (param ** 2) * loss + torch.log(1 + param ** 2) | |
for param, loss in zip(self.params, losses) | |
) | |
return loss | |