MedicalAILabo commited on
Commit
1f53a4c
·
1 Parent(s): 8630b06

Upload app.py and lib.

Browse files
app.py ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # -*- coding: utf-8 -*-
3
+
4
+ import torch
5
+ import gradio as gr
6
+ from lib import create_model
7
+ from lib.options import ParamSet, _retrieve_parameter, _dispatch_by_group
8
+ from lib.dataloader import ImageMixin
9
+
10
+
11
+ test_weight = './weight_epoch-200_best.pt'
12
+ parameter = './parameters.json'
13
+
14
+ class ImageHandler(ImageMixin):
15
+ def __init__(self, params):
16
+ self.params = params
17
+ self.transform = self._make_transforms()
18
+
19
+ def set_image(self, image):
20
+ image = self.transform(image)
21
+ image = {'image': image.unsqueeze(0)}
22
+ return image
23
+
24
+ def load_parameter(parameter):
25
+ _args = ParamSet()
26
+ params = _retrieve_parameter(parameter)
27
+ for _param, _arg in params.items():
28
+ setattr(_args, _param, _arg)
29
+
30
+ _args.augmentation = 'no'
31
+ _args.sampler = 'no'
32
+ _args.pretrained = False
33
+ _args.mlp = None
34
+ _args.net = _args.model
35
+ _args.device = torch.device('cpu')
36
+
37
+ args_model = _dispatch_by_group(_args, 'model')
38
+ args_dataloader = _dispatch_by_group(_args, 'dataloader')
39
+ return args_model, args_dataloader
40
+
41
+ args_model, args_dataloader = load_parameter(parameter)
42
+ model = create_model(args_model)
43
+ model.load_weight(test_weight)
44
+
45
+ def main(image):
46
+ model.eval()
47
+ image_handler = ImageHandler(args_dataloader)
48
+ image = image_handler.set_image(image)
49
+
50
+ with torch.no_grad():
51
+ outputs = model(image)
52
+
53
+ label_name = list(outputs.keys())[0]
54
+ result = outputs[label_name].detach().numpy().item()
55
+ result = f"{result:.2f}"
56
+ return result
57
+
58
+
59
+ # Gradio
60
+ iface = gr.Interface(fn=main, inputs=[gr.Image(type='pil', image_mode='L')], outputs='text')
61
+ iface.launch()
lib/__init__.py ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # -*- coding: utf-8 -*-
3
+
4
+ from .options import (
5
+ ParamSet,
6
+ set_options,
7
+ save_parameter,
8
+ print_parameter
9
+ )
10
+ from .dataloader import create_dataloader
11
+ from .framework import create_model
12
+ from .metrics import set_eval
13
+ from .logger import BaseLogger
14
+
15
+ __all__ = [
16
+ 'ParamSet',
17
+ 'set_options',
18
+ 'print_parameter',
19
+ 'save_parameter',
20
+ 'create_dataloader',
21
+ 'create_model',
22
+ 'set_eval',
23
+ 'BaseLogger'
24
+ ]
lib/component/__init__.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # -*- coding: utf-8 -*-
3
+
4
+ from .net import create_net
5
+ from .criterion import set_criterion
6
+ from .optimizer import set_optimizer
7
+ from .loss import set_loss_store
8
+ from .likelihood import set_likelihood
9
+
10
+ __all__ = [
11
+ 'create_net',
12
+ 'set_criterion',
13
+ 'set_optimizer',
14
+ 'set_loss_store',
15
+ 'set_likelihood'
16
+ ]
lib/component/criterion.py ADDED
@@ -0,0 +1,332 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # -*- coding: utf-8 -*-
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ from typing import Dict, Union
7
+
8
+ # Alias of typing
9
+ # eg. {'labels': {'label_A: torch.Tensor([0, 1, ...]), ...}}
10
+ LabelDict = Dict[str, Dict[str, Union[torch.IntTensor, torch.FloatTensor]]]
11
+
12
+
13
+ class RMSELoss(nn.Module):
14
+ """
15
+ Class to calculate RMSE.
16
+ """
17
+ def __init__(self, eps: float = 1e-7) -> None:
18
+ """
19
+ Args:
20
+ eps (float, optional): value to avoid 0. Defaults to 1e-7.
21
+ """
22
+ super().__init__()
23
+ self.mse = nn.MSELoss()
24
+ self.eps = eps
25
+
26
+ def forward(self, yhat: float, y: float) -> torch.FloatTensor:
27
+ """
28
+ Calculate RMSE.
29
+
30
+ Args:
31
+ yhat (float): prediction value
32
+ y (float): ground truth value
33
+
34
+ Returns:
35
+ float: RMSE
36
+ """
37
+ _loss = self.mse(yhat, y) + self.eps
38
+ return torch.sqrt(_loss)
39
+
40
+
41
+ class Regularization:
42
+ """
43
+ Class to calculate regularization loss.
44
+
45
+ Args:
46
+ object (object): object
47
+ """
48
+ def __init__(self, order: int, weight_decay: float) -> None:
49
+ """
50
+ The initialization of Regularization class.
51
+
52
+ Args:
53
+ order: (int) norm order number
54
+ weight_decay: (float) weight decay rate
55
+ """
56
+ super().__init__()
57
+ self.order = order
58
+ self.weight_decay = weight_decay
59
+
60
+ def __call__(self, network: nn.Module) -> torch.FloatTensor:
61
+ """"
62
+ Calculates regularization(self.order) loss for network.
63
+
64
+ Args:
65
+ model: (torch.nn.Module object)
66
+
67
+ Returns:
68
+ torch.FloatTensor: the regularization(self.order) loss
69
+ """
70
+ reg_loss = 0
71
+ for name, w in network.named_parameters():
72
+ if 'weight' in name:
73
+ reg_loss = reg_loss + torch.norm(w, p=self.order)
74
+ reg_loss = self.weight_decay * reg_loss
75
+ return reg_loss
76
+
77
+
78
+ class NegativeLogLikelihood(nn.Module):
79
+ """
80
+ Class to calculate RMSE.
81
+ """
82
+ def __init__(self, device: torch.device) -> None:
83
+ """
84
+ Args:
85
+ device (torch.device): device
86
+ """
87
+ super().__init__()
88
+ self.L2_reg = 0.05
89
+ self.reg = Regularization(order=2, weight_decay=self.L2_reg)
90
+ self.device = device
91
+
92
+ def forward(
93
+ self,
94
+ output: torch.FloatTensor,
95
+ label: torch.IntTensor,
96
+ periods: torch.FloatTensor,
97
+ network: nn.Module
98
+ ) -> torch.FloatTensor:
99
+ """
100
+ Calculates Negative Log Likelihood.
101
+
102
+ Args:
103
+ output (torch.FloatTensor): prediction value, ie risk prediction
104
+ label (torch.IntTensor): occurrence of event
105
+ periods (torch.FloatTensor): period
106
+ network (nn.Network): network
107
+
108
+ Returns:
109
+ torch.FloatTensor: Negative Log Likelihood
110
+ """
111
+ mask = torch.ones(periods.shape[0], periods.shape[0]).to(self.device) # output and mask should be on the same device.
112
+ mask[(periods.T - periods) > 0] = 0
113
+
114
+ _loss = torch.exp(output) * mask
115
+ # Note: torch.sum(_loss, dim=0) possibly returns nan, in particular MLP.
116
+ _loss = torch.sum(_loss, dim=0) / torch.sum(mask, dim=0)
117
+ _loss = torch.log(_loss).reshape(-1, 1)
118
+ num_occurs = torch.sum(label)
119
+
120
+ if num_occurs.item() == 0.0:
121
+ loss = torch.tensor([1e-7], requires_grad=True).to(self.device) # To avoid zero division, set small value as loss
122
+ return loss
123
+ else:
124
+ neg_log_loss = -torch.sum((output - _loss) * label) / num_occurs
125
+ l2_loss = self.reg(network)
126
+ loss = neg_log_loss + l2_loss
127
+ return loss
128
+
129
+
130
+ class ClsCriterion:
131
+ """
132
+ Class of criterion for classification.
133
+ """
134
+ def __init__(self, device: torch.device = None) -> None:
135
+ """
136
+ Set CrossEntropyLoss.
137
+
138
+ Args:
139
+ device (torch.device): device
140
+ """
141
+ self.device = device
142
+ self.criterion = nn.CrossEntropyLoss()
143
+
144
+ def __call__(
145
+ self,
146
+ outputs: Dict[str, torch.FloatTensor],
147
+ labels: Dict[str, LabelDict]
148
+ ) -> Dict[str, torch.FloatTensor]:
149
+ """
150
+ Calculate loss.
151
+
152
+ Args:
153
+ outputs (Dict[str, torch.FloatTensor], optional): output
154
+ labels (Dict[str, LabelDict]): labels
155
+
156
+ Returns:
157
+ Dict[str, torch.FloatTensor]: loss for each label and their total loss
158
+
159
+ # No reshape and no cast:
160
+ output: [64, 2]: torch.float32
161
+ label: [64] : torch.int64
162
+ label.dtype should be torch.int64, otherwise nn.CrossEntropyLoss() causes error.
163
+
164
+ eg.
165
+ outputs = {'label_A': [[0.8, 0.2], ...] 'label_B': [[0.7, 0.3]], ...}
166
+ labels = { 'labels': {'label_A: 1: [1, 1, 0, ...], 'label_B': [0, 0, 1, ...], ...} }
167
+
168
+ -> losses = {total: loss_total, label_A: loss_A, label_B: loss_B, ... }
169
+ """
170
+ _labels = labels['labels']
171
+
172
+ # loss for each label and total of their losses
173
+ losses = dict()
174
+ losses['total'] = torch.tensor([0.0], requires_grad=True).to(self.device)
175
+ for label_name in labels['labels'].keys():
176
+ _output = outputs[label_name]
177
+ _label = _labels[label_name]
178
+ _label_loss = self.criterion(_output, _label)
179
+ losses[label_name] = _label_loss
180
+ losses['total'] = torch.add(losses['total'], _label_loss)
181
+ return losses
182
+
183
+
184
+ class RegCriterion:
185
+ """
186
+ Class of criterion for regression.
187
+ """
188
+ def __init__(self, criterion_name: str = None, device: torch.device = None) -> None:
189
+ """
190
+ Set MSE, RMSE or MAE.
191
+
192
+ Args:
193
+ criterion_name (str): 'MSE', 'RMSE', or 'MAE'
194
+ device (torch.device): device
195
+ """
196
+ self.device = device
197
+
198
+ if criterion_name == 'MSE':
199
+ self.criterion = nn.MSELoss()
200
+ elif criterion_name == 'RMSE':
201
+ self.criterion = RMSELoss()
202
+ elif criterion_name == 'MAE':
203
+ self.criterion = nn.L1Loss()
204
+ else:
205
+ raise ValueError(f"Invalid criterion for regression: {criterion_name}.")
206
+
207
+ def __call__(
208
+ self,
209
+ outputs: Dict[str, torch.FloatTensor],
210
+ labels: Dict[str, LabelDict]
211
+ ) -> Dict[str, torch.FloatTensor]:
212
+ """
213
+ Calculate loss.
214
+
215
+ Args:
216
+ Args:
217
+ outputs (Dict[str, torch.FloatTensor], optional): output
218
+ labels (Dict[str, LabelDict]): labels
219
+
220
+ Returns:
221
+ Dict[str, torch.FloatTensor]: loss for each label and their total loss
222
+
223
+ # Reshape and cast
224
+ output: [64, 1] -> [64]: torch.float32
225
+ label: [64]: torch.float64 -> torch.float32
226
+ # label.dtype should be torch.float32, otherwise cannot backward.
227
+
228
+ eg.
229
+ outputs = {'label_A': [[10.8], ...] 'label_B': [[15.7]], ...}
230
+ labels = {'labels': {'label_A: 1: [10, 9, ...], 'label_B': [12, 17,], ...}}
231
+ -> losses = {total: loss_total, label_A: loss_A, label_B: loss_B, ... }
232
+ """
233
+ _outputs = {label_name: _output.squeeze() for label_name, _output in outputs.items()}
234
+ _labels = {label_name: _label.to(torch.float32) for label_name, _label in labels['labels'].items()}
235
+
236
+ # loss for each label and total of their losses
237
+ losses = dict()
238
+ losses['total'] = torch.tensor([0.0], requires_grad=True).to(self.device)
239
+ for label_name in labels['labels'].keys():
240
+ _output = _outputs[label_name]
241
+ _label = _labels[label_name]
242
+ _label_loss = self.criterion(_output, _label)
243
+ losses[label_name] = _label_loss
244
+ losses['total'] = torch.add(losses['total'], _label_loss)
245
+ return losses
246
+
247
+
248
+ class DeepSurvCriterion:
249
+ """
250
+ Class of criterion for deepsurv.
251
+ """
252
+ def __init__(self, device: torch.device = None) -> None:
253
+ """
254
+ Set NegativeLogLikelihood.
255
+
256
+ Args:
257
+ device (torch.device, optional): device
258
+ """
259
+ self.device = device
260
+ self.criterion = NegativeLogLikelihood(self.device).to(self.device)
261
+
262
+ def __call__(
263
+ self,
264
+ outputs: Dict[str, torch.FloatTensor],
265
+ labels: Dict[str, Union[LabelDict, torch.IntTensor, nn.Module]]
266
+ ) -> Dict[str, torch.FloatTensor]:
267
+ """
268
+ Calculate loss.
269
+
270
+ Args:
271
+ outputs (Dict[str, torch.FloatTensor], optional): output
272
+ labels (Dict[str, Union[LabelDict, torch.IntTensor, nn.Module]]): labels, periods, and network
273
+
274
+ Returns:
275
+ Dict[str, torch.FloatTensor]: loss for each label and their total loss
276
+
277
+ # Reshape and no cast
278
+ output: [64, 1]: torch.float32
279
+ label: [64] -> [64, 1]: torch.int64
280
+ period: [64] -> [64, 1]: torch.float32
281
+
282
+ eg.
283
+ outputs = {'label_A': [[10.8], ...] 'label_B': [[15.7]], ...}
284
+ labels = {
285
+ 'labels': {'label_A: 1: [1, 0, 1, ...] },
286
+ 'periods': [5, 10, 7, ...],
287
+ 'network': network
288
+ }
289
+ -> losses = {total: loss_total, label_A: loss_A, label_B: loss_B, ... }
290
+ """
291
+ _labels = {label_name: _label.reshape(-1, 1) for label_name, _label in labels['labels'].items()}
292
+ _periods = labels['periods'].reshape(-1, 1)
293
+ _network = labels['network']
294
+
295
+ # loss for each label and total of their losses
296
+ losses = dict()
297
+ losses['total'] = torch.tensor([0.0], requires_grad=True).to(self.device)
298
+ for label_name in labels['labels'].keys():
299
+ _output = outputs[label_name]
300
+ _label = _labels[label_name]
301
+ _label_loss = self.criterion(_output, _label, _periods, _network)
302
+ losses[label_name] = _label_loss
303
+ losses['total'] = torch.add(losses['total'], _label_loss)
304
+ return losses
305
+
306
+
307
+ def set_criterion(
308
+ criterion_name: str,
309
+ device: torch.device
310
+ ) -> Union[ClsCriterion, RegCriterion, DeepSurvCriterion]:
311
+ """
312
+ Return criterion class
313
+
314
+ Args:
315
+ criterion_name (str): criterion name
316
+ device (torch.device): device
317
+
318
+ Returns:
319
+ Union[ClsCriterion, RegCriterion, DeepSurvCriterion]: criterion class
320
+ """
321
+
322
+ if criterion_name == 'CEL':
323
+ return ClsCriterion(device=device)
324
+
325
+ elif criterion_name in ['MSE', 'RMSE', 'MAE']:
326
+ return RegCriterion(criterion_name=criterion_name, device=device)
327
+
328
+ elif criterion_name == 'NLL':
329
+ return DeepSurvCriterion(device=device)
330
+
331
+ else:
332
+ raise ValueError(f"Invalid criterion: {criterion_name}.")
lib/component/likelihood.py ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # -*- coding: utf-8 -*-
3
+
4
+ import pandas as pd
5
+ import torch
6
+ from typing import List, Dict
7
+
8
+
9
+ class Likelihood:
10
+ """
11
+ Class for making likelihood.
12
+ """
13
+ def __init__(self, task: str, num_outputs_for_label: Dict[str, int]) -> None:
14
+ """
15
+ Args:
16
+ task (str): task
17
+ num_outputs_for_label (Dict[str, int]): number of classes for each label
18
+ """
19
+ self.task = task
20
+ self.num_outputs_for_label = num_outputs_for_label
21
+ self.base_column_list = self._set_base_columns(self.task)
22
+ self.pred_column_list = self._make_pred_columns(self.task, self.num_outputs_for_label)
23
+
24
+ def _set_base_columns(self, task: str) -> List[str]:
25
+ """
26
+ Return base columns.
27
+
28
+ Args:
29
+ task (str): task
30
+
31
+ Returns:
32
+ List[str]: base columns except columns of label and prediction
33
+ """
34
+ if (task == 'classification') or (task == 'regression'):
35
+ base_columns = ['uniqID', 'group', 'imgpath', 'split']
36
+ return base_columns
37
+ elif task == 'deepsurv':
38
+ base_columns = ['uniqID', 'group', 'imgpath', 'split', 'periods']
39
+ return base_columns
40
+ else:
41
+ raise ValueError(f"Invalid task: {task}.")
42
+
43
+ def _make_pred_columns(self, task: str, num_outputs_for_label: Dict[str, int]) -> Dict[str, List[str]]:
44
+ """
45
+ Make column names of predictions with label name and its number of classes.
46
+
47
+ Args:
48
+ task (str): task
49
+ num_outputs_for_label (Dict[str, int]): number of classes for each label
50
+
51
+ Returns:
52
+ Dict[str, List[str]]: label and list of columns of predictions with its class number
53
+
54
+ eg.
55
+ {label_A: 2, label_B: 2} -> {label_A: [pred_label_A_0, pred_label_A_1], label_B: [pred_label_B_0, pred_label_B_1]}
56
+ {label_A: 1, label_B: 1} -> {label_A: [pred_label_A], label_B: [pred_label_B]}
57
+ """
58
+ pred_columns = dict()
59
+ if task == 'classification':
60
+ for label_name, num_classes in num_outputs_for_label.items():
61
+ pred_columns[label_name] = ['pred_' + label_name + '_' + str(i) for i in range(num_classes)]
62
+ return pred_columns
63
+ elif (task == 'regression') or (task == 'deepsurv'):
64
+ for label_name, num_classes in num_outputs_for_label.items():
65
+ pred_columns[label_name] = ['pred_' + label_name]
66
+ return pred_columns
67
+ else:
68
+ raise ValueError(f"Invalid task: {task}.")
69
+
70
+ def make_format(self, data: Dict, output: Dict[str, torch.Tensor]) -> pd.DataFrame:
71
+ """
72
+ Make a new DataFrame of likelihood every batch.
73
+
74
+ Args:
75
+ data (Dict): batch data from dataloader
76
+ output (Dict[str, torch.Tensor]): output of model
77
+ """
78
+ _likelihood = {column_name: data[column_name] for column_name in self.base_column_list}
79
+ df_likelihood = pd.DataFrame(_likelihood)
80
+
81
+ if any(data['labels']):
82
+ for label_name, pred in output.items():
83
+ _df_label = pd.DataFrame({label_name: data['labels'][label_name].tolist()})
84
+ pred = pred.to('cpu').detach().numpy().copy()
85
+ _df_pred = pd.DataFrame(pred, columns=self.pred_column_list[label_name])
86
+ df_likelihood = pd.concat([df_likelihood, _df_label, _df_pred], axis=1)
87
+ return df_likelihood
88
+ else:
89
+ for label_name, pred in output.items():
90
+ pred = pred.to('cpu').detach().numpy().copy()
91
+ _df_pred = pd.DataFrame(pred, columns=self.pred_column_list[label_name])
92
+ df_likelihood = pd.concat([df_likelihood, _df_pred], axis=1)
93
+ return df_likelihood
94
+
95
+
96
+ def set_likelihood(task: str, num_outputs_for_label: Dict[str, int]) -> Likelihood:
97
+ """
98
+ Set likelihood.
99
+
100
+ Args:
101
+ task (str): task
102
+ num_outputs_for_label (Dict[str, int]): number of classes for each label
103
+
104
+ Returns:
105
+ Likelihood: instance of class Likelihood
106
+ """
107
+ return Likelihood(task, num_outputs_for_label)
lib/component/loss.py ADDED
@@ -0,0 +1,248 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # -*- coding: utf-8 -*-
3
+
4
+ from pathlib import Path
5
+ import torch
6
+ import pandas as pd
7
+ from ..logger import BaseLogger
8
+ from typing import List, Dict, Union
9
+
10
+
11
+ logger = BaseLogger.get_logger(__name__)
12
+
13
+
14
+ class LabelLoss:
15
+ """
16
+ Class to store loss for every bash and epoch loss of each label.
17
+ """
18
+ def __init__(self) -> None:
19
+ # Accumulate batch_loss(=loss * batch_size)
20
+ self.train_batch_loss = 0.0
21
+ self.val_batch_loss = 0.0
22
+
23
+ # epoch_loss = batch_loss / dataset_size
24
+ self.train_epoch_loss = [] # List[float]
25
+ self.val_epoch_loss = [] # List[float]
26
+
27
+ self.best_val_loss = None # float
28
+ self.best_epoch = None # int
29
+ self.is_val_loss_updated = None # bool
30
+
31
+ def get_loss(self, phase: str, target: str) -> Union[float, List[float]]:
32
+ """
33
+ Return loss depending on phase and target
34
+
35
+ Args:
36
+ phase (str): 'train' or 'val'
37
+ target (str): 'batch' or 'epoch'
38
+
39
+ Returns:
40
+ Union[float, List[float]]: batch_loss or epoch_loss
41
+ """
42
+ _target = phase + '_' + target + '_loss'
43
+ return getattr(self, _target)
44
+
45
+ def store_batch_loss(self, phase: str, new_batch_loss: torch.FloatTensor, batch_size: int) -> None:
46
+ """
47
+ Add new batch loss to previous one for phase by multiplying by batch_size.
48
+
49
+ Args:
50
+ phase (str): 'train' or 'val'
51
+ new_batch_loss (torch.FloatTensor): batch loss calculated by criterion
52
+ batch_size (int): batch size
53
+ """
54
+ _new = new_batch_loss.item() * batch_size # torch.FloatTensor -> float
55
+ _prev = self.get_loss(phase, 'batch')
56
+ _added = _prev + _new
57
+ _target = phase + '_' + 'batch_loss'
58
+ setattr(self, _target, _added)
59
+
60
+ def append_epoch_loss(self, phase: str, new_epoch_loss: float) -> None:
61
+ """
62
+ Append epoch loss depending on phase and target
63
+
64
+ Args:
65
+ phase (str): 'train' or 'val'
66
+ new_epoch_loss (float): batch loss or epoch loss
67
+ """
68
+ _target = phase + '_' + 'epoch_loss'
69
+ getattr(self, _target).append(new_epoch_loss)
70
+
71
+ def get_latest_epoch_loss(self, phase: str) -> float:
72
+ """
73
+ Return the latest loss of phase.
74
+
75
+ Args:
76
+ phase (str): train or val
77
+
78
+ Returns:
79
+ float: the latest loss
80
+ """
81
+ return self.get_loss(phase, 'epoch')[-1]
82
+
83
+ def update_best_val_loss(self, at_epoch: int = None) -> None:
84
+ """
85
+ Update val_epoch_loss is the best.
86
+
87
+ Args:
88
+ at_epoch (int): epoch when checked
89
+ """
90
+ _latest_val_loss = self.get_latest_epoch_loss('val')
91
+
92
+ if at_epoch == 1:
93
+ self.best_val_loss = _latest_val_loss
94
+ self.best_epoch = at_epoch
95
+ self.is_val_loss_updated = True
96
+ else:
97
+ # When at_epoch > 1
98
+ if _latest_val_loss < self.best_val_loss:
99
+ self.best_val_loss = _latest_val_loss
100
+ self.best_epoch = at_epoch
101
+ self.is_val_loss_updated = True
102
+ else:
103
+ self.is_val_loss_updated = False
104
+
105
+
106
+ class LossStore:
107
+ """
108
+ Class for calculating loss and store it.
109
+ """
110
+ def __init__(self, label_list: List[str], num_epochs: int, dataset_info: Dict[str, int]) -> None:
111
+ """
112
+ Args:
113
+ label_list (List[str]): list of internal labels
114
+ num_epochs (int) : number of epochs
115
+ dataset_info (Dict[str, int]): dataset sizes of 'train' and 'val'
116
+ """
117
+ self.label_list = label_list
118
+ self.num_epochs = num_epochs
119
+ self.dataset_info = dataset_info
120
+
121
+ # Added a special label 'total' to store total of losses of all labels.
122
+ self.label_losses = {label_name: LabelLoss() for label_name in self.label_list + ['total']}
123
+
124
+ def store(self, phase: str, losses: Dict[str, torch.FloatTensor], batch_size: int = None) -> None:
125
+ """
126
+ Store label-wise batch losses of phase to previous one.
127
+
128
+ Args:
129
+ phase (str): 'train' or 'val'
130
+ losses (Dict[str, torch.FloatTensor]): loss for each label calculated by criterion
131
+ batch_size (int): batch size
132
+
133
+ # Note:
134
+ self.loss_stores['total'] is already total of losses of all label, which is calculated in criterion.py,
135
+ therefore, it is OK just to multiply by batch_size. This is done in add_batch_loss().
136
+ """
137
+ for label_name in self.label_list + ['total']:
138
+ _new_batch_loss = losses[label_name]
139
+ self.label_losses[label_name].store_batch_loss(phase, _new_batch_loss, batch_size)
140
+
141
+ def cal_epoch_loss(self, at_epoch: int = None) -> None:
142
+ """
143
+ Calculate epoch loss for each phase all at once.
144
+
145
+ Args:
146
+ at_epoch (int): epoch number
147
+ """
148
+ # For each label
149
+ for label_name in self.label_list:
150
+ for phase in ['train', 'val']:
151
+ _batch_loss = self.label_losses[label_name].get_loss(phase, 'batch')
152
+ _dataset_size = self.dataset_info[phase]
153
+ _new_epoch_loss = _batch_loss / _dataset_size
154
+ self.label_losses[label_name].append_epoch_loss(phase, _new_epoch_loss)
155
+
156
+ # For total, average by dataset_size and the number of labels.
157
+ for phase in ['train', 'val']:
158
+ _batch_loss = self.label_losses['total'].get_loss(phase, 'batch')
159
+ _dataset_size = self.dataset_info[phase]
160
+ _new_epoch_loss = _batch_loss / (_dataset_size * len(self.label_list))
161
+ self.label_losses['total'].append_epoch_loss(phase, _new_epoch_loss)
162
+
163
+ # Update val_best_loss and best_epoch.
164
+ for label_name in self.label_list + ['total']:
165
+ self.label_losses[label_name].update_best_val_loss(at_epoch=at_epoch)
166
+
167
+ # Initialize batch_loss after calculating epoch loss.
168
+ for label_name in self.label_list + ['total']:
169
+ self.label_losses[label_name].train_batch_loss = 0.0
170
+ self.label_losses[label_name].val_batch_loss = 0.0
171
+
172
+ def is_val_loss_updated(self) -> bool:
173
+ """
174
+ Check if val_loss of 'total' is updated.
175
+
176
+ Returns:
177
+ bool: Updated or not
178
+ """
179
+ return self.label_losses['total'].is_val_loss_updated
180
+
181
+ def get_best_epoch(self) -> int:
182
+ """
183
+ Returns best epoch.
184
+
185
+ Returns:
186
+ int: best epoch
187
+ """
188
+ return self.label_losses['total'].best_epoch
189
+
190
+ def print_epoch_loss(self, at_epoch: int = None) -> None:
191
+ """
192
+ Print train_loss and val_loss for the ith epoch.
193
+
194
+ Args:
195
+ at_epoch (int): epoch number
196
+ """
197
+ train_epoch_loss = self.label_losses['total'].get_latest_epoch_loss('train')
198
+ val_epoch_loss = self.label_losses['total'].get_latest_epoch_loss('val')
199
+
200
+ _epoch_comm = f"epoch [{at_epoch:>3}/{self.num_epochs:<3}]"
201
+ _train_comm = f"train_loss: {train_epoch_loss :>8.4f}"
202
+ _val_comm = f"val_loss: {val_epoch_loss:>8.4f}"
203
+ _updated_comment = ''
204
+ if (at_epoch > 1) and (self.is_val_loss_updated()):
205
+ _updated_comment = ' Updated best val_loss!'
206
+ comment = _epoch_comm + ', ' + _train_comm + ', ' + _val_comm + _updated_comment
207
+ logger.info(comment)
208
+
209
+ def save_learning_curve(self, save_datetime_dir: str) -> None:
210
+ """
211
+ Save learning curve.
212
+
213
+ Args:
214
+ save_datetime_dir (str): save_datetime_dir
215
+ """
216
+ save_dir = Path(save_datetime_dir, 'learning_curve')
217
+ save_dir.mkdir(parents=True, exist_ok=True)
218
+
219
+ for label_name in self.label_list + ['total']:
220
+ _label_loss = self.label_losses[label_name]
221
+ _train_epoch_loss = _label_loss.get_loss('train', 'epoch')
222
+ _val_epoch_loss = _label_loss.get_loss('val', 'epoch')
223
+
224
+ df_label_epoch_loss = pd.DataFrame({
225
+ 'train_loss': _train_epoch_loss,
226
+ 'val_loss': _val_epoch_loss
227
+ })
228
+
229
+ _best_epoch = str(_label_loss.best_epoch).zfill(3)
230
+ _best_val_loss = f"{_label_loss.best_val_loss:.4f}"
231
+ save_name = 'learning_curve_' + label_name + '_val-best-epoch-' + _best_epoch + '_val-best-loss-' + _best_val_loss + '.csv'
232
+ save_path = Path(save_dir, save_name)
233
+ df_label_epoch_loss.to_csv(save_path, index=False)
234
+
235
+
236
+ def set_loss_store(label_list: List[str], num_epochs: int, dataset_info: Dict[str, int]) -> LossStore:
237
+ """
238
+ Return class LossStore.
239
+
240
+ Args:
241
+ label_list (List[str]): label list
242
+ num_epochs (int) : number of epochs
243
+ dataset_info (Dict[str, int]): dataset sizes of 'train' and 'val'
244
+
245
+ Returns:
246
+ LossStore: LossStore
247
+ """
248
+ return LossStore(label_list, num_epochs, dataset_info)
lib/component/net.py ADDED
@@ -0,0 +1,624 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # -*- coding: utf-8 -*-r
3
+
4
+ from collections import OrderedDict
5
+ import torch
6
+ import torch.nn as nn
7
+ from torchvision.ops import MLP
8
+ import torchvision.models as models
9
+ from typing import Dict, Optional
10
+
11
+
12
+ class BaseNet:
13
+ """
14
+ Class to construct network
15
+ """
16
+ cnn = {
17
+ 'ResNet18': models.resnet18,
18
+ 'ResNet': models.resnet50,
19
+ 'DenseNet': models.densenet161,
20
+ 'EfficientNetB0': models.efficientnet_b0,
21
+ 'EfficientNetB2': models.efficientnet_b2,
22
+ 'EfficientNetB4': models.efficientnet_b4,
23
+ 'EfficientNetB6': models.efficientnet_b6,
24
+ 'EfficientNetV2s': models.efficientnet_v2_s,
25
+ 'EfficientNetV2m': models.efficientnet_v2_m,
26
+ 'EfficientNetV2l': models.efficientnet_v2_l,
27
+ 'ConvNeXtTiny': models.convnext_tiny,
28
+ 'ConvNeXtSmall': models.convnext_small,
29
+ 'ConvNeXtBase': models.convnext_base,
30
+ 'ConvNeXtLarge': models.convnext_large
31
+ }
32
+
33
+ vit = {
34
+ 'ViTb16': models.vit_b_16,
35
+ 'ViTb32': models.vit_b_32,
36
+ 'ViTl16': models.vit_l_16,
37
+ 'ViTl32': models.vit_l_32,
38
+ 'ViTH14': models.vit_h_14
39
+ }
40
+
41
+ net = {**cnn, **vit}
42
+
43
+ _classifier = {
44
+ 'ResNet': 'fc',
45
+ 'DenseNet': 'classifier',
46
+ 'EfficientNet': 'classifier',
47
+ 'ConvNext': 'classifier',
48
+ 'ViT': 'heads'
49
+ }
50
+
51
+ classifier = {
52
+ 'ResNet18': _classifier['ResNet'],
53
+ 'ResNet': _classifier['ResNet'],
54
+ 'DenseNet': _classifier['DenseNet'],
55
+ 'EfficientNetB0': _classifier['EfficientNet'],
56
+ 'EfficientNetB2': _classifier['EfficientNet'],
57
+ 'EfficientNetB4': _classifier['EfficientNet'],
58
+ 'EfficientNetB6': _classifier['EfficientNet'],
59
+ 'EfficientNetV2s': _classifier['EfficientNet'],
60
+ 'EfficientNetV2m': _classifier['EfficientNet'],
61
+ 'EfficientNetV2l': _classifier['EfficientNet'],
62
+ 'ConvNeXtTiny': _classifier['ConvNext'],
63
+ 'ConvNeXtSmall': _classifier['ConvNext'],
64
+ 'ConvNeXtBase': _classifier['ConvNext'],
65
+ 'ConvNeXtLarge': _classifier['ConvNext'],
66
+ 'ViTb16': _classifier['ViT'],
67
+ 'ViTb32': _classifier['ViT'],
68
+ 'ViTl16': _classifier['ViT'],
69
+ 'ViTl32': _classifier['ViT'],
70
+ 'ViTH14': _classifier['ViT']
71
+ }
72
+
73
+ mlp_config = {
74
+ 'hidden_channels': [256, 256, 256],
75
+ 'dropout': 0.2
76
+ }
77
+
78
+ DUMMY = nn.Identity()
79
+
80
+ @classmethod
81
+ def MLPNet(cls, mlp_num_inputs: int = None, inplace: bool = None) -> MLP:
82
+ """
83
+ Construct MLP.
84
+
85
+ Args:
86
+ mlp_num_inputs (int): the number of input of MLP
87
+ inplace (bool, optional): parameter for the activation layer, which can optionally do the operation in-place. Defaults to None.
88
+
89
+ Returns:
90
+ MLP: MLP
91
+ """
92
+ assert isinstance(mlp_num_inputs, int), f"Invalid number of inputs for MLP: {mlp_num_inputs}."
93
+ mlp = MLP(in_channels=mlp_num_inputs, hidden_channels=cls.mlp_config['hidden_channels'], inplace=inplace, dropout=cls.mlp_config['dropout'])
94
+ return mlp
95
+
96
+ @classmethod
97
+ def align_in_channels_1ch(cls, net_name: str = None, net: nn.Module = None) -> nn.Module:
98
+ """
99
+ Modify network to handle gray scale image.
100
+
101
+ Args:
102
+ net_name (str): network name
103
+ net (nn.Module): network itself
104
+
105
+ Returns:
106
+ nn.Module: network available for gray scale
107
+ """
108
+ if net_name.startswith('ResNet'):
109
+ net.conv1.in_channels = 1
110
+ net.conv1.weight = nn.Parameter(net.conv1.weight.sum(dim=1).unsqueeze(1))
111
+
112
+ elif net_name.startswith('DenseNet'):
113
+ net.features.conv0.in_channels = 1
114
+ net.features.conv0.weight = nn.Parameter(net.features.conv0.weight.sum(dim=1).unsqueeze(1))
115
+
116
+ elif net_name.startswith('Efficient'):
117
+ net.features[0][0].in_channels = 1
118
+ net.features[0][0].weight = nn.Parameter(net.features[0][0].weight.sum(dim=1).unsqueeze(1))
119
+
120
+ elif net_name.startswith('ConvNeXt'):
121
+ net.features[0][0].in_channels = 1
122
+ net.features[0][0].weight = nn.Parameter(net.features[0][0].weight.sum(dim=1).unsqueeze(1))
123
+
124
+ elif net_name.startswith('ViT'):
125
+ net.conv_proj.in_channels = 1
126
+ net.conv_proj.weight = nn.Parameter(net.conv_proj.weight.sum(dim=1).unsqueeze(1))
127
+
128
+ else:
129
+ raise ValueError(f"No specified net: {net_name}.")
130
+ return net
131
+
132
+ @classmethod
133
+ def set_net(
134
+ cls,
135
+ net_name: str = None,
136
+ in_channel: int = None,
137
+ vit_image_size: int = None,
138
+ pretrained: bool = None
139
+ ) -> nn.Module:
140
+ """
141
+ Modify network depending on in_channel and vit_image_size.
142
+
143
+ Args:
144
+ net_name (str): network name
145
+ in_channel (int, optional): image channel(any of 1ch or 3ch). Defaults to None.
146
+ vit_image_size (int, optional): image size which ViT handles if ViT is used. Defaults to None.
147
+ vit_image_size should be power of patch size.
148
+ pretrained (bool, optional): True when use pretrained CNN or ViT, otherwise False. Defaults to None.
149
+
150
+ Returns:
151
+ nn.Module: modified network
152
+ """
153
+ assert net_name in cls.net, f"No specified net: {net_name}."
154
+ if net_name in cls.cnn:
155
+ if pretrained:
156
+ net = cls.cnn[net_name](weights='DEFAULT')
157
+ else:
158
+ net = cls.cnn[net_name]()
159
+ else:
160
+ # When ViT
161
+ # always use pretrained
162
+ net = cls.set_vit(net_name=net_name, vit_image_size=vit_image_size)
163
+
164
+ if in_channel == 1:
165
+ net = cls.align_in_channels_1ch(net_name=net_name, net=net)
166
+ return net
167
+
168
+ @classmethod
169
+ def set_vit(cls, net_name: str = None, vit_image_size: int = None) -> nn.Module:
170
+ """
171
+ Modify ViT depending on vit_image_size.
172
+
173
+ Args:
174
+ net_name (str): ViT name
175
+ vit_image_size (int): image size which ViT handles if ViT is used.
176
+
177
+ Returns:
178
+ nn.Module: modified ViT
179
+ """
180
+ base_vit = cls.vit[net_name]
181
+ # pretrained_vit = base_vit(weights=cls.vit_weight[net_name])
182
+ pretrained_vit = base_vit(weights='DEFAULT')
183
+
184
+ # Align weight depending on image size
185
+ weight = pretrained_vit.state_dict()
186
+ patch_size = int(net_name[-2:]) # 'ViTb16' -> 16
187
+ aligned_weight = models.vision_transformer.interpolate_embeddings(
188
+ image_size=vit_image_size,
189
+ patch_size=patch_size,
190
+ model_state=weight
191
+ )
192
+ aligned_vit = base_vit(image_size=vit_image_size) # Specify new image size.
193
+ aligned_vit.load_state_dict(aligned_weight) # Load weight which can handle the new image size.
194
+ return aligned_vit
195
+
196
+ @classmethod
197
+ def construct_extractor(
198
+ cls,
199
+ net_name: str = None,
200
+ mlp_num_inputs: int = None,
201
+ in_channel: int = None,
202
+ vit_image_size: int = None,
203
+ pretrained: bool = None
204
+ ) -> nn.Module:
205
+ """
206
+ Construct extractor of network depending on net_name.
207
+
208
+ Args:
209
+ net_name (str): network name.
210
+ mlp_num_inputs (int, optional): number of input of MLP. Defaults to None.
211
+ in_channel (int, optional): image channel(any of 1ch or 3ch). Defaults to None.
212
+ vit_image_size (int, optional): image size which ViT handles if ViT is used. Defaults to None.
213
+ pretrained (bool, optional): True when use pretrained CNN or ViT, otherwise False. Defaults to None.
214
+
215
+ Returns:
216
+ nn.Module: extractor of network
217
+ """
218
+ if net_name == 'MLP':
219
+ extractor = cls.MLPNet(mlp_num_inputs=mlp_num_inputs)
220
+ else:
221
+ extractor = cls.set_net(net_name=net_name, in_channel=in_channel, vit_image_size=vit_image_size, pretrained=pretrained)
222
+ setattr(extractor, cls.classifier[net_name], cls.DUMMY) # Replace classifier with DUMMY(=nn.Identity()).
223
+ return extractor
224
+
225
+ @classmethod
226
+ def get_classifier(cls, net_name: str) -> nn.Module:
227
+ """
228
+ Get classifier of network depending on net_name.
229
+
230
+ Args:
231
+ net_name (str): network name
232
+
233
+ Returns:
234
+ nn.Module: classifier of network
235
+ """
236
+ net = cls.net[net_name]()
237
+ classifier = getattr(net, cls.classifier[net_name])
238
+ return classifier
239
+
240
+ @classmethod
241
+ def construct_multi_classifier(cls, net_name: str = None, num_outputs_for_label: Dict[str, int] = None) -> nn.ModuleDict:
242
+ """
243
+ Construct classifier for multi-label.
244
+
245
+ Args:
246
+ net_name (str): network name
247
+ num_outputs_for_label (Dict[str, int]): number of outputs for each label
248
+
249
+ Returns:
250
+ nn.ModuleDict: classifier for multi-label
251
+ """
252
+ classifiers = dict()
253
+ if net_name == 'MLP':
254
+ in_features = cls.mlp_config['hidden_channels'][-1]
255
+ for label_name, num_outputs in num_outputs_for_label.items():
256
+ classifiers[label_name] = nn.Linear(in_features, num_outputs)
257
+
258
+ elif net_name.startswith('ResNet') or net_name.startswith('DenseNet'):
259
+ base_classifier = cls.get_classifier(net_name)
260
+ in_features = base_classifier.in_features
261
+ for label_name, num_outputs in num_outputs_for_label.items():
262
+ classifiers[label_name] = nn.Linear(in_features, num_outputs)
263
+
264
+ elif net_name.startswith('EfficientNet'):
265
+ base_classifier = cls.get_classifier(net_name)
266
+ dropout = base_classifier[0].p
267
+ in_features = base_classifier[1].in_features
268
+ for label_name, num_outputs in num_outputs_for_label.items():
269
+ classifiers[label_name] = nn.Sequential(
270
+ nn.Dropout(p=dropout, inplace=False),
271
+ nn.Linear(in_features, num_outputs)
272
+ )
273
+
274
+ elif net_name.startswith('ConvNeXt'):
275
+ base_classifier = cls.get_classifier(net_name)
276
+ layer_norm = base_classifier[0]
277
+ flatten = base_classifier[1]
278
+ in_features = base_classifier[2].in_features
279
+ for label_name, num_outputs in num_outputs_for_label.items():
280
+ # Shape is changed before nn.Linear.
281
+ classifiers[label_name] = nn.Sequential(
282
+ layer_norm,
283
+ flatten,
284
+ nn.Linear(in_features, num_outputs)
285
+ )
286
+
287
+ elif net_name.startswith('ViT'):
288
+ base_classifier = cls.get_classifier(net_name)
289
+ in_features = base_classifier.head.in_features
290
+ for label_name, num_outputs in num_outputs_for_label.items():
291
+ classifiers[label_name] = nn.Sequential(
292
+ OrderedDict([
293
+ ('head', nn.Linear(in_features, num_outputs))
294
+ ])
295
+ )
296
+
297
+ else:
298
+ raise ValueError(f"No specified net: {net_name}.")
299
+
300
+ multi_classifier = nn.ModuleDict(classifiers)
301
+ return multi_classifier
302
+
303
+ @classmethod
304
+ def get_classifier_in_features(cls, net_name: str) -> int:
305
+ """
306
+ Return in_feature of network indicating by net_name.
307
+ This class is used in class MultiNetFusion() only.
308
+
309
+ Args:
310
+ net_name (str): net_name
311
+
312
+ Returns:
313
+ int : in_feature
314
+
315
+ Required:
316
+ classifier.in_feature
317
+ classifier.[1].in_features
318
+ classifier.[2].in_features
319
+ classifier.head.in_features
320
+ """
321
+ if net_name == 'MLP':
322
+ in_features = cls.mlp_config['hidden_channels'][-1]
323
+
324
+ elif net_name.startswith('ResNet') or net_name.startswith('DenseNet'):
325
+ base_classifier = cls.get_classifier(net_name)
326
+ in_features = base_classifier.in_features
327
+
328
+ elif net_name.startswith('EfficientNet'):
329
+ base_classifier = cls.get_classifier(net_name)
330
+ in_features = base_classifier[1].in_features
331
+
332
+ elif net_name.startswith('ConvNeXt'):
333
+ base_classifier = cls.get_classifier(net_name)
334
+ in_features = base_classifier[2].in_features
335
+
336
+ elif net_name.startswith('ViT'):
337
+ base_classifier = cls.get_classifier(net_name)
338
+ in_features = base_classifier.head.in_features
339
+
340
+ else:
341
+ raise ValueError(f"No specified net: {net_name}.")
342
+ return in_features
343
+
344
+ @classmethod
345
+ def construct_aux_module(cls, net_name: str) -> nn.Sequential:
346
+ """
347
+ Construct module to align the shape of feature from extractor depending on network.
348
+ Actually, only when net_name == 'ConvNeXt'.
349
+ Because ConvNeXt has the process of aligning the dimensions in its classifier.
350
+
351
+ Needs to align shape of the feature extractor when ConvNeXt
352
+ (classifier):
353
+ Sequential(
354
+ (0): LayerNorm2d((768,), eps=1e-06, elementwise_affine=True)
355
+ (1): Flatten(start_dim=1, end_dim=-1)
356
+ (2): Linear(in_features=768, out_features=1000, bias=True)
357
+ )
358
+
359
+ Args:
360
+ net_name (str): net name
361
+
362
+ Returns:
363
+ nn.Module: layers such that they align the dimension of the output from the extractor like the original ConvNeXt.
364
+ """
365
+ aux_module = cls.DUMMY
366
+ if net_name.startswith('ConvNeXt'):
367
+ base_classifier = cls.get_classifier(net_name)
368
+ layer_norm = base_classifier[0]
369
+ flatten = base_classifier[1]
370
+ aux_module = nn.Sequential(
371
+ layer_norm,
372
+ flatten
373
+ )
374
+ return aux_module
375
+
376
+ @classmethod
377
+ def get_last_extractor(cls, net: nn.Module = None, mlp: str = None, net_name: str = None) -> nn.Module:
378
+ """
379
+ Return the last extractor of network.
380
+ This is for Grad-CAM.
381
+ net should be one loaded weight.
382
+
383
+ Args:
384
+ net (nn.Module): network itself
385
+ mlp (str): 'MLP', otherwise None
386
+ net_name (str): network name
387
+
388
+ Returns:
389
+ nn.Module: last extractor of network
390
+ """
391
+ assert (net_name is not None), f"Network does not contain CNN or ViT: mlp={mlp}, net={net_name}."
392
+
393
+ _extractor = net.extractor_net
394
+
395
+ if net_name.startswith('ResNet'):
396
+ last_extractor = _extractor.layer4[-1]
397
+ elif net_name.startswith('DenseNet'):
398
+ last_extractor = _extractor.features.denseblock4.denselayer24
399
+ elif net_name.startswith('EfficientNet'):
400
+ last_extractor = _extractor.features[-1]
401
+ elif net_name.startswith('ConvNeXt'):
402
+ last_extractor = _extractor.features[-1][-1].block
403
+ elif net_name.startswith('ViT'):
404
+ last_extractor = _extractor.encoder.layers[-1]
405
+ else:
406
+ raise ValueError(f"Cannot get last extractor of net: {net_name}.")
407
+ return last_extractor
408
+
409
+
410
+ class MultiMixin:
411
+ """
412
+ Class to define auxiliary function to handle multi-label.
413
+ """
414
+ def multi_forward(self, out_features: int) -> Dict[str, float]:
415
+ """
416
+ Forward out_features to classifier for each label.
417
+
418
+ Args:
419
+ out_features (int): output from extractor
420
+
421
+ Returns:
422
+ Dict[str, float]: output of classifier of each label
423
+ """
424
+ output = dict()
425
+ for label_name, classifier in self.multi_classifier.items():
426
+ output[label_name] = classifier(out_features)
427
+ return output
428
+
429
+
430
+ class MultiWidget(nn.Module, BaseNet, MultiMixin):
431
+ """
432
+ Class for a widget to inherit multiple classes simultaneously.
433
+ """
434
+ pass
435
+
436
+
437
+ class MultiNet(MultiWidget):
438
+ """
439
+ Model of MLP, CNN or ViT.
440
+ """
441
+ def __init__(
442
+ self,
443
+ net_name: str = None,
444
+ num_outputs_for_label: Dict[str, int] = None,
445
+ mlp_num_inputs: int = None,
446
+ in_channel: int = None,
447
+ vit_image_size: int = None,
448
+ pretrained: bool = None
449
+ ) -> None:
450
+ """
451
+ Args:
452
+ net_name (str): MLP, CNN or ViT name
453
+ num_outputs_for_label (Dict[str, int]): number of classes for each label
454
+ mlp_num_inputs (int): number of input of MLP.
455
+ in_channel (int): number of image channel, ie gray scale(=1) or color image(=3).
456
+ vit_image_size (int): image size to be input to ViT.
457
+ pretrained (bool): True when use pretrained CNN or ViT, otherwise False.
458
+ """
459
+ super().__init__()
460
+
461
+ self.net_name = net_name
462
+ self.num_outputs_for_label = num_outputs_for_label
463
+ self.mlp_num_inputs = mlp_num_inputs
464
+ self.in_channel = in_channel
465
+ self.vit_image_size = vit_image_size
466
+ self.pretrained = pretrained
467
+
468
+ # self.extractor_net = MLP or CVmodel
469
+ self.extractor_net = self.construct_extractor(
470
+ net_name=self.net_name,
471
+ mlp_num_inputs=self.mlp_num_inputs,
472
+ in_channel=self.in_channel,
473
+ vit_image_size=self.vit_image_size,
474
+ pretrained=self.pretrained
475
+ )
476
+ self.multi_classifier = self.construct_multi_classifier(net_name=self.net_name, num_outputs_for_label=self.num_outputs_for_label)
477
+
478
+ def forward(self, x: torch.Tensor) -> Dict[str, torch.Tensor]:
479
+ """
480
+ Forward.
481
+
482
+ Args:
483
+ x (torch.Tensor): tabular data or image
484
+
485
+ Returns:
486
+ Dict[str, torch.Tensor]: output
487
+ """
488
+ out_features = self.extractor_net(x)
489
+ output = self.multi_forward(out_features)
490
+ return output
491
+
492
+
493
+ class MultiNetFusion(MultiWidget):
494
+ """
495
+ Fusion model of MLP and CNN or ViT.
496
+ """
497
+ def __init__(
498
+ self,
499
+ net_name: str = None,
500
+ num_outputs_for_label: Dict[str, int] = None,
501
+ mlp_num_inputs: int = None,
502
+ in_channel: int = None,
503
+ vit_image_size: int = None,
504
+ pretrained: bool = None
505
+ ) -> None:
506
+ """
507
+ Args:
508
+ net_name (str): CNN or ViT name. It is clear that MLP is used in fusion model.
509
+ num_outputs_for_label (Dict[str, int]): number of classes for each label
510
+ mlp_num_inputs (int): number of input of MLP. Defaults to None.
511
+ in_channel (int): number of image channel, ie gray scale(=1) or color image(=3).
512
+ vit_image_size (int): image size to be input to ViT.
513
+ pretrained (bool): True when use pretrained CNN or ViT, otherwise False.
514
+ """
515
+ assert (net_name != 'MLP'), 'net_name should not be MLP.'
516
+
517
+ super().__init__()
518
+
519
+ self.net_name = net_name
520
+ self.num_outputs_for_label = num_outputs_for_label
521
+ self.mlp_num_inputs = mlp_num_inputs
522
+ self.in_channel = in_channel
523
+ self.vit_image_size = vit_image_size
524
+ self.pretrained = pretrained
525
+
526
+ # Extractor of MLP and Net
527
+ self.extractor_mlp = self.construct_extractor(net_name='MLP', mlp_num_inputs=self.mlp_num_inputs)
528
+ self.extractor_net = self.construct_extractor(
529
+ net_name=self.net_name,
530
+ in_channel=self.in_channel,
531
+ vit_image_size=self.vit_image_size,
532
+ pretrained=self.pretrained
533
+ )
534
+ self.aux_module = self.construct_aux_module(self.net_name)
535
+
536
+ # Intermediate MLP
537
+ self.in_features_from_mlp = self.get_classifier_in_features('MLP')
538
+ self.in_features_from_net = self.get_classifier_in_features(self.net_name)
539
+ self.inter_mlp_in_feature = self.in_features_from_mlp + self.in_features_from_net
540
+ self.inter_mlp = self.MLPNet(mlp_num_inputs=self.inter_mlp_in_feature, inplace=False)
541
+
542
+ # Multi classifier
543
+ self.multi_classifier = self.construct_multi_classifier(net_name='MLP', num_outputs_for_label=num_outputs_for_label)
544
+
545
+ def forward(self, x_mlp: torch.Tensor, x_net: torch.Tensor) -> Dict[str, torch.Tensor]:
546
+ """
547
+ Forward.
548
+
549
+ Args:
550
+ x_mlp (torch.Tensor): tabular data
551
+ x_net (torch.Tensor): image
552
+
553
+ Returns:
554
+ Dict[str, torch.Tensor]: output
555
+ """
556
+ out_mlp = self.extractor_mlp(x_mlp)
557
+ out_net = self.extractor_net(x_net)
558
+ out_net = self.aux_module(out_net)
559
+
560
+ out_features = torch.cat([out_mlp, out_net], dim=1)
561
+ out_features = self.inter_mlp(out_features)
562
+ output = self.multi_forward(out_features)
563
+ return output
564
+
565
+
566
+ def create_net(
567
+ mlp: Optional[str] = None,
568
+ net: Optional[str] = None,
569
+ num_outputs_for_label: Dict[str, int] = None,
570
+ mlp_num_inputs: int = None,
571
+ in_channel: int = None,
572
+ vit_image_size: int = None,
573
+ pretrained: bool = None
574
+ ) -> nn.Module:
575
+ """
576
+ Create network.
577
+
578
+ Args:
579
+ mlp (Optional[str]): 'MLP' or None
580
+ net (Optional[str]): CNN, ViT name or None
581
+ num_outputs_for_label (Dict[str, int]): number of outputs for each label
582
+ mlp_num_inputs (int): number of input of MLP.
583
+ in_channel (int): number of image channel, ie gray scale(=1) or color image(=3).
584
+ vit_image_size (int): image size to be input to ViT.
585
+ pretrained (bool): True when use pretrained CNN or ViT, otherwise False.
586
+
587
+ Returns:
588
+ nn.Module: network
589
+ """
590
+ _isMLPModel = (mlp is not None) and (net is None)
591
+ _isCVModel = (mlp is None) and (net is not None)
592
+ _isFusion = (mlp is not None) and (net is not None)
593
+
594
+ if _isMLPModel:
595
+ multi_net = MultiNet(
596
+ net_name='MLP',
597
+ num_outputs_for_label=num_outputs_for_label,
598
+ mlp_num_inputs=mlp_num_inputs,
599
+ in_channel=in_channel,
600
+ vit_image_size=vit_image_size,
601
+ pretrained=False # No need of pretrained for MLP
602
+ )
603
+ elif _isCVModel:
604
+ multi_net = MultiNet(
605
+ net_name=net,
606
+ num_outputs_for_label=num_outputs_for_label,
607
+ mlp_num_inputs=mlp_num_inputs,
608
+ in_channel=in_channel,
609
+ vit_image_size=vit_image_size,
610
+ pretrained=pretrained
611
+ )
612
+ elif _isFusion:
613
+ multi_net = MultiNetFusion(
614
+ net_name=net,
615
+ num_outputs_for_label=num_outputs_for_label,
616
+ mlp_num_inputs=mlp_num_inputs,
617
+ in_channel=in_channel,
618
+ vit_image_size=vit_image_size,
619
+ pretrained=pretrained
620
+ )
621
+ else:
622
+ raise ValueError(f"Invalid model type: mlp={mlp}, net={net}.")
623
+
624
+ return multi_net
lib/component/optimizer.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # -*- coding: utf-8 -*-
3
+
4
+ import torch.optim as optim
5
+ import torch.nn as nn
6
+
7
+
8
+ def set_optimizer(optimizer_name: str, network: nn.Module, lr: float) -> optim:
9
+ """
10
+ Set optimizer.
11
+ Args:
12
+ optimizer_name (str): criterion name
13
+ network (torch.nn.Module): network
14
+ lr (float): learning rate
15
+ Returns:
16
+ torch.optim: optimizer
17
+ """
18
+ optimizers = {
19
+ 'SGD': optim.SGD,
20
+ 'Adadelta': optim.Adadelta,
21
+ 'Adam': optim.Adam,
22
+ 'RMSprop': optim.RMSprop,
23
+ 'RAdam': optim.RAdam
24
+ }
25
+
26
+ assert (optimizer_name in optimizers), f"No specified optimizer: {optimizer_name}."
27
+
28
+ _optim = optimizers[optimizer_name]
29
+
30
+ if lr is None:
31
+ optimizer = _optim(network.parameters())
32
+ else:
33
+ optimizer = _optim(network.parameters(), lr=lr)
34
+ return optimizer
lib/dataloader.py ADDED
@@ -0,0 +1,400 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # -*- coding: utf-8 -*-
3
+
4
+ import numpy as np
5
+ import torch
6
+ import torchvision.transforms as transforms
7
+ from torch.utils.data.dataset import Dataset
8
+ from torch.utils.data.dataloader import DataLoader
9
+ from torch.utils.data.sampler import WeightedRandomSampler
10
+ from PIL import Image
11
+ from sklearn.preprocessing import MinMaxScaler
12
+ import pickle
13
+ from .logger import BaseLogger
14
+ from typing import List, Dict, Union
15
+ import pandas as pd
16
+
17
+
18
+ logger = BaseLogger.get_logger(__name__)
19
+
20
+
21
+ class PrivateAugment(torch.nn.Module):
22
+ """
23
+ Augmentation defined privately.
24
+ Variety of augmentation can be written in this class if necessary.
25
+ """
26
+ # For X-ray photo.
27
+ xray_augs_list = [
28
+ transforms.RandomAffine(degrees=(-3, 3), translate=(0.02, 0.02)),
29
+ transforms.RandomAdjustSharpness(sharpness_factor=2),
30
+ transforms.RandomAutocontrast()
31
+ ]
32
+
33
+
34
+ class InputDataMixin:
35
+ """
36
+ Class to normalizes input data.
37
+ """
38
+ def _make_scaler(self) -> MinMaxScaler:
39
+ """
40
+ Make scaler to normalize input data by min-max normalization with train data.
41
+
42
+ Returns:
43
+ MinMaxScaler: scaler
44
+ """
45
+ scaler = MinMaxScaler()
46
+ _df_train = self.df_source[self.df_source['split'] == 'train'] # should be normalized with min and max of training data
47
+ _ = scaler.fit(_df_train[self.input_list]) # fit only
48
+ return scaler
49
+
50
+ def save_scaler(self, save_path :str) -> None:
51
+ """
52
+ Save scaler
53
+
54
+ Args:
55
+ save_path (str): path for saving scaler.
56
+ """
57
+ #save_scaler_path = Path(save_datetime_dir, 'scaler.pkl')
58
+ with open(save_path, 'wb') as f:
59
+ pickle.dump(self.scaler, f)
60
+
61
+ def load_scaler(self, scaler_path :str) -> None:
62
+ """
63
+ Load scaler.
64
+
65
+ Args:
66
+ scaler_path (str): path to scaler
67
+ """
68
+ with open(scaler_path, 'rb') as f:
69
+ scaler = pickle.load(f)
70
+ return scaler
71
+
72
+ def _normalize_inputs(self, df_inputs: pd.DataFrame) -> torch.FloatTensor:
73
+ """
74
+ Normalize inputs.
75
+
76
+ Args:
77
+ df_inputs (pd.DataFrame): DataFrame of inputs
78
+
79
+ Returns:
80
+ torch.FloatTensor: normalized inputs
81
+
82
+ Note:
83
+ After iloc[[idx], index_input_list], pd.DataFrame is obtained.
84
+ DataFrame fits the input type of self.scaler.transform.
85
+ However, after normalizing, the shape of inputs_value is (1, N), where N is the number of input values.
86
+ Since the shape (1, N) is not acceptable when forwarding, convert (1, N) -> (N,) is needed.
87
+ """
88
+ inputs_value = self.scaler.transform(df_inputs).reshape(-1) # np.float64
89
+ inputs_value = np.array(inputs_value, dtype=np.float32) # -> np.float32
90
+ inputs_value = torch.from_numpy(inputs_value).clone() # -> torch.float32
91
+ return inputs_value
92
+
93
+ def _load_input_value_if_mlp(self, idx: int) -> Union[torch.FloatTensor, str]:
94
+ """
95
+ Load input values after converting them into tensor if MLP is used.
96
+
97
+ Args:
98
+ idx (int): index
99
+
100
+ Returns:
101
+ Union[torch.Tensor[float], str]: tensor of input values, or empty string
102
+ """
103
+ inputs_value = ''
104
+
105
+ if self.params.mlp is None:
106
+ return inputs_value
107
+
108
+ index_input_list = [self.col_index_dict[input] for input in self.input_list]
109
+ _df_inputs = self.df_split.iloc[[idx], index_input_list]
110
+ inputs_value = self._normalize_inputs( _df_inputs)
111
+ return inputs_value
112
+
113
+
114
+ class ImageMixin:
115
+ """
116
+ Class to normalize and transform image.
117
+ """
118
+ def _make_augmentations(self) -> List:
119
+ """
120
+ Define which augmentation is applied.
121
+
122
+ When training, augmentation is needed for train data only.
123
+ When test, no need of augmentation.
124
+ """
125
+ _augmentation = []
126
+ if (self.params.isTrain) and (self.split == 'train'):
127
+ if self.params.augmentation == 'xrayaug':
128
+ _augmentation = PrivateAugment.xray_augs_list
129
+ elif self.params.augmentation == 'trivialaugwide':
130
+ _augmentation.append(transforms.TrivialAugmentWide())
131
+ elif self.params.augmentation == 'randaug':
132
+ _augmentation.append(transforms.RandAugment())
133
+ else:
134
+ # ie. self.params.augmentation == 'no':
135
+ pass
136
+
137
+ _augmentation = transforms.Compose(_augmentation)
138
+ return _augmentation
139
+
140
+ def _make_transforms(self) -> List:
141
+ """
142
+ Make list of transforms.
143
+
144
+ Returns:
145
+ list of transforms: image normalization
146
+ """
147
+ _transforms = []
148
+ _transforms.append(transforms.ToTensor())
149
+
150
+ if self.params.normalize_image == 'yes':
151
+ # transforms.Normalize accepts only Tensor.
152
+ if self.params.in_channel == 1:
153
+ _transforms.append(transforms.Normalize(mean=(0.5, ), std=(0.5, )))
154
+ else:
155
+ # ie. self.params.in_channel == 3
156
+ _transforms.append(transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]))
157
+
158
+ _transforms = transforms.Compose(_transforms)
159
+ return _transforms
160
+
161
+ def _open_image_in_channel(self, imgpath: str, in_channel: int) -> Image:
162
+ """
163
+ Open image in channel.
164
+
165
+ Args:
166
+ imgpath (str): path to image
167
+ in_channel (int): channel, or 1 or 3
168
+
169
+ Returns:
170
+ Image: PIL image
171
+ """
172
+ if in_channel == 1:
173
+ image = Image.open(imgpath).convert('L') # eg. np.array(image).shape = (64, 64)
174
+ return image
175
+ else:
176
+ # ie. self.params.in_channel == 3
177
+ image = Image.open(imgpath).convert('RGB') # eg. np.array(image).shape = (64, 64, 3)
178
+ return image
179
+
180
+ def _load_image_if_cnn(self, idx: int) -> Union[torch.Tensor, str]:
181
+ """
182
+ Load image and convert it to tensor if any of CNN or ViT is used.
183
+
184
+ Args:
185
+ idx (int): index
186
+
187
+ Returns:
188
+ Union[torch.Tensor[float], str]: tensor converted from image, or empty string
189
+ """
190
+ image = ''
191
+
192
+ if self.params.net is None:
193
+ return image
194
+
195
+ imgpath = self.df_split.iat[idx, self.col_index_dict['imgpath']]
196
+ image = self._open_image_in_channel(imgpath, self.params.in_channel)
197
+ image = self.augmentation(image)
198
+ image = self.transform(image)
199
+ return image
200
+
201
+
202
+ class DeepSurvMixin:
203
+ """
204
+ Class to handle required data for deepsurv.
205
+ """
206
+ def _load_periods_if_deepsurv(self, idx: int) -> Union[torch.FloatTensor, str]:
207
+ """
208
+ Return period if deepsurv.
209
+
210
+ Args:
211
+ idx (int): index
212
+
213
+ Returns:
214
+ Union[torch.FloatTensor, str]: period, or empty string
215
+ """
216
+ periods = ''
217
+
218
+ if self.params.task != 'deepsurv':
219
+ return periods
220
+
221
+ assert (self.params.task == 'deepsurv') and (len(self.label_list) == 1), 'Deepsurv cannot work in multi-label.'
222
+ periods = self.df_split.iat[idx, self.col_index_dict[self.period_name]] # int64
223
+ periods = np.array(periods, dtype=np.float32) # -> np.float32
224
+ periods = torch.from_numpy(periods).clone() # -> torch.float32
225
+ return periods
226
+
227
+
228
+ class DataSetWidget(InputDataMixin, ImageMixin, DeepSurvMixin):
229
+ """
230
+ Class for a widget to inherit multiple classes simultaneously.
231
+ """
232
+ pass
233
+
234
+
235
+ class LoadDataSet(Dataset, DataSetWidget):
236
+ """
237
+ Dataset for split.
238
+ """
239
+ def __init__(
240
+ self,
241
+ params,
242
+ split: str
243
+ ) -> None:
244
+ """
245
+ Args:
246
+ params (ParamSet): parameter for model
247
+ split (str): split
248
+ """
249
+ self.params = params
250
+ self.df_source = self.params.df_source
251
+ self.split = split
252
+
253
+ self.input_list = self.params.input_list
254
+ self.label_list = self.params.label_list
255
+
256
+ if self.params.task == 'deepsurv':
257
+ self.period_name = self.params.period_name
258
+
259
+ self.df_split = self.df_source[self.df_source['split'] == self.split]
260
+ self.col_index_dict = {col_name: self.df_split.columns.get_loc(col_name) for col_name in self.df_split.columns}
261
+
262
+ # For input data
263
+ if self.params.mlp is not None:
264
+ assert (self.input_list != []), f"input list is empty."
265
+ if params.isTrain:
266
+ self.scaler = self._make_scaler()
267
+ else:
268
+ # load scaler used at training.
269
+ self.scaler = self.load_scaler(self.params.scaler_path)
270
+
271
+ # For image
272
+ if self.params.net is not None:
273
+ self.augmentation = self._make_augmentations()
274
+ self.transform = self._make_transforms()
275
+
276
+ def __len__(self) -> int:
277
+ """
278
+ Return length of DataFrame.
279
+
280
+ Returns:
281
+ int: length of DataFrame
282
+ """
283
+ return len(self.df_split)
284
+
285
+ def _load_label(self, idx: int) -> Dict[str, Union[int, float]]:
286
+ """
287
+ Return labels.
288
+ If no column of label when csv of external dataset is used,
289
+ empty dictionary is returned.
290
+
291
+ Args:
292
+ idx (int): index
293
+
294
+ Returns:
295
+ Dict[str, Union[int, float]]: dictionary of label name and its value
296
+ """
297
+ # For checking if columns of labels exist when used csv for external dataset.
298
+ label_list_in_split = list(self.df_split.columns[self.df_split.columns.str.startswith('label')])
299
+ label_dict = dict()
300
+ if label_list_in_split != []:
301
+ for label_name in self.label_list:
302
+ label_dict[label_name] = self.df_split.iat[idx, self.col_index_dict[label_name]]
303
+ else:
304
+ # no label
305
+ pass
306
+ return label_dict
307
+
308
+ def __getitem__(self, idx: int) -> Dict:
309
+ """
310
+ Return data row specified by index.
311
+
312
+ Args:
313
+ idx (int): index
314
+
315
+ Returns:
316
+ Dict: dictionary of data to be passed model
317
+ """
318
+ uniqID = self.df_split.iat[idx, self.col_index_dict['uniqID']]
319
+ group = self.df_split.iat[idx, self.col_index_dict['group']]
320
+ imgpath = self.df_split.iat[idx, self.col_index_dict['imgpath']]
321
+ split = self.df_split.iat[idx, self.col_index_dict['split']]
322
+ inputs_value = self._load_input_value_if_mlp(idx)
323
+ image = self._load_image_if_cnn(idx)
324
+ label_dict = self._load_label(idx)
325
+ periods = self._load_periods_if_deepsurv(idx)
326
+
327
+ _data = {
328
+ 'uniqID': uniqID,
329
+ 'group': group,
330
+ 'imgpath': imgpath,
331
+ 'split': split,
332
+ 'inputs': inputs_value,
333
+ 'image': image,
334
+ 'labels': label_dict,
335
+ 'periods': periods
336
+ }
337
+ return _data
338
+
339
+
340
+ def _make_sampler(split_data: LoadDataSet) -> WeightedRandomSampler:
341
+ """
342
+ Make sampler.
343
+
344
+ Args:
345
+ split_data (LoadDataSet): dataset
346
+
347
+ Returns:
348
+ WeightedRandomSampler: sampler
349
+ """
350
+ _target = []
351
+ for _, data in enumerate(split_data):
352
+ _target.append(list(data['labels'].values())[0])
353
+
354
+ class_sample_count = np.array([len(np.where(_target == t)[0]) for t in np.unique(_target)])
355
+ weight = 1. / class_sample_count
356
+ samples_weight = np.array([weight[t] for t in _target])
357
+ sampler = WeightedRandomSampler(samples_weight, len(samples_weight))
358
+ return sampler
359
+
360
+
361
+ def create_dataloader(
362
+ params,
363
+ split: str = None
364
+ ) -> DataLoader:
365
+ """
366
+ Create data loader ofr split.
367
+
368
+ Args:
369
+ params (ParamSet): parameter for dataloader
370
+ split (str): split. Defaults to None.
371
+
372
+ Returns:
373
+ DataLoader: data loader
374
+ """
375
+ split_data = LoadDataSet(params, split)
376
+
377
+ if params.isTrain:
378
+ batch_size = params.batch_size
379
+ shuffle = True
380
+ else:
381
+ batch_size = params.test_batch_size
382
+ shuffle = False
383
+
384
+ if params.sampler == 'yes':
385
+ assert ((params.task == 'classification') or (params.task == 'deepsurv')), 'Cannot make sampler in regression.'
386
+ assert (len(params.label_list) == 1), 'Cannot make sampler for multi-label.'
387
+ shuffle = False
388
+ sampler = _make_sampler(split_data)
389
+ else:
390
+ # When params.sampler == 'no'
391
+ sampler = None
392
+
393
+ split_loader = DataLoader(
394
+ dataset=split_data,
395
+ batch_size=batch_size,
396
+ shuffle=shuffle,
397
+ num_workers=0,
398
+ sampler=sampler
399
+ )
400
+ return split_loader
lib/framework.py ADDED
@@ -0,0 +1,373 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # -*- coding: utf-8 -*-
3
+
4
+ from pathlib import Path
5
+ import copy
6
+ from abc import ABC, abstractmethod
7
+ import torch
8
+ import torch.nn as nn
9
+ from .component import create_net
10
+ from .logger import BaseLogger
11
+ from lib import ParamSet
12
+ from typing import List, Dict, Tuple, Union
13
+
14
+ # Alias of typing
15
+ # eg. {'labels': {'label_A: torch.Tensor([0, 1, ...]), ...}}
16
+ LabelDict = Dict[str, Dict[str, Union[torch.IntTensor, torch.FloatTensor]]]
17
+
18
+
19
+ logger = BaseLogger.get_logger(__name__)
20
+
21
+
22
+ class BaseModel(ABC):
23
+ """
24
+ Class to construct model. This class is the base class to construct model.
25
+ """
26
+ def __init__(self, params: ParamSet) -> None:
27
+ """
28
+ Class to define Model
29
+
30
+ Args:
31
+ param (ParamSet): parameters
32
+ """
33
+ self.params = params
34
+ self.device = self.params.device
35
+
36
+ self.network = create_net(
37
+ mlp=self.params.mlp,
38
+ net=self.params.net,
39
+ num_outputs_for_label=self.params.num_outputs_for_label,
40
+ mlp_num_inputs=self.params.mlp_num_inputs,
41
+ in_channel=self.params.in_channel,
42
+ vit_image_size=self.params.vit_image_size,
43
+ pretrained=self.params.pretrained
44
+ )
45
+ self.network.to(self.device)
46
+
47
+ # variables to keep temporary best_weight and best_epoch
48
+ self.acting_best_weight = None
49
+ self.acting_best_epoch = None
50
+
51
+ def train(self) -> None:
52
+ """
53
+ Make network training mode.
54
+ """
55
+ self.network.train()
56
+
57
+ def eval(self) -> None:
58
+ """
59
+ Make network evaluation mode.
60
+ """
61
+ self.network.eval()
62
+
63
+ @abstractmethod
64
+ def set_data(
65
+ self,
66
+ data: Dict
67
+ ) -> Tuple[
68
+ Dict[str, torch.FloatTensor],
69
+ Dict[str, Union[LabelDict, torch.IntTensor, nn.Module]]
70
+ ]:
71
+ raise NotImplementedError
72
+
73
+ def store_weight(self, at_epoch: int = None) -> None:
74
+ """
75
+ Store weight and epoch number when it is saved.
76
+
77
+ Args:
78
+ at_epoch (int): epoch number when save weight
79
+ """
80
+ self.acting_best_epoch = at_epoch
81
+
82
+ _network = copy.deepcopy(self.network)
83
+ if hasattr(_network, 'module'):
84
+ # When DataParallel used, move weight to CPU.
85
+ self.acting_best_weight = copy.deepcopy(_network.module.to(torch.device('cpu')).state_dict())
86
+ else:
87
+ self.acting_best_weight = copy.deepcopy(_network.state_dict())
88
+
89
+ def save_weight(self, save_datetime_dir: str, as_best: bool = None) -> None:
90
+ """
91
+ Save weight.
92
+
93
+ Args:
94
+ save_datetime_dir (str): save_datetime_dir
95
+ as_best (bool): True if weight is saved as best, otherwise False. Defaults to None.
96
+ """
97
+
98
+ save_dir = Path(save_datetime_dir, 'weights')
99
+ save_dir.mkdir(parents=True, exist_ok=True)
100
+ save_name = 'weight_epoch-' + str(self.acting_best_epoch).zfill(3) + '.pt'
101
+ save_path = Path(save_dir, save_name)
102
+
103
+ if as_best:
104
+ save_name_as_best = 'weight_epoch-' + str(self.acting_best_epoch).zfill(3) + '_best' + '.pt'
105
+ save_path_as_best = Path(save_dir, save_name_as_best)
106
+ if save_path.exists():
107
+ # Check if best weight already saved. If exists, rename with '_best'
108
+ save_path.rename(save_path_as_best)
109
+ else:
110
+ torch.save(self.acting_best_weight, save_path_as_best)
111
+ else:
112
+ save_name = 'weight_epoch-' + str(self.acting_best_epoch).zfill(3) + '.pt'
113
+ torch.save(self.acting_best_weight, save_path)
114
+
115
+ def load_weight(self, weight_path: Path) -> None:
116
+ """
117
+ Load wight from weight_path.
118
+
119
+ Args:
120
+ weight_path (Path): path to weight
121
+ """
122
+ logger.info(f"Load weight: {weight_path}.\n")
123
+ weight = torch.load(weight_path)
124
+ self.network.load_state_dict(weight)
125
+
126
+
127
+ class ModelMixin:
128
+ def to_gpu(self, gpu_ids: List[int]) -> None:
129
+ """
130
+ Make model compute on the GPU.
131
+
132
+ Args:
133
+ gpu_ids (List[int]): GPU ids
134
+ """
135
+ if gpu_ids != []:
136
+ assert torch.cuda.is_available(), 'No available GPU on this machine.'
137
+ self.network = nn.DataParallel(self.network, device_ids=gpu_ids)
138
+
139
+ def init_network(self) -> None:
140
+ """
141
+ Initialize network.
142
+ This method is used at test to reset the current weight by redefining network.
143
+ """
144
+ self.network = create_net(
145
+ mlp=self.params.mlp,
146
+ net=self.params.net,
147
+ num_outputs_for_label=self.params.num_outputs_for_label,
148
+ mlp_num_inputs=self.params.mlp_num_inputs,
149
+ in_channel=self.params.in_channel,
150
+ vit_image_size=self.params.vit_image_size,
151
+ pretrained=self.params.pretrained
152
+ )
153
+ self.network.to(self.device)
154
+
155
+ class ModelWidget(BaseModel, ModelMixin):
156
+ """
157
+ Class for a widget to inherit multiple classes simultaneously
158
+ """
159
+ pass
160
+
161
+
162
+ class MLPModel(ModelWidget):
163
+ """
164
+ Class for MLP model
165
+ """
166
+
167
+ def __init__(self, params: ParamSet) -> None:
168
+ """
169
+ Args:
170
+ params: (ParamSet): parameters
171
+ """
172
+ super().__init__(params)
173
+
174
+ def set_data(
175
+ self,
176
+ data: Dict
177
+ ) -> Tuple[
178
+ Dict[str, torch.FloatTensor],
179
+ Dict[str, Union[LabelDict, torch.IntTensor, nn.Module]]
180
+ ]:
181
+ """
182
+ Unpack data for forwarding of MLP and calculating loss
183
+ by passing them to device.
184
+ When deepsurv, period and network are also returned.
185
+
186
+ Args:
187
+ data (Dict): dictionary of data
188
+
189
+ Returns:
190
+ Tuple[
191
+ Dict[str, torch.FloatTensor],
192
+ Dict[str, Union[LabelDict, torch.IntTensor, nn.Module]]
193
+ ]: input of model and data for calculating loss.
194
+ eg.
195
+ ([inputs], [labels]), or ([inputs], [labels, periods, network]) when deepsurv
196
+ """
197
+ in_data = {'inputs': data['inputs'].to(self.device)}
198
+ labels = {'labels': {label_name: label.to(self.device) for label_name, label in data['labels'].items()}}
199
+
200
+ if not any(data['periods']):
201
+ return in_data, labels
202
+
203
+ # When deepsurv
204
+ labels = {
205
+ **labels,
206
+ **{'periods': data['periods'].to(self.device), 'network': self.network.to(self.device)}
207
+ }
208
+ return in_data, labels
209
+
210
+ def __call__(self, in_data: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
211
+ """
212
+ Forward
213
+
214
+ Args:
215
+ in_data (Dict[str, torch.Tensor]): data to be input into model
216
+
217
+ Returns:
218
+ Dict[str, torch.Tensor]: output
219
+ """
220
+ inputs = in_data['inputs']
221
+ output = self.network(inputs)
222
+ return output
223
+
224
+
225
+ class CVModel(ModelWidget):
226
+ """
227
+ Class for CNN or ViT model
228
+ """
229
+ def __init__(self, params: ParamSet) -> None:
230
+ """
231
+ Args:
232
+ params: (ParamSet): parameters
233
+ """
234
+ super().__init__(params)
235
+
236
+ def set_data(
237
+ self,
238
+ data: Dict
239
+ ) -> Tuple[
240
+ Dict[str, torch.FloatTensor],
241
+ Dict[str, Union[LabelDict, torch.IntTensor, nn.Module]]
242
+ ]:
243
+ """
244
+ Unpack data for forwarding of CNN or ViT and calculating loss by passing them to device.
245
+ When deepsurv, period and network are also returned.
246
+
247
+ Args:
248
+ data (Dict): dictionary of data
249
+
250
+ Returns:
251
+ Tuple[
252
+ Dict[str, torch.FloatTensor],
253
+ Dict[str, Union[LabelDict, torch.IntTensor, nn.Module]]
254
+ ]: input of model and data for calculating loss.
255
+ eg.
256
+ ([image], [labels]), or ([image], [labels, periods, network]) when deepsurv
257
+ """
258
+ in_data = {'image': data['image'].to(self.device)}
259
+ labels = {'labels': {label_name: label.to(self.device) for label_name, label in data['labels'].items()}}
260
+
261
+ if not any(data['periods']):
262
+ return in_data, labels
263
+
264
+ # When deepsurv
265
+ labels = {
266
+ **labels,
267
+ **{'periods': data['periods'].to(self.device), 'network': self.network.to(self.device)}
268
+ }
269
+ return in_data, labels
270
+
271
+ def __call__(self, in_data: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
272
+ """
273
+ Forward
274
+
275
+ Args:
276
+ in_data (Dict[str, torch.Tensor]): data to be input into model
277
+
278
+ Returns:
279
+ Dict[str, torch.Tensor]: output
280
+ """
281
+ image = in_data['image']
282
+ output = self.network(image)
283
+ return output
284
+
285
+
286
+ class FusionModel(ModelWidget):
287
+ """
288
+ Class for MLP+CNN or MLP+ViT model.
289
+ """
290
+ def __init__(self, params: ParamSet) -> None:
291
+ """
292
+ Args:
293
+ params: (ParamSet): parameters
294
+ """
295
+ super().__init__(params)
296
+
297
+ def set_data(
298
+ self,
299
+ data: Dict
300
+ ) -> Tuple[
301
+ Dict[str, torch.FloatTensor],
302
+ Dict[str, Union[LabelDict, torch.IntTensor, nn.Module]]
303
+ ]:
304
+ """
305
+ Unpack data for forwarding of MLP+CNN or MLP+ViT and calculating loss
306
+ by passing them to device.
307
+ When deepsurv, period and network are also returned.
308
+
309
+ Args:
310
+ data (Dict): dictionary of data
311
+
312
+ Returns:
313
+ Tuple[
314
+ Dict[str, torch.FloatTensor],
315
+ Dict[str, Union[LabelDict, torch.IntTensor, nn.Module]]
316
+ ]: input of model and data for calculating loss.
317
+ eg.
318
+ ([inputs, image], [labels]), or ([inputs, image], [labels, periods, network]) when deepsurv
319
+ """
320
+ in_data = {
321
+ 'inputs': data['inputs'].to(self.device),
322
+ 'image': data['image'].to(self.device)
323
+ }
324
+ labels = {'labels': {label_name: label.to(self.device) for label_name, label in data['labels'].items()}}
325
+
326
+ if not any(data['periods']):
327
+ return in_data, labels
328
+
329
+ # When deepsurv
330
+ labels = {
331
+ **labels,
332
+ **{'periods': data['periods'].to(self.device), 'network': self.network.to(self.device)}
333
+ }
334
+ return in_data, labels
335
+
336
+ def __call__(self, in_data: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
337
+ """
338
+ Forward
339
+
340
+ Args:
341
+ in_data (Dict[str, torch.Tensor]): data to be input into model
342
+
343
+ Returns:
344
+ Dict[str, torch.Tensor]: output
345
+ """
346
+ inputs = in_data['inputs']
347
+ image = in_data['image']
348
+ output = self.network(inputs, image)
349
+ return output
350
+
351
+
352
+ def create_model(params: ParamSet) -> nn.Module:
353
+ """
354
+ Construct model.
355
+
356
+ Args:
357
+ params (ParamSet): parameters
358
+
359
+ Returns:
360
+ nn.Module: model
361
+ """
362
+ _isMLPModel = (params.mlp is not None) and (params.net is None)
363
+ _isCVModel = (params.mlp is None) and (params.net is not None)
364
+ _isFusionModel = (params.mlp is not None) and (params.net is not None)
365
+
366
+ if _isMLPModel:
367
+ return MLPModel(params)
368
+ elif _isCVModel:
369
+ return CVModel(params)
370
+ elif _isFusionModel:
371
+ return FusionModel(params)
372
+ else:
373
+ raise ValueError(f"Invalid model type: mlp={params.mlp}, net={params.net}.")
lib/logger.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # -*- coding: utf-8 -*-
3
+
4
+ from pathlib import Path
5
+ import logging
6
+
7
+
8
+ class BaseLogger:
9
+ """
10
+ Class for defining logger.
11
+ """
12
+ _unexecuted_configure = True
13
+
14
+ @classmethod
15
+ def get_logger(cls, name: str) -> logging.Logger:
16
+ """
17
+ Set logger.
18
+ Args:
19
+ name (str): If needed, potentially hierarchical name is desired, eg. lib.net, lib.dataloader, etc.
20
+ For the details, see https://docs.python.org/3/library/logging.html?highlight=logging#module-logging.
21
+ Returns:
22
+ logging.Logger: logger
23
+ """
24
+ if cls._unexecuted_configure:
25
+ cls._init_logger()
26
+
27
+ return logging.getLogger('nervus.{}'.format(name))
28
+
29
+ @classmethod
30
+ def _init_logger(cls) -> None:
31
+ """
32
+ Configure logger.
33
+ """
34
+ _root_logger = logging.getLogger('nervus')
35
+ _root_logger.setLevel(logging.DEBUG)
36
+ formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')
37
+
38
+ log_dir = Path('logs')
39
+ log_dir.mkdir(parents=True, exist_ok=True)
40
+ log_path = Path(log_dir, 'log.log')
41
+
42
+ # file handler
43
+ ## upper warning
44
+ fh_err = logging.FileHandler(log_path)
45
+ fh_err.setLevel(logging.WARNING)
46
+ fh_err.setFormatter(formatter)
47
+ fh_err.addFilter(lambda log_record: not ('BdbQuit' in str(log_record.exc_info)) and (log_record.levelno >= logging.WARNING))
48
+ _root_logger.addHandler(fh_err)
49
+
50
+ ## lower warning
51
+ fh = logging.FileHandler(log_path)
52
+ fh.setLevel(logging.DEBUG)
53
+ fh.addFilter(lambda log_record: log_record.levelno < logging.WARNING)
54
+ _root_logger.addHandler(fh)
55
+
56
+ # stream handler
57
+ ## upper warning
58
+ ch_err = logging.StreamHandler()
59
+ ch_err.setLevel(logging.WARNING)
60
+ ch_err.setFormatter(formatter)
61
+ ch_err.addFilter(lambda log_record: log_record.levelno >= logging.WARNING)
62
+ _root_logger.addHandler(ch_err)
63
+
64
+ ## lower warning
65
+ ch = logging.StreamHandler()
66
+ ch.setLevel(logging.DEBUG)
67
+ ch.addFilter(lambda log_record: log_record.levelno < logging.WARNING)
68
+ _root_logger.addHandler(ch)
69
+
70
+ cls._unexecuted_configure = False
71
+
lib/metrics.py ADDED
@@ -0,0 +1,623 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # -*- coding: utf-8 -*-
3
+
4
+ from pathlib import Path
5
+ import numpy as np
6
+ import pandas as pd
7
+ from sklearn import metrics
8
+ from sklearn.preprocessing import label_binarize
9
+ import matplotlib.pyplot as plt
10
+ from matplotlib import colors as mcolors
11
+ from .logger import BaseLogger
12
+ from typing import Dict, Union
13
+
14
+
15
+ logger = BaseLogger.get_logger(__name__)
16
+
17
+
18
+ class MetricsData:
19
+ """
20
+ Class to store metrics as class variable.
21
+ Metrics are defined depending on task.
22
+
23
+ For ROC
24
+ self.fpr: np.ndarray
25
+ self.tpr: np.ndarray
26
+ self.auc: float
27
+
28
+ For Regression
29
+ self.y_obs: np.ndarray
30
+ self.y_pred: np.ndarray
31
+ self.r2: float
32
+
33
+ For DeepSurv
34
+ self.c_index: float
35
+ """
36
+ def __init__(self) -> None:
37
+ pass
38
+
39
+
40
+ class LabelMetrics:
41
+ """
42
+ Class to store metrics of each split for each label.
43
+ """
44
+ def __init__(self) -> None:
45
+ """
46
+ Metrics of split, ie 'val' and 'test'
47
+ """
48
+ self.val = MetricsData()
49
+ self.test = MetricsData()
50
+
51
+ def set_label_metrics(self, split: str, attr: str, value: Union[np.ndarray, float]) -> None:
52
+ """
53
+ Set value as appropriate metrics of split.
54
+
55
+ Args:
56
+ split (str): split
57
+ attr (str): attribute name as follows:
58
+ classification: 'fpr', 'tpr', or 'auc',
59
+ regression: 'y_obs'(ground truth), 'y_pred'(prediction) or 'r2', or
60
+ deepsurv: 'c_index'
61
+ value (Union[np.ndarray,float]): value of attr
62
+ """
63
+ setattr(getattr(self, split), attr, value)
64
+
65
+ def get_label_metrics(self, split: str, attr: str) -> Union[np.ndarray, float]:
66
+ """
67
+ Return value of metrics of split.
68
+
69
+ Args:
70
+ split (str): split
71
+ attr (str): metrics name
72
+
73
+ Returns:
74
+ Union[np.ndarray,float]: value of attr
75
+ """
76
+ return getattr(getattr(self, split), attr)
77
+
78
+
79
+ class ROCMixin:
80
+ """
81
+ Class for calculating ROC and AUC.
82
+ """
83
+ def _set_roc(self, label_metrics: LabelMetrics, split: str, fpr: np.ndarray, tpr: np.ndarray) -> None:
84
+ """
85
+ Set fpr, tpr, and auc.
86
+
87
+ Args:
88
+ label_metrics (LabelMetrics): metrics of 'val' and 'test'
89
+ split (str): 'val' or 'test'
90
+ fpr (np.ndarray): FPR
91
+ tpr (np.ndarray): TPR
92
+
93
+ self.metrics_kind = 'auc' is defined in class ClsEval below.
94
+ """
95
+ label_metrics.set_label_metrics(split, 'fpr', fpr)
96
+ label_metrics.set_label_metrics(split, 'tpr', tpr)
97
+ label_metrics.set_label_metrics(split, self.metrics_kind, metrics.auc(fpr, tpr))
98
+
99
+ def _cal_label_roc_binary(self, label_name: str, df_group: pd.DataFrame) -> LabelMetrics:
100
+ """
101
+ Calculate ROC for binary class.
102
+
103
+ Args:
104
+ label_name (str): label name
105
+ df_group (pd.DataFrame): likelihood for group
106
+
107
+ Returns:
108
+ LabelMetrics: metrics of 'val' and 'test'
109
+ """
110
+ required_columns = [column_name for column_name in df_group.columns if label_name in column_name] + ['split']
111
+ df_label = df_group[required_columns]
112
+ POSITIVE = 1
113
+ positive_pred_name = 'pred_' + label_name + '_' + str(POSITIVE)
114
+
115
+ # ! When splits is 'test' only, ie when external dataset, error occurs.
116
+ label_metrics = LabelMetrics()
117
+ for split in ['val', 'test']:
118
+ df_split = df_label.query('split == @split')
119
+ y_true = df_split[label_name]
120
+ y_score = df_split[positive_pred_name]
121
+ _fpr, _tpr, _ = metrics.roc_curve(y_true, y_score)
122
+ self._set_roc(label_metrics, split, _fpr, _tpr)
123
+ return label_metrics
124
+
125
+ def _cal_label_roc_multi(self, label_name: str, df_group: pd.DataFrame) -> LabelMetrics:
126
+ """
127
+ Calculate ROC for multi-class by macro average.
128
+
129
+ Args:
130
+ label_name (str): label name
131
+ df_group (pd.DataFrame): likelihood for group
132
+
133
+ Returns:
134
+ LabelMetrics: metrics of 'val' and 'test'
135
+ """
136
+ required_columns = [column_name for column_name in df_group.columns if label_name in column_name] + ['split']
137
+ df_label = df_group[required_columns]
138
+
139
+ pred_name_list = list(df_label.columns[df_label.columns.str.startswith('pred')])
140
+ class_list = [int(pred_name.rsplit('_', 1)[-1]) for pred_name in pred_name_list] # [pred_label_0, pred_label_1, pred_label_2] -> [0, 1, 2]
141
+ num_classes = len(class_list)
142
+
143
+ label_metrics = LabelMetrics()
144
+ for split in ['val', 'test']:
145
+ df_split = df_label.query('split == @split')
146
+ y_true = df_split[label_name]
147
+ y_true_bin = label_binarize(y_true, classes=class_list) # Since y_true: List[int], should be class_list: List[int]
148
+
149
+ # Compute ROC for each class by OneVsRest
150
+ _fpr = dict()
151
+ _tpr = dict()
152
+ for i, class_name in enumerate(class_list):
153
+ pred_name = 'pred_' + label_name + '_' + str(class_name)
154
+ _fpr[class_name], _tpr[class_name], _ = metrics.roc_curve(y_true_bin[:, i], df_split[pred_name])
155
+
156
+ # First aggregate all false positive rates
157
+ all_fpr = np.unique(np.concatenate([_fpr[class_name] for class_name in class_list]))
158
+
159
+ # Then interpolate all ROC at this points
160
+ mean_tpr = np.zeros_like(all_fpr)
161
+ for class_name in class_list:
162
+ mean_tpr += np.interp(all_fpr, _fpr[class_name], _tpr[class_name])
163
+
164
+ # Finally average it and compute AUC
165
+ mean_tpr /= num_classes
166
+
167
+ _fpr['macro'] = all_fpr
168
+ _tpr['macro'] = mean_tpr
169
+ self._set_roc(label_metrics, split, _fpr['macro'], _tpr['macro'])
170
+ return label_metrics
171
+
172
+ def cal_label_metrics(self, label_name: str, df_group: pd.DataFrame) -> LabelMetrics:
173
+ """
174
+ Calculate ROC and AUC for label depending on binary or multi-class.
175
+
176
+ Args:
177
+ label_name (str):label name
178
+ df_group (pd.DataFrame): likelihood for group
179
+
180
+ Returns:
181
+ LabelMetrics: metrics of 'val' and 'test'
182
+ """
183
+ pred_name_list = df_group.columns[df_group.columns.str.startswith('pred_' + label_name)]
184
+ isMultiClass = (len(pred_name_list) > 2)
185
+ if isMultiClass:
186
+ label_metrics = self._cal_label_roc_multi(label_name, df_group)
187
+ else:
188
+ label_metrics = self._cal_label_roc_binary(label_name, df_group)
189
+ return label_metrics
190
+
191
+
192
+ class YYMixin:
193
+ """
194
+ Class for calculating YY and R2.
195
+ """
196
+ def _set_yy(self, label_metrics: LabelMetrics, split: str, y_obs: np.ndarray, y_pred: np.ndarray) -> None:
197
+ """
198
+ Set ground truth, prediction, and R2.
199
+
200
+ Args:
201
+ label_metrics (LabelMetrics): metrics of 'val' and 'test'
202
+ split (str): 'val' or 'test'
203
+ y_obs (np.ndarray): ground truth
204
+ y_pred (np.ndarray): prediction
205
+
206
+ self.metrics_kind = 'r2' is defined in class RegEval below.
207
+ """
208
+ label_metrics.set_label_metrics(split, 'y_obs', y_obs.values)
209
+ label_metrics.set_label_metrics(split, 'y_pred', y_pred.values)
210
+ label_metrics.set_label_metrics(split, self.metrics_kind, metrics.r2_score(y_obs, y_pred))
211
+
212
+ def cal_label_metrics(self, label_name: str, df_group: pd.DataFrame) -> LabelMetrics:
213
+ """
214
+ Calculate YY and R2 for label.
215
+
216
+ Args:
217
+ label_name (str): label name
218
+ df_group (pd.DataFrame): likelihood for group
219
+
220
+ Returns:
221
+ LabelMetrics: metrics of 'val' and 'test'
222
+ """
223
+ required_columns = [column_name for column_name in df_group.columns if label_name in column_name] + ['split']
224
+ df_label = df_group[required_columns]
225
+ label_metrics = LabelMetrics()
226
+ for split in ['val', 'test']:
227
+ df_split = df_label.query('split == @split')
228
+ y_obs = df_split[label_name]
229
+ y_pred = df_split['pred_' + label_name]
230
+ self._set_yy(label_metrics, split, y_obs, y_pred)
231
+ return label_metrics
232
+
233
+
234
+ class C_IndexMixin:
235
+ """
236
+ Class for calculating C-Index.
237
+ """
238
+ def _set_c_index(
239
+ self,
240
+ label_metrics: LabelMetrics,
241
+ split: str,
242
+ periods: pd.Series,
243
+ preds: pd.Series,
244
+ labels: pd.Series
245
+ ) -> None:
246
+ """
247
+ Set C-Index.
248
+
249
+ Args:
250
+ label_metrics (LabelMetrics): metrics of 'val' and 'test'
251
+ split (str): 'val' or 'test'
252
+ periods (pd.Series): periods
253
+ preds (pd.Series): prediction
254
+ labels (pd.Series): label
255
+
256
+ self.metrics_kind = 'c_index' is defined in class DeepSurvEval below.
257
+ """
258
+ from lifelines.utils import concordance_index
259
+ value_c_index = concordance_index(periods, (-1)*preds, labels)
260
+ label_metrics.set_label_metrics(split, self.metrics_kind, value_c_index)
261
+
262
+ def cal_label_metrics(self, label_name: str, df_group: pd.DataFrame) -> LabelMetrics:
263
+ """
264
+ Calculate C-Index for label.
265
+
266
+ Args:
267
+ label_name (str): label name
268
+ df_group (pd.DataFrame): likelihood for group
269
+
270
+ Returns:
271
+ LabelMetrics: metrics of 'val' and 'test'
272
+ """
273
+ required_columns = [column_name for column_name in df_group.columns if label_name in column_name] + ['periods', 'split']
274
+ df_label = df_group[required_columns]
275
+ label_metrics = LabelMetrics()
276
+ for split in ['val', 'test']:
277
+ df_split = df_label.query('split == @split')
278
+ periods = df_split['periods']
279
+ preds = df_split['pred_' + label_name]
280
+ labels = df_split[label_name]
281
+ self._set_c_index(label_metrics, split, periods, preds, labels)
282
+ return label_metrics
283
+
284
+
285
+ class MetricsMixin:
286
+ """
287
+ Class to calculate metrics and make summary.
288
+ """
289
+ def _cal_group_metrics(self, df_group: pd.DataFrame) -> Dict[str, LabelMetrics]:
290
+ """
291
+ Calculate metrics for each group.
292
+
293
+ Args:
294
+ df_group (pd.DataFrame): likelihood for group
295
+
296
+ Returns:
297
+ Dict[str, LabelMetrics]: dictionary of label and its LabelMetrics
298
+ eg. {{label_1: LabelMetrics(), label_2: LabelMetrics(), ...}
299
+ """
300
+ label_list = list(df_group.columns[df_group.columns.str.startswith('label')])
301
+ group_metrics = dict()
302
+ for label_name in label_list:
303
+ label_metrics = self.cal_label_metrics(label_name, df_group)
304
+ group_metrics[label_name] = label_metrics
305
+ return group_metrics
306
+
307
+ def cal_whole_metrics(self, df_likelihood: pd.DataFrame) -> Dict[str, Dict[str, LabelMetrics]]:
308
+ """
309
+ Calculate metrics for all groups.
310
+
311
+ Args:
312
+ df_likelihood (pd.DataFrame) : DataFrame of likelihood
313
+
314
+ Returns:
315
+ Dict[str, Dict[str, LabelMetrics]]: dictionary of group and dictionary of label and its LabelMetrics
316
+ eg. {
317
+ groupA: {label_1: LabelMetrics(), label_2: LabelMetrics(), ...},
318
+ groupB: {label_1: LabelMetrics(), label_2: LabelMetrics()}, ...},
319
+ ...}
320
+ """
321
+ whole_metrics = dict()
322
+ for group in df_likelihood['group'].unique():
323
+ df_group = df_likelihood.query('group == @group')
324
+ whole_metrics[group] = self._cal_group_metrics(df_group)
325
+ return whole_metrics
326
+
327
+ def make_summary(
328
+ self,
329
+ whole_metrics: Dict[str, Dict[str, LabelMetrics]],
330
+ likelihood_path: Path,
331
+ metrics_kind: str
332
+ ) -> pd.DataFrame:
333
+ """
334
+ Make summary.
335
+
336
+ Args:
337
+ whole_metrics (Dict[str, Dict[str, LabelMetrics]]): metrics for all groups
338
+ likelihood_path (Path): path to likelihood
339
+ metrics_kind (str): kind of metrics, ie, 'auc', 'r2', or 'c_index'
340
+
341
+ Returns:
342
+ pd.DataFrame: summary
343
+ """
344
+ _datetime = likelihood_path.parents[1].name
345
+ _weight = likelihood_path.stem.replace('likelihood_', '') + '.pt'
346
+ df_summary = pd.DataFrame()
347
+ for group, group_metrics in whole_metrics.items():
348
+ _new = dict()
349
+ _new['datetime'] = [_datetime]
350
+ _new['weight'] = [ _weight]
351
+ _new['group'] = [group]
352
+ for label_name, label_metrics in group_metrics.items():
353
+ _val_metrics = label_metrics.get_label_metrics('val', metrics_kind)
354
+ _test_metrics = label_metrics.get_label_metrics('test', metrics_kind)
355
+ _new[label_name + '_val_' + metrics_kind] = [f"{_val_metrics:.2f}"]
356
+ _new[label_name + '_test_' + metrics_kind] = [f"{_test_metrics:.2f}"]
357
+ df_summary = pd.concat([df_summary, pd.DataFrame(_new)], ignore_index=True)
358
+
359
+ df_summary = df_summary.sort_values('group')
360
+ return df_summary
361
+
362
+ def print_metrics(self, df_summary: pd.DataFrame, metrics_kind: str) -> None:
363
+ """
364
+ Print metrics.
365
+
366
+ Args:
367
+ df_summary (pd.DataFrame): summary
368
+ metrics_kind (str): kind of metrics, ie. 'auc', 'r2', or 'c_index'
369
+ """
370
+ label_list = list(df_summary.columns[df_summary.columns.str.startswith('label')]) # [label_1_val, label_1_test, label_2_val, label_2_test, ...]
371
+ num_splits = len(['val', 'test'])
372
+ _column_val_test_list = [label_list[i:i+num_splits] for i in range(0, len(label_list), num_splits)] # [[label_1_val, label_1_test], [label_2_val, label_2_test], ...]
373
+ for _, row in df_summary.iterrows():
374
+ logger.info(row['group'])
375
+ for _column_val_test in _column_val_test_list:
376
+ _label_name = _column_val_test[0].replace('_val', '')
377
+ _label_name_val = _column_val_test[0]
378
+ _label_name_test = _column_val_test[1]
379
+ logger.info(f"{_label_name:<25} val_{metrics_kind}: {row[_label_name_val]:>7}, test_{metrics_kind}: {row[_label_name_test]:>7}")
380
+
381
+ def update_summary(self, df_summary: pd.DataFrame, likelihood_path: Path) -> None:
382
+ """
383
+ Update summary.
384
+
385
+ Args:
386
+ df_summary (pd.DataFrame): summary to be added to the previous summary
387
+ likelihood_path (Path): path to likelihood
388
+ """
389
+ _project_dir = likelihood_path.parents[3]
390
+ summary_dir = Path(_project_dir, 'summary')
391
+ summary_path = Path(summary_dir, 'summary.csv')
392
+ if summary_path.exists():
393
+ df_prev = pd.read_csv(summary_path)
394
+ df_updated = pd.concat([df_prev, df_summary], axis=0)
395
+ else:
396
+ summary_dir.mkdir(parents=True, exist_ok=True)
397
+ df_updated = df_summary
398
+ df_updated.to_csv(summary_path, index=False)
399
+
400
+ def make_metrics(self, likelihood_path: Path) -> None:
401
+ """
402
+ Make metrics.
403
+
404
+ Args:
405
+ likelihood_path (Path): path to likelihood
406
+ """
407
+ df_likelihood = pd.read_csv(likelihood_path)
408
+ whole_metrics = self.cal_whole_metrics(df_likelihood)
409
+ self.make_save_fig(whole_metrics, likelihood_path, self.fig_kind)
410
+ df_summary = self.make_summary(whole_metrics, likelihood_path, self.metrics_kind)
411
+ self.print_metrics(df_summary, self.metrics_kind)
412
+ self.update_summary(df_summary, likelihood_path)
413
+
414
+
415
+ class FigROCMixin:
416
+ """
417
+ Class to plot ROC.
418
+ """
419
+ def _plot_fig_group_metrics(self, group: str, group_metrics: Dict[str, LabelMetrics]) -> plt:
420
+ """
421
+ Plot ROC.
422
+
423
+ Args:
424
+ group (str): group
425
+ group_metrics (Dict[str, LabelMetrics]): dictionary of label and its LabelMetrics
426
+
427
+ Returns:
428
+ plt: ROC
429
+ """
430
+ label_list = group_metrics.keys()
431
+ num_rows = 1
432
+ num_cols = len(label_list)
433
+ base_size = 7
434
+ height = num_rows * base_size
435
+ width = num_cols * height
436
+ fig = plt.figure(figsize=(width, height))
437
+
438
+ for i, label_name in enumerate(label_list):
439
+ label_metrics = group_metrics[label_name]
440
+ offset = i + 1
441
+ ax_i = fig.add_subplot(
442
+ num_rows,
443
+ num_cols,
444
+ offset,
445
+ title=group + ': ' + label_name,
446
+ xlabel='1 - Specificity',
447
+ ylabel='Sensitivity',
448
+ xmargin=0,
449
+ ymargin=0
450
+ )
451
+ ax_i.plot(label_metrics.val.fpr, label_metrics.val.tpr, label=f"AUC_val = {label_metrics.val.auc:.2f}", marker='x')
452
+ ax_i.plot(label_metrics.test.fpr, label_metrics.test.tpr, label=f"AUC_test = {label_metrics.test.auc:.2f}", marker='o')
453
+ ax_i.grid()
454
+ ax_i.legend()
455
+ fig.tight_layout()
456
+ return fig
457
+
458
+
459
+ class FigYYMixin:
460
+ """
461
+ Class to plot YY-graph.
462
+ """
463
+ def _plot_fig_group_metrics(self, group: str, group_metrics: Dict[str, LabelMetrics]) -> plt:
464
+ """
465
+ Plot yy.
466
+
467
+ Args:
468
+ group (str): group
469
+ group_metrics (Dict[str, LabelMetrics]): dictionary of label and its LabelMetrics
470
+
471
+ Returns:
472
+ plt: YY-graph
473
+ """
474
+ label_list = group_metrics.keys()
475
+ num_splits = len(['val', 'test'])
476
+ num_rows = 1
477
+ num_cols = len(label_list) * num_splits
478
+ base_size = 7
479
+ height = num_rows * base_size
480
+ width = num_cols * height
481
+ fig = plt.figure(figsize=(width, height))
482
+
483
+ for i, label_name in enumerate(label_list):
484
+ label_metrics = group_metrics[label_name]
485
+ val_offset = (i * num_splits) + 1
486
+ test_offset = val_offset + 1
487
+
488
+ val_ax = fig.add_subplot(
489
+ num_rows,
490
+ num_cols,
491
+ val_offset,
492
+ title=group + ': ' + label_name + '\n' + 'val: Observed-Predicted Plot',
493
+ xlabel='Observed',
494
+ ylabel='Predicted',
495
+ xmargin=0,
496
+ ymargin=0
497
+ )
498
+
499
+ test_ax = fig.add_subplot(
500
+ num_rows,
501
+ num_cols,
502
+ test_offset,
503
+ title=group + ': ' + label_name + '\n' + 'test: Observed-Predicted Plot',
504
+ xlabel='Observed',
505
+ ylabel='Predicted',
506
+ xmargin=0,
507
+ ymargin=0
508
+ )
509
+
510
+ y_obs_val = label_metrics.val.y_obs
511
+ y_pred_val = label_metrics.val.y_pred
512
+
513
+ y_obs_test = label_metrics.test.y_obs
514
+ y_pred_test = label_metrics.test.y_pred
515
+
516
+ # Plot
517
+ color = mcolors.TABLEAU_COLORS
518
+ val_ax.scatter(y_obs_val, y_pred_val, color=color['tab:blue'], label='val')
519
+ test_ax.scatter(y_obs_test, y_pred_test, color=color['tab:orange'], label='test')
520
+
521
+ # Draw diagonal line
522
+ y_values_val = np.concatenate([y_obs_val.flatten(), y_pred_val.flatten()])
523
+ y_values_test = np.concatenate([y_obs_test.flatten(), y_pred_test.flatten()])
524
+
525
+ y_values_val_min, y_values_val_max, y_values_val_range = np.amin(y_values_val), np.amax(y_values_val), np.ptp(y_values_val)
526
+ y_values_test_min, y_values_test_max, y_values_test_range = np.amin(y_values_test), np.amax(y_values_test), np.ptp(y_values_test)
527
+
528
+ val_ax.plot([y_values_val_min - (y_values_val_range * 0.01), y_values_val_max + (y_values_val_range * 0.01)],
529
+ [y_values_val_min - (y_values_val_range * 0.01), y_values_val_max + (y_values_val_range * 0.01)], color='red')
530
+
531
+ test_ax.plot([y_values_test_min - (y_values_test_range * 0.01), y_values_test_max + (y_values_test_range * 0.01)],
532
+ [y_values_test_min - (y_values_test_range * 0.01), y_values_test_max + (y_values_test_range * 0.01)], color='red')
533
+
534
+ fig.tight_layout()
535
+ return fig
536
+
537
+
538
+ class FigMixin:
539
+ """
540
+ Class for make and save figure
541
+ This class is for ROC and YY-graph.
542
+ """
543
+ def make_save_fig(self, whole_metrics: Dict[str, Dict[str, LabelMetrics]], likelihood_path: Path, fig_kind: str) -> None:
544
+ """
545
+ Make and save figure.
546
+
547
+ Args:
548
+ whole_metrics (Dict[str, Dict[str, LabelMetrics]]): metrics for all groups
549
+ likelihood_path (Path): path to likelihood
550
+ fig_kind (str): kind of figure, ie. 'roc' or 'yy'
551
+ """
552
+ _datetime_dir = likelihood_path.parents[1]
553
+ save_dir = Path(_datetime_dir, fig_kind)
554
+ save_dir.mkdir(parents=True, exist_ok=True)
555
+ _fig_name = fig_kind + '_' + likelihood_path.stem.replace('likelihood_', '')
556
+ for group, group_metrics in whole_metrics.items():
557
+ fig = self._plot_fig_group_metrics(group, group_metrics)
558
+ save_path = Path(save_dir, group + '_' + _fig_name + '.png')
559
+ fig.savefig(save_path)
560
+ plt.close()
561
+
562
+
563
+ class ClsEval(MetricsMixin, ROCMixin, FigMixin, FigROCMixin):
564
+ """
565
+ Class for calculation metrics for classification.
566
+ """
567
+ def __init__(self) -> None:
568
+ self.fig_kind = 'roc'
569
+ self.metrics_kind = 'auc'
570
+
571
+
572
+ class RegEval(MetricsMixin, YYMixin, FigMixin, FigYYMixin):
573
+ """
574
+ Class for calculation metrics for regression.
575
+ """
576
+ def __init__(self) -> None:
577
+ self.fig_kind = 'yy'
578
+ self.metrics_kind = 'r2'
579
+
580
+
581
+ class DeepSurvEval(MetricsMixin, C_IndexMixin):
582
+ """
583
+ Class for calculation metrics for DeepSurv.
584
+ """
585
+ def __init__(self) -> None:
586
+ self.fig_kind = None
587
+ self.metrics_kind = 'c_index'
588
+
589
+ def make_metrics(self, likelihood_path: Path) -> None:
590
+ """
591
+ Make metrics, substantially this method handles everything all.
592
+
593
+ Args:
594
+ likelihood_path (Path): path to likelihood
595
+
596
+ Overwrite def make_metrics() in class MetricsMixin by deleting self.make_save_fig(),
597
+ because of no need to plot and save figure.
598
+ """
599
+ df_likelihood = pd.read_csv(likelihood_path)
600
+ whole_metrics = self.cal_whole_metrics(df_likelihood)
601
+ df_summary = self.make_summary(whole_metrics, likelihood_path, self.metrics_kind)
602
+ self.print_metrics(df_summary, self.metrics_kind)
603
+ self.update_summary(df_summary, likelihood_path)
604
+
605
+
606
+ def set_eval(task: str) -> Union[ClsEval, RegEval, DeepSurvEval]:
607
+ """
608
+ Set class for evaluation depending on task depending on task.
609
+
610
+ Args:
611
+ task (str): task
612
+
613
+ Returns:
614
+ Union[ClsEval, RegEval, DeepSurvEval]: class for evaluation
615
+ """
616
+ if task == 'classification':
617
+ return ClsEval()
618
+ elif task == 'regression':
619
+ return RegEval()
620
+ elif task == 'deepsurv':
621
+ return DeepSurvEval()
622
+ else:
623
+ raise ValueError(f"Invalid task: {task}.")
lib/options.py ADDED
@@ -0,0 +1,655 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # -*- coding: utf-8 -*-
3
+
4
+ import argparse
5
+ from distutils.util import strtobool
6
+ from pathlib import Path
7
+ import pandas as pd
8
+ import json
9
+ import torch
10
+ from .logger import BaseLogger
11
+ from typing import List, Dict, Tuple, Union
12
+
13
+
14
+ logger = BaseLogger.get_logger(__name__)
15
+
16
+
17
+ class Options:
18
+ """
19
+ Class for options.
20
+ """
21
+ def __init__(self, datetime: str = None, isTrain: bool = None) -> None:
22
+ """
23
+ Args:
24
+ datetime (str, optional): date time Args:
25
+ isTrain (bool, optional): Variable indicating whether training or not. Defaults to None.
26
+ """
27
+ self.parser = argparse.ArgumentParser(description='Options for training or test')
28
+
29
+ # CSV
30
+ self.parser.add_argument('--csvpath', type=str, required=True, help='path to csv for training or test')
31
+
32
+ # GPU Ids
33
+ self.parser.add_argument('--gpu_ids', type=str, default='cpu', help='gpu ids: e.g. 0, 0-1-2, 0-2. Use cpu for CPU (Default: cpu)')
34
+
35
+ if isTrain:
36
+ # Task
37
+ self.parser.add_argument('--task', type=str, required=True, choices=['classification', 'regression', 'deepsurv'], help='Task')
38
+
39
+ # Model
40
+ self.parser.add_argument('--model', type=str, required=True, help='model: MLP, CNN, ViT, or MLP+(CNN or ViT)')
41
+ self.parser.add_argument('--pretrained', type=strtobool, default=False, help='For use of pretrained model(CNN or ViT)')
42
+
43
+ # Training and Internal validation
44
+ self.parser.add_argument('--criterion', type=str, required=True, choices=['CEL', 'MSE', 'RMSE', 'MAE', 'NLL'], help='criterion')
45
+ self.parser.add_argument('--optimizer', type=str, default='Adam', choices=['SGD', 'Adadelta', 'RMSprop', 'Adam', 'RAdam'], help='optimizer')
46
+ self.parser.add_argument('--lr', type=float, metavar='N', help='learning rate')
47
+ self.parser.add_argument('--epochs', type=int, default=10, metavar='N', help='number of epochs (Default: 10)')
48
+
49
+ # Batch size
50
+ self.parser.add_argument('--batch_size', type=int, required=True, metavar='N', help='batch size in training')
51
+
52
+ # Preprocess for image
53
+ self.parser.add_argument('--augmentation', type=str, default='no', choices=['xrayaug', 'trivialaugwide', 'randaug', 'no'], help='kind of augmentation')
54
+ self.parser.add_argument('--normalize_image', type=str, choices=['yes', 'no'], default='yes', help='image normalization: yes, no (Default: yes)')
55
+
56
+ # Sampler
57
+ self.parser.add_argument('--sampler', type=str, default='no', choices=['yes', 'no'], help='sample data in training or not, yes or no')
58
+
59
+ # Input channel
60
+ self.parser.add_argument('--in_channel', type=int, required=True, choices=[1, 3], help='channel of input image')
61
+ self.parser.add_argument('--vit_image_size', type=int, default=0, help='input image size for ViT. Set 0 if not used ViT (Default: 0)')
62
+
63
+ # Weight saving strategy
64
+ self.parser.add_argument('--save_weight_policy', type=str, choices=['best', 'each'], default='best', help='Save weight policy: best, or each(ie. save each time loss decreases when multi-label output) (Default: best)')
65
+
66
+ else:
67
+ # Directory of weight at training
68
+ self.parser.add_argument('--weight_dir', type=str, default=None, help='directory of weight to be used when test. If None, the latest one is selected')
69
+
70
+ # Test bash size
71
+ self.parser.add_argument('--test_batch_size', type=int, default=1, metavar='N', help='batch size for test (Default: 1)')
72
+
73
+ # Splits for test
74
+ self.parser.add_argument('--test_splits', type=str, default='train-val-test', help='splits for test: e.g. test, val-test, train-val-test. (Default: train-val-test)')
75
+
76
+ self.args = self.parser.parse_args()
77
+
78
+ if datetime is not None:
79
+ self.args.datetime = datetime
80
+
81
+ assert isinstance(isTrain, bool), 'isTrain should be bool.'
82
+ self.args.isTrain = isTrain
83
+
84
+ def get_args(self) -> argparse.Namespace:
85
+ """
86
+ Return arguments.
87
+
88
+ Returns:
89
+ argparse.Namespace: arguments
90
+ """
91
+ return self.args
92
+
93
+
94
+ class CSVParser:
95
+ """
96
+ Class to get information of csv and cast csv.
97
+ """
98
+ def __init__(self, csvpath: str, task: str, isTrain: bool = None) -> None:
99
+ """
100
+ Args:
101
+ csvpath (str): path to csv
102
+ task (str): task
103
+ isTrain (bool): if training or not
104
+ """
105
+ self.csvpath = csvpath
106
+ self.task = task
107
+
108
+ _df_source = pd.read_csv(self.csvpath)
109
+ _df_source = _df_source[_df_source['split'] != 'exclude']
110
+
111
+ self.input_list = list(_df_source.columns[_df_source.columns.str.startswith('input')])
112
+ self.label_list = list(_df_source.columns[_df_source.columns.str.startswith('label')])
113
+ if self.task == 'deepsurv':
114
+ _period_name_list = list(_df_source.columns[_df_source.columns.str.startswith('period')])
115
+ assert (len(_period_name_list) == 1), f"One column of period should be contained in {self.csvpath} when deepsurv."
116
+ self.period_name = _period_name_list[0]
117
+
118
+ _df_source = self._cast(_df_source, self.task)
119
+
120
+ # If no column of group, add it.
121
+ if 'group' not in _df_source.columns:
122
+ _df_source = _df_source.assign(group='all')
123
+
124
+ self.df_source = _df_source
125
+
126
+ if isTrain:
127
+ self.mlp_num_inputs = len(self.input_list)
128
+ self.num_outputs_for_label = self._define_num_outputs_for_label(self.df_source, self.label_list, self.task)
129
+
130
+ def _cast(self, df_source: pd.DataFrame, task: str) -> pd.DataFrame:
131
+ """
132
+ Make dictionary of cast depending on task.
133
+
134
+ Args:
135
+ df_source (pd.DataFrame): excluded DataFrame
136
+ task: (str): task
137
+
138
+ Returns:
139
+ DataFrame: csv excluded and cast depending on task
140
+ """
141
+ _cast_input = {input_name: float for input_name in self.input_list}
142
+
143
+ if task == 'classification':
144
+ _cast_label = {label_name: int for label_name in self.label_list}
145
+ _casts = {**_cast_input, **_cast_label}
146
+ df_source = df_source.astype(_casts)
147
+ return df_source
148
+
149
+ elif task == 'regression':
150
+ _cast_label = {label_name: float for label_name in self.label_list}
151
+ _casts = {**_cast_input, **_cast_label}
152
+ df_source = df_source.astype(_casts)
153
+ return df_source
154
+
155
+ elif task == 'deepsurv':
156
+ _cast_label = {label_name: int for label_name in self.label_list}
157
+ _cast_period = {self.period_name: int}
158
+ _casts = {**_cast_input, **_cast_label, **_cast_period}
159
+ df_source = df_source.astype(_casts)
160
+ return df_source
161
+
162
+ else:
163
+ raise ValueError(f"Invalid task: {self.task}.")
164
+
165
+ def _define_num_outputs_for_label(self, df_source: pd.DataFrame, label_list: List[str], task :str) -> Dict[str, int]:
166
+ """
167
+ Define the number of outputs for each label.
168
+
169
+ Args:
170
+ df_source (pd.DataFrame): DataFrame of csv
171
+ label_list (List[str]): list of labels
172
+ task: str
173
+
174
+ Returns:
175
+ Dict[str, int]: dictionary of the number of outputs for each label
176
+ eg.
177
+ classification: _num_outputs_for_label = {label_A: 2, label_B: 3, ...}
178
+ regression, deepsurv: _num_outputs_for_label = {label_A: 1, label_B: 1, ...}
179
+ deepsurv: _num_outputs_for_label = {label_A: 1}
180
+ """
181
+ if task == 'classification':
182
+ _num_outputs_for_label = {label_name: df_source[label_name].nunique() for label_name in label_list}
183
+ return _num_outputs_for_label
184
+
185
+ elif (task == 'regression') or (task == 'deepsurv'):
186
+ _num_outputs_for_label = {label_name: 1 for label_name in label_list}
187
+ return _num_outputs_for_label
188
+
189
+ else:
190
+ raise ValueError(f"Invalid task: {task}.")
191
+
192
+
193
+ def _parse_model(model_name: str) -> Tuple[Union[str, None], Union[str, None]]:
194
+ """
195
+ Parse model name.
196
+
197
+ Args:
198
+ model_name (str): model name (eg. MLP, ResNey18, or MLP+ResNet18)
199
+
200
+ Returns:
201
+ Tuple[str, str]: MLP, CNN or Vision Transformer name
202
+ eg. 'MLP', 'ResNet18', 'MLP+ResNet18' ->
203
+ ['MLP'], ['ResNet18'], ['MLP', 'ResNet18']
204
+ """
205
+ _model = model_name.split('+')
206
+ mlp = 'MLP' if 'MLP' in _model else None
207
+ _net = [_n for _n in _model if _n != 'MLP']
208
+ net = _net[0] if _net != [] else None
209
+ return mlp, net
210
+
211
+
212
+ def _parse_gpu_ids(gpu_ids: str) -> List[int]:
213
+ """
214
+ Parse GPU ids concatenated with '-' to list of integers of GPU ids.
215
+ eg. '0-1-2' -> [0, 1, 2], '-1' -> []
216
+
217
+ Args:
218
+ gpu_ids (str): GPU Ids
219
+
220
+ Returns:
221
+ List[int]: list of GPU ids
222
+ """
223
+ if (gpu_ids == 'cpu') or (gpu_ids == 'cpu\r'):
224
+ str_ids = []
225
+ else:
226
+ str_ids = gpu_ids.split('-')
227
+ _gpu_ids = []
228
+ for str_id in str_ids:
229
+ id = int(str_id)
230
+ if id >= 0:
231
+ _gpu_ids.append(id)
232
+ return _gpu_ids
233
+
234
+
235
+ def _get_latest_weight_dir() -> str:
236
+ """
237
+ Return the latest path to directory of weight made at training.
238
+
239
+ Returns:
240
+ str: path to directory of the latest weight
241
+ eg. 'results/<project>/trials/2022-09-30-15-56-60/weights'
242
+ """
243
+ _weight_dirs = list(Path('results').glob('*/trials/*/weights'))
244
+ assert (_weight_dirs != []), 'No directory of weight.'
245
+ weight_dir = max(_weight_dirs, key=lambda weight_dir: weight_dir.stat().st_mtime)
246
+ return str(weight_dir)
247
+
248
+
249
+ def _collect_weight_paths(weight_dir: str) -> List[str]:
250
+ """
251
+ Return list of weight paths.
252
+
253
+ Args:
254
+ weight_dir (str): path to directory of weights
255
+
256
+ Returns:
257
+ List[str]: list of weight paths
258
+ """
259
+ _weight_paths = list(Path(weight_dir).glob('*.pt'))
260
+ assert _weight_paths != [], f"No weight in {weight_dir}."
261
+ _weight_paths.sort(key=lambda path: path.stat().st_mtime)
262
+ _weight_paths = [str(weight_path) for weight_path in _weight_paths]
263
+ return _weight_paths
264
+
265
+
266
+ class ParamTable:
267
+ """
268
+ Class to make table to dispatch parameters by group.
269
+ """
270
+ def __init__(self) -> None:
271
+ # groups
272
+ # key is abbreviation, value is group name
273
+ self.groups = {
274
+ 'mo': 'model',
275
+ 'dl': 'dataloader',
276
+ 'trc': 'train_conf',
277
+ 'tsc': 'test_conf',
278
+ 'sa': 'save',
279
+ 'lo': 'load',
280
+ 'trp': 'train_print',
281
+ 'tsp': 'test_print'
282
+ }
283
+
284
+ mo = self.groups['mo']
285
+ dl = self.groups['dl']
286
+ trc = self.groups['trc']
287
+ tsc = self.groups['tsc']
288
+ sa = self.groups['sa']
289
+ lo = self.groups['lo']
290
+ trp = self.groups['trp']
291
+ tsp = self.groups['tsp']
292
+
293
+ # The below shows that which group each parameter dispatches to.
294
+ self.dispatch = {
295
+ 'datetime': [sa],
296
+ 'project': [sa, trp, tsp],
297
+ 'csvpath': [sa, trp, tsp],
298
+ 'task': [dl, tsc, sa, lo, trp, tsp],
299
+ 'isTrain': [dl, trp, tsp],
300
+
301
+ 'model': [sa, lo, trp, tsp],
302
+ 'vit_image_size': [mo, sa, lo, trp, tsp],
303
+ 'pretrained': [mo, sa, trp],
304
+ 'mlp': [mo, dl],
305
+ 'net': [mo, dl],
306
+
307
+ 'weight_dir': [tsc, tsp],
308
+ 'weight_paths': [tsc],
309
+
310
+ 'criterion': [trc, sa, trp],
311
+ 'optimizer': [trc, sa, trp],
312
+ 'lr': [trc, sa, trp],
313
+ 'epochs': [trc, sa, trp],
314
+
315
+ 'batch_size': [dl, sa, trp],
316
+ 'test_batch_size': [dl, tsp],
317
+ 'test_splits': [tsc, tsp],
318
+
319
+ 'in_channel': [mo, dl, sa, lo, trp, tsp],
320
+ 'normalize_image': [dl, sa, lo, trp, tsp],
321
+ 'augmentation': [dl, sa, trp],
322
+ 'sampler': [dl, sa, trp],
323
+
324
+ 'df_source': [dl],
325
+ 'label_list': [dl, trc, sa, lo],
326
+ 'input_list': [dl, sa, lo],
327
+ 'period_name': [dl, sa, lo],
328
+ 'mlp_num_inputs': [mo, sa, lo],
329
+ 'num_outputs_for_label': [mo, sa, lo, tsc],
330
+
331
+ 'save_weight_policy': [sa, trp, trc],
332
+ 'scaler_path': [dl, tsp],
333
+ 'save_datetime_dir': [trc, tsc, trp, tsp],
334
+
335
+ 'gpu_ids': [trc, tsc, sa, trp, tsp],
336
+ 'device': [mo, trc, tsc],
337
+ 'dataset_info': [trc, sa, trp, tsp]
338
+ }
339
+
340
+ self.table = self._make_table()
341
+
342
+ def _make_table(self) -> pd.DataFrame:
343
+ """
344
+ Make table to dispatch parameters by group.
345
+
346
+ Returns:
347
+ pd.DataFrame: table which shows that which group each parameter belongs to.
348
+ """
349
+ df_table = pd.DataFrame([], index=self.dispatch.keys(), columns=self.groups.values()).fillna('no')
350
+ for param, grps in self.dispatch.items():
351
+ for grp in grps:
352
+ df_table.loc[param, grp] = 'yes'
353
+
354
+ df_table = df_table.reset_index()
355
+ df_table = df_table.rename(columns={'index': 'parameter'})
356
+ return df_table
357
+
358
+ def get_by_group(self, group_name: str) -> List[str]:
359
+ """
360
+ Return list of parameters which belong to group
361
+
362
+ Args:
363
+ group_name (str): group name
364
+
365
+ Returns:
366
+ List[str]: list of parameters
367
+ """
368
+ _df_table = self.table
369
+ _param_names = _df_table[_df_table[group_name] == 'yes']['parameter'].tolist()
370
+ return _param_names
371
+
372
+
373
+ Param_Table = ParamTable()
374
+
375
+
376
+ class ParamSet:
377
+ """
378
+ Class to store required parameters for each group.
379
+ """
380
+ pass
381
+
382
+
383
+ def _dispatch_by_group(args: argparse.Namespace, group_name: str) -> ParamSet:
384
+ """
385
+ Dispatch parameters depending on group.
386
+
387
+ Args:
388
+ args (argparse.Namespace): arguments
389
+ group_name (str): group
390
+
391
+ Returns:
392
+ ParamSet: class containing parameters for group
393
+ """
394
+ _param_names = Param_Table.get_by_group(group_name)
395
+ param_set = ParamSet()
396
+ for param_name in _param_names:
397
+ if hasattr(args, param_name):
398
+ _arg = getattr(args, param_name)
399
+ setattr(param_set, param_name, _arg)
400
+ return param_set
401
+
402
+
403
+ def save_parameter(params: ParamSet, save_path: str) -> None:
404
+ """
405
+ Save parameters.
406
+
407
+ Args:
408
+ params (ParamSet): parameters
409
+
410
+ save_path (str): save path for parameters
411
+ """
412
+ _saved = {_param: _arg for _param, _arg in vars(params).items()}
413
+ save_dir = Path(save_path).parents[0]
414
+ save_dir.mkdir(parents=True, exist_ok=True)
415
+ with open(save_path, 'w') as f:
416
+ json.dump(_saved, f, indent=4)
417
+
418
+
419
+ def _retrieve_parameter(parameter_path: str) -> Dict[str, Union[str, int, float]]:
420
+ """
421
+ Retrieve only parameters required at test from parameters at training.
422
+
423
+ Args:
424
+ parameter_path (str): path to parameter_path
425
+
426
+ Returns:
427
+ Dict[str, Union[str, int, float]]: parameters at training
428
+ """
429
+ with open(parameter_path) as f:
430
+ params = json.load(f)
431
+
432
+ _required = Param_Table.get_by_group('load')
433
+ params = {p: v for p, v in params.items() if p in _required}
434
+ return params
435
+
436
+
437
+ def print_parameter(params: ParamSet) -> None:
438
+ """
439
+ Print parameters.
440
+
441
+ Args:
442
+ params (ParamSet): parameters
443
+ """
444
+
445
+ LINE_LENGTH = 82
446
+
447
+ if params.isTrain:
448
+ phase = 'Training'
449
+ else:
450
+ phase = 'Test'
451
+
452
+ _header = f" Configuration of {phase} "
453
+ _padding = (LINE_LENGTH - len(_header) + 1) // 2 # round up
454
+ _header = ('-' * _padding) + _header + ('-' * _padding) + '\n'
455
+
456
+ _footer = ' End '
457
+ _padding = (LINE_LENGTH - len(_footer) + 1) // 2
458
+ _footer = ('-' * _padding) + _footer + ('-' * _padding) + '\n'
459
+
460
+ message = ''
461
+ message += _header
462
+
463
+ _params_dict = vars(params)
464
+ del _params_dict['isTrain']
465
+ for _param, _arg in _params_dict.items():
466
+ _str_arg = _arg2str(_param, _arg)
467
+ message += f"{_param:>30}: {_str_arg:<40}\n"
468
+
469
+ message += _footer
470
+ logger.info(message)
471
+
472
+
473
+ def _arg2str(param: str, arg: Union[str, int, float]) -> str:
474
+ """
475
+ Convert argument to string.
476
+
477
+ Args:
478
+ param (str): parameter
479
+ arg (Union[str, int, float]): argument
480
+
481
+ Returns:
482
+ str: strings of argument
483
+ """
484
+ if param == 'lr':
485
+ if arg is None:
486
+ str_arg = 'Default'
487
+ else:
488
+ str_arg = str(param)
489
+ return str_arg
490
+ elif param == 'gpu_ids':
491
+ if arg == []:
492
+ str_arg = 'CPU selected'
493
+ else:
494
+ str_arg = f"{arg} (Primary GPU:{arg[0]})"
495
+ return str_arg
496
+ elif param == 'test_splits':
497
+ str_arg = ', '.join(arg)
498
+ return str_arg
499
+ elif param == 'dataset_info':
500
+ str_arg = ''
501
+ for i, (split, total) in enumerate(arg.items()):
502
+ if i < len(arg) - 1:
503
+ str_arg += (f"{split}_data={total}, ")
504
+ else:
505
+ str_arg += (f"{split}_data={total}")
506
+ return str_arg
507
+ else:
508
+ if arg is None:
509
+ str_arg = 'No need'
510
+ else:
511
+ str_arg = str(arg)
512
+ return str_arg
513
+
514
+
515
+ def _check_if_valid_criterion(task: str = None, criterion: str = None) -> None:
516
+ """
517
+ Check if criterion is valid.
518
+
519
+ Args:
520
+ task (str): task
521
+ criterion (str): criterion
522
+ """
523
+ valid_criterion = {
524
+ 'classification': ['CEL'],
525
+ 'regression': ['MSE', 'RMSE', 'MAE'],
526
+ 'deepsurv': ['NLL']
527
+ }
528
+ if criterion in valid_criterion[task]:
529
+ pass
530
+ else:
531
+ raise ValueError(f"Invalid criterion for task: task={task}, criterion={criterion}.")
532
+
533
+
534
+ def _train_parse(args: argparse.Namespace) -> Dict[str, ParamSet]:
535
+ """
536
+ Parse parameters required at training.
537
+
538
+ Args:
539
+ args (argparse.Namespace): arguments
540
+
541
+ Returns:
542
+ Dict[str, ParamSet]: parameters dispatched by group
543
+ """
544
+ # Check if criterion is valid.
545
+ _check_if_valid_criterion(task=args.task, criterion=args.criterion)
546
+
547
+ args.project = Path(args.csvpath).stem
548
+ args.gpu_ids = _parse_gpu_ids(args.gpu_ids)
549
+ args.device = torch.device(f"cuda:{args.gpu_ids[0]}") if args.gpu_ids != [] else torch.device('cpu')
550
+ args.mlp, args.net = _parse_model(args.model)
551
+ args.pretrained = bool(args.pretrained) # strtobool('False') = 0 (== False)
552
+ args.save_datetime_dir = str(Path('results', args.project, 'trials', args.datetime))
553
+
554
+ # Parse csv
555
+ _csvparser = CSVParser(args.csvpath, args.task, args.isTrain)
556
+ args.df_source = _csvparser.df_source
557
+ args.dataset_info = {split: len(args.df_source[args.df_source['split'] == split]) for split in ['train', 'val']}
558
+ args.input_list = _csvparser.input_list
559
+ args.label_list = _csvparser.label_list
560
+ args.mlp_num_inputs = _csvparser.mlp_num_inputs
561
+ args.num_outputs_for_label = _csvparser.num_outputs_for_label
562
+ if args.task == 'deepsurv':
563
+ args.period_name = _csvparser.period_name
564
+
565
+ # Dispatch parameters
566
+ return {
567
+ 'args_model': _dispatch_by_group(args, 'model'),
568
+ 'args_dataloader': _dispatch_by_group(args, 'dataloader'),
569
+ 'args_conf': _dispatch_by_group(args, 'train_conf'),
570
+ 'args_print': _dispatch_by_group(args, 'train_print'),
571
+ 'args_save': _dispatch_by_group(args, 'save')
572
+ }
573
+
574
+
575
+ def _test_parse(args: argparse.Namespace) -> Dict[str, ParamSet]:
576
+ """
577
+ Parse parameters required at test.
578
+
579
+ Args:
580
+ args (argparse.Namespace): arguments
581
+
582
+ Returns:
583
+ Dict[str, ParamSet]: parameters dispatched by group
584
+ """
585
+ args.project = Path(args.csvpath).stem
586
+ args.gpu_ids = _parse_gpu_ids(args.gpu_ids)
587
+ args.device = torch.device(f"cuda:{args.gpu_ids[0]}") if args.gpu_ids != [] else torch.device('cpu')
588
+
589
+ # Collect weight paths
590
+ if args.weight_dir is None:
591
+ args.weight_dir = _get_latest_weight_dir()
592
+ args.weight_paths = _collect_weight_paths(args.weight_dir)
593
+
594
+ # Get datetime at training
595
+ _train_datetime_dir = Path(args.weight_dir).parents[0]
596
+ _train_datetime = _train_datetime_dir.name
597
+
598
+ args.save_datetime_dir = str(Path('results', args.project, 'trials', _train_datetime))
599
+
600
+ # Retrieve only parameters required at test
601
+ _parameter_path = str(Path(_train_datetime_dir, 'parameters.json'))
602
+ params = _retrieve_parameter(_parameter_path)
603
+ for _param, _arg in params.items():
604
+ setattr(args, _param, _arg)
605
+
606
+ # When test, the followings are always fixed.
607
+ args.augmentation = 'no'
608
+ args.sampler = 'no'
609
+ args.pretrained = False
610
+
611
+ args.mlp, args.net = _parse_model(args.model)
612
+ if args.mlp is not None:
613
+ args.scaler_path = str(Path(_train_datetime_dir, 'scaler.pkl'))
614
+
615
+ # Parse csv
616
+ _csvparser = CSVParser(args.csvpath, args.task)
617
+ args.df_source = _csvparser.df_source
618
+
619
+ # Align test_splits
620
+ args.test_splits = args.test_splits.split('-')
621
+ _splits = args.df_source['split'].unique().tolist()
622
+ if set(_splits) < set(args.test_splits):
623
+ args.test_splits = _splits
624
+
625
+ args.dataset_info = {split: len(args.df_source[args.df_source['split'] == split]) for split in args.test_splits}
626
+
627
+ # Dispatch parameters
628
+ return {
629
+ 'args_model': _dispatch_by_group(args, 'model'),
630
+ 'args_dataloader': _dispatch_by_group(args, 'dataloader'),
631
+ 'args_conf': _dispatch_by_group(args, 'test_conf'),
632
+ 'args_print': _dispatch_by_group(args, 'test_print')
633
+ }
634
+
635
+ def set_options(datetime_name: str = None, phase: str = None) -> argparse.Namespace:
636
+ """
637
+ Parse options for training or test.
638
+
639
+ Args:
640
+ datetime_name (str, optional): datetime name. Defaults to None.
641
+ phase (str, optional): train or test. Defaults to None.
642
+
643
+ Returns:
644
+ argparse.Namespace: arguments
645
+ """
646
+ if phase == 'train':
647
+ opt = Options(datetime=datetime_name, isTrain=True)
648
+ _args = opt.get_args()
649
+ args = _train_parse(_args)
650
+ return args
651
+ else:
652
+ opt = Options(isTrain=False)
653
+ _args = opt.get_args()
654
+ args = _test_parse(_args)
655
+ return args