File size: 4,296 Bytes
c0ec7e6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
from itertools import zip_longest

import torch


class MultitaskLoss(torch.nn.Module):
    """A generic multitask loss class that takes a tuple of loss functions as input"""
    def __init__(self, loss_fns, reduction='sum'):
        super().__init__()
        self.n_tasks = len(loss_fns)  # assuming the number of tasks is equal to the number of loss functions
        self.loss_fns = loss_fns  # store the tuple of loss functions
        self.reduction = reduction

    def forward(self, preds, target):
        if isinstance(preds, torch.Tensor):
            preds = (preds,)
        if isinstance(target, torch.Tensor):
            target = (target,)
        # compute the weighted losses for each task by applying the corresponding loss function and weight
        # losses = [weight * loss_fn(p, t)
        #           for weight, loss_fn, p, t in zip_longest(self.weights, self.loss_fns, preds, target)]
        losses = []
        for loss_fn, p, t in zip_longest(self.loss_fns, preds, target):
            if t is not None:
                loss = loss_fn(p, t)
            else:
                loss = loss_fn(p)
            losses.append(loss)

        reduced_loss = None
        # apply reduction if specified
        if self.reduction == 'sum':
            reduced_loss = sum(losses)
        elif self.reduction == 'mean':
            reduced_loss = sum(losses) / self.n_tasks
        # return the tuple of losses or the reduced value
        return reduced_loss


class MultitaskWeightedLoss(MultitaskLoss):
    """A multitask loss class that takes a tuple of loss functions and weights as input"""

    def __init__(self, loss_fns, weights, reduction='sum'):
        super().__init__(loss_fns, reduction)
        self.weights = weights  # store the tuple of weights

    def forward(self, preds, target):
        if isinstance(preds, torch.Tensor):
            preds = (preds,)
        if isinstance(target, torch.Tensor):
            target = (target,)
        # compute the weighted losses for each task by applying the corresponding loss function and weight
        # losses = [weight * loss_fn(p, t)
        #           for weight, loss_fn, p, t in zip_longest(self.weights, self.loss_fns, preds, target)]
        losses = []
        for weight, loss_fn, p, t in zip_longest(self.weights, self.loss_fns, preds, target):
            if t is not None:
                loss = weight * loss_fn(p, t)
            else:
                loss = weight * loss_fn(p)
            losses.append(loss)

        reduced_loss = None
        # apply reduction if specified
        if self.reduction == 'sum':
            reduced_loss = sum(losses)
        elif self.reduction == 'mean':
            reduced_loss = sum(losses) / self.n_tasks
        # return the tuple of losses or the reduced value
        return reduced_loss


class MultitaskUncertaintyLoss(MultitaskLoss):
    """
    Modified from https://arxiv.org/abs/1705.07115.
    Removed task-specific scale factor for flexibility.
    """

    def __init__(self, loss_fns):
        # for loss_fn in loss_fns:
        #     loss_fn.reduction = 'none'
        super().__init__(loss_fns, reduction='none')
        self.log_vars = torch.nn.Parameter(torch.zeros(self.n_tasks, requires_grad=True))

    def forward(self, preds, targets, rescale=True):
        losses = super().forward(preds, targets)
        stds = torch.exp(self.log_vars / 2)
        coeffs = 1 / (stds ** 2)
        loss = coeffs * losses + torch.log(stds)

        return loss


class MultitaskAutomaticWeightedLoss(MultitaskLoss):
    """Automatically weighted multitask loss

    Params:
        loss_fns: tuple of loss functions
        num: int, the number of losses
        x: multitask loss
    Examples:
        loss1 = 1
        loss2 = 2
        awl = AutomaticWeightedLoss(2)
        loss_sum = awl(loss1, loss2)
    """

    def __init__(self, loss_fns):
        super().__init__(loss_fns, reduction='none')
        self.params = torch.nn.Parameter(torch.ones(self.n_tasks, requires_grad=True))

    def forward(self, preds, target):
        losses = super().forward(preds, target)
        loss = sum(
            0.5 / (param ** 2) * loss + torch.log(1 + param ** 2)
            for param, loss in zip(self.params, losses)
        )
        return loss