eksemyashkina commited on
Commit
af720c2
·
verified ·
1 Parent(s): 656b785

Upload 9 files

Browse files
src/dataset.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List, Tuple, Callable
2
+ from pathlib import Path
3
+ import datasets
4
+ import torch
5
+ from torch.utils.data import Dataset
6
+
7
+
8
+ class SegmentationDataset(Dataset):
9
+ def __init__(
10
+ self,
11
+ dataset: datasets.Dataset,
12
+ train: bool = True,
13
+ transform: Callable = None,
14
+ target_transform: Callable = None,
15
+ test_size: float = 0.25,
16
+ ) -> None:
17
+ super().__init__()
18
+ self.dataset = dataset
19
+ self.train = train
20
+ self.transform = transform
21
+ self.target_transform = target_transform
22
+ self.test_size = test_size
23
+
24
+ total_size = len(dataset)
25
+ indices = list(range(total_size))
26
+ split = int(self.test_size * total_size)
27
+
28
+ if train:
29
+ self.indices = indices[split:]
30
+ else:
31
+ self.indices = indices[:split]
32
+
33
+ def __len__(self) -> int:
34
+ return len(self.indices)
35
+
36
+ def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor]:
37
+ item = self.dataset[self.indices[idx]]
38
+ image = item["image"]
39
+ mask = item["mask"]
40
+ if self.transform:
41
+ image = self.transform(image)
42
+ if self.target_transform:
43
+ mask = self.target_transform(mask)
44
+ return image, mask
45
+
46
+
47
+ def collate_fn(items: List[Tuple[torch.Tensor, torch.Tensor]]) -> Tuple[torch.Tensor, torch.Tensor]:
48
+ images = torch.stack([item[0] for item in items])
49
+ masks = torch.stack([item[1] for item in items])
50
+ return images, masks
src/get_loss.py ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Dict, Callable
2
+ import torch.nn as nn
3
+ import torch
4
+
5
+ from losses import SoftDiceLoss, SSLoss, IoULoss, TverskyLoss, FocalTversky_loss, AsymLoss, ExpLog_loss, FocalLoss, LovaszSoftmax, TopKLoss, WeightedCrossEntropyLoss, SoftDiceLoss_v2, IoULoss_v2, TverskyLoss_v2, FocalTversky_loss_v2, AsymLoss_v2, SSLoss_v2
6
+
7
+
8
+ def get_loss(loss_type: str) -> Callable | None:
9
+ if loss_type == "cross_entropy":
10
+ return nn.CrossEntropyLoss()
11
+ elif loss_type == "SoftDiceLoss":
12
+ return SoftDiceLoss()
13
+ elif loss_type == "SSLoss":
14
+ return SSLoss()
15
+ elif loss_type == "IoULoss":
16
+ return IoULoss()
17
+ elif loss_type == "TverskyLoss":
18
+ return TverskyLoss()
19
+ elif loss_type == "FocalTversky_loss":
20
+ tversky_kwargs = {
21
+ "apply_nonlin": None,
22
+ "batch_dice": False,
23
+ "do_bg": True,
24
+ "smooth": 1.0,
25
+ "square": False
26
+ }
27
+ return FocalTversky_loss(tversky_kwargs=tversky_kwargs)
28
+ elif loss_type == "AsymLoss":
29
+ return AsymLoss()
30
+ elif loss_type == "ExpLog_loss":
31
+ soft_dice_kwargs = {
32
+ "smooth": 1.0
33
+ }
34
+ wce_kwargs = {
35
+ "weight": None
36
+ }
37
+ return ExpLog_loss(soft_dice_kwargs=soft_dice_kwargs, wce_kwargs=wce_kwargs)
38
+ elif loss_type == "FocalLoss":
39
+ return FocalLoss()
40
+ elif loss_type == "LovaszSoftmax":
41
+ return LovaszSoftmax()
42
+ elif loss_type == "TopKLoss":
43
+ return TopKLoss()
44
+ elif loss_type == "WeightedCrossEntropyLoss":
45
+ return WeightedCrossEntropyLoss()
46
+ elif loss_type == "SoftDiceLoss_v2":
47
+ return SoftDiceLoss_v2()
48
+ elif loss_type == "IoULoss_v2":
49
+ return IoULoss_v2()
50
+ elif loss_type == "TverskyLoss_v2":
51
+ return TverskyLoss_v2()
52
+ elif loss_type == "FocalTversky_loss_v2":
53
+ return FocalTversky_loss_v2()
54
+ elif loss_type == "AsymLoss_v2":
55
+ return AsymLoss_v2()
56
+ elif loss_type == "SSLoss_v2":
57
+ return SSLoss_v2()
58
+ else:
59
+ raise ValueError(f"Unsupported loss type: {loss_type}")
60
+
61
+
62
+ def get_composite_criterion(losses_config: Dict[str, float]) -> Callable[[torch.Tensor, torch.Tensor], torch.Tensor]:
63
+ losses = []
64
+ weights = []
65
+
66
+ for loss_name, weight in losses_config.items():
67
+ if weight != 0.0:
68
+ loss_fn = get_loss(loss_name)
69
+ if loss_fn is not None:
70
+ losses.append(loss_fn)
71
+ weights.append(weight)
72
+
73
+ def composite_loss(output: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
74
+ total_loss = 0.0
75
+ for loss_fn, weight in zip(losses, weights):
76
+ total_loss += weight * loss_fn(output, target)
77
+ return total_loss
78
+
79
+ return composite_loss
src/losses.py ADDED
@@ -0,0 +1,498 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Callable, List, Tuple, Dict
2
+ import torch
3
+ from torch import nn
4
+ import torch.nn.functional as F
5
+ import numpy as np
6
+
7
+
8
+ def sum_tensor(inp: torch.Tensor, axes: int | List[int], keepdim: bool = False) -> torch.Tensor:
9
+ axes = np.unique(axes).astype(int)
10
+ if keepdim:
11
+ for ax in axes:
12
+ inp = inp.sum(int(ax), keepdim=True)
13
+ else:
14
+ for ax in sorted(axes, reverse=True):
15
+ inp = inp.sum(int(ax))
16
+ return inp
17
+
18
+
19
+ def get_tp_fp_fn(net_output: torch.Tensor, gt: torch.Tensor, axes: int | Tuple[int, ...] | None = None, mask: torch.Tensor | None = None, square: bool = False) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
20
+ if axes is None:
21
+ axes = tuple(range(2, len(net_output.size())))
22
+ shp_x = net_output.shape
23
+ shp_y = gt.shape
24
+ with torch.no_grad():
25
+ if len(shp_x) != len(shp_y):
26
+ gt = gt.view((shp_y[0], 1, *shp_y[1:]))
27
+ if all([i == j for i, j in zip(net_output.shape, gt.shape)]):
28
+ y_onehot = gt
29
+ else:
30
+ gt = gt.long()
31
+ y_onehot = torch.zeros(shp_x)
32
+ if net_output.device.type == "cuda":
33
+ y_onehot = y_onehot.cuda(net_output.device.index)
34
+ y_onehot.scatter_(1, gt, 1)
35
+ tp = net_output * y_onehot
36
+ fp = net_output * (1 - y_onehot)
37
+ fn = (1 - net_output) * y_onehot
38
+ if mask is not None:
39
+ tp = torch.stack(tuple(x_i * mask[:, 0] for x_i in torch.unbind(tp, dim=1)), dim=1)
40
+ fp = torch.stack(tuple(x_i * mask[:, 0] for x_i in torch.unbind(fp, dim=1)), dim=1)
41
+ fn = torch.stack(tuple(x_i * mask[:, 0] for x_i in torch.unbind(fn, dim=1)), dim=1)
42
+ if square:
43
+ tp = tp ** 2
44
+ fp = fp ** 2
45
+ fn = fn ** 2
46
+ tp = sum_tensor(tp, axes, keepdim=False)
47
+ fp = sum_tensor(fp, axes, keepdim=False)
48
+ fn = sum_tensor(fn, axes, keepdim=False)
49
+ return tp, fp, fn
50
+
51
+
52
+ def softmax_helper(x: torch.Tensor) -> torch.Tensor:
53
+ rpt = [1 for _ in range(len(x.size()))]
54
+ rpt[1] = x.size(1)
55
+ x_max = x.max(1, keepdim=True)[0].repeat(*rpt)
56
+ e_x = torch.exp(x - x_max)
57
+ return e_x / e_x.sum(1, keepdim=True).repeat(*rpt)
58
+
59
+ def flatten(tensor: torch.Tensor) -> torch.Tensor:
60
+ C = tensor.size(1)
61
+ axis_order = (1, 0) + tuple(range(2, tensor.dim()))
62
+ transposed = tensor.permute(axis_order).contiguous()
63
+ return transposed.view(C, -1)
64
+
65
+
66
+ class SoftDiceLoss(nn.Module):
67
+ def __init__(self, apply_nonlin: Callable | None = softmax_helper, batch_dice: bool = True, do_bg: bool = False, smooth: float = 1.0, square: bool = True) -> None:
68
+ super().__init__()
69
+ self.square = square
70
+ self.do_bg = do_bg
71
+ self.batch_dice = batch_dice
72
+ self.apply_nonlin = apply_nonlin
73
+ self.smooth = smooth
74
+
75
+ def forward(self, x: torch.Tensor, y: torch.Tensor, loss_mask: torch.Tensor | None = None) -> torch.Tensor:
76
+ shp_x = x.shape
77
+ if self.batch_dice:
78
+ axes = [0] + list(range(2, len(shp_x)))
79
+ else:
80
+ axes = list(range(2, len(shp_x)))
81
+ if self.apply_nonlin is not None:
82
+ x = self.apply_nonlin(x)
83
+ tp, fp, fn = get_tp_fp_fn(x, y, axes, loss_mask, self.square)
84
+ dc = (2 * tp + self.smooth) / (2 * tp + fp + fn + self.smooth)
85
+ if not self.do_bg:
86
+ if self.batch_dice:
87
+ dc = dc[1:]
88
+ else:
89
+ dc = dc[:, 1:]
90
+ dc = dc.mean()
91
+ return -dc
92
+
93
+
94
+ class SoftDiceLoss_v2(nn.Module):
95
+ def __init__(self, smooth: float = 1.0) -> None:
96
+ super().__init__()
97
+ self.smooth = smooth
98
+
99
+ def forward(self, logits: torch.Tensor, targets: torch.Tensor) -> torch.Tensor:
100
+ probs = F.softmax(logits, dim=1)
101
+ targets = F.one_hot(targets, num_classes=probs.size(1)).permute(0, 3, 1, 2).float()
102
+ intersection = torch.sum(probs * targets, dim=(0, 2, 3))
103
+ union = torch.sum(probs + targets, dim=(0, 2, 3))
104
+ dl = 1 - (2.0 * intersection + self.smooth) / (union + self.smooth)
105
+ dice_loss = torch.mean(dl)
106
+ return dice_loss
107
+
108
+
109
+ class SSLoss(nn.Module):
110
+ def __init__(self, apply_nonlin: Callable | None = softmax_helper, batch_dice: bool = True, do_bg: bool = False, smooth: float = 1., square: bool = True) -> None:
111
+ super().__init__()
112
+ self.square = square
113
+ self.do_bg = do_bg
114
+ self.batch_dice = batch_dice
115
+ self.apply_nonlin = apply_nonlin
116
+ self.smooth = smooth
117
+ self.r = 0.1
118
+
119
+ def forward(self, net_output: torch.Tensor, gt: torch.Tensor) -> torch.Tensor:
120
+ shp_x = net_output.shape
121
+ shp_y = gt.shape
122
+ with torch.no_grad():
123
+ if len(shp_x) != len(shp_y):
124
+ gt = gt.view((shp_y[0], 1, *shp_y[1:]))
125
+ if all([i == j for i, j in zip(net_output.shape, gt.shape)]):
126
+ y_onehot = gt
127
+ else:
128
+ gt = gt.long()
129
+ y_onehot = torch.zeros(shp_x)
130
+ if net_output.device.type == "cuda":
131
+ y_onehot = y_onehot.cuda(net_output.device.index)
132
+ y_onehot.scatter_(1, gt, 1)
133
+ if self.batch_dice:
134
+ axes = [0] + list(range(2, len(shp_x)))
135
+ else:
136
+ axes = list(range(2, len(shp_x)))
137
+ if self.apply_nonlin is not None:
138
+ net_output = self.apply_nonlin(net_output)
139
+ bg_onehot = 1 - y_onehot
140
+ squared_error = (y_onehot - net_output)**2
141
+ specificity_part = sum_tensor(squared_error*y_onehot, axes)/(sum_tensor(y_onehot, axes)+self.smooth)
142
+ sensitivity_part = sum_tensor(squared_error*bg_onehot, axes)/(sum_tensor(bg_onehot, axes)+self.smooth)
143
+ ss = self.r * specificity_part + (1-self.r) * sensitivity_part
144
+ if not self.do_bg:
145
+ if self.batch_dice:
146
+ ss = ss[1:]
147
+ else:
148
+ ss = ss[:, 1:]
149
+ ss = ss.mean()
150
+ return ss
151
+
152
+
153
+ class SSLoss_v2(nn.Module):
154
+ def __init__(self, alpha: float = 0.5) -> None:
155
+ super().__init__()
156
+ self.alpha = alpha
157
+
158
+ def forward(self, logits: torch.Tensor, targets: torch.Tensor) -> torch.Tensor:
159
+ probs = F.softmax(logits, dim=1)
160
+ targets = F.one_hot(targets, num_classes=probs.size(1)).permute(0, 3, 1, 2).float()
161
+ intersection = torch.sum(probs * targets, dim=(0, 2, 3))
162
+ cardinality = torch.sum(probs + targets, dim=(0, 2, 3))
163
+ dice_loss = 1 - (2.0 * intersection + 1e-6) / (cardinality + 1e-6)
164
+ ce_loss = F.cross_entropy(probs, targets, reduction='mean')
165
+ loss = 0.5 * dice_loss.mean() + (1 - 0.5) * ce_loss
166
+ return loss
167
+
168
+
169
+ class IoULoss(nn.Module):
170
+ def __init__(self, apply_nonlin: Callable | None = softmax_helper, batch_dice: bool = True, do_bg: bool = False, smooth: float = 1., square: bool = True) -> None:
171
+ super().__init__()
172
+ self.square = square
173
+ self.do_bg = do_bg
174
+ self.batch_dice = batch_dice
175
+ self.apply_nonlin = apply_nonlin
176
+ self.smooth = smooth
177
+
178
+ def forward(self, x: torch.Tensor, y: torch.Tensor, loss_mask: torch.Tensor | None = None) -> torch.Tensor:
179
+ shp_x = x.shape
180
+ if self.batch_dice:
181
+ axes = [0] + list(range(2, len(shp_x)))
182
+ else:
183
+ axes = list(range(2, len(shp_x)))
184
+ if self.apply_nonlin is not None:
185
+ x = self.apply_nonlin(x)
186
+ tp, fp, fn = get_tp_fp_fn(x, y, axes, loss_mask, self.square)
187
+ iou = (tp + self.smooth) / (tp + fp + fn + self.smooth)
188
+ if not self.do_bg:
189
+ if self.batch_dice:
190
+ iou = iou[1:]
191
+ else:
192
+ iou = iou[:, 1:]
193
+ iou = iou.mean()
194
+ return -iou
195
+
196
+
197
+ class IoULoss_v2(nn.Module):
198
+ def __init__(self, smooth: float = 1.0) -> None:
199
+ super().__init__()
200
+ self.smooth = smooth
201
+
202
+ def forward(self, logits: torch.Tensor, targets: torch.Tensor) -> torch.Tensor:
203
+ probs = F.softmax(logits, dim=1)
204
+ targets = F.one_hot(targets, num_classes=probs.size(1)).permute(0, 3, 1, 2).float()
205
+ intersection = torch.sum(probs * targets, dim=(0, 2, 3))
206
+ union = torch.sum(probs + targets, dim=(0, 2, 3)) - intersection
207
+ iou = 1 - (intersection + self.smooth) / (union + self.smooth)
208
+ iou_loss = torch.mean(iou)
209
+ return iou_loss
210
+
211
+
212
+ class TverskyLoss(nn.Module):
213
+ def __init__(self, apply_nonlin: Callable | None = softmax_helper, batch_dice: bool = True, do_bg: bool = False, smooth: float = 1., square: bool = True) -> None:
214
+ super().__init__()
215
+ self.square = square
216
+ self.do_bg = do_bg
217
+ self.batch_dice = batch_dice
218
+ self.apply_nonlin = apply_nonlin
219
+ self.smooth = smooth
220
+ self.alpha = 0.3
221
+ self.beta = 0.7
222
+
223
+ def forward(self, x: torch.Tensor, y: torch.Tensor, loss_mask: torch.Tensor | None = None) -> torch.Tensor:
224
+ shp_x = x.shape
225
+ if self.batch_dice:
226
+ axes = [0] + list(range(2, len(shp_x)))
227
+ else:
228
+ axes = list(range(2, len(shp_x)))
229
+ if self.apply_nonlin is not None:
230
+ x = self.apply_nonlin(x)
231
+ tp, fp, fn = get_tp_fp_fn(x, y, axes, loss_mask, self.square)
232
+ tversky = (tp + self.smooth) / (tp + self.alpha*fp + self.beta*fn + self.smooth)
233
+ if not self.do_bg:
234
+ if self.batch_dice:
235
+ tversky = tversky[1:]
236
+ else:
237
+ tversky = tversky[:, 1:]
238
+ tversky = tversky.mean()
239
+ return -tversky
240
+
241
+
242
+ class TverskyLoss_v2(nn.Module):
243
+ def __init__(self, alpha: float = 0.5, beta: float = 0.5, smooth: float = 1.0) -> None:
244
+ super().__init__()
245
+ self.alpha = alpha
246
+ self.beta = beta
247
+ self.smooth = smooth
248
+
249
+ def forward(self, logits: torch.Tensor, targets: torch.Tensor) -> torch.Tensor:
250
+ probs = F.softmax(logits, dim=1)
251
+ targets = F.one_hot(targets, num_classes=probs.size(1)).permute(0, 3, 1, 2).float()
252
+ tp = torch.sum(probs * targets, dim=(0, 2, 3))
253
+ fp = torch.sum((1 - targets) * probs, dim=(0, 2, 3))
254
+ fn = torch.sum(targets * (1 - probs), dim=(0, 2, 3))
255
+ tversky = 1 - (tp + self.smooth) / (tp + self.alpha * fp + self.beta * fn + self.smooth)
256
+ tversky_loss = torch.mean(tversky)
257
+ return tversky_loss
258
+
259
+
260
+ class FocalTversky_loss(nn.Module):
261
+ def __init__(self, tversky_kwargs: Dict, gamma: float = 0.75) -> None:
262
+ super().__init__()
263
+ self.gamma = gamma
264
+ self.tversky = TverskyLoss(**tversky_kwargs)
265
+
266
+ def forward(self, net_output: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
267
+ tversky_loss = 1 + self.tversky(net_output, target)
268
+ focal_tversky = torch.pow(tversky_loss, self.gamma)
269
+ return focal_tversky
270
+
271
+
272
+ class FocalTversky_loss_v2(nn.Module):
273
+ def __init__(self, alpha: float = 0.5, beta: float = 0.5, gamma: float = 1.5, smooth: float = 1.0) -> None:
274
+ super().__init__()
275
+ self.alpha = alpha
276
+ self.beta = beta
277
+ self.gamma = gamma
278
+ self.smooth = smooth
279
+
280
+ def forward(self, logits: torch.Tensor, targets: torch.Tensor) -> torch.Tensor:
281
+ probs = F.softmax(logits, dim=1)
282
+ targets = F.one_hot(targets, num_classes=probs.size(1)).permute(0, 3, 1, 2).float()
283
+ tp = torch.sum(probs * targets, dim=(0, 2, 3))
284
+ fp = torch.sum((1 - targets) * probs, dim=(0, 2, 3))
285
+ fn = torch.sum(targets * (1 - probs), dim=(0, 2, 3))
286
+ focal_tversky = (1 - (tp + self.smooth) / (tp + self.alpha * fp + self.beta * fn + self.smooth)) ** self.gamma
287
+ focal_tversky_loss = torch.mean(focal_tversky)
288
+ return focal_tversky_loss
289
+
290
+
291
+ class AsymLoss(nn.Module):
292
+ def __init__(self, apply_nonlin: Callable | None = softmax_helper, batch_dice: bool = True, do_bg: bool = False, smooth: float = 1., square: bool = True) -> None:
293
+ super().__init__()
294
+ self.square = square
295
+ self.do_bg = do_bg
296
+ self.batch_dice = batch_dice
297
+ self.apply_nonlin = apply_nonlin
298
+ self.smooth = smooth
299
+ self.beta = 1.5
300
+
301
+ def forward(self, x: torch.Tensor, y: torch.Tensor, loss_mask: torch.Tensor | None = None) -> torch.Tensor:
302
+ shp_x = x.shape
303
+ if self.batch_dice:
304
+ axes = [0] + list(range(2, len(shp_x)))
305
+ else:
306
+ axes = list(range(2, len(shp_x)))
307
+ if self.apply_nonlin is not None:
308
+ x = self.apply_nonlin(x)
309
+ tp, fp, fn = get_tp_fp_fn(x, y, axes, loss_mask, self.square)
310
+ weight = (self.beta**2)/(1+self.beta**2)
311
+ asym = (tp + self.smooth) / (tp + weight*fn + (1-weight)*fp + self.smooth)
312
+ if not self.do_bg:
313
+ if self.batch_dice:
314
+ asym = asym[1:]
315
+ else:
316
+ asym = asym[:, 1:]
317
+ asym = asym.mean()
318
+ return -asym
319
+
320
+
321
+ class AsymLoss_v2(nn.Module):
322
+ def __init__(self, alpha: float = 0.5, gamma: float = 2.0, smooth: float = 1e-5) -> None:
323
+ super().__init__()
324
+ self.alpha = alpha
325
+ self.gamma = gamma
326
+ self.smooth = smooth
327
+
328
+ def forward(self, logits: torch.Tensor, targets: torch.Tensor) -> torch.Tensor:
329
+ probs = F.softmax(logits, dim=1)
330
+ targets_one_hot = F.one_hot(targets, num_classes=probs.size(1)).permute(0, 3, 1, 2).float()
331
+ pos_loss = -self.alpha * (1 - probs) ** self.gamma * targets_one_hot * torch.log(probs + self.smooth)
332
+ neg_loss = -(1 - self.alpha) * probs ** self.gamma * (1 - targets_one_hot) * torch.log(1 - probs + self.smooth)
333
+ loss = pos_loss + neg_loss
334
+ return loss.mean()
335
+
336
+
337
+ class ExpLog_loss(nn.Module):
338
+ def __init__(self, soft_dice_kwargs: Dict, wce_kwargs: Dict, gamma: float = 0.3) -> None:
339
+ super().__init__()
340
+ self.wce = WeightedCrossEntropyLoss(**wce_kwargs)
341
+ self.dc = SoftDiceLoss_v2(**soft_dice_kwargs)
342
+ self.gamma = gamma
343
+
344
+ def forward(self, net_output: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
345
+ dc_loss = -self.dc(net_output, target)
346
+ wce_loss = self.wce(net_output, target)
347
+ explog_loss = 0.8*torch.pow(-torch.log(torch.clamp(dc_loss, 1e-6)), self.gamma) + 0.2*wce_loss
348
+ return explog_loss
349
+
350
+
351
+ class FocalLoss(nn.Module):
352
+ def __init__(self, apply_nonlin: Callable | None = softmax_helper, alpha: float | List[float] | np.ndarray | None = None, gamma: int = 2, balance_index: int = 0, smooth: float = 1e-4, size_average: bool = True) -> None:
353
+ super().__init__()
354
+ self.apply_nonlin = apply_nonlin
355
+ self.alpha = alpha
356
+ self.gamma = gamma
357
+ self.balance_index = balance_index
358
+ self.smooth = smooth
359
+ self.size_average = size_average
360
+ if self.smooth is not None:
361
+ if self.smooth < 0 or self.smooth > 1.0:
362
+ raise ValueError("smooth value should be in [0,1]")
363
+
364
+ def forward(self, logit: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
365
+ if self.apply_nonlin is not None:
366
+ logit = self.apply_nonlin(logit)
367
+ num_class = logit.shape[1]
368
+ if logit.dim() > 2:
369
+ logit = logit.view(logit.size(0), logit.size(1), -1)
370
+ logit = logit.permute(0, 2, 1).contiguous()
371
+ logit = logit.view(-1, logit.size(-1))
372
+ target = torch.squeeze(target, 1)
373
+ target = target.view(-1, 1)
374
+ alpha = self.alpha
375
+ if alpha is None:
376
+ alpha = torch.ones(num_class, 1)
377
+ elif isinstance(alpha, (list, np.ndarray)):
378
+ assert len(alpha) == num_class
379
+ alpha = torch.FloatTensor(alpha).view(num_class, 1)
380
+ alpha = alpha / alpha.sum()
381
+ elif isinstance(alpha, float):
382
+ alpha = torch.ones(num_class, 1)
383
+ alpha = alpha * (1 - self.alpha)
384
+ alpha[self.balance_index] = self.alpha
385
+ else:
386
+ raise TypeError("Not support alpha type")
387
+ if alpha.device != logit.device:
388
+ alpha = alpha.to(logit.device)
389
+ idx = target.cpu().long()
390
+ one_hot_key = torch.FloatTensor(target.size(0), num_class).zero_()
391
+ one_hot_key = one_hot_key.scatter_(1, idx, 1)
392
+ if one_hot_key.device != logit.device:
393
+ one_hot_key = one_hot_key.to(logit.device)
394
+ if self.smooth:
395
+ one_hot_key = torch.clamp(
396
+ one_hot_key, self.smooth/(num_class-1), 1.0 - self.smooth)
397
+ pt = (one_hot_key * logit).sum(1) + self.smooth
398
+ logpt = pt.log()
399
+ gamma = self.gamma
400
+ alpha = alpha[idx]
401
+ alpha = torch.squeeze(alpha)
402
+ loss = -1 * alpha * torch.pow((1 - pt), gamma) * logpt
403
+ if self.size_average:
404
+ loss = loss.mean()
405
+ else:
406
+ loss = loss.sum()
407
+ return loss
408
+
409
+
410
+ def lovasz_grad(gt_sorted: torch.Tensor) -> torch.Tensor:
411
+ p = len(gt_sorted)
412
+ gts = gt_sorted.sum()
413
+ intersection = gts - gt_sorted.float().cumsum(0)
414
+ union = gts + (1 - gt_sorted).float().cumsum(0)
415
+ jaccard = 1. - intersection / union
416
+ if p > 1:
417
+ jaccard[1:p] = jaccard[1:p] - jaccard[0:-1]
418
+ return jaccard
419
+
420
+
421
+ class LovaszSoftmax(nn.Module):
422
+ def __init__(self, reduction: str = "mean") -> None:
423
+ super().__init__()
424
+ self.reduction = reduction
425
+
426
+ def prob_flatten(self, input: torch.Tensor, target: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
427
+ assert input.dim() in [4, 5]
428
+ num_class = input.size(1)
429
+ if input.dim() == 4:
430
+ input = input.permute(0, 2, 3, 1).contiguous()
431
+ input_flatten = input.view(-1, num_class)
432
+ elif input.dim() == 5:
433
+ input = input.permute(0, 2, 3, 4, 1).contiguous()
434
+ input_flatten = input.view(-1, num_class)
435
+ target_flatten = target.view(-1)
436
+ return input_flatten, target_flatten
437
+
438
+ def lovasz_softmax_flat(self, inputs: torch.Tensor, targets: torch.Tensor) -> torch.Tensor:
439
+ num_classes = inputs.size(1)
440
+ losses = []
441
+ for c in range(num_classes):
442
+ target_c = (targets == c).float()
443
+ if num_classes == 1:
444
+ input_c = inputs[:, 0]
445
+ else:
446
+ input_c = inputs[:, c]
447
+ loss_c = (torch.autograd.Variable(target_c) - input_c).abs()
448
+ loss_c_sorted, loss_index = torch.sort(loss_c, 0, descending=True)
449
+ target_c_sorted = target_c[loss_index]
450
+ losses.append(torch.dot(loss_c_sorted, torch.autograd.Variable(lovasz_grad(target_c_sorted))))
451
+ losses = torch.stack(losses)
452
+ if self.reduction == "none":
453
+ loss = losses
454
+ elif self.reduction == "sum":
455
+ loss = losses.sum()
456
+ else:
457
+ loss = losses.mean()
458
+ return loss
459
+
460
+ def forward(self, inputs: torch.Tensor, targets: torch.Tensor) -> torch.Tensor:
461
+ inputs, targets = self.prob_flatten(inputs, targets)
462
+ losses = self.lovasz_softmax_flat(inputs, targets)
463
+ return losses
464
+
465
+
466
+ class TopKLoss(nn.Module):
467
+ def __init__(self, weight: torch.Tensor | None = None, ignore_index: int = -100, k: int = 10) -> None:
468
+ super().__init__()
469
+ self.k = k
470
+ self.cross_entropy = nn.CrossEntropyLoss(weight=weight, ignore_index=ignore_index, reduction="none")
471
+
472
+ def forward(self, inp: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
473
+ pixel_losses = self.cross_entropy(inp, target)
474
+ pixel_losses = pixel_losses.view(-1)
475
+ num_voxels = pixel_losses.numel()
476
+ res, _ = torch.topk(pixel_losses, int(num_voxels * self.k / 100), sorted=False)
477
+ return res.mean()
478
+
479
+
480
+ class WeightedCrossEntropyLoss(torch.nn.CrossEntropyLoss):
481
+ def __init__(self, weight: torch.Tensor | None = None) -> None:
482
+ super().__init__()
483
+ self.weight = weight
484
+
485
+ def forward(self, inp: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
486
+ target = target.long()
487
+ num_classes = inp.size()[1]
488
+ i0 = 1
489
+ i1 = 2
490
+ while i1 < len(inp.shape):
491
+ inp = inp.transpose(i0, i1)
492
+ i0 += 1
493
+ i1 += 1
494
+ inp = inp.contiguous()
495
+ inp = inp.view(-1, num_classes)
496
+ target = target.view(-1,)
497
+ wce_loss = torch.nn.CrossEntropyLoss(weight=self.weight)
498
+ return wce_loss(inp, target)
src/models/dino.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import Dinov2Backbone
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+
6
+ from src.models.segmentation_head import SegmentationHead
7
+
8
+
9
+ class DINOSegmentationModel(nn.Module):
10
+ def __init__(self, image_size: int = 224, num_classes: int = 18) -> None:
11
+ super().__init__()
12
+ self.mean = [0.485, 0.456, 0.406]
13
+ self.std = [0.229, 0.224, 0.225]
14
+ self.image_size = image_size
15
+ model_name = "facebook/dinov2-small"
16
+ self.backbone = Dinov2Backbone.from_pretrained(model_name)
17
+ for param in self.backbone.parameters():
18
+ param.requires_grad = False
19
+ self.segmentation_head = SegmentationHead(in_channels=384, num_classes=num_classes)
20
+
21
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
22
+ batch_size, channels, height, width = x.size()
23
+ assert height == width == self.image_size, "The image must match the size required by the DINO model"
24
+ features = self.backbone(pixel_values=x).feature_maps[0]
25
+ masks = self.segmentation_head(features)
26
+ return masks
27
+
28
+
29
+ def main() -> None:
30
+ # model = DINOSegmentationModel()
31
+ model = SegmentationHead(384, 18)
32
+ num_params = sum([p.numel() for p in model.parameters()])
33
+ print(f"params: {num_params/1e6:.2f} M")
34
+
35
+
36
+ if __name__ == "__main__":
37
+ main()
src/models/segmentation_head.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+
5
+ class SegmentationHead(nn.Module):
6
+ def __init__(self, in_channels: int, num_classes: int):
7
+ super().__init__()
8
+ self.head = nn.Sequential(
9
+ nn.Conv2d(in_channels, 256, kernel_size=3, padding=1),
10
+ nn.BatchNorm2d(256),
11
+ nn.ReLU(),
12
+ nn.Conv2d(256, 256, kernel_size=3, padding=1),
13
+ nn.BatchNorm2d(256),
14
+ nn.ReLU(),
15
+ nn.Upsample(size=(64, 64), mode="bilinear"),
16
+ nn.Conv2d(256, 128, kernel_size=3, padding=1),
17
+ nn.BatchNorm2d(128),
18
+ nn.ReLU(),
19
+ nn.Conv2d(128, 128, kernel_size=3, padding=1),
20
+ nn.BatchNorm2d(128),
21
+ nn.ReLU(),
22
+ nn.Upsample(size=(128, 128), mode="bilinear"),
23
+ nn.Conv2d(128, 64, kernel_size=3, padding=1),
24
+ nn.BatchNorm2d(64),
25
+ nn.ReLU(),
26
+ nn.Conv2d(64, 64, kernel_size=3, padding=1),
27
+ nn.BatchNorm2d(64),
28
+ nn.ReLU(),
29
+ nn.Upsample(size=(224, 224), mode="bilinear"),
30
+ nn.Conv2d(64, 32, kernel_size=3, padding=1),
31
+ nn.BatchNorm2d(32),
32
+ nn.ReLU(),
33
+ nn.Conv2d(32, 32, kernel_size=3, padding=1),
34
+ nn.BatchNorm2d(32),
35
+ nn.ReLU(),
36
+ nn.Conv2d(32, num_classes, kernel_size=3, padding=1),
37
+ )
38
+
39
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
40
+ return self.head(x)
src/models/unet.py ADDED
@@ -0,0 +1,171 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+
4
+
5
+ class UNet(nn.Module):
6
+ def __init__(self) -> None:
7
+ super().__init__()
8
+ self.mean = [0.485, 0.456, 0.406]
9
+ self.std = [0.229, 0.224, 0.225]
10
+ # Downsampler
11
+ self.enc_conv0 = nn.Sequential(
12
+ nn.Conv2d(in_channels=3, out_channels=64, kernel_size=3, stride=1, padding=1),
13
+ nn.LeakyReLU(inplace=True),
14
+ nn.BatchNorm2d(64),
15
+ nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1),
16
+ nn.LeakyReLU(inplace=True),
17
+ nn.BatchNorm2d(64),
18
+ nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1),
19
+ nn.LeakyReLU(inplace=True),
20
+ nn.BatchNorm2d(64)
21
+ )
22
+ self.pool0 = nn.MaxPool2d(kernel_size=2, stride=2)
23
+ self.enc_conv1 = nn.Sequential(
24
+ nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, stride=1, padding=1),
25
+ nn.LeakyReLU(inplace=True),
26
+ nn.BatchNorm2d(128),
27
+ nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, stride=1, padding=1),
28
+ nn.LeakyReLU(inplace=True),
29
+ nn.BatchNorm2d(128),
30
+ nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, stride=1, padding=1),
31
+ nn.LeakyReLU(inplace=True),
32
+ nn.BatchNorm2d(128)
33
+ )
34
+ self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)
35
+ self.enc_conv2 = nn.Sequential(
36
+ nn.Conv2d(in_channels=128, out_channels=256, kernel_size=3, stride=1, padding=1),
37
+ nn.LeakyReLU(inplace=True),
38
+ nn.BatchNorm2d(256),
39
+ nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, stride=1, padding=1),
40
+ nn.LeakyReLU(inplace=True),
41
+ nn.BatchNorm2d(256),
42
+ nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, stride=1, padding=1),
43
+ nn.LeakyReLU(inplace=True),
44
+ nn.BatchNorm2d(256)
45
+ )
46
+ self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2)
47
+ self.enc_conv3 = nn.Sequential(
48
+ nn.Conv2d(in_channels=256, out_channels=512, kernel_size=3, stride=1, padding=1),
49
+ nn.LeakyReLU(inplace=True),
50
+ nn.BatchNorm2d(512),
51
+ nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, stride=1, padding=1),
52
+ nn.LeakyReLU(inplace=True),
53
+ nn.BatchNorm2d(512),
54
+ nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, stride=1, padding=1),
55
+ nn.LeakyReLU(inplace=True),
56
+ nn.BatchNorm2d(512)
57
+ )
58
+ self.pool3 = nn.MaxPool2d(kernel_size=2, stride=2)
59
+
60
+ # bottleneck
61
+ self.bottleneck_conv = nn.Sequential(
62
+ nn.Conv2d(in_channels=512, out_channels=1024, kernel_size=3, stride=1, padding=1),
63
+ nn.LeakyReLU(inplace=True),
64
+ nn.BatchNorm2d(1024),
65
+ nn.Conv2d(in_channels=1024, out_channels=1024, kernel_size=3, stride=1, padding=1),
66
+ nn.LeakyReLU(inplace=True),
67
+ nn.BatchNorm2d(1024),
68
+ nn.Conv2d(in_channels=1024, out_channels=1024, kernel_size=3, stride=1, padding=1),
69
+ nn.LeakyReLU(inplace=True),
70
+ nn.BatchNorm2d(1024)
71
+ )
72
+
73
+ # Upsampler
74
+
75
+ self.upsample0 = nn.Sequential(
76
+ nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True),
77
+ nn.Conv2d(in_channels=1024, out_channels=512, kernel_size=3, padding=1),
78
+ )
79
+
80
+ self.dec_conv0 = nn.Sequential(
81
+ nn.Conv2d(in_channels=1024, out_channels=512, kernel_size=3, stride=1, padding=1),
82
+ nn.LeakyReLU(inplace=True),
83
+ nn.BatchNorm2d(512),
84
+ nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, stride=1, padding=1),
85
+ nn.LeakyReLU(inplace=True),
86
+ nn.BatchNorm2d(512),
87
+ nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, stride=1, padding=1),
88
+ nn.LeakyReLU(inplace=True),
89
+ nn.BatchNorm2d(512)
90
+ )
91
+
92
+ self.upsample1 = nn.Sequential(
93
+ nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True),
94
+ nn.Conv2d(in_channels=512, out_channels=256, kernel_size=3, padding=1),
95
+ )
96
+
97
+ self.dec_conv1 = nn.Sequential(
98
+ nn.Conv2d(in_channels=512, out_channels=256, kernel_size=3, stride=1, padding=1),
99
+ nn.LeakyReLU(inplace=True),
100
+ nn.BatchNorm2d(256),
101
+ nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, stride=1, padding=1),
102
+ nn.LeakyReLU(inplace=True),
103
+ nn.BatchNorm2d(256),
104
+ nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, stride=1, padding=1),
105
+ nn.LeakyReLU(inplace=True),
106
+ nn.BatchNorm2d(256)
107
+ )
108
+
109
+ self.upsample2 = nn.Sequential(
110
+ nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True),
111
+ nn.Conv2d(in_channels=256, out_channels=128, kernel_size=3, padding=1),
112
+ )
113
+
114
+ self.dec_conv2 = nn.Sequential(
115
+ nn.Conv2d(in_channels=256, out_channels=128, kernel_size=3, stride=1, padding=1),
116
+ nn.LeakyReLU(inplace=True),
117
+ nn.BatchNorm2d(128),
118
+ nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, stride=1, padding=1),
119
+ nn.LeakyReLU(inplace=True),
120
+ nn.BatchNorm2d(128),
121
+ nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, stride=1, padding=1),
122
+ nn.LeakyReLU(inplace=True),
123
+ nn.BatchNorm2d(128)
124
+ )
125
+
126
+ self.upsample3 = nn.Sequential(
127
+ nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True),
128
+ nn.Conv2d(in_channels=128, out_channels=64, kernel_size=3, padding=1),
129
+ )
130
+
131
+ self.dec_conv3 = nn.Sequential(
132
+ nn.Conv2d(in_channels=128, out_channels=64, kernel_size=3, stride=1, padding=1),
133
+ nn.LeakyReLU(inplace=True),
134
+ nn.BatchNorm2d(64),
135
+ nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1),
136
+ nn.LeakyReLU(inplace=True),
137
+ nn.BatchNorm2d(64),
138
+ nn.Conv2d(in_channels=64, out_channels=18, kernel_size=1, stride=1, padding=0)
139
+ )
140
+
141
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
142
+ # encoder
143
+ e0 = self.enc_conv0(x)
144
+ e1 = self.pool0(e0)
145
+ e1 = self.enc_conv1(e1)
146
+ e2 = self.pool1(e1)
147
+ e2 = self.enc_conv2(e2)
148
+ e3 = self.pool2(e2)
149
+ e3 = self.enc_conv3(e3)
150
+
151
+ # bottleneck
152
+ b = self.pool3(e3)
153
+ b = self.bottleneck_conv(b)
154
+
155
+ # decoder
156
+ d0 = self.upsample0(b)
157
+ d0 = torch.cat([d0, e3], dim=1)
158
+ d0 = self.dec_conv0(d0)
159
+
160
+ d1 = self.upsample1(d0)
161
+ d1 = torch.cat([d1, e2], dim=1)
162
+ d1 = self.dec_conv1(d1)
163
+
164
+ d2 = self.upsample2(d1)
165
+ d2 = torch.cat([d2, e1], dim=1)
166
+ d2 = self.dec_conv2(d2)
167
+
168
+ d3 = self.upsample3(d2)
169
+ d3 = torch.cat([d3, e0], dim=1)
170
+ d3 = self.dec_conv3(d3)
171
+ return d3
src/models/vit.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from transformers import ViTModel
4
+
5
+ from src.models.segmentation_head import SegmentationHead
6
+
7
+
8
+ class ViTSegmentation(nn.Module):
9
+ def __init__(self, image_size: int = 224, num_classes: int = 18) -> None:
10
+ super().__init__()
11
+ self.mean = [0.5, 0.5, 0.5]
12
+ self.std = [0.5, 0.5, 0.5]
13
+ self.backbone = ViTModel.from_pretrained("google/vit-base-patch16-224")
14
+ self.segmentation_head = SegmentationHead(in_channels=768, num_classes=num_classes)
15
+ for param in self.backbone.parameters():
16
+ param.requires_grad = False
17
+
18
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
19
+ batch_size, channels, height, width = x.size()
20
+ assert height == width == self.backbone.config.image_size, "The image must match the size required by the ViT model"
21
+ outputs = self.backbone(pixel_values=x).last_hidden_state
22
+ patch_dim = int(height / self.backbone.config.patch_size)
23
+ outputs = outputs[:, 1:, :]
24
+ outputs = outputs.permute(0, 2, 1).view(batch_size, -1, patch_dim, patch_dim)
25
+ masks = self.segmentation_head(outputs)
26
+ return masks
27
+
28
+
29
+ def main() -> None:
30
+ model = ViTSegmentation(image_size=224, num_classes=18)
31
+ num_params = sum([p.numel() for p in model.parameters()])
32
+ print(f"params: {num_params/1e6:.2f} M")
33
+
34
+
35
+ if __name__ == "__main__":
36
+ main()
src/train.py ADDED
@@ -0,0 +1,264 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Tuple
2
+ from pathlib import Path
3
+ from tqdm import tqdm
4
+ from accelerate import Accelerator
5
+ from accelerate.utils import set_seed
6
+ from matplotlib import cm
7
+ import numpy as np
8
+ import matplotlib.pyplot as plt
9
+ import argparse
10
+ import json
11
+ import wandb
12
+ from datasets import load_dataset
13
+ import torch
14
+ from torch.utils.data import DataLoader
15
+
16
+ from models.unet import UNet
17
+ from dataset import SegmentationDataset, collate_fn
18
+ from utils import get_transform, mask_transform, EMA
19
+ from get_loss import get_composite_criterion
20
+ from models.vit import ViTSegmentation
21
+ from models.dino import DINOSegmentationModel
22
+
23
+
24
+ color_map = cm.get_cmap('tab20', 18)
25
+ fixed_colors = np.array([color_map(i)[:3] for i in range(18)]) * 255
26
+
27
+
28
+ def mask_to_color(mask: np.ndarray) -> np.ndarray:
29
+ h, w = mask.shape
30
+ color_mask = np.zeros((h, w, 3), dtype=np.uint8)
31
+ for class_idx in range(18):
32
+ color_mask[mask == class_idx] = fixed_colors[class_idx]
33
+ return color_mask
34
+
35
+
36
+ def create_combined_image(
37
+ x: torch.Tensor,
38
+ y: torch.Tensor,
39
+ y_pred: torch.Tensor,
40
+ mean: list[float] = [0.485, 0.456, 0.406],
41
+ std: list[float] = [0.229, 0.224, 0.225]
42
+ ) -> np.ndarray:
43
+ batch_size, _, height, width = x.shape
44
+ combined_height = height * 3
45
+ combined_width = width * batch_size
46
+ combined_image = np.zeros((combined_height, combined_width, 3), dtype=np.uint8)
47
+
48
+ for i in range(batch_size):
49
+ image = x[i].cpu().permute(1, 2, 0).numpy()
50
+ image = (image * std + mean).clip(0, 1)
51
+ image = (image * 255).astype(np.uint8)
52
+ true_mask = y[i].cpu().numpy()
53
+ true_mask_color = mask_to_color(true_mask)
54
+ pred_mask = y_pred[i].cpu().numpy()
55
+ pred_mask_color = mask_to_color(pred_mask)
56
+ combined_image[:height, i * width:(i + 1) * width, :] = image
57
+ combined_image[height:2 * height, i * width:(i + 1) * width, :] = true_mask_color
58
+ combined_image[2 * height:, i * width:(i + 1) * width, :] = pred_mask_color
59
+ return combined_image
60
+
61
+
62
+ def compute_metrics(y_pred: torch.Tensor, y: torch.Tensor, num_classes: int = 18) -> Tuple[float, float, float, float, float, float]:
63
+ pred_mask = y_pred.unsqueeze(-1) == torch.arange(num_classes, device=y_pred.device).reshape(1, 1, 1, -1)
64
+ target_mask = y.unsqueeze(-1) == torch.arange(num_classes, device=y.device).reshape(1, 1, 1, -1)
65
+ class_present = (target_mask.sum(dim=(0, 1, 2)) > 0).float()
66
+ tp = (pred_mask & target_mask).sum(dim=(0, 1, 2)).float()
67
+ fp = (pred_mask & ~target_mask).sum(dim=(0, 1, 2)).float()
68
+ fn = (~pred_mask & target_mask).sum(dim=(0, 1, 2)).float()
69
+ tn = (~pred_mask & ~target_mask).sum(dim=(0, 1, 2)).float()
70
+ overall_tp = tp.sum()
71
+ overall_fp = fp.sum()
72
+ overall_fn = fn.sum()
73
+ overall_tn = tn.sum()
74
+ precision = tp / (tp + fp).clamp(min=1e-8)
75
+ recall = tp / (tp + fn).clamp(min=1e-8)
76
+ accuracy = (tp + tn) / (tp + tn + fp + fn)
77
+ macro_precision = ((precision * class_present).sum() / class_present.sum().clamp(min=1e-8)).item()
78
+ macro_recall = ((recall * class_present).sum() / class_present.sum().clamp(min=1e-8)).item()
79
+ macro_accuracy = accuracy.mean().item()
80
+ micro_precision = (overall_tp / (overall_tp + overall_fp).clamp(min=1e-8)).item()
81
+ micro_recall = (overall_tp / (overall_tp + overall_fn).clamp(min=1e-8)).item()
82
+ global_accuracy = ((y_pred == y).sum() / (y.shape[0] * y.shape[1] * y.shape[2])).item()
83
+ return macro_precision, macro_recall, macro_accuracy, micro_precision, micro_recall, global_accuracy
84
+
85
+
86
+ def parse_args():
87
+ parser = argparse.ArgumentParser(description="Train a model on human parsing dataset")
88
+ parser.add_argument("--data-path", type=str, default="mattmdjaga/human_parsing_dataset", help="Path to the data")
89
+ parser.add_argument("--batch-size", type=int, default=32, help="Batch size for training and testing")
90
+ parser.add_argument("--pin-memory", type=bool, default=True, help="Pin memory for DataLoader")
91
+ parser.add_argument("--num-workers", type=int, default=0, help="Number of workers for DataLoader")
92
+ parser.add_argument("--num-epochs", type=int, default=15, help="Number of training epochs")
93
+ parser.add_argument("--optimizer", type=str, default="AdamW", help="Optimizer type")
94
+ parser.add_argument("--learning-rate", type=float, default=1e-4, help="Learning rate for the optimizer")
95
+ parser.add_argument("--max-norm", type=float, default=1.0, help="Maximum gradient norm for clipping")
96
+ parser.add_argument("--logs-dir", type=str, default="dino-logs", help="Directory for saving logs")
97
+ parser.add_argument("--model", type=str, default="dino", choices=["unet", "vit", "dino"], help="Model class name")
98
+ parser.add_argument("--losses-path", type=str, default="losses_config.json", help="Path to the losses")
99
+ parser.add_argument("--mixed-precision", type=str, default="fp16", choices=["fp16", "bf16", "fp8", "no"], help="Value of the mixed precision")
100
+ parser.add_argument("--gradient-accumulation-steps", type=int, default=2, help="Value of the gradient accumulation steps")
101
+ parser.add_argument("--project-name", type=str, default="human_parsing_segmentation_ttk", help="WandB project name")
102
+ parser.add_argument("--save-frequency", type=int, default=4, help="Frequency of saving model weights")
103
+ parser.add_argument("--log-steps", type=int, default=400, help="Number of steps between logging training images and metrics")
104
+ parser.add_argument("--seed", type=int, default=42, help="Value of the seed")
105
+ return parser.parse_args()
106
+
107
+
108
+ def main() -> None:
109
+ args = parse_args()
110
+
111
+ set_seed(args.seed)
112
+ accelerator = Accelerator(gradient_accumulation_steps=args.gradient_accumulation_steps, mixed_precision=args.mixed_precision)
113
+
114
+ with open(args.losses_path, "r") as fp:
115
+ losses_config = json.load(fp)
116
+
117
+ with accelerator.main_process_first():
118
+ logs_dir = Path(args.logs_dir)
119
+ logs_dir.mkdir(exist_ok=True)
120
+ wandb.init(project=args.project_name, dir=logs_dir)
121
+ wandb.save(args.losses_path)
122
+
123
+ optimizer_class = getattr(torch.optim, args.optimizer)
124
+
125
+ if args.model == "unet":
126
+ model = UNet().to(accelerator.device)
127
+ optimizer = optimizer_class(model.parameters(), lr=args.learning_rate)
128
+ elif args.model == "vit":
129
+ model = ViTSegmentation().to(accelerator.device)
130
+ optimizer = optimizer_class(model.parameters(), lr=args.learning_rate)
131
+ elif args.model == "dino":
132
+ model = DINOSegmentationModel().to(accelerator.device)
133
+ optimizer = optimizer_class(model.segmentation_head.parameters(), lr=args.learning_rate)
134
+ else:
135
+ raise NotImplementedError("Incorrect model name")
136
+
137
+ transform = get_transform(model.mean, model.std)
138
+
139
+ dataset = load_dataset(args.data_path, split="train")
140
+ train_dataset = SegmentationDataset(dataset, train=True, transform=transform, target_transform=mask_transform)
141
+ valid_dataset = SegmentationDataset(dataset, train=False, transform=transform, target_transform=mask_transform)
142
+
143
+ train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers, pin_memory=args.pin_memory, collate_fn=collate_fn)
144
+ valid_loader = DataLoader(valid_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers, pin_memory=args.pin_memory, collate_fn=collate_fn)
145
+
146
+ criterion = get_composite_criterion(losses_config)
147
+
148
+ lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, args.num_epochs * len(train_loader))
149
+
150
+ model, optimizer, train_loader, lr_scheduler = accelerator.prepare(model, optimizer, train_loader, lr_scheduler)
151
+
152
+ best_accuracy = 0
153
+
154
+ print(f"params: {sum([p.numel() for p in model.parameters()])/1e6:.2f} M")
155
+ print(f"trainable params: {sum([p.numel() for p in model.parameters() if p.requires_grad])/1e6:.2f} M")
156
+
157
+ train_loss_ema, train_macro_precision_ema, train_macro_recall_ema, train_macro_accuracy_ema, train_micro_precision_ema, train_micro_recall_ema, train_global_accuracy_ema = EMA(), EMA(), EMA(), EMA(), EMA(), EMA(), EMA()
158
+ for epoch in range(1, args.num_epochs + 1):
159
+ model.train()
160
+ print(f"trainable params: {sum([p.numel() for p in model.parameters() if p.requires_grad])/1e6:.2f} M")
161
+ exit()
162
+ pbar = tqdm(train_loader, desc=f"Train epoch {epoch}/{args.num_epochs}")
163
+ for index, (x, y) in enumerate(pbar):
164
+ x, y = x.to(accelerator.device), y.squeeze(1).to(accelerator.device)
165
+ with accelerator.accumulate(model):
166
+ with accelerator.autocast():
167
+ output = model(x)
168
+ loss = criterion(output, y)
169
+ accelerator.backward(loss)
170
+ train_loss = loss.item()
171
+ grad_norm = None
172
+ _, y_pred = output.max(dim=1)
173
+ train_macro_precision, train_macro_recall, train_macro_accuracy, train_micro_precision, train_micro_recall, train_global_accuracy = compute_metrics(y_pred, y)
174
+ if accelerator.sync_gradients:
175
+ grad_norm = accelerator.clip_grad_norm_(model.parameters(), args.max_norm).item()
176
+ optimizer.step()
177
+ lr_scheduler.step()
178
+ optimizer.zero_grad()
179
+ if (index + 1) % args.log_steps == 0 and accelerator.is_main_process:
180
+ images_to_log = []
181
+ combined_image = create_combined_image(x, y, y_pred)
182
+ images_to_log.append(wandb.Image(combined_image, caption=f"Combined Image (Train, Epoch {epoch}, Batch {index})"))
183
+ wandb.log({"train_samples": images_to_log})
184
+ pbar.set_postfix({"loss": train_loss_ema(train_loss), "macro_precision": train_macro_precision_ema(train_macro_precision), "macro_recall": train_macro_recall_ema(train_macro_recall), "macro_accuracy": train_macro_accuracy_ema(train_macro_accuracy), "micro_precision": train_micro_precision_ema(train_micro_precision), "micro_recall": train_micro_recall_ema(train_micro_recall), "global_accuracy": train_global_accuracy_ema(train_global_accuracy)})
185
+ log_data = {
186
+ "train/epoch": epoch,
187
+ "train/loss": train_loss,
188
+ "train/macro_accuracy": train_macro_accuracy,
189
+ "train/learning_rate": optimizer.param_groups[0]["lr"],
190
+ "train/macro_precision": train_macro_precision,
191
+ "train/macro_recall": train_macro_recall,
192
+ "train/micro_precision": train_micro_precision,
193
+ "train/micro_recall": train_micro_recall,
194
+ "train/global_accuracy": train_global_accuracy,
195
+ }
196
+ if grad_norm is not None:
197
+ log_data["train/grad_norm"] = grad_norm
198
+ if accelerator.is_main_process:
199
+ wandb.log(log_data)
200
+ accelerator.wait_for_everyone()
201
+
202
+ model.eval()
203
+ valid_loss, valid_macro_accuracies, valid_macro_precisions, valid_macro_recalls, valid_global_accuracies, valid_micro_precisions, valid_micro_recalls = 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0
204
+ with torch.inference_mode():
205
+ pbar = tqdm(valid_loader, desc=f"Val epoch {epoch}/{args.num_epochs}")
206
+ for index, (x, y) in enumerate(valid_loader):
207
+ x, y = x.to(accelerator.device), y.squeeze(1).to(accelerator.device)
208
+ output = model(x)
209
+ _, y_pred = output.max(dim=1)
210
+ if (index + 1) % args.log_steps == 0 and accelerator.is_main_process:
211
+ images_to_log = []
212
+ combined_image = create_combined_image(x, y, y_pred)
213
+ images_to_log.append(wandb.Image(combined_image, caption=f"Combined Image (Validation, Epoch {epoch})"))
214
+ wandb.log({"valid_samples": images_to_log})
215
+ valid_macro_precision, valid_macro_recall, valid_macro_accuracy, valid_micro_precision, valid_micro_recall, valid_global_accuracy = compute_metrics(y_pred, y)
216
+ valid_macro_precisions += valid_macro_precision
217
+ valid_macro_recalls += valid_macro_recall
218
+ valid_macro_accuracies += valid_macro_accuracy
219
+ valid_micro_precisions += valid_micro_precision
220
+ valid_micro_recalls += valid_micro_recall
221
+ valid_global_accuracies += valid_global_accuracy
222
+ loss = criterion(output, y)
223
+ valid_loss += loss.item()
224
+ valid_loss = valid_loss / len(valid_loader)
225
+ valid_macro_accuracies = valid_macro_accuracies / len(valid_loader)
226
+ valid_macro_precisions = valid_macro_precisions / len(valid_loader)
227
+ valid_macro_recalls = valid_macro_recalls / len(valid_loader)
228
+ valid_global_accuracies = valid_global_accuracies / len(valid_loader)
229
+ valid_micro_precisions = valid_micro_precisions / len(valid_loader)
230
+ valid_micro_recalls = valid_micro_recalls / len(valid_loader)
231
+ accelerator.print(f"loss: {valid_loss:.3f}, valid_macro_precision: {valid_macro_precisions:.3f}, valid_macro_recall: {valid_macro_recalls:.3f}, valid_macro_accuracy: {valid_macro_accuracies:.3f}, valid_micro_precision: {valid_micro_precisions:.3f}, valid_micro_recall: {valid_micro_recalls:.3f}, valid_global_accuracy: {valid_global_accuracies:.3f}")
232
+ if accelerator.is_main_process:
233
+ wandb.log(
234
+ {
235
+ "val/epoch": epoch,
236
+ "val/loss": valid_loss,
237
+ "val/macro_accuracy": valid_macro_accuracies,
238
+ "val/macro_precision": valid_macro_precisions,
239
+ "val/macro_recall": valid_macro_recalls,
240
+ "val/global_accuracy": valid_global_accuracies,
241
+ "val/micro_precision": valid_micro_precisions,
242
+ "val/micro_recall": valid_micro_recalls,
243
+ }
244
+ )
245
+ if valid_global_accuracies > best_accuracy:
246
+ best_accuracy = valid_global_accuracies
247
+ if args.model in ["dino", "vit"]:
248
+ accelerator.save(model.segmentation_head.state_dict(), logs_dir / f"checkpoint-best.pth")
249
+ else:
250
+ accelerator.save(model.state_dict(), logs_dir / f"checkpoint-best.pth")
251
+ accelerator.print(f"new best_accuracy {best_accuracy}, {epoch=}")
252
+ if epoch % args.save_frequency == 0:
253
+ if args.model in ["dino", "vit"]:
254
+ accelerator.save(model.segmentation_head.state_dict(), logs_dir / f"checkpoint-{epoch:09}.pth")
255
+ else:
256
+ accelerator.save(model.state_dict(), logs_dir / f"checkpoint-{epoch:09}.pth")
257
+ accelerator.wait_for_everyone()
258
+
259
+ accelerator.wait_for_everyone()
260
+ wandb.finish()
261
+
262
+
263
+ if __name__ == "__main__":
264
+ main()
src/utils.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torchvision.transforms as T
3
+ import PIL.Image
4
+ from typing import List
5
+
6
+
7
+ size = (224, 224)
8
+
9
+
10
+ class ResizeWithPadding:
11
+ def __init__(self, target_size: int = 224, fill: int = 0, mode: str = "RGB") -> None:
12
+ self.target_size = target_size
13
+ self.fill = fill
14
+ self.mode = mode
15
+
16
+ def __call__(self, image: PIL.Image) -> PIL.Image:
17
+ original_width, original_height = image.size
18
+ aspect_ratio = original_width / original_height
19
+ if aspect_ratio > 1:
20
+ new_width = self.target_size
21
+ new_height = int(self.target_size / aspect_ratio)
22
+ else:
23
+ new_height = self.target_size
24
+ new_width = int(self.target_size * aspect_ratio)
25
+ resized_image = image.resize((new_width, new_height), PIL.Image.BICUBIC if self.mode == "RGB" else PIL.Image.NEAREST)
26
+ delta_w = self.target_size - new_width
27
+ delta_h = self.target_size - new_height
28
+ padding = (delta_w // 2, delta_h // 2, delta_w - delta_w // 2, delta_h - delta_h // 2)
29
+ padded_image = PIL.Image.new(self.mode, (self.target_size, self.target_size), self.fill)
30
+ padded_image.paste(resized_image, (padding[0], padding[1]))
31
+ return padded_image
32
+
33
+
34
+ def get_transform(mean: List[float], std: List[float]) -> T.Compose:
35
+ return T.Compose([
36
+ ResizeWithPadding(),
37
+ T.ToTensor(),
38
+ T.Normalize(mean=mean, std=std),
39
+ ])
40
+
41
+ mask_transform = T.Compose([
42
+ ResizeWithPadding(mode="L"),
43
+ T.ToTensor(),
44
+ T.Lambda(lambda x: (x * 255).long()),
45
+ ])
46
+
47
+
48
+ class EMA:
49
+ def __init__(self, alpha: float = 0.9) -> None:
50
+ self.value = None
51
+ self.alpha = alpha
52
+
53
+ def __call__(self, value: float) -> float:
54
+ if self.value is None:
55
+ self.value = value
56
+ else:
57
+ self.value = self.alpha * self.value + (1 - self.alpha) * value
58
+ return self.value