File size: 2,717 Bytes
12a4d59
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
import torch
from deepinv.physics import Physics, LinearPhysics, Downsampling

class Upsampling(Downsampling):
    def A(self, x, **kwargs):
        return super().A_adjoint(x, **kwargs)

    def A_adjoint(self, y, **kwargs):
        return super().A(y, **kwargs)

    def prox_l2(self, z, y, gamma, **kwargs):
        return super().prox_l2(z, y, gamma, **kwargs)


class MultiScalePhysics(Physics):
    def __init__(self, physics, img_shape, filter="sinc", scales=[2, 4, 8], device='cpu', **kwargs):
        super().__init__(noise_model=physics.noise_model, **kwargs)
        self.base = physics
        self.scales = scales
        self.img_shape = img_shape
        self.Upsamplings = [Upsampling(img_size=img_shape, filter=filter, factor=factor, device=device) for factor in scales]
        self.scale = 0

    def set_scale(self, scale):
        if scale is not None:
            self.scale = scale

    def A(self, x, scale=None, **kwargs):
        self.set_scale(scale)
        if self.scale == 0:
            return self.base.A(x, **kwargs)
        else:
            return self.base.A(self.Upsamplings[self.scale - 1].A(x), **kwargs)

    def downsample(self, x, scale=None):
        self.set_scale(scale)
        if self.scale == 0:
            return x
        else:
            return self.Upsamplings[self.scale - 1].A_adjoint(x)

    def upsample(self, x, scale=None):
        self.set_scale(scale)
        if self.scale == 0:
            return x
        else:
            return self.Upsamplings[self.scale - 1].A(x)

    def update_parameters(self, **kwargs):
        self.base.update_parameters(**kwargs)


class MultiScaleLinearPhysics(MultiScalePhysics, LinearPhysics):
    def __init__(self, physics, img_shape, filter="sinc", scales=[2, 4, 8], **kwargs):
        super().__init__(physics=physics, img_shape=img_shape, filter=filter, scales=scales, **kwargs)

    def A_adjoint(self, y, scale=None, **kwargs):
        self.set_scale(scale)
        y = self.base.A_adjoint(y, **kwargs)
        if self.scale == 0:
            return y
        else:
            return self.Upsamplings[self.scale - 1].A_adjoint(y)


class Pad(LinearPhysics):
    def __init__(self, physics, pad):
        super().__init__(noise_model=physics.noise_model)
        self.base = physics
        self.pad = pad

    def A(self, x):
        return self.base.A(x[..., self.pad[0]:, self.pad[1]:])

    def A_adjoint(self, y):
        y = self.base.A_adjoint(y)
        y = torch.nn.functional.pad(y, (self.pad[1], 0, self.pad[0], 0))
        return y

    def remove_pad(self, x):
        return x[..., self.pad[0]:, self.pad[1]:]

    def update_parameters(self, **kwargs):
        self.base.update_parameters(**kwargs)