Spaces:
Sleeping
Sleeping
File size: 4,931 Bytes
12a4d59 33dc149 12a4d59 33dc149 12a4d59 33dc149 12a4d59 33dc149 12a4d59 33dc149 12a4d59 33dc149 12a4d59 33dc149 12a4d59 33dc149 12a4d59 33dc149 12a4d59 33dc149 12a4d59 33dc149 12a4d59 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 |
import torch
import torch.nn as nn
import deepinv as dinv
from models.unext_wip import UNeXt
from models.unrolled_dpir import get_unrolled_architecture
from models.PDNet import get_PDNet_architecture
from physics.multiscale import Pad
class ArtifactRemoval(nn.Module):
r"""
Artifact removal architecture :math:`\phi(A^{\top}y)`.
This differs from the dinv.models.ArtifactRemoval in that it allows to forward the physics.
In the end we should not use this for unext !!
"""
def __init__(self, backbone_net, pinv=False, ckpt_path=None, device=None, fm_mode=False):
super(ArtifactRemoval, self).__init__()
self.pinv = pinv
self.backbone_net = backbone_net
self.fm_mode = fm_mode
if ckpt_path is not None:
self.backbone_net.load_state_dict(torch.load(ckpt_path), strict=True)
self.backbone_net.eval()
if type(self.backbone_net).__name__ == "UNetRes":
for _, v in self.backbone_net.named_parameters():
v.requires_grad = False
self.backbone_net = self.backbone_net.to(device)
def forward_basic(self, y=None, physics=None, x_in=None, t=None, **kwargs):
r"""
Reconstructs a signal estimate from measurements y
:param torch.tensor y: measurements
:param deepinv.physics.Physics physics: forward operator
"""
if physics is None:
physics = dinv.physics.Denoising(noise_model=dinv.physics.GaussianNoise(sigma=0.), device=y.device)
if not self.training:
x_temp = physics.A_adjoint(y)
pad = (-x_temp.size(-2) % 8, -x_temp.size(-1) % 8)
physics = Pad(physics, pad)
x_in = physics.A_adjoint(y) if not self.pinv else physics.A_dagger(y)
if hasattr(physics.noise_model, "sigma"):
sigma = physics.noise_model.sigma
else:
sigma = 1e-3 # WARNING: this is a default value that we may not want to use?
if hasattr(physics.noise_model, "gain"):
gamma = physics.noise_model.gain
else:
gamma = 1e-3 # WARNING: this is a default value that we may not want to use?
out = self.backbone_net(x_in, physics=physics, y=y, sigma=sigma, gamma=gamma, t=t)
if not self.training:
out = physics.remove_pad(out)
return out
def forward(self, y=None, physics=None, x_in=None, **kwargs):
if 'unext' in type(self.backbone_net).__name__.lower():
return self.forward_basic(physics=physics, y=y, x_in=x_in, **kwargs)
else:
return self.backbone_net(physics=physics, y=y, **kwargs)
def get_model(
model_name="unext_emb_physics_config_C",
device="cpu",
in_channels=[1, 2, 3],
grayscale=False,
conv_type="base",
pool_type="base",
layer_scale_init_value=1e-6,
init_type="ortho",
gain_init_conv=1.0,
gain_init_linear=1.0,
drop_prob=0.0,
replk=False,
mult_fact=4,
antialias="gaussian",
nc_base=64,
cond_type="base",
blind=False,
pretrained_pth=None,
weight_tied=True,
N=4,
c_mult=1,
depth_encoding=1,
relu_in_encoding=False,
skip_in_encoding=True,
):
"""
Load the model.
:param str model_name: name of the model
:param str device: device
:param bool grayscale: if True, the model is trained on grayscale images
:param bool train: if True, the model is trained
:return: model
"""
model_name = model_name.lower()
if model_name == "pdnet":
return get_PDNet_architecture(in_channels=in_channels, out_channels=in_channels, device=device)
elif model_name == "unext_emb_physics_config_c":
n_chan = [1, 2, 3] # 6 for old head grayscale, complex and color = 1 + 2 + 3
residual = True if "residual" in model_name else False
nc = [nc_base * 2**i for i in range(4)]
model = UNeXt(
in_channels=in_channels,
out_channels=in_channels,
device=device,
residual=residual,
conv_type=conv_type,
pool_type=pool_type,
layer_scale_init_value=layer_scale_init_value,
init_type=init_type,
gain_init_conv=gain_init_conv,
gain_init_linear=gain_init_linear,
drop_prob=drop_prob,
replk=replk,
mult_fact=mult_fact,
antialias=antialias,
nc=nc,
cond_type=cond_type,
emb_physics=True,
config="C",
pretrained_pth=pretrained_pth,
N=N,
c_mult=c_mult,
depth_encoding=depth_encoding,
relu_in_encoding=relu_in_encoding,
skip_in_encoding=skip_in_encoding,
).to(device)
return ArtifactRemoval(model, pinv=False, device=device)
else:
raise ValueError(f"Model {model_name} is not supported.") |