denoising / utils.py
Yonuts's picture
Bugfix
33dc149
raw
history blame
4.93 kB
import torch
import torch.nn as nn
import deepinv as dinv
from models.unext_wip import UNeXt
from models.unrolled_dpir import get_unrolled_architecture
from models.PDNet import get_PDNet_architecture
from physics.multiscale import Pad
class ArtifactRemoval(nn.Module):
r"""
Artifact removal architecture :math:`\phi(A^{\top}y)`.
This differs from the dinv.models.ArtifactRemoval in that it allows to forward the physics.
In the end we should not use this for unext !!
"""
def __init__(self, backbone_net, pinv=False, ckpt_path=None, device=None, fm_mode=False):
super(ArtifactRemoval, self).__init__()
self.pinv = pinv
self.backbone_net = backbone_net
self.fm_mode = fm_mode
if ckpt_path is not None:
self.backbone_net.load_state_dict(torch.load(ckpt_path), strict=True)
self.backbone_net.eval()
if type(self.backbone_net).__name__ == "UNetRes":
for _, v in self.backbone_net.named_parameters():
v.requires_grad = False
self.backbone_net = self.backbone_net.to(device)
def forward_basic(self, y=None, physics=None, x_in=None, t=None, **kwargs):
r"""
Reconstructs a signal estimate from measurements y
:param torch.tensor y: measurements
:param deepinv.physics.Physics physics: forward operator
"""
if physics is None:
physics = dinv.physics.Denoising(noise_model=dinv.physics.GaussianNoise(sigma=0.), device=y.device)
if not self.training:
x_temp = physics.A_adjoint(y)
pad = (-x_temp.size(-2) % 8, -x_temp.size(-1) % 8)
physics = Pad(physics, pad)
x_in = physics.A_adjoint(y) if not self.pinv else physics.A_dagger(y)
if hasattr(physics.noise_model, "sigma"):
sigma = physics.noise_model.sigma
else:
sigma = 1e-3 # WARNING: this is a default value that we may not want to use?
if hasattr(physics.noise_model, "gain"):
gamma = physics.noise_model.gain
else:
gamma = 1e-3 # WARNING: this is a default value that we may not want to use?
out = self.backbone_net(x_in, physics=physics, y=y, sigma=sigma, gamma=gamma, t=t)
if not self.training:
out = physics.remove_pad(out)
return out
def forward(self, y=None, physics=None, x_in=None, **kwargs):
if 'unext' in type(self.backbone_net).__name__.lower():
return self.forward_basic(physics=physics, y=y, x_in=x_in, **kwargs)
else:
return self.backbone_net(physics=physics, y=y, **kwargs)
def get_model(
model_name="unext_emb_physics_config_C",
device="cpu",
in_channels=[1, 2, 3],
grayscale=False,
conv_type="base",
pool_type="base",
layer_scale_init_value=1e-6,
init_type="ortho",
gain_init_conv=1.0,
gain_init_linear=1.0,
drop_prob=0.0,
replk=False,
mult_fact=4,
antialias="gaussian",
nc_base=64,
cond_type="base",
blind=False,
pretrained_pth=None,
weight_tied=True,
N=4,
c_mult=1,
depth_encoding=1,
relu_in_encoding=False,
skip_in_encoding=True,
):
"""
Load the model.
:param str model_name: name of the model
:param str device: device
:param bool grayscale: if True, the model is trained on grayscale images
:param bool train: if True, the model is trained
:return: model
"""
model_name = model_name.lower()
if model_name == "pdnet":
return get_PDNet_architecture(in_channels=in_channels, out_channels=in_channels, device=device)
elif model_name == "unext_emb_physics_config_c":
n_chan = [1, 2, 3] # 6 for old head grayscale, complex and color = 1 + 2 + 3
residual = True if "residual" in model_name else False
nc = [nc_base * 2**i for i in range(4)]
model = UNeXt(
in_channels=in_channels,
out_channels=in_channels,
device=device,
residual=residual,
conv_type=conv_type,
pool_type=pool_type,
layer_scale_init_value=layer_scale_init_value,
init_type=init_type,
gain_init_conv=gain_init_conv,
gain_init_linear=gain_init_linear,
drop_prob=drop_prob,
replk=replk,
mult_fact=mult_fact,
antialias=antialias,
nc=nc,
cond_type=cond_type,
emb_physics=True,
config="C",
pretrained_pth=pretrained_pth,
N=N,
c_mult=c_mult,
depth_encoding=depth_encoding,
relu_in_encoding=relu_in_encoding,
skip_in_encoding=skip_in_encoding,
).to(device)
return ArtifactRemoval(model, pinv=False, device=device)
else:
raise ValueError(f"Model {model_name} is not supported.")