File size: 2,673 Bytes
6fc43ab
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
import sys
sys.path.append('..')
# from feature_extractor.for_image_data.backbone import CNN_GAP, ResNet3D, UNet3D
import torch
import torch.nn as nn
from torchvision import models
import torch.nn.functional as F
# from . import UNet3D
from .unet import UNet3D
from icecream import ic


class UNet3DBase(nn.Module):
    def __init__(self, n_class=1, act='relu', attention=False, pretrained=False, drop_rate=0.1, blocks=4):
        super(UNet3DBase, self).__init__()
        model = UNet3D(n_class=n_class, attention=attention, pretrained=pretrained, blocks=blocks)

        self.blocks = blocks

        self.down_tr64 = model.down_tr64
        self.down_tr128 = model.down_tr128
        self.down_tr256 = model.down_tr256
        self.down_tr512 = model.down_tr512
        if self.blocks == 5:
            self.down_tr1024 = model.down_tr1024
        # self.block_modules = nn.ModuleList([self.down_tr64, self.down_tr128, self.down_tr256, self.down_tr512])

        self.in_features = model.in_features
        # ic(attention)
        if attention:
            self.attention_module = model.attention_module
        #     self.attention_module = AttentionModule(512, n_class, drop_rate=drop_rate)
        # self.avgpool = nn.AvgPool3d((6,7,6), stride=(6,6,6))

    def forward(self, x, stage='normal', attention=False):
        # ic('UNet3DBase forward')
        self.out64, self.skip_out64 = self.down_tr64(x)
        # ic(self.out64.shape, self.skip_out64.shape)
        self.out128,self.skip_out128 = self.down_tr128(self.out64)
        # ic(self.out128.shape, self.skip_out128.shape)
        self.out256,self.skip_out256 = self.down_tr256(self.out128)
        # ic(self.out256.shape, self.skip_out256.shape)
        self.out512,self.skip_out512 = self.down_tr512(self.out256)
        # ic(self.out512.shape, self.skip_out512.shape)
        if self.blocks == 5:
            self.out1024,self.skip_out1024 = self.down_tr1024(self.out512)
        # ic(self.out1024.shape, self.skip_out1024.shape)
        # ic(hasattr(self, 'attention_module'))
        if hasattr(self, 'attention_module'):
            att, feats = self.attention_module(self.out1024 if self.blocks == 5 else self.out512)
        else:
            feats = self.out1024 if self.blocks == 5 else self.out512
        # ic(feats.shape)
        if attention:
            return att, feats
        return feats

        # self.out_up_256 = self.up_tr256(self.out512,self.skip_out256)
        # self.out_up_128 = self.up_tr128(self.out_up_256, self.skip_out128)
        # self.out_up_64 = self.up_tr64(self.out_up_128, self.skip_out64)
        # self.out = self.out_tr(self.out_up_64)

        # return self.out