import torch from deepinv.physics import Physics, LinearPhysics, Downsampling class Upsampling(Downsampling): def A(self, x, **kwargs): return super().A_adjoint(x, **kwargs) def A_adjoint(self, y, **kwargs): return super().A(y, **kwargs) def prox_l2(self, z, y, gamma, **kwargs): return super().prox_l2(z, y, gamma, **kwargs) class MultiScalePhysics(Physics): def __init__(self, physics, img_shape, filter="sinc", scales=[2, 4, 8], device='cpu', **kwargs): super().__init__(noise_model=physics.noise_model, **kwargs) self.base = physics self.scales = scales self.img_shape = img_shape self.Upsamplings = [Upsampling(img_size=img_shape, filter=filter, factor=factor, device=device) for factor in scales] self.scale = 0 def set_scale(self, scale): if scale is not None: self.scale = scale def A(self, x, scale=None, **kwargs): self.set_scale(scale) if self.scale == 0: return self.base.A(x, **kwargs) else: return self.base.A(self.Upsamplings[self.scale - 1].A(x), **kwargs) def downsample(self, x, scale=None): self.set_scale(scale) if self.scale == 0: return x else: return self.Upsamplings[self.scale - 1].A_adjoint(x) def upsample(self, x, scale=None): self.set_scale(scale) if self.scale == 0: return x else: return self.Upsamplings[self.scale - 1].A(x) def update_parameters(self, **kwargs): self.base.update_parameters(**kwargs) class MultiScaleLinearPhysics(MultiScalePhysics, LinearPhysics): def __init__(self, physics, img_shape, filter="sinc", scales=[2, 4, 8], **kwargs): super().__init__(physics=physics, img_shape=img_shape, filter=filter, scales=scales, **kwargs) def A_adjoint(self, y, scale=None, **kwargs): self.set_scale(scale) y = self.base.A_adjoint(y, **kwargs) if self.scale == 0: return y else: return self.Upsamplings[self.scale - 1].A_adjoint(y) class Pad(LinearPhysics): def __init__(self, physics, pad): super().__init__(noise_model=physics.noise_model) self.base = physics self.pad = pad def A(self, x): return self.base.A(x[..., self.pad[0]:, self.pad[1]:]) def A_adjoint(self, y): y = self.base.A_adjoint(y) y = torch.nn.functional.pad(y, (self.pad[1], 0, self.pad[0], 0)) return y def remove_pad(self, x): return x[..., self.pad[0]:, self.pad[1]:] def update_parameters(self, **kwargs): self.base.update_parameters(**kwargs)