File size: 11,280 Bytes
12a4d59
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
import numpy as np
import deepinv
import torch
import deepinv as dinv
from deepinv.optim.data_fidelity import L2
from deepinv.optim.prior import PnP
from deepinv.unfolded import unfolded_builder
import copy
import deepinv.optim.utils

class PoissonGaussianDistance(dinv.optim.Distance):
    r"""
    Implementation of :math:`\distancename` as the normalized :math:`\ell_2` norm

    .. math::
        f(x) = (x-y)^{T}\Sigma_y(x-y)

    with :math:`\Sigma_y=\text{diag}(gamma y + \sigma^2)`

    :param float sigma: Gaussian noise parameter. Default: 1.
    :param float gain: Poisson noise parameter. Default 0.
    """

    def __init__(self, sigma=1.0, gain=0.):
        super().__init__()
        self.sigma = sigma
        self.gain = gain

    def fn(self, x, y, *args, **kwargs):
        r"""
        Computes the distance :math:`\distance{x}{y}` i.e.

        .. math::

            \distance{x}{y} = \frac{1}{2}\|x-y\|^2


        :param torch.Tensor u: Variable :math:`x` at which the data fidelity is computed.
        :param torch.Tensor y: Data :math:`y`.
        :return: (:class:`torch.Tensor`) data fidelity :math:`\datafid{u}{y}` of size `B` with `B` the size of the batch.
        """
        norm = 1.0 / (self.sigma**2 + y * self.gain)
        z = (x - y) * norm
        d = 0.5 * torch.norm(z.reshape(z.shape[0], -1), p=2, dim=-1) ** 2
        return d

    def grad(self, x, y, *args, **kwargs):
        r"""
        Computes the gradient of :math:`\distancename`, that is  :math:`\nabla_{x}\distance{x}{y}`, i.e.

        .. math::

            \nabla_{x}\distance{x}{y} = \frac{1}{\sigma^2} x-y


        :param torch.Tensor x: Variable :math:`x` at which the gradient is computed.
        :param torch.Tensor y: Observation :math:`y`.
        :return: (:class:`torch.Tensor`) gradient of the distance function :math:`\nabla_{x}\distance{x}{y}`.
        """
        norm = 1.0 / (self.sigma**2 + y * self.gain)
        return (x - y) * norm

    def prox(self, x, y, *args, gamma=1.0, **kwargs):
        r"""
        Proximal operator of :math:`\gamma \distance{x}{y} = \frac{\gamma}{2 \sigma^2} \|x-y\|^2`.

        Computes :math:`\operatorname{prox}_{\gamma \distancename}`, i.e.

        .. math::

           \operatorname{prox}_{\gamma \distancename} = \underset{u}{\text{argmin}} \frac{\gamma}{2\sigma^2}\|u-y\|_2^2+\frac{1}{2}\|u-x\|_2^2


        :param torch.Tensor x: Variable :math:`x` at which the proximity operator is computed.
        :param torch.Tensor y: Data :math:`y`.
        :param float gamma: thresholding parameter.
        :return: (:class:`torch.Tensor`) proximity operator :math:`\operatorname{prox}_{\gamma \distancename}(x)`.
        """
        norm = 1.0 / (self.sigma**2 + y * self.gain)
        return (x + norm * gamma * y) / (1 + gamma * norm)


class PoissonGaussianDataFidelity(dinv.optim.DataFidelity):
    r"""
    Implementation of the data-fidelity as the normalized :math:`\ell_2` norm

    .. math::

        f(x) = \|\forw{x}-y\|^2_{\text{diag}(\sigma^2 + y \gamma)}

    It can be used to define a log-likelihood function associated with Poisson Gaussian noise
    by setting an appropriate noise level :math:`\sigma`.

    :param float sigma: Standard deviation of the noise to be used as a normalisation factor.
    :param float gain: Gain factor of the data-fidelity term.
    """

    def __init__(self, sigma=1.0, gain=0.):
        super().__init__()
        self.d = PoissonGaussianDistance(sigma=sigma, gain=gain)
        self.gain = gain
        self.sigma = sigma

    def prox(self, x, y, physics, gamma=1.0, *args, **kwargs):
        r"""
        Proximal operator of :math:`\gamma \datafid{Ax}{y} = \frac{\gamma}{2\sigma^2}\|Ax-y\|^2`.

        Computes :math:`\operatorname{prox}_{\gamma \datafidname}`, i.e.

        .. math::

           \operatorname{prox}_{\gamma \datafidname} = \underset{u}{\text{argmin}} \frac{\gamma}{2\sigma^2}\|Au-y\|_2^2+\frac{1}{2}\|u-x\|_2^2


        :param torch.Tensor x: Variable :math:`x` at which the proximity operator is computed.
        :param torch.Tensor y: Data :math:`y`.
        :param deepinv.physics.Physics physics: physics model.
        :param float gamma: stepsize of the proximity operator.
        :return: (:class:`torch.Tensor`) proximity operator :math:`\operatorname{prox}_{\gamma \datafidname}(x)`.
        """
        assert isinstance(physics, dinv.physics.LinearPhysics), "not implemented for non-linear physics"   
        if isinstance(physics, dinv.physics.StackedPhysics):
            device=y[0].device
            noise_model = physics[-1].noise_model
        else:
            device=y.device
            noise_model = physics.noise_model
        if hasattr(noise_model, "gain"):
            self.gain = noise_model.gain.detach().to(device)
        if hasattr(noise_model, "sigma"):
            self.sigma = noise_model.sigma.detach().to(device)
        # Ensure sigma is a tensor and reshape if necessary
        if isinstance(self.sigma, float):
            self.sigma = torch.tensor([self.sigma], device=device)
        if self.sigma.ndim == 0 :
            self.sigma = self.sigma.unsqueeze(0).to(device)
        # Ensure gain is a tensor and reshape if necessary
        if isinstance(self.gain, float):
            self.gain = torch.tensor([self.gain], device=device)
        if self.gain.ndim == 0 :
            self.gain = self.gain.unsqueeze(0).to(device)
        if self.gain[0] > 0 :
            norm = gamma / (self.sigma[:, None, None, None]**2 + y * self.gain[:, None, None, None])
        else : 
            norm = gamma / (self.sigma[:, None, None, None]**2)
        A = lambda u: physics.A_adjoint(physics.A(u)*norm) + u
        b = physics.A_adjoint(norm*y) + x
        return deepinv.optim.utils.conjugate_gradient(A, b, init=x, max_iter=3, tol=1e-3)

from deepinv.optim.optim_iterators import OptimIterator, fStep, gStep

class myHQSIteration(OptimIterator):
    r"""
    Single iteration of half-quadratic splitting.

    Class for a single iteration of the Half-Quadratic Splitting (HQS) algorithm for minimising :math:`f(x) + \lambda \regname(x)`.
    The iteration is given by


    .. math::
        \begin{equation*}
        \begin{aligned}
        u_{k} &= \operatorname{prox}_{\gamma f}(x_k) \\
        x_{k+1} &= \operatorname{prox}_{\sigma \lambda \regname}(u_k).
        \end{aligned}
        \end{equation*}


    where :math:`\gamma` and :math:`\sigma` are step-sizes. Note that this algorithm does not converge to
    a minimizer of :math:`f(x) + \lambda  \regname(x)`, but instead to a minimizer of
    :math:`\gamma\, ^1f+\sigma \lambda \regname`, where :math:`^1f` denotes
    the Moreau envelope of :math:`f`

    """

    def __init__(self, **kwargs):
        super(myHQSIteration, self).__init__(**kwargs)
        self.g_step = mygStepHQS(**kwargs)
        self.f_step = myfStepHQS(**kwargs)
        self.requires_prox_g = True

class myfStepHQS(fStep):
    r"""
    HQS fStep module.
    """

    def __init__(self, **kwargs):
        super(myfStepHQS, self).__init__(**kwargs)

    def forward(self, x, cur_data_fidelity, cur_params, y, physics):
        r"""
        Single proximal step on the data-fidelity term :math:`f`.

        :param torch.Tensor x: Current iterate :math:`x_k`.
        :param deepinv.optim.DataFidelity cur_data_fidelity: Instance of the DataFidelity class defining the current data_fidelity.
        :param dict cur_params: Dictionary containing the current parameters of the algorithm.
        :param torch.Tensor y: Input data.
        :param deepinv.physics.Physics physics: Instance of the physics modeling the data-fidelity term.
        """
        return cur_data_fidelity.prox(x, y, physics, gamma=cur_params["stepsize"])

class mygStepHQS(gStep):
    r"""
    HQS gStep module.
    """

    def __init__(self, **kwargs):
        super(mygStepHQS, self).__init__(**kwargs)

    def forward(self, x, cur_prior, cur_params):
        r"""
        Single proximal step on the prior term :math:`\lambda \regname`.

        :param torch.Tensor x: Current iterate :math:`x_k`.
        :param dict cur_prior: Class containing the current prior.
        :param dict cur_params: Dictionary containing the current parameters of the algorithm.
        """
        return cur_prior.prox(
            x,
            sigma_denoiser = cur_params["g_param"],
            gain_denoiser = cur_params["gain_param"],
            gamma=cur_params["lambda"] * cur_params["stepsize"],
        )


def get_unrolled_architecture(gain_param_init = 1e-3, weight_tied = True, model = None, device = 'cpu'):

    # Unrolled optimization algorithm parameters
    max_iter = 8  # number of unfolded layers

    # Select the data fidelity term
    

    # Set up the trainable denoising prior
    # Here the prior model is common for all iterations
    if model is not None : 
        denoiser = model.to(device)
    else :
        denoiser = dinv.models.DRUNet(
        pretrained= '/lustre/fswork/projects/rech/nyd/commun/mterris/base_checkpoints/drunet_deepinv_color_finetune_22k.pth',
        ).to(device)

    class myPnP(PnP):
        r"""
        Gradient-Step Denoiser prior.
        """

        def __init__(self, *args, **kwargs):
            super().__init__(*args, **kwargs)

        def prox(self, x, sigma_denoiser, gain_denoiser, *args, **kwargs):
            if not self.training:
                pad = (-x.size(-2) % 8, -x.size(-1) % 8)
                x = torch.nn.functional.pad(x, (0, pad[1], 0, pad[0]), mode="constant")
            out = self.denoiser(x, sigma=sigma_denoiser, gamma=gain_denoiser)
            if not self.training:
                out = out[..., : -pad[0] or None, : -pad[1] or None]
            return out
    
    data_fidelity = PoissonGaussianDataFidelity()
    
    if not weight_tied :
        prior = [myPnP(denoiser=copy.deepcopy(denoiser)) for i in range(max_iter)]
    else :
        prior = [myPnP(denoiser=denoiser)]
    
    def get_DPIR_params(noise_level_img, max_iter=8):
        r"""
        Default parameters for the DPIR Plug-and-Play algorithm.

        :param float noise_level_img: Noise level of the input image.
        """
        s1 = 49.0 / 255.0
        s2 = noise_level_img
        sigma_denoiser = np.logspace(np.log10(s1), np.log10(s2), max_iter).astype(
            np.float32
        )
        stepsize = (sigma_denoiser / max(0.01, noise_level_img)) ** 2
        lamb = 1 / 0.23
        return list(sigma_denoiser), list(lamb * stepsize)

    sigma_denoiser, stepsize = get_DPIR_params(0.05)
    stepsize = torch.tensor(stepsize) * (torch.tensor(sigma_denoiser)**2)
    gain_denoiser = [gain_param_init]*len(sigma_denoiser)
    params_algo = {"stepsize": stepsize, "g_param": sigma_denoiser, "gain_param": gain_denoiser}

    trainable_params = [
        "g_param",
        "gain_param"
        "stepsize",
    ]  # define which parameters from 'params_algo' are trainable

    # Define the unfolded trainable model.
    model = unfolded_builder(
        iteration=myHQSIteration(),
        params_algo=params_algo.copy(),
        trainable_params=trainable_params,
        data_fidelity=data_fidelity,
        max_iter=max_iter,
        prior=prior,
        device=device,
    )

    return model.to(device)