File size: 4,442 Bytes
12a4d59
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
import torch
import torch.nn as nn
import deepinv as dinv

from models.unext_wip import UNeXt
from models.unrolled_dpir import get_unrolled_architecture
from models.PDNet import get_PDNet_architecture
from physics.multiscale import Pad


class ArtifactRemoval(nn.Module):
    def __init__(self, backbone_net, pinv=False, ckpt_path=None, device=None, fm_mode=False):
        super().__init__()
        self.pinv = pinv
        self.backbone_net = backbone_net
        self.fm_mode = fm_mode

        if ckpt_path is not None:
            self.backbone_net.load_state_dict(torch.load(ckpt_path), strict=True)
            self.backbone_net.eval()

        if type(self.backbone_net).__name__ == "UNetRes":
            for _, v in self.backbone_net.named_parameters():
                v.requires_grad = False
            self.backbone_net = self.backbone_net.to(device)

    def forward_basic(self, y=None, physics=None, x_in=None, t=None, **kwargs):
        if physics is None:
            physics = dinv.physics.Denoising(noise_model=dinv.physics.GaussianNoise(sigma=0.), device=y.device)

        if not self.training:
            x_temp = physics.A_adjoint(y)
            pad = (-x_temp.size(-2) % 8, -x_temp.size(-1) % 8)
            physics = Pad(physics, pad)

        x_in = physics.A_adjoint(y) if not self.pinv else physics.A_dagger(y)

        sigma = getattr(physics.noise_model, "sigma", 1e-3)
        gamma = getattr(physics.noise_model, "gain", 1e-3)

        out = self.backbone_net(x_in, physics=physics, y=y, sigma=sigma, gamma=gamma, t=t)

        if not self.training:
            out = physics.remove_pad(out)

        return out

    def forward(self, y=None, physics=None, x_in=None, **kwargs):
        return self.forward_basic(physics=physics, y=y, x_in=x_in, **kwargs)


def get_model(
    model_name="unext_emb_physics_config_C",
    device="cpu",
    in_channels=[1, 2, 3],
    conv_type="base",
    pool_type="base",
    layer_scale_init_value=1e-6,
    init_type="ortho",
    gain_init_conv=1.0,
    gain_init_linear=1.0,
    drop_prob=0.0,
    replk=False,
    mult_fact=4,
    antialias="gaussian",
    nc_base=64,
    cond_type="base",
    pretrained_pth=None,
    weight_tied=True,
    N=4,
    c_mult=1,
    depth_encoding=1,
    relu_in_encoding=False,
    skip_in_encoding=True,
):
    model_name = model_name.lower()
    nc = [nc_base * 2**i for i in range(4)]

    if model_name == "pdnet":
        return get_PDNet_architecture(in_channels=in_channels, out_channels=in_channels, device=device)

    elif model_name == "unrolled_dpir":
        model = UNeXt(
            in_channels=in_channels,
            out_channels=in_channels,
            device=device,
            conv_type=conv_type,
            pool_type=pool_type,
            layer_scale_init_value=layer_scale_init_value,
            init_type=init_type,
            gain_init_conv=gain_init_conv,
            gain_init_linear=gain_init_linear,
            drop_prob=drop_prob,
            replk=replk,
            mult_fact=mult_fact,
            antialias=antialias,
            nc=nc,
            cond_type=cond_type,
            emb_physics=False,
            config=None,
            pretrained_pth=pretrained_pth,
        ).to(device)
        model = get_unrolled_architecture(model=model, weight_tied=weight_tied, device=device)
        return ArtifactRemoval(model, pinv=True, device=device)

    elif model_name == "unext_emb_physics_config_c":
        model = UNeXt(
            in_channels=in_channels,
            out_channels=in_channels,
            device=device,
            conv_type=conv_type,
            pool_type=pool_type,
            layer_scale_init_value=layer_scale_init_value,
            init_type=init_type,
            gain_init_conv=gain_init_conv,
            gain_init_linear=gain_init_linear,
            drop_prob=drop_prob,
            replk=replk,
            mult_fact=mult_fact,
            antialias=antialias,
            nc=nc,
            cond_type=cond_type,
            emb_physics=True,
            config="C",
            pretrained_pth=pretrained_pth,
            N=N,
            c_mult=c_mult,
            depth_encoding=depth_encoding,
            relu_in_encoding=relu_in_encoding,
            skip_in_encoding=skip_in_encoding,
        ).to(device)
        return ArtifactRemoval(model, pinv=False, device=device)

    else:
        raise ValueError(f"Model {model_name} is not supported.")