Yonuts commited on
Commit
33dc149
·
1 Parent(s): cd12993
Files changed (5) hide show
  1. .gitignore +2 -2
  2. app.py +6 -4
  3. evals.py +3 -5
  4. models/blocks.py +1 -1
  5. utils.py +47 -30
.gitignore CHANGED
@@ -1,2 +1,2 @@
1
- .ipynb
2
- __pycache__
 
1
+ .ipynb_checkpoints/
2
+ __pycache__/
app.py CHANGED
@@ -14,6 +14,9 @@ from torchvision import transforms
14
  from evals import PhysicsWithGenerator, EvalModel, BaselineModel, EvalDataset, Metric
15
 
16
 
 
 
 
17
  ### Gradio Utils
18
  def generate_imgs(dataset: EvalDataset, idx: int,
19
  model: EvalModel, baseline: BaselineModel,
@@ -152,8 +155,8 @@ with gr.Blocks(title=title, css=custom_css) as interface:
152
  # Loading things
153
  model_a_placeholder = gr.State(lambda: get_eval_model_on_DEVICE_STR("unext_emb_physics_config_C", "")) # lambda expression to instanciate a callable in a gr.State
154
  model_b_placeholder = gr.State(lambda: get_baseline_model_on_DEVICE_STR("DRUNET")) # lambda expression to instanciate a callable in a gr.State
155
- dataset_placeholder = gr.State(get_dataset_on_DEVICE_STR("DIV2K_valid_HR"))
156
- physics_placeholder = gr.State(lambda: get_physics_generator_on_DEVICE_STR("Denoising")) # lambda expression to instanciate a callable in a gr.State
157
  metrics_placeholder = gr.State(get_list_metrics_on_DEVICE_STR(["PSNR"]))
158
 
159
  @gr.render(inputs=[model_a_placeholder, model_b_placeholder, dataset_placeholder, physics_placeholder, metrics_placeholder])
@@ -265,5 +268,4 @@ with gr.Blocks(title=title, css=custom_css) as interface:
265
  metrics_placeholder],
266
  outputs=[idx_slider, clean, y_image, model_a_out, model_b_out, physics_params, y_metrics, out_a_metric, out_b_metric])
267
 
268
- if __name__ == "__main__":
269
- interface.launch()
 
14
  from evals import PhysicsWithGenerator, EvalModel, BaselineModel, EvalDataset, Metric
15
 
16
 
17
+ DEVICE_STR = 'cuda'
18
+
19
+
20
  ### Gradio Utils
21
  def generate_imgs(dataset: EvalDataset, idx: int,
22
  model: EvalModel, baseline: BaselineModel,
 
155
  # Loading things
156
  model_a_placeholder = gr.State(lambda: get_eval_model_on_DEVICE_STR("unext_emb_physics_config_C", "")) # lambda expression to instanciate a callable in a gr.State
157
  model_b_placeholder = gr.State(lambda: get_baseline_model_on_DEVICE_STR("DRUNET")) # lambda expression to instanciate a callable in a gr.State
158
+ dataset_placeholder = gr.State(get_dataset_on_DEVICE_STR("Natural"))
159
+ physics_placeholder = gr.State(lambda: get_physics_generator_on_DEVICE_STR("MotionBlur_easy")) # lambda expression to instanciate a callable in a gr.State
160
  metrics_placeholder = gr.State(get_list_metrics_on_DEVICE_STR(["PSNR"]))
161
 
162
  @gr.render(inputs=[model_a_placeholder, model_b_placeholder, dataset_placeholder, physics_placeholder, metrics_placeholder])
 
268
  metrics_placeholder],
269
  outputs=[idx_slider, clean, y_image, model_a_out, model_b_out, physics_params, y_metrics, out_a_metric, out_b_metric])
270
 
271
+ interface.launch()
 
evals.py CHANGED
@@ -486,8 +486,6 @@ class BaselineModel(torch.nn.Module):
486
  x_adj = physics.A_adjoint(y)
487
  output = output[..., :x_adj.size(-2), :x_adj.size(-1)]
488
  return output
489
- elif 'UNROLLED_DPIR' in self.name:
490
- return self.model(y, physics=physics)
491
  else:
492
  return self.model(y)
493
 
@@ -504,19 +502,19 @@ class EvalDataset(torch.utils.data.Dataset):
504
  if self.name not in self.all_datasets:
505
  raise ValueError(f"{self.name} is unavailable.")
506
  if self.name == 'Natural':
507
- self.root = 'datasets/LSDIR_samples'
508
  self.transform = transforms.Compose([transforms.ToTensor()])
509
  self.dataset = dinv.datasets.LsdirHR(root=self.root,
510
  download=False,
511
  transform=self.transform)
512
  elif self.name == 'MRI':
513
- self.root = 'datasets/FastMRI_samples'
514
  self.transform = transforms.CenterCrop((640, 320)) # , pad_if_needed=True)
515
  self.dataset = Preprocessed_fastMRI(root=self.root,
516
  transform=self.transform,
517
  preprocess=False)
518
  elif self.name == "CT":
519
- self.root = 'datasets/LIDC_IDRI_samples'
520
  self.transform = None
521
  self.dataset = Preprocessed_LIDCIDRI(root=self.root,
522
  transform=self.transform)
 
486
  x_adj = physics.A_adjoint(y)
487
  output = output[..., :x_adj.size(-2), :x_adj.size(-1)]
488
  return output
 
 
489
  else:
490
  return self.model(y)
491
 
 
502
  if self.name not in self.all_datasets:
503
  raise ValueError(f"{self.name} is unavailable.")
504
  if self.name == 'Natural':
505
+ self.root = 'img_samples/LSDIR_samples'
506
  self.transform = transforms.Compose([transforms.ToTensor()])
507
  self.dataset = dinv.datasets.LsdirHR(root=self.root,
508
  download=False,
509
  transform=self.transform)
510
  elif self.name == 'MRI':
511
+ self.root = 'img_samples/FastMRI_samples'
512
  self.transform = transforms.CenterCrop((640, 320)) # , pad_if_needed=True)
513
  self.dataset = Preprocessed_fastMRI(root=self.root,
514
  transform=self.transform,
515
  preprocess=False)
516
  elif self.name == "CT":
517
+ self.root = 'img_samples/LIDC_IDRI_samples'
518
  self.transform = None
519
  self.dataset = Preprocessed_LIDCIDRI(root=self.root,
520
  transform=self.transform)
models/blocks.py CHANGED
@@ -7,7 +7,7 @@ import torch.nn.functional as F
7
  from deepinv.models.unet import BFBatchNorm2d
8
  from deepinv.physics.blur import gaussian_blur
9
  from deepinv.physics.functional import conv2d
10
- from deepinv.utils.tensorlist import TensorList
11
 
12
  from timm.models.layers import trunc_normal_, DropPath
13
 
 
7
  from deepinv.models.unet import BFBatchNorm2d
8
  from deepinv.physics.blur import gaussian_blur
9
  from deepinv.physics.functional import conv2d
10
+ from deepinv.utils import TensorList
11
 
12
  from timm.models.layers import trunc_normal_, DropPath
13
 
utils.py CHANGED
@@ -9,8 +9,16 @@ from physics.multiscale import Pad
9
 
10
 
11
  class ArtifactRemoval(nn.Module):
 
 
 
 
 
 
 
 
12
  def __init__(self, backbone_net, pinv=False, ckpt_path=None, device=None, fm_mode=False):
13
- super().__init__()
14
  self.pinv = pinv
15
  self.backbone_net = backbone_net
16
  self.fm_mode = fm_mode
@@ -24,7 +32,14 @@ class ArtifactRemoval(nn.Module):
24
  v.requires_grad = False
25
  self.backbone_net = self.backbone_net.to(device)
26
 
 
27
  def forward_basic(self, y=None, physics=None, x_in=None, t=None, **kwargs):
 
 
 
 
 
 
28
  if physics is None:
29
  physics = dinv.physics.Denoising(noise_model=dinv.physics.GaussianNoise(sigma=0.), device=y.device)
30
 
@@ -35,8 +50,15 @@ class ArtifactRemoval(nn.Module):
35
 
36
  x_in = physics.A_adjoint(y) if not self.pinv else physics.A_dagger(y)
37
 
38
- sigma = getattr(physics.noise_model, "sigma", 1e-3)
39
- gamma = getattr(physics.noise_model, "gain", 1e-3)
 
 
 
 
 
 
 
40
 
41
  out = self.backbone_net(x_in, physics=physics, y=y, sigma=sigma, gamma=gamma, t=t)
42
 
@@ -45,14 +67,18 @@ class ArtifactRemoval(nn.Module):
45
 
46
  return out
47
 
48
- def forward(self, y=None, physics=None, x_in=None, **kwargs):
49
- return self.forward_basic(physics=physics, y=y, x_in=x_in, **kwargs)
 
 
 
50
 
51
 
52
  def get_model(
53
  model_name="unext_emb_physics_config_C",
54
  device="cpu",
55
  in_channels=[1, 2, 3],
 
56
  conv_type="base",
57
  pool_type="base",
58
  layer_scale_init_value=1e-6,
@@ -65,6 +91,7 @@ def get_model(
65
  antialias="gaussian",
66
  nc_base=64,
67
  cond_type="base",
 
68
  pretrained_pth=None,
69
  weight_tied=True,
70
  N=4,
@@ -73,41 +100,31 @@ def get_model(
73
  relu_in_encoding=False,
74
  skip_in_encoding=True,
75
  ):
 
 
 
 
 
 
 
 
 
76
  model_name = model_name.lower()
77
- nc = [nc_base * 2**i for i in range(4)]
78
 
79
  if model_name == "pdnet":
80
  return get_PDNet_architecture(in_channels=in_channels, out_channels=in_channels, device=device)
81
 
82
- elif model_name == "unrolled_dpir":
83
- model = UNeXt(
84
- in_channels=in_channels,
85
- out_channels=in_channels,
86
- device=device,
87
- conv_type=conv_type,
88
- pool_type=pool_type,
89
- layer_scale_init_value=layer_scale_init_value,
90
- init_type=init_type,
91
- gain_init_conv=gain_init_conv,
92
- gain_init_linear=gain_init_linear,
93
- drop_prob=drop_prob,
94
- replk=replk,
95
- mult_fact=mult_fact,
96
- antialias=antialias,
97
- nc=nc,
98
- cond_type=cond_type,
99
- emb_physics=False,
100
- config=None,
101
- pretrained_pth=pretrained_pth,
102
- ).to(device)
103
- model = get_unrolled_architecture(model=model, weight_tied=weight_tied, device=device)
104
- return ArtifactRemoval(model, pinv=True, device=device)
105
-
106
  elif model_name == "unext_emb_physics_config_c":
 
 
 
 
 
107
  model = UNeXt(
108
  in_channels=in_channels,
109
  out_channels=in_channels,
110
  device=device,
 
111
  conv_type=conv_type,
112
  pool_type=pool_type,
113
  layer_scale_init_value=layer_scale_init_value,
 
9
 
10
 
11
  class ArtifactRemoval(nn.Module):
12
+ r"""
13
+ Artifact removal architecture :math:`\phi(A^{\top}y)`.
14
+
15
+ This differs from the dinv.models.ArtifactRemoval in that it allows to forward the physics.
16
+
17
+ In the end we should not use this for unext !!
18
+ """
19
+
20
  def __init__(self, backbone_net, pinv=False, ckpt_path=None, device=None, fm_mode=False):
21
+ super(ArtifactRemoval, self).__init__()
22
  self.pinv = pinv
23
  self.backbone_net = backbone_net
24
  self.fm_mode = fm_mode
 
32
  v.requires_grad = False
33
  self.backbone_net = self.backbone_net.to(device)
34
 
35
+
36
  def forward_basic(self, y=None, physics=None, x_in=None, t=None, **kwargs):
37
+ r"""
38
+ Reconstructs a signal estimate from measurements y
39
+
40
+ :param torch.tensor y: measurements
41
+ :param deepinv.physics.Physics physics: forward operator
42
+ """
43
  if physics is None:
44
  physics = dinv.physics.Denoising(noise_model=dinv.physics.GaussianNoise(sigma=0.), device=y.device)
45
 
 
50
 
51
  x_in = physics.A_adjoint(y) if not self.pinv else physics.A_dagger(y)
52
 
53
+ if hasattr(physics.noise_model, "sigma"):
54
+ sigma = physics.noise_model.sigma
55
+ else:
56
+ sigma = 1e-3 # WARNING: this is a default value that we may not want to use?
57
+
58
+ if hasattr(physics.noise_model, "gain"):
59
+ gamma = physics.noise_model.gain
60
+ else:
61
+ gamma = 1e-3 # WARNING: this is a default value that we may not want to use?
62
 
63
  out = self.backbone_net(x_in, physics=physics, y=y, sigma=sigma, gamma=gamma, t=t)
64
 
 
67
 
68
  return out
69
 
70
+ def forward(self, y=None, physics=None, x_in=None, **kwargs):
71
+ if 'unext' in type(self.backbone_net).__name__.lower():
72
+ return self.forward_basic(physics=physics, y=y, x_in=x_in, **kwargs)
73
+ else:
74
+ return self.backbone_net(physics=physics, y=y, **kwargs)
75
 
76
 
77
  def get_model(
78
  model_name="unext_emb_physics_config_C",
79
  device="cpu",
80
  in_channels=[1, 2, 3],
81
+ grayscale=False,
82
  conv_type="base",
83
  pool_type="base",
84
  layer_scale_init_value=1e-6,
 
91
  antialias="gaussian",
92
  nc_base=64,
93
  cond_type="base",
94
+ blind=False,
95
  pretrained_pth=None,
96
  weight_tied=True,
97
  N=4,
 
100
  relu_in_encoding=False,
101
  skip_in_encoding=True,
102
  ):
103
+ """
104
+ Load the model.
105
+
106
+ :param str model_name: name of the model
107
+ :param str device: device
108
+ :param bool grayscale: if True, the model is trained on grayscale images
109
+ :param bool train: if True, the model is trained
110
+ :return: model
111
+ """
112
  model_name = model_name.lower()
 
113
 
114
  if model_name == "pdnet":
115
  return get_PDNet_architecture(in_channels=in_channels, out_channels=in_channels, device=device)
116
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
117
  elif model_name == "unext_emb_physics_config_c":
118
+ n_chan = [1, 2, 3] # 6 for old head grayscale, complex and color = 1 + 2 + 3
119
+ residual = True if "residual" in model_name else False
120
+ nc = [nc_base * 2**i for i in range(4)]
121
+
122
+
123
  model = UNeXt(
124
  in_channels=in_channels,
125
  out_channels=in_channels,
126
  device=device,
127
+ residual=residual,
128
  conv_type=conv_type,
129
  pool_type=pool_type,
130
  layer_scale_init_value=layer_scale_init_value,