File size: 4,376 Bytes
7b17641
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
#!/usr/bin/env python

# from __future__ import print_function, division
"""

Purpose :

"""
import torch.nn
import torch
import torch.nn as nn
from .attention_unet3d import AttUnet
from .unet3d import UNet, UNetDeepSup

__author__ = "Chethan Radhakrishna and Soumick Chatterjee"
__credits__ = ["Chethan Radhakrishna", "Soumick Chatterjee"]
__license__ = "GPL"
__version__ = "1.0.0"
__maintainer__ = "Chethan Radhakrishna"
__email__ = "[email protected]"
__status__ = "Development"


class WNet3dAttUNet(nn.Module):
    """
    Attention Unet implementation
    Paper: https://arxiv.org/abs/1804.03999
    """

    def __init__(self, in_ch=1, out_ch=6, init_features=64):
        super(WNet3dAttUNet, self).__init__()

        self.Encoder = AttUnet(in_ch=in_ch, out_ch=out_ch, init_features=init_features)
        self.Decoder = AttUnet(in_ch=out_ch, out_ch=in_ch, init_features=init_features)

        self.activation = torch.nn.Softmax(dim=1)

        self.Conv = nn.Conv3d(out_ch, in_ch, kernel_size=1, stride=1, padding=0)

    def forward(self, ip, ip_mask=None, ops="both"):
        encoder_op = self.Encoder(ip)
        if ip_mask is not None:
            encoder_op = ip_mask * encoder_op
        class_prob = self.activation(encoder_op)
        feature_rep = self.Conv(encoder_op)
        if ops == "enc":
            return class_prob, feature_rep
        reconstructed_op = self.Decoder(class_prob)
        # if ip_mask is not None:
        #     reconstructed_op = torch.amax(ip_mask, dim=1, keepdim=True) * reconstructed_op
        if ops == "dec":
            return reconstructed_op
        if ops == "both":
            return class_prob, feature_rep, reconstructed_op
        else:
            raise ValueError('Invalid ops, ops must be in [enc, dec, both]')


class WNet3dUNet(nn.Module):
    """
    Attention Unet implementation
    Paper: https://arxiv.org/abs/1804.03999
    """

    def __init__(self, in_ch=1, out_ch=6, init_features=64):
        super(WNet3dUNet, self).__init__()

        self.Encoder = UNet(in_ch=in_ch, out_ch=out_ch, init_features=init_features)
        self.Decoder = UNet(in_ch=out_ch, out_ch=in_ch, init_features=init_features)

        self.activation = torch.nn.Softmax(dim=1)

        self.Conv = nn.Conv3d(out_ch, in_ch, kernel_size=1, stride=1, padding=0)

    def forward(self, ip, ip_mask=None, ops="both"):
        encoder_op = self.Encoder(ip)
        if ip_mask is not None:
            encoder_op = ip_mask * encoder_op
        class_prob = self.activation(encoder_op)
        feature_rep = self.Conv(encoder_op)
        if ops == "enc":
            return class_prob, feature_rep
        reconstructed_op = self.Decoder(class_prob)
        # if ip_mask is not None:
        #     reconstructed_op = torch.amax(ip_mask, dim=1, keepdim=True) * reconstructed_op
        if ops == "dec":
            return reconstructed_op
        if ops == "both":
            return class_prob, feature_rep, reconstructed_op
        else:
            raise ValueError('Invalid ops, ops must be in [enc, dec, both]')


class WNet3dUNetMSS(nn.Module):
    """
    Attention Unet implementation
    Paper: https://arxiv.org/abs/1804.03999
    """

    def __init__(self, in_ch=1, out_ch=6, init_features=64):
        super(WNet3dUNetMSS, self).__init__()

        self.Encoder = UNetDeepSup(in_ch=in_ch, out_ch=out_ch, init_features=init_features)
        self.Decoder = UNetDeepSup(in_ch=out_ch, out_ch=in_ch, init_features=init_features)

        self.activation = torch.nn.Softmax(dim=1)

        self.Conv = nn.Conv3d(out_ch, in_ch, kernel_size=1, stride=1, padding=0)

    def forward(self, ip, ip_mask=None, ops="both"):
        encoder_op = self.Encoder(ip)
        if ip_mask is not None:
            encoder_op = ip_mask * encoder_op
        class_prob = self.activation(encoder_op)
        feature_rep = self.Conv(encoder_op)
        if ops == "enc":
            return class_prob, feature_rep
        reconstructed_op = self.Decoder(class_prob)
        # if ip_mask is not None:
        #     reconstructed_op = torch.amax(ip_mask, dim=1, keepdim=True) * reconstructed_op
        if ops == "dec":
            return reconstructed_op
        if ops == "both":
            return class_prob, feature_rep, reconstructed_op
        else:
            raise ValueError('Invalid ops, ops must be in [enc, dec, both]')