import torch import torch.nn as nn import deepinv as dinv from models.unext_wip import UNeXt from models.unrolled_dpir import get_unrolled_architecture from models.PDNet import get_PDNet_architecture from physics.multiscale import Pad class ArtifactRemoval(nn.Module): r""" Artifact removal architecture :math:`\phi(A^{\top}y)`. This differs from the dinv.models.ArtifactRemoval in that it allows to forward the physics. In the end we should not use this for unext !! """ def __init__(self, backbone_net, pinv=False, ckpt_path=None, device=None, fm_mode=False): super(ArtifactRemoval, self).__init__() self.pinv = pinv self.backbone_net = backbone_net self.fm_mode = fm_mode if ckpt_path is not None: self.backbone_net.load_state_dict(torch.load(ckpt_path), strict=True) self.backbone_net.eval() if type(self.backbone_net).__name__ == "UNetRes": for _, v in self.backbone_net.named_parameters(): v.requires_grad = False self.backbone_net = self.backbone_net.to(device) def forward_basic(self, y=None, physics=None, x_in=None, t=None, **kwargs): r""" Reconstructs a signal estimate from measurements y :param torch.tensor y: measurements :param deepinv.physics.Physics physics: forward operator """ if physics is None: physics = dinv.physics.Denoising(noise_model=dinv.physics.GaussianNoise(sigma=0.), device=y.device) if not self.training: x_temp = physics.A_adjoint(y) pad = (-x_temp.size(-2) % 8, -x_temp.size(-1) % 8) physics = Pad(physics, pad) x_in = physics.A_adjoint(y) if not self.pinv else physics.A_dagger(y) if hasattr(physics.noise_model, "sigma"): sigma = physics.noise_model.sigma else: sigma = 1e-3 # WARNING: this is a default value that we may not want to use? if hasattr(physics.noise_model, "gain"): gamma = physics.noise_model.gain else: gamma = 1e-3 # WARNING: this is a default value that we may not want to use? out = self.backbone_net(x_in, physics=physics, y=y, sigma=sigma, gamma=gamma, t=t) if not self.training: out = physics.remove_pad(out) return out def forward(self, y=None, physics=None, x_in=None, **kwargs): if 'unext' in type(self.backbone_net).__name__.lower(): return self.forward_basic(physics=physics, y=y, x_in=x_in, **kwargs) else: return self.backbone_net(physics=physics, y=y, **kwargs) def get_model( model_name="unext_emb_physics_config_C", device="cpu", in_channels=[1, 2, 3], grayscale=False, conv_type="base", pool_type="base", layer_scale_init_value=1e-6, init_type="ortho", gain_init_conv=1.0, gain_init_linear=1.0, drop_prob=0.0, replk=False, mult_fact=4, antialias="gaussian", nc_base=64, cond_type="base", blind=False, pretrained_pth=None, weight_tied=True, N=4, c_mult=1, depth_encoding=1, relu_in_encoding=False, skip_in_encoding=True, ): """ Load the model. :param str model_name: name of the model :param str device: device :param bool grayscale: if True, the model is trained on grayscale images :param bool train: if True, the model is trained :return: model """ model_name = model_name.lower() if model_name == "pdnet": return get_PDNet_architecture(in_channels=in_channels, out_channels=in_channels, device=device) elif model_name == "unext_emb_physics_config_c": n_chan = [1, 2, 3] # 6 for old head grayscale, complex and color = 1 + 2 + 3 residual = True if "residual" in model_name else False nc = [nc_base * 2**i for i in range(4)] model = UNeXt( in_channels=in_channels, out_channels=in_channels, device=device, residual=residual, conv_type=conv_type, pool_type=pool_type, layer_scale_init_value=layer_scale_init_value, init_type=init_type, gain_init_conv=gain_init_conv, gain_init_linear=gain_init_linear, drop_prob=drop_prob, replk=replk, mult_fact=mult_fact, antialias=antialias, nc=nc, cond_type=cond_type, emb_physics=True, config="C", pretrained_pth=pretrained_pth, N=N, c_mult=c_mult, depth_encoding=depth_encoding, relu_in_encoding=relu_in_encoding, skip_in_encoding=skip_in_encoding, ).to(device) return ArtifactRemoval(model, pinv=False, device=device) else: raise ValueError(f"Model {model_name} is not supported.")