soumickmj commited on
Commit
4cd46da
1 Parent(s): 2078d5d

Upload UNetMSS3D

Browse files
Files changed (5) hide show
  1. UNetConfigs.py +32 -0
  2. UNets.py +27 -0
  3. config.json +16 -0
  4. model.safetensors +3 -0
  5. unet3d.py +305 -0
UNetConfigs.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import PretrainedConfig
2
+ from typing import List
3
+
4
+ class UNet3DConfig(PretrainedConfig):
5
+ model_type = "UNet"
6
+ def __init__(
7
+ self,
8
+ in_ch=1,
9
+ out_ch=1,
10
+ init_features=64,
11
+ dropout_rate=0.5,
12
+ **kwargs):
13
+ self.in_ch = in_ch
14
+ self.out_ch = out_ch
15
+ self.init_features = init_features
16
+ self.dropout_rate = dropout_rate
17
+ super().__init__(**kwargs)
18
+
19
+ class UNetMSS3DConfig(PretrainedConfig):
20
+ model_type = "UNetMSS"
21
+ def __init__(
22
+ self,
23
+ in_ch=1,
24
+ out_ch=1,
25
+ init_features=64,
26
+ dropout_rate=0.5,
27
+ **kwargs):
28
+ self.in_ch = in_ch
29
+ self.out_ch = out_ch
30
+ self.init_features = init_features
31
+ self.dropout_rate = dropout_rate
32
+ super().__init__(**kwargs)
UNets.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import PreTrainedModel
2
+ from .unet3d import UNet, UNetDeepSup
3
+ from .UNetConfigs import UNet3DConfig, UNetMSS3DConfig
4
+
5
+ class UNet3D(PreTrainedModel):
6
+ config_class = UNet3DConfig
7
+ def __init__(self, config):
8
+ super().__init__(config)
9
+ self.model = UNet(
10
+ in_ch=config.in_ch,
11
+ out_ch=config.out_ch,
12
+ init_features=config.init_features,
13
+ dropout_rate=config.dropout_rate)
14
+ def forward(self, x):
15
+ return self.model(x)
16
+
17
+ class UNetMSS3D(PreTrainedModel):
18
+ config_class = UNetMSS3DConfig
19
+ def __init__(self, config):
20
+ super().__init__(config)
21
+ self.model = UNetDeepSup(
22
+ in_ch=config.in_ch,
23
+ out_ch=config.out_ch,
24
+ init_features=config.init_features,
25
+ dropout_rate=config.dropout_rate)
26
+ def forward(self, x):
27
+ return self.model(x)
config.json ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "UNetMSS3D"
4
+ ],
5
+ "auto_map": {
6
+ "AutoConfig": "UNetConfigs.UNetMSS3DConfig",
7
+ "AutoModel": "UNets.UNetMSS3D"
8
+ },
9
+ "dropout_rate": 0.5,
10
+ "in_ch": 1,
11
+ "init_features": 64,
12
+ "model_type": "UNetMSS",
13
+ "out_ch": 1,
14
+ "torch_dtype": "float32",
15
+ "transformers_version": "4.44.2"
16
+ }
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4b05c582cc75112697435fb9d0b9d4fdb7af4119f6b0477513375738b56ed9ba
3
+ size 414260220
unet3d.py ADDED
@@ -0,0 +1,305 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+
3
+ """
4
+
5
+ Purpose :
6
+
7
+ """
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+ import torch.utils.data
12
+
13
+ __author__ = "Kartik Prabhu, Mahantesh Pattadkal, and Soumick Chatterjee"
14
+ __copyright__ = "Copyright 2020, Faculty of Computer Science, Otto von Guericke University Magdeburg, Germany"
15
+ __credits__ = ["Kartik Prabhu", "Mahantesh Pattadkal", "Soumick Chatterjee"]
16
+ __license__ = "GPL"
17
+ __version__ = "1.0.0"
18
+ __maintainer__ = "Soumick Chatterjee"
19
+ __email__ = "[email protected]"
20
+ __status__ = "Production"
21
+
22
+
23
+ class ConvBlock(nn.Module):
24
+ """
25
+ Convolution Block
26
+ """
27
+
28
+ def __init__(self, in_channels, out_channels, k_size=3, stride=1, padding=1, bias=True, dropout_rate=None):
29
+ super(ConvBlock, self).__init__()
30
+ if bool(dropout_rate):
31
+ self.conv = nn.Sequential(
32
+ nn.Conv3d(in_channels=in_channels, out_channels=out_channels, kernel_size=k_size,
33
+ stride=stride, padding=padding, bias=bias),
34
+ nn.BatchNorm3d(num_features=out_channels),
35
+ nn.LeakyReLU(inplace=True),
36
+ nn.Dropout3d(p=dropout_rate), #This changes the order in the sequential model
37
+ nn.Conv3d(in_channels=out_channels, out_channels=out_channels, kernel_size=k_size,
38
+ stride=stride, padding=padding, bias=bias),
39
+ nn.BatchNorm3d(num_features=out_channels),
40
+ nn.LeakyReLU(inplace=True)
41
+ )
42
+ else:
43
+ self.conv = nn.Sequential(
44
+ nn.Conv3d(in_channels=in_channels, out_channels=out_channels, kernel_size=k_size,
45
+ stride=stride, padding=padding, bias=bias),
46
+ nn.BatchNorm3d(num_features=out_channels),
47
+ nn.LeakyReLU(inplace=True),
48
+ nn.Conv3d(in_channels=out_channels, out_channels=out_channels, kernel_size=k_size,
49
+ stride=stride, padding=padding, bias=bias),
50
+ nn.BatchNorm3d(num_features=out_channels),
51
+ nn.LeakyReLU(inplace=True)
52
+ )
53
+
54
+ def forward(self, x):
55
+ x = self.conv(x)
56
+ return x
57
+
58
+
59
+ class UpConv(nn.Module):
60
+ """
61
+ Up Convolution Block
62
+ """
63
+
64
+ # def __init__(self, in_ch, out_ch):
65
+ def __init__(self, in_channels, out_channels, k_size=3, stride=1, padding=1, bias=True):
66
+ super(UpConv, self).__init__()
67
+ self.up = nn.Sequential(
68
+ nn.Upsample(scale_factor=2),
69
+ nn.Conv3d(in_channels=in_channels, out_channels=out_channels, kernel_size=k_size,
70
+ stride=stride, padding=padding, bias=bias),
71
+ nn.BatchNorm3d(num_features=out_channels),
72
+ nn.LeakyReLU(inplace=True))
73
+
74
+ def forward(self, x):
75
+ x = self.up(x)
76
+ return x
77
+
78
+
79
+ class UNet(nn.Module):
80
+ """
81
+ UNet - Basic Implementation
82
+ Input _ [batch * channel(# of channels of each image) * depth(# of frames) * height * width].
83
+ Paper : https://arxiv.org/abs/1505.04597
84
+ """
85
+
86
+ def __init__(self, in_ch=1, out_ch=1, init_features=64, dropout_rate=None):
87
+ super(UNet, self).__init__()
88
+
89
+ n1 = init_features
90
+ filters = [n1, n1 * 2, n1 * 4, n1 * 8, n1 * 16] # 64,128,256,512,1024
91
+
92
+ self.Maxpool1 = nn.MaxPool3d(kernel_size=2, stride=2)
93
+ self.Maxpool2 = nn.MaxPool3d(kernel_size=2, stride=2)
94
+ self.Maxpool3 = nn.MaxPool3d(kernel_size=2, stride=2)
95
+ self.Maxpool4 = nn.MaxPool3d(kernel_size=2, stride=2)
96
+
97
+ self.Conv1 = ConvBlock(in_ch, filters[0], dropout_rate=dropout_rate)
98
+ self.Conv2 = ConvBlock(filters[0], filters[1], dropout_rate=dropout_rate)
99
+ self.Conv3 = ConvBlock(filters[1], filters[2], dropout_rate=dropout_rate)
100
+ self.Conv4 = ConvBlock(filters[2], filters[3], dropout_rate=dropout_rate)
101
+ self.Conv5 = ConvBlock(filters[3], filters[4], dropout_rate=dropout_rate)
102
+
103
+ self.Up5 = UpConv(filters[4], filters[3])
104
+ self.Up_conv5 = ConvBlock(filters[4], filters[3], dropout_rate=dropout_rate)
105
+
106
+ self.Up4 = UpConv(filters[3], filters[2])
107
+ self.Up_conv4 = ConvBlock(filters[3], filters[2], dropout_rate=dropout_rate)
108
+
109
+ self.Up3 = UpConv(filters[2], filters[1])
110
+ self.Up_conv3 = ConvBlock(filters[2], filters[1], dropout_rate=dropout_rate)
111
+
112
+ self.Up2 = UpConv(filters[1], filters[0])
113
+ self.Up_conv2 = ConvBlock(filters[1], filters[0], dropout_rate=dropout_rate)
114
+
115
+ self.Conv = nn.Conv3d(filters[0], out_ch, kernel_size=1, stride=1, padding=0)
116
+
117
+ # self.active = torch.nn.Sigmoid()
118
+
119
+ def forward(self, x):
120
+ # print("unet")
121
+ # print(x.shape)
122
+ # print(padded.shape)
123
+
124
+ e1 = self.Conv1(x)
125
+ # print("conv1:")
126
+ # print(e1.shape)
127
+
128
+ e2 = self.Maxpool1(e1)
129
+ e2 = self.Conv2(e2)
130
+ # print("conv2:")
131
+ # print(e2.shape)
132
+
133
+ e3 = self.Maxpool2(e2)
134
+ e3 = self.Conv3(e3)
135
+ # print("conv3:")
136
+ # print(e3.shape)
137
+
138
+ e4 = self.Maxpool3(e3)
139
+ e4 = self.Conv4(e4)
140
+ # print("conv4:")
141
+ # print(e4.shape)
142
+
143
+ e5 = self.Maxpool4(e4)
144
+ e5 = self.Conv5(e5)
145
+ # print("conv5:")
146
+ # print(e5.shape)
147
+
148
+ d5 = self.Up5(e5)
149
+ # print("d5:")
150
+ # print(d5.shape)
151
+ # print("e4:")
152
+ # print(e4.shape)
153
+ d5 = torch.cat((e4, d5), dim=1)
154
+ d5 = self.Up_conv5(d5)
155
+ # print("upconv5:")
156
+ # print(d5.size)
157
+
158
+ d4 = self.Up4(d5)
159
+ # print("d4:")
160
+ # print(d4.shape)
161
+ d4 = torch.cat((e3, d4), dim=1)
162
+ d4 = self.Up_conv4(d4)
163
+ # print("upconv4:")
164
+ # print(d4.shape)
165
+ d3 = self.Up3(d4)
166
+ d3 = torch.cat((e2, d3), dim=1)
167
+ d3 = self.Up_conv3(d3)
168
+ # print("upconv3:")
169
+ # print(d3.shape)
170
+ d2 = self.Up2(d3)
171
+ d2 = torch.cat((e1, d2), dim=1)
172
+ d2 = self.Up_conv2(d2)
173
+ # print("upconv2:")
174
+ # print(d2.shape)
175
+ out = self.Conv(d2)
176
+ # print("out:")
177
+ # print(out.shape)
178
+ # d1 = self.active(out)
179
+
180
+ return [out]
181
+
182
+
183
+ class UNetDeepSup(nn.Module):
184
+ """
185
+ UNet - Basic Implementation
186
+ Input _ [batch * channel(# of channels of each image) * depth(# of frames) * height * width].
187
+ Paper : https://arxiv.org/abs/1505.04597
188
+ """
189
+
190
+ def __init__(self, in_ch=1, out_ch=1, init_features=64, dropout_rate=None):
191
+ super(UNetDeepSup, self).__init__()
192
+
193
+ n1 = init_features
194
+ filters = [n1, n1 * 2, n1 * 4, n1 * 8, n1 * 16] # 64,128,256,512,1024
195
+
196
+ self.Maxpool1 = nn.MaxPool3d(kernel_size=2, stride=2)
197
+ self.Maxpool2 = nn.MaxPool3d(kernel_size=2, stride=2)
198
+ self.Maxpool3 = nn.MaxPool3d(kernel_size=2, stride=2)
199
+ self.Maxpool4 = nn.MaxPool3d(kernel_size=2, stride=2)
200
+
201
+ self.Conv1 = ConvBlock(in_ch, filters[0], dropout_rate=dropout_rate)
202
+ self.Conv2 = ConvBlock(filters[0], filters[1], dropout_rate=dropout_rate)
203
+ self.Conv3 = ConvBlock(filters[1], filters[2], dropout_rate=dropout_rate)
204
+ self.Conv4 = ConvBlock(filters[2], filters[3], dropout_rate=dropout_rate)
205
+ self.Conv5 = ConvBlock(filters[3], filters[4], dropout_rate=dropout_rate)
206
+
207
+ # 1x1x1 Convolution for Deep Supervision
208
+ self.Conv_d3 = ConvBlock(filters[1], 1, dropout_rate=None)
209
+ self.Conv_d4 = ConvBlock(filters[2], 1, dropout_rate=None)
210
+
211
+ self.Up5 = UpConv(filters[4], filters[3])
212
+ self.Up_conv5 = ConvBlock(filters[4], filters[3], dropout_rate=dropout_rate)
213
+
214
+ self.Up4 = UpConv(filters[3], filters[2])
215
+ self.Up_conv4 = ConvBlock(filters[3], filters[2], dropout_rate=dropout_rate)
216
+
217
+ self.Up3 = UpConv(filters[2], filters[1])
218
+ self.Up_conv3 = ConvBlock(filters[2], filters[1], dropout_rate=dropout_rate)
219
+
220
+ self.Up2 = UpConv(filters[1], filters[0])
221
+ self.Up_conv2 = ConvBlock(filters[1], filters[0], dropout_rate=dropout_rate)
222
+
223
+ self.Conv = nn.Conv3d(filters[0], out_ch, kernel_size=1, stride=1, padding=0)
224
+
225
+ for submodule in self.modules():
226
+ submodule.register_forward_hook(self.nan_hook)
227
+
228
+ # self.active = torch.nn.Sigmoid()
229
+
230
+ def nan_hook(self, module, inp, output):
231
+ for i, out in enumerate(output):
232
+ nan_mask = torch.isnan(out)
233
+ if nan_mask.any():
234
+ print("In", self.__class__.__name__)
235
+ torch.save(inp, '/nfs1/sutrave/outputs/nan_values_input/inp_2_Nov.pt')
236
+ raise RuntimeError(" classname " + self.__class__.__name__ + "i " + str(
237
+ i) + f" module: {module} classname {self.__class__.__name__} Found NAN in output {i} at indices: ",
238
+ nan_mask.nonzero(), "where:", out[nan_mask.nonzero()[:, 0].unique(sorted=True)])
239
+
240
+ def forward(self, x):
241
+ # print("unet")
242
+ # print(x.shape)
243
+ # print(padded.shape)
244
+
245
+ e1 = self.Conv1(x)
246
+ # print("conv1:")
247
+ # print(e1.shape)
248
+
249
+ e2 = self.Maxpool1(e1)
250
+ e2 = self.Conv2(e2)
251
+ # print("conv2:")
252
+ # print(e2.shape)
253
+
254
+ e3 = self.Maxpool2(e2)
255
+ e3 = self.Conv3(e3)
256
+ # print("conv3:")
257
+ # print(e3.shape)
258
+
259
+ e4 = self.Maxpool3(e3)
260
+ e4 = self.Conv4(e4)
261
+ # print("conv4:")
262
+ # print(e4.shape)
263
+
264
+ e5 = self.Maxpool4(e4)
265
+ e5 = self.Conv5(e5)
266
+ # print("conv5:")
267
+ # print(e5.shape)
268
+
269
+ d5 = self.Up5(e5)
270
+ # print("d5:")
271
+ # print(d5.shape)
272
+ # print("e4:")
273
+ # print(e4.shape)
274
+ d5 = torch.cat((e4, d5), dim=1)
275
+ d5 = self.Up_conv5(d5)
276
+ # print("upconv5:")
277
+ # print(d5.size)
278
+
279
+ d4 = self.Up4(d5)
280
+ # print("d4:")
281
+ # print(d4.shape)
282
+ d4 = torch.cat((e3, d4), dim=1)
283
+ d4 = self.Up_conv4(d4)
284
+ d4_out = self.Conv_d4(d4)
285
+
286
+ # print("upconv4:")
287
+ # print(d4.shape)
288
+ d3 = self.Up3(d4)
289
+ d3 = torch.cat((e2, d3), dim=1)
290
+ d3 = self.Up_conv3(d3)
291
+ d3_out = self.Conv_d3(d3)
292
+
293
+ # print("upconv3:")
294
+ # print(d3.shape)
295
+ d2 = self.Up2(d3)
296
+ d2 = torch.cat((e1, d2), dim=1)
297
+ d2 = self.Up_conv2(d2)
298
+ # print("upconv2:")
299
+ # print(d2.shape)
300
+ out = self.Conv(d2)
301
+ # print("out:")
302
+ # print(out.shape)
303
+ # d1 = self.active(out)
304
+
305
+ return [out, d3_out, d4_out]