Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
File size: 4,296 Bytes
c0ec7e6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 |
from itertools import zip_longest
import torch
class MultitaskLoss(torch.nn.Module):
"""A generic multitask loss class that takes a tuple of loss functions as input"""
def __init__(self, loss_fns, reduction='sum'):
super().__init__()
self.n_tasks = len(loss_fns) # assuming the number of tasks is equal to the number of loss functions
self.loss_fns = loss_fns # store the tuple of loss functions
self.reduction = reduction
def forward(self, preds, target):
if isinstance(preds, torch.Tensor):
preds = (preds,)
if isinstance(target, torch.Tensor):
target = (target,)
# compute the weighted losses for each task by applying the corresponding loss function and weight
# losses = [weight * loss_fn(p, t)
# for weight, loss_fn, p, t in zip_longest(self.weights, self.loss_fns, preds, target)]
losses = []
for loss_fn, p, t in zip_longest(self.loss_fns, preds, target):
if t is not None:
loss = loss_fn(p, t)
else:
loss = loss_fn(p)
losses.append(loss)
reduced_loss = None
# apply reduction if specified
if self.reduction == 'sum':
reduced_loss = sum(losses)
elif self.reduction == 'mean':
reduced_loss = sum(losses) / self.n_tasks
# return the tuple of losses or the reduced value
return reduced_loss
class MultitaskWeightedLoss(MultitaskLoss):
"""A multitask loss class that takes a tuple of loss functions and weights as input"""
def __init__(self, loss_fns, weights, reduction='sum'):
super().__init__(loss_fns, reduction)
self.weights = weights # store the tuple of weights
def forward(self, preds, target):
if isinstance(preds, torch.Tensor):
preds = (preds,)
if isinstance(target, torch.Tensor):
target = (target,)
# compute the weighted losses for each task by applying the corresponding loss function and weight
# losses = [weight * loss_fn(p, t)
# for weight, loss_fn, p, t in zip_longest(self.weights, self.loss_fns, preds, target)]
losses = []
for weight, loss_fn, p, t in zip_longest(self.weights, self.loss_fns, preds, target):
if t is not None:
loss = weight * loss_fn(p, t)
else:
loss = weight * loss_fn(p)
losses.append(loss)
reduced_loss = None
# apply reduction if specified
if self.reduction == 'sum':
reduced_loss = sum(losses)
elif self.reduction == 'mean':
reduced_loss = sum(losses) / self.n_tasks
# return the tuple of losses or the reduced value
return reduced_loss
class MultitaskUncertaintyLoss(MultitaskLoss):
"""
Modified from https://arxiv.org/abs/1705.07115.
Removed task-specific scale factor for flexibility.
"""
def __init__(self, loss_fns):
# for loss_fn in loss_fns:
# loss_fn.reduction = 'none'
super().__init__(loss_fns, reduction='none')
self.log_vars = torch.nn.Parameter(torch.zeros(self.n_tasks, requires_grad=True))
def forward(self, preds, targets, rescale=True):
losses = super().forward(preds, targets)
stds = torch.exp(self.log_vars / 2)
coeffs = 1 / (stds ** 2)
loss = coeffs * losses + torch.log(stds)
return loss
class MultitaskAutomaticWeightedLoss(MultitaskLoss):
"""Automatically weighted multitask loss
Params:
loss_fns: tuple of loss functions
num: int, the number of losses
x: multitask loss
Examples:
loss1 = 1
loss2 = 2
awl = AutomaticWeightedLoss(2)
loss_sum = awl(loss1, loss2)
"""
def __init__(self, loss_fns):
super().__init__(loss_fns, reduction='none')
self.params = torch.nn.Parameter(torch.ones(self.n_tasks, requires_grad=True))
def forward(self, preds, target):
losses = super().forward(preds, target)
loss = sum(
0.5 / (param ** 2) * loss + torch.log(1 + param ** 2)
for param, loss in zip(self.params, losses)
)
return loss
|