Spaces:
Sleeping
Sleeping
Cleanup
Browse files- evals.py +1 -1
- utils.py → model_factory.py +0 -1
- models/PDNet.py +0 -322
- physics/inpainting_generator.py +0 -107
evals.py
CHANGED
@@ -7,7 +7,7 @@ from deepinv.physics.generator import MotionBlurGenerator, SigmaGenerator
|
|
7 |
from torchvision import transforms
|
8 |
|
9 |
from datasets import Preprocessed_fastMRI, Preprocessed_LIDCIDRI, LsdirMiniDataset
|
10 |
-
from
|
11 |
|
12 |
DEFAULT_MODEL_PARAMS = {
|
13 |
"in_channels": [1, 2, 3],
|
|
|
7 |
from torchvision import transforms
|
8 |
|
9 |
from datasets import Preprocessed_fastMRI, Preprocessed_LIDCIDRI, LsdirMiniDataset
|
10 |
+
from model_factory import get_model
|
11 |
|
12 |
DEFAULT_MODEL_PARAMS = {
|
13 |
"in_channels": [1, 2, 3],
|
utils.py → model_factory.py
RENAMED
@@ -3,7 +3,6 @@ import torch.nn as nn
|
|
3 |
import deepinv as dinv
|
4 |
|
5 |
from models.unext_wip import UNeXt
|
6 |
-
from models.unrolled_dpir import get_unrolled_architecture
|
7 |
from physics.multiscale import Pad
|
8 |
|
9 |
|
|
|
3 |
import deepinv as dinv
|
4 |
|
5 |
from models.unext_wip import UNeXt
|
|
|
6 |
from physics.multiscale import Pad
|
7 |
|
8 |
|
models/PDNet.py
DELETED
@@ -1,322 +0,0 @@
|
|
1 |
-
from pathlib import Path
|
2 |
-
|
3 |
-
import torch
|
4 |
-
from torch.func import vmap
|
5 |
-
from torch.utils.data import DataLoader
|
6 |
-
import deepinv as dinv
|
7 |
-
from deepinv.unfolded import unfolded_builder
|
8 |
-
from deepinv.utils.phantoms import RandomPhantomDataset, SheppLoganDataset
|
9 |
-
from deepinv.optim.optim_iterators import CPIteration, fStep, gStep
|
10 |
-
from deepinv.optim import Prior, DataFidelity
|
11 |
-
from deepinv.utils import TensorList
|
12 |
-
|
13 |
-
from physics.multiscale import MultiScaleLinearPhysics
|
14 |
-
from models.heads import Heads, Tails, InHead, OutTail, ConvChannels, SNRModule, EquivConvModule, EquivHeads
|
15 |
-
|
16 |
-
|
17 |
-
def get_PDNet_architecture(in_channels=[1, 2, 3], out_channels=[1, 2, 3], n_primal=3, n_dual=3, device='cuda'):
|
18 |
-
class PDNetIteration(CPIteration):
|
19 |
-
r"""Single iteration of learned primal dual.
|
20 |
-
We only redefine the fStep and gStep classes.
|
21 |
-
The forward method is inherited from the CPIteration class.
|
22 |
-
"""
|
23 |
-
|
24 |
-
def __init__(self, **kwargs):
|
25 |
-
super().__init__(**kwargs)
|
26 |
-
self.g_step = gStepPDNet(**kwargs)
|
27 |
-
self.f_step = fStepPDNet(**kwargs)
|
28 |
-
|
29 |
-
def forward(
|
30 |
-
self, X, cur_data_fidelity, cur_prior, cur_params, y, physics, *args, **kwargs
|
31 |
-
):
|
32 |
-
r"""
|
33 |
-
Single iteration of the Chambolle-Pock algorithm.
|
34 |
-
|
35 |
-
:param dict X: Dictionary containing the current iterate and the estimated cost.
|
36 |
-
:param deepinv.optim.DataFidelity cur_data_fidelity: Instance of the DataFidelity class defining the current data_fidelity.
|
37 |
-
:param deepinv.optim.Prior cur_prior: Instance of the Prior class defining the current prior.
|
38 |
-
:param dict cur_params: dictionary containing the current parameters of the algorithm.
|
39 |
-
:param torch.Tensor y: Input data.
|
40 |
-
:param deepinv.physics.Physics physics: Instance of the physics modeling the data-fidelity term.
|
41 |
-
:return: Dictionary `{"est": (x, ), "cost": F}` containing the updated current iterate and the estimated current cost.
|
42 |
-
"""
|
43 |
-
x_prev, z_prev, u_prev = X["est"] # x : primal, z : relaxed primal, u : dual
|
44 |
-
BS, C_primal, H_primal, W_primal = x_prev.shape
|
45 |
-
_, C_dual, H_dual, W_dual = u_prev.shape
|
46 |
-
n_channels = C_primal // n_primal
|
47 |
-
K = lambda x: torch.cat(
|
48 |
-
[physics.A(x[:, i * n_channels:(i + 1) * n_channels, :, :]) for i in range(n_primal)], dim=1)
|
49 |
-
K_adjoint = lambda x: torch.cat(
|
50 |
-
[physics.A_adjoint(x[:, i * n_channels:(i + 1) * n_channels, :, :]) for i in range(n_dual)], dim=1)
|
51 |
-
u = self.f_step(u_prev, K(z_prev), cur_data_fidelity, y, physics, n_channels,
|
52 |
-
cur_params) # dual update (data_fid)
|
53 |
-
x = self.g_step(x_prev, K_adjoint(u), cur_prior, n_channels, cur_params) # primal update (prior)
|
54 |
-
z = x + cur_params["beta"] * (x - x_prev)
|
55 |
-
F = (
|
56 |
-
self.F_fn(x, cur_data_fidelity, cur_prior, cur_params, y, physics)
|
57 |
-
if self.has_cost
|
58 |
-
else None
|
59 |
-
)
|
60 |
-
return {"est": (x, z, u), "cost": F}
|
61 |
-
|
62 |
-
class fStepPDNet(fStep):
|
63 |
-
r"""
|
64 |
-
Dual update of the PDNet algorithm.
|
65 |
-
We write it as a proximal operator of the data fidelity term.
|
66 |
-
This proximal mapping is to be replaced by a trainable model.
|
67 |
-
"""
|
68 |
-
|
69 |
-
def __init__(self, **kwargs):
|
70 |
-
super().__init__(**kwargs)
|
71 |
-
|
72 |
-
def forward(self, x, w, cur_data_fidelity, y, physics, n_channels, *args):
|
73 |
-
r"""
|
74 |
-
:param torch.Tensor x: Current first variable :math:`u`.
|
75 |
-
:param torch.Tensor w: Current second variable :math:`A z`.
|
76 |
-
:param deepinv.optim.data_fidelity cur_data_fidelity: Instance of the DataFidelity class defining the current data fidelity term.
|
77 |
-
:param torch.Tensor y: Input data.
|
78 |
-
"""
|
79 |
-
return cur_data_fidelity.prox(x, w, y, n_channels)
|
80 |
-
|
81 |
-
class gStepPDNet(gStep):
|
82 |
-
r"""
|
83 |
-
Primal update of the PDNet algorithm.
|
84 |
-
We write it as a proximal operator of the prior term.
|
85 |
-
This proximal mapping is to be replaced by a trainable model.
|
86 |
-
"""
|
87 |
-
|
88 |
-
def __init__(self, **kwargs):
|
89 |
-
super().__init__(**kwargs)
|
90 |
-
|
91 |
-
def forward(self, x, w, cur_prior, n_channels, *args):
|
92 |
-
r"""
|
93 |
-
:param torch.Tensor x: Current first variable :math:`x`.
|
94 |
-
:param torch.Tensor w: Current second variable :math:`A^\top u`.
|
95 |
-
:param deepinv.optim.prior cur_prior: Instance of the Prior class defining the current prior.
|
96 |
-
"""
|
97 |
-
return cur_prior.prox(x, w, n_channels)
|
98 |
-
|
99 |
-
# %%
|
100 |
-
# Define the trainable prior and data fidelity terms.
|
101 |
-
# ---------------------------------------------------
|
102 |
-
# Prior and data-fidelity are respectively defined as subclass of :class:`deepinv.optim.Prior` and :class:`deepinv.optim.DataFidelity`.
|
103 |
-
# Their proximal operators are replaced by trainable models.
|
104 |
-
|
105 |
-
class PDNetPrior(Prior):
|
106 |
-
def __init__(self, model, *args, **kwargs):
|
107 |
-
super().__init__(*args, **kwargs)
|
108 |
-
self.model = model
|
109 |
-
|
110 |
-
def prox(self, x, w, n_channels):
|
111 |
-
# give to the model : full primal + premier de dual
|
112 |
-
dual_cond = w[:, 0:n_channels, :, :]
|
113 |
-
return self.model(x, dual_cond)
|
114 |
-
|
115 |
-
class PDNetDataFid(DataFidelity):
|
116 |
-
def __init__(self, model, *args, **kwargs):
|
117 |
-
super().__init__(*args, **kwargs)
|
118 |
-
self.model = model
|
119 |
-
|
120 |
-
def prox(self, x, w, y, n_channels):
|
121 |
-
# give to the model : full dual + deuxieme de primal + y = n_channel*n_dual + n_channel + n_channel
|
122 |
-
if n_primal > 1:
|
123 |
-
primal_cond = w[:, n_channels:(2 * n_channels), :, :]
|
124 |
-
else:
|
125 |
-
primal_cond = w[:, 0:n_channels, :, :]
|
126 |
-
return self.model(x, primal_cond, y)
|
127 |
-
|
128 |
-
# Unrolled optimization algorithm parameters
|
129 |
-
max_iter = 10
|
130 |
-
|
131 |
-
# Set up the data fidelity term. Each layer has its own data fidelity module.
|
132 |
-
in_channels_dual = [in_channel * n_dual + in_channel + in_channel for in_channel in in_channels]
|
133 |
-
out_channels_dual = [in_channel * n_dual for in_channel in in_channels]
|
134 |
-
in_channels_primal = [in_channel * n_primal + in_channel for in_channel in in_channels]
|
135 |
-
out_channels_primal = [in_channel * n_primal for in_channel in in_channels]
|
136 |
-
|
137 |
-
data_fidelity = [
|
138 |
-
PDNetDataFid(model=PDNet_DualBlock(in_channels=in_channels_dual, out_channels=out_channels_dual).to(device)) for
|
139 |
-
i in range(max_iter)
|
140 |
-
]
|
141 |
-
|
142 |
-
# Set up the trainable prior. Each layer has its own prior module.
|
143 |
-
prior = [
|
144 |
-
PDNetPrior(model=PDNet_PrimalBlock(in_channels=in_channels_primal, out_channels=out_channels_primal).to(device))
|
145 |
-
for i in range(max_iter)]
|
146 |
-
|
147 |
-
# %%
|
148 |
-
# Define the model.
|
149 |
-
# -------------------------------
|
150 |
-
|
151 |
-
def custom_init(y, physics):
|
152 |
-
x0 = physics.A_dagger(y).repeat(1, n_primal, 1, 1)
|
153 |
-
u0 = (0 * y).repeat(1, n_dual, 1, 1)
|
154 |
-
return {"est": (x0, x0, u0)}
|
155 |
-
|
156 |
-
def custom_output(X):
|
157 |
-
x = X["est"][0]
|
158 |
-
n_channels = x.shape[1] // n_primal
|
159 |
-
if n_primal > 1:
|
160 |
-
return X["est"][0][:, n_channels:(2 * n_channels), :, :]
|
161 |
-
else:
|
162 |
-
return X["est"][0][:, 0:n_channels, :, :]
|
163 |
-
|
164 |
-
# %%
|
165 |
-
# Define the unfolded trainable model.
|
166 |
-
# -------------------------------------
|
167 |
-
# The original paper of the learned primal dual algorithm the authors used the adjoint operator
|
168 |
-
# in the primal update. However, the same authors (among others) find in the paper
|
169 |
-
#
|
170 |
-
# A. Hauptmann, J. Adler, S. Arridge, O. Öktem,
|
171 |
-
# Multi-scale learned iterative reconstruction,
|
172 |
-
# IEEE Transactions on Computational Imaging 6, 843-856, 2020.
|
173 |
-
#
|
174 |
-
# that using a filtered gradient can improve both the training speed and reconstruction quality significantly.
|
175 |
-
# Following this approach, we use the filtered backprojection instead of the adjoint operator in the primal step.
|
176 |
-
|
177 |
-
model = unfolded_builder(
|
178 |
-
iteration=PDNetIteration(),
|
179 |
-
params_algo={"beta": 0.0},
|
180 |
-
data_fidelity=data_fidelity,
|
181 |
-
prior=prior,
|
182 |
-
max_iter=max_iter,
|
183 |
-
custom_init=custom_init,
|
184 |
-
get_output=custom_output,
|
185 |
-
)
|
186 |
-
|
187 |
-
return model.to(device)
|
188 |
-
|
189 |
-
|
190 |
-
def init_weights(m):
|
191 |
-
if isinstance(m, torch.nn.Linear):
|
192 |
-
torch.torch.nn.init.xavier_uniform(m.weight)
|
193 |
-
m.bias.data.fill_(0.0)
|
194 |
-
|
195 |
-
|
196 |
-
class PDNet_PrimalBlock(torch.nn.Module):
|
197 |
-
r"""
|
198 |
-
Primal block for the Primal-Dual unfolding model.
|
199 |
-
|
200 |
-
From https://arxiv.org/abs/1707.06474.
|
201 |
-
|
202 |
-
Primal variables are images of shape (batch_size, in_channels, height, width). The input of each
|
203 |
-
primal block is the concatenation of the current primal variable and the backprojected dual variable along
|
204 |
-
the channel dimension. The output of each primal block is the current primal variable.
|
205 |
-
|
206 |
-
:param int in_channels: number of input channels. Default: 6.
|
207 |
-
:param int out_channels: number of output channels. Default: 5.
|
208 |
-
:param int depth: number of convolutional layers in the block. Default: 3.
|
209 |
-
:param bool bias: whether to use bias in convolutional layers. Default: True.
|
210 |
-
:param int nf: number of features in the convolutional layers. Default: 32.
|
211 |
-
"""
|
212 |
-
|
213 |
-
def __init__(self, in_channels=[1, 2, 3], out_channels=[1, 2, 3], depth=3, bias=True, nf=32):
|
214 |
-
super(PDNet_PrimalBlock, self).__init__()
|
215 |
-
|
216 |
-
self.separate_head = isinstance(in_channels, list)
|
217 |
-
self.depth = depth
|
218 |
-
|
219 |
-
self.in_conv = InHead(in_channels, nf, bias=bias)
|
220 |
-
# self.m_head.apply(init_weights)
|
221 |
-
|
222 |
-
# self.in_conv = torch.nn.Conv2d(
|
223 |
-
# in_channels, nf, kernel_size=3, stride=1, padding=1, bias=bias
|
224 |
-
# )
|
225 |
-
|
226 |
-
self.in_conv.apply(init_weights)
|
227 |
-
self.conv_list = torch.nn.ModuleList(
|
228 |
-
[
|
229 |
-
torch.nn.Conv2d(nf, nf, kernel_size=3, stride=1, padding=1, bias=bias)
|
230 |
-
for _ in range(self.depth - 2)
|
231 |
-
]
|
232 |
-
)
|
233 |
-
self.conv_list.apply(init_weights)
|
234 |
-
# self.out_conv = torch.nn.Conv2d(
|
235 |
-
# nf, out_channels, kernel_size=3, stride=1, padding=1, bias=bias
|
236 |
-
# )
|
237 |
-
self.out_conv = OutTail(nf, out_channels, bias=bias)
|
238 |
-
self.out_conv.apply(init_weights)
|
239 |
-
|
240 |
-
self.nl_list = torch.nn.ModuleList([torch.nn.PReLU() for _ in range(self.depth - 1)])
|
241 |
-
|
242 |
-
def forward(self, x, Atu):
|
243 |
-
r"""
|
244 |
-
Forward pass of the primal block.
|
245 |
-
|
246 |
-
:param torch.Tensor x: current primal variable.
|
247 |
-
:param torch.Tensor Atu: backprojected dual variable.
|
248 |
-
:return: (:class:`torch.Tensor`) the current primal variable.
|
249 |
-
"""
|
250 |
-
primal_channels = x.shape[1]
|
251 |
-
x_in = torch.cat((x, Atu), dim=1)
|
252 |
-
|
253 |
-
x_ = self.in_conv(x_in)
|
254 |
-
x_ = self.nl_list[0](x_)
|
255 |
-
|
256 |
-
for i in range(self.depth - 2):
|
257 |
-
x_l = self.conv_list[i](x_)
|
258 |
-
x_ = self.nl_list[i + 1](x_l)
|
259 |
-
|
260 |
-
return self.out_conv(x_, primal_channels) + x
|
261 |
-
|
262 |
-
|
263 |
-
class PDNet_DualBlock(torch.nn.Module):
|
264 |
-
r"""
|
265 |
-
Dual block for the Primal-Dual unfolding model.
|
266 |
-
|
267 |
-
From https://arxiv.org/abs/1707.06474.
|
268 |
-
|
269 |
-
Dual variables are images of shape (batch_size, in_channels, height, width). The input of each
|
270 |
-
primal block is the concatenation of the current dual variable with the projected primal variable and
|
271 |
-
the measurements. The output of each dual block is the current primal variable.
|
272 |
-
|
273 |
-
:param int in_channels: number of input channels. Default: 7.
|
274 |
-
:param int out_channels: number of output channels. Default: 5.
|
275 |
-
:param int depth: number of convolutional layers in the block. Default: 3.
|
276 |
-
:param bool bias: whether to use bias in convolutional layers. Default: True.
|
277 |
-
:param int nf: number of features in the convolutional layers. Default: 32.
|
278 |
-
"""
|
279 |
-
|
280 |
-
def __init__(self, in_channels=[1, 2, 3], out_channels=[6, 2, 3], depth=3, bias=True, nf=32):
|
281 |
-
super(PDNet_DualBlock, self).__init__()
|
282 |
-
|
283 |
-
self.depth = depth
|
284 |
-
self.in_conv = InHead(in_channels, nf, bias=bias)
|
285 |
-
# self.in_conv = torch.nn.Conv2d(
|
286 |
-
# in_channels, nf, kernel_size=3, stride=1, padding=1, bias=bias
|
287 |
-
# )
|
288 |
-
self.in_conv.apply(init_weights)
|
289 |
-
self.conv_list = torch.nn.ModuleList(
|
290 |
-
[
|
291 |
-
torch.nn.Conv2d(nf, nf, kernel_size=3, stride=1, padding=1, bias=bias)
|
292 |
-
for _ in range(self.depth - 2)
|
293 |
-
]
|
294 |
-
)
|
295 |
-
self.conv_list.apply(init_weights)
|
296 |
-
self.out_conv = OutTail(nf, out_channels, bias=bias)
|
297 |
-
# self.out_conv = torch.nn.Conv2d(
|
298 |
-
# nf, out_channels, kernel_size=3, stride=1, padding=1, bias=bias
|
299 |
-
# )
|
300 |
-
self.out_conv.apply(init_weights)
|
301 |
-
|
302 |
-
self.nl_list = torch.nn.ModuleList([torch.nn.PReLU() for _ in range(self.depth - 1)])
|
303 |
-
|
304 |
-
def forward(self, u, Ax_cur, y):
|
305 |
-
r"""
|
306 |
-
Forward pass of the dual block.
|
307 |
-
|
308 |
-
:param torch.Tensor u: current dual variable.
|
309 |
-
:param torch.Tensor Ax_cur: projection of the primal variable.
|
310 |
-
:param torch.Tensor y: measurements.
|
311 |
-
"""
|
312 |
-
dual_channels = u.shape[1]
|
313 |
-
x_in = torch.cat((u, Ax_cur, y), dim=1)
|
314 |
-
|
315 |
-
x_ = self.in_conv(x_in)
|
316 |
-
x_ = self.nl_list[0](x_)
|
317 |
-
|
318 |
-
for i in range(self.depth - 2):
|
319 |
-
x_l = self.conv_list[i](x_)
|
320 |
-
x_ = self.nl_list[i + 1](x_l)
|
321 |
-
|
322 |
-
return self.out_conv(x_, dual_channels) + u
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
physics/inpainting_generator.py
DELETED
@@ -1,107 +0,0 @@
|
|
1 |
-
import torch
|
2 |
-
from deepinv.physics.generator import PhysicsGenerator
|
3 |
-
|
4 |
-
|
5 |
-
class InpaintingMaskGenerator(PhysicsGenerator):
|
6 |
-
|
7 |
-
def __init__(
|
8 |
-
self,
|
9 |
-
mask_shape: tuple,
|
10 |
-
num_channels: int = 1,
|
11 |
-
device: str = "cpu",
|
12 |
-
dtype: type = torch.float32,
|
13 |
-
block_size_ratio=0.1,
|
14 |
-
num_blocks=5,
|
15 |
-
) -> None:
|
16 |
-
kwargs = {
|
17 |
-
"mask_shape": mask_shape,
|
18 |
-
"block_size_ratio": block_size_ratio,
|
19 |
-
"num_blocks": num_blocks,
|
20 |
-
}
|
21 |
-
if len(mask_shape) != 2:
|
22 |
-
raise ValueError(
|
23 |
-
"mask_shape must 2D. Add channels via num_channels parameter"
|
24 |
-
)
|
25 |
-
super().__init__(
|
26 |
-
num_channels=num_channels,
|
27 |
-
device=device,
|
28 |
-
dtype=dtype,
|
29 |
-
**kwargs,
|
30 |
-
)
|
31 |
-
|
32 |
-
def generate_mask(self, image_shape, block_size_ratio, num_blocks):
|
33 |
-
# Create an all-ones tensor which will serve as the initial mask
|
34 |
-
mask = torch.ones(image_shape)
|
35 |
-
batch_size = mask.shape[0]
|
36 |
-
|
37 |
-
# Calculate block size based on the image dimensions and block_size_ratio
|
38 |
-
block_width = int(image_shape[-2] * block_size_ratio)
|
39 |
-
block_height = int(image_shape[-1] * block_size_ratio)
|
40 |
-
|
41 |
-
# Generate random coordinates for each block in each batch
|
42 |
-
x_coords = torch.randint(
|
43 |
-
0, image_shape[-1] - block_width, (batch_size, num_blocks)
|
44 |
-
)
|
45 |
-
y_coords = torch.randint(
|
46 |
-
0, image_shape[-2] - block_height, (batch_size, num_blocks)
|
47 |
-
)
|
48 |
-
|
49 |
-
# Create grids of indices for the block dimensions
|
50 |
-
x_range = torch.arange(block_width).view(1, 1, -1)
|
51 |
-
y_range = torch.arange(block_height).view(1, 1, -1)
|
52 |
-
|
53 |
-
# Expand ranges to match the batch and num_blocks dimensions
|
54 |
-
x_indices = x_coords.unsqueeze(-1) + x_range
|
55 |
-
y_indices = y_coords.unsqueeze(-1) + y_range
|
56 |
-
|
57 |
-
# Expand and flatten the indices for advanced indexing
|
58 |
-
x_indices = x_indices.unsqueeze(2).expand(-1, -1, block_height, -1).reshape(-1)
|
59 |
-
y_indices = y_indices.unsqueeze(3).expand(-1, -1, -1, block_width).reshape(-1)
|
60 |
-
|
61 |
-
# Create batch indices for advanced indexing
|
62 |
-
batch_indices = (
|
63 |
-
torch.arange(batch_size)
|
64 |
-
.view(-1, 1, 1)
|
65 |
-
.expand(-1, num_blocks, block_width * block_height)
|
66 |
-
.reshape(-1)
|
67 |
-
)
|
68 |
-
channel_indices = (
|
69 |
-
torch.arange(3)
|
70 |
-
.view(1, 1, 1, -1)
|
71 |
-
.expand(batch_size, num_blocks, block_width * block_height, -1)
|
72 |
-
.reshape(-1)
|
73 |
-
)
|
74 |
-
|
75 |
-
# Apply the blocks using advanced indexing
|
76 |
-
mask[batch_indices, :, y_indices, x_indices] = 0
|
77 |
-
|
78 |
-
return mask
|
79 |
-
|
80 |
-
def step(
|
81 |
-
self, batch_size: int = 1, block_size_ratio: float = None, num_blocks=None
|
82 |
-
):
|
83 |
-
r"""
|
84 |
-
Generate a random motion blur PSF with parameters :math:`\sigma` and :math:`l`
|
85 |
-
|
86 |
-
:param int batch_size: batch_size.
|
87 |
-
:param float sigma: the standard deviation of the Gaussian Process
|
88 |
-
:param float l: the length scale of the trajectory
|
89 |
-
|
90 |
-
:return: dictionary with key **'filter'**: the generated PSF of shape `(batch_size, 1, psf_size[0], psf_size[1])`
|
91 |
-
"""
|
92 |
-
|
93 |
-
# TODO: add randomness
|
94 |
-
block_size_ratio = (
|
95 |
-
self.block_size_ratio if block_size_ratio is None else block_size_ratio
|
96 |
-
)
|
97 |
-
num_blocks = self.num_blocks if num_blocks is None else num_blocks
|
98 |
-
batch_shape = (
|
99 |
-
batch_size,
|
100 |
-
self.num_channels,
|
101 |
-
self.mask_shape[-2],
|
102 |
-
self.mask_shape[-1],
|
103 |
-
)
|
104 |
-
|
105 |
-
mask = self.generate_mask(batch_shape, block_size_ratio, num_blocks)
|
106 |
-
|
107 |
-
return {"mask": mask.to(self.factory_kwargs["device"])}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|