denoising / models /unrolled_dpir.py
Yonuts's picture
gradio demo
12a4d59
raw
history blame
11.3 kB
import numpy as np
import deepinv
import torch
import deepinv as dinv
from deepinv.optim.data_fidelity import L2
from deepinv.optim.prior import PnP
from deepinv.unfolded import unfolded_builder
import copy
import deepinv.optim.utils
class PoissonGaussianDistance(dinv.optim.Distance):
r"""
Implementation of :math:`\distancename` as the normalized :math:`\ell_2` norm
.. math::
f(x) = (x-y)^{T}\Sigma_y(x-y)
with :math:`\Sigma_y=\text{diag}(gamma y + \sigma^2)`
:param float sigma: Gaussian noise parameter. Default: 1.
:param float gain: Poisson noise parameter. Default 0.
"""
def __init__(self, sigma=1.0, gain=0.):
super().__init__()
self.sigma = sigma
self.gain = gain
def fn(self, x, y, *args, **kwargs):
r"""
Computes the distance :math:`\distance{x}{y}` i.e.
.. math::
\distance{x}{y} = \frac{1}{2}\|x-y\|^2
:param torch.Tensor u: Variable :math:`x` at which the data fidelity is computed.
:param torch.Tensor y: Data :math:`y`.
:return: (:class:`torch.Tensor`) data fidelity :math:`\datafid{u}{y}` of size `B` with `B` the size of the batch.
"""
norm = 1.0 / (self.sigma**2 + y * self.gain)
z = (x - y) * norm
d = 0.5 * torch.norm(z.reshape(z.shape[0], -1), p=2, dim=-1) ** 2
return d
def grad(self, x, y, *args, **kwargs):
r"""
Computes the gradient of :math:`\distancename`, that is :math:`\nabla_{x}\distance{x}{y}`, i.e.
.. math::
\nabla_{x}\distance{x}{y} = \frac{1}{\sigma^2} x-y
:param torch.Tensor x: Variable :math:`x` at which the gradient is computed.
:param torch.Tensor y: Observation :math:`y`.
:return: (:class:`torch.Tensor`) gradient of the distance function :math:`\nabla_{x}\distance{x}{y}`.
"""
norm = 1.0 / (self.sigma**2 + y * self.gain)
return (x - y) * norm
def prox(self, x, y, *args, gamma=1.0, **kwargs):
r"""
Proximal operator of :math:`\gamma \distance{x}{y} = \frac{\gamma}{2 \sigma^2} \|x-y\|^2`.
Computes :math:`\operatorname{prox}_{\gamma \distancename}`, i.e.
.. math::
\operatorname{prox}_{\gamma \distancename} = \underset{u}{\text{argmin}} \frac{\gamma}{2\sigma^2}\|u-y\|_2^2+\frac{1}{2}\|u-x\|_2^2
:param torch.Tensor x: Variable :math:`x` at which the proximity operator is computed.
:param torch.Tensor y: Data :math:`y`.
:param float gamma: thresholding parameter.
:return: (:class:`torch.Tensor`) proximity operator :math:`\operatorname{prox}_{\gamma \distancename}(x)`.
"""
norm = 1.0 / (self.sigma**2 + y * self.gain)
return (x + norm * gamma * y) / (1 + gamma * norm)
class PoissonGaussianDataFidelity(dinv.optim.DataFidelity):
r"""
Implementation of the data-fidelity as the normalized :math:`\ell_2` norm
.. math::
f(x) = \|\forw{x}-y\|^2_{\text{diag}(\sigma^2 + y \gamma)}
It can be used to define a log-likelihood function associated with Poisson Gaussian noise
by setting an appropriate noise level :math:`\sigma`.
:param float sigma: Standard deviation of the noise to be used as a normalisation factor.
:param float gain: Gain factor of the data-fidelity term.
"""
def __init__(self, sigma=1.0, gain=0.):
super().__init__()
self.d = PoissonGaussianDistance(sigma=sigma, gain=gain)
self.gain = gain
self.sigma = sigma
def prox(self, x, y, physics, gamma=1.0, *args, **kwargs):
r"""
Proximal operator of :math:`\gamma \datafid{Ax}{y} = \frac{\gamma}{2\sigma^2}\|Ax-y\|^2`.
Computes :math:`\operatorname{prox}_{\gamma \datafidname}`, i.e.
.. math::
\operatorname{prox}_{\gamma \datafidname} = \underset{u}{\text{argmin}} \frac{\gamma}{2\sigma^2}\|Au-y\|_2^2+\frac{1}{2}\|u-x\|_2^2
:param torch.Tensor x: Variable :math:`x` at which the proximity operator is computed.
:param torch.Tensor y: Data :math:`y`.
:param deepinv.physics.Physics physics: physics model.
:param float gamma: stepsize of the proximity operator.
:return: (:class:`torch.Tensor`) proximity operator :math:`\operatorname{prox}_{\gamma \datafidname}(x)`.
"""
assert isinstance(physics, dinv.physics.LinearPhysics), "not implemented for non-linear physics"
if isinstance(physics, dinv.physics.StackedPhysics):
device=y[0].device
noise_model = physics[-1].noise_model
else:
device=y.device
noise_model = physics.noise_model
if hasattr(noise_model, "gain"):
self.gain = noise_model.gain.detach().to(device)
if hasattr(noise_model, "sigma"):
self.sigma = noise_model.sigma.detach().to(device)
# Ensure sigma is a tensor and reshape if necessary
if isinstance(self.sigma, float):
self.sigma = torch.tensor([self.sigma], device=device)
if self.sigma.ndim == 0 :
self.sigma = self.sigma.unsqueeze(0).to(device)
# Ensure gain is a tensor and reshape if necessary
if isinstance(self.gain, float):
self.gain = torch.tensor([self.gain], device=device)
if self.gain.ndim == 0 :
self.gain = self.gain.unsqueeze(0).to(device)
if self.gain[0] > 0 :
norm = gamma / (self.sigma[:, None, None, None]**2 + y * self.gain[:, None, None, None])
else :
norm = gamma / (self.sigma[:, None, None, None]**2)
A = lambda u: physics.A_adjoint(physics.A(u)*norm) + u
b = physics.A_adjoint(norm*y) + x
return deepinv.optim.utils.conjugate_gradient(A, b, init=x, max_iter=3, tol=1e-3)
from deepinv.optim.optim_iterators import OptimIterator, fStep, gStep
class myHQSIteration(OptimIterator):
r"""
Single iteration of half-quadratic splitting.
Class for a single iteration of the Half-Quadratic Splitting (HQS) algorithm for minimising :math:`f(x) + \lambda \regname(x)`.
The iteration is given by
.. math::
\begin{equation*}
\begin{aligned}
u_{k} &= \operatorname{prox}_{\gamma f}(x_k) \\
x_{k+1} &= \operatorname{prox}_{\sigma \lambda \regname}(u_k).
\end{aligned}
\end{equation*}
where :math:`\gamma` and :math:`\sigma` are step-sizes. Note that this algorithm does not converge to
a minimizer of :math:`f(x) + \lambda \regname(x)`, but instead to a minimizer of
:math:`\gamma\, ^1f+\sigma \lambda \regname`, where :math:`^1f` denotes
the Moreau envelope of :math:`f`
"""
def __init__(self, **kwargs):
super(myHQSIteration, self).__init__(**kwargs)
self.g_step = mygStepHQS(**kwargs)
self.f_step = myfStepHQS(**kwargs)
self.requires_prox_g = True
class myfStepHQS(fStep):
r"""
HQS fStep module.
"""
def __init__(self, **kwargs):
super(myfStepHQS, self).__init__(**kwargs)
def forward(self, x, cur_data_fidelity, cur_params, y, physics):
r"""
Single proximal step on the data-fidelity term :math:`f`.
:param torch.Tensor x: Current iterate :math:`x_k`.
:param deepinv.optim.DataFidelity cur_data_fidelity: Instance of the DataFidelity class defining the current data_fidelity.
:param dict cur_params: Dictionary containing the current parameters of the algorithm.
:param torch.Tensor y: Input data.
:param deepinv.physics.Physics physics: Instance of the physics modeling the data-fidelity term.
"""
return cur_data_fidelity.prox(x, y, physics, gamma=cur_params["stepsize"])
class mygStepHQS(gStep):
r"""
HQS gStep module.
"""
def __init__(self, **kwargs):
super(mygStepHQS, self).__init__(**kwargs)
def forward(self, x, cur_prior, cur_params):
r"""
Single proximal step on the prior term :math:`\lambda \regname`.
:param torch.Tensor x: Current iterate :math:`x_k`.
:param dict cur_prior: Class containing the current prior.
:param dict cur_params: Dictionary containing the current parameters of the algorithm.
"""
return cur_prior.prox(
x,
sigma_denoiser = cur_params["g_param"],
gain_denoiser = cur_params["gain_param"],
gamma=cur_params["lambda"] * cur_params["stepsize"],
)
def get_unrolled_architecture(gain_param_init = 1e-3, weight_tied = True, model = None, device = 'cpu'):
# Unrolled optimization algorithm parameters
max_iter = 8 # number of unfolded layers
# Select the data fidelity term
# Set up the trainable denoising prior
# Here the prior model is common for all iterations
if model is not None :
denoiser = model.to(device)
else :
denoiser = dinv.models.DRUNet(
pretrained= '/lustre/fswork/projects/rech/nyd/commun/mterris/base_checkpoints/drunet_deepinv_color_finetune_22k.pth',
).to(device)
class myPnP(PnP):
r"""
Gradient-Step Denoiser prior.
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
def prox(self, x, sigma_denoiser, gain_denoiser, *args, **kwargs):
if not self.training:
pad = (-x.size(-2) % 8, -x.size(-1) % 8)
x = torch.nn.functional.pad(x, (0, pad[1], 0, pad[0]), mode="constant")
out = self.denoiser(x, sigma=sigma_denoiser, gamma=gain_denoiser)
if not self.training:
out = out[..., : -pad[0] or None, : -pad[1] or None]
return out
data_fidelity = PoissonGaussianDataFidelity()
if not weight_tied :
prior = [myPnP(denoiser=copy.deepcopy(denoiser)) for i in range(max_iter)]
else :
prior = [myPnP(denoiser=denoiser)]
def get_DPIR_params(noise_level_img, max_iter=8):
r"""
Default parameters for the DPIR Plug-and-Play algorithm.
:param float noise_level_img: Noise level of the input image.
"""
s1 = 49.0 / 255.0
s2 = noise_level_img
sigma_denoiser = np.logspace(np.log10(s1), np.log10(s2), max_iter).astype(
np.float32
)
stepsize = (sigma_denoiser / max(0.01, noise_level_img)) ** 2
lamb = 1 / 0.23
return list(sigma_denoiser), list(lamb * stepsize)
sigma_denoiser, stepsize = get_DPIR_params(0.05)
stepsize = torch.tensor(stepsize) * (torch.tensor(sigma_denoiser)**2)
gain_denoiser = [gain_param_init]*len(sigma_denoiser)
params_algo = {"stepsize": stepsize, "g_param": sigma_denoiser, "gain_param": gain_denoiser}
trainable_params = [
"g_param",
"gain_param"
"stepsize",
] # define which parameters from 'params_algo' are trainable
# Define the unfolded trainable model.
model = unfolded_builder(
iteration=myHQSIteration(),
params_algo=params_algo.copy(),
trainable_params=trainable_params,
data_fidelity=data_fidelity,
max_iter=max_iter,
prior=prior,
device=device,
)
return model.to(device)