Spaces:
Sleeping
Sleeping
Bugfix
Browse files- .gitignore +2 -2
- app.py +6 -4
- evals.py +3 -5
- models/blocks.py +1 -1
- utils.py +47 -30
.gitignore
CHANGED
@@ -1,2 +1,2 @@
|
|
1 |
-
.
|
2 |
-
__pycache__
|
|
|
1 |
+
.ipynb_checkpoints/
|
2 |
+
__pycache__/
|
app.py
CHANGED
@@ -14,6 +14,9 @@ from torchvision import transforms
|
|
14 |
from evals import PhysicsWithGenerator, EvalModel, BaselineModel, EvalDataset, Metric
|
15 |
|
16 |
|
|
|
|
|
|
|
17 |
### Gradio Utils
|
18 |
def generate_imgs(dataset: EvalDataset, idx: int,
|
19 |
model: EvalModel, baseline: BaselineModel,
|
@@ -152,8 +155,8 @@ with gr.Blocks(title=title, css=custom_css) as interface:
|
|
152 |
# Loading things
|
153 |
model_a_placeholder = gr.State(lambda: get_eval_model_on_DEVICE_STR("unext_emb_physics_config_C", "")) # lambda expression to instanciate a callable in a gr.State
|
154 |
model_b_placeholder = gr.State(lambda: get_baseline_model_on_DEVICE_STR("DRUNET")) # lambda expression to instanciate a callable in a gr.State
|
155 |
-
dataset_placeholder = gr.State(get_dataset_on_DEVICE_STR("
|
156 |
-
physics_placeholder = gr.State(lambda: get_physics_generator_on_DEVICE_STR("
|
157 |
metrics_placeholder = gr.State(get_list_metrics_on_DEVICE_STR(["PSNR"]))
|
158 |
|
159 |
@gr.render(inputs=[model_a_placeholder, model_b_placeholder, dataset_placeholder, physics_placeholder, metrics_placeholder])
|
@@ -265,5 +268,4 @@ with gr.Blocks(title=title, css=custom_css) as interface:
|
|
265 |
metrics_placeholder],
|
266 |
outputs=[idx_slider, clean, y_image, model_a_out, model_b_out, physics_params, y_metrics, out_a_metric, out_b_metric])
|
267 |
|
268 |
-
|
269 |
-
interface.launch()
|
|
|
14 |
from evals import PhysicsWithGenerator, EvalModel, BaselineModel, EvalDataset, Metric
|
15 |
|
16 |
|
17 |
+
DEVICE_STR = 'cuda'
|
18 |
+
|
19 |
+
|
20 |
### Gradio Utils
|
21 |
def generate_imgs(dataset: EvalDataset, idx: int,
|
22 |
model: EvalModel, baseline: BaselineModel,
|
|
|
155 |
# Loading things
|
156 |
model_a_placeholder = gr.State(lambda: get_eval_model_on_DEVICE_STR("unext_emb_physics_config_C", "")) # lambda expression to instanciate a callable in a gr.State
|
157 |
model_b_placeholder = gr.State(lambda: get_baseline_model_on_DEVICE_STR("DRUNET")) # lambda expression to instanciate a callable in a gr.State
|
158 |
+
dataset_placeholder = gr.State(get_dataset_on_DEVICE_STR("Natural"))
|
159 |
+
physics_placeholder = gr.State(lambda: get_physics_generator_on_DEVICE_STR("MotionBlur_easy")) # lambda expression to instanciate a callable in a gr.State
|
160 |
metrics_placeholder = gr.State(get_list_metrics_on_DEVICE_STR(["PSNR"]))
|
161 |
|
162 |
@gr.render(inputs=[model_a_placeholder, model_b_placeholder, dataset_placeholder, physics_placeholder, metrics_placeholder])
|
|
|
268 |
metrics_placeholder],
|
269 |
outputs=[idx_slider, clean, y_image, model_a_out, model_b_out, physics_params, y_metrics, out_a_metric, out_b_metric])
|
270 |
|
271 |
+
interface.launch()
|
|
evals.py
CHANGED
@@ -486,8 +486,6 @@ class BaselineModel(torch.nn.Module):
|
|
486 |
x_adj = physics.A_adjoint(y)
|
487 |
output = output[..., :x_adj.size(-2), :x_adj.size(-1)]
|
488 |
return output
|
489 |
-
elif 'UNROLLED_DPIR' in self.name:
|
490 |
-
return self.model(y, physics=physics)
|
491 |
else:
|
492 |
return self.model(y)
|
493 |
|
@@ -504,19 +502,19 @@ class EvalDataset(torch.utils.data.Dataset):
|
|
504 |
if self.name not in self.all_datasets:
|
505 |
raise ValueError(f"{self.name} is unavailable.")
|
506 |
if self.name == 'Natural':
|
507 |
-
self.root = '
|
508 |
self.transform = transforms.Compose([transforms.ToTensor()])
|
509 |
self.dataset = dinv.datasets.LsdirHR(root=self.root,
|
510 |
download=False,
|
511 |
transform=self.transform)
|
512 |
elif self.name == 'MRI':
|
513 |
-
self.root = '
|
514 |
self.transform = transforms.CenterCrop((640, 320)) # , pad_if_needed=True)
|
515 |
self.dataset = Preprocessed_fastMRI(root=self.root,
|
516 |
transform=self.transform,
|
517 |
preprocess=False)
|
518 |
elif self.name == "CT":
|
519 |
-
self.root = '
|
520 |
self.transform = None
|
521 |
self.dataset = Preprocessed_LIDCIDRI(root=self.root,
|
522 |
transform=self.transform)
|
|
|
486 |
x_adj = physics.A_adjoint(y)
|
487 |
output = output[..., :x_adj.size(-2), :x_adj.size(-1)]
|
488 |
return output
|
|
|
|
|
489 |
else:
|
490 |
return self.model(y)
|
491 |
|
|
|
502 |
if self.name not in self.all_datasets:
|
503 |
raise ValueError(f"{self.name} is unavailable.")
|
504 |
if self.name == 'Natural':
|
505 |
+
self.root = 'img_samples/LSDIR_samples'
|
506 |
self.transform = transforms.Compose([transforms.ToTensor()])
|
507 |
self.dataset = dinv.datasets.LsdirHR(root=self.root,
|
508 |
download=False,
|
509 |
transform=self.transform)
|
510 |
elif self.name == 'MRI':
|
511 |
+
self.root = 'img_samples/FastMRI_samples'
|
512 |
self.transform = transforms.CenterCrop((640, 320)) # , pad_if_needed=True)
|
513 |
self.dataset = Preprocessed_fastMRI(root=self.root,
|
514 |
transform=self.transform,
|
515 |
preprocess=False)
|
516 |
elif self.name == "CT":
|
517 |
+
self.root = 'img_samples/LIDC_IDRI_samples'
|
518 |
self.transform = None
|
519 |
self.dataset = Preprocessed_LIDCIDRI(root=self.root,
|
520 |
transform=self.transform)
|
models/blocks.py
CHANGED
@@ -7,7 +7,7 @@ import torch.nn.functional as F
|
|
7 |
from deepinv.models.unet import BFBatchNorm2d
|
8 |
from deepinv.physics.blur import gaussian_blur
|
9 |
from deepinv.physics.functional import conv2d
|
10 |
-
from deepinv.utils
|
11 |
|
12 |
from timm.models.layers import trunc_normal_, DropPath
|
13 |
|
|
|
7 |
from deepinv.models.unet import BFBatchNorm2d
|
8 |
from deepinv.physics.blur import gaussian_blur
|
9 |
from deepinv.physics.functional import conv2d
|
10 |
+
from deepinv.utils import TensorList
|
11 |
|
12 |
from timm.models.layers import trunc_normal_, DropPath
|
13 |
|
utils.py
CHANGED
@@ -9,8 +9,16 @@ from physics.multiscale import Pad
|
|
9 |
|
10 |
|
11 |
class ArtifactRemoval(nn.Module):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
12 |
def __init__(self, backbone_net, pinv=False, ckpt_path=None, device=None, fm_mode=False):
|
13 |
-
super().__init__()
|
14 |
self.pinv = pinv
|
15 |
self.backbone_net = backbone_net
|
16 |
self.fm_mode = fm_mode
|
@@ -24,7 +32,14 @@ class ArtifactRemoval(nn.Module):
|
|
24 |
v.requires_grad = False
|
25 |
self.backbone_net = self.backbone_net.to(device)
|
26 |
|
|
|
27 |
def forward_basic(self, y=None, physics=None, x_in=None, t=None, **kwargs):
|
|
|
|
|
|
|
|
|
|
|
|
|
28 |
if physics is None:
|
29 |
physics = dinv.physics.Denoising(noise_model=dinv.physics.GaussianNoise(sigma=0.), device=y.device)
|
30 |
|
@@ -35,8 +50,15 @@ class ArtifactRemoval(nn.Module):
|
|
35 |
|
36 |
x_in = physics.A_adjoint(y) if not self.pinv else physics.A_dagger(y)
|
37 |
|
38 |
-
|
39 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
40 |
|
41 |
out = self.backbone_net(x_in, physics=physics, y=y, sigma=sigma, gamma=gamma, t=t)
|
42 |
|
@@ -45,14 +67,18 @@ class ArtifactRemoval(nn.Module):
|
|
45 |
|
46 |
return out
|
47 |
|
48 |
-
def forward(self,
|
49 |
-
|
|
|
|
|
|
|
50 |
|
51 |
|
52 |
def get_model(
|
53 |
model_name="unext_emb_physics_config_C",
|
54 |
device="cpu",
|
55 |
in_channels=[1, 2, 3],
|
|
|
56 |
conv_type="base",
|
57 |
pool_type="base",
|
58 |
layer_scale_init_value=1e-6,
|
@@ -65,6 +91,7 @@ def get_model(
|
|
65 |
antialias="gaussian",
|
66 |
nc_base=64,
|
67 |
cond_type="base",
|
|
|
68 |
pretrained_pth=None,
|
69 |
weight_tied=True,
|
70 |
N=4,
|
@@ -73,41 +100,31 @@ def get_model(
|
|
73 |
relu_in_encoding=False,
|
74 |
skip_in_encoding=True,
|
75 |
):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
76 |
model_name = model_name.lower()
|
77 |
-
nc = [nc_base * 2**i for i in range(4)]
|
78 |
|
79 |
if model_name == "pdnet":
|
80 |
return get_PDNet_architecture(in_channels=in_channels, out_channels=in_channels, device=device)
|
81 |
|
82 |
-
elif model_name == "unrolled_dpir":
|
83 |
-
model = UNeXt(
|
84 |
-
in_channels=in_channels,
|
85 |
-
out_channels=in_channels,
|
86 |
-
device=device,
|
87 |
-
conv_type=conv_type,
|
88 |
-
pool_type=pool_type,
|
89 |
-
layer_scale_init_value=layer_scale_init_value,
|
90 |
-
init_type=init_type,
|
91 |
-
gain_init_conv=gain_init_conv,
|
92 |
-
gain_init_linear=gain_init_linear,
|
93 |
-
drop_prob=drop_prob,
|
94 |
-
replk=replk,
|
95 |
-
mult_fact=mult_fact,
|
96 |
-
antialias=antialias,
|
97 |
-
nc=nc,
|
98 |
-
cond_type=cond_type,
|
99 |
-
emb_physics=False,
|
100 |
-
config=None,
|
101 |
-
pretrained_pth=pretrained_pth,
|
102 |
-
).to(device)
|
103 |
-
model = get_unrolled_architecture(model=model, weight_tied=weight_tied, device=device)
|
104 |
-
return ArtifactRemoval(model, pinv=True, device=device)
|
105 |
-
|
106 |
elif model_name == "unext_emb_physics_config_c":
|
|
|
|
|
|
|
|
|
|
|
107 |
model = UNeXt(
|
108 |
in_channels=in_channels,
|
109 |
out_channels=in_channels,
|
110 |
device=device,
|
|
|
111 |
conv_type=conv_type,
|
112 |
pool_type=pool_type,
|
113 |
layer_scale_init_value=layer_scale_init_value,
|
|
|
9 |
|
10 |
|
11 |
class ArtifactRemoval(nn.Module):
|
12 |
+
r"""
|
13 |
+
Artifact removal architecture :math:`\phi(A^{\top}y)`.
|
14 |
+
|
15 |
+
This differs from the dinv.models.ArtifactRemoval in that it allows to forward the physics.
|
16 |
+
|
17 |
+
In the end we should not use this for unext !!
|
18 |
+
"""
|
19 |
+
|
20 |
def __init__(self, backbone_net, pinv=False, ckpt_path=None, device=None, fm_mode=False):
|
21 |
+
super(ArtifactRemoval, self).__init__()
|
22 |
self.pinv = pinv
|
23 |
self.backbone_net = backbone_net
|
24 |
self.fm_mode = fm_mode
|
|
|
32 |
v.requires_grad = False
|
33 |
self.backbone_net = self.backbone_net.to(device)
|
34 |
|
35 |
+
|
36 |
def forward_basic(self, y=None, physics=None, x_in=None, t=None, **kwargs):
|
37 |
+
r"""
|
38 |
+
Reconstructs a signal estimate from measurements y
|
39 |
+
|
40 |
+
:param torch.tensor y: measurements
|
41 |
+
:param deepinv.physics.Physics physics: forward operator
|
42 |
+
"""
|
43 |
if physics is None:
|
44 |
physics = dinv.physics.Denoising(noise_model=dinv.physics.GaussianNoise(sigma=0.), device=y.device)
|
45 |
|
|
|
50 |
|
51 |
x_in = physics.A_adjoint(y) if not self.pinv else physics.A_dagger(y)
|
52 |
|
53 |
+
if hasattr(physics.noise_model, "sigma"):
|
54 |
+
sigma = physics.noise_model.sigma
|
55 |
+
else:
|
56 |
+
sigma = 1e-3 # WARNING: this is a default value that we may not want to use?
|
57 |
+
|
58 |
+
if hasattr(physics.noise_model, "gain"):
|
59 |
+
gamma = physics.noise_model.gain
|
60 |
+
else:
|
61 |
+
gamma = 1e-3 # WARNING: this is a default value that we may not want to use?
|
62 |
|
63 |
out = self.backbone_net(x_in, physics=physics, y=y, sigma=sigma, gamma=gamma, t=t)
|
64 |
|
|
|
67 |
|
68 |
return out
|
69 |
|
70 |
+
def forward(self, y=None, physics=None, x_in=None, **kwargs):
|
71 |
+
if 'unext' in type(self.backbone_net).__name__.lower():
|
72 |
+
return self.forward_basic(physics=physics, y=y, x_in=x_in, **kwargs)
|
73 |
+
else:
|
74 |
+
return self.backbone_net(physics=physics, y=y, **kwargs)
|
75 |
|
76 |
|
77 |
def get_model(
|
78 |
model_name="unext_emb_physics_config_C",
|
79 |
device="cpu",
|
80 |
in_channels=[1, 2, 3],
|
81 |
+
grayscale=False,
|
82 |
conv_type="base",
|
83 |
pool_type="base",
|
84 |
layer_scale_init_value=1e-6,
|
|
|
91 |
antialias="gaussian",
|
92 |
nc_base=64,
|
93 |
cond_type="base",
|
94 |
+
blind=False,
|
95 |
pretrained_pth=None,
|
96 |
weight_tied=True,
|
97 |
N=4,
|
|
|
100 |
relu_in_encoding=False,
|
101 |
skip_in_encoding=True,
|
102 |
):
|
103 |
+
"""
|
104 |
+
Load the model.
|
105 |
+
|
106 |
+
:param str model_name: name of the model
|
107 |
+
:param str device: device
|
108 |
+
:param bool grayscale: if True, the model is trained on grayscale images
|
109 |
+
:param bool train: if True, the model is trained
|
110 |
+
:return: model
|
111 |
+
"""
|
112 |
model_name = model_name.lower()
|
|
|
113 |
|
114 |
if model_name == "pdnet":
|
115 |
return get_PDNet_architecture(in_channels=in_channels, out_channels=in_channels, device=device)
|
116 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
117 |
elif model_name == "unext_emb_physics_config_c":
|
118 |
+
n_chan = [1, 2, 3] # 6 for old head grayscale, complex and color = 1 + 2 + 3
|
119 |
+
residual = True if "residual" in model_name else False
|
120 |
+
nc = [nc_base * 2**i for i in range(4)]
|
121 |
+
|
122 |
+
|
123 |
model = UNeXt(
|
124 |
in_channels=in_channels,
|
125 |
out_channels=in_channels,
|
126 |
device=device,
|
127 |
+
residual=residual,
|
128 |
conv_type=conv_type,
|
129 |
pool_type=pool_type,
|
130 |
layer_scale_init_value=layer_scale_init_value,
|