Spaces:
Sleeping
Sleeping
import torch | |
import torch.nn as nn | |
import deepinv as dinv | |
from models.unext_wip import UNeXt | |
from models.unrolled_dpir import get_unrolled_architecture | |
from models.PDNet import get_PDNet_architecture | |
from physics.multiscale import Pad | |
class ArtifactRemoval(nn.Module): | |
r""" | |
Artifact removal architecture :math:`\phi(A^{\top}y)`. | |
This differs from the dinv.models.ArtifactRemoval in that it allows to forward the physics. | |
In the end we should not use this for unext !! | |
""" | |
def __init__(self, backbone_net, pinv=False, ckpt_path=None, device=None, fm_mode=False): | |
super(ArtifactRemoval, self).__init__() | |
self.pinv = pinv | |
self.backbone_net = backbone_net | |
self.fm_mode = fm_mode | |
if ckpt_path is not None: | |
self.backbone_net.load_state_dict(torch.load(ckpt_path), strict=True) | |
self.backbone_net.eval() | |
if type(self.backbone_net).__name__ == "UNetRes": | |
for _, v in self.backbone_net.named_parameters(): | |
v.requires_grad = False | |
self.backbone_net = self.backbone_net.to(device) | |
def forward_basic(self, y=None, physics=None, x_in=None, t=None, **kwargs): | |
r""" | |
Reconstructs a signal estimate from measurements y | |
:param torch.tensor y: measurements | |
:param deepinv.physics.Physics physics: forward operator | |
""" | |
if physics is None: | |
physics = dinv.physics.Denoising(noise_model=dinv.physics.GaussianNoise(sigma=0.), device=y.device) | |
if not self.training: | |
x_temp = physics.A_adjoint(y) | |
pad = (-x_temp.size(-2) % 8, -x_temp.size(-1) % 8) | |
physics = Pad(physics, pad) | |
x_in = physics.A_adjoint(y) if not self.pinv else physics.A_dagger(y) | |
if hasattr(physics.noise_model, "sigma"): | |
sigma = physics.noise_model.sigma | |
else: | |
sigma = 1e-3 # WARNING: this is a default value that we may not want to use? | |
if hasattr(physics.noise_model, "gain"): | |
gamma = physics.noise_model.gain | |
else: | |
gamma = 1e-3 # WARNING: this is a default value that we may not want to use? | |
out = self.backbone_net(x_in, physics=physics, y=y, sigma=sigma, gamma=gamma, t=t) | |
if not self.training: | |
out = physics.remove_pad(out) | |
return out | |
def forward(self, y=None, physics=None, x_in=None, **kwargs): | |
if 'unext' in type(self.backbone_net).__name__.lower(): | |
return self.forward_basic(physics=physics, y=y, x_in=x_in, **kwargs) | |
else: | |
return self.backbone_net(physics=physics, y=y, **kwargs) | |
def get_model( | |
model_name="unext_emb_physics_config_C", | |
device="cpu", | |
in_channels=[1, 2, 3], | |
grayscale=False, | |
conv_type="base", | |
pool_type="base", | |
layer_scale_init_value=1e-6, | |
init_type="ortho", | |
gain_init_conv=1.0, | |
gain_init_linear=1.0, | |
drop_prob=0.0, | |
replk=False, | |
mult_fact=4, | |
antialias="gaussian", | |
nc_base=64, | |
cond_type="base", | |
blind=False, | |
pretrained_pth=None, | |
weight_tied=True, | |
N=4, | |
c_mult=1, | |
depth_encoding=1, | |
relu_in_encoding=False, | |
skip_in_encoding=True, | |
): | |
""" | |
Load the model. | |
:param str model_name: name of the model | |
:param str device: device | |
:param bool grayscale: if True, the model is trained on grayscale images | |
:param bool train: if True, the model is trained | |
:return: model | |
""" | |
model_name = model_name.lower() | |
if model_name == "pdnet": | |
return get_PDNet_architecture(in_channels=in_channels, out_channels=in_channels, device=device) | |
elif model_name == "unext_emb_physics_config_c": | |
n_chan = [1, 2, 3] # 6 for old head grayscale, complex and color = 1 + 2 + 3 | |
residual = True if "residual" in model_name else False | |
nc = [nc_base * 2**i for i in range(4)] | |
model = UNeXt( | |
in_channels=in_channels, | |
out_channels=in_channels, | |
device=device, | |
residual=residual, | |
conv_type=conv_type, | |
pool_type=pool_type, | |
layer_scale_init_value=layer_scale_init_value, | |
init_type=init_type, | |
gain_init_conv=gain_init_conv, | |
gain_init_linear=gain_init_linear, | |
drop_prob=drop_prob, | |
replk=replk, | |
mult_fact=mult_fact, | |
antialias=antialias, | |
nc=nc, | |
cond_type=cond_type, | |
emb_physics=True, | |
config="C", | |
pretrained_pth=pretrained_pth, | |
N=N, | |
c_mult=c_mult, | |
depth_encoding=depth_encoding, | |
relu_in_encoding=relu_in_encoding, | |
skip_in_encoding=skip_in_encoding, | |
).to(device) | |
return ArtifactRemoval(model, pinv=False, device=device) | |
else: | |
raise ValueError(f"Model {model_name} is not supported.") |