|
|
|
|
|
|
|
""" |
|
|
|
Purpose : |
|
|
|
""" |
|
import torch.nn |
|
import torch |
|
import torch.nn as nn |
|
from .attention_unet3d import AttUnet |
|
from .unet3d import UNet, UNetDeepSup |
|
|
|
__author__ = "Chethan Radhakrishna and Soumick Chatterjee" |
|
__credits__ = ["Chethan Radhakrishna", "Soumick Chatterjee"] |
|
__license__ = "GPL" |
|
__version__ = "1.0.0" |
|
__maintainer__ = "Chethan Radhakrishna" |
|
__email__ = "[email protected]" |
|
__status__ = "Development" |
|
|
|
|
|
class WNet3dAttUNet(nn.Module): |
|
""" |
|
Attention Unet implementation |
|
Paper: https://arxiv.org/abs/1804.03999 |
|
""" |
|
|
|
def __init__(self, in_ch=1, out_ch=6, init_features=64): |
|
super(WNet3dAttUNet, self).__init__() |
|
|
|
self.Encoder = AttUnet(in_ch=in_ch, out_ch=out_ch, init_features=init_features) |
|
self.Decoder = AttUnet(in_ch=out_ch, out_ch=in_ch, init_features=init_features) |
|
|
|
self.activation = torch.nn.Softmax(dim=1) |
|
|
|
self.Conv = nn.Conv3d(out_ch, in_ch, kernel_size=1, stride=1, padding=0) |
|
|
|
def forward(self, ip, ip_mask=None, ops="both"): |
|
encoder_op = self.Encoder(ip) |
|
if ip_mask is not None: |
|
encoder_op = ip_mask * encoder_op |
|
class_prob = self.activation(encoder_op) |
|
feature_rep = self.Conv(encoder_op) |
|
if ops == "enc": |
|
return class_prob, feature_rep |
|
reconstructed_op = self.Decoder(class_prob) |
|
|
|
|
|
if ops == "dec": |
|
return reconstructed_op |
|
if ops == "both": |
|
return class_prob, feature_rep, reconstructed_op |
|
else: |
|
raise ValueError('Invalid ops, ops must be in [enc, dec, both]') |
|
|
|
|
|
class WNet3dUNet(nn.Module): |
|
""" |
|
Attention Unet implementation |
|
Paper: https://arxiv.org/abs/1804.03999 |
|
""" |
|
|
|
def __init__(self, in_ch=1, out_ch=6, init_features=64): |
|
super(WNet3dUNet, self).__init__() |
|
|
|
self.Encoder = UNet(in_ch=in_ch, out_ch=out_ch, init_features=init_features) |
|
self.Decoder = UNet(in_ch=out_ch, out_ch=in_ch, init_features=init_features) |
|
|
|
self.activation = torch.nn.Softmax(dim=1) |
|
|
|
self.Conv = nn.Conv3d(out_ch, in_ch, kernel_size=1, stride=1, padding=0) |
|
|
|
def forward(self, ip, ip_mask=None, ops="both"): |
|
encoder_op = self.Encoder(ip) |
|
if ip_mask is not None: |
|
encoder_op = ip_mask * encoder_op |
|
class_prob = self.activation(encoder_op) |
|
feature_rep = self.Conv(encoder_op) |
|
if ops == "enc": |
|
return class_prob, feature_rep |
|
reconstructed_op = self.Decoder(class_prob) |
|
|
|
|
|
if ops == "dec": |
|
return reconstructed_op |
|
if ops == "both": |
|
return class_prob, feature_rep, reconstructed_op |
|
else: |
|
raise ValueError('Invalid ops, ops must be in [enc, dec, both]') |
|
|
|
|
|
class WNet3dUNetMSS(nn.Module): |
|
""" |
|
Attention Unet implementation |
|
Paper: https://arxiv.org/abs/1804.03999 |
|
""" |
|
|
|
def __init__(self, in_ch=1, out_ch=6, init_features=64): |
|
super(WNet3dUNetMSS, self).__init__() |
|
|
|
self.Encoder = UNetDeepSup(in_ch=in_ch, out_ch=out_ch, init_features=init_features) |
|
self.Decoder = UNetDeepSup(in_ch=out_ch, out_ch=in_ch, init_features=init_features) |
|
|
|
self.activation = torch.nn.Softmax(dim=1) |
|
|
|
self.Conv = nn.Conv3d(out_ch, in_ch, kernel_size=1, stride=1, padding=0) |
|
|
|
def forward(self, ip, ip_mask=None, ops="both"): |
|
encoder_op = self.Encoder(ip) |
|
if ip_mask is not None: |
|
encoder_op = ip_mask * encoder_op |
|
class_prob = self.activation(encoder_op) |
|
feature_rep = self.Conv(encoder_op) |
|
if ops == "enc": |
|
return class_prob, feature_rep |
|
reconstructed_op = self.Decoder(class_prob) |
|
|
|
|
|
if ops == "dec": |
|
return reconstructed_op |
|
if ops == "both": |
|
return class_prob, feature_rep, reconstructed_op |
|
else: |
|
raise ValueError('Invalid ops, ops must be in [enc, dec, both]') |