ECG_Delineation / res /impl /UNet3PlusDeepSup.py
wogh2012's picture
refactor: add implementations
aefacda
"""
paper: https://arxiv.org/abs/2004.08790
ref: https://github.com/ZJUGiveLab/UNet-Version/blob/master/models/UNet_3Plus.py
"""
import torch
from torch import nn
from torch.functional import F
class UNetConv(nn.Module):
def __init__(
self,
in_size,
out_size,
is_batchnorm=True,
num_layers=2,
kernel_size=3,
stride=1,
padding=1,
):
super().__init__()
self.num_layers = num_layers
for i in range(num_layers):
seq = [nn.Conv1d(in_size, out_size, kernel_size, stride, padding)]
if is_batchnorm:
seq.append(nn.BatchNorm1d(out_size))
seq.append(nn.ReLU())
conv = nn.Sequential(*seq)
setattr(self, "conv%d" % i, conv)
in_size = out_size
def forward(self, inputs):
x = inputs
for i in range(self.num_layers):
conv = getattr(self, "conv%d" % i)
x = conv(x)
return x
class UNet3PlusDeepSup(nn.Module):
def __init__(self, config):
super().__init__()
self.config = config
inplanes = int(config.inplanes)
kernel_size = int(config.kernel_size)
padding = (kernel_size - 1) // 2
num_encoder_layers = int(config.num_encoder_layers)
encoder_batchnorm = bool(config.encoder_batchnorm)
self.num_depths = int(config.num_depths)
self.interpolate_mode = str(config.interpolate_mode)
dropout = float(config.dropout)
self.use_cgm = bool(config.use_cgm)
# sum_of_sup == True: ๋ชจ๋“  sup ์„ elementwise sum ํ•˜์—ฌ ํ•˜๋‚˜์˜ dense map ์„ ๋งŒ๋“ค์–ด label ๊ณผ loss ๋ฅผ ๊ตฌํ•จ
# sum_of_sup == False: ๊ฐ sup ๊ณผ label์˜ loss ๋ฅผ ๊ฐ๊ฐ ๊ตฌํ•˜์—ฌ ํ•˜๋‚˜์˜ loss ์— ์ €์žฅ
self.sum_of_sup = bool(config.sum_of_sup)
# TrialSetup._init_network_params ์—์„œ ์„ค์ •๋จ
self.output_size: int = config.output_size
# Encoder
self.encoders = torch.nn.ModuleList()
for i in range(self.num_depths):
"""(MaxPool - UNetConv) ๋ฅผ ์ˆ˜ํ–‰ํ•˜๋Š” ๊ฒƒ์ด ํ•˜๋‚˜์˜ depth ์ด๊ณ , ์˜ˆ์™ธ์ ์œผ๋กœ ์ฒซ๋ฒˆ์งธ depth ์˜ encode ๊ฒฐ๊ณผ๋Š” (UNetConv)๋งŒ ์ˆ˜ํ–‰ํ•œ ๊ฒƒ"""
_encoders = []
if i != 0:
_encoders.append(nn.MaxPool1d(2))
_encoders.append(
UNetConv(
1 if i == 0 else (inplanes * (2 ** (i - 1))),
inplanes * (2**i),
is_batchnorm=encoder_batchnorm,
num_layers=num_encoder_layers,
kernel_size=kernel_size,
stride=1,
padding=padding,
)
)
self.encoders.append(nn.Sequential(*_encoders))
# CGM: Classification-Guided Module
if self.use_cgm:
self.cls = nn.Sequential(
nn.Dropout(dropout),
nn.Conv1d(
inplanes * (2 ** (self.num_depths - 1)), 2 * self.output_size, 1
),
nn.AdaptiveMaxPool1d(1),
nn.Sigmoid(),
)
# Decoder
self.up_channels = inplanes * self.num_depths
self.decoders = torch.nn.ModuleList()
for i in reversed(range(self.num_depths - 1)):
"""
๊ฐ decoder ๋Š” ๊ฐ encode ๊ฒฐ๊ณผ๋ฅผ MaxPool ํ•˜๊ฑฐ๋‚˜ ๊ทธ๋Œ€๋กœ(Conv,BatchNorm,Relu ๋งŒ) ์‚ฌ์šฉํ•˜๊ฑฐ๋‚˜ Upsample ๋œ ๊ฒฐ๊ณผ๋ฅผ ์ˆ˜ํ–‰ํ•˜๊ณ  concat ํ•˜์—ฌ (Conv,BatchNorm,Relu)๋ฅผ ์ˆ˜ํ–‰ํ•  ์ˆ˜ ์žˆ๋„๋ก ๊ตฌ์„ฑ
๋‹ค๋งŒ, Upsample ์€ encode ๊ฒฐ๊ณผ์™€ size ๋ฅผ ๋งž์ถ”๊ธฐ ๊ฐ„ํŽธํ•˜๋„๋ก forward ๋‹จ๊ณ„์—์„œ torch.functional.interpolate() ๋กœ ์ˆ˜ํ–‰
"""
# ๊ฐ ๋‹จ๊ณ„๋ณ„ decoder ๋Š” ํ•ญ์ƒ num_depths ๋งŒํผ ๊ตฌ์„ฑ๋˜๊ณ  ๋‚ด๋ถ€์ ์œผ๋กœ MaxPool/๊ทธ๋Œ€๋กœ/Upsample ์ˆ˜ํ–‰ํ• ์ง€๊ฐ€ ๋‹ฌ๋ผ์ง
_decoders = torch.nn.ModuleList()
for j in range(self.num_depths):
_each_decoders = []
if j < i:
_each_decoders.append(nn.MaxPool1d(2 ** (i - j), ceil_mode=True))
if i < j < self.num_depths - 1:
_each_decoders.append(
nn.Conv1d(
inplanes * self.num_depths,
inplanes,
kernel_size,
padding=padding,
)
)
else:
_each_decoders.append(
nn.Conv1d(
inplanes * (2**j), inplanes, kernel_size, padding=padding
)
)
_each_decoders.append(nn.BatchNorm1d(inplanes))
_each_decoders.append(nn.ReLU())
_decoders.append(nn.Sequential(*_each_decoders))
_decoders.append(
nn.Sequential(
nn.Conv1d(
self.up_channels, self.up_channels, kernel_size, padding=padding
),
nn.BatchNorm1d(self.up_channels),
nn.ReLU(),
)
)
self.decoders.append(_decoders)
# ์•ž conv ๋“ค์€ in channel ์ด up_channels(inplanes*num_depths(์›๋ณธ์—์„œ๋Š” 320)), ๋งˆ์ง€๋ง‰ conv ๋Š” ๋งˆ์ง€๋ง‰ encoder ๊ฒฐ๊ณผ์˜ output_channel ๊ณผ ๋งž์ถค
self.sup_conv = torch.nn.ModuleList()
for i in range(self.num_depths - 1):
self.sup_conv.append(
nn.Sequential(
nn.Conv1d(
self.up_channels, self.output_size, kernel_size, padding=padding
),
nn.BatchNorm1d(self.output_size),
nn.ReLU(),
)
)
self.sup_conv.append(
nn.Sequential(
nn.Conv1d(
inplanes * (2 ** (self.num_depths - 1)),
self.output_size,
kernel_size,
padding=padding,
),
nn.BatchNorm1d(self.output_size),
nn.ReLU(),
)
)
def forward(self, input: torch.Tensor, y=None):
# Encoder
output = input
enc_features = [] # X1Ee, X2Ee, .. , X5Ee
dec_features = [] # X5Ee, X4De, .. , X1De
for encoder in self.encoders:
output = encoder(output)
enc_features.append(output)
dec_features.append(output)
# CGM
cls_branch_max = None
if self.use_cgm:
# (B, 2*3(output_size), 1)
cls_branch: torch.Tensor = self.cls(enc_features[-1])
# (B, 3(output_size))
cls_branch_max = cls_branch.view(
input.shape[0], self.output_size, 2
).argmax(2)
# Decoder
for i in reversed(range(self.num_depths - 1)):
_each_dec_feature = []
for j in range(self.num_depths):
if j <= i:
_each_enc = enc_features[j]
else:
_each_enc = F.interpolate(
dec_features[self.num_depths - j - 1],
enc_features[i].shape[2],
mode=self.interpolate_mode,
)
_each_dec_feature.append(
self.decoders[self.num_depths - i - 2][j](_each_enc)
)
dec_features.append(
self.decoders[self.num_depths - i - 2][-1](
torch.cat(_each_dec_feature, dim=1)
)
)
sup = []
for i, (dec_feature, sup_conv) in enumerate(
zip(dec_features, reversed(self.sup_conv))
):
if i < self.num_depths - 1:
sup.append(
F.interpolate(
sup_conv(dec_feature),
input.shape[2],
mode=self.interpolate_mode,
)
)
else:
sup.append(sup_conv(dec_feature))
if self.use_cgm:
if self.sum_of_sup:
return torch.sigmoid(
sum(
[
torch.einsum("ijk,ij->ijk", [_sup, cls_branch_max])
for _sup in reversed(sup)
]
)
)
else:
return [
torch.sigmoid(
torch.einsum("ijk,ij->ijk", [_sup, cls_branch_max])
for _sup in reversed(sup)
)
]
else:
if self.sum_of_sup:
return torch.sigmoid(sum(sup))
else:
return [torch.sigmoid(_sup) for _sup in reversed(sup)]