Upload ReconResNet
Browse files- ReconResNet.py +25 -0
- ReconResNetBase.py +267 -0
- ReconResNetConfig.py +37 -0
- config.json +26 -0
- model.safetensors +3 -0
ReconResNet.py
ADDED
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from transformers import PreTrainedModel
|
2 |
+
from .ReconResNetBase import ReconResNetBase
|
3 |
+
from .ReconResNetConfig import ReconResNetConfig
|
4 |
+
|
5 |
+
class ReconResNet(PreTrainedModel):
|
6 |
+
config_class = ReconResNetConfig
|
7 |
+
def __init__(self, config):
|
8 |
+
super().__init__(config)
|
9 |
+
self.model = ReconResNetBase(
|
10 |
+
in_channels=config.in_channels,
|
11 |
+
out_channels=config.out_channels,
|
12 |
+
res_blocks=config.res_blocks,
|
13 |
+
starting_nfeatures=config.starting_nfeatures,
|
14 |
+
updown_blocks=config.updown_blocks,
|
15 |
+
is_relu_leaky=config.is_relu_leaky,
|
16 |
+
do_batchnorm=config.do_batchnorm,
|
17 |
+
res_drop_prob=config.res_drop_prob,
|
18 |
+
is_replicatepad=config.is_replicatepad,
|
19 |
+
out_act=config.out_act,
|
20 |
+
forwardV=config.forwardV,
|
21 |
+
upinterp_algo=config.upinterp_algo,
|
22 |
+
post_interp_convtrans=config.post_interp_convtrans,
|
23 |
+
is3D=config.is3D)
|
24 |
+
def forward(self, x):
|
25 |
+
return self.model(x)
|
ReconResNetBase.py
ADDED
@@ -0,0 +1,267 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
|
3 |
+
# This model is part of the paper "ReconResNet: Regularised Residual Learning for MR Image Reconstruction of Undersampled Cartesian and Radial Data" (https://doi.org/10.1016/j.compbiomed.2022.105321)
|
4 |
+
# and has been published on GitHub: https://github.com/soumickmj/NCC1701/blob/main/Bridge/WarpDrives/ReconResNet/ReconResNet.py
|
5 |
+
|
6 |
+
import torch.nn as nn
|
7 |
+
from tricorder.torch.transforms import Interpolator
|
8 |
+
|
9 |
+
__author__ = "Soumick Chatterjee"
|
10 |
+
__copyright__ = "Copyright 2019, Soumick Chatterjee & OvGU:ESF:MEMoRIAL"
|
11 |
+
__credits__ = ["Soumick Chatterjee"]
|
12 |
+
|
13 |
+
__license__ = "apache-2.0"
|
14 |
+
__version__ = "1.0.0"
|
15 |
+
__email__ = "[email protected]"
|
16 |
+
__status__ = "Published"
|
17 |
+
|
18 |
+
|
19 |
+
class ResidualBlock(nn.Module):
|
20 |
+
def __init__(self, in_features, drop_prob=0.2):
|
21 |
+
super(ResidualBlock, self).__init__()
|
22 |
+
|
23 |
+
conv_block = [layer_pad(1),
|
24 |
+
layer_conv(in_features, in_features, 3),
|
25 |
+
layer_norm(in_features),
|
26 |
+
act_relu(),
|
27 |
+
layer_drop(p=drop_prob, inplace=True),
|
28 |
+
layer_pad(1),
|
29 |
+
layer_conv(in_features, in_features, 3),
|
30 |
+
layer_norm(in_features)]
|
31 |
+
|
32 |
+
self.conv_block = nn.Sequential(*conv_block)
|
33 |
+
|
34 |
+
def forward(self, x):
|
35 |
+
return x + self.conv_block(x)
|
36 |
+
|
37 |
+
|
38 |
+
class DownsamplingBlock(nn.Module):
|
39 |
+
def __init__(self, in_features, out_features):
|
40 |
+
super(DownsamplingBlock, self).__init__()
|
41 |
+
|
42 |
+
conv_block = [layer_conv(in_features, out_features, 3, stride=2, padding=1),
|
43 |
+
layer_norm(out_features),
|
44 |
+
act_relu()]
|
45 |
+
self.conv_block = nn.Sequential(*conv_block)
|
46 |
+
|
47 |
+
def forward(self, x):
|
48 |
+
return self.conv_block(x)
|
49 |
+
|
50 |
+
|
51 |
+
class UpsamplingBlock(nn.Module):
|
52 |
+
def __init__(self, in_features, out_features, mode="convtrans", interpolator=None, post_interp_convtrans=False):
|
53 |
+
super(UpsamplingBlock, self).__init__()
|
54 |
+
|
55 |
+
self.interpolator = interpolator
|
56 |
+
self.mode = mode
|
57 |
+
self.post_interp_convtrans = post_interp_convtrans
|
58 |
+
if self.post_interp_convtrans:
|
59 |
+
self.post_conv = layer_conv(out_features, out_features, 1)
|
60 |
+
|
61 |
+
if mode == "convtrans":
|
62 |
+
conv_block = [layer_convtrans(
|
63 |
+
in_features, out_features, 3, stride=2, padding=1, output_padding=1), ]
|
64 |
+
else:
|
65 |
+
conv_block = [layer_pad(1),
|
66 |
+
layer_conv(in_features, out_features, 3), ]
|
67 |
+
conv_block += [layer_norm(out_features),
|
68 |
+
act_relu()]
|
69 |
+
self.conv_block = nn.Sequential(*conv_block)
|
70 |
+
|
71 |
+
def forward(self, x, out_shape=None):
|
72 |
+
if self.mode == "convtrans":
|
73 |
+
if self.post_interp_convtrans:
|
74 |
+
x = self.conv_block(x)
|
75 |
+
if x.shape[2:] != out_shape:
|
76 |
+
return self.post_conv(self.interpolator(x, out_shape))
|
77 |
+
else:
|
78 |
+
return x
|
79 |
+
else:
|
80 |
+
return self.conv_block(x)
|
81 |
+
else:
|
82 |
+
return self.conv_block(self.interpolator(x, out_shape))
|
83 |
+
|
84 |
+
|
85 |
+
class ReconResNetBase(nn.Module):
|
86 |
+
def __init__(self, in_channels=1, out_channels=1, res_blocks=14, starting_nfeatures=64, updown_blocks=2, is_relu_leaky=True, do_batchnorm=False, res_drop_prob=0.2,
|
87 |
+
is_replicatepad=0, out_act="sigmoid", forwardV=0, upinterp_algo='convtrans', post_interp_convtrans=False, is3D=False): # should use 14 as that gives number of trainable parameters close to number of possible pixel values in a image 256x256
|
88 |
+
super(ReconResNetBase, self).__init__()
|
89 |
+
|
90 |
+
layers = {}
|
91 |
+
if is3D:
|
92 |
+
layers["layer_conv"] = nn.Conv3d
|
93 |
+
layers["layer_convtrans"] = nn.ConvTranspose3d
|
94 |
+
if do_batchnorm:
|
95 |
+
layers["layer_norm"] = nn.BatchNorm3d
|
96 |
+
else:
|
97 |
+
layers["layer_norm"] = nn.InstanceNorm3d
|
98 |
+
layers["layer_drop"] = nn.Dropout3d
|
99 |
+
if is_replicatepad == 0:
|
100 |
+
layers["layer_pad"] = nn.ReflectionPad3d
|
101 |
+
elif is_replicatepad == 1:
|
102 |
+
layers["layer_pad"] = nn.ReplicationPad3d
|
103 |
+
layers["interp_mode"] = 'trilinear'
|
104 |
+
else:
|
105 |
+
layers["layer_conv"] = nn.Conv2d
|
106 |
+
layers["layer_convtrans"] = nn.ConvTranspose2d
|
107 |
+
if do_batchnorm:
|
108 |
+
layers["layer_norm"] = nn.BatchNorm2d
|
109 |
+
else:
|
110 |
+
layers["layer_norm"] = nn.InstanceNorm2d
|
111 |
+
layers["layer_drop"] = nn.Dropout2d
|
112 |
+
if is_replicatepad == 0:
|
113 |
+
layers["layer_pad"] = nn.ReflectionPad2d
|
114 |
+
elif is_replicatepad == 1:
|
115 |
+
layers["layer_pad"] = nn.ReplicationPad2d
|
116 |
+
layers["interp_mode"] = 'bilinear'
|
117 |
+
if is_relu_leaky:
|
118 |
+
layers["act_relu"] = nn.PReLU
|
119 |
+
else:
|
120 |
+
layers["act_relu"] = nn.ReLU
|
121 |
+
globals().update(layers)
|
122 |
+
|
123 |
+
self.forwardV = forwardV
|
124 |
+
self.upinterp_algo = upinterp_algo
|
125 |
+
|
126 |
+
interpolator = Interpolator(
|
127 |
+
mode=layers["interp_mode"] if self.upinterp_algo == "convtrans" else self.upinterp_algo)
|
128 |
+
|
129 |
+
# Initial convolution block
|
130 |
+
intialConv = [layer_pad(3),
|
131 |
+
layer_conv(in_channels, starting_nfeatures, 7),
|
132 |
+
layer_norm(starting_nfeatures),
|
133 |
+
act_relu()]
|
134 |
+
|
135 |
+
# Downsampling [need to save the shape for upsample]
|
136 |
+
downsam = []
|
137 |
+
in_features = starting_nfeatures
|
138 |
+
out_features = in_features*2
|
139 |
+
for _ in range(updown_blocks):
|
140 |
+
downsam.append(DownsamplingBlock(in_features, out_features))
|
141 |
+
in_features = out_features
|
142 |
+
out_features = in_features*2
|
143 |
+
|
144 |
+
# Residual blocks
|
145 |
+
resblocks = []
|
146 |
+
for _ in range(res_blocks):
|
147 |
+
resblocks += [ResidualBlock(in_features, res_drop_prob)]
|
148 |
+
|
149 |
+
# Upsampling
|
150 |
+
upsam = []
|
151 |
+
out_features = in_features//2
|
152 |
+
for _ in range(updown_blocks):
|
153 |
+
upsam.append(UpsamplingBlock(in_features, out_features,
|
154 |
+
self.upinterp_algo, interpolator, post_interp_convtrans))
|
155 |
+
in_features = out_features
|
156 |
+
out_features = in_features//2
|
157 |
+
|
158 |
+
# Output layer
|
159 |
+
finalconv = [layer_pad(3),
|
160 |
+
layer_conv(starting_nfeatures, out_channels, 7), ]
|
161 |
+
|
162 |
+
if out_act == "sigmoid":
|
163 |
+
finalconv += [nn.Sigmoid(), ]
|
164 |
+
elif out_act == "relu":
|
165 |
+
finalconv += [act_relu(), ]
|
166 |
+
elif out_act == "tanh":
|
167 |
+
finalconv += [nn.Tanh(), ]
|
168 |
+
|
169 |
+
self.intialConv = nn.Sequential(*intialConv)
|
170 |
+
self.downsam = nn.ModuleList(downsam)
|
171 |
+
self.resblocks = nn.Sequential(*resblocks)
|
172 |
+
self.upsam = nn.ModuleList(upsam)
|
173 |
+
self.finalconv = nn.Sequential(*finalconv)
|
174 |
+
|
175 |
+
if self.forwardV == 0:
|
176 |
+
self.forward = self.forwardV0
|
177 |
+
elif self.forwardV == 1:
|
178 |
+
self.forward = self.forwardV1
|
179 |
+
elif self.forwardV == 2:
|
180 |
+
self.forward = self.forwardV2
|
181 |
+
elif self.forwardV == 3:
|
182 |
+
self.forward = self.forwardV3
|
183 |
+
elif self.forwardV == 4:
|
184 |
+
self.forward = self.forwardV4
|
185 |
+
elif self.forwardV == 5:
|
186 |
+
self.forward = self.forwardV5
|
187 |
+
|
188 |
+
def forwardV0(self, x):
|
189 |
+
# v0: Original Version
|
190 |
+
x = self.intialConv(x)
|
191 |
+
shapes = []
|
192 |
+
for downblock in self.downsam:
|
193 |
+
shapes.append(x.shape[2:])
|
194 |
+
x = downblock(x)
|
195 |
+
x = self.resblocks(x)
|
196 |
+
for i, upblock in enumerate(self.upsam):
|
197 |
+
x = upblock(x, shapes[-1-i])
|
198 |
+
return self.finalconv(x)
|
199 |
+
|
200 |
+
def forwardV1(self, x):
|
201 |
+
# v1: input is added to the final output
|
202 |
+
out = self.intialConv(x)
|
203 |
+
shapes = []
|
204 |
+
for downblock in self.downsam:
|
205 |
+
shapes.append(out.shape[2:])
|
206 |
+
out = downblock(out)
|
207 |
+
out = self.resblocks(out)
|
208 |
+
for i, upblock in enumerate(self.upsam):
|
209 |
+
out = upblock(out, shapes[-1-i])
|
210 |
+
return x + self.finalconv(out)
|
211 |
+
|
212 |
+
def forwardV2(self, x):
|
213 |
+
# v2: residual of v1 + input to the residual blocks added back with the output
|
214 |
+
out = self.intialConv(x)
|
215 |
+
shapes = []
|
216 |
+
for downblock in self.downsam:
|
217 |
+
shapes.append(out.shape[2:])
|
218 |
+
out = downblock(out)
|
219 |
+
out = out + self.resblocks(out)
|
220 |
+
for i, upblock in enumerate(self.upsam):
|
221 |
+
out = upblock(out, shapes[-1-i])
|
222 |
+
return x + self.finalconv(out)
|
223 |
+
|
224 |
+
def forwardV3(self, x):
|
225 |
+
# v3: residual of v2 + input of the initial conv added back with the output
|
226 |
+
out = x + self.intialConv(x)
|
227 |
+
shapes = []
|
228 |
+
for downblock in self.downsam:
|
229 |
+
shapes.append(out.shape[2:])
|
230 |
+
out = downblock(out)
|
231 |
+
out = out + self.resblocks(out)
|
232 |
+
for i, upblock in enumerate(self.upsam):
|
233 |
+
out = upblock(out, shapes[-1-i])
|
234 |
+
return x + self.finalconv(out)
|
235 |
+
|
236 |
+
def forwardV4(self, x):
|
237 |
+
# v4: residual of v3 + output of the initial conv added back with the input of final conv
|
238 |
+
iniconv = x + self.intialConv(x)
|
239 |
+
shapes = []
|
240 |
+
if len(self.downsam) > 0:
|
241 |
+
for i, downblock in enumerate(self.downsam):
|
242 |
+
if i == 0:
|
243 |
+
shapes.append(iniconv.shape[2:])
|
244 |
+
out = downblock(iniconv)
|
245 |
+
else:
|
246 |
+
shapes.append(out.shape[2:])
|
247 |
+
out = downblock(out)
|
248 |
+
else:
|
249 |
+
out = iniconv
|
250 |
+
out = out + self.resblocks(out)
|
251 |
+
for i, upblock in enumerate(self.upsam):
|
252 |
+
out = upblock(out, shapes[-1-i])
|
253 |
+
out = iniconv + out
|
254 |
+
return x + self.finalconv(out)
|
255 |
+
|
256 |
+
def forwardV5(self, x):
|
257 |
+
# v5: residual of v4 + individual down blocks with individual up blocks
|
258 |
+
outs = [x + self.intialConv(x)]
|
259 |
+
shapes = []
|
260 |
+
for i, downblock in enumerate(self.downsam):
|
261 |
+
shapes.append(outs[-1].shape[2:])
|
262 |
+
outs.append(downblock(outs[-1]))
|
263 |
+
outs[-1] = outs[-1] + self.resblocks(outs[-1])
|
264 |
+
for i, upblock in enumerate(self.upsam):
|
265 |
+
outs[-1] = upblock(outs[-1], shapes[-1-i])
|
266 |
+
outs[-1] = outs[-2] + outs.pop()
|
267 |
+
return x + self.finalconv(outs.pop())
|
ReconResNetConfig.py
ADDED
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from transformers import PretrainedConfig
|
2 |
+
from typing import List
|
3 |
+
|
4 |
+
class ReconResNetConfig(PretrainedConfig):
|
5 |
+
model_type = "ReconResNet"
|
6 |
+
def __init__(
|
7 |
+
self,
|
8 |
+
in_channels=1,
|
9 |
+
out_channels=1,
|
10 |
+
res_blocks=14,
|
11 |
+
starting_nfeatures=64,
|
12 |
+
updown_blocks=2,
|
13 |
+
is_relu_leaky=True,
|
14 |
+
do_batchnorm=False,
|
15 |
+
res_drop_prob=0.2,
|
16 |
+
is_replicatepad=0,
|
17 |
+
out_act="sigmoid",
|
18 |
+
forwardV=0,
|
19 |
+
upinterp_algo='convtrans',
|
20 |
+
post_interp_convtrans=False,
|
21 |
+
is3D=False,
|
22 |
+
**kwargs):
|
23 |
+
self.in_channels = in_channels
|
24 |
+
self.out_channels = out_channels
|
25 |
+
self.res_blocks = res_blocks
|
26 |
+
self.starting_nfeatures = starting_nfeatures
|
27 |
+
self.updown_blocks = updown_blocks
|
28 |
+
self.is_relu_leaky = is_relu_leaky
|
29 |
+
self.do_batchnorm = do_batchnorm
|
30 |
+
self.res_drop_prob = res_drop_prob
|
31 |
+
self.is_replicatepad = is_replicatepad
|
32 |
+
self.out_act = out_act
|
33 |
+
self.forwardV = forwardV
|
34 |
+
self.upinterp_algo = upinterp_algo
|
35 |
+
self.post_interp_convtrans = post_interp_convtrans
|
36 |
+
self.is3D = is3D
|
37 |
+
super().__init__(**kwargs)
|
config.json
ADDED
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"architectures": [
|
3 |
+
"ReconResNet"
|
4 |
+
],
|
5 |
+
"auto_map": {
|
6 |
+
"AutoConfig": "ReconResNetConfig.ReconResNetConfig",
|
7 |
+
"AutoModel": "ReconResNet.ReconResNet"
|
8 |
+
},
|
9 |
+
"do_batchnorm": false,
|
10 |
+
"forwardV": 0,
|
11 |
+
"in_channels": 1,
|
12 |
+
"is3D": false,
|
13 |
+
"is_relu_leaky": true,
|
14 |
+
"is_replicatepad": 0,
|
15 |
+
"model_type": "ReconResNet",
|
16 |
+
"out_act": "sigmoid",
|
17 |
+
"out_channels": 1,
|
18 |
+
"post_interp_convtrans": false,
|
19 |
+
"res_blocks": 14,
|
20 |
+
"res_drop_prob": 0.2,
|
21 |
+
"starting_nfeatures": 64,
|
22 |
+
"torch_dtype": "float32",
|
23 |
+
"transformers_version": "4.44.2",
|
24 |
+
"updown_blocks": 2,
|
25 |
+
"upinterp_algo": "convtrans"
|
26 |
+
}
|
model.safetensors
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:c144f0f64f152f337fd5c32d2fa1519d570ac5f2e4349759c36317844173290e
|
3 |
+
size 69075000
|