File size: 2,673 Bytes
6fc43ab |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 |
import sys
sys.path.append('..')
# from feature_extractor.for_image_data.backbone import CNN_GAP, ResNet3D, UNet3D
import torch
import torch.nn as nn
from torchvision import models
import torch.nn.functional as F
# from . import UNet3D
from .unet import UNet3D
from icecream import ic
class UNet3DBase(nn.Module):
def __init__(self, n_class=1, act='relu', attention=False, pretrained=False, drop_rate=0.1, blocks=4):
super(UNet3DBase, self).__init__()
model = UNet3D(n_class=n_class, attention=attention, pretrained=pretrained, blocks=blocks)
self.blocks = blocks
self.down_tr64 = model.down_tr64
self.down_tr128 = model.down_tr128
self.down_tr256 = model.down_tr256
self.down_tr512 = model.down_tr512
if self.blocks == 5:
self.down_tr1024 = model.down_tr1024
# self.block_modules = nn.ModuleList([self.down_tr64, self.down_tr128, self.down_tr256, self.down_tr512])
self.in_features = model.in_features
# ic(attention)
if attention:
self.attention_module = model.attention_module
# self.attention_module = AttentionModule(512, n_class, drop_rate=drop_rate)
# self.avgpool = nn.AvgPool3d((6,7,6), stride=(6,6,6))
def forward(self, x, stage='normal', attention=False):
# ic('UNet3DBase forward')
self.out64, self.skip_out64 = self.down_tr64(x)
# ic(self.out64.shape, self.skip_out64.shape)
self.out128,self.skip_out128 = self.down_tr128(self.out64)
# ic(self.out128.shape, self.skip_out128.shape)
self.out256,self.skip_out256 = self.down_tr256(self.out128)
# ic(self.out256.shape, self.skip_out256.shape)
self.out512,self.skip_out512 = self.down_tr512(self.out256)
# ic(self.out512.shape, self.skip_out512.shape)
if self.blocks == 5:
self.out1024,self.skip_out1024 = self.down_tr1024(self.out512)
# ic(self.out1024.shape, self.skip_out1024.shape)
# ic(hasattr(self, 'attention_module'))
if hasattr(self, 'attention_module'):
att, feats = self.attention_module(self.out1024 if self.blocks == 5 else self.out512)
else:
feats = self.out1024 if self.blocks == 5 else self.out512
# ic(feats.shape)
if attention:
return att, feats
return feats
# self.out_up_256 = self.up_tr256(self.out512,self.skip_out256)
# self.out_up_128 = self.up_tr128(self.out_up_256, self.skip_out128)
# self.out_up_64 = self.up_tr64(self.out_up_128, self.skip_out64)
# self.out = self.out_tr(self.out_up_64)
# return self.out |