Spaces:
Sleeping
Sleeping
import numpy as np | |
import deepinv | |
import torch | |
import deepinv as dinv | |
from deepinv.optim.data_fidelity import L2 | |
from deepinv.optim.prior import PnP | |
from deepinv.unfolded import unfolded_builder | |
import copy | |
import deepinv.optim.utils | |
class PoissonGaussianDistance(dinv.optim.Distance): | |
r""" | |
Implementation of :math:`\distancename` as the normalized :math:`\ell_2` norm | |
.. math:: | |
f(x) = (x-y)^{T}\Sigma_y(x-y) | |
with :math:`\Sigma_y=\text{diag}(gamma y + \sigma^2)` | |
:param float sigma: Gaussian noise parameter. Default: 1. | |
:param float gain: Poisson noise parameter. Default 0. | |
""" | |
def __init__(self, sigma=1.0, gain=0.): | |
super().__init__() | |
self.sigma = sigma | |
self.gain = gain | |
def fn(self, x, y, *args, **kwargs): | |
r""" | |
Computes the distance :math:`\distance{x}{y}` i.e. | |
.. math:: | |
\distance{x}{y} = \frac{1}{2}\|x-y\|^2 | |
:param torch.Tensor u: Variable :math:`x` at which the data fidelity is computed. | |
:param torch.Tensor y: Data :math:`y`. | |
:return: (:class:`torch.Tensor`) data fidelity :math:`\datafid{u}{y}` of size `B` with `B` the size of the batch. | |
""" | |
norm = 1.0 / (self.sigma**2 + y * self.gain) | |
z = (x - y) * norm | |
d = 0.5 * torch.norm(z.reshape(z.shape[0], -1), p=2, dim=-1) ** 2 | |
return d | |
def grad(self, x, y, *args, **kwargs): | |
r""" | |
Computes the gradient of :math:`\distancename`, that is :math:`\nabla_{x}\distance{x}{y}`, i.e. | |
.. math:: | |
\nabla_{x}\distance{x}{y} = \frac{1}{\sigma^2} x-y | |
:param torch.Tensor x: Variable :math:`x` at which the gradient is computed. | |
:param torch.Tensor y: Observation :math:`y`. | |
:return: (:class:`torch.Tensor`) gradient of the distance function :math:`\nabla_{x}\distance{x}{y}`. | |
""" | |
norm = 1.0 / (self.sigma**2 + y * self.gain) | |
return (x - y) * norm | |
def prox(self, x, y, *args, gamma=1.0, **kwargs): | |
r""" | |
Proximal operator of :math:`\gamma \distance{x}{y} = \frac{\gamma}{2 \sigma^2} \|x-y\|^2`. | |
Computes :math:`\operatorname{prox}_{\gamma \distancename}`, i.e. | |
.. math:: | |
\operatorname{prox}_{\gamma \distancename} = \underset{u}{\text{argmin}} \frac{\gamma}{2\sigma^2}\|u-y\|_2^2+\frac{1}{2}\|u-x\|_2^2 | |
:param torch.Tensor x: Variable :math:`x` at which the proximity operator is computed. | |
:param torch.Tensor y: Data :math:`y`. | |
:param float gamma: thresholding parameter. | |
:return: (:class:`torch.Tensor`) proximity operator :math:`\operatorname{prox}_{\gamma \distancename}(x)`. | |
""" | |
norm = 1.0 / (self.sigma**2 + y * self.gain) | |
return (x + norm * gamma * y) / (1 + gamma * norm) | |
class PoissonGaussianDataFidelity(dinv.optim.DataFidelity): | |
r""" | |
Implementation of the data-fidelity as the normalized :math:`\ell_2` norm | |
.. math:: | |
f(x) = \|\forw{x}-y\|^2_{\text{diag}(\sigma^2 + y \gamma)} | |
It can be used to define a log-likelihood function associated with Poisson Gaussian noise | |
by setting an appropriate noise level :math:`\sigma`. | |
:param float sigma: Standard deviation of the noise to be used as a normalisation factor. | |
:param float gain: Gain factor of the data-fidelity term. | |
""" | |
def __init__(self, sigma=1.0, gain=0.): | |
super().__init__() | |
self.d = PoissonGaussianDistance(sigma=sigma, gain=gain) | |
self.gain = gain | |
self.sigma = sigma | |
def prox(self, x, y, physics, gamma=1.0, *args, **kwargs): | |
r""" | |
Proximal operator of :math:`\gamma \datafid{Ax}{y} = \frac{\gamma}{2\sigma^2}\|Ax-y\|^2`. | |
Computes :math:`\operatorname{prox}_{\gamma \datafidname}`, i.e. | |
.. math:: | |
\operatorname{prox}_{\gamma \datafidname} = \underset{u}{\text{argmin}} \frac{\gamma}{2\sigma^2}\|Au-y\|_2^2+\frac{1}{2}\|u-x\|_2^2 | |
:param torch.Tensor x: Variable :math:`x` at which the proximity operator is computed. | |
:param torch.Tensor y: Data :math:`y`. | |
:param deepinv.physics.Physics physics: physics model. | |
:param float gamma: stepsize of the proximity operator. | |
:return: (:class:`torch.Tensor`) proximity operator :math:`\operatorname{prox}_{\gamma \datafidname}(x)`. | |
""" | |
assert isinstance(physics, dinv.physics.LinearPhysics), "not implemented for non-linear physics" | |
if isinstance(physics, dinv.physics.StackedPhysics): | |
device=y[0].device | |
noise_model = physics[-1].noise_model | |
else: | |
device=y.device | |
noise_model = physics.noise_model | |
if hasattr(noise_model, "gain"): | |
self.gain = noise_model.gain.detach().to(device) | |
if hasattr(noise_model, "sigma"): | |
self.sigma = noise_model.sigma.detach().to(device) | |
# Ensure sigma is a tensor and reshape if necessary | |
if isinstance(self.sigma, float): | |
self.sigma = torch.tensor([self.sigma], device=device) | |
if self.sigma.ndim == 0 : | |
self.sigma = self.sigma.unsqueeze(0).to(device) | |
# Ensure gain is a tensor and reshape if necessary | |
if isinstance(self.gain, float): | |
self.gain = torch.tensor([self.gain], device=device) | |
if self.gain.ndim == 0 : | |
self.gain = self.gain.unsqueeze(0).to(device) | |
if self.gain[0] > 0 : | |
norm = gamma / (self.sigma[:, None, None, None]**2 + y * self.gain[:, None, None, None]) | |
else : | |
norm = gamma / (self.sigma[:, None, None, None]**2) | |
A = lambda u: physics.A_adjoint(physics.A(u)*norm) + u | |
b = physics.A_adjoint(norm*y) + x | |
return deepinv.optim.utils.conjugate_gradient(A, b, init=x, max_iter=3, tol=1e-3) | |
from deepinv.optim.optim_iterators import OptimIterator, fStep, gStep | |
class myHQSIteration(OptimIterator): | |
r""" | |
Single iteration of half-quadratic splitting. | |
Class for a single iteration of the Half-Quadratic Splitting (HQS) algorithm for minimising :math:`f(x) + \lambda \regname(x)`. | |
The iteration is given by | |
.. math:: | |
\begin{equation*} | |
\begin{aligned} | |
u_{k} &= \operatorname{prox}_{\gamma f}(x_k) \\ | |
x_{k+1} &= \operatorname{prox}_{\sigma \lambda \regname}(u_k). | |
\end{aligned} | |
\end{equation*} | |
where :math:`\gamma` and :math:`\sigma` are step-sizes. Note that this algorithm does not converge to | |
a minimizer of :math:`f(x) + \lambda \regname(x)`, but instead to a minimizer of | |
:math:`\gamma\, ^1f+\sigma \lambda \regname`, where :math:`^1f` denotes | |
the Moreau envelope of :math:`f` | |
""" | |
def __init__(self, **kwargs): | |
super(myHQSIteration, self).__init__(**kwargs) | |
self.g_step = mygStepHQS(**kwargs) | |
self.f_step = myfStepHQS(**kwargs) | |
self.requires_prox_g = True | |
class myfStepHQS(fStep): | |
r""" | |
HQS fStep module. | |
""" | |
def __init__(self, **kwargs): | |
super(myfStepHQS, self).__init__(**kwargs) | |
def forward(self, x, cur_data_fidelity, cur_params, y, physics): | |
r""" | |
Single proximal step on the data-fidelity term :math:`f`. | |
:param torch.Tensor x: Current iterate :math:`x_k`. | |
:param deepinv.optim.DataFidelity cur_data_fidelity: Instance of the DataFidelity class defining the current data_fidelity. | |
:param dict cur_params: Dictionary containing the current parameters of the algorithm. | |
:param torch.Tensor y: Input data. | |
:param deepinv.physics.Physics physics: Instance of the physics modeling the data-fidelity term. | |
""" | |
return cur_data_fidelity.prox(x, y, physics, gamma=cur_params["stepsize"]) | |
class mygStepHQS(gStep): | |
r""" | |
HQS gStep module. | |
""" | |
def __init__(self, **kwargs): | |
super(mygStepHQS, self).__init__(**kwargs) | |
def forward(self, x, cur_prior, cur_params): | |
r""" | |
Single proximal step on the prior term :math:`\lambda \regname`. | |
:param torch.Tensor x: Current iterate :math:`x_k`. | |
:param dict cur_prior: Class containing the current prior. | |
:param dict cur_params: Dictionary containing the current parameters of the algorithm. | |
""" | |
return cur_prior.prox( | |
x, | |
sigma_denoiser = cur_params["g_param"], | |
gain_denoiser = cur_params["gain_param"], | |
gamma=cur_params["lambda"] * cur_params["stepsize"], | |
) | |
def get_unrolled_architecture(gain_param_init = 1e-3, weight_tied = True, model = None, device = 'cpu'): | |
# Unrolled optimization algorithm parameters | |
max_iter = 8 # number of unfolded layers | |
# Select the data fidelity term | |
# Set up the trainable denoising prior | |
# Here the prior model is common for all iterations | |
if model is not None : | |
denoiser = model.to(device) | |
else : | |
denoiser = dinv.models.DRUNet( | |
pretrained= '/lustre/fswork/projects/rech/nyd/commun/mterris/base_checkpoints/drunet_deepinv_color_finetune_22k.pth', | |
).to(device) | |
class myPnP(PnP): | |
r""" | |
Gradient-Step Denoiser prior. | |
""" | |
def __init__(self, *args, **kwargs): | |
super().__init__(*args, **kwargs) | |
def prox(self, x, sigma_denoiser, gain_denoiser, *args, **kwargs): | |
if not self.training: | |
pad = (-x.size(-2) % 8, -x.size(-1) % 8) | |
x = torch.nn.functional.pad(x, (0, pad[1], 0, pad[0]), mode="constant") | |
out = self.denoiser(x, sigma=sigma_denoiser, gamma=gain_denoiser) | |
if not self.training: | |
out = out[..., : -pad[0] or None, : -pad[1] or None] | |
return out | |
data_fidelity = PoissonGaussianDataFidelity() | |
if not weight_tied : | |
prior = [myPnP(denoiser=copy.deepcopy(denoiser)) for i in range(max_iter)] | |
else : | |
prior = [myPnP(denoiser=denoiser)] | |
def get_DPIR_params(noise_level_img, max_iter=8): | |
r""" | |
Default parameters for the DPIR Plug-and-Play algorithm. | |
:param float noise_level_img: Noise level of the input image. | |
""" | |
s1 = 49.0 / 255.0 | |
s2 = noise_level_img | |
sigma_denoiser = np.logspace(np.log10(s1), np.log10(s2), max_iter).astype( | |
np.float32 | |
) | |
stepsize = (sigma_denoiser / max(0.01, noise_level_img)) ** 2 | |
lamb = 1 / 0.23 | |
return list(sigma_denoiser), list(lamb * stepsize) | |
sigma_denoiser, stepsize = get_DPIR_params(0.05) | |
stepsize = torch.tensor(stepsize) * (torch.tensor(sigma_denoiser)**2) | |
gain_denoiser = [gain_param_init]*len(sigma_denoiser) | |
params_algo = {"stepsize": stepsize, "g_param": sigma_denoiser, "gain_param": gain_denoiser} | |
trainable_params = [ | |
"g_param", | |
"gain_param" | |
"stepsize", | |
] # define which parameters from 'params_algo' are trainable | |
# Define the unfolded trainable model. | |
model = unfolded_builder( | |
iteration=myHQSIteration(), | |
params_algo=params_algo.copy(), | |
trainable_params=trainable_params, | |
data_fidelity=data_fidelity, | |
max_iter=max_iter, | |
prior=prior, | |
device=device, | |
) | |
return model.to(device) | |