Spaces:
Running
on
T4
Running
on
T4
File size: 4,688 Bytes
4dc3e99 ff76a8d 4dc3e99 ff76a8d 4dc3e99 ff76a8d 4dc3e99 ff76a8d 4dc3e99 ff76a8d 4dc3e99 ff76a8d 4dc3e99 ff76a8d 4dc3e99 ff76a8d 4dc3e99 ff76a8d 4dc3e99 ccc37f0 ff76a8d 4dc3e99 ff76a8d 4dc3e99 ccc37f0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 |
import torch
import torch.nn as nn
import deepinv as dinv
from models.unext_wip import UNeXt
from physics.multiscale import Pad
class ArtifactRemoval(nn.Module):
r"""
Artifact removal architecture :math:`\phi(A^{\top}y)`.
This differs from the dinv.models.ArtifactRemoval in that it allows to forward the physics.
In the end we should not use this for unext !!
"""
def __init__(self, backbone_net, pinv=False, ckpt_path=None, device=None, fm_mode=False):
super(ArtifactRemoval, self).__init__()
self.pinv = pinv
self.backbone_net = backbone_net
self.fm_mode = fm_mode
if ckpt_path is not None:
self.backbone_net.load_state_dict(torch.load(ckpt_path), strict=True)
self.backbone_net.eval()
if type(self.backbone_net).__name__ == "UNetRes":
for _, v in self.backbone_net.named_parameters():
v.requires_grad = False
self.backbone_net = self.backbone_net.to(device)
def forward_basic(self, y=None, physics=None, x_in=None, t=None, **kwargs):
r"""
Reconstructs a signal estimate from measurements y
:param torch.tensor y: measurements
:param deepinv.physics.Physics physics: forward operator
"""
if physics is None:
physics = dinv.physics.Denoising(noise_model=dinv.physics.GaussianNoise(sigma=0.), device=y.device)
if not self.training:
x_temp = physics.A_adjoint(y)
pad = (-x_temp.size(-2) % 8, -x_temp.size(-1) % 8)
physics = Pad(physics, pad)
x_in = physics.A_adjoint(y) if not self.pinv else physics.A_dagger(y)
if hasattr(physics.noise_model, "sigma"):
sigma = physics.noise_model.sigma
else:
sigma = 1e-3 # WARNING: this is a default value that we may not want to use?
if hasattr(physics.noise_model, "gain"):
gamma = physics.noise_model.gain
else:
gamma = 1e-3 # WARNING: this is a default value that we may not want to use?
out = self.backbone_net(x_in, physics=physics, y=y, sigma=sigma, gamma=gamma, t=t)
if not self.training:
out = physics.remove_pad(out)
return out
def forward(self, y=None, physics=None, x_in=None, **kwargs):
if 'unext' in type(self.backbone_net).__name__.lower():
return self.forward_basic(physics=physics, y=y, x_in=x_in, **kwargs)
else:
return self.backbone_net(physics=physics, y=y, **kwargs)
def get_model(
model_name="unext_emb_physics_config_C",
device="cpu",
in_channels=[1, 2, 3],
grayscale=False,
conv_type="base",
pool_type="base",
layer_scale_init_value=1e-6,
init_type="ortho",
gain_init_conv=1.0,
gain_init_linear=1.0,
drop_prob=0.0,
replk=False,
mult_fact=4,
antialias="gaussian",
nc_base=64,
cond_type="base",
blind=False,
pretrained_pth=None,
weight_tied=True,
N=4,
c_mult=1,
depth_encoding=1,
relu_in_encoding=False,
skip_in_encoding=True,
):
"""
Load the model.
:param str model_name: name of the model
:param str device: device
:param bool grayscale: if True, the model is trained on grayscale images
:param bool train: if True, the model is trained
:return: model
"""
model_name = model_name.lower()
if model_name == "unext_emb_physics_config_c":
n_chan = [1, 2, 3] # 6 for old head grayscale, complex and color = 1 + 2 + 3
residual = True if "residual" in model_name else False
nc = [nc_base * 2**i for i in range(4)]
model = UNeXt(
in_channels=in_channels,
out_channels=in_channels,
device=device,
residual=residual,
conv_type=conv_type,
pool_type=pool_type,
layer_scale_init_value=layer_scale_init_value,
init_type=init_type,
gain_init_conv=gain_init_conv,
gain_init_linear=gain_init_linear,
drop_prob=drop_prob,
replk=replk,
mult_fact=mult_fact,
antialias=antialias,
nc=nc,
cond_type=cond_type,
emb_physics=True,
config="C",
pretrained_pth=pretrained_pth,
N=N,
c_mult=c_mult,
depth_encoding=depth_encoding,
relu_in_encoding=relu_in_encoding,
skip_in_encoding=skip_in_encoding,
).to(device)
return ArtifactRemoval(model, pinv=False, device=device)
else:
raise ValueError(f"Model {model_name} is not supported.")
|