File size: 4,931 Bytes
12a4d59
 
 
 
 
 
 
 
 
 
 
33dc149
 
 
 
 
 
 
 
12a4d59
33dc149
12a4d59
 
 
 
 
 
 
 
 
 
 
 
 
33dc149
12a4d59
33dc149
 
 
 
 
 
12a4d59
 
 
 
 
 
 
 
 
 
33dc149
 
 
 
 
 
 
 
 
12a4d59
 
 
 
 
 
 
 
33dc149
 
 
 
 
12a4d59
 
 
 
 
 
33dc149
12a4d59
 
 
 
 
 
 
 
 
 
 
 
33dc149
12a4d59
 
 
 
 
 
 
 
33dc149
 
 
 
 
 
 
 
 
12a4d59
 
 
 
 
 
33dc149
 
 
 
 
12a4d59
 
 
 
33dc149
12a4d59
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
import torch
import torch.nn as nn
import deepinv as dinv

from models.unext_wip import UNeXt
from models.unrolled_dpir import get_unrolled_architecture
from models.PDNet import get_PDNet_architecture
from physics.multiscale import Pad


class ArtifactRemoval(nn.Module):
    r"""
    Artifact removal architecture :math:`\phi(A^{\top}y)`.

    This differs from the dinv.models.ArtifactRemoval in that it allows to forward the physics.

    In the end we should not use this for unext !!
    """

    def __init__(self, backbone_net, pinv=False, ckpt_path=None, device=None, fm_mode=False):
        super(ArtifactRemoval, self).__init__()
        self.pinv = pinv
        self.backbone_net = backbone_net
        self.fm_mode = fm_mode

        if ckpt_path is not None:
            self.backbone_net.load_state_dict(torch.load(ckpt_path), strict=True)
            self.backbone_net.eval()

        if type(self.backbone_net).__name__ == "UNetRes":
            for _, v in self.backbone_net.named_parameters():
                v.requires_grad = False
            self.backbone_net = self.backbone_net.to(device)


    def forward_basic(self, y=None, physics=None, x_in=None, t=None, **kwargs):
        r"""
        Reconstructs a signal estimate from measurements y

        :param torch.tensor y: measurements
        :param deepinv.physics.Physics physics: forward operator
        """
        if physics is None:
            physics = dinv.physics.Denoising(noise_model=dinv.physics.GaussianNoise(sigma=0.), device=y.device)

        if not self.training:
            x_temp = physics.A_adjoint(y)
            pad = (-x_temp.size(-2) % 8, -x_temp.size(-1) % 8)
            physics = Pad(physics, pad)

        x_in = physics.A_adjoint(y) if not self.pinv else physics.A_dagger(y)

        if hasattr(physics.noise_model, "sigma"):
            sigma = physics.noise_model.sigma
        else:
            sigma = 1e-3  # WARNING: this is a default value that we may not want to use?

        if hasattr(physics.noise_model, "gain"):
            gamma = physics.noise_model.gain
        else:
            gamma = 1e-3  # WARNING: this is a default value that we may not want to use?

        out = self.backbone_net(x_in, physics=physics, y=y, sigma=sigma, gamma=gamma, t=t)

        if not self.training:
            out = physics.remove_pad(out)

        return out

    def forward(self,  y=None, physics=None, x_in=None, **kwargs):
        if 'unext' in type(self.backbone_net).__name__.lower():
            return self.forward_basic(physics=physics, y=y, x_in=x_in, **kwargs)
        else:
            return self.backbone_net(physics=physics, y=y, **kwargs)


def get_model(
    model_name="unext_emb_physics_config_C",
    device="cpu",
    in_channels=[1, 2, 3],
    grayscale=False,
    conv_type="base",
    pool_type="base",
    layer_scale_init_value=1e-6,
    init_type="ortho",
    gain_init_conv=1.0,
    gain_init_linear=1.0,
    drop_prob=0.0,
    replk=False,
    mult_fact=4,
    antialias="gaussian",
    nc_base=64,
    cond_type="base",
    blind=False,
    pretrained_pth=None,
    weight_tied=True,
    N=4,
    c_mult=1,
    depth_encoding=1,
    relu_in_encoding=False,
    skip_in_encoding=True,
):
    """
    Load the model.

    :param str model_name: name of the model
    :param str device: device
    :param bool grayscale: if True, the model is trained on grayscale images
    :param bool train: if True, the model is trained
    :return: model
    """
    model_name = model_name.lower()

    if model_name == "pdnet":
        return get_PDNet_architecture(in_channels=in_channels, out_channels=in_channels, device=device)

    elif model_name == "unext_emb_physics_config_c":
        n_chan = [1, 2, 3] # 6 for old head grayscale, complex and color = 1 + 2 + 3
        residual = True if "residual" in model_name else False
        nc = [nc_base * 2**i for i in range(4)]


        model = UNeXt(
            in_channels=in_channels,
            out_channels=in_channels,
            device=device,
            residual=residual,
            conv_type=conv_type,
            pool_type=pool_type,
            layer_scale_init_value=layer_scale_init_value,
            init_type=init_type,
            gain_init_conv=gain_init_conv,
            gain_init_linear=gain_init_linear,
            drop_prob=drop_prob,
            replk=replk,
            mult_fact=mult_fact,
            antialias=antialias,
            nc=nc,
            cond_type=cond_type,
            emb_physics=True,
            config="C",
            pretrained_pth=pretrained_pth,
            N=N,
            c_mult=c_mult,
            depth_encoding=depth_encoding,
            relu_in_encoding=relu_in_encoding,
            skip_in_encoding=skip_in_encoding,
        ).to(device)
        return ArtifactRemoval(model, pinv=False, device=device)

    else:
        raise ValueError(f"Model {model_name} is not supported.")