libokj's picture
Upload 110 files
c0ec7e6
raw
history blame
4.3 kB
from itertools import zip_longest
import torch
class MultitaskLoss(torch.nn.Module):
"""A generic multitask loss class that takes a tuple of loss functions as input"""
def __init__(self, loss_fns, reduction='sum'):
super().__init__()
self.n_tasks = len(loss_fns) # assuming the number of tasks is equal to the number of loss functions
self.loss_fns = loss_fns # store the tuple of loss functions
self.reduction = reduction
def forward(self, preds, target):
if isinstance(preds, torch.Tensor):
preds = (preds,)
if isinstance(target, torch.Tensor):
target = (target,)
# compute the weighted losses for each task by applying the corresponding loss function and weight
# losses = [weight * loss_fn(p, t)
# for weight, loss_fn, p, t in zip_longest(self.weights, self.loss_fns, preds, target)]
losses = []
for loss_fn, p, t in zip_longest(self.loss_fns, preds, target):
if t is not None:
loss = loss_fn(p, t)
else:
loss = loss_fn(p)
losses.append(loss)
reduced_loss = None
# apply reduction if specified
if self.reduction == 'sum':
reduced_loss = sum(losses)
elif self.reduction == 'mean':
reduced_loss = sum(losses) / self.n_tasks
# return the tuple of losses or the reduced value
return reduced_loss
class MultitaskWeightedLoss(MultitaskLoss):
"""A multitask loss class that takes a tuple of loss functions and weights as input"""
def __init__(self, loss_fns, weights, reduction='sum'):
super().__init__(loss_fns, reduction)
self.weights = weights # store the tuple of weights
def forward(self, preds, target):
if isinstance(preds, torch.Tensor):
preds = (preds,)
if isinstance(target, torch.Tensor):
target = (target,)
# compute the weighted losses for each task by applying the corresponding loss function and weight
# losses = [weight * loss_fn(p, t)
# for weight, loss_fn, p, t in zip_longest(self.weights, self.loss_fns, preds, target)]
losses = []
for weight, loss_fn, p, t in zip_longest(self.weights, self.loss_fns, preds, target):
if t is not None:
loss = weight * loss_fn(p, t)
else:
loss = weight * loss_fn(p)
losses.append(loss)
reduced_loss = None
# apply reduction if specified
if self.reduction == 'sum':
reduced_loss = sum(losses)
elif self.reduction == 'mean':
reduced_loss = sum(losses) / self.n_tasks
# return the tuple of losses or the reduced value
return reduced_loss
class MultitaskUncertaintyLoss(MultitaskLoss):
"""
Modified from https://arxiv.org/abs/1705.07115.
Removed task-specific scale factor for flexibility.
"""
def __init__(self, loss_fns):
# for loss_fn in loss_fns:
# loss_fn.reduction = 'none'
super().__init__(loss_fns, reduction='none')
self.log_vars = torch.nn.Parameter(torch.zeros(self.n_tasks, requires_grad=True))
def forward(self, preds, targets, rescale=True):
losses = super().forward(preds, targets)
stds = torch.exp(self.log_vars / 2)
coeffs = 1 / (stds ** 2)
loss = coeffs * losses + torch.log(stds)
return loss
class MultitaskAutomaticWeightedLoss(MultitaskLoss):
"""Automatically weighted multitask loss
Params:
loss_fns: tuple of loss functions
num: int, the number of losses
x: multitask loss
Examples:
loss1 = 1
loss2 = 2
awl = AutomaticWeightedLoss(2)
loss_sum = awl(loss1, loss2)
"""
def __init__(self, loss_fns):
super().__init__(loss_fns, reduction='none')
self.params = torch.nn.Parameter(torch.ones(self.n_tasks, requires_grad=True))
def forward(self, preds, target):
losses = super().forward(preds, target)
loss = sum(
0.5 / (param ** 2) * loss + torch.log(1 + param ** 2)
for param, loss in zip(self.params, losses)
)
return loss