soumickmj commited on
Commit
3818098
1 Parent(s): 2f158f8

Upload WNetMSS3D

Browse files
Files changed (7) hide show
  1. WNetConfigs.py +41 -0
  2. WNets.py +36 -0
  3. attention_unet3d.py +211 -0
  4. config.json +15 -0
  5. model.safetensors +3 -0
  6. unet3d.py +321 -0
  7. w_net_3d.py +126 -0
WNetConfigs.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import PretrainedConfig
2
+ from typing import List
3
+
4
+ class WNet3DConfig(PretrainedConfig):
5
+ model_type = "WNet"
6
+ def __init__(
7
+ self,
8
+ in_ch=1,
9
+ out_ch=5,
10
+ init_features=64,
11
+ **kwargs):
12
+ self.in_ch = in_ch
13
+ self.out_ch = out_ch
14
+ self.init_features = init_features
15
+ super().__init__(**kwargs)
16
+
17
+ class AttWNet3DConfig(PretrainedConfig):
18
+ model_type = "AttWNet"
19
+ def __init__(
20
+ self,
21
+ in_ch=1,
22
+ out_ch=5,
23
+ init_features=64,
24
+ **kwargs):
25
+ self.in_ch = in_ch
26
+ self.out_ch = out_ch
27
+ self.init_features = init_features
28
+ super().__init__(**kwargs)
29
+
30
+ class WNetMSS3DConfig(PretrainedConfig):
31
+ model_type = "WNetMSS"
32
+ def __init__(
33
+ self,
34
+ in_ch=1,
35
+ out_ch=5,
36
+ init_features=64,
37
+ **kwargs):
38
+ self.in_ch = in_ch
39
+ self.out_ch = out_ch
40
+ self.init_features = init_features
41
+ super().__init__(**kwargs)
WNets.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import PreTrainedModel
2
+ from .w_net_3d import WNet3dUNet, WNet3dAttUNet, WNet3dUNetMSS
3
+ from .WNetConfigs import WNet3DConfig, AttWNet3DConfig, WNetMSS3DConfig
4
+
5
+ class WNet3D(PreTrainedModel):
6
+ config_class = WNet3DConfig
7
+ def __init__(self, config):
8
+ super().__init__(config)
9
+ self.model = WNet3dUNet(
10
+ in_ch=config.in_ch,
11
+ out_ch=config.out_ch,
12
+ init_features=config.init_features)
13
+ def forward(self, x):
14
+ return self.model(x)
15
+
16
+ class AttWNet3D(PreTrainedModel):
17
+ config_class = AttWNet3DConfig
18
+ def __init__(self, config):
19
+ super().__init__(config)
20
+ self.model = WNet3dAttUNet(
21
+ in_ch=config.in_ch,
22
+ out_ch=config.out_ch,
23
+ init_features=config.init_features)
24
+ def forward(self, x):
25
+ return self.model(x)
26
+
27
+ class WNetMSS3D(PreTrainedModel):
28
+ config_class = WNetMSS3DConfig
29
+ def __init__(self, config):
30
+ super().__init__(config)
31
+ self.model = WNet3dUNetMSS(
32
+ in_ch=config.in_ch,
33
+ out_ch=config.out_ch,
34
+ init_features=config.init_features)
35
+ def forward(self, x):
36
+ return self.model(x)
attention_unet3d.py ADDED
@@ -0,0 +1,211 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+
3
+ # from __future__ import print_function, division
4
+ """
5
+
6
+ Purpose :
7
+
8
+ """
9
+ import torch.nn
10
+ import torch
11
+ import torch.nn as nn
12
+
13
+ __author__ = "Chethan Radhakrishna and Soumick Chatterjee"
14
+ __credits__ = ["Chethan Radhakrishna", "Soumick Chatterjee"]
15
+ __license__ = "GPL"
16
+ __version__ = "1.0.0"
17
+ __maintainer__ = "Chethan Radhakrishna"
18
+ __email__ = "[email protected]"
19
+ __status__ = "Development"
20
+
21
+
22
+ class ConvBlock(nn.Module):
23
+ """
24
+ Convolution Block
25
+ """
26
+
27
+ def __init__(self, in_channels, out_channels, k_size=3, stride=1, padding=1, bias=True):
28
+ super(ConvBlock, self).__init__()
29
+ self.conv = nn.Sequential(
30
+ nn.Conv3d(in_channels=in_channels, out_channels=out_channels, kernel_size=k_size,
31
+ stride=stride, padding=padding, bias=bias),
32
+ nn.PReLU(num_parameters=out_channels, init=0.25),
33
+ # nn.Dropout3d(),
34
+ nn.BatchNorm3d(num_features=out_channels),
35
+ nn.Conv3d(in_channels=out_channels, out_channels=out_channels, kernel_size=k_size,
36
+ stride=stride, padding=padding, bias=bias),
37
+ nn.PReLU(num_parameters=out_channels, init=0.25),
38
+ # nn.Dropout3d(),
39
+ nn.BatchNorm3d(num_features=out_channels))
40
+
41
+ def forward(self, x):
42
+ x = self.conv(x)
43
+ return x
44
+
45
+
46
+ class SeparableConvBlock(nn.Module):
47
+ """
48
+ Convolution Block
49
+ """
50
+
51
+ def __init__(self, in_channels, out_channels, k_size=3, stride=1, padding=1, bias=True):
52
+ super(SeparableConvBlock, self).__init__()
53
+ self.conv = nn.Sequential(
54
+ nn.Conv3d(in_channels=in_channels, out_channels=out_channels, kernel_size=1,
55
+ bias=bias),
56
+ nn.Conv3d(in_channels=out_channels, out_channels=out_channels, kernel_size=k_size,
57
+ stride=stride, padding=padding, bias=bias),
58
+ nn.PReLU(num_parameters=out_channels, init=0.25),
59
+ # nn.Dropout3d(),
60
+ nn.BatchNorm3d(num_features=out_channels),
61
+ nn.Conv3d(in_channels=out_channels, out_channels=out_channels, kernel_size=1,
62
+ bias=bias),
63
+ nn.Conv3d(in_channels=out_channels, out_channels=out_channels, kernel_size=k_size,
64
+ stride=stride, padding=padding, bias=bias),
65
+ nn.PReLU(num_parameters=out_channels, init=0.25),
66
+ # nn.Dropout3d(),
67
+ nn.BatchNorm3d(num_features=out_channels))
68
+
69
+ def forward(self, x):
70
+ x = self.conv(x)
71
+ return x
72
+
73
+
74
+ class UpConv(nn.Module):
75
+ """
76
+ Up Convolution Block
77
+ """
78
+
79
+ # def __init__(self, in_ch, out_ch):
80
+ def __init__(self, in_channels, out_channels, k_size=3, stride=1, padding=1):
81
+ super(UpConv, self).__init__()
82
+ self.up = nn.Sequential(
83
+ nn.Upsample(scale_factor=2),
84
+ nn.Conv3d(in_channels=in_channels, out_channels=out_channels, kernel_size=k_size,
85
+ stride=stride, padding=padding),
86
+ nn.BatchNorm3d(num_features=out_channels),
87
+ nn.PReLU(num_parameters=out_channels, init=0.25))
88
+
89
+ def forward(self, x):
90
+ x = self.up(x)
91
+ return x
92
+
93
+
94
+ class AttentionBlock(nn.Module):
95
+ """
96
+ Attention Block
97
+ """
98
+
99
+ def __init__(self, f_g, f_l, f_int):
100
+ super(AttentionBlock, self).__init__()
101
+
102
+ self.W_g = nn.Sequential(
103
+ nn.Conv3d(f_l, f_int, kernel_size=1, stride=1, padding=0, bias=True),
104
+ nn.BatchNorm3d(f_int)
105
+ )
106
+
107
+ self.W_x = nn.Sequential(
108
+ nn.Conv3d(f_g, f_int, kernel_size=1, stride=1, padding=0, bias=True),
109
+ nn.BatchNorm3d(f_int)
110
+ )
111
+
112
+ self.psi = nn.Sequential(
113
+ nn.Conv3d(f_int, 1, kernel_size=1, stride=1, padding=0, bias=True),
114
+ nn.BatchNorm3d(1),
115
+ nn.Sigmoid()
116
+ )
117
+
118
+ self.relu = nn.ReLU(inplace=True)
119
+
120
+ def forward(self, g, x):
121
+ g1 = self.W_g(g)
122
+ x1 = self.W_x(x)
123
+ psi = self.relu(g1 + x1)
124
+ psi = self.psi(psi)
125
+ out = x * psi
126
+ return out
127
+
128
+
129
+ class AttUnet(nn.Module):
130
+ """
131
+ Attention Unet implementation
132
+ Paper: https://arxiv.org/abs/1804.03999
133
+ """
134
+
135
+ def __init__(self, in_ch=1, out_ch=6, init_features=64):
136
+ super(AttUnet, self).__init__()
137
+
138
+ n1 = init_features
139
+ filters = [n1, n1 * 2, n1 * 4, n1 * 8, n1 * 16]
140
+
141
+ self.Maxpool1 = nn.MaxPool3d(kernel_size=2, stride=2)
142
+ self.Maxpool2 = nn.MaxPool3d(kernel_size=2, stride=2)
143
+ self.Maxpool3 = nn.MaxPool3d(kernel_size=2, stride=2)
144
+ self.Maxpool4 = nn.MaxPool3d(kernel_size=2, stride=2)
145
+
146
+ self.Conv1 = ConvBlock(in_ch, filters[0])
147
+ self.Conv2 = SeparableConvBlock(filters[0], filters[1])
148
+ self.Conv3 = SeparableConvBlock(filters[1], filters[2])
149
+ self.Conv4 = SeparableConvBlock(filters[2], filters[3])
150
+ self.Conv5 = SeparableConvBlock(filters[3], filters[4])
151
+
152
+ self.Up5 = UpConv(filters[4], filters[3])
153
+ self.Att5 = AttentionBlock(f_g=filters[3], f_l=filters[3], f_int=filters[2])
154
+ self.Up_conv5 = SeparableConvBlock(filters[4], filters[3])
155
+
156
+ self.Up4 = UpConv(filters[3], filters[2])
157
+ self.Att4 = AttentionBlock(f_g=filters[2], f_l=filters[2], f_int=filters[1])
158
+ self.Up_conv4 = SeparableConvBlock(filters[3], filters[2])
159
+
160
+ self.Up3 = UpConv(filters[2], filters[1])
161
+ self.Att3 = AttentionBlock(f_g=filters[1], f_l=filters[1], f_int=filters[0])
162
+ self.Up_conv3 = SeparableConvBlock(filters[2], filters[1])
163
+
164
+ self.Up2 = UpConv(filters[1], filters[0])
165
+ self.Att2 = AttentionBlock(f_g=filters[0], f_l=filters[0], f_int=32)
166
+ self.Up_conv2 = ConvBlock(filters[1], filters[0])
167
+
168
+ self.Conv = nn.Conv3d(filters[0], out_ch, kernel_size=1, stride=1, padding=0)
169
+
170
+ # self.active = torch.nn.Sigmoid()
171
+
172
+ def forward(self, x):
173
+ e1 = self.Conv1(x)
174
+
175
+ e2 = self.Maxpool1(e1)
176
+ e2 = self.Conv2(e2)
177
+
178
+ e3 = self.Maxpool2(e2)
179
+ e3 = self.Conv3(e3)
180
+
181
+ e4 = self.Maxpool3(e3)
182
+ e4 = self.Conv4(e4)
183
+
184
+ e5 = self.Maxpool4(e4)
185
+ e5 = self.Conv5(e5)
186
+
187
+ d5 = self.Up5(e5)
188
+ x4 = self.Att5(d5, e4)
189
+ d5 = torch.cat((x4, d5), dim=1)
190
+ d5 = self.Up_conv5(d5)
191
+
192
+ d4 = self.Up4(d5)
193
+ x3 = self.Att4(d4, e3)
194
+ d4 = torch.cat((x3, d4), dim=1)
195
+ d4 = self.Up_conv4(d4)
196
+
197
+ d3 = self.Up3(d4)
198
+ x2 = self.Att3(d3, e2)
199
+ d3 = torch.cat((x2, d3), dim=1)
200
+ d3 = self.Up_conv3(d3)
201
+
202
+ d2 = self.Up2(d3)
203
+ x1 = self.Att2(d2, e1)
204
+ d2 = torch.cat((x1, d2), dim=1)
205
+ d2 = self.Up_conv2(d2)
206
+
207
+ out = self.Conv(d2)
208
+
209
+ # out = self.active(out)
210
+
211
+ return out
config.json ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "WNetMSS3D"
4
+ ],
5
+ "auto_map": {
6
+ "AutoConfig": "WNetConfigs.WNetMSS3DConfig",
7
+ "AutoModel": "WNets.WNetMSS3D"
8
+ },
9
+ "in_ch": 1,
10
+ "init_features": 64,
11
+ "model_type": "WNetMSS",
12
+ "out_ch": 5,
13
+ "torch_dtype": "float32",
14
+ "transformers_version": "4.44.2"
15
+ }
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2095e12a7cace20a37746793dbbbfda62ad277b7cc537ad002850765f2d08782
3
+ size 929631232
unet3d.py ADDED
@@ -0,0 +1,321 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+
3
+ # from __future__ import print_function, division
4
+ """
5
+
6
+ Purpose :
7
+
8
+ """
9
+
10
+ import torch
11
+ import torch.nn as nn
12
+ import torch.utils.data
13
+
14
+ __author__ = "Kartik Prabhu, Mahantesh Pattadkal, and Soumick Chatterjee"
15
+ __copyright__ = "Copyright 2020, Faculty of Computer Science, Otto von Guericke University Magdeburg, Germany"
16
+ __credits__ = ["Kartik Prabhu", "Mahantesh Pattadkal", "Soumick Chatterjee"]
17
+ __license__ = "GPL"
18
+ __version__ = "1.0.0"
19
+ __maintainer__ = "Soumick Chatterjee"
20
+ __email__ = "[email protected]"
21
+ __status__ = "Production"
22
+
23
+
24
+ class ConvBlock(nn.Module):
25
+ """
26
+ Convolution Block
27
+ """
28
+
29
+ def __init__(self, in_channels, out_channels, k_size=3, stride=1, padding=1, bias=True):
30
+ super(ConvBlock, self).__init__()
31
+ self.conv = nn.Sequential(
32
+ nn.Conv3d(in_channels=in_channels, out_channels=out_channels, kernel_size=k_size,
33
+ stride=stride, padding=padding, bias=bias),
34
+ nn.BatchNorm3d(num_features=out_channels),
35
+ nn.LeakyReLU(inplace=True),
36
+ nn.Conv3d(in_channels=out_channels, out_channels=out_channels, kernel_size=k_size,
37
+ stride=stride, padding=padding, bias=bias),
38
+ nn.BatchNorm3d(num_features=out_channels),
39
+ nn.LeakyReLU(inplace=True)
40
+ )
41
+
42
+ def forward(self, x):
43
+ x = self.conv(x)
44
+ return x
45
+
46
+
47
+ class SeparableConvBlock(nn.Module):
48
+ """
49
+ Convolution Block
50
+ """
51
+
52
+ def __init__(self, in_channels, out_channels, k_size=3, stride=1, padding=1, bias=True):
53
+ super(SeparableConvBlock, self).__init__()
54
+ self.conv = nn.Sequential(
55
+ nn.Conv3d(in_channels=in_channels, out_channels=out_channels, kernel_size=1,
56
+ bias=bias),
57
+ nn.Conv3d(in_channels=out_channels, out_channels=out_channels, kernel_size=k_size,
58
+ stride=stride, padding=padding, bias=bias),
59
+ nn.PReLU(num_parameters=out_channels, init=0.25),
60
+ nn.Dropout3d(),
61
+ nn.BatchNorm3d(num_features=out_channels),
62
+ nn.Conv3d(in_channels=out_channels, out_channels=out_channels, kernel_size=1,
63
+ bias=bias),
64
+ nn.Conv3d(in_channels=out_channels, out_channels=out_channels, kernel_size=k_size,
65
+ stride=stride, padding=padding, bias=bias),
66
+ nn.PReLU(num_parameters=out_channels, init=0.25),
67
+ nn.Dropout3d(),
68
+ nn.BatchNorm3d(num_features=out_channels))
69
+
70
+ def forward(self, x):
71
+ x = self.conv(x)
72
+ return x
73
+
74
+
75
+ class UpConv(nn.Module):
76
+ """
77
+ Up Convolution Block
78
+ """
79
+
80
+ # def __init__(self, in_ch, out_ch):
81
+ def __init__(self, in_channels, out_channels, k_size=3, stride=1, padding=1, bias=True):
82
+ super(UpConv, self).__init__()
83
+ self.up = nn.Sequential(
84
+ nn.Upsample(scale_factor=2),
85
+ nn.Conv3d(in_channels=in_channels, out_channels=out_channels, kernel_size=k_size,
86
+ stride=stride, padding=padding, bias=bias),
87
+ nn.BatchNorm3d(num_features=out_channels),
88
+ nn.LeakyReLU(inplace=True))
89
+
90
+ def forward(self, x):
91
+ x = self.up(x)
92
+ return x
93
+
94
+
95
+ class UNet(nn.Module):
96
+ """
97
+ UNet - Basic Implementation
98
+ Input _ [batch * channel(# of channels of each image) * depth(# of frames) * height * width].
99
+ Paper : https://arxiv.org/abs/1505.04597
100
+ """
101
+
102
+ def __init__(self, in_ch=1, out_ch=1, init_features=64):
103
+ super(UNet, self).__init__()
104
+
105
+ n1 = init_features
106
+ filters = [n1, n1 * 2, n1 * 4, n1 * 8, n1 * 16] # 64,128,256,512,1024
107
+
108
+ self.Maxpool1 = nn.MaxPool3d(kernel_size=2, stride=2)
109
+ self.Maxpool2 = nn.MaxPool3d(kernel_size=2, stride=2)
110
+ self.Maxpool3 = nn.MaxPool3d(kernel_size=2, stride=2)
111
+ self.Maxpool4 = nn.MaxPool3d(kernel_size=2, stride=2)
112
+
113
+ self.Conv1 = ConvBlock(in_ch, filters[0])
114
+ self.Conv2 = SeparableConvBlock(filters[0], filters[1])
115
+ self.Conv3 = SeparableConvBlock(filters[1], filters[2])
116
+ self.Conv4 = SeparableConvBlock(filters[2], filters[3])
117
+ self.Conv5 = SeparableConvBlock(filters[3], filters[4])
118
+
119
+ self.Up5 = UpConv(filters[4], filters[3])
120
+ self.Up_conv5 = SeparableConvBlock(filters[4], filters[3])
121
+
122
+ self.Up4 = UpConv(filters[3], filters[2])
123
+ self.Up_conv4 = SeparableConvBlock(filters[3], filters[2])
124
+
125
+ self.Up3 = UpConv(filters[2], filters[1])
126
+ self.Up_conv3 = SeparableConvBlock(filters[2], filters[1])
127
+
128
+ self.Up2 = UpConv(filters[1], filters[0])
129
+ self.Up_conv2 = ConvBlock(filters[1], filters[0])
130
+
131
+ self.Conv = nn.Conv3d(filters[0], out_ch, kernel_size=1, stride=1, padding=0)
132
+
133
+ # self.active = torch.nn.Sigmoid()
134
+
135
+ def forward(self, x):
136
+ # print("unet")
137
+ # print(x.shape)
138
+ # print(padded.shape)
139
+
140
+ e1 = self.Conv1(x)
141
+ # print("conv1:")
142
+ # print(e1.shape)
143
+
144
+ e2 = self.Maxpool1(e1)
145
+ e2 = self.Conv2(e2)
146
+ # print("conv2:")
147
+ # print(e2.shape)
148
+
149
+ e3 = self.Maxpool2(e2)
150
+ e3 = self.Conv3(e3)
151
+ # print("conv3:")
152
+ # print(e3.shape)
153
+
154
+ e4 = self.Maxpool3(e3)
155
+ e4 = self.Conv4(e4)
156
+ # print("conv4:")
157
+ # print(e4.shape)
158
+
159
+ e5 = self.Maxpool4(e4)
160
+ e5 = self.Conv5(e5)
161
+ # print("conv5:")
162
+ # print(e5.shape)
163
+
164
+ d5 = self.Up5(e5)
165
+ # print("d5:")
166
+ # print(d5.shape)
167
+ # print("e4:")
168
+ # print(e4.shape)
169
+ d5 = torch.cat((e4, d5), dim=1)
170
+ d5 = self.Up_conv5(d5)
171
+ # print("upconv5:")
172
+ # print(d5.size)
173
+
174
+ d4 = self.Up4(d5)
175
+ # print("d4:")
176
+ # print(d4.shape)
177
+ d4 = torch.cat((e3, d4), dim=1)
178
+ d4 = self.Up_conv4(d4)
179
+ # print("upconv4:")
180
+ # print(d4.shape)
181
+ d3 = self.Up3(d4)
182
+ d3 = torch.cat((e2, d3), dim=1)
183
+ d3 = self.Up_conv3(d3)
184
+ # print("upconv3:")
185
+ # print(d3.shape)
186
+ d2 = self.Up2(d3)
187
+ d2 = torch.cat((e1, d2), dim=1)
188
+ d2 = self.Up_conv2(d2)
189
+ # print("upconv2:")
190
+ # print(d2.shape)
191
+ out = self.Conv(d2)
192
+ # print("out:")
193
+ # print(out.shape)
194
+ # d1 = self.active(out)
195
+
196
+ return out
197
+
198
+
199
+ class UNetDeepSup(nn.Module):
200
+ """
201
+ UNet - Basic Implementation
202
+ Input _ [batch * channel(# of channels of each image) * depth(# of frames) * height * width].
203
+ Paper : https://arxiv.org/abs/1505.04597
204
+ """
205
+
206
+ def __init__(self, in_ch=1, out_ch=1, init_features=64):
207
+ super(UNetDeepSup, self).__init__()
208
+
209
+ n1 = init_features
210
+ filters = [n1, n1 * 2, n1 * 4, n1 * 8, n1 * 16] # 64,128,256,512,1024
211
+
212
+ self.Maxpool1 = nn.MaxPool3d(kernel_size=2, stride=2)
213
+ self.Maxpool2 = nn.MaxPool3d(kernel_size=2, stride=2)
214
+ self.Maxpool3 = nn.MaxPool3d(kernel_size=2, stride=2)
215
+ self.Maxpool4 = nn.MaxPool3d(kernel_size=2, stride=2)
216
+
217
+ self.Conv1 = ConvBlock(in_ch, filters[0])
218
+ self.Conv2 = SeparableConvBlock(filters[0], filters[1])
219
+ self.Conv3 = SeparableConvBlock(filters[1], filters[2])
220
+ self.Conv4 = SeparableConvBlock(filters[2], filters[3])
221
+ self.Conv5 = SeparableConvBlock(filters[3], filters[4])
222
+
223
+ # 1x1x1 Convolution for Deep Supervision
224
+ self.Conv_d3 = SeparableConvBlock(filters[1], 1)
225
+ self.Conv_d4 = SeparableConvBlock(filters[2], 1)
226
+
227
+ self.Up5 = UpConv(filters[4], filters[3])
228
+ self.Up_conv5 = SeparableConvBlock(filters[4], filters[3])
229
+
230
+ self.Up4 = UpConv(filters[3], filters[2])
231
+ self.Up_conv4 = SeparableConvBlock(filters[3], filters[2])
232
+
233
+ self.Up3 = UpConv(filters[2], filters[1])
234
+ self.Up_conv3 = SeparableConvBlock(filters[2], filters[1])
235
+
236
+ self.Up2 = UpConv(filters[1], filters[0])
237
+ self.Up_conv2 = ConvBlock(filters[1], filters[0])
238
+
239
+ self.Conv = nn.Conv3d(filters[0], out_ch, kernel_size=1, stride=1, padding=0)
240
+
241
+ for submodule in self.modules():
242
+ submodule.register_forward_hook(self.nan_hook)
243
+
244
+ # self.active = torch.nn.Sigmoid()
245
+
246
+ def nan_hook(self, module, inp, output):
247
+ for i, out in enumerate(output):
248
+ nan_mask = torch.isnan(out)
249
+ if nan_mask.any():
250
+ print("In", self.__class__.__name__)
251
+ torch.save(inp, '/nfs1/sutrave/outputs/nan_values_input/inp_2_Nov.pt')
252
+ raise RuntimeError(" classname " + self.__class__.__name__ + "i " + str(
253
+ i) + f" module: {module} classname {self.__class__.__name__} Found NAN in output {i} at indices: ",
254
+ nan_mask.nonzero(), "where:", out[nan_mask.nonzero()[:, 0].unique(sorted=True)])
255
+
256
+ def forward(self, x):
257
+ # print("unet")
258
+ # print(x.shape)
259
+ # print(padded.shape)
260
+
261
+ e1 = self.Conv1(x)
262
+ # print("conv1:")
263
+ # print(e1.shape)
264
+
265
+ e2 = self.Maxpool1(e1)
266
+ e2 = self.Conv2(e2)
267
+ # print("conv2:")
268
+ # print(e2.shape)
269
+
270
+ e3 = self.Maxpool2(e2)
271
+ e3 = self.Conv3(e3)
272
+ # print("conv3:")
273
+ # print(e3.shape)
274
+
275
+ e4 = self.Maxpool3(e3)
276
+ e4 = self.Conv4(e4)
277
+ # print("conv4:")
278
+ # print(e4.shape)
279
+
280
+ e5 = self.Maxpool4(e4)
281
+ e5 = self.Conv5(e5)
282
+ # print("conv5:")
283
+ # print(e5.shape)
284
+
285
+ d5 = self.Up5(e5)
286
+ # print("d5:")
287
+ # print(d5.shape)
288
+ # print("e4:")
289
+ # print(e4.shape)
290
+ d5 = torch.cat((e4, d5), dim=1)
291
+ d5 = self.Up_conv5(d5)
292
+ # print("upconv5:")
293
+ # print(d5.size)
294
+
295
+ d4 = self.Up4(d5)
296
+ # print("d4:")
297
+ # print(d4.shape)
298
+ d4 = torch.cat((e3, d4), dim=1)
299
+ d4 = self.Up_conv4(d4)
300
+ d4_out = self.Conv_d4(d4)
301
+
302
+ # print("upconv4:")
303
+ # print(d4.shape)
304
+ d3 = self.Up3(d4)
305
+ d3 = torch.cat((e2, d3), dim=1)
306
+ d3 = self.Up_conv3(d3)
307
+ d3_out = self.Conv_d3(d3)
308
+
309
+ # print("upconv3:")
310
+ # print(d3.shape)
311
+ d2 = self.Up2(d3)
312
+ d2 = torch.cat((e1, d2), dim=1)
313
+ d2 = self.Up_conv2(d2)
314
+ # print("upconv2:")
315
+ # print(d2.shape)
316
+ out = self.Conv(d2)
317
+ # print("out:")
318
+ # print(out.shape)
319
+ # d1 = self.active(out)
320
+
321
+ return out
w_net_3d.py ADDED
@@ -0,0 +1,126 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+
3
+ # from __future__ import print_function, division
4
+ """
5
+
6
+ Purpose :
7
+
8
+ """
9
+ import torch.nn
10
+ import torch
11
+ import torch.nn as nn
12
+ from .attention_unet3d import AttUnet
13
+ from .unet3d import UNet, UNetDeepSup
14
+
15
+ __author__ = "Chethan Radhakrishna and Soumick Chatterjee"
16
+ __credits__ = ["Chethan Radhakrishna", "Soumick Chatterjee"]
17
+ __license__ = "GPL"
18
+ __version__ = "1.0.0"
19
+ __maintainer__ = "Chethan Radhakrishna"
20
+ __email__ = "[email protected]"
21
+ __status__ = "Development"
22
+
23
+
24
+ class WNet3dAttUNet(nn.Module):
25
+ """
26
+ Attention Unet implementation
27
+ Paper: https://arxiv.org/abs/1804.03999
28
+ """
29
+
30
+ def __init__(self, in_ch=1, out_ch=6, init_features=64):
31
+ super(WNet3dAttUNet, self).__init__()
32
+
33
+ self.Encoder = AttUnet(in_ch=in_ch, out_ch=out_ch, init_features=init_features)
34
+ self.Decoder = AttUnet(in_ch=out_ch, out_ch=in_ch, init_features=init_features)
35
+
36
+ self.activation = torch.nn.Softmax(dim=1)
37
+
38
+ self.Conv = nn.Conv3d(out_ch, in_ch, kernel_size=1, stride=1, padding=0)
39
+
40
+ def forward(self, ip, ip_mask=None, ops="both"):
41
+ encoder_op = self.Encoder(ip)
42
+ if ip_mask is not None:
43
+ encoder_op = ip_mask * encoder_op
44
+ class_prob = self.activation(encoder_op)
45
+ feature_rep = self.Conv(encoder_op)
46
+ if ops == "enc":
47
+ return class_prob, feature_rep
48
+ reconstructed_op = self.Decoder(class_prob)
49
+ # if ip_mask is not None:
50
+ # reconstructed_op = torch.amax(ip_mask, dim=1, keepdim=True) * reconstructed_op
51
+ if ops == "dec":
52
+ return reconstructed_op
53
+ if ops == "both":
54
+ return class_prob, feature_rep, reconstructed_op
55
+ else:
56
+ raise ValueError('Invalid ops, ops must be in [enc, dec, both]')
57
+
58
+
59
+ class WNet3dUNet(nn.Module):
60
+ """
61
+ Attention Unet implementation
62
+ Paper: https://arxiv.org/abs/1804.03999
63
+ """
64
+
65
+ def __init__(self, in_ch=1, out_ch=6, init_features=64):
66
+ super(WNet3dUNet, self).__init__()
67
+
68
+ self.Encoder = UNet(in_ch=in_ch, out_ch=out_ch, init_features=init_features)
69
+ self.Decoder = UNet(in_ch=out_ch, out_ch=in_ch, init_features=init_features)
70
+
71
+ self.activation = torch.nn.Softmax(dim=1)
72
+
73
+ self.Conv = nn.Conv3d(out_ch, in_ch, kernel_size=1, stride=1, padding=0)
74
+
75
+ def forward(self, ip, ip_mask=None, ops="both"):
76
+ encoder_op = self.Encoder(ip)
77
+ if ip_mask is not None:
78
+ encoder_op = ip_mask * encoder_op
79
+ class_prob = self.activation(encoder_op)
80
+ feature_rep = self.Conv(encoder_op)
81
+ if ops == "enc":
82
+ return class_prob, feature_rep
83
+ reconstructed_op = self.Decoder(class_prob)
84
+ # if ip_mask is not None:
85
+ # reconstructed_op = torch.amax(ip_mask, dim=1, keepdim=True) * reconstructed_op
86
+ if ops == "dec":
87
+ return reconstructed_op
88
+ if ops == "both":
89
+ return class_prob, feature_rep, reconstructed_op
90
+ else:
91
+ raise ValueError('Invalid ops, ops must be in [enc, dec, both]')
92
+
93
+
94
+ class WNet3dUNetMSS(nn.Module):
95
+ """
96
+ Attention Unet implementation
97
+ Paper: https://arxiv.org/abs/1804.03999
98
+ """
99
+
100
+ def __init__(self, in_ch=1, out_ch=6, init_features=64):
101
+ super(WNet3dUNetMSS, self).__init__()
102
+
103
+ self.Encoder = UNetDeepSup(in_ch=in_ch, out_ch=out_ch, init_features=init_features)
104
+ self.Decoder = UNetDeepSup(in_ch=out_ch, out_ch=in_ch, init_features=init_features)
105
+
106
+ self.activation = torch.nn.Softmax(dim=1)
107
+
108
+ self.Conv = nn.Conv3d(out_ch, in_ch, kernel_size=1, stride=1, padding=0)
109
+
110
+ def forward(self, ip, ip_mask=None, ops="both"):
111
+ encoder_op = self.Encoder(ip)
112
+ if ip_mask is not None:
113
+ encoder_op = ip_mask * encoder_op
114
+ class_prob = self.activation(encoder_op)
115
+ feature_rep = self.Conv(encoder_op)
116
+ if ops == "enc":
117
+ return class_prob, feature_rep
118
+ reconstructed_op = self.Decoder(class_prob)
119
+ # if ip_mask is not None:
120
+ # reconstructed_op = torch.amax(ip_mask, dim=1, keepdim=True) * reconstructed_op
121
+ if ops == "dec":
122
+ return reconstructed_op
123
+ if ops == "both":
124
+ return class_prob, feature_rep, reconstructed_op
125
+ else:
126
+ raise ValueError('Invalid ops, ops must be in [enc, dec, both]')