Spaces:
Running
on
T4
Running
on
T4
File size: 4,442 Bytes
12a4d59 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 |
import torch
import torch.nn as nn
import deepinv as dinv
from models.unext_wip import UNeXt
from models.unrolled_dpir import get_unrolled_architecture
from models.PDNet import get_PDNet_architecture
from physics.multiscale import Pad
class ArtifactRemoval(nn.Module):
def __init__(self, backbone_net, pinv=False, ckpt_path=None, device=None, fm_mode=False):
super().__init__()
self.pinv = pinv
self.backbone_net = backbone_net
self.fm_mode = fm_mode
if ckpt_path is not None:
self.backbone_net.load_state_dict(torch.load(ckpt_path), strict=True)
self.backbone_net.eval()
if type(self.backbone_net).__name__ == "UNetRes":
for _, v in self.backbone_net.named_parameters():
v.requires_grad = False
self.backbone_net = self.backbone_net.to(device)
def forward_basic(self, y=None, physics=None, x_in=None, t=None, **kwargs):
if physics is None:
physics = dinv.physics.Denoising(noise_model=dinv.physics.GaussianNoise(sigma=0.), device=y.device)
if not self.training:
x_temp = physics.A_adjoint(y)
pad = (-x_temp.size(-2) % 8, -x_temp.size(-1) % 8)
physics = Pad(physics, pad)
x_in = physics.A_adjoint(y) if not self.pinv else physics.A_dagger(y)
sigma = getattr(physics.noise_model, "sigma", 1e-3)
gamma = getattr(physics.noise_model, "gain", 1e-3)
out = self.backbone_net(x_in, physics=physics, y=y, sigma=sigma, gamma=gamma, t=t)
if not self.training:
out = physics.remove_pad(out)
return out
def forward(self, y=None, physics=None, x_in=None, **kwargs):
return self.forward_basic(physics=physics, y=y, x_in=x_in, **kwargs)
def get_model(
model_name="unext_emb_physics_config_C",
device="cpu",
in_channels=[1, 2, 3],
conv_type="base",
pool_type="base",
layer_scale_init_value=1e-6,
init_type="ortho",
gain_init_conv=1.0,
gain_init_linear=1.0,
drop_prob=0.0,
replk=False,
mult_fact=4,
antialias="gaussian",
nc_base=64,
cond_type="base",
pretrained_pth=None,
weight_tied=True,
N=4,
c_mult=1,
depth_encoding=1,
relu_in_encoding=False,
skip_in_encoding=True,
):
model_name = model_name.lower()
nc = [nc_base * 2**i for i in range(4)]
if model_name == "pdnet":
return get_PDNet_architecture(in_channels=in_channels, out_channels=in_channels, device=device)
elif model_name == "unrolled_dpir":
model = UNeXt(
in_channels=in_channels,
out_channels=in_channels,
device=device,
conv_type=conv_type,
pool_type=pool_type,
layer_scale_init_value=layer_scale_init_value,
init_type=init_type,
gain_init_conv=gain_init_conv,
gain_init_linear=gain_init_linear,
drop_prob=drop_prob,
replk=replk,
mult_fact=mult_fact,
antialias=antialias,
nc=nc,
cond_type=cond_type,
emb_physics=False,
config=None,
pretrained_pth=pretrained_pth,
).to(device)
model = get_unrolled_architecture(model=model, weight_tied=weight_tied, device=device)
return ArtifactRemoval(model, pinv=True, device=device)
elif model_name == "unext_emb_physics_config_c":
model = UNeXt(
in_channels=in_channels,
out_channels=in_channels,
device=device,
conv_type=conv_type,
pool_type=pool_type,
layer_scale_init_value=layer_scale_init_value,
init_type=init_type,
gain_init_conv=gain_init_conv,
gain_init_linear=gain_init_linear,
drop_prob=drop_prob,
replk=replk,
mult_fact=mult_fact,
antialias=antialias,
nc=nc,
cond_type=cond_type,
emb_physics=True,
config="C",
pretrained_pth=pretrained_pth,
N=N,
c_mult=c_mult,
depth_encoding=depth_encoding,
relu_in_encoding=relu_in_encoding,
skip_in_encoding=skip_in_encoding,
).to(device)
return ArtifactRemoval(model, pinv=False, device=device)
else:
raise ValueError(f"Model {model_name} is not supported.") |