Yonuts commited on
Commit
b4684a7
·
1 Parent(s): 3925250
evals.py CHANGED
@@ -7,7 +7,7 @@ from deepinv.physics.generator import MotionBlurGenerator, SigmaGenerator
7
  from torchvision import transforms
8
 
9
  from datasets import Preprocessed_fastMRI, Preprocessed_LIDCIDRI, LsdirMiniDataset
10
- from utils import get_model
11
 
12
  DEFAULT_MODEL_PARAMS = {
13
  "in_channels": [1, 2, 3],
 
7
  from torchvision import transforms
8
 
9
  from datasets import Preprocessed_fastMRI, Preprocessed_LIDCIDRI, LsdirMiniDataset
10
+ from model_factory import get_model
11
 
12
  DEFAULT_MODEL_PARAMS = {
13
  "in_channels": [1, 2, 3],
utils.py → model_factory.py RENAMED
@@ -3,7 +3,6 @@ import torch.nn as nn
3
  import deepinv as dinv
4
 
5
  from models.unext_wip import UNeXt
6
- from models.unrolled_dpir import get_unrolled_architecture
7
  from physics.multiscale import Pad
8
 
9
 
 
3
  import deepinv as dinv
4
 
5
  from models.unext_wip import UNeXt
 
6
  from physics.multiscale import Pad
7
 
8
 
models/PDNet.py DELETED
@@ -1,322 +0,0 @@
1
- from pathlib import Path
2
-
3
- import torch
4
- from torch.func import vmap
5
- from torch.utils.data import DataLoader
6
- import deepinv as dinv
7
- from deepinv.unfolded import unfolded_builder
8
- from deepinv.utils.phantoms import RandomPhantomDataset, SheppLoganDataset
9
- from deepinv.optim.optim_iterators import CPIteration, fStep, gStep
10
- from deepinv.optim import Prior, DataFidelity
11
- from deepinv.utils import TensorList
12
-
13
- from physics.multiscale import MultiScaleLinearPhysics
14
- from models.heads import Heads, Tails, InHead, OutTail, ConvChannels, SNRModule, EquivConvModule, EquivHeads
15
-
16
-
17
- def get_PDNet_architecture(in_channels=[1, 2, 3], out_channels=[1, 2, 3], n_primal=3, n_dual=3, device='cuda'):
18
- class PDNetIteration(CPIteration):
19
- r"""Single iteration of learned primal dual.
20
- We only redefine the fStep and gStep classes.
21
- The forward method is inherited from the CPIteration class.
22
- """
23
-
24
- def __init__(self, **kwargs):
25
- super().__init__(**kwargs)
26
- self.g_step = gStepPDNet(**kwargs)
27
- self.f_step = fStepPDNet(**kwargs)
28
-
29
- def forward(
30
- self, X, cur_data_fidelity, cur_prior, cur_params, y, physics, *args, **kwargs
31
- ):
32
- r"""
33
- Single iteration of the Chambolle-Pock algorithm.
34
-
35
- :param dict X: Dictionary containing the current iterate and the estimated cost.
36
- :param deepinv.optim.DataFidelity cur_data_fidelity: Instance of the DataFidelity class defining the current data_fidelity.
37
- :param deepinv.optim.Prior cur_prior: Instance of the Prior class defining the current prior.
38
- :param dict cur_params: dictionary containing the current parameters of the algorithm.
39
- :param torch.Tensor y: Input data.
40
- :param deepinv.physics.Physics physics: Instance of the physics modeling the data-fidelity term.
41
- :return: Dictionary `{"est": (x, ), "cost": F}` containing the updated current iterate and the estimated current cost.
42
- """
43
- x_prev, z_prev, u_prev = X["est"] # x : primal, z : relaxed primal, u : dual
44
- BS, C_primal, H_primal, W_primal = x_prev.shape
45
- _, C_dual, H_dual, W_dual = u_prev.shape
46
- n_channels = C_primal // n_primal
47
- K = lambda x: torch.cat(
48
- [physics.A(x[:, i * n_channels:(i + 1) * n_channels, :, :]) for i in range(n_primal)], dim=1)
49
- K_adjoint = lambda x: torch.cat(
50
- [physics.A_adjoint(x[:, i * n_channels:(i + 1) * n_channels, :, :]) for i in range(n_dual)], dim=1)
51
- u = self.f_step(u_prev, K(z_prev), cur_data_fidelity, y, physics, n_channels,
52
- cur_params) # dual update (data_fid)
53
- x = self.g_step(x_prev, K_adjoint(u), cur_prior, n_channels, cur_params) # primal update (prior)
54
- z = x + cur_params["beta"] * (x - x_prev)
55
- F = (
56
- self.F_fn(x, cur_data_fidelity, cur_prior, cur_params, y, physics)
57
- if self.has_cost
58
- else None
59
- )
60
- return {"est": (x, z, u), "cost": F}
61
-
62
- class fStepPDNet(fStep):
63
- r"""
64
- Dual update of the PDNet algorithm.
65
- We write it as a proximal operator of the data fidelity term.
66
- This proximal mapping is to be replaced by a trainable model.
67
- """
68
-
69
- def __init__(self, **kwargs):
70
- super().__init__(**kwargs)
71
-
72
- def forward(self, x, w, cur_data_fidelity, y, physics, n_channels, *args):
73
- r"""
74
- :param torch.Tensor x: Current first variable :math:`u`.
75
- :param torch.Tensor w: Current second variable :math:`A z`.
76
- :param deepinv.optim.data_fidelity cur_data_fidelity: Instance of the DataFidelity class defining the current data fidelity term.
77
- :param torch.Tensor y: Input data.
78
- """
79
- return cur_data_fidelity.prox(x, w, y, n_channels)
80
-
81
- class gStepPDNet(gStep):
82
- r"""
83
- Primal update of the PDNet algorithm.
84
- We write it as a proximal operator of the prior term.
85
- This proximal mapping is to be replaced by a trainable model.
86
- """
87
-
88
- def __init__(self, **kwargs):
89
- super().__init__(**kwargs)
90
-
91
- def forward(self, x, w, cur_prior, n_channels, *args):
92
- r"""
93
- :param torch.Tensor x: Current first variable :math:`x`.
94
- :param torch.Tensor w: Current second variable :math:`A^\top u`.
95
- :param deepinv.optim.prior cur_prior: Instance of the Prior class defining the current prior.
96
- """
97
- return cur_prior.prox(x, w, n_channels)
98
-
99
- # %%
100
- # Define the trainable prior and data fidelity terms.
101
- # ---------------------------------------------------
102
- # Prior and data-fidelity are respectively defined as subclass of :class:`deepinv.optim.Prior` and :class:`deepinv.optim.DataFidelity`.
103
- # Their proximal operators are replaced by trainable models.
104
-
105
- class PDNetPrior(Prior):
106
- def __init__(self, model, *args, **kwargs):
107
- super().__init__(*args, **kwargs)
108
- self.model = model
109
-
110
- def prox(self, x, w, n_channels):
111
- # give to the model : full primal + premier de dual
112
- dual_cond = w[:, 0:n_channels, :, :]
113
- return self.model(x, dual_cond)
114
-
115
- class PDNetDataFid(DataFidelity):
116
- def __init__(self, model, *args, **kwargs):
117
- super().__init__(*args, **kwargs)
118
- self.model = model
119
-
120
- def prox(self, x, w, y, n_channels):
121
- # give to the model : full dual + deuxieme de primal + y = n_channel*n_dual + n_channel + n_channel
122
- if n_primal > 1:
123
- primal_cond = w[:, n_channels:(2 * n_channels), :, :]
124
- else:
125
- primal_cond = w[:, 0:n_channels, :, :]
126
- return self.model(x, primal_cond, y)
127
-
128
- # Unrolled optimization algorithm parameters
129
- max_iter = 10
130
-
131
- # Set up the data fidelity term. Each layer has its own data fidelity module.
132
- in_channels_dual = [in_channel * n_dual + in_channel + in_channel for in_channel in in_channels]
133
- out_channels_dual = [in_channel * n_dual for in_channel in in_channels]
134
- in_channels_primal = [in_channel * n_primal + in_channel for in_channel in in_channels]
135
- out_channels_primal = [in_channel * n_primal for in_channel in in_channels]
136
-
137
- data_fidelity = [
138
- PDNetDataFid(model=PDNet_DualBlock(in_channels=in_channels_dual, out_channels=out_channels_dual).to(device)) for
139
- i in range(max_iter)
140
- ]
141
-
142
- # Set up the trainable prior. Each layer has its own prior module.
143
- prior = [
144
- PDNetPrior(model=PDNet_PrimalBlock(in_channels=in_channels_primal, out_channels=out_channels_primal).to(device))
145
- for i in range(max_iter)]
146
-
147
- # %%
148
- # Define the model.
149
- # -------------------------------
150
-
151
- def custom_init(y, physics):
152
- x0 = physics.A_dagger(y).repeat(1, n_primal, 1, 1)
153
- u0 = (0 * y).repeat(1, n_dual, 1, 1)
154
- return {"est": (x0, x0, u0)}
155
-
156
- def custom_output(X):
157
- x = X["est"][0]
158
- n_channels = x.shape[1] // n_primal
159
- if n_primal > 1:
160
- return X["est"][0][:, n_channels:(2 * n_channels), :, :]
161
- else:
162
- return X["est"][0][:, 0:n_channels, :, :]
163
-
164
- # %%
165
- # Define the unfolded trainable model.
166
- # -------------------------------------
167
- # The original paper of the learned primal dual algorithm the authors used the adjoint operator
168
- # in the primal update. However, the same authors (among others) find in the paper
169
- #
170
- # A. Hauptmann, J. Adler, S. Arridge, O. Öktem,
171
- # Multi-scale learned iterative reconstruction,
172
- # IEEE Transactions on Computational Imaging 6, 843-856, 2020.
173
- #
174
- # that using a filtered gradient can improve both the training speed and reconstruction quality significantly.
175
- # Following this approach, we use the filtered backprojection instead of the adjoint operator in the primal step.
176
-
177
- model = unfolded_builder(
178
- iteration=PDNetIteration(),
179
- params_algo={"beta": 0.0},
180
- data_fidelity=data_fidelity,
181
- prior=prior,
182
- max_iter=max_iter,
183
- custom_init=custom_init,
184
- get_output=custom_output,
185
- )
186
-
187
- return model.to(device)
188
-
189
-
190
- def init_weights(m):
191
- if isinstance(m, torch.nn.Linear):
192
- torch.torch.nn.init.xavier_uniform(m.weight)
193
- m.bias.data.fill_(0.0)
194
-
195
-
196
- class PDNet_PrimalBlock(torch.nn.Module):
197
- r"""
198
- Primal block for the Primal-Dual unfolding model.
199
-
200
- From https://arxiv.org/abs/1707.06474.
201
-
202
- Primal variables are images of shape (batch_size, in_channels, height, width). The input of each
203
- primal block is the concatenation of the current primal variable and the backprojected dual variable along
204
- the channel dimension. The output of each primal block is the current primal variable.
205
-
206
- :param int in_channels: number of input channels. Default: 6.
207
- :param int out_channels: number of output channels. Default: 5.
208
- :param int depth: number of convolutional layers in the block. Default: 3.
209
- :param bool bias: whether to use bias in convolutional layers. Default: True.
210
- :param int nf: number of features in the convolutional layers. Default: 32.
211
- """
212
-
213
- def __init__(self, in_channels=[1, 2, 3], out_channels=[1, 2, 3], depth=3, bias=True, nf=32):
214
- super(PDNet_PrimalBlock, self).__init__()
215
-
216
- self.separate_head = isinstance(in_channels, list)
217
- self.depth = depth
218
-
219
- self.in_conv = InHead(in_channels, nf, bias=bias)
220
- # self.m_head.apply(init_weights)
221
-
222
- # self.in_conv = torch.nn.Conv2d(
223
- # in_channels, nf, kernel_size=3, stride=1, padding=1, bias=bias
224
- # )
225
-
226
- self.in_conv.apply(init_weights)
227
- self.conv_list = torch.nn.ModuleList(
228
- [
229
- torch.nn.Conv2d(nf, nf, kernel_size=3, stride=1, padding=1, bias=bias)
230
- for _ in range(self.depth - 2)
231
- ]
232
- )
233
- self.conv_list.apply(init_weights)
234
- # self.out_conv = torch.nn.Conv2d(
235
- # nf, out_channels, kernel_size=3, stride=1, padding=1, bias=bias
236
- # )
237
- self.out_conv = OutTail(nf, out_channels, bias=bias)
238
- self.out_conv.apply(init_weights)
239
-
240
- self.nl_list = torch.nn.ModuleList([torch.nn.PReLU() for _ in range(self.depth - 1)])
241
-
242
- def forward(self, x, Atu):
243
- r"""
244
- Forward pass of the primal block.
245
-
246
- :param torch.Tensor x: current primal variable.
247
- :param torch.Tensor Atu: backprojected dual variable.
248
- :return: (:class:`torch.Tensor`) the current primal variable.
249
- """
250
- primal_channels = x.shape[1]
251
- x_in = torch.cat((x, Atu), dim=1)
252
-
253
- x_ = self.in_conv(x_in)
254
- x_ = self.nl_list[0](x_)
255
-
256
- for i in range(self.depth - 2):
257
- x_l = self.conv_list[i](x_)
258
- x_ = self.nl_list[i + 1](x_l)
259
-
260
- return self.out_conv(x_, primal_channels) + x
261
-
262
-
263
- class PDNet_DualBlock(torch.nn.Module):
264
- r"""
265
- Dual block for the Primal-Dual unfolding model.
266
-
267
- From https://arxiv.org/abs/1707.06474.
268
-
269
- Dual variables are images of shape (batch_size, in_channels, height, width). The input of each
270
- primal block is the concatenation of the current dual variable with the projected primal variable and
271
- the measurements. The output of each dual block is the current primal variable.
272
-
273
- :param int in_channels: number of input channels. Default: 7.
274
- :param int out_channels: number of output channels. Default: 5.
275
- :param int depth: number of convolutional layers in the block. Default: 3.
276
- :param bool bias: whether to use bias in convolutional layers. Default: True.
277
- :param int nf: number of features in the convolutional layers. Default: 32.
278
- """
279
-
280
- def __init__(self, in_channels=[1, 2, 3], out_channels=[6, 2, 3], depth=3, bias=True, nf=32):
281
- super(PDNet_DualBlock, self).__init__()
282
-
283
- self.depth = depth
284
- self.in_conv = InHead(in_channels, nf, bias=bias)
285
- # self.in_conv = torch.nn.Conv2d(
286
- # in_channels, nf, kernel_size=3, stride=1, padding=1, bias=bias
287
- # )
288
- self.in_conv.apply(init_weights)
289
- self.conv_list = torch.nn.ModuleList(
290
- [
291
- torch.nn.Conv2d(nf, nf, kernel_size=3, stride=1, padding=1, bias=bias)
292
- for _ in range(self.depth - 2)
293
- ]
294
- )
295
- self.conv_list.apply(init_weights)
296
- self.out_conv = OutTail(nf, out_channels, bias=bias)
297
- # self.out_conv = torch.nn.Conv2d(
298
- # nf, out_channels, kernel_size=3, stride=1, padding=1, bias=bias
299
- # )
300
- self.out_conv.apply(init_weights)
301
-
302
- self.nl_list = torch.nn.ModuleList([torch.nn.PReLU() for _ in range(self.depth - 1)])
303
-
304
- def forward(self, u, Ax_cur, y):
305
- r"""
306
- Forward pass of the dual block.
307
-
308
- :param torch.Tensor u: current dual variable.
309
- :param torch.Tensor Ax_cur: projection of the primal variable.
310
- :param torch.Tensor y: measurements.
311
- """
312
- dual_channels = u.shape[1]
313
- x_in = torch.cat((u, Ax_cur, y), dim=1)
314
-
315
- x_ = self.in_conv(x_in)
316
- x_ = self.nl_list[0](x_)
317
-
318
- for i in range(self.depth - 2):
319
- x_l = self.conv_list[i](x_)
320
- x_ = self.nl_list[i + 1](x_l)
321
-
322
- return self.out_conv(x_, dual_channels) + u
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
physics/inpainting_generator.py DELETED
@@ -1,107 +0,0 @@
1
- import torch
2
- from deepinv.physics.generator import PhysicsGenerator
3
-
4
-
5
- class InpaintingMaskGenerator(PhysicsGenerator):
6
-
7
- def __init__(
8
- self,
9
- mask_shape: tuple,
10
- num_channels: int = 1,
11
- device: str = "cpu",
12
- dtype: type = torch.float32,
13
- block_size_ratio=0.1,
14
- num_blocks=5,
15
- ) -> None:
16
- kwargs = {
17
- "mask_shape": mask_shape,
18
- "block_size_ratio": block_size_ratio,
19
- "num_blocks": num_blocks,
20
- }
21
- if len(mask_shape) != 2:
22
- raise ValueError(
23
- "mask_shape must 2D. Add channels via num_channels parameter"
24
- )
25
- super().__init__(
26
- num_channels=num_channels,
27
- device=device,
28
- dtype=dtype,
29
- **kwargs,
30
- )
31
-
32
- def generate_mask(self, image_shape, block_size_ratio, num_blocks):
33
- # Create an all-ones tensor which will serve as the initial mask
34
- mask = torch.ones(image_shape)
35
- batch_size = mask.shape[0]
36
-
37
- # Calculate block size based on the image dimensions and block_size_ratio
38
- block_width = int(image_shape[-2] * block_size_ratio)
39
- block_height = int(image_shape[-1] * block_size_ratio)
40
-
41
- # Generate random coordinates for each block in each batch
42
- x_coords = torch.randint(
43
- 0, image_shape[-1] - block_width, (batch_size, num_blocks)
44
- )
45
- y_coords = torch.randint(
46
- 0, image_shape[-2] - block_height, (batch_size, num_blocks)
47
- )
48
-
49
- # Create grids of indices for the block dimensions
50
- x_range = torch.arange(block_width).view(1, 1, -1)
51
- y_range = torch.arange(block_height).view(1, 1, -1)
52
-
53
- # Expand ranges to match the batch and num_blocks dimensions
54
- x_indices = x_coords.unsqueeze(-1) + x_range
55
- y_indices = y_coords.unsqueeze(-1) + y_range
56
-
57
- # Expand and flatten the indices for advanced indexing
58
- x_indices = x_indices.unsqueeze(2).expand(-1, -1, block_height, -1).reshape(-1)
59
- y_indices = y_indices.unsqueeze(3).expand(-1, -1, -1, block_width).reshape(-1)
60
-
61
- # Create batch indices for advanced indexing
62
- batch_indices = (
63
- torch.arange(batch_size)
64
- .view(-1, 1, 1)
65
- .expand(-1, num_blocks, block_width * block_height)
66
- .reshape(-1)
67
- )
68
- channel_indices = (
69
- torch.arange(3)
70
- .view(1, 1, 1, -1)
71
- .expand(batch_size, num_blocks, block_width * block_height, -1)
72
- .reshape(-1)
73
- )
74
-
75
- # Apply the blocks using advanced indexing
76
- mask[batch_indices, :, y_indices, x_indices] = 0
77
-
78
- return mask
79
-
80
- def step(
81
- self, batch_size: int = 1, block_size_ratio: float = None, num_blocks=None
82
- ):
83
- r"""
84
- Generate a random motion blur PSF with parameters :math:`\sigma` and :math:`l`
85
-
86
- :param int batch_size: batch_size.
87
- :param float sigma: the standard deviation of the Gaussian Process
88
- :param float l: the length scale of the trajectory
89
-
90
- :return: dictionary with key **'filter'**: the generated PSF of shape `(batch_size, 1, psf_size[0], psf_size[1])`
91
- """
92
-
93
- # TODO: add randomness
94
- block_size_ratio = (
95
- self.block_size_ratio if block_size_ratio is None else block_size_ratio
96
- )
97
- num_blocks = self.num_blocks if num_blocks is None else num_blocks
98
- batch_shape = (
99
- batch_size,
100
- self.num_channels,
101
- self.mask_shape[-2],
102
- self.mask_shape[-1],
103
- )
104
-
105
- mask = self.generate_mask(batch_shape, block_size_ratio, num_blocks)
106
-
107
- return {"mask": mask.to(self.factory_kwargs["device"])}