""" paper: https://arxiv.org/abs/2004.08790 ref: https://github.com/ZJUGiveLab/UNet-Version/blob/master/models/UNet_3Plus.py """ import torch from torch import nn from torch.functional import F class UNetConv(nn.Module): def __init__( self, in_size, out_size, is_batchnorm=True, num_layers=2, kernel_size=3, stride=1, padding=1, ): super().__init__() self.num_layers = num_layers for i in range(num_layers): seq = [nn.Conv1d(in_size, out_size, kernel_size, stride, padding)] if is_batchnorm: seq.append(nn.BatchNorm1d(out_size)) seq.append(nn.ReLU()) conv = nn.Sequential(*seq) setattr(self, "conv%d" % i, conv) in_size = out_size def forward(self, inputs): x = inputs for i in range(self.num_layers): conv = getattr(self, "conv%d" % i) x = conv(x) return x class UNet3PlusDeepSup(nn.Module): def __init__(self, config): super().__init__() self.config = config inplanes = int(config.inplanes) kernel_size = int(config.kernel_size) padding = (kernel_size - 1) // 2 num_encoder_layers = int(config.num_encoder_layers) encoder_batchnorm = bool(config.encoder_batchnorm) self.num_depths = int(config.num_depths) self.interpolate_mode = str(config.interpolate_mode) dropout = float(config.dropout) self.use_cgm = bool(config.use_cgm) # sum_of_sup == True: 모든 sup 을 elementwise sum 하여 하나의 dense map 을 만들어 label 과 loss 를 구함 # sum_of_sup == False: 각 sup 과 label의 loss 를 각각 구하여 하나의 loss 에 저장 self.sum_of_sup = bool(config.sum_of_sup) # TrialSetup._init_network_params 에서 설정됨 self.output_size: int = config.output_size # Encoder self.encoders = torch.nn.ModuleList() for i in range(self.num_depths): """(MaxPool - UNetConv) 를 수행하는 것이 하나의 depth 이고, 예외적으로 첫번째 depth 의 encode 결과는 (UNetConv)만 수행한 것""" _encoders = [] if i != 0: _encoders.append(nn.MaxPool1d(2)) _encoders.append( UNetConv( 1 if i == 0 else (inplanes * (2 ** (i - 1))), inplanes * (2**i), is_batchnorm=encoder_batchnorm, num_layers=num_encoder_layers, kernel_size=kernel_size, stride=1, padding=padding, ) ) self.encoders.append(nn.Sequential(*_encoders)) # CGM: Classification-Guided Module if self.use_cgm: self.cls = nn.Sequential( nn.Dropout(dropout), nn.Conv1d( inplanes * (2 ** (self.num_depths - 1)), 2 * self.output_size, 1 ), nn.AdaptiveMaxPool1d(1), nn.Sigmoid(), ) # Decoder self.up_channels = inplanes * self.num_depths self.decoders = torch.nn.ModuleList() for i in reversed(range(self.num_depths - 1)): """ 각 decoder 는 각 encode 결과를 MaxPool 하거나 그대로(Conv,BatchNorm,Relu 만) 사용하거나 Upsample 된 결과를 수행하고 concat 하여 (Conv,BatchNorm,Relu)를 수행할 수 있도록 구성 다만, Upsample 은 encode 결과와 size 를 맞추기 간편하도록 forward 단계에서 torch.functional.interpolate() 로 수행 """ # 각 단계별 decoder 는 항상 num_depths 만큼 구성되고 내부적으로 MaxPool/그대로/Upsample 수행할지가 달라짐 _decoders = torch.nn.ModuleList() for j in range(self.num_depths): _each_decoders = [] if j < i: _each_decoders.append(nn.MaxPool1d(2 ** (i - j), ceil_mode=True)) if i < j < self.num_depths - 1: _each_decoders.append( nn.Conv1d( inplanes * self.num_depths, inplanes, kernel_size, padding=padding, ) ) else: _each_decoders.append( nn.Conv1d( inplanes * (2**j), inplanes, kernel_size, padding=padding ) ) _each_decoders.append(nn.BatchNorm1d(inplanes)) _each_decoders.append(nn.ReLU()) _decoders.append(nn.Sequential(*_each_decoders)) _decoders.append( nn.Sequential( nn.Conv1d( self.up_channels, self.up_channels, kernel_size, padding=padding ), nn.BatchNorm1d(self.up_channels), nn.ReLU(), ) ) self.decoders.append(_decoders) # 앞 conv 들은 in channel 이 up_channels(inplanes*num_depths(원본에서는 320)), 마지막 conv 는 마지막 encoder 결과의 output_channel 과 맞춤 self.sup_conv = torch.nn.ModuleList() for i in range(self.num_depths - 1): self.sup_conv.append( nn.Sequential( nn.Conv1d( self.up_channels, self.output_size, kernel_size, padding=padding ), nn.BatchNorm1d(self.output_size), nn.ReLU(), ) ) self.sup_conv.append( nn.Sequential( nn.Conv1d( inplanes * (2 ** (self.num_depths - 1)), self.output_size, kernel_size, padding=padding, ), nn.BatchNorm1d(self.output_size), nn.ReLU(), ) ) def forward(self, input: torch.Tensor, y=None): # Encoder output = input enc_features = [] # X1Ee, X2Ee, .. , X5Ee dec_features = [] # X5Ee, X4De, .. , X1De for encoder in self.encoders: output = encoder(output) enc_features.append(output) dec_features.append(output) # CGM cls_branch_max = None if self.use_cgm: # (B, 2*3(output_size), 1) cls_branch: torch.Tensor = self.cls(enc_features[-1]) # (B, 3(output_size)) cls_branch_max = cls_branch.view( input.shape[0], self.output_size, 2 ).argmax(2) # Decoder for i in reversed(range(self.num_depths - 1)): _each_dec_feature = [] for j in range(self.num_depths): if j <= i: _each_enc = enc_features[j] else: _each_enc = F.interpolate( dec_features[self.num_depths - j - 1], enc_features[i].shape[2], mode=self.interpolate_mode, ) _each_dec_feature.append( self.decoders[self.num_depths - i - 2][j](_each_enc) ) dec_features.append( self.decoders[self.num_depths - i - 2][-1]( torch.cat(_each_dec_feature, dim=1) ) ) sup = [] for i, (dec_feature, sup_conv) in enumerate( zip(dec_features, reversed(self.sup_conv)) ): if i < self.num_depths - 1: sup.append( F.interpolate( sup_conv(dec_feature), input.shape[2], mode=self.interpolate_mode, ) ) else: sup.append(sup_conv(dec_feature)) if self.use_cgm: if self.sum_of_sup: return torch.sigmoid( sum( [ torch.einsum("ijk,ij->ijk", [_sup, cls_branch_max]) for _sup in reversed(sup) ] ) ) else: return [ torch.sigmoid( torch.einsum("ijk,ij->ijk", [_sup, cls_branch_max]) for _sup in reversed(sup) ) ] else: if self.sum_of_sup: return torch.sigmoid(sum(sup)) else: return [torch.sigmoid(_sup) for _sup in reversed(sup)]