Spaces:
Sleeping
Sleeping
""" | |
paper: https://arxiv.org/abs/2004.08790 | |
ref: https://github.com/ZJUGiveLab/UNet-Version/blob/master/models/UNet_3Plus.py | |
""" | |
import torch | |
from torch import nn | |
from torch.functional import F | |
class UNetConv(nn.Module): | |
def __init__( | |
self, | |
in_size, | |
out_size, | |
is_batchnorm=True, | |
num_layers=2, | |
kernel_size=3, | |
stride=1, | |
padding=1, | |
): | |
super().__init__() | |
self.num_layers = num_layers | |
for i in range(num_layers): | |
seq = [nn.Conv1d(in_size, out_size, kernel_size, stride, padding)] | |
if is_batchnorm: | |
seq.append(nn.BatchNorm1d(out_size)) | |
seq.append(nn.ReLU()) | |
conv = nn.Sequential(*seq) | |
setattr(self, "conv%d" % i, conv) | |
in_size = out_size | |
def forward(self, inputs): | |
x = inputs | |
for i in range(self.num_layers): | |
conv = getattr(self, "conv%d" % i) | |
x = conv(x) | |
return x | |
class UNet3PlusDeepSup(nn.Module): | |
def __init__(self, config): | |
super().__init__() | |
self.config = config | |
inplanes = int(config.inplanes) | |
kernel_size = int(config.kernel_size) | |
padding = (kernel_size - 1) // 2 | |
num_encoder_layers = int(config.num_encoder_layers) | |
encoder_batchnorm = bool(config.encoder_batchnorm) | |
self.num_depths = int(config.num_depths) | |
self.interpolate_mode = str(config.interpolate_mode) | |
dropout = float(config.dropout) | |
self.use_cgm = bool(config.use_cgm) | |
# sum_of_sup == True: ๋ชจ๋ sup ์ elementwise sum ํ์ฌ ํ๋์ dense map ์ ๋ง๋ค์ด label ๊ณผ loss ๋ฅผ ๊ตฌํจ | |
# sum_of_sup == False: ๊ฐ sup ๊ณผ label์ loss ๋ฅผ ๊ฐ๊ฐ ๊ตฌํ์ฌ ํ๋์ loss ์ ์ ์ฅ | |
self.sum_of_sup = bool(config.sum_of_sup) | |
# TrialSetup._init_network_params ์์ ์ค์ ๋จ | |
self.output_size: int = config.output_size | |
# Encoder | |
self.encoders = torch.nn.ModuleList() | |
for i in range(self.num_depths): | |
"""(MaxPool - UNetConv) ๋ฅผ ์ํํ๋ ๊ฒ์ด ํ๋์ depth ์ด๊ณ , ์์ธ์ ์ผ๋ก ์ฒซ๋ฒ์งธ depth ์ encode ๊ฒฐ๊ณผ๋ (UNetConv)๋ง ์ํํ ๊ฒ""" | |
_encoders = [] | |
if i != 0: | |
_encoders.append(nn.MaxPool1d(2)) | |
_encoders.append( | |
UNetConv( | |
1 if i == 0 else (inplanes * (2 ** (i - 1))), | |
inplanes * (2**i), | |
is_batchnorm=encoder_batchnorm, | |
num_layers=num_encoder_layers, | |
kernel_size=kernel_size, | |
stride=1, | |
padding=padding, | |
) | |
) | |
self.encoders.append(nn.Sequential(*_encoders)) | |
# CGM: Classification-Guided Module | |
if self.use_cgm: | |
self.cls = nn.Sequential( | |
nn.Dropout(dropout), | |
nn.Conv1d( | |
inplanes * (2 ** (self.num_depths - 1)), 2 * self.output_size, 1 | |
), | |
nn.AdaptiveMaxPool1d(1), | |
nn.Sigmoid(), | |
) | |
# Decoder | |
self.up_channels = inplanes * self.num_depths | |
self.decoders = torch.nn.ModuleList() | |
for i in reversed(range(self.num_depths - 1)): | |
""" | |
๊ฐ decoder ๋ ๊ฐ encode ๊ฒฐ๊ณผ๋ฅผ MaxPool ํ๊ฑฐ๋ ๊ทธ๋๋ก(Conv,BatchNorm,Relu ๋ง) ์ฌ์ฉํ๊ฑฐ๋ Upsample ๋ ๊ฒฐ๊ณผ๋ฅผ ์ํํ๊ณ concat ํ์ฌ (Conv,BatchNorm,Relu)๋ฅผ ์ํํ ์ ์๋๋ก ๊ตฌ์ฑ | |
๋ค๋ง, Upsample ์ encode ๊ฒฐ๊ณผ์ size ๋ฅผ ๋ง์ถ๊ธฐ ๊ฐํธํ๋๋ก forward ๋จ๊ณ์์ torch.functional.interpolate() ๋ก ์ํ | |
""" | |
# ๊ฐ ๋จ๊ณ๋ณ decoder ๋ ํญ์ num_depths ๋งํผ ๊ตฌ์ฑ๋๊ณ ๋ด๋ถ์ ์ผ๋ก MaxPool/๊ทธ๋๋ก/Upsample ์ํํ ์ง๊ฐ ๋ฌ๋ผ์ง | |
_decoders = torch.nn.ModuleList() | |
for j in range(self.num_depths): | |
_each_decoders = [] | |
if j < i: | |
_each_decoders.append(nn.MaxPool1d(2 ** (i - j), ceil_mode=True)) | |
if i < j < self.num_depths - 1: | |
_each_decoders.append( | |
nn.Conv1d( | |
inplanes * self.num_depths, | |
inplanes, | |
kernel_size, | |
padding=padding, | |
) | |
) | |
else: | |
_each_decoders.append( | |
nn.Conv1d( | |
inplanes * (2**j), inplanes, kernel_size, padding=padding | |
) | |
) | |
_each_decoders.append(nn.BatchNorm1d(inplanes)) | |
_each_decoders.append(nn.ReLU()) | |
_decoders.append(nn.Sequential(*_each_decoders)) | |
_decoders.append( | |
nn.Sequential( | |
nn.Conv1d( | |
self.up_channels, self.up_channels, kernel_size, padding=padding | |
), | |
nn.BatchNorm1d(self.up_channels), | |
nn.ReLU(), | |
) | |
) | |
self.decoders.append(_decoders) | |
# ์ conv ๋ค์ in channel ์ด up_channels(inplanes*num_depths(์๋ณธ์์๋ 320)), ๋ง์ง๋ง conv ๋ ๋ง์ง๋ง encoder ๊ฒฐ๊ณผ์ output_channel ๊ณผ ๋ง์ถค | |
self.sup_conv = torch.nn.ModuleList() | |
for i in range(self.num_depths - 1): | |
self.sup_conv.append( | |
nn.Sequential( | |
nn.Conv1d( | |
self.up_channels, self.output_size, kernel_size, padding=padding | |
), | |
nn.BatchNorm1d(self.output_size), | |
nn.ReLU(), | |
) | |
) | |
self.sup_conv.append( | |
nn.Sequential( | |
nn.Conv1d( | |
inplanes * (2 ** (self.num_depths - 1)), | |
self.output_size, | |
kernel_size, | |
padding=padding, | |
), | |
nn.BatchNorm1d(self.output_size), | |
nn.ReLU(), | |
) | |
) | |
def forward(self, input: torch.Tensor, y=None): | |
# Encoder | |
output = input | |
enc_features = [] # X1Ee, X2Ee, .. , X5Ee | |
dec_features = [] # X5Ee, X4De, .. , X1De | |
for encoder in self.encoders: | |
output = encoder(output) | |
enc_features.append(output) | |
dec_features.append(output) | |
# CGM | |
cls_branch_max = None | |
if self.use_cgm: | |
# (B, 2*3(output_size), 1) | |
cls_branch: torch.Tensor = self.cls(enc_features[-1]) | |
# (B, 3(output_size)) | |
cls_branch_max = cls_branch.view( | |
input.shape[0], self.output_size, 2 | |
).argmax(2) | |
# Decoder | |
for i in reversed(range(self.num_depths - 1)): | |
_each_dec_feature = [] | |
for j in range(self.num_depths): | |
if j <= i: | |
_each_enc = enc_features[j] | |
else: | |
_each_enc = F.interpolate( | |
dec_features[self.num_depths - j - 1], | |
enc_features[i].shape[2], | |
mode=self.interpolate_mode, | |
) | |
_each_dec_feature.append( | |
self.decoders[self.num_depths - i - 2][j](_each_enc) | |
) | |
dec_features.append( | |
self.decoders[self.num_depths - i - 2][-1]( | |
torch.cat(_each_dec_feature, dim=1) | |
) | |
) | |
sup = [] | |
for i, (dec_feature, sup_conv) in enumerate( | |
zip(dec_features, reversed(self.sup_conv)) | |
): | |
if i < self.num_depths - 1: | |
sup.append( | |
F.interpolate( | |
sup_conv(dec_feature), | |
input.shape[2], | |
mode=self.interpolate_mode, | |
) | |
) | |
else: | |
sup.append(sup_conv(dec_feature)) | |
if self.use_cgm: | |
if self.sum_of_sup: | |
return torch.sigmoid( | |
sum( | |
[ | |
torch.einsum("ijk,ij->ijk", [_sup, cls_branch_max]) | |
for _sup in reversed(sup) | |
] | |
) | |
) | |
else: | |
return [ | |
torch.sigmoid( | |
torch.einsum("ijk,ij->ijk", [_sup, cls_branch_max]) | |
for _sup in reversed(sup) | |
) | |
] | |
else: | |
if self.sum_of_sup: | |
return torch.sigmoid(sum(sup)) | |
else: | |
return [torch.sigmoid(_sup) for _sup in reversed(sup)] | |