soumickmj commited on
Commit
0e6a5ed
1 Parent(s): f992e86

Upload ReconResNet

Browse files
Files changed (5) hide show
  1. ReconResNet.py +25 -0
  2. ReconResNetBase.py +267 -0
  3. ReconResNetConfig.py +37 -0
  4. config.json +26 -0
  5. model.safetensors +3 -0
ReconResNet.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import PreTrainedModel
2
+ from .ReconResNetBase import ReconResNetBase
3
+ from .ReconResNetConfig import ReconResNetConfig
4
+
5
+ class ReconResNet(PreTrainedModel):
6
+ config_class = ReconResNetConfig
7
+ def __init__(self, config):
8
+ super().__init__(config)
9
+ self.model = ReconResNetBase(
10
+ in_channels=config.in_channels,
11
+ out_channels=config.out_channels,
12
+ res_blocks=config.res_blocks,
13
+ starting_nfeatures=config.starting_nfeatures,
14
+ updown_blocks=config.updown_blocks,
15
+ is_relu_leaky=config.is_relu_leaky,
16
+ do_batchnorm=config.do_batchnorm,
17
+ res_drop_prob=config.res_drop_prob,
18
+ is_replicatepad=config.is_replicatepad,
19
+ out_act=config.out_act,
20
+ forwardV=config.forwardV,
21
+ upinterp_algo=config.upinterp_algo,
22
+ post_interp_convtrans=config.post_interp_convtrans,
23
+ is3D=config.is3D)
24
+ def forward(self, x):
25
+ return self.model(x)
ReconResNetBase.py ADDED
@@ -0,0 +1,267 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+
3
+ # This model is part of the paper "ReconResNet: Regularised Residual Learning for MR Image Reconstruction of Undersampled Cartesian and Radial Data" (https://doi.org/10.1016/j.compbiomed.2022.105321)
4
+ # and has been published on GitHub: https://github.com/soumickmj/NCC1701/blob/main/Bridge/WarpDrives/ReconResNet/ReconResNet.py
5
+
6
+ import torch.nn as nn
7
+ from tricorder.torch.transforms import Interpolator
8
+
9
+ __author__ = "Soumick Chatterjee"
10
+ __copyright__ = "Copyright 2019, Soumick Chatterjee & OvGU:ESF:MEMoRIAL"
11
+ __credits__ = ["Soumick Chatterjee"]
12
+
13
+ __license__ = "apache-2.0"
14
+ __version__ = "1.0.0"
15
+ __email__ = "[email protected]"
16
+ __status__ = "Published"
17
+
18
+
19
+ class ResidualBlock(nn.Module):
20
+ def __init__(self, in_features, drop_prob=0.2):
21
+ super(ResidualBlock, self).__init__()
22
+
23
+ conv_block = [layer_pad(1),
24
+ layer_conv(in_features, in_features, 3),
25
+ layer_norm(in_features),
26
+ act_relu(),
27
+ layer_drop(p=drop_prob, inplace=True),
28
+ layer_pad(1),
29
+ layer_conv(in_features, in_features, 3),
30
+ layer_norm(in_features)]
31
+
32
+ self.conv_block = nn.Sequential(*conv_block)
33
+
34
+ def forward(self, x):
35
+ return x + self.conv_block(x)
36
+
37
+
38
+ class DownsamplingBlock(nn.Module):
39
+ def __init__(self, in_features, out_features):
40
+ super(DownsamplingBlock, self).__init__()
41
+
42
+ conv_block = [layer_conv(in_features, out_features, 3, stride=2, padding=1),
43
+ layer_norm(out_features),
44
+ act_relu()]
45
+ self.conv_block = nn.Sequential(*conv_block)
46
+
47
+ def forward(self, x):
48
+ return self.conv_block(x)
49
+
50
+
51
+ class UpsamplingBlock(nn.Module):
52
+ def __init__(self, in_features, out_features, mode="convtrans", interpolator=None, post_interp_convtrans=False):
53
+ super(UpsamplingBlock, self).__init__()
54
+
55
+ self.interpolator = interpolator
56
+ self.mode = mode
57
+ self.post_interp_convtrans = post_interp_convtrans
58
+ if self.post_interp_convtrans:
59
+ self.post_conv = layer_conv(out_features, out_features, 1)
60
+
61
+ if mode == "convtrans":
62
+ conv_block = [layer_convtrans(
63
+ in_features, out_features, 3, stride=2, padding=1, output_padding=1), ]
64
+ else:
65
+ conv_block = [layer_pad(1),
66
+ layer_conv(in_features, out_features, 3), ]
67
+ conv_block += [layer_norm(out_features),
68
+ act_relu()]
69
+ self.conv_block = nn.Sequential(*conv_block)
70
+
71
+ def forward(self, x, out_shape=None):
72
+ if self.mode == "convtrans":
73
+ if self.post_interp_convtrans:
74
+ x = self.conv_block(x)
75
+ if x.shape[2:] != out_shape:
76
+ return self.post_conv(self.interpolator(x, out_shape))
77
+ else:
78
+ return x
79
+ else:
80
+ return self.conv_block(x)
81
+ else:
82
+ return self.conv_block(self.interpolator(x, out_shape))
83
+
84
+
85
+ class ReconResNetBase(nn.Module):
86
+ def __init__(self, in_channels=1, out_channels=1, res_blocks=14, starting_nfeatures=64, updown_blocks=2, is_relu_leaky=True, do_batchnorm=False, res_drop_prob=0.2,
87
+ is_replicatepad=0, out_act="sigmoid", forwardV=0, upinterp_algo='convtrans', post_interp_convtrans=False, is3D=False): # should use 14 as that gives number of trainable parameters close to number of possible pixel values in a image 256x256
88
+ super(ReconResNetBase, self).__init__()
89
+
90
+ layers = {}
91
+ if is3D:
92
+ layers["layer_conv"] = nn.Conv3d
93
+ layers["layer_convtrans"] = nn.ConvTranspose3d
94
+ if do_batchnorm:
95
+ layers["layer_norm"] = nn.BatchNorm3d
96
+ else:
97
+ layers["layer_norm"] = nn.InstanceNorm3d
98
+ layers["layer_drop"] = nn.Dropout3d
99
+ if is_replicatepad == 0:
100
+ layers["layer_pad"] = nn.ReflectionPad3d
101
+ elif is_replicatepad == 1:
102
+ layers["layer_pad"] = nn.ReplicationPad3d
103
+ layers["interp_mode"] = 'trilinear'
104
+ else:
105
+ layers["layer_conv"] = nn.Conv2d
106
+ layers["layer_convtrans"] = nn.ConvTranspose2d
107
+ if do_batchnorm:
108
+ layers["layer_norm"] = nn.BatchNorm2d
109
+ else:
110
+ layers["layer_norm"] = nn.InstanceNorm2d
111
+ layers["layer_drop"] = nn.Dropout2d
112
+ if is_replicatepad == 0:
113
+ layers["layer_pad"] = nn.ReflectionPad2d
114
+ elif is_replicatepad == 1:
115
+ layers["layer_pad"] = nn.ReplicationPad2d
116
+ layers["interp_mode"] = 'bilinear'
117
+ if is_relu_leaky:
118
+ layers["act_relu"] = nn.PReLU
119
+ else:
120
+ layers["act_relu"] = nn.ReLU
121
+ globals().update(layers)
122
+
123
+ self.forwardV = forwardV
124
+ self.upinterp_algo = upinterp_algo
125
+
126
+ interpolator = Interpolator(
127
+ mode=layers["interp_mode"] if self.upinterp_algo == "convtrans" else self.upinterp_algo)
128
+
129
+ # Initial convolution block
130
+ intialConv = [layer_pad(3),
131
+ layer_conv(in_channels, starting_nfeatures, 7),
132
+ layer_norm(starting_nfeatures),
133
+ act_relu()]
134
+
135
+ # Downsampling [need to save the shape for upsample]
136
+ downsam = []
137
+ in_features = starting_nfeatures
138
+ out_features = in_features*2
139
+ for _ in range(updown_blocks):
140
+ downsam.append(DownsamplingBlock(in_features, out_features))
141
+ in_features = out_features
142
+ out_features = in_features*2
143
+
144
+ # Residual blocks
145
+ resblocks = []
146
+ for _ in range(res_blocks):
147
+ resblocks += [ResidualBlock(in_features, res_drop_prob)]
148
+
149
+ # Upsampling
150
+ upsam = []
151
+ out_features = in_features//2
152
+ for _ in range(updown_blocks):
153
+ upsam.append(UpsamplingBlock(in_features, out_features,
154
+ self.upinterp_algo, interpolator, post_interp_convtrans))
155
+ in_features = out_features
156
+ out_features = in_features//2
157
+
158
+ # Output layer
159
+ finalconv = [layer_pad(3),
160
+ layer_conv(starting_nfeatures, out_channels, 7), ]
161
+
162
+ if out_act == "sigmoid":
163
+ finalconv += [nn.Sigmoid(), ]
164
+ elif out_act == "relu":
165
+ finalconv += [act_relu(), ]
166
+ elif out_act == "tanh":
167
+ finalconv += [nn.Tanh(), ]
168
+
169
+ self.intialConv = nn.Sequential(*intialConv)
170
+ self.downsam = nn.ModuleList(downsam)
171
+ self.resblocks = nn.Sequential(*resblocks)
172
+ self.upsam = nn.ModuleList(upsam)
173
+ self.finalconv = nn.Sequential(*finalconv)
174
+
175
+ if self.forwardV == 0:
176
+ self.forward = self.forwardV0
177
+ elif self.forwardV == 1:
178
+ self.forward = self.forwardV1
179
+ elif self.forwardV == 2:
180
+ self.forward = self.forwardV2
181
+ elif self.forwardV == 3:
182
+ self.forward = self.forwardV3
183
+ elif self.forwardV == 4:
184
+ self.forward = self.forwardV4
185
+ elif self.forwardV == 5:
186
+ self.forward = self.forwardV5
187
+
188
+ def forwardV0(self, x):
189
+ # v0: Original Version
190
+ x = self.intialConv(x)
191
+ shapes = []
192
+ for downblock in self.downsam:
193
+ shapes.append(x.shape[2:])
194
+ x = downblock(x)
195
+ x = self.resblocks(x)
196
+ for i, upblock in enumerate(self.upsam):
197
+ x = upblock(x, shapes[-1-i])
198
+ return self.finalconv(x)
199
+
200
+ def forwardV1(self, x):
201
+ # v1: input is added to the final output
202
+ out = self.intialConv(x)
203
+ shapes = []
204
+ for downblock in self.downsam:
205
+ shapes.append(out.shape[2:])
206
+ out = downblock(out)
207
+ out = self.resblocks(out)
208
+ for i, upblock in enumerate(self.upsam):
209
+ out = upblock(out, shapes[-1-i])
210
+ return x + self.finalconv(out)
211
+
212
+ def forwardV2(self, x):
213
+ # v2: residual of v1 + input to the residual blocks added back with the output
214
+ out = self.intialConv(x)
215
+ shapes = []
216
+ for downblock in self.downsam:
217
+ shapes.append(out.shape[2:])
218
+ out = downblock(out)
219
+ out = out + self.resblocks(out)
220
+ for i, upblock in enumerate(self.upsam):
221
+ out = upblock(out, shapes[-1-i])
222
+ return x + self.finalconv(out)
223
+
224
+ def forwardV3(self, x):
225
+ # v3: residual of v2 + input of the initial conv added back with the output
226
+ out = x + self.intialConv(x)
227
+ shapes = []
228
+ for downblock in self.downsam:
229
+ shapes.append(out.shape[2:])
230
+ out = downblock(out)
231
+ out = out + self.resblocks(out)
232
+ for i, upblock in enumerate(self.upsam):
233
+ out = upblock(out, shapes[-1-i])
234
+ return x + self.finalconv(out)
235
+
236
+ def forwardV4(self, x):
237
+ # v4: residual of v3 + output of the initial conv added back with the input of final conv
238
+ iniconv = x + self.intialConv(x)
239
+ shapes = []
240
+ if len(self.downsam) > 0:
241
+ for i, downblock in enumerate(self.downsam):
242
+ if i == 0:
243
+ shapes.append(iniconv.shape[2:])
244
+ out = downblock(iniconv)
245
+ else:
246
+ shapes.append(out.shape[2:])
247
+ out = downblock(out)
248
+ else:
249
+ out = iniconv
250
+ out = out + self.resblocks(out)
251
+ for i, upblock in enumerate(self.upsam):
252
+ out = upblock(out, shapes[-1-i])
253
+ out = iniconv + out
254
+ return x + self.finalconv(out)
255
+
256
+ def forwardV5(self, x):
257
+ # v5: residual of v4 + individual down blocks with individual up blocks
258
+ outs = [x + self.intialConv(x)]
259
+ shapes = []
260
+ for i, downblock in enumerate(self.downsam):
261
+ shapes.append(outs[-1].shape[2:])
262
+ outs.append(downblock(outs[-1]))
263
+ outs[-1] = outs[-1] + self.resblocks(outs[-1])
264
+ for i, upblock in enumerate(self.upsam):
265
+ outs[-1] = upblock(outs[-1], shapes[-1-i])
266
+ outs[-1] = outs[-2] + outs.pop()
267
+ return x + self.finalconv(outs.pop())
ReconResNetConfig.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import PretrainedConfig
2
+ from typing import List
3
+
4
+ class ReconResNetConfig(PretrainedConfig):
5
+ model_type = "ReconResNet"
6
+ def __init__(
7
+ self,
8
+ in_channels=1,
9
+ out_channels=1,
10
+ res_blocks=14,
11
+ starting_nfeatures=64,
12
+ updown_blocks=2,
13
+ is_relu_leaky=True,
14
+ do_batchnorm=False,
15
+ res_drop_prob=0.2,
16
+ is_replicatepad=0,
17
+ out_act="sigmoid",
18
+ forwardV=0,
19
+ upinterp_algo='convtrans',
20
+ post_interp_convtrans=False,
21
+ is3D=False,
22
+ **kwargs):
23
+ self.in_channels = in_channels
24
+ self.out_channels = out_channels
25
+ self.res_blocks = res_blocks
26
+ self.starting_nfeatures = starting_nfeatures
27
+ self.updown_blocks = updown_blocks
28
+ self.is_relu_leaky = is_relu_leaky
29
+ self.do_batchnorm = do_batchnorm
30
+ self.res_drop_prob = res_drop_prob
31
+ self.is_replicatepad = is_replicatepad
32
+ self.out_act = out_act
33
+ self.forwardV = forwardV
34
+ self.upinterp_algo = upinterp_algo
35
+ self.post_interp_convtrans = post_interp_convtrans
36
+ self.is3D = is3D
37
+ super().__init__(**kwargs)
config.json ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "ReconResNet"
4
+ ],
5
+ "auto_map": {
6
+ "AutoConfig": "ReconResNetConfig.ReconResNetConfig",
7
+ "AutoModel": "ReconResNet.ReconResNet"
8
+ },
9
+ "do_batchnorm": false,
10
+ "forwardV": 0,
11
+ "in_channels": 1,
12
+ "is3D": false,
13
+ "is_relu_leaky": true,
14
+ "is_replicatepad": 0,
15
+ "model_type": "ReconResNet",
16
+ "out_act": "sigmoid",
17
+ "out_channels": 1,
18
+ "post_interp_convtrans": false,
19
+ "res_blocks": 14,
20
+ "res_drop_prob": 0.2,
21
+ "starting_nfeatures": 64,
22
+ "torch_dtype": "float32",
23
+ "transformers_version": "4.44.2",
24
+ "updown_blocks": 2,
25
+ "upinterp_algo": "convtrans"
26
+ }
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c144f0f64f152f337fd5c32d2fa1519d570ac5f2e4349759c36317844173290e
3
+ size 69075000