Upload 9 files
Browse files- src/dataset.py +50 -0
- src/get_loss.py +79 -0
- src/losses.py +498 -0
- src/models/dino.py +37 -0
- src/models/segmentation_head.py +40 -0
- src/models/unet.py +171 -0
- src/models/vit.py +36 -0
- src/train.py +264 -0
- src/utils.py +58 -0
src/dataset.py
ADDED
@@ -0,0 +1,50 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import List, Tuple, Callable
|
2 |
+
from pathlib import Path
|
3 |
+
import datasets
|
4 |
+
import torch
|
5 |
+
from torch.utils.data import Dataset
|
6 |
+
|
7 |
+
|
8 |
+
class SegmentationDataset(Dataset):
|
9 |
+
def __init__(
|
10 |
+
self,
|
11 |
+
dataset: datasets.Dataset,
|
12 |
+
train: bool = True,
|
13 |
+
transform: Callable = None,
|
14 |
+
target_transform: Callable = None,
|
15 |
+
test_size: float = 0.25,
|
16 |
+
) -> None:
|
17 |
+
super().__init__()
|
18 |
+
self.dataset = dataset
|
19 |
+
self.train = train
|
20 |
+
self.transform = transform
|
21 |
+
self.target_transform = target_transform
|
22 |
+
self.test_size = test_size
|
23 |
+
|
24 |
+
total_size = len(dataset)
|
25 |
+
indices = list(range(total_size))
|
26 |
+
split = int(self.test_size * total_size)
|
27 |
+
|
28 |
+
if train:
|
29 |
+
self.indices = indices[split:]
|
30 |
+
else:
|
31 |
+
self.indices = indices[:split]
|
32 |
+
|
33 |
+
def __len__(self) -> int:
|
34 |
+
return len(self.indices)
|
35 |
+
|
36 |
+
def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor]:
|
37 |
+
item = self.dataset[self.indices[idx]]
|
38 |
+
image = item["image"]
|
39 |
+
mask = item["mask"]
|
40 |
+
if self.transform:
|
41 |
+
image = self.transform(image)
|
42 |
+
if self.target_transform:
|
43 |
+
mask = self.target_transform(mask)
|
44 |
+
return image, mask
|
45 |
+
|
46 |
+
|
47 |
+
def collate_fn(items: List[Tuple[torch.Tensor, torch.Tensor]]) -> Tuple[torch.Tensor, torch.Tensor]:
|
48 |
+
images = torch.stack([item[0] for item in items])
|
49 |
+
masks = torch.stack([item[1] for item in items])
|
50 |
+
return images, masks
|
src/get_loss.py
ADDED
@@ -0,0 +1,79 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Dict, Callable
|
2 |
+
import torch.nn as nn
|
3 |
+
import torch
|
4 |
+
|
5 |
+
from losses import SoftDiceLoss, SSLoss, IoULoss, TverskyLoss, FocalTversky_loss, AsymLoss, ExpLog_loss, FocalLoss, LovaszSoftmax, TopKLoss, WeightedCrossEntropyLoss, SoftDiceLoss_v2, IoULoss_v2, TverskyLoss_v2, FocalTversky_loss_v2, AsymLoss_v2, SSLoss_v2
|
6 |
+
|
7 |
+
|
8 |
+
def get_loss(loss_type: str) -> Callable | None:
|
9 |
+
if loss_type == "cross_entropy":
|
10 |
+
return nn.CrossEntropyLoss()
|
11 |
+
elif loss_type == "SoftDiceLoss":
|
12 |
+
return SoftDiceLoss()
|
13 |
+
elif loss_type == "SSLoss":
|
14 |
+
return SSLoss()
|
15 |
+
elif loss_type == "IoULoss":
|
16 |
+
return IoULoss()
|
17 |
+
elif loss_type == "TverskyLoss":
|
18 |
+
return TverskyLoss()
|
19 |
+
elif loss_type == "FocalTversky_loss":
|
20 |
+
tversky_kwargs = {
|
21 |
+
"apply_nonlin": None,
|
22 |
+
"batch_dice": False,
|
23 |
+
"do_bg": True,
|
24 |
+
"smooth": 1.0,
|
25 |
+
"square": False
|
26 |
+
}
|
27 |
+
return FocalTversky_loss(tversky_kwargs=tversky_kwargs)
|
28 |
+
elif loss_type == "AsymLoss":
|
29 |
+
return AsymLoss()
|
30 |
+
elif loss_type == "ExpLog_loss":
|
31 |
+
soft_dice_kwargs = {
|
32 |
+
"smooth": 1.0
|
33 |
+
}
|
34 |
+
wce_kwargs = {
|
35 |
+
"weight": None
|
36 |
+
}
|
37 |
+
return ExpLog_loss(soft_dice_kwargs=soft_dice_kwargs, wce_kwargs=wce_kwargs)
|
38 |
+
elif loss_type == "FocalLoss":
|
39 |
+
return FocalLoss()
|
40 |
+
elif loss_type == "LovaszSoftmax":
|
41 |
+
return LovaszSoftmax()
|
42 |
+
elif loss_type == "TopKLoss":
|
43 |
+
return TopKLoss()
|
44 |
+
elif loss_type == "WeightedCrossEntropyLoss":
|
45 |
+
return WeightedCrossEntropyLoss()
|
46 |
+
elif loss_type == "SoftDiceLoss_v2":
|
47 |
+
return SoftDiceLoss_v2()
|
48 |
+
elif loss_type == "IoULoss_v2":
|
49 |
+
return IoULoss_v2()
|
50 |
+
elif loss_type == "TverskyLoss_v2":
|
51 |
+
return TverskyLoss_v2()
|
52 |
+
elif loss_type == "FocalTversky_loss_v2":
|
53 |
+
return FocalTversky_loss_v2()
|
54 |
+
elif loss_type == "AsymLoss_v2":
|
55 |
+
return AsymLoss_v2()
|
56 |
+
elif loss_type == "SSLoss_v2":
|
57 |
+
return SSLoss_v2()
|
58 |
+
else:
|
59 |
+
raise ValueError(f"Unsupported loss type: {loss_type}")
|
60 |
+
|
61 |
+
|
62 |
+
def get_composite_criterion(losses_config: Dict[str, float]) -> Callable[[torch.Tensor, torch.Tensor], torch.Tensor]:
|
63 |
+
losses = []
|
64 |
+
weights = []
|
65 |
+
|
66 |
+
for loss_name, weight in losses_config.items():
|
67 |
+
if weight != 0.0:
|
68 |
+
loss_fn = get_loss(loss_name)
|
69 |
+
if loss_fn is not None:
|
70 |
+
losses.append(loss_fn)
|
71 |
+
weights.append(weight)
|
72 |
+
|
73 |
+
def composite_loss(output: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
|
74 |
+
total_loss = 0.0
|
75 |
+
for loss_fn, weight in zip(losses, weights):
|
76 |
+
total_loss += weight * loss_fn(output, target)
|
77 |
+
return total_loss
|
78 |
+
|
79 |
+
return composite_loss
|
src/losses.py
ADDED
@@ -0,0 +1,498 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Callable, List, Tuple, Dict
|
2 |
+
import torch
|
3 |
+
from torch import nn
|
4 |
+
import torch.nn.functional as F
|
5 |
+
import numpy as np
|
6 |
+
|
7 |
+
|
8 |
+
def sum_tensor(inp: torch.Tensor, axes: int | List[int], keepdim: bool = False) -> torch.Tensor:
|
9 |
+
axes = np.unique(axes).astype(int)
|
10 |
+
if keepdim:
|
11 |
+
for ax in axes:
|
12 |
+
inp = inp.sum(int(ax), keepdim=True)
|
13 |
+
else:
|
14 |
+
for ax in sorted(axes, reverse=True):
|
15 |
+
inp = inp.sum(int(ax))
|
16 |
+
return inp
|
17 |
+
|
18 |
+
|
19 |
+
def get_tp_fp_fn(net_output: torch.Tensor, gt: torch.Tensor, axes: int | Tuple[int, ...] | None = None, mask: torch.Tensor | None = None, square: bool = False) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
20 |
+
if axes is None:
|
21 |
+
axes = tuple(range(2, len(net_output.size())))
|
22 |
+
shp_x = net_output.shape
|
23 |
+
shp_y = gt.shape
|
24 |
+
with torch.no_grad():
|
25 |
+
if len(shp_x) != len(shp_y):
|
26 |
+
gt = gt.view((shp_y[0], 1, *shp_y[1:]))
|
27 |
+
if all([i == j for i, j in zip(net_output.shape, gt.shape)]):
|
28 |
+
y_onehot = gt
|
29 |
+
else:
|
30 |
+
gt = gt.long()
|
31 |
+
y_onehot = torch.zeros(shp_x)
|
32 |
+
if net_output.device.type == "cuda":
|
33 |
+
y_onehot = y_onehot.cuda(net_output.device.index)
|
34 |
+
y_onehot.scatter_(1, gt, 1)
|
35 |
+
tp = net_output * y_onehot
|
36 |
+
fp = net_output * (1 - y_onehot)
|
37 |
+
fn = (1 - net_output) * y_onehot
|
38 |
+
if mask is not None:
|
39 |
+
tp = torch.stack(tuple(x_i * mask[:, 0] for x_i in torch.unbind(tp, dim=1)), dim=1)
|
40 |
+
fp = torch.stack(tuple(x_i * mask[:, 0] for x_i in torch.unbind(fp, dim=1)), dim=1)
|
41 |
+
fn = torch.stack(tuple(x_i * mask[:, 0] for x_i in torch.unbind(fn, dim=1)), dim=1)
|
42 |
+
if square:
|
43 |
+
tp = tp ** 2
|
44 |
+
fp = fp ** 2
|
45 |
+
fn = fn ** 2
|
46 |
+
tp = sum_tensor(tp, axes, keepdim=False)
|
47 |
+
fp = sum_tensor(fp, axes, keepdim=False)
|
48 |
+
fn = sum_tensor(fn, axes, keepdim=False)
|
49 |
+
return tp, fp, fn
|
50 |
+
|
51 |
+
|
52 |
+
def softmax_helper(x: torch.Tensor) -> torch.Tensor:
|
53 |
+
rpt = [1 for _ in range(len(x.size()))]
|
54 |
+
rpt[1] = x.size(1)
|
55 |
+
x_max = x.max(1, keepdim=True)[0].repeat(*rpt)
|
56 |
+
e_x = torch.exp(x - x_max)
|
57 |
+
return e_x / e_x.sum(1, keepdim=True).repeat(*rpt)
|
58 |
+
|
59 |
+
def flatten(tensor: torch.Tensor) -> torch.Tensor:
|
60 |
+
C = tensor.size(1)
|
61 |
+
axis_order = (1, 0) + tuple(range(2, tensor.dim()))
|
62 |
+
transposed = tensor.permute(axis_order).contiguous()
|
63 |
+
return transposed.view(C, -1)
|
64 |
+
|
65 |
+
|
66 |
+
class SoftDiceLoss(nn.Module):
|
67 |
+
def __init__(self, apply_nonlin: Callable | None = softmax_helper, batch_dice: bool = True, do_bg: bool = False, smooth: float = 1.0, square: bool = True) -> None:
|
68 |
+
super().__init__()
|
69 |
+
self.square = square
|
70 |
+
self.do_bg = do_bg
|
71 |
+
self.batch_dice = batch_dice
|
72 |
+
self.apply_nonlin = apply_nonlin
|
73 |
+
self.smooth = smooth
|
74 |
+
|
75 |
+
def forward(self, x: torch.Tensor, y: torch.Tensor, loss_mask: torch.Tensor | None = None) -> torch.Tensor:
|
76 |
+
shp_x = x.shape
|
77 |
+
if self.batch_dice:
|
78 |
+
axes = [0] + list(range(2, len(shp_x)))
|
79 |
+
else:
|
80 |
+
axes = list(range(2, len(shp_x)))
|
81 |
+
if self.apply_nonlin is not None:
|
82 |
+
x = self.apply_nonlin(x)
|
83 |
+
tp, fp, fn = get_tp_fp_fn(x, y, axes, loss_mask, self.square)
|
84 |
+
dc = (2 * tp + self.smooth) / (2 * tp + fp + fn + self.smooth)
|
85 |
+
if not self.do_bg:
|
86 |
+
if self.batch_dice:
|
87 |
+
dc = dc[1:]
|
88 |
+
else:
|
89 |
+
dc = dc[:, 1:]
|
90 |
+
dc = dc.mean()
|
91 |
+
return -dc
|
92 |
+
|
93 |
+
|
94 |
+
class SoftDiceLoss_v2(nn.Module):
|
95 |
+
def __init__(self, smooth: float = 1.0) -> None:
|
96 |
+
super().__init__()
|
97 |
+
self.smooth = smooth
|
98 |
+
|
99 |
+
def forward(self, logits: torch.Tensor, targets: torch.Tensor) -> torch.Tensor:
|
100 |
+
probs = F.softmax(logits, dim=1)
|
101 |
+
targets = F.one_hot(targets, num_classes=probs.size(1)).permute(0, 3, 1, 2).float()
|
102 |
+
intersection = torch.sum(probs * targets, dim=(0, 2, 3))
|
103 |
+
union = torch.sum(probs + targets, dim=(0, 2, 3))
|
104 |
+
dl = 1 - (2.0 * intersection + self.smooth) / (union + self.smooth)
|
105 |
+
dice_loss = torch.mean(dl)
|
106 |
+
return dice_loss
|
107 |
+
|
108 |
+
|
109 |
+
class SSLoss(nn.Module):
|
110 |
+
def __init__(self, apply_nonlin: Callable | None = softmax_helper, batch_dice: bool = True, do_bg: bool = False, smooth: float = 1., square: bool = True) -> None:
|
111 |
+
super().__init__()
|
112 |
+
self.square = square
|
113 |
+
self.do_bg = do_bg
|
114 |
+
self.batch_dice = batch_dice
|
115 |
+
self.apply_nonlin = apply_nonlin
|
116 |
+
self.smooth = smooth
|
117 |
+
self.r = 0.1
|
118 |
+
|
119 |
+
def forward(self, net_output: torch.Tensor, gt: torch.Tensor) -> torch.Tensor:
|
120 |
+
shp_x = net_output.shape
|
121 |
+
shp_y = gt.shape
|
122 |
+
with torch.no_grad():
|
123 |
+
if len(shp_x) != len(shp_y):
|
124 |
+
gt = gt.view((shp_y[0], 1, *shp_y[1:]))
|
125 |
+
if all([i == j for i, j in zip(net_output.shape, gt.shape)]):
|
126 |
+
y_onehot = gt
|
127 |
+
else:
|
128 |
+
gt = gt.long()
|
129 |
+
y_onehot = torch.zeros(shp_x)
|
130 |
+
if net_output.device.type == "cuda":
|
131 |
+
y_onehot = y_onehot.cuda(net_output.device.index)
|
132 |
+
y_onehot.scatter_(1, gt, 1)
|
133 |
+
if self.batch_dice:
|
134 |
+
axes = [0] + list(range(2, len(shp_x)))
|
135 |
+
else:
|
136 |
+
axes = list(range(2, len(shp_x)))
|
137 |
+
if self.apply_nonlin is not None:
|
138 |
+
net_output = self.apply_nonlin(net_output)
|
139 |
+
bg_onehot = 1 - y_onehot
|
140 |
+
squared_error = (y_onehot - net_output)**2
|
141 |
+
specificity_part = sum_tensor(squared_error*y_onehot, axes)/(sum_tensor(y_onehot, axes)+self.smooth)
|
142 |
+
sensitivity_part = sum_tensor(squared_error*bg_onehot, axes)/(sum_tensor(bg_onehot, axes)+self.smooth)
|
143 |
+
ss = self.r * specificity_part + (1-self.r) * sensitivity_part
|
144 |
+
if not self.do_bg:
|
145 |
+
if self.batch_dice:
|
146 |
+
ss = ss[1:]
|
147 |
+
else:
|
148 |
+
ss = ss[:, 1:]
|
149 |
+
ss = ss.mean()
|
150 |
+
return ss
|
151 |
+
|
152 |
+
|
153 |
+
class SSLoss_v2(nn.Module):
|
154 |
+
def __init__(self, alpha: float = 0.5) -> None:
|
155 |
+
super().__init__()
|
156 |
+
self.alpha = alpha
|
157 |
+
|
158 |
+
def forward(self, logits: torch.Tensor, targets: torch.Tensor) -> torch.Tensor:
|
159 |
+
probs = F.softmax(logits, dim=1)
|
160 |
+
targets = F.one_hot(targets, num_classes=probs.size(1)).permute(0, 3, 1, 2).float()
|
161 |
+
intersection = torch.sum(probs * targets, dim=(0, 2, 3))
|
162 |
+
cardinality = torch.sum(probs + targets, dim=(0, 2, 3))
|
163 |
+
dice_loss = 1 - (2.0 * intersection + 1e-6) / (cardinality + 1e-6)
|
164 |
+
ce_loss = F.cross_entropy(probs, targets, reduction='mean')
|
165 |
+
loss = 0.5 * dice_loss.mean() + (1 - 0.5) * ce_loss
|
166 |
+
return loss
|
167 |
+
|
168 |
+
|
169 |
+
class IoULoss(nn.Module):
|
170 |
+
def __init__(self, apply_nonlin: Callable | None = softmax_helper, batch_dice: bool = True, do_bg: bool = False, smooth: float = 1., square: bool = True) -> None:
|
171 |
+
super().__init__()
|
172 |
+
self.square = square
|
173 |
+
self.do_bg = do_bg
|
174 |
+
self.batch_dice = batch_dice
|
175 |
+
self.apply_nonlin = apply_nonlin
|
176 |
+
self.smooth = smooth
|
177 |
+
|
178 |
+
def forward(self, x: torch.Tensor, y: torch.Tensor, loss_mask: torch.Tensor | None = None) -> torch.Tensor:
|
179 |
+
shp_x = x.shape
|
180 |
+
if self.batch_dice:
|
181 |
+
axes = [0] + list(range(2, len(shp_x)))
|
182 |
+
else:
|
183 |
+
axes = list(range(2, len(shp_x)))
|
184 |
+
if self.apply_nonlin is not None:
|
185 |
+
x = self.apply_nonlin(x)
|
186 |
+
tp, fp, fn = get_tp_fp_fn(x, y, axes, loss_mask, self.square)
|
187 |
+
iou = (tp + self.smooth) / (tp + fp + fn + self.smooth)
|
188 |
+
if not self.do_bg:
|
189 |
+
if self.batch_dice:
|
190 |
+
iou = iou[1:]
|
191 |
+
else:
|
192 |
+
iou = iou[:, 1:]
|
193 |
+
iou = iou.mean()
|
194 |
+
return -iou
|
195 |
+
|
196 |
+
|
197 |
+
class IoULoss_v2(nn.Module):
|
198 |
+
def __init__(self, smooth: float = 1.0) -> None:
|
199 |
+
super().__init__()
|
200 |
+
self.smooth = smooth
|
201 |
+
|
202 |
+
def forward(self, logits: torch.Tensor, targets: torch.Tensor) -> torch.Tensor:
|
203 |
+
probs = F.softmax(logits, dim=1)
|
204 |
+
targets = F.one_hot(targets, num_classes=probs.size(1)).permute(0, 3, 1, 2).float()
|
205 |
+
intersection = torch.sum(probs * targets, dim=(0, 2, 3))
|
206 |
+
union = torch.sum(probs + targets, dim=(0, 2, 3)) - intersection
|
207 |
+
iou = 1 - (intersection + self.smooth) / (union + self.smooth)
|
208 |
+
iou_loss = torch.mean(iou)
|
209 |
+
return iou_loss
|
210 |
+
|
211 |
+
|
212 |
+
class TverskyLoss(nn.Module):
|
213 |
+
def __init__(self, apply_nonlin: Callable | None = softmax_helper, batch_dice: bool = True, do_bg: bool = False, smooth: float = 1., square: bool = True) -> None:
|
214 |
+
super().__init__()
|
215 |
+
self.square = square
|
216 |
+
self.do_bg = do_bg
|
217 |
+
self.batch_dice = batch_dice
|
218 |
+
self.apply_nonlin = apply_nonlin
|
219 |
+
self.smooth = smooth
|
220 |
+
self.alpha = 0.3
|
221 |
+
self.beta = 0.7
|
222 |
+
|
223 |
+
def forward(self, x: torch.Tensor, y: torch.Tensor, loss_mask: torch.Tensor | None = None) -> torch.Tensor:
|
224 |
+
shp_x = x.shape
|
225 |
+
if self.batch_dice:
|
226 |
+
axes = [0] + list(range(2, len(shp_x)))
|
227 |
+
else:
|
228 |
+
axes = list(range(2, len(shp_x)))
|
229 |
+
if self.apply_nonlin is not None:
|
230 |
+
x = self.apply_nonlin(x)
|
231 |
+
tp, fp, fn = get_tp_fp_fn(x, y, axes, loss_mask, self.square)
|
232 |
+
tversky = (tp + self.smooth) / (tp + self.alpha*fp + self.beta*fn + self.smooth)
|
233 |
+
if not self.do_bg:
|
234 |
+
if self.batch_dice:
|
235 |
+
tversky = tversky[1:]
|
236 |
+
else:
|
237 |
+
tversky = tversky[:, 1:]
|
238 |
+
tversky = tversky.mean()
|
239 |
+
return -tversky
|
240 |
+
|
241 |
+
|
242 |
+
class TverskyLoss_v2(nn.Module):
|
243 |
+
def __init__(self, alpha: float = 0.5, beta: float = 0.5, smooth: float = 1.0) -> None:
|
244 |
+
super().__init__()
|
245 |
+
self.alpha = alpha
|
246 |
+
self.beta = beta
|
247 |
+
self.smooth = smooth
|
248 |
+
|
249 |
+
def forward(self, logits: torch.Tensor, targets: torch.Tensor) -> torch.Tensor:
|
250 |
+
probs = F.softmax(logits, dim=1)
|
251 |
+
targets = F.one_hot(targets, num_classes=probs.size(1)).permute(0, 3, 1, 2).float()
|
252 |
+
tp = torch.sum(probs * targets, dim=(0, 2, 3))
|
253 |
+
fp = torch.sum((1 - targets) * probs, dim=(0, 2, 3))
|
254 |
+
fn = torch.sum(targets * (1 - probs), dim=(0, 2, 3))
|
255 |
+
tversky = 1 - (tp + self.smooth) / (tp + self.alpha * fp + self.beta * fn + self.smooth)
|
256 |
+
tversky_loss = torch.mean(tversky)
|
257 |
+
return tversky_loss
|
258 |
+
|
259 |
+
|
260 |
+
class FocalTversky_loss(nn.Module):
|
261 |
+
def __init__(self, tversky_kwargs: Dict, gamma: float = 0.75) -> None:
|
262 |
+
super().__init__()
|
263 |
+
self.gamma = gamma
|
264 |
+
self.tversky = TverskyLoss(**tversky_kwargs)
|
265 |
+
|
266 |
+
def forward(self, net_output: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
|
267 |
+
tversky_loss = 1 + self.tversky(net_output, target)
|
268 |
+
focal_tversky = torch.pow(tversky_loss, self.gamma)
|
269 |
+
return focal_tversky
|
270 |
+
|
271 |
+
|
272 |
+
class FocalTversky_loss_v2(nn.Module):
|
273 |
+
def __init__(self, alpha: float = 0.5, beta: float = 0.5, gamma: float = 1.5, smooth: float = 1.0) -> None:
|
274 |
+
super().__init__()
|
275 |
+
self.alpha = alpha
|
276 |
+
self.beta = beta
|
277 |
+
self.gamma = gamma
|
278 |
+
self.smooth = smooth
|
279 |
+
|
280 |
+
def forward(self, logits: torch.Tensor, targets: torch.Tensor) -> torch.Tensor:
|
281 |
+
probs = F.softmax(logits, dim=1)
|
282 |
+
targets = F.one_hot(targets, num_classes=probs.size(1)).permute(0, 3, 1, 2).float()
|
283 |
+
tp = torch.sum(probs * targets, dim=(0, 2, 3))
|
284 |
+
fp = torch.sum((1 - targets) * probs, dim=(0, 2, 3))
|
285 |
+
fn = torch.sum(targets * (1 - probs), dim=(0, 2, 3))
|
286 |
+
focal_tversky = (1 - (tp + self.smooth) / (tp + self.alpha * fp + self.beta * fn + self.smooth)) ** self.gamma
|
287 |
+
focal_tversky_loss = torch.mean(focal_tversky)
|
288 |
+
return focal_tversky_loss
|
289 |
+
|
290 |
+
|
291 |
+
class AsymLoss(nn.Module):
|
292 |
+
def __init__(self, apply_nonlin: Callable | None = softmax_helper, batch_dice: bool = True, do_bg: bool = False, smooth: float = 1., square: bool = True) -> None:
|
293 |
+
super().__init__()
|
294 |
+
self.square = square
|
295 |
+
self.do_bg = do_bg
|
296 |
+
self.batch_dice = batch_dice
|
297 |
+
self.apply_nonlin = apply_nonlin
|
298 |
+
self.smooth = smooth
|
299 |
+
self.beta = 1.5
|
300 |
+
|
301 |
+
def forward(self, x: torch.Tensor, y: torch.Tensor, loss_mask: torch.Tensor | None = None) -> torch.Tensor:
|
302 |
+
shp_x = x.shape
|
303 |
+
if self.batch_dice:
|
304 |
+
axes = [0] + list(range(2, len(shp_x)))
|
305 |
+
else:
|
306 |
+
axes = list(range(2, len(shp_x)))
|
307 |
+
if self.apply_nonlin is not None:
|
308 |
+
x = self.apply_nonlin(x)
|
309 |
+
tp, fp, fn = get_tp_fp_fn(x, y, axes, loss_mask, self.square)
|
310 |
+
weight = (self.beta**2)/(1+self.beta**2)
|
311 |
+
asym = (tp + self.smooth) / (tp + weight*fn + (1-weight)*fp + self.smooth)
|
312 |
+
if not self.do_bg:
|
313 |
+
if self.batch_dice:
|
314 |
+
asym = asym[1:]
|
315 |
+
else:
|
316 |
+
asym = asym[:, 1:]
|
317 |
+
asym = asym.mean()
|
318 |
+
return -asym
|
319 |
+
|
320 |
+
|
321 |
+
class AsymLoss_v2(nn.Module):
|
322 |
+
def __init__(self, alpha: float = 0.5, gamma: float = 2.0, smooth: float = 1e-5) -> None:
|
323 |
+
super().__init__()
|
324 |
+
self.alpha = alpha
|
325 |
+
self.gamma = gamma
|
326 |
+
self.smooth = smooth
|
327 |
+
|
328 |
+
def forward(self, logits: torch.Tensor, targets: torch.Tensor) -> torch.Tensor:
|
329 |
+
probs = F.softmax(logits, dim=1)
|
330 |
+
targets_one_hot = F.one_hot(targets, num_classes=probs.size(1)).permute(0, 3, 1, 2).float()
|
331 |
+
pos_loss = -self.alpha * (1 - probs) ** self.gamma * targets_one_hot * torch.log(probs + self.smooth)
|
332 |
+
neg_loss = -(1 - self.alpha) * probs ** self.gamma * (1 - targets_one_hot) * torch.log(1 - probs + self.smooth)
|
333 |
+
loss = pos_loss + neg_loss
|
334 |
+
return loss.mean()
|
335 |
+
|
336 |
+
|
337 |
+
class ExpLog_loss(nn.Module):
|
338 |
+
def __init__(self, soft_dice_kwargs: Dict, wce_kwargs: Dict, gamma: float = 0.3) -> None:
|
339 |
+
super().__init__()
|
340 |
+
self.wce = WeightedCrossEntropyLoss(**wce_kwargs)
|
341 |
+
self.dc = SoftDiceLoss_v2(**soft_dice_kwargs)
|
342 |
+
self.gamma = gamma
|
343 |
+
|
344 |
+
def forward(self, net_output: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
|
345 |
+
dc_loss = -self.dc(net_output, target)
|
346 |
+
wce_loss = self.wce(net_output, target)
|
347 |
+
explog_loss = 0.8*torch.pow(-torch.log(torch.clamp(dc_loss, 1e-6)), self.gamma) + 0.2*wce_loss
|
348 |
+
return explog_loss
|
349 |
+
|
350 |
+
|
351 |
+
class FocalLoss(nn.Module):
|
352 |
+
def __init__(self, apply_nonlin: Callable | None = softmax_helper, alpha: float | List[float] | np.ndarray | None = None, gamma: int = 2, balance_index: int = 0, smooth: float = 1e-4, size_average: bool = True) -> None:
|
353 |
+
super().__init__()
|
354 |
+
self.apply_nonlin = apply_nonlin
|
355 |
+
self.alpha = alpha
|
356 |
+
self.gamma = gamma
|
357 |
+
self.balance_index = balance_index
|
358 |
+
self.smooth = smooth
|
359 |
+
self.size_average = size_average
|
360 |
+
if self.smooth is not None:
|
361 |
+
if self.smooth < 0 or self.smooth > 1.0:
|
362 |
+
raise ValueError("smooth value should be in [0,1]")
|
363 |
+
|
364 |
+
def forward(self, logit: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
|
365 |
+
if self.apply_nonlin is not None:
|
366 |
+
logit = self.apply_nonlin(logit)
|
367 |
+
num_class = logit.shape[1]
|
368 |
+
if logit.dim() > 2:
|
369 |
+
logit = logit.view(logit.size(0), logit.size(1), -1)
|
370 |
+
logit = logit.permute(0, 2, 1).contiguous()
|
371 |
+
logit = logit.view(-1, logit.size(-1))
|
372 |
+
target = torch.squeeze(target, 1)
|
373 |
+
target = target.view(-1, 1)
|
374 |
+
alpha = self.alpha
|
375 |
+
if alpha is None:
|
376 |
+
alpha = torch.ones(num_class, 1)
|
377 |
+
elif isinstance(alpha, (list, np.ndarray)):
|
378 |
+
assert len(alpha) == num_class
|
379 |
+
alpha = torch.FloatTensor(alpha).view(num_class, 1)
|
380 |
+
alpha = alpha / alpha.sum()
|
381 |
+
elif isinstance(alpha, float):
|
382 |
+
alpha = torch.ones(num_class, 1)
|
383 |
+
alpha = alpha * (1 - self.alpha)
|
384 |
+
alpha[self.balance_index] = self.alpha
|
385 |
+
else:
|
386 |
+
raise TypeError("Not support alpha type")
|
387 |
+
if alpha.device != logit.device:
|
388 |
+
alpha = alpha.to(logit.device)
|
389 |
+
idx = target.cpu().long()
|
390 |
+
one_hot_key = torch.FloatTensor(target.size(0), num_class).zero_()
|
391 |
+
one_hot_key = one_hot_key.scatter_(1, idx, 1)
|
392 |
+
if one_hot_key.device != logit.device:
|
393 |
+
one_hot_key = one_hot_key.to(logit.device)
|
394 |
+
if self.smooth:
|
395 |
+
one_hot_key = torch.clamp(
|
396 |
+
one_hot_key, self.smooth/(num_class-1), 1.0 - self.smooth)
|
397 |
+
pt = (one_hot_key * logit).sum(1) + self.smooth
|
398 |
+
logpt = pt.log()
|
399 |
+
gamma = self.gamma
|
400 |
+
alpha = alpha[idx]
|
401 |
+
alpha = torch.squeeze(alpha)
|
402 |
+
loss = -1 * alpha * torch.pow((1 - pt), gamma) * logpt
|
403 |
+
if self.size_average:
|
404 |
+
loss = loss.mean()
|
405 |
+
else:
|
406 |
+
loss = loss.sum()
|
407 |
+
return loss
|
408 |
+
|
409 |
+
|
410 |
+
def lovasz_grad(gt_sorted: torch.Tensor) -> torch.Tensor:
|
411 |
+
p = len(gt_sorted)
|
412 |
+
gts = gt_sorted.sum()
|
413 |
+
intersection = gts - gt_sorted.float().cumsum(0)
|
414 |
+
union = gts + (1 - gt_sorted).float().cumsum(0)
|
415 |
+
jaccard = 1. - intersection / union
|
416 |
+
if p > 1:
|
417 |
+
jaccard[1:p] = jaccard[1:p] - jaccard[0:-1]
|
418 |
+
return jaccard
|
419 |
+
|
420 |
+
|
421 |
+
class LovaszSoftmax(nn.Module):
|
422 |
+
def __init__(self, reduction: str = "mean") -> None:
|
423 |
+
super().__init__()
|
424 |
+
self.reduction = reduction
|
425 |
+
|
426 |
+
def prob_flatten(self, input: torch.Tensor, target: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
427 |
+
assert input.dim() in [4, 5]
|
428 |
+
num_class = input.size(1)
|
429 |
+
if input.dim() == 4:
|
430 |
+
input = input.permute(0, 2, 3, 1).contiguous()
|
431 |
+
input_flatten = input.view(-1, num_class)
|
432 |
+
elif input.dim() == 5:
|
433 |
+
input = input.permute(0, 2, 3, 4, 1).contiguous()
|
434 |
+
input_flatten = input.view(-1, num_class)
|
435 |
+
target_flatten = target.view(-1)
|
436 |
+
return input_flatten, target_flatten
|
437 |
+
|
438 |
+
def lovasz_softmax_flat(self, inputs: torch.Tensor, targets: torch.Tensor) -> torch.Tensor:
|
439 |
+
num_classes = inputs.size(1)
|
440 |
+
losses = []
|
441 |
+
for c in range(num_classes):
|
442 |
+
target_c = (targets == c).float()
|
443 |
+
if num_classes == 1:
|
444 |
+
input_c = inputs[:, 0]
|
445 |
+
else:
|
446 |
+
input_c = inputs[:, c]
|
447 |
+
loss_c = (torch.autograd.Variable(target_c) - input_c).abs()
|
448 |
+
loss_c_sorted, loss_index = torch.sort(loss_c, 0, descending=True)
|
449 |
+
target_c_sorted = target_c[loss_index]
|
450 |
+
losses.append(torch.dot(loss_c_sorted, torch.autograd.Variable(lovasz_grad(target_c_sorted))))
|
451 |
+
losses = torch.stack(losses)
|
452 |
+
if self.reduction == "none":
|
453 |
+
loss = losses
|
454 |
+
elif self.reduction == "sum":
|
455 |
+
loss = losses.sum()
|
456 |
+
else:
|
457 |
+
loss = losses.mean()
|
458 |
+
return loss
|
459 |
+
|
460 |
+
def forward(self, inputs: torch.Tensor, targets: torch.Tensor) -> torch.Tensor:
|
461 |
+
inputs, targets = self.prob_flatten(inputs, targets)
|
462 |
+
losses = self.lovasz_softmax_flat(inputs, targets)
|
463 |
+
return losses
|
464 |
+
|
465 |
+
|
466 |
+
class TopKLoss(nn.Module):
|
467 |
+
def __init__(self, weight: torch.Tensor | None = None, ignore_index: int = -100, k: int = 10) -> None:
|
468 |
+
super().__init__()
|
469 |
+
self.k = k
|
470 |
+
self.cross_entropy = nn.CrossEntropyLoss(weight=weight, ignore_index=ignore_index, reduction="none")
|
471 |
+
|
472 |
+
def forward(self, inp: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
|
473 |
+
pixel_losses = self.cross_entropy(inp, target)
|
474 |
+
pixel_losses = pixel_losses.view(-1)
|
475 |
+
num_voxels = pixel_losses.numel()
|
476 |
+
res, _ = torch.topk(pixel_losses, int(num_voxels * self.k / 100), sorted=False)
|
477 |
+
return res.mean()
|
478 |
+
|
479 |
+
|
480 |
+
class WeightedCrossEntropyLoss(torch.nn.CrossEntropyLoss):
|
481 |
+
def __init__(self, weight: torch.Tensor | None = None) -> None:
|
482 |
+
super().__init__()
|
483 |
+
self.weight = weight
|
484 |
+
|
485 |
+
def forward(self, inp: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
|
486 |
+
target = target.long()
|
487 |
+
num_classes = inp.size()[1]
|
488 |
+
i0 = 1
|
489 |
+
i1 = 2
|
490 |
+
while i1 < len(inp.shape):
|
491 |
+
inp = inp.transpose(i0, i1)
|
492 |
+
i0 += 1
|
493 |
+
i1 += 1
|
494 |
+
inp = inp.contiguous()
|
495 |
+
inp = inp.view(-1, num_classes)
|
496 |
+
target = target.view(-1,)
|
497 |
+
wce_loss = torch.nn.CrossEntropyLoss(weight=self.weight)
|
498 |
+
return wce_loss(inp, target)
|
src/models/dino.py
ADDED
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from transformers import Dinov2Backbone
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
import torch.nn.functional as F
|
5 |
+
|
6 |
+
from src.models.segmentation_head import SegmentationHead
|
7 |
+
|
8 |
+
|
9 |
+
class DINOSegmentationModel(nn.Module):
|
10 |
+
def __init__(self, image_size: int = 224, num_classes: int = 18) -> None:
|
11 |
+
super().__init__()
|
12 |
+
self.mean = [0.485, 0.456, 0.406]
|
13 |
+
self.std = [0.229, 0.224, 0.225]
|
14 |
+
self.image_size = image_size
|
15 |
+
model_name = "facebook/dinov2-small"
|
16 |
+
self.backbone = Dinov2Backbone.from_pretrained(model_name)
|
17 |
+
for param in self.backbone.parameters():
|
18 |
+
param.requires_grad = False
|
19 |
+
self.segmentation_head = SegmentationHead(in_channels=384, num_classes=num_classes)
|
20 |
+
|
21 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
22 |
+
batch_size, channels, height, width = x.size()
|
23 |
+
assert height == width == self.image_size, "The image must match the size required by the DINO model"
|
24 |
+
features = self.backbone(pixel_values=x).feature_maps[0]
|
25 |
+
masks = self.segmentation_head(features)
|
26 |
+
return masks
|
27 |
+
|
28 |
+
|
29 |
+
def main() -> None:
|
30 |
+
# model = DINOSegmentationModel()
|
31 |
+
model = SegmentationHead(384, 18)
|
32 |
+
num_params = sum([p.numel() for p in model.parameters()])
|
33 |
+
print(f"params: {num_params/1e6:.2f} M")
|
34 |
+
|
35 |
+
|
36 |
+
if __name__ == "__main__":
|
37 |
+
main()
|
src/models/segmentation_head.py
ADDED
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
|
4 |
+
|
5 |
+
class SegmentationHead(nn.Module):
|
6 |
+
def __init__(self, in_channels: int, num_classes: int):
|
7 |
+
super().__init__()
|
8 |
+
self.head = nn.Sequential(
|
9 |
+
nn.Conv2d(in_channels, 256, kernel_size=3, padding=1),
|
10 |
+
nn.BatchNorm2d(256),
|
11 |
+
nn.ReLU(),
|
12 |
+
nn.Conv2d(256, 256, kernel_size=3, padding=1),
|
13 |
+
nn.BatchNorm2d(256),
|
14 |
+
nn.ReLU(),
|
15 |
+
nn.Upsample(size=(64, 64), mode="bilinear"),
|
16 |
+
nn.Conv2d(256, 128, kernel_size=3, padding=1),
|
17 |
+
nn.BatchNorm2d(128),
|
18 |
+
nn.ReLU(),
|
19 |
+
nn.Conv2d(128, 128, kernel_size=3, padding=1),
|
20 |
+
nn.BatchNorm2d(128),
|
21 |
+
nn.ReLU(),
|
22 |
+
nn.Upsample(size=(128, 128), mode="bilinear"),
|
23 |
+
nn.Conv2d(128, 64, kernel_size=3, padding=1),
|
24 |
+
nn.BatchNorm2d(64),
|
25 |
+
nn.ReLU(),
|
26 |
+
nn.Conv2d(64, 64, kernel_size=3, padding=1),
|
27 |
+
nn.BatchNorm2d(64),
|
28 |
+
nn.ReLU(),
|
29 |
+
nn.Upsample(size=(224, 224), mode="bilinear"),
|
30 |
+
nn.Conv2d(64, 32, kernel_size=3, padding=1),
|
31 |
+
nn.BatchNorm2d(32),
|
32 |
+
nn.ReLU(),
|
33 |
+
nn.Conv2d(32, 32, kernel_size=3, padding=1),
|
34 |
+
nn.BatchNorm2d(32),
|
35 |
+
nn.ReLU(),
|
36 |
+
nn.Conv2d(32, num_classes, kernel_size=3, padding=1),
|
37 |
+
)
|
38 |
+
|
39 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
40 |
+
return self.head(x)
|
src/models/unet.py
ADDED
@@ -0,0 +1,171 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torch import nn
|
3 |
+
|
4 |
+
|
5 |
+
class UNet(nn.Module):
|
6 |
+
def __init__(self) -> None:
|
7 |
+
super().__init__()
|
8 |
+
self.mean = [0.485, 0.456, 0.406]
|
9 |
+
self.std = [0.229, 0.224, 0.225]
|
10 |
+
# Downsampler
|
11 |
+
self.enc_conv0 = nn.Sequential(
|
12 |
+
nn.Conv2d(in_channels=3, out_channels=64, kernel_size=3, stride=1, padding=1),
|
13 |
+
nn.LeakyReLU(inplace=True),
|
14 |
+
nn.BatchNorm2d(64),
|
15 |
+
nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1),
|
16 |
+
nn.LeakyReLU(inplace=True),
|
17 |
+
nn.BatchNorm2d(64),
|
18 |
+
nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1),
|
19 |
+
nn.LeakyReLU(inplace=True),
|
20 |
+
nn.BatchNorm2d(64)
|
21 |
+
)
|
22 |
+
self.pool0 = nn.MaxPool2d(kernel_size=2, stride=2)
|
23 |
+
self.enc_conv1 = nn.Sequential(
|
24 |
+
nn.Conv2d(in_channels=64, out_channels=128, kernel_size=3, stride=1, padding=1),
|
25 |
+
nn.LeakyReLU(inplace=True),
|
26 |
+
nn.BatchNorm2d(128),
|
27 |
+
nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, stride=1, padding=1),
|
28 |
+
nn.LeakyReLU(inplace=True),
|
29 |
+
nn.BatchNorm2d(128),
|
30 |
+
nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, stride=1, padding=1),
|
31 |
+
nn.LeakyReLU(inplace=True),
|
32 |
+
nn.BatchNorm2d(128)
|
33 |
+
)
|
34 |
+
self.pool1 = nn.MaxPool2d(kernel_size=2, stride=2)
|
35 |
+
self.enc_conv2 = nn.Sequential(
|
36 |
+
nn.Conv2d(in_channels=128, out_channels=256, kernel_size=3, stride=1, padding=1),
|
37 |
+
nn.LeakyReLU(inplace=True),
|
38 |
+
nn.BatchNorm2d(256),
|
39 |
+
nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, stride=1, padding=1),
|
40 |
+
nn.LeakyReLU(inplace=True),
|
41 |
+
nn.BatchNorm2d(256),
|
42 |
+
nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, stride=1, padding=1),
|
43 |
+
nn.LeakyReLU(inplace=True),
|
44 |
+
nn.BatchNorm2d(256)
|
45 |
+
)
|
46 |
+
self.pool2 = nn.MaxPool2d(kernel_size=2, stride=2)
|
47 |
+
self.enc_conv3 = nn.Sequential(
|
48 |
+
nn.Conv2d(in_channels=256, out_channels=512, kernel_size=3, stride=1, padding=1),
|
49 |
+
nn.LeakyReLU(inplace=True),
|
50 |
+
nn.BatchNorm2d(512),
|
51 |
+
nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, stride=1, padding=1),
|
52 |
+
nn.LeakyReLU(inplace=True),
|
53 |
+
nn.BatchNorm2d(512),
|
54 |
+
nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, stride=1, padding=1),
|
55 |
+
nn.LeakyReLU(inplace=True),
|
56 |
+
nn.BatchNorm2d(512)
|
57 |
+
)
|
58 |
+
self.pool3 = nn.MaxPool2d(kernel_size=2, stride=2)
|
59 |
+
|
60 |
+
# bottleneck
|
61 |
+
self.bottleneck_conv = nn.Sequential(
|
62 |
+
nn.Conv2d(in_channels=512, out_channels=1024, kernel_size=3, stride=1, padding=1),
|
63 |
+
nn.LeakyReLU(inplace=True),
|
64 |
+
nn.BatchNorm2d(1024),
|
65 |
+
nn.Conv2d(in_channels=1024, out_channels=1024, kernel_size=3, stride=1, padding=1),
|
66 |
+
nn.LeakyReLU(inplace=True),
|
67 |
+
nn.BatchNorm2d(1024),
|
68 |
+
nn.Conv2d(in_channels=1024, out_channels=1024, kernel_size=3, stride=1, padding=1),
|
69 |
+
nn.LeakyReLU(inplace=True),
|
70 |
+
nn.BatchNorm2d(1024)
|
71 |
+
)
|
72 |
+
|
73 |
+
# Upsampler
|
74 |
+
|
75 |
+
self.upsample0 = nn.Sequential(
|
76 |
+
nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True),
|
77 |
+
nn.Conv2d(in_channels=1024, out_channels=512, kernel_size=3, padding=1),
|
78 |
+
)
|
79 |
+
|
80 |
+
self.dec_conv0 = nn.Sequential(
|
81 |
+
nn.Conv2d(in_channels=1024, out_channels=512, kernel_size=3, stride=1, padding=1),
|
82 |
+
nn.LeakyReLU(inplace=True),
|
83 |
+
nn.BatchNorm2d(512),
|
84 |
+
nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, stride=1, padding=1),
|
85 |
+
nn.LeakyReLU(inplace=True),
|
86 |
+
nn.BatchNorm2d(512),
|
87 |
+
nn.Conv2d(in_channels=512, out_channels=512, kernel_size=3, stride=1, padding=1),
|
88 |
+
nn.LeakyReLU(inplace=True),
|
89 |
+
nn.BatchNorm2d(512)
|
90 |
+
)
|
91 |
+
|
92 |
+
self.upsample1 = nn.Sequential(
|
93 |
+
nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True),
|
94 |
+
nn.Conv2d(in_channels=512, out_channels=256, kernel_size=3, padding=1),
|
95 |
+
)
|
96 |
+
|
97 |
+
self.dec_conv1 = nn.Sequential(
|
98 |
+
nn.Conv2d(in_channels=512, out_channels=256, kernel_size=3, stride=1, padding=1),
|
99 |
+
nn.LeakyReLU(inplace=True),
|
100 |
+
nn.BatchNorm2d(256),
|
101 |
+
nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, stride=1, padding=1),
|
102 |
+
nn.LeakyReLU(inplace=True),
|
103 |
+
nn.BatchNorm2d(256),
|
104 |
+
nn.Conv2d(in_channels=256, out_channels=256, kernel_size=3, stride=1, padding=1),
|
105 |
+
nn.LeakyReLU(inplace=True),
|
106 |
+
nn.BatchNorm2d(256)
|
107 |
+
)
|
108 |
+
|
109 |
+
self.upsample2 = nn.Sequential(
|
110 |
+
nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True),
|
111 |
+
nn.Conv2d(in_channels=256, out_channels=128, kernel_size=3, padding=1),
|
112 |
+
)
|
113 |
+
|
114 |
+
self.dec_conv2 = nn.Sequential(
|
115 |
+
nn.Conv2d(in_channels=256, out_channels=128, kernel_size=3, stride=1, padding=1),
|
116 |
+
nn.LeakyReLU(inplace=True),
|
117 |
+
nn.BatchNorm2d(128),
|
118 |
+
nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, stride=1, padding=1),
|
119 |
+
nn.LeakyReLU(inplace=True),
|
120 |
+
nn.BatchNorm2d(128),
|
121 |
+
nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, stride=1, padding=1),
|
122 |
+
nn.LeakyReLU(inplace=True),
|
123 |
+
nn.BatchNorm2d(128)
|
124 |
+
)
|
125 |
+
|
126 |
+
self.upsample3 = nn.Sequential(
|
127 |
+
nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True),
|
128 |
+
nn.Conv2d(in_channels=128, out_channels=64, kernel_size=3, padding=1),
|
129 |
+
)
|
130 |
+
|
131 |
+
self.dec_conv3 = nn.Sequential(
|
132 |
+
nn.Conv2d(in_channels=128, out_channels=64, kernel_size=3, stride=1, padding=1),
|
133 |
+
nn.LeakyReLU(inplace=True),
|
134 |
+
nn.BatchNorm2d(64),
|
135 |
+
nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1),
|
136 |
+
nn.LeakyReLU(inplace=True),
|
137 |
+
nn.BatchNorm2d(64),
|
138 |
+
nn.Conv2d(in_channels=64, out_channels=18, kernel_size=1, stride=1, padding=0)
|
139 |
+
)
|
140 |
+
|
141 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
142 |
+
# encoder
|
143 |
+
e0 = self.enc_conv0(x)
|
144 |
+
e1 = self.pool0(e0)
|
145 |
+
e1 = self.enc_conv1(e1)
|
146 |
+
e2 = self.pool1(e1)
|
147 |
+
e2 = self.enc_conv2(e2)
|
148 |
+
e3 = self.pool2(e2)
|
149 |
+
e3 = self.enc_conv3(e3)
|
150 |
+
|
151 |
+
# bottleneck
|
152 |
+
b = self.pool3(e3)
|
153 |
+
b = self.bottleneck_conv(b)
|
154 |
+
|
155 |
+
# decoder
|
156 |
+
d0 = self.upsample0(b)
|
157 |
+
d0 = torch.cat([d0, e3], dim=1)
|
158 |
+
d0 = self.dec_conv0(d0)
|
159 |
+
|
160 |
+
d1 = self.upsample1(d0)
|
161 |
+
d1 = torch.cat([d1, e2], dim=1)
|
162 |
+
d1 = self.dec_conv1(d1)
|
163 |
+
|
164 |
+
d2 = self.upsample2(d1)
|
165 |
+
d2 = torch.cat([d2, e1], dim=1)
|
166 |
+
d2 = self.dec_conv2(d2)
|
167 |
+
|
168 |
+
d3 = self.upsample3(d2)
|
169 |
+
d3 = torch.cat([d3, e0], dim=1)
|
170 |
+
d3 = self.dec_conv3(d3)
|
171 |
+
return d3
|
src/models/vit.py
ADDED
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
from transformers import ViTModel
|
4 |
+
|
5 |
+
from src.models.segmentation_head import SegmentationHead
|
6 |
+
|
7 |
+
|
8 |
+
class ViTSegmentation(nn.Module):
|
9 |
+
def __init__(self, image_size: int = 224, num_classes: int = 18) -> None:
|
10 |
+
super().__init__()
|
11 |
+
self.mean = [0.5, 0.5, 0.5]
|
12 |
+
self.std = [0.5, 0.5, 0.5]
|
13 |
+
self.backbone = ViTModel.from_pretrained("google/vit-base-patch16-224")
|
14 |
+
self.segmentation_head = SegmentationHead(in_channels=768, num_classes=num_classes)
|
15 |
+
for param in self.backbone.parameters():
|
16 |
+
param.requires_grad = False
|
17 |
+
|
18 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
19 |
+
batch_size, channels, height, width = x.size()
|
20 |
+
assert height == width == self.backbone.config.image_size, "The image must match the size required by the ViT model"
|
21 |
+
outputs = self.backbone(pixel_values=x).last_hidden_state
|
22 |
+
patch_dim = int(height / self.backbone.config.patch_size)
|
23 |
+
outputs = outputs[:, 1:, :]
|
24 |
+
outputs = outputs.permute(0, 2, 1).view(batch_size, -1, patch_dim, patch_dim)
|
25 |
+
masks = self.segmentation_head(outputs)
|
26 |
+
return masks
|
27 |
+
|
28 |
+
|
29 |
+
def main() -> None:
|
30 |
+
model = ViTSegmentation(image_size=224, num_classes=18)
|
31 |
+
num_params = sum([p.numel() for p in model.parameters()])
|
32 |
+
print(f"params: {num_params/1e6:.2f} M")
|
33 |
+
|
34 |
+
|
35 |
+
if __name__ == "__main__":
|
36 |
+
main()
|
src/train.py
ADDED
@@ -0,0 +1,264 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Tuple
|
2 |
+
from pathlib import Path
|
3 |
+
from tqdm import tqdm
|
4 |
+
from accelerate import Accelerator
|
5 |
+
from accelerate.utils import set_seed
|
6 |
+
from matplotlib import cm
|
7 |
+
import numpy as np
|
8 |
+
import matplotlib.pyplot as plt
|
9 |
+
import argparse
|
10 |
+
import json
|
11 |
+
import wandb
|
12 |
+
from datasets import load_dataset
|
13 |
+
import torch
|
14 |
+
from torch.utils.data import DataLoader
|
15 |
+
|
16 |
+
from models.unet import UNet
|
17 |
+
from dataset import SegmentationDataset, collate_fn
|
18 |
+
from utils import get_transform, mask_transform, EMA
|
19 |
+
from get_loss import get_composite_criterion
|
20 |
+
from models.vit import ViTSegmentation
|
21 |
+
from models.dino import DINOSegmentationModel
|
22 |
+
|
23 |
+
|
24 |
+
color_map = cm.get_cmap('tab20', 18)
|
25 |
+
fixed_colors = np.array([color_map(i)[:3] for i in range(18)]) * 255
|
26 |
+
|
27 |
+
|
28 |
+
def mask_to_color(mask: np.ndarray) -> np.ndarray:
|
29 |
+
h, w = mask.shape
|
30 |
+
color_mask = np.zeros((h, w, 3), dtype=np.uint8)
|
31 |
+
for class_idx in range(18):
|
32 |
+
color_mask[mask == class_idx] = fixed_colors[class_idx]
|
33 |
+
return color_mask
|
34 |
+
|
35 |
+
|
36 |
+
def create_combined_image(
|
37 |
+
x: torch.Tensor,
|
38 |
+
y: torch.Tensor,
|
39 |
+
y_pred: torch.Tensor,
|
40 |
+
mean: list[float] = [0.485, 0.456, 0.406],
|
41 |
+
std: list[float] = [0.229, 0.224, 0.225]
|
42 |
+
) -> np.ndarray:
|
43 |
+
batch_size, _, height, width = x.shape
|
44 |
+
combined_height = height * 3
|
45 |
+
combined_width = width * batch_size
|
46 |
+
combined_image = np.zeros((combined_height, combined_width, 3), dtype=np.uint8)
|
47 |
+
|
48 |
+
for i in range(batch_size):
|
49 |
+
image = x[i].cpu().permute(1, 2, 0).numpy()
|
50 |
+
image = (image * std + mean).clip(0, 1)
|
51 |
+
image = (image * 255).astype(np.uint8)
|
52 |
+
true_mask = y[i].cpu().numpy()
|
53 |
+
true_mask_color = mask_to_color(true_mask)
|
54 |
+
pred_mask = y_pred[i].cpu().numpy()
|
55 |
+
pred_mask_color = mask_to_color(pred_mask)
|
56 |
+
combined_image[:height, i * width:(i + 1) * width, :] = image
|
57 |
+
combined_image[height:2 * height, i * width:(i + 1) * width, :] = true_mask_color
|
58 |
+
combined_image[2 * height:, i * width:(i + 1) * width, :] = pred_mask_color
|
59 |
+
return combined_image
|
60 |
+
|
61 |
+
|
62 |
+
def compute_metrics(y_pred: torch.Tensor, y: torch.Tensor, num_classes: int = 18) -> Tuple[float, float, float, float, float, float]:
|
63 |
+
pred_mask = y_pred.unsqueeze(-1) == torch.arange(num_classes, device=y_pred.device).reshape(1, 1, 1, -1)
|
64 |
+
target_mask = y.unsqueeze(-1) == torch.arange(num_classes, device=y.device).reshape(1, 1, 1, -1)
|
65 |
+
class_present = (target_mask.sum(dim=(0, 1, 2)) > 0).float()
|
66 |
+
tp = (pred_mask & target_mask).sum(dim=(0, 1, 2)).float()
|
67 |
+
fp = (pred_mask & ~target_mask).sum(dim=(0, 1, 2)).float()
|
68 |
+
fn = (~pred_mask & target_mask).sum(dim=(0, 1, 2)).float()
|
69 |
+
tn = (~pred_mask & ~target_mask).sum(dim=(0, 1, 2)).float()
|
70 |
+
overall_tp = tp.sum()
|
71 |
+
overall_fp = fp.sum()
|
72 |
+
overall_fn = fn.sum()
|
73 |
+
overall_tn = tn.sum()
|
74 |
+
precision = tp / (tp + fp).clamp(min=1e-8)
|
75 |
+
recall = tp / (tp + fn).clamp(min=1e-8)
|
76 |
+
accuracy = (tp + tn) / (tp + tn + fp + fn)
|
77 |
+
macro_precision = ((precision * class_present).sum() / class_present.sum().clamp(min=1e-8)).item()
|
78 |
+
macro_recall = ((recall * class_present).sum() / class_present.sum().clamp(min=1e-8)).item()
|
79 |
+
macro_accuracy = accuracy.mean().item()
|
80 |
+
micro_precision = (overall_tp / (overall_tp + overall_fp).clamp(min=1e-8)).item()
|
81 |
+
micro_recall = (overall_tp / (overall_tp + overall_fn).clamp(min=1e-8)).item()
|
82 |
+
global_accuracy = ((y_pred == y).sum() / (y.shape[0] * y.shape[1] * y.shape[2])).item()
|
83 |
+
return macro_precision, macro_recall, macro_accuracy, micro_precision, micro_recall, global_accuracy
|
84 |
+
|
85 |
+
|
86 |
+
def parse_args():
|
87 |
+
parser = argparse.ArgumentParser(description="Train a model on human parsing dataset")
|
88 |
+
parser.add_argument("--data-path", type=str, default="mattmdjaga/human_parsing_dataset", help="Path to the data")
|
89 |
+
parser.add_argument("--batch-size", type=int, default=32, help="Batch size for training and testing")
|
90 |
+
parser.add_argument("--pin-memory", type=bool, default=True, help="Pin memory for DataLoader")
|
91 |
+
parser.add_argument("--num-workers", type=int, default=0, help="Number of workers for DataLoader")
|
92 |
+
parser.add_argument("--num-epochs", type=int, default=15, help="Number of training epochs")
|
93 |
+
parser.add_argument("--optimizer", type=str, default="AdamW", help="Optimizer type")
|
94 |
+
parser.add_argument("--learning-rate", type=float, default=1e-4, help="Learning rate for the optimizer")
|
95 |
+
parser.add_argument("--max-norm", type=float, default=1.0, help="Maximum gradient norm for clipping")
|
96 |
+
parser.add_argument("--logs-dir", type=str, default="dino-logs", help="Directory for saving logs")
|
97 |
+
parser.add_argument("--model", type=str, default="dino", choices=["unet", "vit", "dino"], help="Model class name")
|
98 |
+
parser.add_argument("--losses-path", type=str, default="losses_config.json", help="Path to the losses")
|
99 |
+
parser.add_argument("--mixed-precision", type=str, default="fp16", choices=["fp16", "bf16", "fp8", "no"], help="Value of the mixed precision")
|
100 |
+
parser.add_argument("--gradient-accumulation-steps", type=int, default=2, help="Value of the gradient accumulation steps")
|
101 |
+
parser.add_argument("--project-name", type=str, default="human_parsing_segmentation_ttk", help="WandB project name")
|
102 |
+
parser.add_argument("--save-frequency", type=int, default=4, help="Frequency of saving model weights")
|
103 |
+
parser.add_argument("--log-steps", type=int, default=400, help="Number of steps between logging training images and metrics")
|
104 |
+
parser.add_argument("--seed", type=int, default=42, help="Value of the seed")
|
105 |
+
return parser.parse_args()
|
106 |
+
|
107 |
+
|
108 |
+
def main() -> None:
|
109 |
+
args = parse_args()
|
110 |
+
|
111 |
+
set_seed(args.seed)
|
112 |
+
accelerator = Accelerator(gradient_accumulation_steps=args.gradient_accumulation_steps, mixed_precision=args.mixed_precision)
|
113 |
+
|
114 |
+
with open(args.losses_path, "r") as fp:
|
115 |
+
losses_config = json.load(fp)
|
116 |
+
|
117 |
+
with accelerator.main_process_first():
|
118 |
+
logs_dir = Path(args.logs_dir)
|
119 |
+
logs_dir.mkdir(exist_ok=True)
|
120 |
+
wandb.init(project=args.project_name, dir=logs_dir)
|
121 |
+
wandb.save(args.losses_path)
|
122 |
+
|
123 |
+
optimizer_class = getattr(torch.optim, args.optimizer)
|
124 |
+
|
125 |
+
if args.model == "unet":
|
126 |
+
model = UNet().to(accelerator.device)
|
127 |
+
optimizer = optimizer_class(model.parameters(), lr=args.learning_rate)
|
128 |
+
elif args.model == "vit":
|
129 |
+
model = ViTSegmentation().to(accelerator.device)
|
130 |
+
optimizer = optimizer_class(model.parameters(), lr=args.learning_rate)
|
131 |
+
elif args.model == "dino":
|
132 |
+
model = DINOSegmentationModel().to(accelerator.device)
|
133 |
+
optimizer = optimizer_class(model.segmentation_head.parameters(), lr=args.learning_rate)
|
134 |
+
else:
|
135 |
+
raise NotImplementedError("Incorrect model name")
|
136 |
+
|
137 |
+
transform = get_transform(model.mean, model.std)
|
138 |
+
|
139 |
+
dataset = load_dataset(args.data_path, split="train")
|
140 |
+
train_dataset = SegmentationDataset(dataset, train=True, transform=transform, target_transform=mask_transform)
|
141 |
+
valid_dataset = SegmentationDataset(dataset, train=False, transform=transform, target_transform=mask_transform)
|
142 |
+
|
143 |
+
train_loader = DataLoader(train_dataset, batch_size=args.batch_size, shuffle=True, num_workers=args.num_workers, pin_memory=args.pin_memory, collate_fn=collate_fn)
|
144 |
+
valid_loader = DataLoader(valid_dataset, batch_size=args.batch_size, shuffle=False, num_workers=args.num_workers, pin_memory=args.pin_memory, collate_fn=collate_fn)
|
145 |
+
|
146 |
+
criterion = get_composite_criterion(losses_config)
|
147 |
+
|
148 |
+
lr_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, args.num_epochs * len(train_loader))
|
149 |
+
|
150 |
+
model, optimizer, train_loader, lr_scheduler = accelerator.prepare(model, optimizer, train_loader, lr_scheduler)
|
151 |
+
|
152 |
+
best_accuracy = 0
|
153 |
+
|
154 |
+
print(f"params: {sum([p.numel() for p in model.parameters()])/1e6:.2f} M")
|
155 |
+
print(f"trainable params: {sum([p.numel() for p in model.parameters() if p.requires_grad])/1e6:.2f} M")
|
156 |
+
|
157 |
+
train_loss_ema, train_macro_precision_ema, train_macro_recall_ema, train_macro_accuracy_ema, train_micro_precision_ema, train_micro_recall_ema, train_global_accuracy_ema = EMA(), EMA(), EMA(), EMA(), EMA(), EMA(), EMA()
|
158 |
+
for epoch in range(1, args.num_epochs + 1):
|
159 |
+
model.train()
|
160 |
+
print(f"trainable params: {sum([p.numel() for p in model.parameters() if p.requires_grad])/1e6:.2f} M")
|
161 |
+
exit()
|
162 |
+
pbar = tqdm(train_loader, desc=f"Train epoch {epoch}/{args.num_epochs}")
|
163 |
+
for index, (x, y) in enumerate(pbar):
|
164 |
+
x, y = x.to(accelerator.device), y.squeeze(1).to(accelerator.device)
|
165 |
+
with accelerator.accumulate(model):
|
166 |
+
with accelerator.autocast():
|
167 |
+
output = model(x)
|
168 |
+
loss = criterion(output, y)
|
169 |
+
accelerator.backward(loss)
|
170 |
+
train_loss = loss.item()
|
171 |
+
grad_norm = None
|
172 |
+
_, y_pred = output.max(dim=1)
|
173 |
+
train_macro_precision, train_macro_recall, train_macro_accuracy, train_micro_precision, train_micro_recall, train_global_accuracy = compute_metrics(y_pred, y)
|
174 |
+
if accelerator.sync_gradients:
|
175 |
+
grad_norm = accelerator.clip_grad_norm_(model.parameters(), args.max_norm).item()
|
176 |
+
optimizer.step()
|
177 |
+
lr_scheduler.step()
|
178 |
+
optimizer.zero_grad()
|
179 |
+
if (index + 1) % args.log_steps == 0 and accelerator.is_main_process:
|
180 |
+
images_to_log = []
|
181 |
+
combined_image = create_combined_image(x, y, y_pred)
|
182 |
+
images_to_log.append(wandb.Image(combined_image, caption=f"Combined Image (Train, Epoch {epoch}, Batch {index})"))
|
183 |
+
wandb.log({"train_samples": images_to_log})
|
184 |
+
pbar.set_postfix({"loss": train_loss_ema(train_loss), "macro_precision": train_macro_precision_ema(train_macro_precision), "macro_recall": train_macro_recall_ema(train_macro_recall), "macro_accuracy": train_macro_accuracy_ema(train_macro_accuracy), "micro_precision": train_micro_precision_ema(train_micro_precision), "micro_recall": train_micro_recall_ema(train_micro_recall), "global_accuracy": train_global_accuracy_ema(train_global_accuracy)})
|
185 |
+
log_data = {
|
186 |
+
"train/epoch": epoch,
|
187 |
+
"train/loss": train_loss,
|
188 |
+
"train/macro_accuracy": train_macro_accuracy,
|
189 |
+
"train/learning_rate": optimizer.param_groups[0]["lr"],
|
190 |
+
"train/macro_precision": train_macro_precision,
|
191 |
+
"train/macro_recall": train_macro_recall,
|
192 |
+
"train/micro_precision": train_micro_precision,
|
193 |
+
"train/micro_recall": train_micro_recall,
|
194 |
+
"train/global_accuracy": train_global_accuracy,
|
195 |
+
}
|
196 |
+
if grad_norm is not None:
|
197 |
+
log_data["train/grad_norm"] = grad_norm
|
198 |
+
if accelerator.is_main_process:
|
199 |
+
wandb.log(log_data)
|
200 |
+
accelerator.wait_for_everyone()
|
201 |
+
|
202 |
+
model.eval()
|
203 |
+
valid_loss, valid_macro_accuracies, valid_macro_precisions, valid_macro_recalls, valid_global_accuracies, valid_micro_precisions, valid_micro_recalls = 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0
|
204 |
+
with torch.inference_mode():
|
205 |
+
pbar = tqdm(valid_loader, desc=f"Val epoch {epoch}/{args.num_epochs}")
|
206 |
+
for index, (x, y) in enumerate(valid_loader):
|
207 |
+
x, y = x.to(accelerator.device), y.squeeze(1).to(accelerator.device)
|
208 |
+
output = model(x)
|
209 |
+
_, y_pred = output.max(dim=1)
|
210 |
+
if (index + 1) % args.log_steps == 0 and accelerator.is_main_process:
|
211 |
+
images_to_log = []
|
212 |
+
combined_image = create_combined_image(x, y, y_pred)
|
213 |
+
images_to_log.append(wandb.Image(combined_image, caption=f"Combined Image (Validation, Epoch {epoch})"))
|
214 |
+
wandb.log({"valid_samples": images_to_log})
|
215 |
+
valid_macro_precision, valid_macro_recall, valid_macro_accuracy, valid_micro_precision, valid_micro_recall, valid_global_accuracy = compute_metrics(y_pred, y)
|
216 |
+
valid_macro_precisions += valid_macro_precision
|
217 |
+
valid_macro_recalls += valid_macro_recall
|
218 |
+
valid_macro_accuracies += valid_macro_accuracy
|
219 |
+
valid_micro_precisions += valid_micro_precision
|
220 |
+
valid_micro_recalls += valid_micro_recall
|
221 |
+
valid_global_accuracies += valid_global_accuracy
|
222 |
+
loss = criterion(output, y)
|
223 |
+
valid_loss += loss.item()
|
224 |
+
valid_loss = valid_loss / len(valid_loader)
|
225 |
+
valid_macro_accuracies = valid_macro_accuracies / len(valid_loader)
|
226 |
+
valid_macro_precisions = valid_macro_precisions / len(valid_loader)
|
227 |
+
valid_macro_recalls = valid_macro_recalls / len(valid_loader)
|
228 |
+
valid_global_accuracies = valid_global_accuracies / len(valid_loader)
|
229 |
+
valid_micro_precisions = valid_micro_precisions / len(valid_loader)
|
230 |
+
valid_micro_recalls = valid_micro_recalls / len(valid_loader)
|
231 |
+
accelerator.print(f"loss: {valid_loss:.3f}, valid_macro_precision: {valid_macro_precisions:.3f}, valid_macro_recall: {valid_macro_recalls:.3f}, valid_macro_accuracy: {valid_macro_accuracies:.3f}, valid_micro_precision: {valid_micro_precisions:.3f}, valid_micro_recall: {valid_micro_recalls:.3f}, valid_global_accuracy: {valid_global_accuracies:.3f}")
|
232 |
+
if accelerator.is_main_process:
|
233 |
+
wandb.log(
|
234 |
+
{
|
235 |
+
"val/epoch": epoch,
|
236 |
+
"val/loss": valid_loss,
|
237 |
+
"val/macro_accuracy": valid_macro_accuracies,
|
238 |
+
"val/macro_precision": valid_macro_precisions,
|
239 |
+
"val/macro_recall": valid_macro_recalls,
|
240 |
+
"val/global_accuracy": valid_global_accuracies,
|
241 |
+
"val/micro_precision": valid_micro_precisions,
|
242 |
+
"val/micro_recall": valid_micro_recalls,
|
243 |
+
}
|
244 |
+
)
|
245 |
+
if valid_global_accuracies > best_accuracy:
|
246 |
+
best_accuracy = valid_global_accuracies
|
247 |
+
if args.model in ["dino", "vit"]:
|
248 |
+
accelerator.save(model.segmentation_head.state_dict(), logs_dir / f"checkpoint-best.pth")
|
249 |
+
else:
|
250 |
+
accelerator.save(model.state_dict(), logs_dir / f"checkpoint-best.pth")
|
251 |
+
accelerator.print(f"new best_accuracy {best_accuracy}, {epoch=}")
|
252 |
+
if epoch % args.save_frequency == 0:
|
253 |
+
if args.model in ["dino", "vit"]:
|
254 |
+
accelerator.save(model.segmentation_head.state_dict(), logs_dir / f"checkpoint-{epoch:09}.pth")
|
255 |
+
else:
|
256 |
+
accelerator.save(model.state_dict(), logs_dir / f"checkpoint-{epoch:09}.pth")
|
257 |
+
accelerator.wait_for_everyone()
|
258 |
+
|
259 |
+
accelerator.wait_for_everyone()
|
260 |
+
wandb.finish()
|
261 |
+
|
262 |
+
|
263 |
+
if __name__ == "__main__":
|
264 |
+
main()
|
src/utils.py
ADDED
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torchvision.transforms as T
|
3 |
+
import PIL.Image
|
4 |
+
from typing import List
|
5 |
+
|
6 |
+
|
7 |
+
size = (224, 224)
|
8 |
+
|
9 |
+
|
10 |
+
class ResizeWithPadding:
|
11 |
+
def __init__(self, target_size: int = 224, fill: int = 0, mode: str = "RGB") -> None:
|
12 |
+
self.target_size = target_size
|
13 |
+
self.fill = fill
|
14 |
+
self.mode = mode
|
15 |
+
|
16 |
+
def __call__(self, image: PIL.Image) -> PIL.Image:
|
17 |
+
original_width, original_height = image.size
|
18 |
+
aspect_ratio = original_width / original_height
|
19 |
+
if aspect_ratio > 1:
|
20 |
+
new_width = self.target_size
|
21 |
+
new_height = int(self.target_size / aspect_ratio)
|
22 |
+
else:
|
23 |
+
new_height = self.target_size
|
24 |
+
new_width = int(self.target_size * aspect_ratio)
|
25 |
+
resized_image = image.resize((new_width, new_height), PIL.Image.BICUBIC if self.mode == "RGB" else PIL.Image.NEAREST)
|
26 |
+
delta_w = self.target_size - new_width
|
27 |
+
delta_h = self.target_size - new_height
|
28 |
+
padding = (delta_w // 2, delta_h // 2, delta_w - delta_w // 2, delta_h - delta_h // 2)
|
29 |
+
padded_image = PIL.Image.new(self.mode, (self.target_size, self.target_size), self.fill)
|
30 |
+
padded_image.paste(resized_image, (padding[0], padding[1]))
|
31 |
+
return padded_image
|
32 |
+
|
33 |
+
|
34 |
+
def get_transform(mean: List[float], std: List[float]) -> T.Compose:
|
35 |
+
return T.Compose([
|
36 |
+
ResizeWithPadding(),
|
37 |
+
T.ToTensor(),
|
38 |
+
T.Normalize(mean=mean, std=std),
|
39 |
+
])
|
40 |
+
|
41 |
+
mask_transform = T.Compose([
|
42 |
+
ResizeWithPadding(mode="L"),
|
43 |
+
T.ToTensor(),
|
44 |
+
T.Lambda(lambda x: (x * 255).long()),
|
45 |
+
])
|
46 |
+
|
47 |
+
|
48 |
+
class EMA:
|
49 |
+
def __init__(self, alpha: float = 0.9) -> None:
|
50 |
+
self.value = None
|
51 |
+
self.alpha = alpha
|
52 |
+
|
53 |
+
def __call__(self, value: float) -> float:
|
54 |
+
if self.value is None:
|
55 |
+
self.value = value
|
56 |
+
else:
|
57 |
+
self.value = self.alpha * self.value + (1 - self.alpha) * value
|
58 |
+
return self.value
|