File size: 4,376 Bytes
34baeb4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 |
#!/usr/bin/env python
# from __future__ import print_function, division
"""
Purpose :
"""
import torch.nn
import torch
import torch.nn as nn
from .attention_unet3d import AttUnet
from .unet3d import UNet, UNetDeepSup
__author__ = "Chethan Radhakrishna and Soumick Chatterjee"
__credits__ = ["Chethan Radhakrishna", "Soumick Chatterjee"]
__license__ = "GPL"
__version__ = "1.0.0"
__maintainer__ = "Chethan Radhakrishna"
__email__ = "[email protected]"
__status__ = "Development"
class WNet3dAttUNet(nn.Module):
"""
Attention Unet implementation
Paper: https://arxiv.org/abs/1804.03999
"""
def __init__(self, in_ch=1, out_ch=6, init_features=64):
super(WNet3dAttUNet, self).__init__()
self.Encoder = AttUnet(in_ch=in_ch, out_ch=out_ch, init_features=init_features)
self.Decoder = AttUnet(in_ch=out_ch, out_ch=in_ch, init_features=init_features)
self.activation = torch.nn.Softmax(dim=1)
self.Conv = nn.Conv3d(out_ch, in_ch, kernel_size=1, stride=1, padding=0)
def forward(self, ip, ip_mask=None, ops="both"):
encoder_op = self.Encoder(ip)
if ip_mask is not None:
encoder_op = ip_mask * encoder_op
class_prob = self.activation(encoder_op)
feature_rep = self.Conv(encoder_op)
if ops == "enc":
return class_prob, feature_rep
reconstructed_op = self.Decoder(class_prob)
# if ip_mask is not None:
# reconstructed_op = torch.amax(ip_mask, dim=1, keepdim=True) * reconstructed_op
if ops == "dec":
return reconstructed_op
if ops == "both":
return class_prob, feature_rep, reconstructed_op
else:
raise ValueError('Invalid ops, ops must be in [enc, dec, both]')
class WNet3dUNet(nn.Module):
"""
Attention Unet implementation
Paper: https://arxiv.org/abs/1804.03999
"""
def __init__(self, in_ch=1, out_ch=6, init_features=64):
super(WNet3dUNet, self).__init__()
self.Encoder = UNet(in_ch=in_ch, out_ch=out_ch, init_features=init_features)
self.Decoder = UNet(in_ch=out_ch, out_ch=in_ch, init_features=init_features)
self.activation = torch.nn.Softmax(dim=1)
self.Conv = nn.Conv3d(out_ch, in_ch, kernel_size=1, stride=1, padding=0)
def forward(self, ip, ip_mask=None, ops="both"):
encoder_op = self.Encoder(ip)
if ip_mask is not None:
encoder_op = ip_mask * encoder_op
class_prob = self.activation(encoder_op)
feature_rep = self.Conv(encoder_op)
if ops == "enc":
return class_prob, feature_rep
reconstructed_op = self.Decoder(class_prob)
# if ip_mask is not None:
# reconstructed_op = torch.amax(ip_mask, dim=1, keepdim=True) * reconstructed_op
if ops == "dec":
return reconstructed_op
if ops == "both":
return class_prob, feature_rep, reconstructed_op
else:
raise ValueError('Invalid ops, ops must be in [enc, dec, both]')
class WNet3dUNetMSS(nn.Module):
"""
Attention Unet implementation
Paper: https://arxiv.org/abs/1804.03999
"""
def __init__(self, in_ch=1, out_ch=6, init_features=64):
super(WNet3dUNetMSS, self).__init__()
self.Encoder = UNetDeepSup(in_ch=in_ch, out_ch=out_ch, init_features=init_features)
self.Decoder = UNetDeepSup(in_ch=out_ch, out_ch=in_ch, init_features=init_features)
self.activation = torch.nn.Softmax(dim=1)
self.Conv = nn.Conv3d(out_ch, in_ch, kernel_size=1, stride=1, padding=0)
def forward(self, ip, ip_mask=None, ops="both"):
encoder_op = self.Encoder(ip)
if ip_mask is not None:
encoder_op = ip_mask * encoder_op
class_prob = self.activation(encoder_op)
feature_rep = self.Conv(encoder_op)
if ops == "enc":
return class_prob, feature_rep
reconstructed_op = self.Decoder(class_prob)
# if ip_mask is not None:
# reconstructed_op = torch.amax(ip_mask, dim=1, keepdim=True) * reconstructed_op
if ops == "dec":
return reconstructed_op
if ops == "both":
return class_prob, feature_rep, reconstructed_op
else:
raise ValueError('Invalid ops, ops must be in [enc, dec, both]') |