Spaces:
Sleeping
Sleeping
import torch | |
from deepinv.physics import Physics, LinearPhysics, Downsampling | |
class Upsampling(Downsampling): | |
def A(self, x, **kwargs): | |
return super().A_adjoint(x, **kwargs) | |
def A_adjoint(self, y, **kwargs): | |
return super().A(y, **kwargs) | |
def prox_l2(self, z, y, gamma, **kwargs): | |
return super().prox_l2(z, y, gamma, **kwargs) | |
class MultiScalePhysics(Physics): | |
def __init__(self, physics, img_shape, filter="sinc", scales=[2, 4, 8], device='cpu', **kwargs): | |
super().__init__(noise_model=physics.noise_model, **kwargs) | |
self.base = physics | |
self.scales = scales | |
self.img_shape = img_shape | |
self.Upsamplings = [Upsampling(img_size=img_shape, filter=filter, factor=factor, device=device) for factor in scales] | |
self.scale = 0 | |
def set_scale(self, scale): | |
if scale is not None: | |
self.scale = scale | |
def A(self, x, scale=None, **kwargs): | |
self.set_scale(scale) | |
if self.scale == 0: | |
return self.base.A(x, **kwargs) | |
else: | |
return self.base.A(self.Upsamplings[self.scale - 1].A(x), **kwargs) | |
def downsample(self, x, scale=None): | |
self.set_scale(scale) | |
if self.scale == 0: | |
return x | |
else: | |
return self.Upsamplings[self.scale - 1].A_adjoint(x) | |
def upsample(self, x, scale=None): | |
self.set_scale(scale) | |
if self.scale == 0: | |
return x | |
else: | |
return self.Upsamplings[self.scale - 1].A(x) | |
def update_parameters(self, **kwargs): | |
self.base.update_parameters(**kwargs) | |
class MultiScaleLinearPhysics(MultiScalePhysics, LinearPhysics): | |
def __init__(self, physics, img_shape, filter="sinc", scales=[2, 4, 8], **kwargs): | |
super().__init__(physics=physics, img_shape=img_shape, filter=filter, scales=scales, **kwargs) | |
def A_adjoint(self, y, scale=None, **kwargs): | |
self.set_scale(scale) | |
y = self.base.A_adjoint(y, **kwargs) | |
if self.scale == 0: | |
return y | |
else: | |
return self.Upsamplings[self.scale - 1].A_adjoint(y) | |
class Pad(LinearPhysics): | |
def __init__(self, physics, pad): | |
super().__init__(noise_model=physics.noise_model) | |
self.base = physics | |
self.pad = pad | |
def A(self, x): | |
return self.base.A(x[..., self.pad[0]:, self.pad[1]:]) | |
def A_adjoint(self, y): | |
y = self.base.A_adjoint(y) | |
y = torch.nn.functional.pad(y, (self.pad[1], 0, self.pad[0], 0)) | |
return y | |
def remove_pad(self, x): | |
return x[..., self.pad[0]:, self.pad[1]:] | |
def update_parameters(self, **kwargs): | |
self.base.update_parameters(**kwargs) |