File size: 10,831 Bytes
1f53a4c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
#!/usr/bin/env python
# -*- coding: utf-8 -*-

import torch
import torch.nn as nn
from typing import Dict, Union

# Alias of typing
# eg. {'labels': {'label_A: torch.Tensor([0, 1, ...]), ...}}
LabelDict = Dict[str, Dict[str, Union[torch.IntTensor, torch.FloatTensor]]]


class RMSELoss(nn.Module):
    """
    Class to calculate RMSE.
    """
    def __init__(self, eps: float = 1e-7) -> None:
        """
        Args:
            eps (float, optional): value to avoid 0. Defaults to 1e-7.
        """
        super().__init__()
        self.mse = nn.MSELoss()
        self.eps = eps

    def forward(self, yhat: float, y: float) -> torch.FloatTensor:
        """
        Calculate RMSE.

        Args:
            yhat (float): prediction value
            y (float): ground truth value

        Returns:
            float: RMSE
        """
        _loss = self.mse(yhat, y) + self.eps
        return torch.sqrt(_loss)


class Regularization:
    """
    Class to calculate regularization loss.

    Args:
        object (object): object
    """
    def __init__(self, order: int, weight_decay: float) -> None:
        """
        The initialization of Regularization class.

        Args:
            order: (int) norm order number
            weight_decay: (float) weight decay rate
        """
        super().__init__()
        self.order = order
        self.weight_decay = weight_decay

    def __call__(self, network: nn.Module) -> torch.FloatTensor:
        """"
        Calculates regularization(self.order) loss for network.

        Args:
            model: (torch.nn.Module object)

        Returns:
            torch.FloatTensor: the regularization(self.order) loss
        """
        reg_loss = 0
        for name, w in network.named_parameters():
            if 'weight' in name:
                reg_loss = reg_loss + torch.norm(w, p=self.order)
        reg_loss = self.weight_decay * reg_loss
        return reg_loss


class NegativeLogLikelihood(nn.Module):
    """
    Class to calculate RMSE.
    """
    def __init__(self, device: torch.device) -> None:
        """
        Args:
            device (torch.device): device
        """
        super().__init__()
        self.L2_reg = 0.05
        self.reg = Regularization(order=2, weight_decay=self.L2_reg)
        self.device = device

    def forward(
                self,
                output: torch.FloatTensor,
                label: torch.IntTensor,
                periods: torch.FloatTensor,
                network: nn.Module
                ) -> torch.FloatTensor:
        """
        Calculates Negative Log Likelihood.

        Args:
            output (torch.FloatTensor): prediction value, ie risk prediction
            label (torch.IntTensor): occurrence of event
            periods (torch.FloatTensor): period
            network (nn.Network): network

        Returns:
            torch.FloatTensor: Negative Log Likelihood
        """
        mask = torch.ones(periods.shape[0], periods.shape[0]).to(self.device)  # output and mask should be on the same device.
        mask[(periods.T - periods) > 0] = 0

        _loss = torch.exp(output) * mask
        # Note: torch.sum(_loss, dim=0) possibly returns nan, in particular MLP.
        _loss = torch.sum(_loss, dim=0) / torch.sum(mask, dim=0)
        _loss = torch.log(_loss).reshape(-1, 1)
        num_occurs = torch.sum(label)

        if num_occurs.item() == 0.0:
            loss = torch.tensor([1e-7], requires_grad=True).to(self.device)  # To avoid zero division, set small value as loss
            return loss
        else:
            neg_log_loss = -torch.sum((output - _loss) * label) / num_occurs
            l2_loss = self.reg(network)
            loss = neg_log_loss + l2_loss
            return loss


class ClsCriterion:
    """
    Class of criterion for classification.
    """
    def __init__(self, device: torch.device = None) -> None:
        """
        Set CrossEntropyLoss.

        Args:
            device (torch.device): device
        """
        self.device = device
        self.criterion = nn.CrossEntropyLoss()

    def __call__(
                self,
                outputs: Dict[str, torch.FloatTensor],
                labels: Dict[str, LabelDict]
                ) -> Dict[str, torch.FloatTensor]:
        """
        Calculate loss.

        Args:
            outputs (Dict[str, torch.FloatTensor], optional): output
            labels (Dict[str, LabelDict]): labels

        Returns:
            Dict[str, torch.FloatTensor]: loss for each label and their total loss

        # No reshape and no cast:
        output: [64, 2]: torch.float32
        label:  [64]   : torch.int64
        label.dtype should be torch.int64, otherwise nn.CrossEntropyLoss() causes error.

        eg.
        outputs = {'label_A': [[0.8, 0.2], ...] 'label_B': [[0.7, 0.3]], ...}
        labels = { 'labels': {'label_A: 1: [1, 1, 0, ...], 'label_B': [0, 0, 1, ...], ...} }

        -> losses = {total: loss_total, label_A: loss_A, label_B: loss_B, ... }
        """
        _labels = labels['labels']

        # loss for each label and total of their losses
        losses = dict()
        losses['total'] = torch.tensor([0.0], requires_grad=True).to(self.device)
        for label_name in labels['labels'].keys():
            _output = outputs[label_name]
            _label = _labels[label_name]
            _label_loss = self.criterion(_output, _label)
            losses[label_name] = _label_loss
            losses['total'] = torch.add(losses['total'], _label_loss)
        return losses


class RegCriterion:
    """
    Class of criterion for regression.
    """
    def __init__(self, criterion_name: str = None, device: torch.device = None) -> None:
        """
        Set MSE, RMSE or MAE.

        Args:
            criterion_name (str): 'MSE', 'RMSE', or 'MAE'
            device (torch.device): device
        """
        self.device = device

        if criterion_name == 'MSE':
            self.criterion = nn.MSELoss()
        elif criterion_name == 'RMSE':
            self.criterion = RMSELoss()
        elif criterion_name == 'MAE':
            self.criterion = nn.L1Loss()
        else:
            raise ValueError(f"Invalid criterion for regression: {criterion_name}.")

    def __call__(
                self,
                outputs: Dict[str, torch.FloatTensor],
                labels: Dict[str, LabelDict]
                ) -> Dict[str, torch.FloatTensor]:
        """
        Calculate loss.

        Args:
            Args:
            outputs (Dict[str, torch.FloatTensor], optional): output
            labels (Dict[str, LabelDict]): labels

        Returns:
            Dict[str, torch.FloatTensor]: loss for each label and their total loss

        # Reshape and cast
        output: [64, 1] -> [64]: torch.float32
        label:             [64]: torch.float64 -> torch.float32
        # label.dtype should be torch.float32, otherwise cannot backward.

        eg.
        outputs = {'label_A': [[10.8], ...] 'label_B': [[15.7]], ...}
        labels = {'labels': {'label_A: 1: [10, 9, ...], 'label_B': [12, 17,], ...}}
        -> losses = {total: loss_total, label_A: loss_A, label_B: loss_B, ... }
        """
        _outputs = {label_name: _output.squeeze() for label_name, _output in outputs.items()}
        _labels = {label_name: _label.to(torch.float32) for label_name, _label in labels['labels'].items()}

        # loss for each label and total of their losses
        losses = dict()
        losses['total'] = torch.tensor([0.0], requires_grad=True).to(self.device)
        for label_name in labels['labels'].keys():
            _output = _outputs[label_name]
            _label = _labels[label_name]
            _label_loss = self.criterion(_output, _label)
            losses[label_name] = _label_loss
            losses['total'] = torch.add(losses['total'], _label_loss)
        return losses


class DeepSurvCriterion:
    """
    Class of criterion for deepsurv.
    """
    def __init__(self, device: torch.device = None) -> None:
        """
        Set NegativeLogLikelihood.

        Args:
            device (torch.device, optional): device
        """
        self.device = device
        self.criterion = NegativeLogLikelihood(self.device).to(self.device)

    def __call__(
                self,
                outputs: Dict[str, torch.FloatTensor],
                labels: Dict[str, Union[LabelDict, torch.IntTensor, nn.Module]]
                ) -> Dict[str, torch.FloatTensor]:
        """
        Calculate loss.

        Args:
            outputs (Dict[str, torch.FloatTensor], optional): output
            labels (Dict[str, Union[LabelDict, torch.IntTensor, nn.Module]]): labels, periods, and network

        Returns:
            Dict[str, torch.FloatTensor]: loss for each label and their total loss

        # Reshape and no cast
        output:         [64, 1]: torch.float32
        label:  [64] -> [64, 1]: torch.int64
        period: [64] -> [64, 1]: torch.float32

        eg.
        outputs = {'label_A': [[10.8], ...] 'label_B': [[15.7]], ...}
        labels = {
                    'labels': {'label_A: 1: [1, 0, 1, ...] },
                    'periods': [5, 10, 7, ...],
                    'network': network
                }
        -> losses = {total: loss_total, label_A: loss_A, label_B: loss_B, ... }
        """
        _labels = {label_name: _label.reshape(-1, 1) for label_name, _label in labels['labels'].items()}
        _periods = labels['periods'].reshape(-1, 1)
        _network = labels['network']

        # loss for each label and total of their losses
        losses = dict()
        losses['total'] = torch.tensor([0.0], requires_grad=True).to(self.device)
        for label_name in labels['labels'].keys():
            _output = outputs[label_name]
            _label = _labels[label_name]
            _label_loss = self.criterion(_output, _label, _periods, _network)
            losses[label_name] = _label_loss
            losses['total'] = torch.add(losses['total'], _label_loss)
        return losses


def set_criterion(
                criterion_name: str,
                device: torch.device
                ) -> Union[ClsCriterion, RegCriterion, DeepSurvCriterion]:
    """
    Return criterion class

    Args:
        criterion_name (str): criterion name
        device (torch.device): device

    Returns:
        Union[ClsCriterion, RegCriterion, DeepSurvCriterion]: criterion class
    """

    if criterion_name == 'CEL':
        return ClsCriterion(device=device)

    elif criterion_name in ['MSE', 'RMSE', 'MAE']:
        return RegCriterion(criterion_name=criterion_name, device=device)

    elif criterion_name == 'NLL':
        return DeepSurvCriterion(device=device)

    else:
        raise ValueError(f"Invalid criterion: {criterion_name}.")