soumickmj's picture
Upload WNetMSS3D
3818098 verified
#!/usr/bin/env python
# from __future__ import print_function, division
"""
Purpose :
"""
import torch.nn
import torch
import torch.nn as nn
from .attention_unet3d import AttUnet
from .unet3d import UNet, UNetDeepSup
__author__ = "Chethan Radhakrishna and Soumick Chatterjee"
__credits__ = ["Chethan Radhakrishna", "Soumick Chatterjee"]
__license__ = "GPL"
__version__ = "1.0.0"
__maintainer__ = "Chethan Radhakrishna"
__email__ = "[email protected]"
__status__ = "Development"
class WNet3dAttUNet(nn.Module):
"""
Attention Unet implementation
Paper: https://arxiv.org/abs/1804.03999
"""
def __init__(self, in_ch=1, out_ch=6, init_features=64):
super(WNet3dAttUNet, self).__init__()
self.Encoder = AttUnet(in_ch=in_ch, out_ch=out_ch, init_features=init_features)
self.Decoder = AttUnet(in_ch=out_ch, out_ch=in_ch, init_features=init_features)
self.activation = torch.nn.Softmax(dim=1)
self.Conv = nn.Conv3d(out_ch, in_ch, kernel_size=1, stride=1, padding=0)
def forward(self, ip, ip_mask=None, ops="both"):
encoder_op = self.Encoder(ip)
if ip_mask is not None:
encoder_op = ip_mask * encoder_op
class_prob = self.activation(encoder_op)
feature_rep = self.Conv(encoder_op)
if ops == "enc":
return class_prob, feature_rep
reconstructed_op = self.Decoder(class_prob)
# if ip_mask is not None:
# reconstructed_op = torch.amax(ip_mask, dim=1, keepdim=True) * reconstructed_op
if ops == "dec":
return reconstructed_op
if ops == "both":
return class_prob, feature_rep, reconstructed_op
else:
raise ValueError('Invalid ops, ops must be in [enc, dec, both]')
class WNet3dUNet(nn.Module):
"""
Attention Unet implementation
Paper: https://arxiv.org/abs/1804.03999
"""
def __init__(self, in_ch=1, out_ch=6, init_features=64):
super(WNet3dUNet, self).__init__()
self.Encoder = UNet(in_ch=in_ch, out_ch=out_ch, init_features=init_features)
self.Decoder = UNet(in_ch=out_ch, out_ch=in_ch, init_features=init_features)
self.activation = torch.nn.Softmax(dim=1)
self.Conv = nn.Conv3d(out_ch, in_ch, kernel_size=1, stride=1, padding=0)
def forward(self, ip, ip_mask=None, ops="both"):
encoder_op = self.Encoder(ip)
if ip_mask is not None:
encoder_op = ip_mask * encoder_op
class_prob = self.activation(encoder_op)
feature_rep = self.Conv(encoder_op)
if ops == "enc":
return class_prob, feature_rep
reconstructed_op = self.Decoder(class_prob)
# if ip_mask is not None:
# reconstructed_op = torch.amax(ip_mask, dim=1, keepdim=True) * reconstructed_op
if ops == "dec":
return reconstructed_op
if ops == "both":
return class_prob, feature_rep, reconstructed_op
else:
raise ValueError('Invalid ops, ops must be in [enc, dec, both]')
class WNet3dUNetMSS(nn.Module):
"""
Attention Unet implementation
Paper: https://arxiv.org/abs/1804.03999
"""
def __init__(self, in_ch=1, out_ch=6, init_features=64):
super(WNet3dUNetMSS, self).__init__()
self.Encoder = UNetDeepSup(in_ch=in_ch, out_ch=out_ch, init_features=init_features)
self.Decoder = UNetDeepSup(in_ch=out_ch, out_ch=in_ch, init_features=init_features)
self.activation = torch.nn.Softmax(dim=1)
self.Conv = nn.Conv3d(out_ch, in_ch, kernel_size=1, stride=1, padding=0)
def forward(self, ip, ip_mask=None, ops="both"):
encoder_op = self.Encoder(ip)
if ip_mask is not None:
encoder_op = ip_mask * encoder_op
class_prob = self.activation(encoder_op)
feature_rep = self.Conv(encoder_op)
if ops == "enc":
return class_prob, feature_rep
reconstructed_op = self.Decoder(class_prob)
# if ip_mask is not None:
# reconstructed_op = torch.amax(ip_mask, dim=1, keepdim=True) * reconstructed_op
if ops == "dec":
return reconstructed_op
if ops == "both":
return class_prob, feature_rep, reconstructed_op
else:
raise ValueError('Invalid ops, ops must be in [enc, dec, both]')