denoising / physics /multiscale.py
Yonuts's picture
gradio demo
12a4d59
import torch
from deepinv.physics import Physics, LinearPhysics, Downsampling
class Upsampling(Downsampling):
def A(self, x, **kwargs):
return super().A_adjoint(x, **kwargs)
def A_adjoint(self, y, **kwargs):
return super().A(y, **kwargs)
def prox_l2(self, z, y, gamma, **kwargs):
return super().prox_l2(z, y, gamma, **kwargs)
class MultiScalePhysics(Physics):
def __init__(self, physics, img_shape, filter="sinc", scales=[2, 4, 8], device='cpu', **kwargs):
super().__init__(noise_model=physics.noise_model, **kwargs)
self.base = physics
self.scales = scales
self.img_shape = img_shape
self.Upsamplings = [Upsampling(img_size=img_shape, filter=filter, factor=factor, device=device) for factor in scales]
self.scale = 0
def set_scale(self, scale):
if scale is not None:
self.scale = scale
def A(self, x, scale=None, **kwargs):
self.set_scale(scale)
if self.scale == 0:
return self.base.A(x, **kwargs)
else:
return self.base.A(self.Upsamplings[self.scale - 1].A(x), **kwargs)
def downsample(self, x, scale=None):
self.set_scale(scale)
if self.scale == 0:
return x
else:
return self.Upsamplings[self.scale - 1].A_adjoint(x)
def upsample(self, x, scale=None):
self.set_scale(scale)
if self.scale == 0:
return x
else:
return self.Upsamplings[self.scale - 1].A(x)
def update_parameters(self, **kwargs):
self.base.update_parameters(**kwargs)
class MultiScaleLinearPhysics(MultiScalePhysics, LinearPhysics):
def __init__(self, physics, img_shape, filter="sinc", scales=[2, 4, 8], **kwargs):
super().__init__(physics=physics, img_shape=img_shape, filter=filter, scales=scales, **kwargs)
def A_adjoint(self, y, scale=None, **kwargs):
self.set_scale(scale)
y = self.base.A_adjoint(y, **kwargs)
if self.scale == 0:
return y
else:
return self.Upsamplings[self.scale - 1].A_adjoint(y)
class Pad(LinearPhysics):
def __init__(self, physics, pad):
super().__init__(noise_model=physics.noise_model)
self.base = physics
self.pad = pad
def A(self, x):
return self.base.A(x[..., self.pad[0]:, self.pad[1]:])
def A_adjoint(self, y):
y = self.base.A_adjoint(y)
y = torch.nn.functional.pad(y, (self.pad[1], 0, self.pad[0], 0))
return y
def remove_pad(self, x):
return x[..., self.pad[0]:, self.pad[1]:]
def update_parameters(self, **kwargs):
self.base.update_parameters(**kwargs)