Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
Commit
·
1f53a4c
1
Parent(s):
8630b06
Upload app.py and lib.
Browse files- app.py +61 -0
- lib/__init__.py +24 -0
- lib/component/__init__.py +16 -0
- lib/component/criterion.py +332 -0
- lib/component/likelihood.py +107 -0
- lib/component/loss.py +248 -0
- lib/component/net.py +624 -0
- lib/component/optimizer.py +34 -0
- lib/dataloader.py +400 -0
- lib/framework.py +373 -0
- lib/logger.py +71 -0
- lib/metrics.py +623 -0
- lib/options.py +655 -0
app.py
ADDED
@@ -0,0 +1,61 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
|
4 |
+
import torch
|
5 |
+
import gradio as gr
|
6 |
+
from lib import create_model
|
7 |
+
from lib.options import ParamSet, _retrieve_parameter, _dispatch_by_group
|
8 |
+
from lib.dataloader import ImageMixin
|
9 |
+
|
10 |
+
|
11 |
+
test_weight = './weight_epoch-200_best.pt'
|
12 |
+
parameter = './parameters.json'
|
13 |
+
|
14 |
+
class ImageHandler(ImageMixin):
|
15 |
+
def __init__(self, params):
|
16 |
+
self.params = params
|
17 |
+
self.transform = self._make_transforms()
|
18 |
+
|
19 |
+
def set_image(self, image):
|
20 |
+
image = self.transform(image)
|
21 |
+
image = {'image': image.unsqueeze(0)}
|
22 |
+
return image
|
23 |
+
|
24 |
+
def load_parameter(parameter):
|
25 |
+
_args = ParamSet()
|
26 |
+
params = _retrieve_parameter(parameter)
|
27 |
+
for _param, _arg in params.items():
|
28 |
+
setattr(_args, _param, _arg)
|
29 |
+
|
30 |
+
_args.augmentation = 'no'
|
31 |
+
_args.sampler = 'no'
|
32 |
+
_args.pretrained = False
|
33 |
+
_args.mlp = None
|
34 |
+
_args.net = _args.model
|
35 |
+
_args.device = torch.device('cpu')
|
36 |
+
|
37 |
+
args_model = _dispatch_by_group(_args, 'model')
|
38 |
+
args_dataloader = _dispatch_by_group(_args, 'dataloader')
|
39 |
+
return args_model, args_dataloader
|
40 |
+
|
41 |
+
args_model, args_dataloader = load_parameter(parameter)
|
42 |
+
model = create_model(args_model)
|
43 |
+
model.load_weight(test_weight)
|
44 |
+
|
45 |
+
def main(image):
|
46 |
+
model.eval()
|
47 |
+
image_handler = ImageHandler(args_dataloader)
|
48 |
+
image = image_handler.set_image(image)
|
49 |
+
|
50 |
+
with torch.no_grad():
|
51 |
+
outputs = model(image)
|
52 |
+
|
53 |
+
label_name = list(outputs.keys())[0]
|
54 |
+
result = outputs[label_name].detach().numpy().item()
|
55 |
+
result = f"{result:.2f}"
|
56 |
+
return result
|
57 |
+
|
58 |
+
|
59 |
+
# Gradio
|
60 |
+
iface = gr.Interface(fn=main, inputs=[gr.Image(type='pil', image_mode='L')], outputs='text')
|
61 |
+
iface.launch()
|
lib/__init__.py
ADDED
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
|
4 |
+
from .options import (
|
5 |
+
ParamSet,
|
6 |
+
set_options,
|
7 |
+
save_parameter,
|
8 |
+
print_parameter
|
9 |
+
)
|
10 |
+
from .dataloader import create_dataloader
|
11 |
+
from .framework import create_model
|
12 |
+
from .metrics import set_eval
|
13 |
+
from .logger import BaseLogger
|
14 |
+
|
15 |
+
__all__ = [
|
16 |
+
'ParamSet',
|
17 |
+
'set_options',
|
18 |
+
'print_parameter',
|
19 |
+
'save_parameter',
|
20 |
+
'create_dataloader',
|
21 |
+
'create_model',
|
22 |
+
'set_eval',
|
23 |
+
'BaseLogger'
|
24 |
+
]
|
lib/component/__init__.py
ADDED
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
|
4 |
+
from .net import create_net
|
5 |
+
from .criterion import set_criterion
|
6 |
+
from .optimizer import set_optimizer
|
7 |
+
from .loss import set_loss_store
|
8 |
+
from .likelihood import set_likelihood
|
9 |
+
|
10 |
+
__all__ = [
|
11 |
+
'create_net',
|
12 |
+
'set_criterion',
|
13 |
+
'set_optimizer',
|
14 |
+
'set_loss_store',
|
15 |
+
'set_likelihood'
|
16 |
+
]
|
lib/component/criterion.py
ADDED
@@ -0,0 +1,332 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
|
4 |
+
import torch
|
5 |
+
import torch.nn as nn
|
6 |
+
from typing import Dict, Union
|
7 |
+
|
8 |
+
# Alias of typing
|
9 |
+
# eg. {'labels': {'label_A: torch.Tensor([0, 1, ...]), ...}}
|
10 |
+
LabelDict = Dict[str, Dict[str, Union[torch.IntTensor, torch.FloatTensor]]]
|
11 |
+
|
12 |
+
|
13 |
+
class RMSELoss(nn.Module):
|
14 |
+
"""
|
15 |
+
Class to calculate RMSE.
|
16 |
+
"""
|
17 |
+
def __init__(self, eps: float = 1e-7) -> None:
|
18 |
+
"""
|
19 |
+
Args:
|
20 |
+
eps (float, optional): value to avoid 0. Defaults to 1e-7.
|
21 |
+
"""
|
22 |
+
super().__init__()
|
23 |
+
self.mse = nn.MSELoss()
|
24 |
+
self.eps = eps
|
25 |
+
|
26 |
+
def forward(self, yhat: float, y: float) -> torch.FloatTensor:
|
27 |
+
"""
|
28 |
+
Calculate RMSE.
|
29 |
+
|
30 |
+
Args:
|
31 |
+
yhat (float): prediction value
|
32 |
+
y (float): ground truth value
|
33 |
+
|
34 |
+
Returns:
|
35 |
+
float: RMSE
|
36 |
+
"""
|
37 |
+
_loss = self.mse(yhat, y) + self.eps
|
38 |
+
return torch.sqrt(_loss)
|
39 |
+
|
40 |
+
|
41 |
+
class Regularization:
|
42 |
+
"""
|
43 |
+
Class to calculate regularization loss.
|
44 |
+
|
45 |
+
Args:
|
46 |
+
object (object): object
|
47 |
+
"""
|
48 |
+
def __init__(self, order: int, weight_decay: float) -> None:
|
49 |
+
"""
|
50 |
+
The initialization of Regularization class.
|
51 |
+
|
52 |
+
Args:
|
53 |
+
order: (int) norm order number
|
54 |
+
weight_decay: (float) weight decay rate
|
55 |
+
"""
|
56 |
+
super().__init__()
|
57 |
+
self.order = order
|
58 |
+
self.weight_decay = weight_decay
|
59 |
+
|
60 |
+
def __call__(self, network: nn.Module) -> torch.FloatTensor:
|
61 |
+
""""
|
62 |
+
Calculates regularization(self.order) loss for network.
|
63 |
+
|
64 |
+
Args:
|
65 |
+
model: (torch.nn.Module object)
|
66 |
+
|
67 |
+
Returns:
|
68 |
+
torch.FloatTensor: the regularization(self.order) loss
|
69 |
+
"""
|
70 |
+
reg_loss = 0
|
71 |
+
for name, w in network.named_parameters():
|
72 |
+
if 'weight' in name:
|
73 |
+
reg_loss = reg_loss + torch.norm(w, p=self.order)
|
74 |
+
reg_loss = self.weight_decay * reg_loss
|
75 |
+
return reg_loss
|
76 |
+
|
77 |
+
|
78 |
+
class NegativeLogLikelihood(nn.Module):
|
79 |
+
"""
|
80 |
+
Class to calculate RMSE.
|
81 |
+
"""
|
82 |
+
def __init__(self, device: torch.device) -> None:
|
83 |
+
"""
|
84 |
+
Args:
|
85 |
+
device (torch.device): device
|
86 |
+
"""
|
87 |
+
super().__init__()
|
88 |
+
self.L2_reg = 0.05
|
89 |
+
self.reg = Regularization(order=2, weight_decay=self.L2_reg)
|
90 |
+
self.device = device
|
91 |
+
|
92 |
+
def forward(
|
93 |
+
self,
|
94 |
+
output: torch.FloatTensor,
|
95 |
+
label: torch.IntTensor,
|
96 |
+
periods: torch.FloatTensor,
|
97 |
+
network: nn.Module
|
98 |
+
) -> torch.FloatTensor:
|
99 |
+
"""
|
100 |
+
Calculates Negative Log Likelihood.
|
101 |
+
|
102 |
+
Args:
|
103 |
+
output (torch.FloatTensor): prediction value, ie risk prediction
|
104 |
+
label (torch.IntTensor): occurrence of event
|
105 |
+
periods (torch.FloatTensor): period
|
106 |
+
network (nn.Network): network
|
107 |
+
|
108 |
+
Returns:
|
109 |
+
torch.FloatTensor: Negative Log Likelihood
|
110 |
+
"""
|
111 |
+
mask = torch.ones(periods.shape[0], periods.shape[0]).to(self.device) # output and mask should be on the same device.
|
112 |
+
mask[(periods.T - periods) > 0] = 0
|
113 |
+
|
114 |
+
_loss = torch.exp(output) * mask
|
115 |
+
# Note: torch.sum(_loss, dim=0) possibly returns nan, in particular MLP.
|
116 |
+
_loss = torch.sum(_loss, dim=0) / torch.sum(mask, dim=0)
|
117 |
+
_loss = torch.log(_loss).reshape(-1, 1)
|
118 |
+
num_occurs = torch.sum(label)
|
119 |
+
|
120 |
+
if num_occurs.item() == 0.0:
|
121 |
+
loss = torch.tensor([1e-7], requires_grad=True).to(self.device) # To avoid zero division, set small value as loss
|
122 |
+
return loss
|
123 |
+
else:
|
124 |
+
neg_log_loss = -torch.sum((output - _loss) * label) / num_occurs
|
125 |
+
l2_loss = self.reg(network)
|
126 |
+
loss = neg_log_loss + l2_loss
|
127 |
+
return loss
|
128 |
+
|
129 |
+
|
130 |
+
class ClsCriterion:
|
131 |
+
"""
|
132 |
+
Class of criterion for classification.
|
133 |
+
"""
|
134 |
+
def __init__(self, device: torch.device = None) -> None:
|
135 |
+
"""
|
136 |
+
Set CrossEntropyLoss.
|
137 |
+
|
138 |
+
Args:
|
139 |
+
device (torch.device): device
|
140 |
+
"""
|
141 |
+
self.device = device
|
142 |
+
self.criterion = nn.CrossEntropyLoss()
|
143 |
+
|
144 |
+
def __call__(
|
145 |
+
self,
|
146 |
+
outputs: Dict[str, torch.FloatTensor],
|
147 |
+
labels: Dict[str, LabelDict]
|
148 |
+
) -> Dict[str, torch.FloatTensor]:
|
149 |
+
"""
|
150 |
+
Calculate loss.
|
151 |
+
|
152 |
+
Args:
|
153 |
+
outputs (Dict[str, torch.FloatTensor], optional): output
|
154 |
+
labels (Dict[str, LabelDict]): labels
|
155 |
+
|
156 |
+
Returns:
|
157 |
+
Dict[str, torch.FloatTensor]: loss for each label and their total loss
|
158 |
+
|
159 |
+
# No reshape and no cast:
|
160 |
+
output: [64, 2]: torch.float32
|
161 |
+
label: [64] : torch.int64
|
162 |
+
label.dtype should be torch.int64, otherwise nn.CrossEntropyLoss() causes error.
|
163 |
+
|
164 |
+
eg.
|
165 |
+
outputs = {'label_A': [[0.8, 0.2], ...] 'label_B': [[0.7, 0.3]], ...}
|
166 |
+
labels = { 'labels': {'label_A: 1: [1, 1, 0, ...], 'label_B': [0, 0, 1, ...], ...} }
|
167 |
+
|
168 |
+
-> losses = {total: loss_total, label_A: loss_A, label_B: loss_B, ... }
|
169 |
+
"""
|
170 |
+
_labels = labels['labels']
|
171 |
+
|
172 |
+
# loss for each label and total of their losses
|
173 |
+
losses = dict()
|
174 |
+
losses['total'] = torch.tensor([0.0], requires_grad=True).to(self.device)
|
175 |
+
for label_name in labels['labels'].keys():
|
176 |
+
_output = outputs[label_name]
|
177 |
+
_label = _labels[label_name]
|
178 |
+
_label_loss = self.criterion(_output, _label)
|
179 |
+
losses[label_name] = _label_loss
|
180 |
+
losses['total'] = torch.add(losses['total'], _label_loss)
|
181 |
+
return losses
|
182 |
+
|
183 |
+
|
184 |
+
class RegCriterion:
|
185 |
+
"""
|
186 |
+
Class of criterion for regression.
|
187 |
+
"""
|
188 |
+
def __init__(self, criterion_name: str = None, device: torch.device = None) -> None:
|
189 |
+
"""
|
190 |
+
Set MSE, RMSE or MAE.
|
191 |
+
|
192 |
+
Args:
|
193 |
+
criterion_name (str): 'MSE', 'RMSE', or 'MAE'
|
194 |
+
device (torch.device): device
|
195 |
+
"""
|
196 |
+
self.device = device
|
197 |
+
|
198 |
+
if criterion_name == 'MSE':
|
199 |
+
self.criterion = nn.MSELoss()
|
200 |
+
elif criterion_name == 'RMSE':
|
201 |
+
self.criterion = RMSELoss()
|
202 |
+
elif criterion_name == 'MAE':
|
203 |
+
self.criterion = nn.L1Loss()
|
204 |
+
else:
|
205 |
+
raise ValueError(f"Invalid criterion for regression: {criterion_name}.")
|
206 |
+
|
207 |
+
def __call__(
|
208 |
+
self,
|
209 |
+
outputs: Dict[str, torch.FloatTensor],
|
210 |
+
labels: Dict[str, LabelDict]
|
211 |
+
) -> Dict[str, torch.FloatTensor]:
|
212 |
+
"""
|
213 |
+
Calculate loss.
|
214 |
+
|
215 |
+
Args:
|
216 |
+
Args:
|
217 |
+
outputs (Dict[str, torch.FloatTensor], optional): output
|
218 |
+
labels (Dict[str, LabelDict]): labels
|
219 |
+
|
220 |
+
Returns:
|
221 |
+
Dict[str, torch.FloatTensor]: loss for each label and their total loss
|
222 |
+
|
223 |
+
# Reshape and cast
|
224 |
+
output: [64, 1] -> [64]: torch.float32
|
225 |
+
label: [64]: torch.float64 -> torch.float32
|
226 |
+
# label.dtype should be torch.float32, otherwise cannot backward.
|
227 |
+
|
228 |
+
eg.
|
229 |
+
outputs = {'label_A': [[10.8], ...] 'label_B': [[15.7]], ...}
|
230 |
+
labels = {'labels': {'label_A: 1: [10, 9, ...], 'label_B': [12, 17,], ...}}
|
231 |
+
-> losses = {total: loss_total, label_A: loss_A, label_B: loss_B, ... }
|
232 |
+
"""
|
233 |
+
_outputs = {label_name: _output.squeeze() for label_name, _output in outputs.items()}
|
234 |
+
_labels = {label_name: _label.to(torch.float32) for label_name, _label in labels['labels'].items()}
|
235 |
+
|
236 |
+
# loss for each label and total of their losses
|
237 |
+
losses = dict()
|
238 |
+
losses['total'] = torch.tensor([0.0], requires_grad=True).to(self.device)
|
239 |
+
for label_name in labels['labels'].keys():
|
240 |
+
_output = _outputs[label_name]
|
241 |
+
_label = _labels[label_name]
|
242 |
+
_label_loss = self.criterion(_output, _label)
|
243 |
+
losses[label_name] = _label_loss
|
244 |
+
losses['total'] = torch.add(losses['total'], _label_loss)
|
245 |
+
return losses
|
246 |
+
|
247 |
+
|
248 |
+
class DeepSurvCriterion:
|
249 |
+
"""
|
250 |
+
Class of criterion for deepsurv.
|
251 |
+
"""
|
252 |
+
def __init__(self, device: torch.device = None) -> None:
|
253 |
+
"""
|
254 |
+
Set NegativeLogLikelihood.
|
255 |
+
|
256 |
+
Args:
|
257 |
+
device (torch.device, optional): device
|
258 |
+
"""
|
259 |
+
self.device = device
|
260 |
+
self.criterion = NegativeLogLikelihood(self.device).to(self.device)
|
261 |
+
|
262 |
+
def __call__(
|
263 |
+
self,
|
264 |
+
outputs: Dict[str, torch.FloatTensor],
|
265 |
+
labels: Dict[str, Union[LabelDict, torch.IntTensor, nn.Module]]
|
266 |
+
) -> Dict[str, torch.FloatTensor]:
|
267 |
+
"""
|
268 |
+
Calculate loss.
|
269 |
+
|
270 |
+
Args:
|
271 |
+
outputs (Dict[str, torch.FloatTensor], optional): output
|
272 |
+
labels (Dict[str, Union[LabelDict, torch.IntTensor, nn.Module]]): labels, periods, and network
|
273 |
+
|
274 |
+
Returns:
|
275 |
+
Dict[str, torch.FloatTensor]: loss for each label and their total loss
|
276 |
+
|
277 |
+
# Reshape and no cast
|
278 |
+
output: [64, 1]: torch.float32
|
279 |
+
label: [64] -> [64, 1]: torch.int64
|
280 |
+
period: [64] -> [64, 1]: torch.float32
|
281 |
+
|
282 |
+
eg.
|
283 |
+
outputs = {'label_A': [[10.8], ...] 'label_B': [[15.7]], ...}
|
284 |
+
labels = {
|
285 |
+
'labels': {'label_A: 1: [1, 0, 1, ...] },
|
286 |
+
'periods': [5, 10, 7, ...],
|
287 |
+
'network': network
|
288 |
+
}
|
289 |
+
-> losses = {total: loss_total, label_A: loss_A, label_B: loss_B, ... }
|
290 |
+
"""
|
291 |
+
_labels = {label_name: _label.reshape(-1, 1) for label_name, _label in labels['labels'].items()}
|
292 |
+
_periods = labels['periods'].reshape(-1, 1)
|
293 |
+
_network = labels['network']
|
294 |
+
|
295 |
+
# loss for each label and total of their losses
|
296 |
+
losses = dict()
|
297 |
+
losses['total'] = torch.tensor([0.0], requires_grad=True).to(self.device)
|
298 |
+
for label_name in labels['labels'].keys():
|
299 |
+
_output = outputs[label_name]
|
300 |
+
_label = _labels[label_name]
|
301 |
+
_label_loss = self.criterion(_output, _label, _periods, _network)
|
302 |
+
losses[label_name] = _label_loss
|
303 |
+
losses['total'] = torch.add(losses['total'], _label_loss)
|
304 |
+
return losses
|
305 |
+
|
306 |
+
|
307 |
+
def set_criterion(
|
308 |
+
criterion_name: str,
|
309 |
+
device: torch.device
|
310 |
+
) -> Union[ClsCriterion, RegCriterion, DeepSurvCriterion]:
|
311 |
+
"""
|
312 |
+
Return criterion class
|
313 |
+
|
314 |
+
Args:
|
315 |
+
criterion_name (str): criterion name
|
316 |
+
device (torch.device): device
|
317 |
+
|
318 |
+
Returns:
|
319 |
+
Union[ClsCriterion, RegCriterion, DeepSurvCriterion]: criterion class
|
320 |
+
"""
|
321 |
+
|
322 |
+
if criterion_name == 'CEL':
|
323 |
+
return ClsCriterion(device=device)
|
324 |
+
|
325 |
+
elif criterion_name in ['MSE', 'RMSE', 'MAE']:
|
326 |
+
return RegCriterion(criterion_name=criterion_name, device=device)
|
327 |
+
|
328 |
+
elif criterion_name == 'NLL':
|
329 |
+
return DeepSurvCriterion(device=device)
|
330 |
+
|
331 |
+
else:
|
332 |
+
raise ValueError(f"Invalid criterion: {criterion_name}.")
|
lib/component/likelihood.py
ADDED
@@ -0,0 +1,107 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
|
4 |
+
import pandas as pd
|
5 |
+
import torch
|
6 |
+
from typing import List, Dict
|
7 |
+
|
8 |
+
|
9 |
+
class Likelihood:
|
10 |
+
"""
|
11 |
+
Class for making likelihood.
|
12 |
+
"""
|
13 |
+
def __init__(self, task: str, num_outputs_for_label: Dict[str, int]) -> None:
|
14 |
+
"""
|
15 |
+
Args:
|
16 |
+
task (str): task
|
17 |
+
num_outputs_for_label (Dict[str, int]): number of classes for each label
|
18 |
+
"""
|
19 |
+
self.task = task
|
20 |
+
self.num_outputs_for_label = num_outputs_for_label
|
21 |
+
self.base_column_list = self._set_base_columns(self.task)
|
22 |
+
self.pred_column_list = self._make_pred_columns(self.task, self.num_outputs_for_label)
|
23 |
+
|
24 |
+
def _set_base_columns(self, task: str) -> List[str]:
|
25 |
+
"""
|
26 |
+
Return base columns.
|
27 |
+
|
28 |
+
Args:
|
29 |
+
task (str): task
|
30 |
+
|
31 |
+
Returns:
|
32 |
+
List[str]: base columns except columns of label and prediction
|
33 |
+
"""
|
34 |
+
if (task == 'classification') or (task == 'regression'):
|
35 |
+
base_columns = ['uniqID', 'group', 'imgpath', 'split']
|
36 |
+
return base_columns
|
37 |
+
elif task == 'deepsurv':
|
38 |
+
base_columns = ['uniqID', 'group', 'imgpath', 'split', 'periods']
|
39 |
+
return base_columns
|
40 |
+
else:
|
41 |
+
raise ValueError(f"Invalid task: {task}.")
|
42 |
+
|
43 |
+
def _make_pred_columns(self, task: str, num_outputs_for_label: Dict[str, int]) -> Dict[str, List[str]]:
|
44 |
+
"""
|
45 |
+
Make column names of predictions with label name and its number of classes.
|
46 |
+
|
47 |
+
Args:
|
48 |
+
task (str): task
|
49 |
+
num_outputs_for_label (Dict[str, int]): number of classes for each label
|
50 |
+
|
51 |
+
Returns:
|
52 |
+
Dict[str, List[str]]: label and list of columns of predictions with its class number
|
53 |
+
|
54 |
+
eg.
|
55 |
+
{label_A: 2, label_B: 2} -> {label_A: [pred_label_A_0, pred_label_A_1], label_B: [pred_label_B_0, pred_label_B_1]}
|
56 |
+
{label_A: 1, label_B: 1} -> {label_A: [pred_label_A], label_B: [pred_label_B]}
|
57 |
+
"""
|
58 |
+
pred_columns = dict()
|
59 |
+
if task == 'classification':
|
60 |
+
for label_name, num_classes in num_outputs_for_label.items():
|
61 |
+
pred_columns[label_name] = ['pred_' + label_name + '_' + str(i) for i in range(num_classes)]
|
62 |
+
return pred_columns
|
63 |
+
elif (task == 'regression') or (task == 'deepsurv'):
|
64 |
+
for label_name, num_classes in num_outputs_for_label.items():
|
65 |
+
pred_columns[label_name] = ['pred_' + label_name]
|
66 |
+
return pred_columns
|
67 |
+
else:
|
68 |
+
raise ValueError(f"Invalid task: {task}.")
|
69 |
+
|
70 |
+
def make_format(self, data: Dict, output: Dict[str, torch.Tensor]) -> pd.DataFrame:
|
71 |
+
"""
|
72 |
+
Make a new DataFrame of likelihood every batch.
|
73 |
+
|
74 |
+
Args:
|
75 |
+
data (Dict): batch data from dataloader
|
76 |
+
output (Dict[str, torch.Tensor]): output of model
|
77 |
+
"""
|
78 |
+
_likelihood = {column_name: data[column_name] for column_name in self.base_column_list}
|
79 |
+
df_likelihood = pd.DataFrame(_likelihood)
|
80 |
+
|
81 |
+
if any(data['labels']):
|
82 |
+
for label_name, pred in output.items():
|
83 |
+
_df_label = pd.DataFrame({label_name: data['labels'][label_name].tolist()})
|
84 |
+
pred = pred.to('cpu').detach().numpy().copy()
|
85 |
+
_df_pred = pd.DataFrame(pred, columns=self.pred_column_list[label_name])
|
86 |
+
df_likelihood = pd.concat([df_likelihood, _df_label, _df_pred], axis=1)
|
87 |
+
return df_likelihood
|
88 |
+
else:
|
89 |
+
for label_name, pred in output.items():
|
90 |
+
pred = pred.to('cpu').detach().numpy().copy()
|
91 |
+
_df_pred = pd.DataFrame(pred, columns=self.pred_column_list[label_name])
|
92 |
+
df_likelihood = pd.concat([df_likelihood, _df_pred], axis=1)
|
93 |
+
return df_likelihood
|
94 |
+
|
95 |
+
|
96 |
+
def set_likelihood(task: str, num_outputs_for_label: Dict[str, int]) -> Likelihood:
|
97 |
+
"""
|
98 |
+
Set likelihood.
|
99 |
+
|
100 |
+
Args:
|
101 |
+
task (str): task
|
102 |
+
num_outputs_for_label (Dict[str, int]): number of classes for each label
|
103 |
+
|
104 |
+
Returns:
|
105 |
+
Likelihood: instance of class Likelihood
|
106 |
+
"""
|
107 |
+
return Likelihood(task, num_outputs_for_label)
|
lib/component/loss.py
ADDED
@@ -0,0 +1,248 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
|
4 |
+
from pathlib import Path
|
5 |
+
import torch
|
6 |
+
import pandas as pd
|
7 |
+
from ..logger import BaseLogger
|
8 |
+
from typing import List, Dict, Union
|
9 |
+
|
10 |
+
|
11 |
+
logger = BaseLogger.get_logger(__name__)
|
12 |
+
|
13 |
+
|
14 |
+
class LabelLoss:
|
15 |
+
"""
|
16 |
+
Class to store loss for every bash and epoch loss of each label.
|
17 |
+
"""
|
18 |
+
def __init__(self) -> None:
|
19 |
+
# Accumulate batch_loss(=loss * batch_size)
|
20 |
+
self.train_batch_loss = 0.0
|
21 |
+
self.val_batch_loss = 0.0
|
22 |
+
|
23 |
+
# epoch_loss = batch_loss / dataset_size
|
24 |
+
self.train_epoch_loss = [] # List[float]
|
25 |
+
self.val_epoch_loss = [] # List[float]
|
26 |
+
|
27 |
+
self.best_val_loss = None # float
|
28 |
+
self.best_epoch = None # int
|
29 |
+
self.is_val_loss_updated = None # bool
|
30 |
+
|
31 |
+
def get_loss(self, phase: str, target: str) -> Union[float, List[float]]:
|
32 |
+
"""
|
33 |
+
Return loss depending on phase and target
|
34 |
+
|
35 |
+
Args:
|
36 |
+
phase (str): 'train' or 'val'
|
37 |
+
target (str): 'batch' or 'epoch'
|
38 |
+
|
39 |
+
Returns:
|
40 |
+
Union[float, List[float]]: batch_loss or epoch_loss
|
41 |
+
"""
|
42 |
+
_target = phase + '_' + target + '_loss'
|
43 |
+
return getattr(self, _target)
|
44 |
+
|
45 |
+
def store_batch_loss(self, phase: str, new_batch_loss: torch.FloatTensor, batch_size: int) -> None:
|
46 |
+
"""
|
47 |
+
Add new batch loss to previous one for phase by multiplying by batch_size.
|
48 |
+
|
49 |
+
Args:
|
50 |
+
phase (str): 'train' or 'val'
|
51 |
+
new_batch_loss (torch.FloatTensor): batch loss calculated by criterion
|
52 |
+
batch_size (int): batch size
|
53 |
+
"""
|
54 |
+
_new = new_batch_loss.item() * batch_size # torch.FloatTensor -> float
|
55 |
+
_prev = self.get_loss(phase, 'batch')
|
56 |
+
_added = _prev + _new
|
57 |
+
_target = phase + '_' + 'batch_loss'
|
58 |
+
setattr(self, _target, _added)
|
59 |
+
|
60 |
+
def append_epoch_loss(self, phase: str, new_epoch_loss: float) -> None:
|
61 |
+
"""
|
62 |
+
Append epoch loss depending on phase and target
|
63 |
+
|
64 |
+
Args:
|
65 |
+
phase (str): 'train' or 'val'
|
66 |
+
new_epoch_loss (float): batch loss or epoch loss
|
67 |
+
"""
|
68 |
+
_target = phase + '_' + 'epoch_loss'
|
69 |
+
getattr(self, _target).append(new_epoch_loss)
|
70 |
+
|
71 |
+
def get_latest_epoch_loss(self, phase: str) -> float:
|
72 |
+
"""
|
73 |
+
Return the latest loss of phase.
|
74 |
+
|
75 |
+
Args:
|
76 |
+
phase (str): train or val
|
77 |
+
|
78 |
+
Returns:
|
79 |
+
float: the latest loss
|
80 |
+
"""
|
81 |
+
return self.get_loss(phase, 'epoch')[-1]
|
82 |
+
|
83 |
+
def update_best_val_loss(self, at_epoch: int = None) -> None:
|
84 |
+
"""
|
85 |
+
Update val_epoch_loss is the best.
|
86 |
+
|
87 |
+
Args:
|
88 |
+
at_epoch (int): epoch when checked
|
89 |
+
"""
|
90 |
+
_latest_val_loss = self.get_latest_epoch_loss('val')
|
91 |
+
|
92 |
+
if at_epoch == 1:
|
93 |
+
self.best_val_loss = _latest_val_loss
|
94 |
+
self.best_epoch = at_epoch
|
95 |
+
self.is_val_loss_updated = True
|
96 |
+
else:
|
97 |
+
# When at_epoch > 1
|
98 |
+
if _latest_val_loss < self.best_val_loss:
|
99 |
+
self.best_val_loss = _latest_val_loss
|
100 |
+
self.best_epoch = at_epoch
|
101 |
+
self.is_val_loss_updated = True
|
102 |
+
else:
|
103 |
+
self.is_val_loss_updated = False
|
104 |
+
|
105 |
+
|
106 |
+
class LossStore:
|
107 |
+
"""
|
108 |
+
Class for calculating loss and store it.
|
109 |
+
"""
|
110 |
+
def __init__(self, label_list: List[str], num_epochs: int, dataset_info: Dict[str, int]) -> None:
|
111 |
+
"""
|
112 |
+
Args:
|
113 |
+
label_list (List[str]): list of internal labels
|
114 |
+
num_epochs (int) : number of epochs
|
115 |
+
dataset_info (Dict[str, int]): dataset sizes of 'train' and 'val'
|
116 |
+
"""
|
117 |
+
self.label_list = label_list
|
118 |
+
self.num_epochs = num_epochs
|
119 |
+
self.dataset_info = dataset_info
|
120 |
+
|
121 |
+
# Added a special label 'total' to store total of losses of all labels.
|
122 |
+
self.label_losses = {label_name: LabelLoss() for label_name in self.label_list + ['total']}
|
123 |
+
|
124 |
+
def store(self, phase: str, losses: Dict[str, torch.FloatTensor], batch_size: int = None) -> None:
|
125 |
+
"""
|
126 |
+
Store label-wise batch losses of phase to previous one.
|
127 |
+
|
128 |
+
Args:
|
129 |
+
phase (str): 'train' or 'val'
|
130 |
+
losses (Dict[str, torch.FloatTensor]): loss for each label calculated by criterion
|
131 |
+
batch_size (int): batch size
|
132 |
+
|
133 |
+
# Note:
|
134 |
+
self.loss_stores['total'] is already total of losses of all label, which is calculated in criterion.py,
|
135 |
+
therefore, it is OK just to multiply by batch_size. This is done in add_batch_loss().
|
136 |
+
"""
|
137 |
+
for label_name in self.label_list + ['total']:
|
138 |
+
_new_batch_loss = losses[label_name]
|
139 |
+
self.label_losses[label_name].store_batch_loss(phase, _new_batch_loss, batch_size)
|
140 |
+
|
141 |
+
def cal_epoch_loss(self, at_epoch: int = None) -> None:
|
142 |
+
"""
|
143 |
+
Calculate epoch loss for each phase all at once.
|
144 |
+
|
145 |
+
Args:
|
146 |
+
at_epoch (int): epoch number
|
147 |
+
"""
|
148 |
+
# For each label
|
149 |
+
for label_name in self.label_list:
|
150 |
+
for phase in ['train', 'val']:
|
151 |
+
_batch_loss = self.label_losses[label_name].get_loss(phase, 'batch')
|
152 |
+
_dataset_size = self.dataset_info[phase]
|
153 |
+
_new_epoch_loss = _batch_loss / _dataset_size
|
154 |
+
self.label_losses[label_name].append_epoch_loss(phase, _new_epoch_loss)
|
155 |
+
|
156 |
+
# For total, average by dataset_size and the number of labels.
|
157 |
+
for phase in ['train', 'val']:
|
158 |
+
_batch_loss = self.label_losses['total'].get_loss(phase, 'batch')
|
159 |
+
_dataset_size = self.dataset_info[phase]
|
160 |
+
_new_epoch_loss = _batch_loss / (_dataset_size * len(self.label_list))
|
161 |
+
self.label_losses['total'].append_epoch_loss(phase, _new_epoch_loss)
|
162 |
+
|
163 |
+
# Update val_best_loss and best_epoch.
|
164 |
+
for label_name in self.label_list + ['total']:
|
165 |
+
self.label_losses[label_name].update_best_val_loss(at_epoch=at_epoch)
|
166 |
+
|
167 |
+
# Initialize batch_loss after calculating epoch loss.
|
168 |
+
for label_name in self.label_list + ['total']:
|
169 |
+
self.label_losses[label_name].train_batch_loss = 0.0
|
170 |
+
self.label_losses[label_name].val_batch_loss = 0.0
|
171 |
+
|
172 |
+
def is_val_loss_updated(self) -> bool:
|
173 |
+
"""
|
174 |
+
Check if val_loss of 'total' is updated.
|
175 |
+
|
176 |
+
Returns:
|
177 |
+
bool: Updated or not
|
178 |
+
"""
|
179 |
+
return self.label_losses['total'].is_val_loss_updated
|
180 |
+
|
181 |
+
def get_best_epoch(self) -> int:
|
182 |
+
"""
|
183 |
+
Returns best epoch.
|
184 |
+
|
185 |
+
Returns:
|
186 |
+
int: best epoch
|
187 |
+
"""
|
188 |
+
return self.label_losses['total'].best_epoch
|
189 |
+
|
190 |
+
def print_epoch_loss(self, at_epoch: int = None) -> None:
|
191 |
+
"""
|
192 |
+
Print train_loss and val_loss for the ith epoch.
|
193 |
+
|
194 |
+
Args:
|
195 |
+
at_epoch (int): epoch number
|
196 |
+
"""
|
197 |
+
train_epoch_loss = self.label_losses['total'].get_latest_epoch_loss('train')
|
198 |
+
val_epoch_loss = self.label_losses['total'].get_latest_epoch_loss('val')
|
199 |
+
|
200 |
+
_epoch_comm = f"epoch [{at_epoch:>3}/{self.num_epochs:<3}]"
|
201 |
+
_train_comm = f"train_loss: {train_epoch_loss :>8.4f}"
|
202 |
+
_val_comm = f"val_loss: {val_epoch_loss:>8.4f}"
|
203 |
+
_updated_comment = ''
|
204 |
+
if (at_epoch > 1) and (self.is_val_loss_updated()):
|
205 |
+
_updated_comment = ' Updated best val_loss!'
|
206 |
+
comment = _epoch_comm + ', ' + _train_comm + ', ' + _val_comm + _updated_comment
|
207 |
+
logger.info(comment)
|
208 |
+
|
209 |
+
def save_learning_curve(self, save_datetime_dir: str) -> None:
|
210 |
+
"""
|
211 |
+
Save learning curve.
|
212 |
+
|
213 |
+
Args:
|
214 |
+
save_datetime_dir (str): save_datetime_dir
|
215 |
+
"""
|
216 |
+
save_dir = Path(save_datetime_dir, 'learning_curve')
|
217 |
+
save_dir.mkdir(parents=True, exist_ok=True)
|
218 |
+
|
219 |
+
for label_name in self.label_list + ['total']:
|
220 |
+
_label_loss = self.label_losses[label_name]
|
221 |
+
_train_epoch_loss = _label_loss.get_loss('train', 'epoch')
|
222 |
+
_val_epoch_loss = _label_loss.get_loss('val', 'epoch')
|
223 |
+
|
224 |
+
df_label_epoch_loss = pd.DataFrame({
|
225 |
+
'train_loss': _train_epoch_loss,
|
226 |
+
'val_loss': _val_epoch_loss
|
227 |
+
})
|
228 |
+
|
229 |
+
_best_epoch = str(_label_loss.best_epoch).zfill(3)
|
230 |
+
_best_val_loss = f"{_label_loss.best_val_loss:.4f}"
|
231 |
+
save_name = 'learning_curve_' + label_name + '_val-best-epoch-' + _best_epoch + '_val-best-loss-' + _best_val_loss + '.csv'
|
232 |
+
save_path = Path(save_dir, save_name)
|
233 |
+
df_label_epoch_loss.to_csv(save_path, index=False)
|
234 |
+
|
235 |
+
|
236 |
+
def set_loss_store(label_list: List[str], num_epochs: int, dataset_info: Dict[str, int]) -> LossStore:
|
237 |
+
"""
|
238 |
+
Return class LossStore.
|
239 |
+
|
240 |
+
Args:
|
241 |
+
label_list (List[str]): label list
|
242 |
+
num_epochs (int) : number of epochs
|
243 |
+
dataset_info (Dict[str, int]): dataset sizes of 'train' and 'val'
|
244 |
+
|
245 |
+
Returns:
|
246 |
+
LossStore: LossStore
|
247 |
+
"""
|
248 |
+
return LossStore(label_list, num_epochs, dataset_info)
|
lib/component/net.py
ADDED
@@ -0,0 +1,624 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
# -*- coding: utf-8 -*-r
|
3 |
+
|
4 |
+
from collections import OrderedDict
|
5 |
+
import torch
|
6 |
+
import torch.nn as nn
|
7 |
+
from torchvision.ops import MLP
|
8 |
+
import torchvision.models as models
|
9 |
+
from typing import Dict, Optional
|
10 |
+
|
11 |
+
|
12 |
+
class BaseNet:
|
13 |
+
"""
|
14 |
+
Class to construct network
|
15 |
+
"""
|
16 |
+
cnn = {
|
17 |
+
'ResNet18': models.resnet18,
|
18 |
+
'ResNet': models.resnet50,
|
19 |
+
'DenseNet': models.densenet161,
|
20 |
+
'EfficientNetB0': models.efficientnet_b0,
|
21 |
+
'EfficientNetB2': models.efficientnet_b2,
|
22 |
+
'EfficientNetB4': models.efficientnet_b4,
|
23 |
+
'EfficientNetB6': models.efficientnet_b6,
|
24 |
+
'EfficientNetV2s': models.efficientnet_v2_s,
|
25 |
+
'EfficientNetV2m': models.efficientnet_v2_m,
|
26 |
+
'EfficientNetV2l': models.efficientnet_v2_l,
|
27 |
+
'ConvNeXtTiny': models.convnext_tiny,
|
28 |
+
'ConvNeXtSmall': models.convnext_small,
|
29 |
+
'ConvNeXtBase': models.convnext_base,
|
30 |
+
'ConvNeXtLarge': models.convnext_large
|
31 |
+
}
|
32 |
+
|
33 |
+
vit = {
|
34 |
+
'ViTb16': models.vit_b_16,
|
35 |
+
'ViTb32': models.vit_b_32,
|
36 |
+
'ViTl16': models.vit_l_16,
|
37 |
+
'ViTl32': models.vit_l_32,
|
38 |
+
'ViTH14': models.vit_h_14
|
39 |
+
}
|
40 |
+
|
41 |
+
net = {**cnn, **vit}
|
42 |
+
|
43 |
+
_classifier = {
|
44 |
+
'ResNet': 'fc',
|
45 |
+
'DenseNet': 'classifier',
|
46 |
+
'EfficientNet': 'classifier',
|
47 |
+
'ConvNext': 'classifier',
|
48 |
+
'ViT': 'heads'
|
49 |
+
}
|
50 |
+
|
51 |
+
classifier = {
|
52 |
+
'ResNet18': _classifier['ResNet'],
|
53 |
+
'ResNet': _classifier['ResNet'],
|
54 |
+
'DenseNet': _classifier['DenseNet'],
|
55 |
+
'EfficientNetB0': _classifier['EfficientNet'],
|
56 |
+
'EfficientNetB2': _classifier['EfficientNet'],
|
57 |
+
'EfficientNetB4': _classifier['EfficientNet'],
|
58 |
+
'EfficientNetB6': _classifier['EfficientNet'],
|
59 |
+
'EfficientNetV2s': _classifier['EfficientNet'],
|
60 |
+
'EfficientNetV2m': _classifier['EfficientNet'],
|
61 |
+
'EfficientNetV2l': _classifier['EfficientNet'],
|
62 |
+
'ConvNeXtTiny': _classifier['ConvNext'],
|
63 |
+
'ConvNeXtSmall': _classifier['ConvNext'],
|
64 |
+
'ConvNeXtBase': _classifier['ConvNext'],
|
65 |
+
'ConvNeXtLarge': _classifier['ConvNext'],
|
66 |
+
'ViTb16': _classifier['ViT'],
|
67 |
+
'ViTb32': _classifier['ViT'],
|
68 |
+
'ViTl16': _classifier['ViT'],
|
69 |
+
'ViTl32': _classifier['ViT'],
|
70 |
+
'ViTH14': _classifier['ViT']
|
71 |
+
}
|
72 |
+
|
73 |
+
mlp_config = {
|
74 |
+
'hidden_channels': [256, 256, 256],
|
75 |
+
'dropout': 0.2
|
76 |
+
}
|
77 |
+
|
78 |
+
DUMMY = nn.Identity()
|
79 |
+
|
80 |
+
@classmethod
|
81 |
+
def MLPNet(cls, mlp_num_inputs: int = None, inplace: bool = None) -> MLP:
|
82 |
+
"""
|
83 |
+
Construct MLP.
|
84 |
+
|
85 |
+
Args:
|
86 |
+
mlp_num_inputs (int): the number of input of MLP
|
87 |
+
inplace (bool, optional): parameter for the activation layer, which can optionally do the operation in-place. Defaults to None.
|
88 |
+
|
89 |
+
Returns:
|
90 |
+
MLP: MLP
|
91 |
+
"""
|
92 |
+
assert isinstance(mlp_num_inputs, int), f"Invalid number of inputs for MLP: {mlp_num_inputs}."
|
93 |
+
mlp = MLP(in_channels=mlp_num_inputs, hidden_channels=cls.mlp_config['hidden_channels'], inplace=inplace, dropout=cls.mlp_config['dropout'])
|
94 |
+
return mlp
|
95 |
+
|
96 |
+
@classmethod
|
97 |
+
def align_in_channels_1ch(cls, net_name: str = None, net: nn.Module = None) -> nn.Module:
|
98 |
+
"""
|
99 |
+
Modify network to handle gray scale image.
|
100 |
+
|
101 |
+
Args:
|
102 |
+
net_name (str): network name
|
103 |
+
net (nn.Module): network itself
|
104 |
+
|
105 |
+
Returns:
|
106 |
+
nn.Module: network available for gray scale
|
107 |
+
"""
|
108 |
+
if net_name.startswith('ResNet'):
|
109 |
+
net.conv1.in_channels = 1
|
110 |
+
net.conv1.weight = nn.Parameter(net.conv1.weight.sum(dim=1).unsqueeze(1))
|
111 |
+
|
112 |
+
elif net_name.startswith('DenseNet'):
|
113 |
+
net.features.conv0.in_channels = 1
|
114 |
+
net.features.conv0.weight = nn.Parameter(net.features.conv0.weight.sum(dim=1).unsqueeze(1))
|
115 |
+
|
116 |
+
elif net_name.startswith('Efficient'):
|
117 |
+
net.features[0][0].in_channels = 1
|
118 |
+
net.features[0][0].weight = nn.Parameter(net.features[0][0].weight.sum(dim=1).unsqueeze(1))
|
119 |
+
|
120 |
+
elif net_name.startswith('ConvNeXt'):
|
121 |
+
net.features[0][0].in_channels = 1
|
122 |
+
net.features[0][0].weight = nn.Parameter(net.features[0][0].weight.sum(dim=1).unsqueeze(1))
|
123 |
+
|
124 |
+
elif net_name.startswith('ViT'):
|
125 |
+
net.conv_proj.in_channels = 1
|
126 |
+
net.conv_proj.weight = nn.Parameter(net.conv_proj.weight.sum(dim=1).unsqueeze(1))
|
127 |
+
|
128 |
+
else:
|
129 |
+
raise ValueError(f"No specified net: {net_name}.")
|
130 |
+
return net
|
131 |
+
|
132 |
+
@classmethod
|
133 |
+
def set_net(
|
134 |
+
cls,
|
135 |
+
net_name: str = None,
|
136 |
+
in_channel: int = None,
|
137 |
+
vit_image_size: int = None,
|
138 |
+
pretrained: bool = None
|
139 |
+
) -> nn.Module:
|
140 |
+
"""
|
141 |
+
Modify network depending on in_channel and vit_image_size.
|
142 |
+
|
143 |
+
Args:
|
144 |
+
net_name (str): network name
|
145 |
+
in_channel (int, optional): image channel(any of 1ch or 3ch). Defaults to None.
|
146 |
+
vit_image_size (int, optional): image size which ViT handles if ViT is used. Defaults to None.
|
147 |
+
vit_image_size should be power of patch size.
|
148 |
+
pretrained (bool, optional): True when use pretrained CNN or ViT, otherwise False. Defaults to None.
|
149 |
+
|
150 |
+
Returns:
|
151 |
+
nn.Module: modified network
|
152 |
+
"""
|
153 |
+
assert net_name in cls.net, f"No specified net: {net_name}."
|
154 |
+
if net_name in cls.cnn:
|
155 |
+
if pretrained:
|
156 |
+
net = cls.cnn[net_name](weights='DEFAULT')
|
157 |
+
else:
|
158 |
+
net = cls.cnn[net_name]()
|
159 |
+
else:
|
160 |
+
# When ViT
|
161 |
+
# always use pretrained
|
162 |
+
net = cls.set_vit(net_name=net_name, vit_image_size=vit_image_size)
|
163 |
+
|
164 |
+
if in_channel == 1:
|
165 |
+
net = cls.align_in_channels_1ch(net_name=net_name, net=net)
|
166 |
+
return net
|
167 |
+
|
168 |
+
@classmethod
|
169 |
+
def set_vit(cls, net_name: str = None, vit_image_size: int = None) -> nn.Module:
|
170 |
+
"""
|
171 |
+
Modify ViT depending on vit_image_size.
|
172 |
+
|
173 |
+
Args:
|
174 |
+
net_name (str): ViT name
|
175 |
+
vit_image_size (int): image size which ViT handles if ViT is used.
|
176 |
+
|
177 |
+
Returns:
|
178 |
+
nn.Module: modified ViT
|
179 |
+
"""
|
180 |
+
base_vit = cls.vit[net_name]
|
181 |
+
# pretrained_vit = base_vit(weights=cls.vit_weight[net_name])
|
182 |
+
pretrained_vit = base_vit(weights='DEFAULT')
|
183 |
+
|
184 |
+
# Align weight depending on image size
|
185 |
+
weight = pretrained_vit.state_dict()
|
186 |
+
patch_size = int(net_name[-2:]) # 'ViTb16' -> 16
|
187 |
+
aligned_weight = models.vision_transformer.interpolate_embeddings(
|
188 |
+
image_size=vit_image_size,
|
189 |
+
patch_size=patch_size,
|
190 |
+
model_state=weight
|
191 |
+
)
|
192 |
+
aligned_vit = base_vit(image_size=vit_image_size) # Specify new image size.
|
193 |
+
aligned_vit.load_state_dict(aligned_weight) # Load weight which can handle the new image size.
|
194 |
+
return aligned_vit
|
195 |
+
|
196 |
+
@classmethod
|
197 |
+
def construct_extractor(
|
198 |
+
cls,
|
199 |
+
net_name: str = None,
|
200 |
+
mlp_num_inputs: int = None,
|
201 |
+
in_channel: int = None,
|
202 |
+
vit_image_size: int = None,
|
203 |
+
pretrained: bool = None
|
204 |
+
) -> nn.Module:
|
205 |
+
"""
|
206 |
+
Construct extractor of network depending on net_name.
|
207 |
+
|
208 |
+
Args:
|
209 |
+
net_name (str): network name.
|
210 |
+
mlp_num_inputs (int, optional): number of input of MLP. Defaults to None.
|
211 |
+
in_channel (int, optional): image channel(any of 1ch or 3ch). Defaults to None.
|
212 |
+
vit_image_size (int, optional): image size which ViT handles if ViT is used. Defaults to None.
|
213 |
+
pretrained (bool, optional): True when use pretrained CNN or ViT, otherwise False. Defaults to None.
|
214 |
+
|
215 |
+
Returns:
|
216 |
+
nn.Module: extractor of network
|
217 |
+
"""
|
218 |
+
if net_name == 'MLP':
|
219 |
+
extractor = cls.MLPNet(mlp_num_inputs=mlp_num_inputs)
|
220 |
+
else:
|
221 |
+
extractor = cls.set_net(net_name=net_name, in_channel=in_channel, vit_image_size=vit_image_size, pretrained=pretrained)
|
222 |
+
setattr(extractor, cls.classifier[net_name], cls.DUMMY) # Replace classifier with DUMMY(=nn.Identity()).
|
223 |
+
return extractor
|
224 |
+
|
225 |
+
@classmethod
|
226 |
+
def get_classifier(cls, net_name: str) -> nn.Module:
|
227 |
+
"""
|
228 |
+
Get classifier of network depending on net_name.
|
229 |
+
|
230 |
+
Args:
|
231 |
+
net_name (str): network name
|
232 |
+
|
233 |
+
Returns:
|
234 |
+
nn.Module: classifier of network
|
235 |
+
"""
|
236 |
+
net = cls.net[net_name]()
|
237 |
+
classifier = getattr(net, cls.classifier[net_name])
|
238 |
+
return classifier
|
239 |
+
|
240 |
+
@classmethod
|
241 |
+
def construct_multi_classifier(cls, net_name: str = None, num_outputs_for_label: Dict[str, int] = None) -> nn.ModuleDict:
|
242 |
+
"""
|
243 |
+
Construct classifier for multi-label.
|
244 |
+
|
245 |
+
Args:
|
246 |
+
net_name (str): network name
|
247 |
+
num_outputs_for_label (Dict[str, int]): number of outputs for each label
|
248 |
+
|
249 |
+
Returns:
|
250 |
+
nn.ModuleDict: classifier for multi-label
|
251 |
+
"""
|
252 |
+
classifiers = dict()
|
253 |
+
if net_name == 'MLP':
|
254 |
+
in_features = cls.mlp_config['hidden_channels'][-1]
|
255 |
+
for label_name, num_outputs in num_outputs_for_label.items():
|
256 |
+
classifiers[label_name] = nn.Linear(in_features, num_outputs)
|
257 |
+
|
258 |
+
elif net_name.startswith('ResNet') or net_name.startswith('DenseNet'):
|
259 |
+
base_classifier = cls.get_classifier(net_name)
|
260 |
+
in_features = base_classifier.in_features
|
261 |
+
for label_name, num_outputs in num_outputs_for_label.items():
|
262 |
+
classifiers[label_name] = nn.Linear(in_features, num_outputs)
|
263 |
+
|
264 |
+
elif net_name.startswith('EfficientNet'):
|
265 |
+
base_classifier = cls.get_classifier(net_name)
|
266 |
+
dropout = base_classifier[0].p
|
267 |
+
in_features = base_classifier[1].in_features
|
268 |
+
for label_name, num_outputs in num_outputs_for_label.items():
|
269 |
+
classifiers[label_name] = nn.Sequential(
|
270 |
+
nn.Dropout(p=dropout, inplace=False),
|
271 |
+
nn.Linear(in_features, num_outputs)
|
272 |
+
)
|
273 |
+
|
274 |
+
elif net_name.startswith('ConvNeXt'):
|
275 |
+
base_classifier = cls.get_classifier(net_name)
|
276 |
+
layer_norm = base_classifier[0]
|
277 |
+
flatten = base_classifier[1]
|
278 |
+
in_features = base_classifier[2].in_features
|
279 |
+
for label_name, num_outputs in num_outputs_for_label.items():
|
280 |
+
# Shape is changed before nn.Linear.
|
281 |
+
classifiers[label_name] = nn.Sequential(
|
282 |
+
layer_norm,
|
283 |
+
flatten,
|
284 |
+
nn.Linear(in_features, num_outputs)
|
285 |
+
)
|
286 |
+
|
287 |
+
elif net_name.startswith('ViT'):
|
288 |
+
base_classifier = cls.get_classifier(net_name)
|
289 |
+
in_features = base_classifier.head.in_features
|
290 |
+
for label_name, num_outputs in num_outputs_for_label.items():
|
291 |
+
classifiers[label_name] = nn.Sequential(
|
292 |
+
OrderedDict([
|
293 |
+
('head', nn.Linear(in_features, num_outputs))
|
294 |
+
])
|
295 |
+
)
|
296 |
+
|
297 |
+
else:
|
298 |
+
raise ValueError(f"No specified net: {net_name}.")
|
299 |
+
|
300 |
+
multi_classifier = nn.ModuleDict(classifiers)
|
301 |
+
return multi_classifier
|
302 |
+
|
303 |
+
@classmethod
|
304 |
+
def get_classifier_in_features(cls, net_name: str) -> int:
|
305 |
+
"""
|
306 |
+
Return in_feature of network indicating by net_name.
|
307 |
+
This class is used in class MultiNetFusion() only.
|
308 |
+
|
309 |
+
Args:
|
310 |
+
net_name (str): net_name
|
311 |
+
|
312 |
+
Returns:
|
313 |
+
int : in_feature
|
314 |
+
|
315 |
+
Required:
|
316 |
+
classifier.in_feature
|
317 |
+
classifier.[1].in_features
|
318 |
+
classifier.[2].in_features
|
319 |
+
classifier.head.in_features
|
320 |
+
"""
|
321 |
+
if net_name == 'MLP':
|
322 |
+
in_features = cls.mlp_config['hidden_channels'][-1]
|
323 |
+
|
324 |
+
elif net_name.startswith('ResNet') or net_name.startswith('DenseNet'):
|
325 |
+
base_classifier = cls.get_classifier(net_name)
|
326 |
+
in_features = base_classifier.in_features
|
327 |
+
|
328 |
+
elif net_name.startswith('EfficientNet'):
|
329 |
+
base_classifier = cls.get_classifier(net_name)
|
330 |
+
in_features = base_classifier[1].in_features
|
331 |
+
|
332 |
+
elif net_name.startswith('ConvNeXt'):
|
333 |
+
base_classifier = cls.get_classifier(net_name)
|
334 |
+
in_features = base_classifier[2].in_features
|
335 |
+
|
336 |
+
elif net_name.startswith('ViT'):
|
337 |
+
base_classifier = cls.get_classifier(net_name)
|
338 |
+
in_features = base_classifier.head.in_features
|
339 |
+
|
340 |
+
else:
|
341 |
+
raise ValueError(f"No specified net: {net_name}.")
|
342 |
+
return in_features
|
343 |
+
|
344 |
+
@classmethod
|
345 |
+
def construct_aux_module(cls, net_name: str) -> nn.Sequential:
|
346 |
+
"""
|
347 |
+
Construct module to align the shape of feature from extractor depending on network.
|
348 |
+
Actually, only when net_name == 'ConvNeXt'.
|
349 |
+
Because ConvNeXt has the process of aligning the dimensions in its classifier.
|
350 |
+
|
351 |
+
Needs to align shape of the feature extractor when ConvNeXt
|
352 |
+
(classifier):
|
353 |
+
Sequential(
|
354 |
+
(0): LayerNorm2d((768,), eps=1e-06, elementwise_affine=True)
|
355 |
+
(1): Flatten(start_dim=1, end_dim=-1)
|
356 |
+
(2): Linear(in_features=768, out_features=1000, bias=True)
|
357 |
+
)
|
358 |
+
|
359 |
+
Args:
|
360 |
+
net_name (str): net name
|
361 |
+
|
362 |
+
Returns:
|
363 |
+
nn.Module: layers such that they align the dimension of the output from the extractor like the original ConvNeXt.
|
364 |
+
"""
|
365 |
+
aux_module = cls.DUMMY
|
366 |
+
if net_name.startswith('ConvNeXt'):
|
367 |
+
base_classifier = cls.get_classifier(net_name)
|
368 |
+
layer_norm = base_classifier[0]
|
369 |
+
flatten = base_classifier[1]
|
370 |
+
aux_module = nn.Sequential(
|
371 |
+
layer_norm,
|
372 |
+
flatten
|
373 |
+
)
|
374 |
+
return aux_module
|
375 |
+
|
376 |
+
@classmethod
|
377 |
+
def get_last_extractor(cls, net: nn.Module = None, mlp: str = None, net_name: str = None) -> nn.Module:
|
378 |
+
"""
|
379 |
+
Return the last extractor of network.
|
380 |
+
This is for Grad-CAM.
|
381 |
+
net should be one loaded weight.
|
382 |
+
|
383 |
+
Args:
|
384 |
+
net (nn.Module): network itself
|
385 |
+
mlp (str): 'MLP', otherwise None
|
386 |
+
net_name (str): network name
|
387 |
+
|
388 |
+
Returns:
|
389 |
+
nn.Module: last extractor of network
|
390 |
+
"""
|
391 |
+
assert (net_name is not None), f"Network does not contain CNN or ViT: mlp={mlp}, net={net_name}."
|
392 |
+
|
393 |
+
_extractor = net.extractor_net
|
394 |
+
|
395 |
+
if net_name.startswith('ResNet'):
|
396 |
+
last_extractor = _extractor.layer4[-1]
|
397 |
+
elif net_name.startswith('DenseNet'):
|
398 |
+
last_extractor = _extractor.features.denseblock4.denselayer24
|
399 |
+
elif net_name.startswith('EfficientNet'):
|
400 |
+
last_extractor = _extractor.features[-1]
|
401 |
+
elif net_name.startswith('ConvNeXt'):
|
402 |
+
last_extractor = _extractor.features[-1][-1].block
|
403 |
+
elif net_name.startswith('ViT'):
|
404 |
+
last_extractor = _extractor.encoder.layers[-1]
|
405 |
+
else:
|
406 |
+
raise ValueError(f"Cannot get last extractor of net: {net_name}.")
|
407 |
+
return last_extractor
|
408 |
+
|
409 |
+
|
410 |
+
class MultiMixin:
|
411 |
+
"""
|
412 |
+
Class to define auxiliary function to handle multi-label.
|
413 |
+
"""
|
414 |
+
def multi_forward(self, out_features: int) -> Dict[str, float]:
|
415 |
+
"""
|
416 |
+
Forward out_features to classifier for each label.
|
417 |
+
|
418 |
+
Args:
|
419 |
+
out_features (int): output from extractor
|
420 |
+
|
421 |
+
Returns:
|
422 |
+
Dict[str, float]: output of classifier of each label
|
423 |
+
"""
|
424 |
+
output = dict()
|
425 |
+
for label_name, classifier in self.multi_classifier.items():
|
426 |
+
output[label_name] = classifier(out_features)
|
427 |
+
return output
|
428 |
+
|
429 |
+
|
430 |
+
class MultiWidget(nn.Module, BaseNet, MultiMixin):
|
431 |
+
"""
|
432 |
+
Class for a widget to inherit multiple classes simultaneously.
|
433 |
+
"""
|
434 |
+
pass
|
435 |
+
|
436 |
+
|
437 |
+
class MultiNet(MultiWidget):
|
438 |
+
"""
|
439 |
+
Model of MLP, CNN or ViT.
|
440 |
+
"""
|
441 |
+
def __init__(
|
442 |
+
self,
|
443 |
+
net_name: str = None,
|
444 |
+
num_outputs_for_label: Dict[str, int] = None,
|
445 |
+
mlp_num_inputs: int = None,
|
446 |
+
in_channel: int = None,
|
447 |
+
vit_image_size: int = None,
|
448 |
+
pretrained: bool = None
|
449 |
+
) -> None:
|
450 |
+
"""
|
451 |
+
Args:
|
452 |
+
net_name (str): MLP, CNN or ViT name
|
453 |
+
num_outputs_for_label (Dict[str, int]): number of classes for each label
|
454 |
+
mlp_num_inputs (int): number of input of MLP.
|
455 |
+
in_channel (int): number of image channel, ie gray scale(=1) or color image(=3).
|
456 |
+
vit_image_size (int): image size to be input to ViT.
|
457 |
+
pretrained (bool): True when use pretrained CNN or ViT, otherwise False.
|
458 |
+
"""
|
459 |
+
super().__init__()
|
460 |
+
|
461 |
+
self.net_name = net_name
|
462 |
+
self.num_outputs_for_label = num_outputs_for_label
|
463 |
+
self.mlp_num_inputs = mlp_num_inputs
|
464 |
+
self.in_channel = in_channel
|
465 |
+
self.vit_image_size = vit_image_size
|
466 |
+
self.pretrained = pretrained
|
467 |
+
|
468 |
+
# self.extractor_net = MLP or CVmodel
|
469 |
+
self.extractor_net = self.construct_extractor(
|
470 |
+
net_name=self.net_name,
|
471 |
+
mlp_num_inputs=self.mlp_num_inputs,
|
472 |
+
in_channel=self.in_channel,
|
473 |
+
vit_image_size=self.vit_image_size,
|
474 |
+
pretrained=self.pretrained
|
475 |
+
)
|
476 |
+
self.multi_classifier = self.construct_multi_classifier(net_name=self.net_name, num_outputs_for_label=self.num_outputs_for_label)
|
477 |
+
|
478 |
+
def forward(self, x: torch.Tensor) -> Dict[str, torch.Tensor]:
|
479 |
+
"""
|
480 |
+
Forward.
|
481 |
+
|
482 |
+
Args:
|
483 |
+
x (torch.Tensor): tabular data or image
|
484 |
+
|
485 |
+
Returns:
|
486 |
+
Dict[str, torch.Tensor]: output
|
487 |
+
"""
|
488 |
+
out_features = self.extractor_net(x)
|
489 |
+
output = self.multi_forward(out_features)
|
490 |
+
return output
|
491 |
+
|
492 |
+
|
493 |
+
class MultiNetFusion(MultiWidget):
|
494 |
+
"""
|
495 |
+
Fusion model of MLP and CNN or ViT.
|
496 |
+
"""
|
497 |
+
def __init__(
|
498 |
+
self,
|
499 |
+
net_name: str = None,
|
500 |
+
num_outputs_for_label: Dict[str, int] = None,
|
501 |
+
mlp_num_inputs: int = None,
|
502 |
+
in_channel: int = None,
|
503 |
+
vit_image_size: int = None,
|
504 |
+
pretrained: bool = None
|
505 |
+
) -> None:
|
506 |
+
"""
|
507 |
+
Args:
|
508 |
+
net_name (str): CNN or ViT name. It is clear that MLP is used in fusion model.
|
509 |
+
num_outputs_for_label (Dict[str, int]): number of classes for each label
|
510 |
+
mlp_num_inputs (int): number of input of MLP. Defaults to None.
|
511 |
+
in_channel (int): number of image channel, ie gray scale(=1) or color image(=3).
|
512 |
+
vit_image_size (int): image size to be input to ViT.
|
513 |
+
pretrained (bool): True when use pretrained CNN or ViT, otherwise False.
|
514 |
+
"""
|
515 |
+
assert (net_name != 'MLP'), 'net_name should not be MLP.'
|
516 |
+
|
517 |
+
super().__init__()
|
518 |
+
|
519 |
+
self.net_name = net_name
|
520 |
+
self.num_outputs_for_label = num_outputs_for_label
|
521 |
+
self.mlp_num_inputs = mlp_num_inputs
|
522 |
+
self.in_channel = in_channel
|
523 |
+
self.vit_image_size = vit_image_size
|
524 |
+
self.pretrained = pretrained
|
525 |
+
|
526 |
+
# Extractor of MLP and Net
|
527 |
+
self.extractor_mlp = self.construct_extractor(net_name='MLP', mlp_num_inputs=self.mlp_num_inputs)
|
528 |
+
self.extractor_net = self.construct_extractor(
|
529 |
+
net_name=self.net_name,
|
530 |
+
in_channel=self.in_channel,
|
531 |
+
vit_image_size=self.vit_image_size,
|
532 |
+
pretrained=self.pretrained
|
533 |
+
)
|
534 |
+
self.aux_module = self.construct_aux_module(self.net_name)
|
535 |
+
|
536 |
+
# Intermediate MLP
|
537 |
+
self.in_features_from_mlp = self.get_classifier_in_features('MLP')
|
538 |
+
self.in_features_from_net = self.get_classifier_in_features(self.net_name)
|
539 |
+
self.inter_mlp_in_feature = self.in_features_from_mlp + self.in_features_from_net
|
540 |
+
self.inter_mlp = self.MLPNet(mlp_num_inputs=self.inter_mlp_in_feature, inplace=False)
|
541 |
+
|
542 |
+
# Multi classifier
|
543 |
+
self.multi_classifier = self.construct_multi_classifier(net_name='MLP', num_outputs_for_label=num_outputs_for_label)
|
544 |
+
|
545 |
+
def forward(self, x_mlp: torch.Tensor, x_net: torch.Tensor) -> Dict[str, torch.Tensor]:
|
546 |
+
"""
|
547 |
+
Forward.
|
548 |
+
|
549 |
+
Args:
|
550 |
+
x_mlp (torch.Tensor): tabular data
|
551 |
+
x_net (torch.Tensor): image
|
552 |
+
|
553 |
+
Returns:
|
554 |
+
Dict[str, torch.Tensor]: output
|
555 |
+
"""
|
556 |
+
out_mlp = self.extractor_mlp(x_mlp)
|
557 |
+
out_net = self.extractor_net(x_net)
|
558 |
+
out_net = self.aux_module(out_net)
|
559 |
+
|
560 |
+
out_features = torch.cat([out_mlp, out_net], dim=1)
|
561 |
+
out_features = self.inter_mlp(out_features)
|
562 |
+
output = self.multi_forward(out_features)
|
563 |
+
return output
|
564 |
+
|
565 |
+
|
566 |
+
def create_net(
|
567 |
+
mlp: Optional[str] = None,
|
568 |
+
net: Optional[str] = None,
|
569 |
+
num_outputs_for_label: Dict[str, int] = None,
|
570 |
+
mlp_num_inputs: int = None,
|
571 |
+
in_channel: int = None,
|
572 |
+
vit_image_size: int = None,
|
573 |
+
pretrained: bool = None
|
574 |
+
) -> nn.Module:
|
575 |
+
"""
|
576 |
+
Create network.
|
577 |
+
|
578 |
+
Args:
|
579 |
+
mlp (Optional[str]): 'MLP' or None
|
580 |
+
net (Optional[str]): CNN, ViT name or None
|
581 |
+
num_outputs_for_label (Dict[str, int]): number of outputs for each label
|
582 |
+
mlp_num_inputs (int): number of input of MLP.
|
583 |
+
in_channel (int): number of image channel, ie gray scale(=1) or color image(=3).
|
584 |
+
vit_image_size (int): image size to be input to ViT.
|
585 |
+
pretrained (bool): True when use pretrained CNN or ViT, otherwise False.
|
586 |
+
|
587 |
+
Returns:
|
588 |
+
nn.Module: network
|
589 |
+
"""
|
590 |
+
_isMLPModel = (mlp is not None) and (net is None)
|
591 |
+
_isCVModel = (mlp is None) and (net is not None)
|
592 |
+
_isFusion = (mlp is not None) and (net is not None)
|
593 |
+
|
594 |
+
if _isMLPModel:
|
595 |
+
multi_net = MultiNet(
|
596 |
+
net_name='MLP',
|
597 |
+
num_outputs_for_label=num_outputs_for_label,
|
598 |
+
mlp_num_inputs=mlp_num_inputs,
|
599 |
+
in_channel=in_channel,
|
600 |
+
vit_image_size=vit_image_size,
|
601 |
+
pretrained=False # No need of pretrained for MLP
|
602 |
+
)
|
603 |
+
elif _isCVModel:
|
604 |
+
multi_net = MultiNet(
|
605 |
+
net_name=net,
|
606 |
+
num_outputs_for_label=num_outputs_for_label,
|
607 |
+
mlp_num_inputs=mlp_num_inputs,
|
608 |
+
in_channel=in_channel,
|
609 |
+
vit_image_size=vit_image_size,
|
610 |
+
pretrained=pretrained
|
611 |
+
)
|
612 |
+
elif _isFusion:
|
613 |
+
multi_net = MultiNetFusion(
|
614 |
+
net_name=net,
|
615 |
+
num_outputs_for_label=num_outputs_for_label,
|
616 |
+
mlp_num_inputs=mlp_num_inputs,
|
617 |
+
in_channel=in_channel,
|
618 |
+
vit_image_size=vit_image_size,
|
619 |
+
pretrained=pretrained
|
620 |
+
)
|
621 |
+
else:
|
622 |
+
raise ValueError(f"Invalid model type: mlp={mlp}, net={net}.")
|
623 |
+
|
624 |
+
return multi_net
|
lib/component/optimizer.py
ADDED
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
|
4 |
+
import torch.optim as optim
|
5 |
+
import torch.nn as nn
|
6 |
+
|
7 |
+
|
8 |
+
def set_optimizer(optimizer_name: str, network: nn.Module, lr: float) -> optim:
|
9 |
+
"""
|
10 |
+
Set optimizer.
|
11 |
+
Args:
|
12 |
+
optimizer_name (str): criterion name
|
13 |
+
network (torch.nn.Module): network
|
14 |
+
lr (float): learning rate
|
15 |
+
Returns:
|
16 |
+
torch.optim: optimizer
|
17 |
+
"""
|
18 |
+
optimizers = {
|
19 |
+
'SGD': optim.SGD,
|
20 |
+
'Adadelta': optim.Adadelta,
|
21 |
+
'Adam': optim.Adam,
|
22 |
+
'RMSprop': optim.RMSprop,
|
23 |
+
'RAdam': optim.RAdam
|
24 |
+
}
|
25 |
+
|
26 |
+
assert (optimizer_name in optimizers), f"No specified optimizer: {optimizer_name}."
|
27 |
+
|
28 |
+
_optim = optimizers[optimizer_name]
|
29 |
+
|
30 |
+
if lr is None:
|
31 |
+
optimizer = _optim(network.parameters())
|
32 |
+
else:
|
33 |
+
optimizer = _optim(network.parameters(), lr=lr)
|
34 |
+
return optimizer
|
lib/dataloader.py
ADDED
@@ -0,0 +1,400 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
|
4 |
+
import numpy as np
|
5 |
+
import torch
|
6 |
+
import torchvision.transforms as transforms
|
7 |
+
from torch.utils.data.dataset import Dataset
|
8 |
+
from torch.utils.data.dataloader import DataLoader
|
9 |
+
from torch.utils.data.sampler import WeightedRandomSampler
|
10 |
+
from PIL import Image
|
11 |
+
from sklearn.preprocessing import MinMaxScaler
|
12 |
+
import pickle
|
13 |
+
from .logger import BaseLogger
|
14 |
+
from typing import List, Dict, Union
|
15 |
+
import pandas as pd
|
16 |
+
|
17 |
+
|
18 |
+
logger = BaseLogger.get_logger(__name__)
|
19 |
+
|
20 |
+
|
21 |
+
class PrivateAugment(torch.nn.Module):
|
22 |
+
"""
|
23 |
+
Augmentation defined privately.
|
24 |
+
Variety of augmentation can be written in this class if necessary.
|
25 |
+
"""
|
26 |
+
# For X-ray photo.
|
27 |
+
xray_augs_list = [
|
28 |
+
transforms.RandomAffine(degrees=(-3, 3), translate=(0.02, 0.02)),
|
29 |
+
transforms.RandomAdjustSharpness(sharpness_factor=2),
|
30 |
+
transforms.RandomAutocontrast()
|
31 |
+
]
|
32 |
+
|
33 |
+
|
34 |
+
class InputDataMixin:
|
35 |
+
"""
|
36 |
+
Class to normalizes input data.
|
37 |
+
"""
|
38 |
+
def _make_scaler(self) -> MinMaxScaler:
|
39 |
+
"""
|
40 |
+
Make scaler to normalize input data by min-max normalization with train data.
|
41 |
+
|
42 |
+
Returns:
|
43 |
+
MinMaxScaler: scaler
|
44 |
+
"""
|
45 |
+
scaler = MinMaxScaler()
|
46 |
+
_df_train = self.df_source[self.df_source['split'] == 'train'] # should be normalized with min and max of training data
|
47 |
+
_ = scaler.fit(_df_train[self.input_list]) # fit only
|
48 |
+
return scaler
|
49 |
+
|
50 |
+
def save_scaler(self, save_path :str) -> None:
|
51 |
+
"""
|
52 |
+
Save scaler
|
53 |
+
|
54 |
+
Args:
|
55 |
+
save_path (str): path for saving scaler.
|
56 |
+
"""
|
57 |
+
#save_scaler_path = Path(save_datetime_dir, 'scaler.pkl')
|
58 |
+
with open(save_path, 'wb') as f:
|
59 |
+
pickle.dump(self.scaler, f)
|
60 |
+
|
61 |
+
def load_scaler(self, scaler_path :str) -> None:
|
62 |
+
"""
|
63 |
+
Load scaler.
|
64 |
+
|
65 |
+
Args:
|
66 |
+
scaler_path (str): path to scaler
|
67 |
+
"""
|
68 |
+
with open(scaler_path, 'rb') as f:
|
69 |
+
scaler = pickle.load(f)
|
70 |
+
return scaler
|
71 |
+
|
72 |
+
def _normalize_inputs(self, df_inputs: pd.DataFrame) -> torch.FloatTensor:
|
73 |
+
"""
|
74 |
+
Normalize inputs.
|
75 |
+
|
76 |
+
Args:
|
77 |
+
df_inputs (pd.DataFrame): DataFrame of inputs
|
78 |
+
|
79 |
+
Returns:
|
80 |
+
torch.FloatTensor: normalized inputs
|
81 |
+
|
82 |
+
Note:
|
83 |
+
After iloc[[idx], index_input_list], pd.DataFrame is obtained.
|
84 |
+
DataFrame fits the input type of self.scaler.transform.
|
85 |
+
However, after normalizing, the shape of inputs_value is (1, N), where N is the number of input values.
|
86 |
+
Since the shape (1, N) is not acceptable when forwarding, convert (1, N) -> (N,) is needed.
|
87 |
+
"""
|
88 |
+
inputs_value = self.scaler.transform(df_inputs).reshape(-1) # np.float64
|
89 |
+
inputs_value = np.array(inputs_value, dtype=np.float32) # -> np.float32
|
90 |
+
inputs_value = torch.from_numpy(inputs_value).clone() # -> torch.float32
|
91 |
+
return inputs_value
|
92 |
+
|
93 |
+
def _load_input_value_if_mlp(self, idx: int) -> Union[torch.FloatTensor, str]:
|
94 |
+
"""
|
95 |
+
Load input values after converting them into tensor if MLP is used.
|
96 |
+
|
97 |
+
Args:
|
98 |
+
idx (int): index
|
99 |
+
|
100 |
+
Returns:
|
101 |
+
Union[torch.Tensor[float], str]: tensor of input values, or empty string
|
102 |
+
"""
|
103 |
+
inputs_value = ''
|
104 |
+
|
105 |
+
if self.params.mlp is None:
|
106 |
+
return inputs_value
|
107 |
+
|
108 |
+
index_input_list = [self.col_index_dict[input] for input in self.input_list]
|
109 |
+
_df_inputs = self.df_split.iloc[[idx], index_input_list]
|
110 |
+
inputs_value = self._normalize_inputs( _df_inputs)
|
111 |
+
return inputs_value
|
112 |
+
|
113 |
+
|
114 |
+
class ImageMixin:
|
115 |
+
"""
|
116 |
+
Class to normalize and transform image.
|
117 |
+
"""
|
118 |
+
def _make_augmentations(self) -> List:
|
119 |
+
"""
|
120 |
+
Define which augmentation is applied.
|
121 |
+
|
122 |
+
When training, augmentation is needed for train data only.
|
123 |
+
When test, no need of augmentation.
|
124 |
+
"""
|
125 |
+
_augmentation = []
|
126 |
+
if (self.params.isTrain) and (self.split == 'train'):
|
127 |
+
if self.params.augmentation == 'xrayaug':
|
128 |
+
_augmentation = PrivateAugment.xray_augs_list
|
129 |
+
elif self.params.augmentation == 'trivialaugwide':
|
130 |
+
_augmentation.append(transforms.TrivialAugmentWide())
|
131 |
+
elif self.params.augmentation == 'randaug':
|
132 |
+
_augmentation.append(transforms.RandAugment())
|
133 |
+
else:
|
134 |
+
# ie. self.params.augmentation == 'no':
|
135 |
+
pass
|
136 |
+
|
137 |
+
_augmentation = transforms.Compose(_augmentation)
|
138 |
+
return _augmentation
|
139 |
+
|
140 |
+
def _make_transforms(self) -> List:
|
141 |
+
"""
|
142 |
+
Make list of transforms.
|
143 |
+
|
144 |
+
Returns:
|
145 |
+
list of transforms: image normalization
|
146 |
+
"""
|
147 |
+
_transforms = []
|
148 |
+
_transforms.append(transforms.ToTensor())
|
149 |
+
|
150 |
+
if self.params.normalize_image == 'yes':
|
151 |
+
# transforms.Normalize accepts only Tensor.
|
152 |
+
if self.params.in_channel == 1:
|
153 |
+
_transforms.append(transforms.Normalize(mean=(0.5, ), std=(0.5, )))
|
154 |
+
else:
|
155 |
+
# ie. self.params.in_channel == 3
|
156 |
+
_transforms.append(transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]))
|
157 |
+
|
158 |
+
_transforms = transforms.Compose(_transforms)
|
159 |
+
return _transforms
|
160 |
+
|
161 |
+
def _open_image_in_channel(self, imgpath: str, in_channel: int) -> Image:
|
162 |
+
"""
|
163 |
+
Open image in channel.
|
164 |
+
|
165 |
+
Args:
|
166 |
+
imgpath (str): path to image
|
167 |
+
in_channel (int): channel, or 1 or 3
|
168 |
+
|
169 |
+
Returns:
|
170 |
+
Image: PIL image
|
171 |
+
"""
|
172 |
+
if in_channel == 1:
|
173 |
+
image = Image.open(imgpath).convert('L') # eg. np.array(image).shape = (64, 64)
|
174 |
+
return image
|
175 |
+
else:
|
176 |
+
# ie. self.params.in_channel == 3
|
177 |
+
image = Image.open(imgpath).convert('RGB') # eg. np.array(image).shape = (64, 64, 3)
|
178 |
+
return image
|
179 |
+
|
180 |
+
def _load_image_if_cnn(self, idx: int) -> Union[torch.Tensor, str]:
|
181 |
+
"""
|
182 |
+
Load image and convert it to tensor if any of CNN or ViT is used.
|
183 |
+
|
184 |
+
Args:
|
185 |
+
idx (int): index
|
186 |
+
|
187 |
+
Returns:
|
188 |
+
Union[torch.Tensor[float], str]: tensor converted from image, or empty string
|
189 |
+
"""
|
190 |
+
image = ''
|
191 |
+
|
192 |
+
if self.params.net is None:
|
193 |
+
return image
|
194 |
+
|
195 |
+
imgpath = self.df_split.iat[idx, self.col_index_dict['imgpath']]
|
196 |
+
image = self._open_image_in_channel(imgpath, self.params.in_channel)
|
197 |
+
image = self.augmentation(image)
|
198 |
+
image = self.transform(image)
|
199 |
+
return image
|
200 |
+
|
201 |
+
|
202 |
+
class DeepSurvMixin:
|
203 |
+
"""
|
204 |
+
Class to handle required data for deepsurv.
|
205 |
+
"""
|
206 |
+
def _load_periods_if_deepsurv(self, idx: int) -> Union[torch.FloatTensor, str]:
|
207 |
+
"""
|
208 |
+
Return period if deepsurv.
|
209 |
+
|
210 |
+
Args:
|
211 |
+
idx (int): index
|
212 |
+
|
213 |
+
Returns:
|
214 |
+
Union[torch.FloatTensor, str]: period, or empty string
|
215 |
+
"""
|
216 |
+
periods = ''
|
217 |
+
|
218 |
+
if self.params.task != 'deepsurv':
|
219 |
+
return periods
|
220 |
+
|
221 |
+
assert (self.params.task == 'deepsurv') and (len(self.label_list) == 1), 'Deepsurv cannot work in multi-label.'
|
222 |
+
periods = self.df_split.iat[idx, self.col_index_dict[self.period_name]] # int64
|
223 |
+
periods = np.array(periods, dtype=np.float32) # -> np.float32
|
224 |
+
periods = torch.from_numpy(periods).clone() # -> torch.float32
|
225 |
+
return periods
|
226 |
+
|
227 |
+
|
228 |
+
class DataSetWidget(InputDataMixin, ImageMixin, DeepSurvMixin):
|
229 |
+
"""
|
230 |
+
Class for a widget to inherit multiple classes simultaneously.
|
231 |
+
"""
|
232 |
+
pass
|
233 |
+
|
234 |
+
|
235 |
+
class LoadDataSet(Dataset, DataSetWidget):
|
236 |
+
"""
|
237 |
+
Dataset for split.
|
238 |
+
"""
|
239 |
+
def __init__(
|
240 |
+
self,
|
241 |
+
params,
|
242 |
+
split: str
|
243 |
+
) -> None:
|
244 |
+
"""
|
245 |
+
Args:
|
246 |
+
params (ParamSet): parameter for model
|
247 |
+
split (str): split
|
248 |
+
"""
|
249 |
+
self.params = params
|
250 |
+
self.df_source = self.params.df_source
|
251 |
+
self.split = split
|
252 |
+
|
253 |
+
self.input_list = self.params.input_list
|
254 |
+
self.label_list = self.params.label_list
|
255 |
+
|
256 |
+
if self.params.task == 'deepsurv':
|
257 |
+
self.period_name = self.params.period_name
|
258 |
+
|
259 |
+
self.df_split = self.df_source[self.df_source['split'] == self.split]
|
260 |
+
self.col_index_dict = {col_name: self.df_split.columns.get_loc(col_name) for col_name in self.df_split.columns}
|
261 |
+
|
262 |
+
# For input data
|
263 |
+
if self.params.mlp is not None:
|
264 |
+
assert (self.input_list != []), f"input list is empty."
|
265 |
+
if params.isTrain:
|
266 |
+
self.scaler = self._make_scaler()
|
267 |
+
else:
|
268 |
+
# load scaler used at training.
|
269 |
+
self.scaler = self.load_scaler(self.params.scaler_path)
|
270 |
+
|
271 |
+
# For image
|
272 |
+
if self.params.net is not None:
|
273 |
+
self.augmentation = self._make_augmentations()
|
274 |
+
self.transform = self._make_transforms()
|
275 |
+
|
276 |
+
def __len__(self) -> int:
|
277 |
+
"""
|
278 |
+
Return length of DataFrame.
|
279 |
+
|
280 |
+
Returns:
|
281 |
+
int: length of DataFrame
|
282 |
+
"""
|
283 |
+
return len(self.df_split)
|
284 |
+
|
285 |
+
def _load_label(self, idx: int) -> Dict[str, Union[int, float]]:
|
286 |
+
"""
|
287 |
+
Return labels.
|
288 |
+
If no column of label when csv of external dataset is used,
|
289 |
+
empty dictionary is returned.
|
290 |
+
|
291 |
+
Args:
|
292 |
+
idx (int): index
|
293 |
+
|
294 |
+
Returns:
|
295 |
+
Dict[str, Union[int, float]]: dictionary of label name and its value
|
296 |
+
"""
|
297 |
+
# For checking if columns of labels exist when used csv for external dataset.
|
298 |
+
label_list_in_split = list(self.df_split.columns[self.df_split.columns.str.startswith('label')])
|
299 |
+
label_dict = dict()
|
300 |
+
if label_list_in_split != []:
|
301 |
+
for label_name in self.label_list:
|
302 |
+
label_dict[label_name] = self.df_split.iat[idx, self.col_index_dict[label_name]]
|
303 |
+
else:
|
304 |
+
# no label
|
305 |
+
pass
|
306 |
+
return label_dict
|
307 |
+
|
308 |
+
def __getitem__(self, idx: int) -> Dict:
|
309 |
+
"""
|
310 |
+
Return data row specified by index.
|
311 |
+
|
312 |
+
Args:
|
313 |
+
idx (int): index
|
314 |
+
|
315 |
+
Returns:
|
316 |
+
Dict: dictionary of data to be passed model
|
317 |
+
"""
|
318 |
+
uniqID = self.df_split.iat[idx, self.col_index_dict['uniqID']]
|
319 |
+
group = self.df_split.iat[idx, self.col_index_dict['group']]
|
320 |
+
imgpath = self.df_split.iat[idx, self.col_index_dict['imgpath']]
|
321 |
+
split = self.df_split.iat[idx, self.col_index_dict['split']]
|
322 |
+
inputs_value = self._load_input_value_if_mlp(idx)
|
323 |
+
image = self._load_image_if_cnn(idx)
|
324 |
+
label_dict = self._load_label(idx)
|
325 |
+
periods = self._load_periods_if_deepsurv(idx)
|
326 |
+
|
327 |
+
_data = {
|
328 |
+
'uniqID': uniqID,
|
329 |
+
'group': group,
|
330 |
+
'imgpath': imgpath,
|
331 |
+
'split': split,
|
332 |
+
'inputs': inputs_value,
|
333 |
+
'image': image,
|
334 |
+
'labels': label_dict,
|
335 |
+
'periods': periods
|
336 |
+
}
|
337 |
+
return _data
|
338 |
+
|
339 |
+
|
340 |
+
def _make_sampler(split_data: LoadDataSet) -> WeightedRandomSampler:
|
341 |
+
"""
|
342 |
+
Make sampler.
|
343 |
+
|
344 |
+
Args:
|
345 |
+
split_data (LoadDataSet): dataset
|
346 |
+
|
347 |
+
Returns:
|
348 |
+
WeightedRandomSampler: sampler
|
349 |
+
"""
|
350 |
+
_target = []
|
351 |
+
for _, data in enumerate(split_data):
|
352 |
+
_target.append(list(data['labels'].values())[0])
|
353 |
+
|
354 |
+
class_sample_count = np.array([len(np.where(_target == t)[0]) for t in np.unique(_target)])
|
355 |
+
weight = 1. / class_sample_count
|
356 |
+
samples_weight = np.array([weight[t] for t in _target])
|
357 |
+
sampler = WeightedRandomSampler(samples_weight, len(samples_weight))
|
358 |
+
return sampler
|
359 |
+
|
360 |
+
|
361 |
+
def create_dataloader(
|
362 |
+
params,
|
363 |
+
split: str = None
|
364 |
+
) -> DataLoader:
|
365 |
+
"""
|
366 |
+
Create data loader ofr split.
|
367 |
+
|
368 |
+
Args:
|
369 |
+
params (ParamSet): parameter for dataloader
|
370 |
+
split (str): split. Defaults to None.
|
371 |
+
|
372 |
+
Returns:
|
373 |
+
DataLoader: data loader
|
374 |
+
"""
|
375 |
+
split_data = LoadDataSet(params, split)
|
376 |
+
|
377 |
+
if params.isTrain:
|
378 |
+
batch_size = params.batch_size
|
379 |
+
shuffle = True
|
380 |
+
else:
|
381 |
+
batch_size = params.test_batch_size
|
382 |
+
shuffle = False
|
383 |
+
|
384 |
+
if params.sampler == 'yes':
|
385 |
+
assert ((params.task == 'classification') or (params.task == 'deepsurv')), 'Cannot make sampler in regression.'
|
386 |
+
assert (len(params.label_list) == 1), 'Cannot make sampler for multi-label.'
|
387 |
+
shuffle = False
|
388 |
+
sampler = _make_sampler(split_data)
|
389 |
+
else:
|
390 |
+
# When params.sampler == 'no'
|
391 |
+
sampler = None
|
392 |
+
|
393 |
+
split_loader = DataLoader(
|
394 |
+
dataset=split_data,
|
395 |
+
batch_size=batch_size,
|
396 |
+
shuffle=shuffle,
|
397 |
+
num_workers=0,
|
398 |
+
sampler=sampler
|
399 |
+
)
|
400 |
+
return split_loader
|
lib/framework.py
ADDED
@@ -0,0 +1,373 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
|
4 |
+
from pathlib import Path
|
5 |
+
import copy
|
6 |
+
from abc import ABC, abstractmethod
|
7 |
+
import torch
|
8 |
+
import torch.nn as nn
|
9 |
+
from .component import create_net
|
10 |
+
from .logger import BaseLogger
|
11 |
+
from lib import ParamSet
|
12 |
+
from typing import List, Dict, Tuple, Union
|
13 |
+
|
14 |
+
# Alias of typing
|
15 |
+
# eg. {'labels': {'label_A: torch.Tensor([0, 1, ...]), ...}}
|
16 |
+
LabelDict = Dict[str, Dict[str, Union[torch.IntTensor, torch.FloatTensor]]]
|
17 |
+
|
18 |
+
|
19 |
+
logger = BaseLogger.get_logger(__name__)
|
20 |
+
|
21 |
+
|
22 |
+
class BaseModel(ABC):
|
23 |
+
"""
|
24 |
+
Class to construct model. This class is the base class to construct model.
|
25 |
+
"""
|
26 |
+
def __init__(self, params: ParamSet) -> None:
|
27 |
+
"""
|
28 |
+
Class to define Model
|
29 |
+
|
30 |
+
Args:
|
31 |
+
param (ParamSet): parameters
|
32 |
+
"""
|
33 |
+
self.params = params
|
34 |
+
self.device = self.params.device
|
35 |
+
|
36 |
+
self.network = create_net(
|
37 |
+
mlp=self.params.mlp,
|
38 |
+
net=self.params.net,
|
39 |
+
num_outputs_for_label=self.params.num_outputs_for_label,
|
40 |
+
mlp_num_inputs=self.params.mlp_num_inputs,
|
41 |
+
in_channel=self.params.in_channel,
|
42 |
+
vit_image_size=self.params.vit_image_size,
|
43 |
+
pretrained=self.params.pretrained
|
44 |
+
)
|
45 |
+
self.network.to(self.device)
|
46 |
+
|
47 |
+
# variables to keep temporary best_weight and best_epoch
|
48 |
+
self.acting_best_weight = None
|
49 |
+
self.acting_best_epoch = None
|
50 |
+
|
51 |
+
def train(self) -> None:
|
52 |
+
"""
|
53 |
+
Make network training mode.
|
54 |
+
"""
|
55 |
+
self.network.train()
|
56 |
+
|
57 |
+
def eval(self) -> None:
|
58 |
+
"""
|
59 |
+
Make network evaluation mode.
|
60 |
+
"""
|
61 |
+
self.network.eval()
|
62 |
+
|
63 |
+
@abstractmethod
|
64 |
+
def set_data(
|
65 |
+
self,
|
66 |
+
data: Dict
|
67 |
+
) -> Tuple[
|
68 |
+
Dict[str, torch.FloatTensor],
|
69 |
+
Dict[str, Union[LabelDict, torch.IntTensor, nn.Module]]
|
70 |
+
]:
|
71 |
+
raise NotImplementedError
|
72 |
+
|
73 |
+
def store_weight(self, at_epoch: int = None) -> None:
|
74 |
+
"""
|
75 |
+
Store weight and epoch number when it is saved.
|
76 |
+
|
77 |
+
Args:
|
78 |
+
at_epoch (int): epoch number when save weight
|
79 |
+
"""
|
80 |
+
self.acting_best_epoch = at_epoch
|
81 |
+
|
82 |
+
_network = copy.deepcopy(self.network)
|
83 |
+
if hasattr(_network, 'module'):
|
84 |
+
# When DataParallel used, move weight to CPU.
|
85 |
+
self.acting_best_weight = copy.deepcopy(_network.module.to(torch.device('cpu')).state_dict())
|
86 |
+
else:
|
87 |
+
self.acting_best_weight = copy.deepcopy(_network.state_dict())
|
88 |
+
|
89 |
+
def save_weight(self, save_datetime_dir: str, as_best: bool = None) -> None:
|
90 |
+
"""
|
91 |
+
Save weight.
|
92 |
+
|
93 |
+
Args:
|
94 |
+
save_datetime_dir (str): save_datetime_dir
|
95 |
+
as_best (bool): True if weight is saved as best, otherwise False. Defaults to None.
|
96 |
+
"""
|
97 |
+
|
98 |
+
save_dir = Path(save_datetime_dir, 'weights')
|
99 |
+
save_dir.mkdir(parents=True, exist_ok=True)
|
100 |
+
save_name = 'weight_epoch-' + str(self.acting_best_epoch).zfill(3) + '.pt'
|
101 |
+
save_path = Path(save_dir, save_name)
|
102 |
+
|
103 |
+
if as_best:
|
104 |
+
save_name_as_best = 'weight_epoch-' + str(self.acting_best_epoch).zfill(3) + '_best' + '.pt'
|
105 |
+
save_path_as_best = Path(save_dir, save_name_as_best)
|
106 |
+
if save_path.exists():
|
107 |
+
# Check if best weight already saved. If exists, rename with '_best'
|
108 |
+
save_path.rename(save_path_as_best)
|
109 |
+
else:
|
110 |
+
torch.save(self.acting_best_weight, save_path_as_best)
|
111 |
+
else:
|
112 |
+
save_name = 'weight_epoch-' + str(self.acting_best_epoch).zfill(3) + '.pt'
|
113 |
+
torch.save(self.acting_best_weight, save_path)
|
114 |
+
|
115 |
+
def load_weight(self, weight_path: Path) -> None:
|
116 |
+
"""
|
117 |
+
Load wight from weight_path.
|
118 |
+
|
119 |
+
Args:
|
120 |
+
weight_path (Path): path to weight
|
121 |
+
"""
|
122 |
+
logger.info(f"Load weight: {weight_path}.\n")
|
123 |
+
weight = torch.load(weight_path)
|
124 |
+
self.network.load_state_dict(weight)
|
125 |
+
|
126 |
+
|
127 |
+
class ModelMixin:
|
128 |
+
def to_gpu(self, gpu_ids: List[int]) -> None:
|
129 |
+
"""
|
130 |
+
Make model compute on the GPU.
|
131 |
+
|
132 |
+
Args:
|
133 |
+
gpu_ids (List[int]): GPU ids
|
134 |
+
"""
|
135 |
+
if gpu_ids != []:
|
136 |
+
assert torch.cuda.is_available(), 'No available GPU on this machine.'
|
137 |
+
self.network = nn.DataParallel(self.network, device_ids=gpu_ids)
|
138 |
+
|
139 |
+
def init_network(self) -> None:
|
140 |
+
"""
|
141 |
+
Initialize network.
|
142 |
+
This method is used at test to reset the current weight by redefining network.
|
143 |
+
"""
|
144 |
+
self.network = create_net(
|
145 |
+
mlp=self.params.mlp,
|
146 |
+
net=self.params.net,
|
147 |
+
num_outputs_for_label=self.params.num_outputs_for_label,
|
148 |
+
mlp_num_inputs=self.params.mlp_num_inputs,
|
149 |
+
in_channel=self.params.in_channel,
|
150 |
+
vit_image_size=self.params.vit_image_size,
|
151 |
+
pretrained=self.params.pretrained
|
152 |
+
)
|
153 |
+
self.network.to(self.device)
|
154 |
+
|
155 |
+
class ModelWidget(BaseModel, ModelMixin):
|
156 |
+
"""
|
157 |
+
Class for a widget to inherit multiple classes simultaneously
|
158 |
+
"""
|
159 |
+
pass
|
160 |
+
|
161 |
+
|
162 |
+
class MLPModel(ModelWidget):
|
163 |
+
"""
|
164 |
+
Class for MLP model
|
165 |
+
"""
|
166 |
+
|
167 |
+
def __init__(self, params: ParamSet) -> None:
|
168 |
+
"""
|
169 |
+
Args:
|
170 |
+
params: (ParamSet): parameters
|
171 |
+
"""
|
172 |
+
super().__init__(params)
|
173 |
+
|
174 |
+
def set_data(
|
175 |
+
self,
|
176 |
+
data: Dict
|
177 |
+
) -> Tuple[
|
178 |
+
Dict[str, torch.FloatTensor],
|
179 |
+
Dict[str, Union[LabelDict, torch.IntTensor, nn.Module]]
|
180 |
+
]:
|
181 |
+
"""
|
182 |
+
Unpack data for forwarding of MLP and calculating loss
|
183 |
+
by passing them to device.
|
184 |
+
When deepsurv, period and network are also returned.
|
185 |
+
|
186 |
+
Args:
|
187 |
+
data (Dict): dictionary of data
|
188 |
+
|
189 |
+
Returns:
|
190 |
+
Tuple[
|
191 |
+
Dict[str, torch.FloatTensor],
|
192 |
+
Dict[str, Union[LabelDict, torch.IntTensor, nn.Module]]
|
193 |
+
]: input of model and data for calculating loss.
|
194 |
+
eg.
|
195 |
+
([inputs], [labels]), or ([inputs], [labels, periods, network]) when deepsurv
|
196 |
+
"""
|
197 |
+
in_data = {'inputs': data['inputs'].to(self.device)}
|
198 |
+
labels = {'labels': {label_name: label.to(self.device) for label_name, label in data['labels'].items()}}
|
199 |
+
|
200 |
+
if not any(data['periods']):
|
201 |
+
return in_data, labels
|
202 |
+
|
203 |
+
# When deepsurv
|
204 |
+
labels = {
|
205 |
+
**labels,
|
206 |
+
**{'periods': data['periods'].to(self.device), 'network': self.network.to(self.device)}
|
207 |
+
}
|
208 |
+
return in_data, labels
|
209 |
+
|
210 |
+
def __call__(self, in_data: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
|
211 |
+
"""
|
212 |
+
Forward
|
213 |
+
|
214 |
+
Args:
|
215 |
+
in_data (Dict[str, torch.Tensor]): data to be input into model
|
216 |
+
|
217 |
+
Returns:
|
218 |
+
Dict[str, torch.Tensor]: output
|
219 |
+
"""
|
220 |
+
inputs = in_data['inputs']
|
221 |
+
output = self.network(inputs)
|
222 |
+
return output
|
223 |
+
|
224 |
+
|
225 |
+
class CVModel(ModelWidget):
|
226 |
+
"""
|
227 |
+
Class for CNN or ViT model
|
228 |
+
"""
|
229 |
+
def __init__(self, params: ParamSet) -> None:
|
230 |
+
"""
|
231 |
+
Args:
|
232 |
+
params: (ParamSet): parameters
|
233 |
+
"""
|
234 |
+
super().__init__(params)
|
235 |
+
|
236 |
+
def set_data(
|
237 |
+
self,
|
238 |
+
data: Dict
|
239 |
+
) -> Tuple[
|
240 |
+
Dict[str, torch.FloatTensor],
|
241 |
+
Dict[str, Union[LabelDict, torch.IntTensor, nn.Module]]
|
242 |
+
]:
|
243 |
+
"""
|
244 |
+
Unpack data for forwarding of CNN or ViT and calculating loss by passing them to device.
|
245 |
+
When deepsurv, period and network are also returned.
|
246 |
+
|
247 |
+
Args:
|
248 |
+
data (Dict): dictionary of data
|
249 |
+
|
250 |
+
Returns:
|
251 |
+
Tuple[
|
252 |
+
Dict[str, torch.FloatTensor],
|
253 |
+
Dict[str, Union[LabelDict, torch.IntTensor, nn.Module]]
|
254 |
+
]: input of model and data for calculating loss.
|
255 |
+
eg.
|
256 |
+
([image], [labels]), or ([image], [labels, periods, network]) when deepsurv
|
257 |
+
"""
|
258 |
+
in_data = {'image': data['image'].to(self.device)}
|
259 |
+
labels = {'labels': {label_name: label.to(self.device) for label_name, label in data['labels'].items()}}
|
260 |
+
|
261 |
+
if not any(data['periods']):
|
262 |
+
return in_data, labels
|
263 |
+
|
264 |
+
# When deepsurv
|
265 |
+
labels = {
|
266 |
+
**labels,
|
267 |
+
**{'periods': data['periods'].to(self.device), 'network': self.network.to(self.device)}
|
268 |
+
}
|
269 |
+
return in_data, labels
|
270 |
+
|
271 |
+
def __call__(self, in_data: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
|
272 |
+
"""
|
273 |
+
Forward
|
274 |
+
|
275 |
+
Args:
|
276 |
+
in_data (Dict[str, torch.Tensor]): data to be input into model
|
277 |
+
|
278 |
+
Returns:
|
279 |
+
Dict[str, torch.Tensor]: output
|
280 |
+
"""
|
281 |
+
image = in_data['image']
|
282 |
+
output = self.network(image)
|
283 |
+
return output
|
284 |
+
|
285 |
+
|
286 |
+
class FusionModel(ModelWidget):
|
287 |
+
"""
|
288 |
+
Class for MLP+CNN or MLP+ViT model.
|
289 |
+
"""
|
290 |
+
def __init__(self, params: ParamSet) -> None:
|
291 |
+
"""
|
292 |
+
Args:
|
293 |
+
params: (ParamSet): parameters
|
294 |
+
"""
|
295 |
+
super().__init__(params)
|
296 |
+
|
297 |
+
def set_data(
|
298 |
+
self,
|
299 |
+
data: Dict
|
300 |
+
) -> Tuple[
|
301 |
+
Dict[str, torch.FloatTensor],
|
302 |
+
Dict[str, Union[LabelDict, torch.IntTensor, nn.Module]]
|
303 |
+
]:
|
304 |
+
"""
|
305 |
+
Unpack data for forwarding of MLP+CNN or MLP+ViT and calculating loss
|
306 |
+
by passing them to device.
|
307 |
+
When deepsurv, period and network are also returned.
|
308 |
+
|
309 |
+
Args:
|
310 |
+
data (Dict): dictionary of data
|
311 |
+
|
312 |
+
Returns:
|
313 |
+
Tuple[
|
314 |
+
Dict[str, torch.FloatTensor],
|
315 |
+
Dict[str, Union[LabelDict, torch.IntTensor, nn.Module]]
|
316 |
+
]: input of model and data for calculating loss.
|
317 |
+
eg.
|
318 |
+
([inputs, image], [labels]), or ([inputs, image], [labels, periods, network]) when deepsurv
|
319 |
+
"""
|
320 |
+
in_data = {
|
321 |
+
'inputs': data['inputs'].to(self.device),
|
322 |
+
'image': data['image'].to(self.device)
|
323 |
+
}
|
324 |
+
labels = {'labels': {label_name: label.to(self.device) for label_name, label in data['labels'].items()}}
|
325 |
+
|
326 |
+
if not any(data['periods']):
|
327 |
+
return in_data, labels
|
328 |
+
|
329 |
+
# When deepsurv
|
330 |
+
labels = {
|
331 |
+
**labels,
|
332 |
+
**{'periods': data['periods'].to(self.device), 'network': self.network.to(self.device)}
|
333 |
+
}
|
334 |
+
return in_data, labels
|
335 |
+
|
336 |
+
def __call__(self, in_data: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
|
337 |
+
"""
|
338 |
+
Forward
|
339 |
+
|
340 |
+
Args:
|
341 |
+
in_data (Dict[str, torch.Tensor]): data to be input into model
|
342 |
+
|
343 |
+
Returns:
|
344 |
+
Dict[str, torch.Tensor]: output
|
345 |
+
"""
|
346 |
+
inputs = in_data['inputs']
|
347 |
+
image = in_data['image']
|
348 |
+
output = self.network(inputs, image)
|
349 |
+
return output
|
350 |
+
|
351 |
+
|
352 |
+
def create_model(params: ParamSet) -> nn.Module:
|
353 |
+
"""
|
354 |
+
Construct model.
|
355 |
+
|
356 |
+
Args:
|
357 |
+
params (ParamSet): parameters
|
358 |
+
|
359 |
+
Returns:
|
360 |
+
nn.Module: model
|
361 |
+
"""
|
362 |
+
_isMLPModel = (params.mlp is not None) and (params.net is None)
|
363 |
+
_isCVModel = (params.mlp is None) and (params.net is not None)
|
364 |
+
_isFusionModel = (params.mlp is not None) and (params.net is not None)
|
365 |
+
|
366 |
+
if _isMLPModel:
|
367 |
+
return MLPModel(params)
|
368 |
+
elif _isCVModel:
|
369 |
+
return CVModel(params)
|
370 |
+
elif _isFusionModel:
|
371 |
+
return FusionModel(params)
|
372 |
+
else:
|
373 |
+
raise ValueError(f"Invalid model type: mlp={params.mlp}, net={params.net}.")
|
lib/logger.py
ADDED
@@ -0,0 +1,71 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
|
4 |
+
from pathlib import Path
|
5 |
+
import logging
|
6 |
+
|
7 |
+
|
8 |
+
class BaseLogger:
|
9 |
+
"""
|
10 |
+
Class for defining logger.
|
11 |
+
"""
|
12 |
+
_unexecuted_configure = True
|
13 |
+
|
14 |
+
@classmethod
|
15 |
+
def get_logger(cls, name: str) -> logging.Logger:
|
16 |
+
"""
|
17 |
+
Set logger.
|
18 |
+
Args:
|
19 |
+
name (str): If needed, potentially hierarchical name is desired, eg. lib.net, lib.dataloader, etc.
|
20 |
+
For the details, see https://docs.python.org/3/library/logging.html?highlight=logging#module-logging.
|
21 |
+
Returns:
|
22 |
+
logging.Logger: logger
|
23 |
+
"""
|
24 |
+
if cls._unexecuted_configure:
|
25 |
+
cls._init_logger()
|
26 |
+
|
27 |
+
return logging.getLogger('nervus.{}'.format(name))
|
28 |
+
|
29 |
+
@classmethod
|
30 |
+
def _init_logger(cls) -> None:
|
31 |
+
"""
|
32 |
+
Configure logger.
|
33 |
+
"""
|
34 |
+
_root_logger = logging.getLogger('nervus')
|
35 |
+
_root_logger.setLevel(logging.DEBUG)
|
36 |
+
formatter = logging.Formatter('%(asctime)s - %(levelname)s - %(message)s')
|
37 |
+
|
38 |
+
log_dir = Path('logs')
|
39 |
+
log_dir.mkdir(parents=True, exist_ok=True)
|
40 |
+
log_path = Path(log_dir, 'log.log')
|
41 |
+
|
42 |
+
# file handler
|
43 |
+
## upper warning
|
44 |
+
fh_err = logging.FileHandler(log_path)
|
45 |
+
fh_err.setLevel(logging.WARNING)
|
46 |
+
fh_err.setFormatter(formatter)
|
47 |
+
fh_err.addFilter(lambda log_record: not ('BdbQuit' in str(log_record.exc_info)) and (log_record.levelno >= logging.WARNING))
|
48 |
+
_root_logger.addHandler(fh_err)
|
49 |
+
|
50 |
+
## lower warning
|
51 |
+
fh = logging.FileHandler(log_path)
|
52 |
+
fh.setLevel(logging.DEBUG)
|
53 |
+
fh.addFilter(lambda log_record: log_record.levelno < logging.WARNING)
|
54 |
+
_root_logger.addHandler(fh)
|
55 |
+
|
56 |
+
# stream handler
|
57 |
+
## upper warning
|
58 |
+
ch_err = logging.StreamHandler()
|
59 |
+
ch_err.setLevel(logging.WARNING)
|
60 |
+
ch_err.setFormatter(formatter)
|
61 |
+
ch_err.addFilter(lambda log_record: log_record.levelno >= logging.WARNING)
|
62 |
+
_root_logger.addHandler(ch_err)
|
63 |
+
|
64 |
+
## lower warning
|
65 |
+
ch = logging.StreamHandler()
|
66 |
+
ch.setLevel(logging.DEBUG)
|
67 |
+
ch.addFilter(lambda log_record: log_record.levelno < logging.WARNING)
|
68 |
+
_root_logger.addHandler(ch)
|
69 |
+
|
70 |
+
cls._unexecuted_configure = False
|
71 |
+
|
lib/metrics.py
ADDED
@@ -0,0 +1,623 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
|
4 |
+
from pathlib import Path
|
5 |
+
import numpy as np
|
6 |
+
import pandas as pd
|
7 |
+
from sklearn import metrics
|
8 |
+
from sklearn.preprocessing import label_binarize
|
9 |
+
import matplotlib.pyplot as plt
|
10 |
+
from matplotlib import colors as mcolors
|
11 |
+
from .logger import BaseLogger
|
12 |
+
from typing import Dict, Union
|
13 |
+
|
14 |
+
|
15 |
+
logger = BaseLogger.get_logger(__name__)
|
16 |
+
|
17 |
+
|
18 |
+
class MetricsData:
|
19 |
+
"""
|
20 |
+
Class to store metrics as class variable.
|
21 |
+
Metrics are defined depending on task.
|
22 |
+
|
23 |
+
For ROC
|
24 |
+
self.fpr: np.ndarray
|
25 |
+
self.tpr: np.ndarray
|
26 |
+
self.auc: float
|
27 |
+
|
28 |
+
For Regression
|
29 |
+
self.y_obs: np.ndarray
|
30 |
+
self.y_pred: np.ndarray
|
31 |
+
self.r2: float
|
32 |
+
|
33 |
+
For DeepSurv
|
34 |
+
self.c_index: float
|
35 |
+
"""
|
36 |
+
def __init__(self) -> None:
|
37 |
+
pass
|
38 |
+
|
39 |
+
|
40 |
+
class LabelMetrics:
|
41 |
+
"""
|
42 |
+
Class to store metrics of each split for each label.
|
43 |
+
"""
|
44 |
+
def __init__(self) -> None:
|
45 |
+
"""
|
46 |
+
Metrics of split, ie 'val' and 'test'
|
47 |
+
"""
|
48 |
+
self.val = MetricsData()
|
49 |
+
self.test = MetricsData()
|
50 |
+
|
51 |
+
def set_label_metrics(self, split: str, attr: str, value: Union[np.ndarray, float]) -> None:
|
52 |
+
"""
|
53 |
+
Set value as appropriate metrics of split.
|
54 |
+
|
55 |
+
Args:
|
56 |
+
split (str): split
|
57 |
+
attr (str): attribute name as follows:
|
58 |
+
classification: 'fpr', 'tpr', or 'auc',
|
59 |
+
regression: 'y_obs'(ground truth), 'y_pred'(prediction) or 'r2', or
|
60 |
+
deepsurv: 'c_index'
|
61 |
+
value (Union[np.ndarray,float]): value of attr
|
62 |
+
"""
|
63 |
+
setattr(getattr(self, split), attr, value)
|
64 |
+
|
65 |
+
def get_label_metrics(self, split: str, attr: str) -> Union[np.ndarray, float]:
|
66 |
+
"""
|
67 |
+
Return value of metrics of split.
|
68 |
+
|
69 |
+
Args:
|
70 |
+
split (str): split
|
71 |
+
attr (str): metrics name
|
72 |
+
|
73 |
+
Returns:
|
74 |
+
Union[np.ndarray,float]: value of attr
|
75 |
+
"""
|
76 |
+
return getattr(getattr(self, split), attr)
|
77 |
+
|
78 |
+
|
79 |
+
class ROCMixin:
|
80 |
+
"""
|
81 |
+
Class for calculating ROC and AUC.
|
82 |
+
"""
|
83 |
+
def _set_roc(self, label_metrics: LabelMetrics, split: str, fpr: np.ndarray, tpr: np.ndarray) -> None:
|
84 |
+
"""
|
85 |
+
Set fpr, tpr, and auc.
|
86 |
+
|
87 |
+
Args:
|
88 |
+
label_metrics (LabelMetrics): metrics of 'val' and 'test'
|
89 |
+
split (str): 'val' or 'test'
|
90 |
+
fpr (np.ndarray): FPR
|
91 |
+
tpr (np.ndarray): TPR
|
92 |
+
|
93 |
+
self.metrics_kind = 'auc' is defined in class ClsEval below.
|
94 |
+
"""
|
95 |
+
label_metrics.set_label_metrics(split, 'fpr', fpr)
|
96 |
+
label_metrics.set_label_metrics(split, 'tpr', tpr)
|
97 |
+
label_metrics.set_label_metrics(split, self.metrics_kind, metrics.auc(fpr, tpr))
|
98 |
+
|
99 |
+
def _cal_label_roc_binary(self, label_name: str, df_group: pd.DataFrame) -> LabelMetrics:
|
100 |
+
"""
|
101 |
+
Calculate ROC for binary class.
|
102 |
+
|
103 |
+
Args:
|
104 |
+
label_name (str): label name
|
105 |
+
df_group (pd.DataFrame): likelihood for group
|
106 |
+
|
107 |
+
Returns:
|
108 |
+
LabelMetrics: metrics of 'val' and 'test'
|
109 |
+
"""
|
110 |
+
required_columns = [column_name for column_name in df_group.columns if label_name in column_name] + ['split']
|
111 |
+
df_label = df_group[required_columns]
|
112 |
+
POSITIVE = 1
|
113 |
+
positive_pred_name = 'pred_' + label_name + '_' + str(POSITIVE)
|
114 |
+
|
115 |
+
# ! When splits is 'test' only, ie when external dataset, error occurs.
|
116 |
+
label_metrics = LabelMetrics()
|
117 |
+
for split in ['val', 'test']:
|
118 |
+
df_split = df_label.query('split == @split')
|
119 |
+
y_true = df_split[label_name]
|
120 |
+
y_score = df_split[positive_pred_name]
|
121 |
+
_fpr, _tpr, _ = metrics.roc_curve(y_true, y_score)
|
122 |
+
self._set_roc(label_metrics, split, _fpr, _tpr)
|
123 |
+
return label_metrics
|
124 |
+
|
125 |
+
def _cal_label_roc_multi(self, label_name: str, df_group: pd.DataFrame) -> LabelMetrics:
|
126 |
+
"""
|
127 |
+
Calculate ROC for multi-class by macro average.
|
128 |
+
|
129 |
+
Args:
|
130 |
+
label_name (str): label name
|
131 |
+
df_group (pd.DataFrame): likelihood for group
|
132 |
+
|
133 |
+
Returns:
|
134 |
+
LabelMetrics: metrics of 'val' and 'test'
|
135 |
+
"""
|
136 |
+
required_columns = [column_name for column_name in df_group.columns if label_name in column_name] + ['split']
|
137 |
+
df_label = df_group[required_columns]
|
138 |
+
|
139 |
+
pred_name_list = list(df_label.columns[df_label.columns.str.startswith('pred')])
|
140 |
+
class_list = [int(pred_name.rsplit('_', 1)[-1]) for pred_name in pred_name_list] # [pred_label_0, pred_label_1, pred_label_2] -> [0, 1, 2]
|
141 |
+
num_classes = len(class_list)
|
142 |
+
|
143 |
+
label_metrics = LabelMetrics()
|
144 |
+
for split in ['val', 'test']:
|
145 |
+
df_split = df_label.query('split == @split')
|
146 |
+
y_true = df_split[label_name]
|
147 |
+
y_true_bin = label_binarize(y_true, classes=class_list) # Since y_true: List[int], should be class_list: List[int]
|
148 |
+
|
149 |
+
# Compute ROC for each class by OneVsRest
|
150 |
+
_fpr = dict()
|
151 |
+
_tpr = dict()
|
152 |
+
for i, class_name in enumerate(class_list):
|
153 |
+
pred_name = 'pred_' + label_name + '_' + str(class_name)
|
154 |
+
_fpr[class_name], _tpr[class_name], _ = metrics.roc_curve(y_true_bin[:, i], df_split[pred_name])
|
155 |
+
|
156 |
+
# First aggregate all false positive rates
|
157 |
+
all_fpr = np.unique(np.concatenate([_fpr[class_name] for class_name in class_list]))
|
158 |
+
|
159 |
+
# Then interpolate all ROC at this points
|
160 |
+
mean_tpr = np.zeros_like(all_fpr)
|
161 |
+
for class_name in class_list:
|
162 |
+
mean_tpr += np.interp(all_fpr, _fpr[class_name], _tpr[class_name])
|
163 |
+
|
164 |
+
# Finally average it and compute AUC
|
165 |
+
mean_tpr /= num_classes
|
166 |
+
|
167 |
+
_fpr['macro'] = all_fpr
|
168 |
+
_tpr['macro'] = mean_tpr
|
169 |
+
self._set_roc(label_metrics, split, _fpr['macro'], _tpr['macro'])
|
170 |
+
return label_metrics
|
171 |
+
|
172 |
+
def cal_label_metrics(self, label_name: str, df_group: pd.DataFrame) -> LabelMetrics:
|
173 |
+
"""
|
174 |
+
Calculate ROC and AUC for label depending on binary or multi-class.
|
175 |
+
|
176 |
+
Args:
|
177 |
+
label_name (str):label name
|
178 |
+
df_group (pd.DataFrame): likelihood for group
|
179 |
+
|
180 |
+
Returns:
|
181 |
+
LabelMetrics: metrics of 'val' and 'test'
|
182 |
+
"""
|
183 |
+
pred_name_list = df_group.columns[df_group.columns.str.startswith('pred_' + label_name)]
|
184 |
+
isMultiClass = (len(pred_name_list) > 2)
|
185 |
+
if isMultiClass:
|
186 |
+
label_metrics = self._cal_label_roc_multi(label_name, df_group)
|
187 |
+
else:
|
188 |
+
label_metrics = self._cal_label_roc_binary(label_name, df_group)
|
189 |
+
return label_metrics
|
190 |
+
|
191 |
+
|
192 |
+
class YYMixin:
|
193 |
+
"""
|
194 |
+
Class for calculating YY and R2.
|
195 |
+
"""
|
196 |
+
def _set_yy(self, label_metrics: LabelMetrics, split: str, y_obs: np.ndarray, y_pred: np.ndarray) -> None:
|
197 |
+
"""
|
198 |
+
Set ground truth, prediction, and R2.
|
199 |
+
|
200 |
+
Args:
|
201 |
+
label_metrics (LabelMetrics): metrics of 'val' and 'test'
|
202 |
+
split (str): 'val' or 'test'
|
203 |
+
y_obs (np.ndarray): ground truth
|
204 |
+
y_pred (np.ndarray): prediction
|
205 |
+
|
206 |
+
self.metrics_kind = 'r2' is defined in class RegEval below.
|
207 |
+
"""
|
208 |
+
label_metrics.set_label_metrics(split, 'y_obs', y_obs.values)
|
209 |
+
label_metrics.set_label_metrics(split, 'y_pred', y_pred.values)
|
210 |
+
label_metrics.set_label_metrics(split, self.metrics_kind, metrics.r2_score(y_obs, y_pred))
|
211 |
+
|
212 |
+
def cal_label_metrics(self, label_name: str, df_group: pd.DataFrame) -> LabelMetrics:
|
213 |
+
"""
|
214 |
+
Calculate YY and R2 for label.
|
215 |
+
|
216 |
+
Args:
|
217 |
+
label_name (str): label name
|
218 |
+
df_group (pd.DataFrame): likelihood for group
|
219 |
+
|
220 |
+
Returns:
|
221 |
+
LabelMetrics: metrics of 'val' and 'test'
|
222 |
+
"""
|
223 |
+
required_columns = [column_name for column_name in df_group.columns if label_name in column_name] + ['split']
|
224 |
+
df_label = df_group[required_columns]
|
225 |
+
label_metrics = LabelMetrics()
|
226 |
+
for split in ['val', 'test']:
|
227 |
+
df_split = df_label.query('split == @split')
|
228 |
+
y_obs = df_split[label_name]
|
229 |
+
y_pred = df_split['pred_' + label_name]
|
230 |
+
self._set_yy(label_metrics, split, y_obs, y_pred)
|
231 |
+
return label_metrics
|
232 |
+
|
233 |
+
|
234 |
+
class C_IndexMixin:
|
235 |
+
"""
|
236 |
+
Class for calculating C-Index.
|
237 |
+
"""
|
238 |
+
def _set_c_index(
|
239 |
+
self,
|
240 |
+
label_metrics: LabelMetrics,
|
241 |
+
split: str,
|
242 |
+
periods: pd.Series,
|
243 |
+
preds: pd.Series,
|
244 |
+
labels: pd.Series
|
245 |
+
) -> None:
|
246 |
+
"""
|
247 |
+
Set C-Index.
|
248 |
+
|
249 |
+
Args:
|
250 |
+
label_metrics (LabelMetrics): metrics of 'val' and 'test'
|
251 |
+
split (str): 'val' or 'test'
|
252 |
+
periods (pd.Series): periods
|
253 |
+
preds (pd.Series): prediction
|
254 |
+
labels (pd.Series): label
|
255 |
+
|
256 |
+
self.metrics_kind = 'c_index' is defined in class DeepSurvEval below.
|
257 |
+
"""
|
258 |
+
from lifelines.utils import concordance_index
|
259 |
+
value_c_index = concordance_index(periods, (-1)*preds, labels)
|
260 |
+
label_metrics.set_label_metrics(split, self.metrics_kind, value_c_index)
|
261 |
+
|
262 |
+
def cal_label_metrics(self, label_name: str, df_group: pd.DataFrame) -> LabelMetrics:
|
263 |
+
"""
|
264 |
+
Calculate C-Index for label.
|
265 |
+
|
266 |
+
Args:
|
267 |
+
label_name (str): label name
|
268 |
+
df_group (pd.DataFrame): likelihood for group
|
269 |
+
|
270 |
+
Returns:
|
271 |
+
LabelMetrics: metrics of 'val' and 'test'
|
272 |
+
"""
|
273 |
+
required_columns = [column_name for column_name in df_group.columns if label_name in column_name] + ['periods', 'split']
|
274 |
+
df_label = df_group[required_columns]
|
275 |
+
label_metrics = LabelMetrics()
|
276 |
+
for split in ['val', 'test']:
|
277 |
+
df_split = df_label.query('split == @split')
|
278 |
+
periods = df_split['periods']
|
279 |
+
preds = df_split['pred_' + label_name]
|
280 |
+
labels = df_split[label_name]
|
281 |
+
self._set_c_index(label_metrics, split, periods, preds, labels)
|
282 |
+
return label_metrics
|
283 |
+
|
284 |
+
|
285 |
+
class MetricsMixin:
|
286 |
+
"""
|
287 |
+
Class to calculate metrics and make summary.
|
288 |
+
"""
|
289 |
+
def _cal_group_metrics(self, df_group: pd.DataFrame) -> Dict[str, LabelMetrics]:
|
290 |
+
"""
|
291 |
+
Calculate metrics for each group.
|
292 |
+
|
293 |
+
Args:
|
294 |
+
df_group (pd.DataFrame): likelihood for group
|
295 |
+
|
296 |
+
Returns:
|
297 |
+
Dict[str, LabelMetrics]: dictionary of label and its LabelMetrics
|
298 |
+
eg. {{label_1: LabelMetrics(), label_2: LabelMetrics(), ...}
|
299 |
+
"""
|
300 |
+
label_list = list(df_group.columns[df_group.columns.str.startswith('label')])
|
301 |
+
group_metrics = dict()
|
302 |
+
for label_name in label_list:
|
303 |
+
label_metrics = self.cal_label_metrics(label_name, df_group)
|
304 |
+
group_metrics[label_name] = label_metrics
|
305 |
+
return group_metrics
|
306 |
+
|
307 |
+
def cal_whole_metrics(self, df_likelihood: pd.DataFrame) -> Dict[str, Dict[str, LabelMetrics]]:
|
308 |
+
"""
|
309 |
+
Calculate metrics for all groups.
|
310 |
+
|
311 |
+
Args:
|
312 |
+
df_likelihood (pd.DataFrame) : DataFrame of likelihood
|
313 |
+
|
314 |
+
Returns:
|
315 |
+
Dict[str, Dict[str, LabelMetrics]]: dictionary of group and dictionary of label and its LabelMetrics
|
316 |
+
eg. {
|
317 |
+
groupA: {label_1: LabelMetrics(), label_2: LabelMetrics(), ...},
|
318 |
+
groupB: {label_1: LabelMetrics(), label_2: LabelMetrics()}, ...},
|
319 |
+
...}
|
320 |
+
"""
|
321 |
+
whole_metrics = dict()
|
322 |
+
for group in df_likelihood['group'].unique():
|
323 |
+
df_group = df_likelihood.query('group == @group')
|
324 |
+
whole_metrics[group] = self._cal_group_metrics(df_group)
|
325 |
+
return whole_metrics
|
326 |
+
|
327 |
+
def make_summary(
|
328 |
+
self,
|
329 |
+
whole_metrics: Dict[str, Dict[str, LabelMetrics]],
|
330 |
+
likelihood_path: Path,
|
331 |
+
metrics_kind: str
|
332 |
+
) -> pd.DataFrame:
|
333 |
+
"""
|
334 |
+
Make summary.
|
335 |
+
|
336 |
+
Args:
|
337 |
+
whole_metrics (Dict[str, Dict[str, LabelMetrics]]): metrics for all groups
|
338 |
+
likelihood_path (Path): path to likelihood
|
339 |
+
metrics_kind (str): kind of metrics, ie, 'auc', 'r2', or 'c_index'
|
340 |
+
|
341 |
+
Returns:
|
342 |
+
pd.DataFrame: summary
|
343 |
+
"""
|
344 |
+
_datetime = likelihood_path.parents[1].name
|
345 |
+
_weight = likelihood_path.stem.replace('likelihood_', '') + '.pt'
|
346 |
+
df_summary = pd.DataFrame()
|
347 |
+
for group, group_metrics in whole_metrics.items():
|
348 |
+
_new = dict()
|
349 |
+
_new['datetime'] = [_datetime]
|
350 |
+
_new['weight'] = [ _weight]
|
351 |
+
_new['group'] = [group]
|
352 |
+
for label_name, label_metrics in group_metrics.items():
|
353 |
+
_val_metrics = label_metrics.get_label_metrics('val', metrics_kind)
|
354 |
+
_test_metrics = label_metrics.get_label_metrics('test', metrics_kind)
|
355 |
+
_new[label_name + '_val_' + metrics_kind] = [f"{_val_metrics:.2f}"]
|
356 |
+
_new[label_name + '_test_' + metrics_kind] = [f"{_test_metrics:.2f}"]
|
357 |
+
df_summary = pd.concat([df_summary, pd.DataFrame(_new)], ignore_index=True)
|
358 |
+
|
359 |
+
df_summary = df_summary.sort_values('group')
|
360 |
+
return df_summary
|
361 |
+
|
362 |
+
def print_metrics(self, df_summary: pd.DataFrame, metrics_kind: str) -> None:
|
363 |
+
"""
|
364 |
+
Print metrics.
|
365 |
+
|
366 |
+
Args:
|
367 |
+
df_summary (pd.DataFrame): summary
|
368 |
+
metrics_kind (str): kind of metrics, ie. 'auc', 'r2', or 'c_index'
|
369 |
+
"""
|
370 |
+
label_list = list(df_summary.columns[df_summary.columns.str.startswith('label')]) # [label_1_val, label_1_test, label_2_val, label_2_test, ...]
|
371 |
+
num_splits = len(['val', 'test'])
|
372 |
+
_column_val_test_list = [label_list[i:i+num_splits] for i in range(0, len(label_list), num_splits)] # [[label_1_val, label_1_test], [label_2_val, label_2_test], ...]
|
373 |
+
for _, row in df_summary.iterrows():
|
374 |
+
logger.info(row['group'])
|
375 |
+
for _column_val_test in _column_val_test_list:
|
376 |
+
_label_name = _column_val_test[0].replace('_val', '')
|
377 |
+
_label_name_val = _column_val_test[0]
|
378 |
+
_label_name_test = _column_val_test[1]
|
379 |
+
logger.info(f"{_label_name:<25} val_{metrics_kind}: {row[_label_name_val]:>7}, test_{metrics_kind}: {row[_label_name_test]:>7}")
|
380 |
+
|
381 |
+
def update_summary(self, df_summary: pd.DataFrame, likelihood_path: Path) -> None:
|
382 |
+
"""
|
383 |
+
Update summary.
|
384 |
+
|
385 |
+
Args:
|
386 |
+
df_summary (pd.DataFrame): summary to be added to the previous summary
|
387 |
+
likelihood_path (Path): path to likelihood
|
388 |
+
"""
|
389 |
+
_project_dir = likelihood_path.parents[3]
|
390 |
+
summary_dir = Path(_project_dir, 'summary')
|
391 |
+
summary_path = Path(summary_dir, 'summary.csv')
|
392 |
+
if summary_path.exists():
|
393 |
+
df_prev = pd.read_csv(summary_path)
|
394 |
+
df_updated = pd.concat([df_prev, df_summary], axis=0)
|
395 |
+
else:
|
396 |
+
summary_dir.mkdir(parents=True, exist_ok=True)
|
397 |
+
df_updated = df_summary
|
398 |
+
df_updated.to_csv(summary_path, index=False)
|
399 |
+
|
400 |
+
def make_metrics(self, likelihood_path: Path) -> None:
|
401 |
+
"""
|
402 |
+
Make metrics.
|
403 |
+
|
404 |
+
Args:
|
405 |
+
likelihood_path (Path): path to likelihood
|
406 |
+
"""
|
407 |
+
df_likelihood = pd.read_csv(likelihood_path)
|
408 |
+
whole_metrics = self.cal_whole_metrics(df_likelihood)
|
409 |
+
self.make_save_fig(whole_metrics, likelihood_path, self.fig_kind)
|
410 |
+
df_summary = self.make_summary(whole_metrics, likelihood_path, self.metrics_kind)
|
411 |
+
self.print_metrics(df_summary, self.metrics_kind)
|
412 |
+
self.update_summary(df_summary, likelihood_path)
|
413 |
+
|
414 |
+
|
415 |
+
class FigROCMixin:
|
416 |
+
"""
|
417 |
+
Class to plot ROC.
|
418 |
+
"""
|
419 |
+
def _plot_fig_group_metrics(self, group: str, group_metrics: Dict[str, LabelMetrics]) -> plt:
|
420 |
+
"""
|
421 |
+
Plot ROC.
|
422 |
+
|
423 |
+
Args:
|
424 |
+
group (str): group
|
425 |
+
group_metrics (Dict[str, LabelMetrics]): dictionary of label and its LabelMetrics
|
426 |
+
|
427 |
+
Returns:
|
428 |
+
plt: ROC
|
429 |
+
"""
|
430 |
+
label_list = group_metrics.keys()
|
431 |
+
num_rows = 1
|
432 |
+
num_cols = len(label_list)
|
433 |
+
base_size = 7
|
434 |
+
height = num_rows * base_size
|
435 |
+
width = num_cols * height
|
436 |
+
fig = plt.figure(figsize=(width, height))
|
437 |
+
|
438 |
+
for i, label_name in enumerate(label_list):
|
439 |
+
label_metrics = group_metrics[label_name]
|
440 |
+
offset = i + 1
|
441 |
+
ax_i = fig.add_subplot(
|
442 |
+
num_rows,
|
443 |
+
num_cols,
|
444 |
+
offset,
|
445 |
+
title=group + ': ' + label_name,
|
446 |
+
xlabel='1 - Specificity',
|
447 |
+
ylabel='Sensitivity',
|
448 |
+
xmargin=0,
|
449 |
+
ymargin=0
|
450 |
+
)
|
451 |
+
ax_i.plot(label_metrics.val.fpr, label_metrics.val.tpr, label=f"AUC_val = {label_metrics.val.auc:.2f}", marker='x')
|
452 |
+
ax_i.plot(label_metrics.test.fpr, label_metrics.test.tpr, label=f"AUC_test = {label_metrics.test.auc:.2f}", marker='o')
|
453 |
+
ax_i.grid()
|
454 |
+
ax_i.legend()
|
455 |
+
fig.tight_layout()
|
456 |
+
return fig
|
457 |
+
|
458 |
+
|
459 |
+
class FigYYMixin:
|
460 |
+
"""
|
461 |
+
Class to plot YY-graph.
|
462 |
+
"""
|
463 |
+
def _plot_fig_group_metrics(self, group: str, group_metrics: Dict[str, LabelMetrics]) -> plt:
|
464 |
+
"""
|
465 |
+
Plot yy.
|
466 |
+
|
467 |
+
Args:
|
468 |
+
group (str): group
|
469 |
+
group_metrics (Dict[str, LabelMetrics]): dictionary of label and its LabelMetrics
|
470 |
+
|
471 |
+
Returns:
|
472 |
+
plt: YY-graph
|
473 |
+
"""
|
474 |
+
label_list = group_metrics.keys()
|
475 |
+
num_splits = len(['val', 'test'])
|
476 |
+
num_rows = 1
|
477 |
+
num_cols = len(label_list) * num_splits
|
478 |
+
base_size = 7
|
479 |
+
height = num_rows * base_size
|
480 |
+
width = num_cols * height
|
481 |
+
fig = plt.figure(figsize=(width, height))
|
482 |
+
|
483 |
+
for i, label_name in enumerate(label_list):
|
484 |
+
label_metrics = group_metrics[label_name]
|
485 |
+
val_offset = (i * num_splits) + 1
|
486 |
+
test_offset = val_offset + 1
|
487 |
+
|
488 |
+
val_ax = fig.add_subplot(
|
489 |
+
num_rows,
|
490 |
+
num_cols,
|
491 |
+
val_offset,
|
492 |
+
title=group + ': ' + label_name + '\n' + 'val: Observed-Predicted Plot',
|
493 |
+
xlabel='Observed',
|
494 |
+
ylabel='Predicted',
|
495 |
+
xmargin=0,
|
496 |
+
ymargin=0
|
497 |
+
)
|
498 |
+
|
499 |
+
test_ax = fig.add_subplot(
|
500 |
+
num_rows,
|
501 |
+
num_cols,
|
502 |
+
test_offset,
|
503 |
+
title=group + ': ' + label_name + '\n' + 'test: Observed-Predicted Plot',
|
504 |
+
xlabel='Observed',
|
505 |
+
ylabel='Predicted',
|
506 |
+
xmargin=0,
|
507 |
+
ymargin=0
|
508 |
+
)
|
509 |
+
|
510 |
+
y_obs_val = label_metrics.val.y_obs
|
511 |
+
y_pred_val = label_metrics.val.y_pred
|
512 |
+
|
513 |
+
y_obs_test = label_metrics.test.y_obs
|
514 |
+
y_pred_test = label_metrics.test.y_pred
|
515 |
+
|
516 |
+
# Plot
|
517 |
+
color = mcolors.TABLEAU_COLORS
|
518 |
+
val_ax.scatter(y_obs_val, y_pred_val, color=color['tab:blue'], label='val')
|
519 |
+
test_ax.scatter(y_obs_test, y_pred_test, color=color['tab:orange'], label='test')
|
520 |
+
|
521 |
+
# Draw diagonal line
|
522 |
+
y_values_val = np.concatenate([y_obs_val.flatten(), y_pred_val.flatten()])
|
523 |
+
y_values_test = np.concatenate([y_obs_test.flatten(), y_pred_test.flatten()])
|
524 |
+
|
525 |
+
y_values_val_min, y_values_val_max, y_values_val_range = np.amin(y_values_val), np.amax(y_values_val), np.ptp(y_values_val)
|
526 |
+
y_values_test_min, y_values_test_max, y_values_test_range = np.amin(y_values_test), np.amax(y_values_test), np.ptp(y_values_test)
|
527 |
+
|
528 |
+
val_ax.plot([y_values_val_min - (y_values_val_range * 0.01), y_values_val_max + (y_values_val_range * 0.01)],
|
529 |
+
[y_values_val_min - (y_values_val_range * 0.01), y_values_val_max + (y_values_val_range * 0.01)], color='red')
|
530 |
+
|
531 |
+
test_ax.plot([y_values_test_min - (y_values_test_range * 0.01), y_values_test_max + (y_values_test_range * 0.01)],
|
532 |
+
[y_values_test_min - (y_values_test_range * 0.01), y_values_test_max + (y_values_test_range * 0.01)], color='red')
|
533 |
+
|
534 |
+
fig.tight_layout()
|
535 |
+
return fig
|
536 |
+
|
537 |
+
|
538 |
+
class FigMixin:
|
539 |
+
"""
|
540 |
+
Class for make and save figure
|
541 |
+
This class is for ROC and YY-graph.
|
542 |
+
"""
|
543 |
+
def make_save_fig(self, whole_metrics: Dict[str, Dict[str, LabelMetrics]], likelihood_path: Path, fig_kind: str) -> None:
|
544 |
+
"""
|
545 |
+
Make and save figure.
|
546 |
+
|
547 |
+
Args:
|
548 |
+
whole_metrics (Dict[str, Dict[str, LabelMetrics]]): metrics for all groups
|
549 |
+
likelihood_path (Path): path to likelihood
|
550 |
+
fig_kind (str): kind of figure, ie. 'roc' or 'yy'
|
551 |
+
"""
|
552 |
+
_datetime_dir = likelihood_path.parents[1]
|
553 |
+
save_dir = Path(_datetime_dir, fig_kind)
|
554 |
+
save_dir.mkdir(parents=True, exist_ok=True)
|
555 |
+
_fig_name = fig_kind + '_' + likelihood_path.stem.replace('likelihood_', '')
|
556 |
+
for group, group_metrics in whole_metrics.items():
|
557 |
+
fig = self._plot_fig_group_metrics(group, group_metrics)
|
558 |
+
save_path = Path(save_dir, group + '_' + _fig_name + '.png')
|
559 |
+
fig.savefig(save_path)
|
560 |
+
plt.close()
|
561 |
+
|
562 |
+
|
563 |
+
class ClsEval(MetricsMixin, ROCMixin, FigMixin, FigROCMixin):
|
564 |
+
"""
|
565 |
+
Class for calculation metrics for classification.
|
566 |
+
"""
|
567 |
+
def __init__(self) -> None:
|
568 |
+
self.fig_kind = 'roc'
|
569 |
+
self.metrics_kind = 'auc'
|
570 |
+
|
571 |
+
|
572 |
+
class RegEval(MetricsMixin, YYMixin, FigMixin, FigYYMixin):
|
573 |
+
"""
|
574 |
+
Class for calculation metrics for regression.
|
575 |
+
"""
|
576 |
+
def __init__(self) -> None:
|
577 |
+
self.fig_kind = 'yy'
|
578 |
+
self.metrics_kind = 'r2'
|
579 |
+
|
580 |
+
|
581 |
+
class DeepSurvEval(MetricsMixin, C_IndexMixin):
|
582 |
+
"""
|
583 |
+
Class for calculation metrics for DeepSurv.
|
584 |
+
"""
|
585 |
+
def __init__(self) -> None:
|
586 |
+
self.fig_kind = None
|
587 |
+
self.metrics_kind = 'c_index'
|
588 |
+
|
589 |
+
def make_metrics(self, likelihood_path: Path) -> None:
|
590 |
+
"""
|
591 |
+
Make metrics, substantially this method handles everything all.
|
592 |
+
|
593 |
+
Args:
|
594 |
+
likelihood_path (Path): path to likelihood
|
595 |
+
|
596 |
+
Overwrite def make_metrics() in class MetricsMixin by deleting self.make_save_fig(),
|
597 |
+
because of no need to plot and save figure.
|
598 |
+
"""
|
599 |
+
df_likelihood = pd.read_csv(likelihood_path)
|
600 |
+
whole_metrics = self.cal_whole_metrics(df_likelihood)
|
601 |
+
df_summary = self.make_summary(whole_metrics, likelihood_path, self.metrics_kind)
|
602 |
+
self.print_metrics(df_summary, self.metrics_kind)
|
603 |
+
self.update_summary(df_summary, likelihood_path)
|
604 |
+
|
605 |
+
|
606 |
+
def set_eval(task: str) -> Union[ClsEval, RegEval, DeepSurvEval]:
|
607 |
+
"""
|
608 |
+
Set class for evaluation depending on task depending on task.
|
609 |
+
|
610 |
+
Args:
|
611 |
+
task (str): task
|
612 |
+
|
613 |
+
Returns:
|
614 |
+
Union[ClsEval, RegEval, DeepSurvEval]: class for evaluation
|
615 |
+
"""
|
616 |
+
if task == 'classification':
|
617 |
+
return ClsEval()
|
618 |
+
elif task == 'regression':
|
619 |
+
return RegEval()
|
620 |
+
elif task == 'deepsurv':
|
621 |
+
return DeepSurvEval()
|
622 |
+
else:
|
623 |
+
raise ValueError(f"Invalid task: {task}.")
|
lib/options.py
ADDED
@@ -0,0 +1,655 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
# -*- coding: utf-8 -*-
|
3 |
+
|
4 |
+
import argparse
|
5 |
+
from distutils.util import strtobool
|
6 |
+
from pathlib import Path
|
7 |
+
import pandas as pd
|
8 |
+
import json
|
9 |
+
import torch
|
10 |
+
from .logger import BaseLogger
|
11 |
+
from typing import List, Dict, Tuple, Union
|
12 |
+
|
13 |
+
|
14 |
+
logger = BaseLogger.get_logger(__name__)
|
15 |
+
|
16 |
+
|
17 |
+
class Options:
|
18 |
+
"""
|
19 |
+
Class for options.
|
20 |
+
"""
|
21 |
+
def __init__(self, datetime: str = None, isTrain: bool = None) -> None:
|
22 |
+
"""
|
23 |
+
Args:
|
24 |
+
datetime (str, optional): date time Args:
|
25 |
+
isTrain (bool, optional): Variable indicating whether training or not. Defaults to None.
|
26 |
+
"""
|
27 |
+
self.parser = argparse.ArgumentParser(description='Options for training or test')
|
28 |
+
|
29 |
+
# CSV
|
30 |
+
self.parser.add_argument('--csvpath', type=str, required=True, help='path to csv for training or test')
|
31 |
+
|
32 |
+
# GPU Ids
|
33 |
+
self.parser.add_argument('--gpu_ids', type=str, default='cpu', help='gpu ids: e.g. 0, 0-1-2, 0-2. Use cpu for CPU (Default: cpu)')
|
34 |
+
|
35 |
+
if isTrain:
|
36 |
+
# Task
|
37 |
+
self.parser.add_argument('--task', type=str, required=True, choices=['classification', 'regression', 'deepsurv'], help='Task')
|
38 |
+
|
39 |
+
# Model
|
40 |
+
self.parser.add_argument('--model', type=str, required=True, help='model: MLP, CNN, ViT, or MLP+(CNN or ViT)')
|
41 |
+
self.parser.add_argument('--pretrained', type=strtobool, default=False, help='For use of pretrained model(CNN or ViT)')
|
42 |
+
|
43 |
+
# Training and Internal validation
|
44 |
+
self.parser.add_argument('--criterion', type=str, required=True, choices=['CEL', 'MSE', 'RMSE', 'MAE', 'NLL'], help='criterion')
|
45 |
+
self.parser.add_argument('--optimizer', type=str, default='Adam', choices=['SGD', 'Adadelta', 'RMSprop', 'Adam', 'RAdam'], help='optimizer')
|
46 |
+
self.parser.add_argument('--lr', type=float, metavar='N', help='learning rate')
|
47 |
+
self.parser.add_argument('--epochs', type=int, default=10, metavar='N', help='number of epochs (Default: 10)')
|
48 |
+
|
49 |
+
# Batch size
|
50 |
+
self.parser.add_argument('--batch_size', type=int, required=True, metavar='N', help='batch size in training')
|
51 |
+
|
52 |
+
# Preprocess for image
|
53 |
+
self.parser.add_argument('--augmentation', type=str, default='no', choices=['xrayaug', 'trivialaugwide', 'randaug', 'no'], help='kind of augmentation')
|
54 |
+
self.parser.add_argument('--normalize_image', type=str, choices=['yes', 'no'], default='yes', help='image normalization: yes, no (Default: yes)')
|
55 |
+
|
56 |
+
# Sampler
|
57 |
+
self.parser.add_argument('--sampler', type=str, default='no', choices=['yes', 'no'], help='sample data in training or not, yes or no')
|
58 |
+
|
59 |
+
# Input channel
|
60 |
+
self.parser.add_argument('--in_channel', type=int, required=True, choices=[1, 3], help='channel of input image')
|
61 |
+
self.parser.add_argument('--vit_image_size', type=int, default=0, help='input image size for ViT. Set 0 if not used ViT (Default: 0)')
|
62 |
+
|
63 |
+
# Weight saving strategy
|
64 |
+
self.parser.add_argument('--save_weight_policy', type=str, choices=['best', 'each'], default='best', help='Save weight policy: best, or each(ie. save each time loss decreases when multi-label output) (Default: best)')
|
65 |
+
|
66 |
+
else:
|
67 |
+
# Directory of weight at training
|
68 |
+
self.parser.add_argument('--weight_dir', type=str, default=None, help='directory of weight to be used when test. If None, the latest one is selected')
|
69 |
+
|
70 |
+
# Test bash size
|
71 |
+
self.parser.add_argument('--test_batch_size', type=int, default=1, metavar='N', help='batch size for test (Default: 1)')
|
72 |
+
|
73 |
+
# Splits for test
|
74 |
+
self.parser.add_argument('--test_splits', type=str, default='train-val-test', help='splits for test: e.g. test, val-test, train-val-test. (Default: train-val-test)')
|
75 |
+
|
76 |
+
self.args = self.parser.parse_args()
|
77 |
+
|
78 |
+
if datetime is not None:
|
79 |
+
self.args.datetime = datetime
|
80 |
+
|
81 |
+
assert isinstance(isTrain, bool), 'isTrain should be bool.'
|
82 |
+
self.args.isTrain = isTrain
|
83 |
+
|
84 |
+
def get_args(self) -> argparse.Namespace:
|
85 |
+
"""
|
86 |
+
Return arguments.
|
87 |
+
|
88 |
+
Returns:
|
89 |
+
argparse.Namespace: arguments
|
90 |
+
"""
|
91 |
+
return self.args
|
92 |
+
|
93 |
+
|
94 |
+
class CSVParser:
|
95 |
+
"""
|
96 |
+
Class to get information of csv and cast csv.
|
97 |
+
"""
|
98 |
+
def __init__(self, csvpath: str, task: str, isTrain: bool = None) -> None:
|
99 |
+
"""
|
100 |
+
Args:
|
101 |
+
csvpath (str): path to csv
|
102 |
+
task (str): task
|
103 |
+
isTrain (bool): if training or not
|
104 |
+
"""
|
105 |
+
self.csvpath = csvpath
|
106 |
+
self.task = task
|
107 |
+
|
108 |
+
_df_source = pd.read_csv(self.csvpath)
|
109 |
+
_df_source = _df_source[_df_source['split'] != 'exclude']
|
110 |
+
|
111 |
+
self.input_list = list(_df_source.columns[_df_source.columns.str.startswith('input')])
|
112 |
+
self.label_list = list(_df_source.columns[_df_source.columns.str.startswith('label')])
|
113 |
+
if self.task == 'deepsurv':
|
114 |
+
_period_name_list = list(_df_source.columns[_df_source.columns.str.startswith('period')])
|
115 |
+
assert (len(_period_name_list) == 1), f"One column of period should be contained in {self.csvpath} when deepsurv."
|
116 |
+
self.period_name = _period_name_list[0]
|
117 |
+
|
118 |
+
_df_source = self._cast(_df_source, self.task)
|
119 |
+
|
120 |
+
# If no column of group, add it.
|
121 |
+
if 'group' not in _df_source.columns:
|
122 |
+
_df_source = _df_source.assign(group='all')
|
123 |
+
|
124 |
+
self.df_source = _df_source
|
125 |
+
|
126 |
+
if isTrain:
|
127 |
+
self.mlp_num_inputs = len(self.input_list)
|
128 |
+
self.num_outputs_for_label = self._define_num_outputs_for_label(self.df_source, self.label_list, self.task)
|
129 |
+
|
130 |
+
def _cast(self, df_source: pd.DataFrame, task: str) -> pd.DataFrame:
|
131 |
+
"""
|
132 |
+
Make dictionary of cast depending on task.
|
133 |
+
|
134 |
+
Args:
|
135 |
+
df_source (pd.DataFrame): excluded DataFrame
|
136 |
+
task: (str): task
|
137 |
+
|
138 |
+
Returns:
|
139 |
+
DataFrame: csv excluded and cast depending on task
|
140 |
+
"""
|
141 |
+
_cast_input = {input_name: float for input_name in self.input_list}
|
142 |
+
|
143 |
+
if task == 'classification':
|
144 |
+
_cast_label = {label_name: int for label_name in self.label_list}
|
145 |
+
_casts = {**_cast_input, **_cast_label}
|
146 |
+
df_source = df_source.astype(_casts)
|
147 |
+
return df_source
|
148 |
+
|
149 |
+
elif task == 'regression':
|
150 |
+
_cast_label = {label_name: float for label_name in self.label_list}
|
151 |
+
_casts = {**_cast_input, **_cast_label}
|
152 |
+
df_source = df_source.astype(_casts)
|
153 |
+
return df_source
|
154 |
+
|
155 |
+
elif task == 'deepsurv':
|
156 |
+
_cast_label = {label_name: int for label_name in self.label_list}
|
157 |
+
_cast_period = {self.period_name: int}
|
158 |
+
_casts = {**_cast_input, **_cast_label, **_cast_period}
|
159 |
+
df_source = df_source.astype(_casts)
|
160 |
+
return df_source
|
161 |
+
|
162 |
+
else:
|
163 |
+
raise ValueError(f"Invalid task: {self.task}.")
|
164 |
+
|
165 |
+
def _define_num_outputs_for_label(self, df_source: pd.DataFrame, label_list: List[str], task :str) -> Dict[str, int]:
|
166 |
+
"""
|
167 |
+
Define the number of outputs for each label.
|
168 |
+
|
169 |
+
Args:
|
170 |
+
df_source (pd.DataFrame): DataFrame of csv
|
171 |
+
label_list (List[str]): list of labels
|
172 |
+
task: str
|
173 |
+
|
174 |
+
Returns:
|
175 |
+
Dict[str, int]: dictionary of the number of outputs for each label
|
176 |
+
eg.
|
177 |
+
classification: _num_outputs_for_label = {label_A: 2, label_B: 3, ...}
|
178 |
+
regression, deepsurv: _num_outputs_for_label = {label_A: 1, label_B: 1, ...}
|
179 |
+
deepsurv: _num_outputs_for_label = {label_A: 1}
|
180 |
+
"""
|
181 |
+
if task == 'classification':
|
182 |
+
_num_outputs_for_label = {label_name: df_source[label_name].nunique() for label_name in label_list}
|
183 |
+
return _num_outputs_for_label
|
184 |
+
|
185 |
+
elif (task == 'regression') or (task == 'deepsurv'):
|
186 |
+
_num_outputs_for_label = {label_name: 1 for label_name in label_list}
|
187 |
+
return _num_outputs_for_label
|
188 |
+
|
189 |
+
else:
|
190 |
+
raise ValueError(f"Invalid task: {task}.")
|
191 |
+
|
192 |
+
|
193 |
+
def _parse_model(model_name: str) -> Tuple[Union[str, None], Union[str, None]]:
|
194 |
+
"""
|
195 |
+
Parse model name.
|
196 |
+
|
197 |
+
Args:
|
198 |
+
model_name (str): model name (eg. MLP, ResNey18, or MLP+ResNet18)
|
199 |
+
|
200 |
+
Returns:
|
201 |
+
Tuple[str, str]: MLP, CNN or Vision Transformer name
|
202 |
+
eg. 'MLP', 'ResNet18', 'MLP+ResNet18' ->
|
203 |
+
['MLP'], ['ResNet18'], ['MLP', 'ResNet18']
|
204 |
+
"""
|
205 |
+
_model = model_name.split('+')
|
206 |
+
mlp = 'MLP' if 'MLP' in _model else None
|
207 |
+
_net = [_n for _n in _model if _n != 'MLP']
|
208 |
+
net = _net[0] if _net != [] else None
|
209 |
+
return mlp, net
|
210 |
+
|
211 |
+
|
212 |
+
def _parse_gpu_ids(gpu_ids: str) -> List[int]:
|
213 |
+
"""
|
214 |
+
Parse GPU ids concatenated with '-' to list of integers of GPU ids.
|
215 |
+
eg. '0-1-2' -> [0, 1, 2], '-1' -> []
|
216 |
+
|
217 |
+
Args:
|
218 |
+
gpu_ids (str): GPU Ids
|
219 |
+
|
220 |
+
Returns:
|
221 |
+
List[int]: list of GPU ids
|
222 |
+
"""
|
223 |
+
if (gpu_ids == 'cpu') or (gpu_ids == 'cpu\r'):
|
224 |
+
str_ids = []
|
225 |
+
else:
|
226 |
+
str_ids = gpu_ids.split('-')
|
227 |
+
_gpu_ids = []
|
228 |
+
for str_id in str_ids:
|
229 |
+
id = int(str_id)
|
230 |
+
if id >= 0:
|
231 |
+
_gpu_ids.append(id)
|
232 |
+
return _gpu_ids
|
233 |
+
|
234 |
+
|
235 |
+
def _get_latest_weight_dir() -> str:
|
236 |
+
"""
|
237 |
+
Return the latest path to directory of weight made at training.
|
238 |
+
|
239 |
+
Returns:
|
240 |
+
str: path to directory of the latest weight
|
241 |
+
eg. 'results/<project>/trials/2022-09-30-15-56-60/weights'
|
242 |
+
"""
|
243 |
+
_weight_dirs = list(Path('results').glob('*/trials/*/weights'))
|
244 |
+
assert (_weight_dirs != []), 'No directory of weight.'
|
245 |
+
weight_dir = max(_weight_dirs, key=lambda weight_dir: weight_dir.stat().st_mtime)
|
246 |
+
return str(weight_dir)
|
247 |
+
|
248 |
+
|
249 |
+
def _collect_weight_paths(weight_dir: str) -> List[str]:
|
250 |
+
"""
|
251 |
+
Return list of weight paths.
|
252 |
+
|
253 |
+
Args:
|
254 |
+
weight_dir (str): path to directory of weights
|
255 |
+
|
256 |
+
Returns:
|
257 |
+
List[str]: list of weight paths
|
258 |
+
"""
|
259 |
+
_weight_paths = list(Path(weight_dir).glob('*.pt'))
|
260 |
+
assert _weight_paths != [], f"No weight in {weight_dir}."
|
261 |
+
_weight_paths.sort(key=lambda path: path.stat().st_mtime)
|
262 |
+
_weight_paths = [str(weight_path) for weight_path in _weight_paths]
|
263 |
+
return _weight_paths
|
264 |
+
|
265 |
+
|
266 |
+
class ParamTable:
|
267 |
+
"""
|
268 |
+
Class to make table to dispatch parameters by group.
|
269 |
+
"""
|
270 |
+
def __init__(self) -> None:
|
271 |
+
# groups
|
272 |
+
# key is abbreviation, value is group name
|
273 |
+
self.groups = {
|
274 |
+
'mo': 'model',
|
275 |
+
'dl': 'dataloader',
|
276 |
+
'trc': 'train_conf',
|
277 |
+
'tsc': 'test_conf',
|
278 |
+
'sa': 'save',
|
279 |
+
'lo': 'load',
|
280 |
+
'trp': 'train_print',
|
281 |
+
'tsp': 'test_print'
|
282 |
+
}
|
283 |
+
|
284 |
+
mo = self.groups['mo']
|
285 |
+
dl = self.groups['dl']
|
286 |
+
trc = self.groups['trc']
|
287 |
+
tsc = self.groups['tsc']
|
288 |
+
sa = self.groups['sa']
|
289 |
+
lo = self.groups['lo']
|
290 |
+
trp = self.groups['trp']
|
291 |
+
tsp = self.groups['tsp']
|
292 |
+
|
293 |
+
# The below shows that which group each parameter dispatches to.
|
294 |
+
self.dispatch = {
|
295 |
+
'datetime': [sa],
|
296 |
+
'project': [sa, trp, tsp],
|
297 |
+
'csvpath': [sa, trp, tsp],
|
298 |
+
'task': [dl, tsc, sa, lo, trp, tsp],
|
299 |
+
'isTrain': [dl, trp, tsp],
|
300 |
+
|
301 |
+
'model': [sa, lo, trp, tsp],
|
302 |
+
'vit_image_size': [mo, sa, lo, trp, tsp],
|
303 |
+
'pretrained': [mo, sa, trp],
|
304 |
+
'mlp': [mo, dl],
|
305 |
+
'net': [mo, dl],
|
306 |
+
|
307 |
+
'weight_dir': [tsc, tsp],
|
308 |
+
'weight_paths': [tsc],
|
309 |
+
|
310 |
+
'criterion': [trc, sa, trp],
|
311 |
+
'optimizer': [trc, sa, trp],
|
312 |
+
'lr': [trc, sa, trp],
|
313 |
+
'epochs': [trc, sa, trp],
|
314 |
+
|
315 |
+
'batch_size': [dl, sa, trp],
|
316 |
+
'test_batch_size': [dl, tsp],
|
317 |
+
'test_splits': [tsc, tsp],
|
318 |
+
|
319 |
+
'in_channel': [mo, dl, sa, lo, trp, tsp],
|
320 |
+
'normalize_image': [dl, sa, lo, trp, tsp],
|
321 |
+
'augmentation': [dl, sa, trp],
|
322 |
+
'sampler': [dl, sa, trp],
|
323 |
+
|
324 |
+
'df_source': [dl],
|
325 |
+
'label_list': [dl, trc, sa, lo],
|
326 |
+
'input_list': [dl, sa, lo],
|
327 |
+
'period_name': [dl, sa, lo],
|
328 |
+
'mlp_num_inputs': [mo, sa, lo],
|
329 |
+
'num_outputs_for_label': [mo, sa, lo, tsc],
|
330 |
+
|
331 |
+
'save_weight_policy': [sa, trp, trc],
|
332 |
+
'scaler_path': [dl, tsp],
|
333 |
+
'save_datetime_dir': [trc, tsc, trp, tsp],
|
334 |
+
|
335 |
+
'gpu_ids': [trc, tsc, sa, trp, tsp],
|
336 |
+
'device': [mo, trc, tsc],
|
337 |
+
'dataset_info': [trc, sa, trp, tsp]
|
338 |
+
}
|
339 |
+
|
340 |
+
self.table = self._make_table()
|
341 |
+
|
342 |
+
def _make_table(self) -> pd.DataFrame:
|
343 |
+
"""
|
344 |
+
Make table to dispatch parameters by group.
|
345 |
+
|
346 |
+
Returns:
|
347 |
+
pd.DataFrame: table which shows that which group each parameter belongs to.
|
348 |
+
"""
|
349 |
+
df_table = pd.DataFrame([], index=self.dispatch.keys(), columns=self.groups.values()).fillna('no')
|
350 |
+
for param, grps in self.dispatch.items():
|
351 |
+
for grp in grps:
|
352 |
+
df_table.loc[param, grp] = 'yes'
|
353 |
+
|
354 |
+
df_table = df_table.reset_index()
|
355 |
+
df_table = df_table.rename(columns={'index': 'parameter'})
|
356 |
+
return df_table
|
357 |
+
|
358 |
+
def get_by_group(self, group_name: str) -> List[str]:
|
359 |
+
"""
|
360 |
+
Return list of parameters which belong to group
|
361 |
+
|
362 |
+
Args:
|
363 |
+
group_name (str): group name
|
364 |
+
|
365 |
+
Returns:
|
366 |
+
List[str]: list of parameters
|
367 |
+
"""
|
368 |
+
_df_table = self.table
|
369 |
+
_param_names = _df_table[_df_table[group_name] == 'yes']['parameter'].tolist()
|
370 |
+
return _param_names
|
371 |
+
|
372 |
+
|
373 |
+
Param_Table = ParamTable()
|
374 |
+
|
375 |
+
|
376 |
+
class ParamSet:
|
377 |
+
"""
|
378 |
+
Class to store required parameters for each group.
|
379 |
+
"""
|
380 |
+
pass
|
381 |
+
|
382 |
+
|
383 |
+
def _dispatch_by_group(args: argparse.Namespace, group_name: str) -> ParamSet:
|
384 |
+
"""
|
385 |
+
Dispatch parameters depending on group.
|
386 |
+
|
387 |
+
Args:
|
388 |
+
args (argparse.Namespace): arguments
|
389 |
+
group_name (str): group
|
390 |
+
|
391 |
+
Returns:
|
392 |
+
ParamSet: class containing parameters for group
|
393 |
+
"""
|
394 |
+
_param_names = Param_Table.get_by_group(group_name)
|
395 |
+
param_set = ParamSet()
|
396 |
+
for param_name in _param_names:
|
397 |
+
if hasattr(args, param_name):
|
398 |
+
_arg = getattr(args, param_name)
|
399 |
+
setattr(param_set, param_name, _arg)
|
400 |
+
return param_set
|
401 |
+
|
402 |
+
|
403 |
+
def save_parameter(params: ParamSet, save_path: str) -> None:
|
404 |
+
"""
|
405 |
+
Save parameters.
|
406 |
+
|
407 |
+
Args:
|
408 |
+
params (ParamSet): parameters
|
409 |
+
|
410 |
+
save_path (str): save path for parameters
|
411 |
+
"""
|
412 |
+
_saved = {_param: _arg for _param, _arg in vars(params).items()}
|
413 |
+
save_dir = Path(save_path).parents[0]
|
414 |
+
save_dir.mkdir(parents=True, exist_ok=True)
|
415 |
+
with open(save_path, 'w') as f:
|
416 |
+
json.dump(_saved, f, indent=4)
|
417 |
+
|
418 |
+
|
419 |
+
def _retrieve_parameter(parameter_path: str) -> Dict[str, Union[str, int, float]]:
|
420 |
+
"""
|
421 |
+
Retrieve only parameters required at test from parameters at training.
|
422 |
+
|
423 |
+
Args:
|
424 |
+
parameter_path (str): path to parameter_path
|
425 |
+
|
426 |
+
Returns:
|
427 |
+
Dict[str, Union[str, int, float]]: parameters at training
|
428 |
+
"""
|
429 |
+
with open(parameter_path) as f:
|
430 |
+
params = json.load(f)
|
431 |
+
|
432 |
+
_required = Param_Table.get_by_group('load')
|
433 |
+
params = {p: v for p, v in params.items() if p in _required}
|
434 |
+
return params
|
435 |
+
|
436 |
+
|
437 |
+
def print_parameter(params: ParamSet) -> None:
|
438 |
+
"""
|
439 |
+
Print parameters.
|
440 |
+
|
441 |
+
Args:
|
442 |
+
params (ParamSet): parameters
|
443 |
+
"""
|
444 |
+
|
445 |
+
LINE_LENGTH = 82
|
446 |
+
|
447 |
+
if params.isTrain:
|
448 |
+
phase = 'Training'
|
449 |
+
else:
|
450 |
+
phase = 'Test'
|
451 |
+
|
452 |
+
_header = f" Configuration of {phase} "
|
453 |
+
_padding = (LINE_LENGTH - len(_header) + 1) // 2 # round up
|
454 |
+
_header = ('-' * _padding) + _header + ('-' * _padding) + '\n'
|
455 |
+
|
456 |
+
_footer = ' End '
|
457 |
+
_padding = (LINE_LENGTH - len(_footer) + 1) // 2
|
458 |
+
_footer = ('-' * _padding) + _footer + ('-' * _padding) + '\n'
|
459 |
+
|
460 |
+
message = ''
|
461 |
+
message += _header
|
462 |
+
|
463 |
+
_params_dict = vars(params)
|
464 |
+
del _params_dict['isTrain']
|
465 |
+
for _param, _arg in _params_dict.items():
|
466 |
+
_str_arg = _arg2str(_param, _arg)
|
467 |
+
message += f"{_param:>30}: {_str_arg:<40}\n"
|
468 |
+
|
469 |
+
message += _footer
|
470 |
+
logger.info(message)
|
471 |
+
|
472 |
+
|
473 |
+
def _arg2str(param: str, arg: Union[str, int, float]) -> str:
|
474 |
+
"""
|
475 |
+
Convert argument to string.
|
476 |
+
|
477 |
+
Args:
|
478 |
+
param (str): parameter
|
479 |
+
arg (Union[str, int, float]): argument
|
480 |
+
|
481 |
+
Returns:
|
482 |
+
str: strings of argument
|
483 |
+
"""
|
484 |
+
if param == 'lr':
|
485 |
+
if arg is None:
|
486 |
+
str_arg = 'Default'
|
487 |
+
else:
|
488 |
+
str_arg = str(param)
|
489 |
+
return str_arg
|
490 |
+
elif param == 'gpu_ids':
|
491 |
+
if arg == []:
|
492 |
+
str_arg = 'CPU selected'
|
493 |
+
else:
|
494 |
+
str_arg = f"{arg} (Primary GPU:{arg[0]})"
|
495 |
+
return str_arg
|
496 |
+
elif param == 'test_splits':
|
497 |
+
str_arg = ', '.join(arg)
|
498 |
+
return str_arg
|
499 |
+
elif param == 'dataset_info':
|
500 |
+
str_arg = ''
|
501 |
+
for i, (split, total) in enumerate(arg.items()):
|
502 |
+
if i < len(arg) - 1:
|
503 |
+
str_arg += (f"{split}_data={total}, ")
|
504 |
+
else:
|
505 |
+
str_arg += (f"{split}_data={total}")
|
506 |
+
return str_arg
|
507 |
+
else:
|
508 |
+
if arg is None:
|
509 |
+
str_arg = 'No need'
|
510 |
+
else:
|
511 |
+
str_arg = str(arg)
|
512 |
+
return str_arg
|
513 |
+
|
514 |
+
|
515 |
+
def _check_if_valid_criterion(task: str = None, criterion: str = None) -> None:
|
516 |
+
"""
|
517 |
+
Check if criterion is valid.
|
518 |
+
|
519 |
+
Args:
|
520 |
+
task (str): task
|
521 |
+
criterion (str): criterion
|
522 |
+
"""
|
523 |
+
valid_criterion = {
|
524 |
+
'classification': ['CEL'],
|
525 |
+
'regression': ['MSE', 'RMSE', 'MAE'],
|
526 |
+
'deepsurv': ['NLL']
|
527 |
+
}
|
528 |
+
if criterion in valid_criterion[task]:
|
529 |
+
pass
|
530 |
+
else:
|
531 |
+
raise ValueError(f"Invalid criterion for task: task={task}, criterion={criterion}.")
|
532 |
+
|
533 |
+
|
534 |
+
def _train_parse(args: argparse.Namespace) -> Dict[str, ParamSet]:
|
535 |
+
"""
|
536 |
+
Parse parameters required at training.
|
537 |
+
|
538 |
+
Args:
|
539 |
+
args (argparse.Namespace): arguments
|
540 |
+
|
541 |
+
Returns:
|
542 |
+
Dict[str, ParamSet]: parameters dispatched by group
|
543 |
+
"""
|
544 |
+
# Check if criterion is valid.
|
545 |
+
_check_if_valid_criterion(task=args.task, criterion=args.criterion)
|
546 |
+
|
547 |
+
args.project = Path(args.csvpath).stem
|
548 |
+
args.gpu_ids = _parse_gpu_ids(args.gpu_ids)
|
549 |
+
args.device = torch.device(f"cuda:{args.gpu_ids[0]}") if args.gpu_ids != [] else torch.device('cpu')
|
550 |
+
args.mlp, args.net = _parse_model(args.model)
|
551 |
+
args.pretrained = bool(args.pretrained) # strtobool('False') = 0 (== False)
|
552 |
+
args.save_datetime_dir = str(Path('results', args.project, 'trials', args.datetime))
|
553 |
+
|
554 |
+
# Parse csv
|
555 |
+
_csvparser = CSVParser(args.csvpath, args.task, args.isTrain)
|
556 |
+
args.df_source = _csvparser.df_source
|
557 |
+
args.dataset_info = {split: len(args.df_source[args.df_source['split'] == split]) for split in ['train', 'val']}
|
558 |
+
args.input_list = _csvparser.input_list
|
559 |
+
args.label_list = _csvparser.label_list
|
560 |
+
args.mlp_num_inputs = _csvparser.mlp_num_inputs
|
561 |
+
args.num_outputs_for_label = _csvparser.num_outputs_for_label
|
562 |
+
if args.task == 'deepsurv':
|
563 |
+
args.period_name = _csvparser.period_name
|
564 |
+
|
565 |
+
# Dispatch parameters
|
566 |
+
return {
|
567 |
+
'args_model': _dispatch_by_group(args, 'model'),
|
568 |
+
'args_dataloader': _dispatch_by_group(args, 'dataloader'),
|
569 |
+
'args_conf': _dispatch_by_group(args, 'train_conf'),
|
570 |
+
'args_print': _dispatch_by_group(args, 'train_print'),
|
571 |
+
'args_save': _dispatch_by_group(args, 'save')
|
572 |
+
}
|
573 |
+
|
574 |
+
|
575 |
+
def _test_parse(args: argparse.Namespace) -> Dict[str, ParamSet]:
|
576 |
+
"""
|
577 |
+
Parse parameters required at test.
|
578 |
+
|
579 |
+
Args:
|
580 |
+
args (argparse.Namespace): arguments
|
581 |
+
|
582 |
+
Returns:
|
583 |
+
Dict[str, ParamSet]: parameters dispatched by group
|
584 |
+
"""
|
585 |
+
args.project = Path(args.csvpath).stem
|
586 |
+
args.gpu_ids = _parse_gpu_ids(args.gpu_ids)
|
587 |
+
args.device = torch.device(f"cuda:{args.gpu_ids[0]}") if args.gpu_ids != [] else torch.device('cpu')
|
588 |
+
|
589 |
+
# Collect weight paths
|
590 |
+
if args.weight_dir is None:
|
591 |
+
args.weight_dir = _get_latest_weight_dir()
|
592 |
+
args.weight_paths = _collect_weight_paths(args.weight_dir)
|
593 |
+
|
594 |
+
# Get datetime at training
|
595 |
+
_train_datetime_dir = Path(args.weight_dir).parents[0]
|
596 |
+
_train_datetime = _train_datetime_dir.name
|
597 |
+
|
598 |
+
args.save_datetime_dir = str(Path('results', args.project, 'trials', _train_datetime))
|
599 |
+
|
600 |
+
# Retrieve only parameters required at test
|
601 |
+
_parameter_path = str(Path(_train_datetime_dir, 'parameters.json'))
|
602 |
+
params = _retrieve_parameter(_parameter_path)
|
603 |
+
for _param, _arg in params.items():
|
604 |
+
setattr(args, _param, _arg)
|
605 |
+
|
606 |
+
# When test, the followings are always fixed.
|
607 |
+
args.augmentation = 'no'
|
608 |
+
args.sampler = 'no'
|
609 |
+
args.pretrained = False
|
610 |
+
|
611 |
+
args.mlp, args.net = _parse_model(args.model)
|
612 |
+
if args.mlp is not None:
|
613 |
+
args.scaler_path = str(Path(_train_datetime_dir, 'scaler.pkl'))
|
614 |
+
|
615 |
+
# Parse csv
|
616 |
+
_csvparser = CSVParser(args.csvpath, args.task)
|
617 |
+
args.df_source = _csvparser.df_source
|
618 |
+
|
619 |
+
# Align test_splits
|
620 |
+
args.test_splits = args.test_splits.split('-')
|
621 |
+
_splits = args.df_source['split'].unique().tolist()
|
622 |
+
if set(_splits) < set(args.test_splits):
|
623 |
+
args.test_splits = _splits
|
624 |
+
|
625 |
+
args.dataset_info = {split: len(args.df_source[args.df_source['split'] == split]) for split in args.test_splits}
|
626 |
+
|
627 |
+
# Dispatch parameters
|
628 |
+
return {
|
629 |
+
'args_model': _dispatch_by_group(args, 'model'),
|
630 |
+
'args_dataloader': _dispatch_by_group(args, 'dataloader'),
|
631 |
+
'args_conf': _dispatch_by_group(args, 'test_conf'),
|
632 |
+
'args_print': _dispatch_by_group(args, 'test_print')
|
633 |
+
}
|
634 |
+
|
635 |
+
def set_options(datetime_name: str = None, phase: str = None) -> argparse.Namespace:
|
636 |
+
"""
|
637 |
+
Parse options for training or test.
|
638 |
+
|
639 |
+
Args:
|
640 |
+
datetime_name (str, optional): datetime name. Defaults to None.
|
641 |
+
phase (str, optional): train or test. Defaults to None.
|
642 |
+
|
643 |
+
Returns:
|
644 |
+
argparse.Namespace: arguments
|
645 |
+
"""
|
646 |
+
if phase == 'train':
|
647 |
+
opt = Options(datetime=datetime_name, isTrain=True)
|
648 |
+
_args = opt.get_args()
|
649 |
+
args = _train_parse(_args)
|
650 |
+
return args
|
651 |
+
else:
|
652 |
+
opt = Options(isTrain=False)
|
653 |
+
_args = opt.get_args()
|
654 |
+
args = _test_parse(_args)
|
655 |
+
return args
|