denoising / utils.py
msong97's picture
gradio demo
12a4d59
raw
history blame
4.44 kB
import torch
import torch.nn as nn
import deepinv as dinv
from models.unext_wip import UNeXt
from models.unrolled_dpir import get_unrolled_architecture
from models.PDNet import get_PDNet_architecture
from physics.multiscale import Pad
class ArtifactRemoval(nn.Module):
def __init__(self, backbone_net, pinv=False, ckpt_path=None, device=None, fm_mode=False):
super().__init__()
self.pinv = pinv
self.backbone_net = backbone_net
self.fm_mode = fm_mode
if ckpt_path is not None:
self.backbone_net.load_state_dict(torch.load(ckpt_path), strict=True)
self.backbone_net.eval()
if type(self.backbone_net).__name__ == "UNetRes":
for _, v in self.backbone_net.named_parameters():
v.requires_grad = False
self.backbone_net = self.backbone_net.to(device)
def forward_basic(self, y=None, physics=None, x_in=None, t=None, **kwargs):
if physics is None:
physics = dinv.physics.Denoising(noise_model=dinv.physics.GaussianNoise(sigma=0.), device=y.device)
if not self.training:
x_temp = physics.A_adjoint(y)
pad = (-x_temp.size(-2) % 8, -x_temp.size(-1) % 8)
physics = Pad(physics, pad)
x_in = physics.A_adjoint(y) if not self.pinv else physics.A_dagger(y)
sigma = getattr(physics.noise_model, "sigma", 1e-3)
gamma = getattr(physics.noise_model, "gain", 1e-3)
out = self.backbone_net(x_in, physics=physics, y=y, sigma=sigma, gamma=gamma, t=t)
if not self.training:
out = physics.remove_pad(out)
return out
def forward(self, y=None, physics=None, x_in=None, **kwargs):
return self.forward_basic(physics=physics, y=y, x_in=x_in, **kwargs)
def get_model(
model_name="unext_emb_physics_config_C",
device="cpu",
in_channels=[1, 2, 3],
conv_type="base",
pool_type="base",
layer_scale_init_value=1e-6,
init_type="ortho",
gain_init_conv=1.0,
gain_init_linear=1.0,
drop_prob=0.0,
replk=False,
mult_fact=4,
antialias="gaussian",
nc_base=64,
cond_type="base",
pretrained_pth=None,
weight_tied=True,
N=4,
c_mult=1,
depth_encoding=1,
relu_in_encoding=False,
skip_in_encoding=True,
):
model_name = model_name.lower()
nc = [nc_base * 2**i for i in range(4)]
if model_name == "pdnet":
return get_PDNet_architecture(in_channels=in_channels, out_channels=in_channels, device=device)
elif model_name == "unrolled_dpir":
model = UNeXt(
in_channels=in_channels,
out_channels=in_channels,
device=device,
conv_type=conv_type,
pool_type=pool_type,
layer_scale_init_value=layer_scale_init_value,
init_type=init_type,
gain_init_conv=gain_init_conv,
gain_init_linear=gain_init_linear,
drop_prob=drop_prob,
replk=replk,
mult_fact=mult_fact,
antialias=antialias,
nc=nc,
cond_type=cond_type,
emb_physics=False,
config=None,
pretrained_pth=pretrained_pth,
).to(device)
model = get_unrolled_architecture(model=model, weight_tied=weight_tied, device=device)
return ArtifactRemoval(model, pinv=True, device=device)
elif model_name == "unext_emb_physics_config_c":
model = UNeXt(
in_channels=in_channels,
out_channels=in_channels,
device=device,
conv_type=conv_type,
pool_type=pool_type,
layer_scale_init_value=layer_scale_init_value,
init_type=init_type,
gain_init_conv=gain_init_conv,
gain_init_linear=gain_init_linear,
drop_prob=drop_prob,
replk=replk,
mult_fact=mult_fact,
antialias=antialias,
nc=nc,
cond_type=cond_type,
emb_physics=True,
config="C",
pretrained_pth=pretrained_pth,
N=N,
c_mult=c_mult,
depth_encoding=depth_encoding,
relu_in_encoding=relu_in_encoding,
skip_in_encoding=skip_in_encoding,
).to(device)
return ArtifactRemoval(model, pinv=False, device=device)
else:
raise ValueError(f"Model {model_name} is not supported.")