|
import sys |
|
sys.path.append('..') |
|
|
|
import torch |
|
import torch.nn as nn |
|
from torchvision import models |
|
import torch.nn.functional as F |
|
|
|
from .unet import UNet3D |
|
from icecream import ic |
|
|
|
|
|
class UNet3DBase(nn.Module): |
|
def __init__(self, n_class=1, act='relu', attention=False, pretrained=False, drop_rate=0.1, blocks=4): |
|
super(UNet3DBase, self).__init__() |
|
model = UNet3D(n_class=n_class, attention=attention, pretrained=pretrained, blocks=blocks) |
|
|
|
self.blocks = blocks |
|
|
|
self.down_tr64 = model.down_tr64 |
|
self.down_tr128 = model.down_tr128 |
|
self.down_tr256 = model.down_tr256 |
|
self.down_tr512 = model.down_tr512 |
|
if self.blocks == 5: |
|
self.down_tr1024 = model.down_tr1024 |
|
|
|
|
|
self.in_features = model.in_features |
|
|
|
if attention: |
|
self.attention_module = model.attention_module |
|
|
|
|
|
|
|
def forward(self, x, stage='normal', attention=False): |
|
|
|
self.out64, self.skip_out64 = self.down_tr64(x) |
|
|
|
self.out128,self.skip_out128 = self.down_tr128(self.out64) |
|
|
|
self.out256,self.skip_out256 = self.down_tr256(self.out128) |
|
|
|
self.out512,self.skip_out512 = self.down_tr512(self.out256) |
|
|
|
if self.blocks == 5: |
|
self.out1024,self.skip_out1024 = self.down_tr1024(self.out512) |
|
|
|
|
|
if hasattr(self, 'attention_module'): |
|
att, feats = self.attention_module(self.out1024 if self.blocks == 5 else self.out512) |
|
else: |
|
feats = self.out1024 if self.blocks == 5 else self.out512 |
|
|
|
if attention: |
|
return att, feats |
|
return feats |
|
|
|
|
|
|
|
|
|
|
|
|
|
|