Yonuts commited on
Commit
25883b9
·
1 Parent(s): b4684a7

new layout

Browse files
Files changed (3) hide show
  1. app.py +102 -90
  2. evals.py → factories.py +3 -2
  3. models/unrolled_dpir.py +0 -304
app.py CHANGED
@@ -11,24 +11,16 @@ import torch
11
  from PIL import Image
12
  from torchvision import transforms
13
 
14
- from evals import PhysicsWithGenerator, EvalModel, BaselineModel, EvalDataset, Metric
15
 
16
 
17
- DEVICE_STR = 'cuda'
 
 
18
 
19
 
20
  ### Gradio Utils
21
 
22
- def generate_imgs_from_dataset(dataset: EvalDataset, idx: int,
23
- model: EvalModel, baseline: BaselineModel,
24
- physics: PhysicsWithGenerator, use_gen: bool,
25
- metrics: List[Metric]):
26
- ### Load 1 image
27
- x = dataset[idx] # shape : (3, 256, 256)
28
- x = x.unsqueeze(0) # shape : (1, 3, 256, 256)
29
-
30
- return generate_imgs(x, model, baseline, physics, use_gen, metrics)
31
-
32
  def generate_imgs_from_user(image,
33
  model: EvalModel, baseline: BaselineModel,
34
  physics: PhysicsWithGenerator, use_gen: bool,
@@ -37,9 +29,31 @@ def generate_imgs_from_user(image,
37
  return None, None, None, None, None, None, None, None
38
 
39
  # PIL image -> torch.Tensor
40
- x = transforms.ToTensor()(image).unsqueeze(0).to('cuda')
 
 
 
 
 
 
 
 
 
 
41
 
42
  return generate_imgs(x, model, baseline, physics, use_gen, metrics)
 
 
 
 
 
 
 
 
 
 
 
 
43
 
44
  def generate_imgs(x: torch.Tensor,
45
  model: EvalModel, baseline: BaselineModel,
@@ -75,7 +89,7 @@ def generate_imgs(x: torch.Tensor,
75
  metrics_out_baseline += f"{metric.name} = {metric(out_baseline, x).item():.4f}" + "\n"
76
 
77
  ### Process y when y shape is different from x shape
78
- if physics.name == "MRI" or "SR" in physics.name:
79
  y_plot = physics.physics.prox_l2(physics.physics.A_adjoint(y), y, 1e4)
80
  else:
81
  y_plot = y.clone()
@@ -93,18 +107,6 @@ def generate_imgs(x: torch.Tensor,
93
 
94
  return x, y, out, out_baseline, physics.display_saved_params(), metrics_y, metrics_out, metrics_out_baseline
95
 
96
- def generate_random_imgs_from_dataset(dataset: EvalDataset,
97
- model: EvalModel,
98
- baseline: BaselineModel,
99
- physics: PhysicsWithGenerator,
100
- use_gen: bool,
101
- metrics: List[Metric]):
102
- idx = random.randint(0, len(dataset)-1)
103
- x, y, out, out_baseline, saved_params_str, metrics_y, metrics_out, metrics_out_baseline = generate_imgs_from_dataset(
104
- dataset, idx, model, baseline, physics, use_gen, metrics
105
- )
106
- return idx, x, y, out, out_baseline, saved_params_str, metrics_y, metrics_out, metrics_out_baseline
107
-
108
 
109
  get_list_metrics_on_DEVICE_STR = partial(Metric.get_list_metrics, device_str=DEVICE_STR)
110
  get_eval_model_on_DEVICE_STR = partial(EvalModel, device_str=DEVICE_STR)
@@ -112,7 +114,8 @@ get_baseline_model_on_DEVICE_STR = partial(BaselineModel, device_str=DEVICE_STR)
112
  get_dataset_on_DEVICE_STR = partial(EvalDataset, device_str=DEVICE_STR)
113
  get_physics_on_DEVICE_STR = partial(PhysicsWithGenerator, device_str=DEVICE_STR)
114
 
115
- AVAILABLE_PHYSICS = PhysicsWithGenerator.all_physics
 
116
  def get_dataset(dataset_name):
117
  global AVAILABLE_PHYSICS
118
  if dataset_name == 'MRI':
@@ -124,10 +127,15 @@ def get_dataset(dataset_name):
124
  baseline_name = 'DPIR_CT'
125
  physics_name = 'CT'
126
  else:
127
- AVAILABLE_PHYSICS = ['MotionBlur_easy', 'MotionBlur_medium', 'MotionBlur_hard', 'GaussianBlur_easy', 'GaussianBlur_medium', 'GaussianBlur_hard']
 
128
  baseline_name = 'DPIR'
129
  physics_name = 'MotionBlur_easy'
130
- return get_dataset_on_DEVICE_STR(dataset_name), get_physics_on_DEVICE_STR(physics_name), get_baseline_model_on_DEVICE_STR(baseline_name)
 
 
 
 
131
 
132
 
133
  ### Gradio Blocks interface
@@ -135,9 +143,9 @@ def get_dataset(dataset_name):
135
  # Define custom CSS
136
  custom_css = """
137
  .fixed-textbox textarea {
138
- height: 90px !important; /* Adjust height to fit exactly 4 lines */
139
- overflow: scroll; /* Add a scroll bar if necessary */
140
- resize: none; /* User can resize vertically the textbox */
141
  }
142
  """
143
 
@@ -145,87 +153,88 @@ title = "Inverse problem playground" # displayed on gradio tab and in the gradi
145
  with gr.Blocks(title=title, css=custom_css) as interface:
146
  gr.Markdown("## " + title)
147
 
148
- # Loading things
149
- model_a_placeholder = gr.State(lambda: get_eval_model_on_DEVICE_STR("unext_emb_physics_config_C", "")) # lambda expression to instanciate a callable in a gr.State
150
- model_b_placeholder = gr.State(lambda: get_baseline_model_on_DEVICE_STR("DPIR")) # lambda expression to instanciate a callable in a gr.State
151
- dataset_placeholder = gr.State(lambda: get_dataset_on_DEVICE_STR("Natural"))
152
- physics_placeholder = gr.State(lambda: get_physics_on_DEVICE_STR("MotionBlur_easy")) # lambda expression to instanciate a callable in a gr.State
153
- metrics_placeholder = gr.State(get_list_metrics_on_DEVICE_STR(["PSNR"]))
 
 
 
 
 
154
 
155
- @gr.render(inputs=[dataset_placeholder, physics_placeholder, metrics_placeholder])
156
- def dynamic_layout(dataset, physics, metrics):
157
  ### LAYOUT
158
- dataset_name = dataset.name
159
- physics_name = physics.name
160
- metric_names = [metric.name for metric in metrics]
161
 
162
- # Components: Inputs/Outputs + Load EvalDataset/PhysicsWithGenerator/EvalModel/BaselineModel
 
 
 
 
 
 
 
163
  with gr.Row():
164
  with gr.Column():
 
 
 
 
 
165
  with gr.Row():
166
- with gr.Column():
167
- clean = gr.Image(label=f"{dataset_name} IMAGE", interactive=True)
168
- physics_params = gr.Textbox(label="Physics parameters", elem_classes=["fixed-textbox"], value=physics.display_saved_params())
169
- with gr.Column():
170
- y_image = gr.Image(label=f"{physics_name} IMAGE", interactive=False)
171
- y_metrics = gr.Textbox(label="Metrics(y, x)", elem_classes=["fixed-textbox"],)
 
 
 
 
 
172
 
 
 
 
173
  choose_physics = gr.Radio(choices=AVAILABLE_PHYSICS,
174
- label="List of PhysicsWithGenerator",
175
- value=physics_name)
 
 
176
  with gr.Row():
177
  key_selector = gr.Dropdown(choices=list(physics.saved_params["updatable_params"].keys()),
178
- label="Updatable Parameter Key",
179
- scale=2)
180
- value_text = gr.Textbox(label="Update Value", scale=2)
181
- update_button = gr.Button("Manually update parameter value", scale=1)
182
-
183
- with gr.Column():
184
- with gr.Row():
185
- with gr.Column():
186
- model_a_out = gr.Image(label="RAM OUTPUT", interactive=False)
187
- out_a_metric = gr.Textbox(label="Metrics(RAM(y, physics), x)", elem_classes=["fixed-textbox"])
188
- with gr.Column():
189
- model_b_out = gr.Image(label="DPIR OUTPUT", interactive=False)
190
- out_b_metric = gr.Textbox(label="Metrics(DPIR(y, physics), x)", elem_classes=["fixed-textbox"])
191
- with gr.Row():
192
- choose_dataset = gr.Radio(choices=EvalDataset.all_datasets,
193
- label="List of EvalDataset",
194
- value=dataset_name,
195
- scale=2)
196
- idx_slider = gr.Slider(minimum=0, maximum=len(dataset)-1, step=1, label="Sample index", scale=1)
197
 
198
- # Components: Load Metric + Load image Buttons
199
- with gr.Row():
200
- with gr.Column(scale=3):
201
- choose_metrics = gr.CheckboxGroup(choices=Metric.all_metrics,
202
- value=metric_names,
203
- label="Choose metrics you are interested")
204
- use_generator_button = gr.Checkbox(label="Generate valid physics parameters", scale=1)
205
- run_button = gr.Button("Run current image", scale=1)
206
- with gr.Column(scale=1):
207
- load_button = gr.Button("Load images from dataset...")
208
- load_random_button = gr.Button("Load randomly from dataset...")
209
 
210
  ### Event listeners
 
211
  choose_dataset.change(fn=get_dataset,
212
  inputs=choose_dataset,
213
  outputs=[dataset_placeholder, physics_placeholder, model_b_placeholder])
214
  choose_physics.change(fn=get_physics_on_DEVICE_STR,
215
  inputs=choose_physics,
216
  outputs=[physics_placeholder])
217
- update_button.click(fn=physics.update_and_display_params, inputs=[key_selector, value_text], outputs=physics_params)
218
- choose_metrics.change(fn=get_list_metrics_on_DEVICE_STR,
219
- inputs=choose_metrics,
220
- outputs=metrics_placeholder)
221
  run_button.click(fn=generate_imgs_from_user,
222
- inputs=[clean,
223
  model_a_placeholder,
224
  model_b_placeholder,
225
  physics_placeholder,
226
  use_generator_button,
227
  metrics_placeholder],
228
- outputs=[clean, y_image, model_a_out, model_b_out, physics_params, y_metrics, out_a_metric, out_b_metric])
 
229
  load_button.click(fn=generate_imgs_from_dataset,
230
  inputs=[dataset_placeholder,
231
  idx_slider,
@@ -234,7 +243,8 @@ with gr.Blocks(title=title, css=custom_css) as interface:
234
  physics_placeholder,
235
  use_generator_button,
236
  metrics_placeholder],
237
- outputs=[clean, y_image, model_a_out, model_b_out, physics_params, y_metrics, out_a_metric, out_b_metric])
 
238
  load_random_button.click(fn=generate_random_imgs_from_dataset,
239
  inputs=[dataset_placeholder,
240
  model_a_placeholder,
@@ -242,6 +252,8 @@ with gr.Blocks(title=title, css=custom_css) as interface:
242
  physics_placeholder,
243
  use_generator_button,
244
  metrics_placeholder],
245
- outputs=[idx_slider, clean, y_image, model_a_out, model_b_out, physics_params, y_metrics, out_a_metric, out_b_metric])
 
 
246
 
247
  interface.launch()
 
11
  from PIL import Image
12
  from torchvision import transforms
13
 
14
+ from factories import PhysicsWithGenerator, EvalModel, BaselineModel, EvalDataset, Metric
15
 
16
 
17
+ ### Config
18
+ DEVICE_STR = 'cuda' # run model inference on NVIDIA gpu
19
+ torch.set_grad_enabled(False) # stops tracking values for gradients
20
 
21
 
22
  ### Gradio Utils
23
 
 
 
 
 
 
 
 
 
 
 
24
  def generate_imgs_from_user(image,
25
  model: EvalModel, baseline: BaselineModel,
26
  physics: PhysicsWithGenerator, use_gen: bool,
 
29
  return None, None, None, None, None, None, None, None
30
 
31
  # PIL image -> torch.Tensor
32
+ x = transforms.ToTensor()(image).unsqueeze(0).to(DEVICE_STR)
33
+
34
+ return generate_imgs(x, model, baseline, physics, use_gen, metrics)
35
+
36
+ def generate_imgs_from_dataset(dataset: EvalDataset, idx: int,
37
+ model: EvalModel, baseline: BaselineModel,
38
+ physics: PhysicsWithGenerator, use_gen: bool,
39
+ metrics: List[Metric]):
40
+ ### Load 1 image
41
+ x = dataset[idx] # shape : (C, H, W)
42
+ x = x.unsqueeze(0) # shape : (1, C, H, W)
43
 
44
  return generate_imgs(x, model, baseline, physics, use_gen, metrics)
45
+
46
+ def generate_random_imgs_from_dataset(dataset: EvalDataset,
47
+ model: EvalModel,
48
+ baseline: BaselineModel,
49
+ physics: PhysicsWithGenerator,
50
+ use_gen: bool,
51
+ metrics: List[Metric]):
52
+ idx = random.randint(0, len(dataset)-1)
53
+ x, y, out, out_baseline, saved_params_str, metrics_y, metrics_out, metrics_out_baseline = generate_imgs_from_dataset(
54
+ dataset, idx, model, baseline, physics, use_gen, metrics
55
+ )
56
+ return idx, x, y, out, out_baseline, saved_params_str, metrics_y, metrics_out, metrics_out_baseline
57
 
58
  def generate_imgs(x: torch.Tensor,
59
  model: EvalModel, baseline: BaselineModel,
 
89
  metrics_out_baseline += f"{metric.name} = {metric(out_baseline, x).item():.4f}" + "\n"
90
 
91
  ### Process y when y shape is different from x shape
92
+ if physics.name == "MRI" in physics.name:
93
  y_plot = physics.physics.prox_l2(physics.physics.A_adjoint(y), y, 1e4)
94
  else:
95
  y_plot = y.clone()
 
107
 
108
  return x, y, out, out_baseline, physics.display_saved_params(), metrics_y, metrics_out, metrics_out_baseline
109
 
 
 
 
 
 
 
 
 
 
 
 
 
110
 
111
  get_list_metrics_on_DEVICE_STR = partial(Metric.get_list_metrics, device_str=DEVICE_STR)
112
  get_eval_model_on_DEVICE_STR = partial(EvalModel, device_str=DEVICE_STR)
 
114
  get_dataset_on_DEVICE_STR = partial(EvalDataset, device_str=DEVICE_STR)
115
  get_physics_on_DEVICE_STR = partial(PhysicsWithGenerator, device_str=DEVICE_STR)
116
 
117
+ AVAILABLE_PHYSICS = ['MotionBlur_easy', 'MotionBlur_medium', 'MotionBlur_hard',
118
+ 'GaussianBlur_easy', 'GaussianBlur_medium', 'GaussianBlur_hard']
119
  def get_dataset(dataset_name):
120
  global AVAILABLE_PHYSICS
121
  if dataset_name == 'MRI':
 
127
  baseline_name = 'DPIR_CT'
128
  physics_name = 'CT'
129
  else:
130
+ AVAILABLE_PHYSICS = ['MotionBlur_easy', 'MotionBlur_medium', 'MotionBlur_hard',
131
+ 'GaussianBlur_easy', 'GaussianBlur_medium', 'GaussianBlur_hard']
132
  baseline_name = 'DPIR'
133
  physics_name = 'MotionBlur_easy'
134
+
135
+ dataset = get_dataset_on_DEVICE_STR(dataset_name)
136
+ physics = get_physics_on_DEVICE_STR(physics_name)
137
+ baseline = get_baseline_model_on_DEVICE_STR(baseline_name)
138
+ return dataset, physics, baseline
139
 
140
 
141
  ### Gradio Blocks interface
 
143
  # Define custom CSS
144
  custom_css = """
145
  .fixed-textbox textarea {
146
+ height: 100px !important; /* Adjust height to fit exactly 4 lines */
147
+ overflow: scroll; /* Add a scroll bar if necessary */
148
+ resize: none; /* User can resize vertically the textbox */
149
  }
150
  """
151
 
 
153
  with gr.Blocks(title=title, css=custom_css) as interface:
154
  gr.Markdown("## " + title)
155
 
156
+ # DEFAULT VALUES
157
+ # Issue: giving directly a `torch.nn.module` to `gr.State(...)` since it has __call__ method
158
+ # Solution: using lambda expression
159
+ model_a_placeholder = gr.State(lambda: get_eval_model_on_DEVICE_STR("unext_emb_physics_config_C", ""))
160
+ model_b_placeholder = gr.State(lambda: get_baseline_model_on_DEVICE_STR("DPIR"))
161
+ dataset_placeholder = gr.State(get_dataset_on_DEVICE_STR("Natural"))
162
+ physics_placeholder = gr.State(lambda: get_physics_on_DEVICE_STR("MotionBlur_easy"))
163
+ idx_placeholder = gr.State(0)
164
+
165
+ metric_names = ["PSNR"]
166
+ metrics_placeholder = gr.State(get_list_metrics_on_DEVICE_STR(metric_names))
167
 
168
+ @gr.render(inputs=[dataset_placeholder, physics_placeholder])
169
+ def dynamic_layout(dataset, physics):
170
  ### LAYOUT
 
 
 
171
 
172
+ # Display images
173
+ with gr.Row():
174
+ gt_img = gr.Image(label=f"Ground-truth IMAGE", interactive=True)
175
+ observed_img = gr.Image(label=f"Observed IMAGE", interactive=False)
176
+ model_a_out = gr.Image(label="RAM OUTPUT", interactive=False)
177
+ model_b_out = gr.Image(label="DPIR OUTPUT", interactive=False)
178
+
179
+ # Manage datasets and display metric values
180
  with gr.Row():
181
  with gr.Column():
182
+ run_button = gr.Button("Demo on above image")
183
+ choose_dataset = gr.Radio(choices=EvalDataset.all_datasets,
184
+ label="Datasets",
185
+ value=dataset.name)
186
+ idx_slider = gr.Slider(minimum=0, maximum=len(dataset)-1, step=1, label="Sample index")
187
  with gr.Row():
188
+ load_button = gr.Button("Run on index image from dataset")
189
+ load_random_button = gr.Button("Run on random image from dataset")
190
+ with gr.Column():
191
+ observed_metrics = gr.Textbox(label="PSNR(Observed, Ground-truth)",
192
+ elem_classes=["fixed-textbox"])
193
+ with gr.Column():
194
+ out_a_metric = gr.Textbox(label="PSNR(RAM(Observed, Ground-truth)",
195
+ elem_classes=["fixed-textbox"])
196
+ with gr.Column():
197
+ out_b_metric = gr.Textbox(label="PSNR(DPIR(Observed, Ground-truth)",
198
+ elem_classes=["fixed-textbox"])
199
 
200
+ # Manage physics
201
+ with gr.Row():
202
+ with gr.Column(scale=1):
203
  choose_physics = gr.Radio(choices=AVAILABLE_PHYSICS,
204
+ label="Physics",
205
+ value=physics.name)
206
+ use_generator_button = gr.Checkbox(label="Generate physics parameters during inference")
207
+ with gr.Column(scale=1):
208
  with gr.Row():
209
  key_selector = gr.Dropdown(choices=list(physics.saved_params["updatable_params"].keys()),
210
+ label="Updatable Parameter Key")
211
+ value_text = gr.Textbox(label="Update Value")
212
+ update_button = gr.Button("Manually update parameter value")
213
+ with gr.Column(scale=2):
214
+ physics_params = gr.Textbox(label="Physics parameters",
215
+ elem_classes=["fixed-textbox"],
216
+ value=physics.display_saved_params())
 
 
 
 
 
 
 
 
 
 
 
 
217
 
 
 
 
 
 
 
 
 
 
 
 
218
 
219
  ### Event listeners
220
+
221
  choose_dataset.change(fn=get_dataset,
222
  inputs=choose_dataset,
223
  outputs=[dataset_placeholder, physics_placeholder, model_b_placeholder])
224
  choose_physics.change(fn=get_physics_on_DEVICE_STR,
225
  inputs=choose_physics,
226
  outputs=[physics_placeholder])
227
+ update_button.click(fn=physics.update_and_display_params,
228
+ inputs=[key_selector, value_text], outputs=physics_params)
 
 
229
  run_button.click(fn=generate_imgs_from_user,
230
+ inputs=[gt_img,
231
  model_a_placeholder,
232
  model_b_placeholder,
233
  physics_placeholder,
234
  use_generator_button,
235
  metrics_placeholder],
236
+ outputs=[gt_img, observed_img, model_a_out, model_b_out,
237
+ physics_params, observed_metrics, out_a_metric, out_b_metric])
238
  load_button.click(fn=generate_imgs_from_dataset,
239
  inputs=[dataset_placeholder,
240
  idx_slider,
 
243
  physics_placeholder,
244
  use_generator_button,
245
  metrics_placeholder],
246
+ outputs=[gt_img, observed_img, model_a_out, model_b_out,
247
+ physics_params, observed_metrics, out_a_metric, out_b_metric])
248
  load_random_button.click(fn=generate_random_imgs_from_dataset,
249
  inputs=[dataset_placeholder,
250
  model_a_placeholder,
 
252
  physics_placeholder,
253
  use_generator_button,
254
  metrics_placeholder],
255
+ outputs=[idx_slider, gt_img, observed_img, model_a_out, model_b_out,
256
+ physics_params, observed_metrics, out_a_metric, out_b_metric])
257
+
258
 
259
  interface.launch()
evals.py → factories.py RENAMED
@@ -8,6 +8,7 @@ from torchvision import transforms
8
 
9
  from datasets import Preprocessed_fastMRI, Preprocessed_LIDCIDRI, LsdirMiniDataset
10
  from model_factory import get_model
 
11
 
12
  DEFAULT_MODEL_PARAMS = {
13
  "in_channels": [1, 2, 3],
@@ -159,7 +160,7 @@ class PhysicsWithGenerator(torch.nn.Module):
159
 
160
  def _update_save_params(self, key: str, value: Any) -> None:
161
  """Update value of an existing key in save_params."""
162
- if key in list(self.saved_params["updatable_params"].keys()):
163
  if type(value) == str: # it may be only a str representation
164
  # type: str -> ???
165
  value = self.saved_params["updatable_params_converter"][key](value)
@@ -168,7 +169,7 @@ class PhysicsWithGenerator(torch.nn.Module):
168
  value = float(f"{value:.4f}") # keeps only 4 significant digits
169
  self.saved_params["updatable_params"][key] = value
170
 
171
- def update_and_display_params(self, key, value) -> str:
172
  """_update_save_params + update physics with saved_params + display_saved_params"""
173
  self._update_save_params(key, value)
174
 
 
8
 
9
  from datasets import Preprocessed_fastMRI, Preprocessed_LIDCIDRI, LsdirMiniDataset
10
  from model_factory import get_model
11
+ from physics.blur_generator import GaussianBlurGenerator
12
 
13
  DEFAULT_MODEL_PARAMS = {
14
  "in_channels": [1, 2, 3],
 
160
 
161
  def _update_save_params(self, key: str, value: Any) -> None:
162
  """Update value of an existing key in save_params."""
163
+ if value != "" and key in list(self.saved_params["updatable_params"].keys()):
164
  if type(value) == str: # it may be only a str representation
165
  # type: str -> ???
166
  value = self.saved_params["updatable_params_converter"][key](value)
 
169
  value = float(f"{value:.4f}") # keeps only 4 significant digits
170
  self.saved_params["updatable_params"][key] = value
171
 
172
+ def update_and_display_params(self, key: str, value: Any) -> str:
173
  """_update_save_params + update physics with saved_params + display_saved_params"""
174
  self._update_save_params(key, value)
175
 
models/unrolled_dpir.py DELETED
@@ -1,304 +0,0 @@
1
- import numpy as np
2
- import deepinv
3
- import torch
4
- import deepinv as dinv
5
- from deepinv.optim.data_fidelity import L2
6
- from deepinv.optim.prior import PnP
7
- from deepinv.unfolded import unfolded_builder
8
- import copy
9
- import deepinv.optim.utils
10
-
11
- class PoissonGaussianDistance(dinv.optim.Distance):
12
- r"""
13
- Implementation of :math:`\distancename` as the normalized :math:`\ell_2` norm
14
-
15
- .. math::
16
- f(x) = (x-y)^{T}\Sigma_y(x-y)
17
-
18
- with :math:`\Sigma_y=\text{diag}(gamma y + \sigma^2)`
19
-
20
- :param float sigma: Gaussian noise parameter. Default: 1.
21
- :param float gain: Poisson noise parameter. Default 0.
22
- """
23
-
24
- def __init__(self, sigma=1.0, gain=0.):
25
- super().__init__()
26
- self.sigma = sigma
27
- self.gain = gain
28
-
29
- def fn(self, x, y, *args, **kwargs):
30
- r"""
31
- Computes the distance :math:`\distance{x}{y}` i.e.
32
-
33
- .. math::
34
-
35
- \distance{x}{y} = \frac{1}{2}\|x-y\|^2
36
-
37
-
38
- :param torch.Tensor u: Variable :math:`x` at which the data fidelity is computed.
39
- :param torch.Tensor y: Data :math:`y`.
40
- :return: (:class:`torch.Tensor`) data fidelity :math:`\datafid{u}{y}` of size `B` with `B` the size of the batch.
41
- """
42
- norm = 1.0 / (self.sigma**2 + y * self.gain)
43
- z = (x - y) * norm
44
- d = 0.5 * torch.norm(z.reshape(z.shape[0], -1), p=2, dim=-1) ** 2
45
- return d
46
-
47
- def grad(self, x, y, *args, **kwargs):
48
- r"""
49
- Computes the gradient of :math:`\distancename`, that is :math:`\nabla_{x}\distance{x}{y}`, i.e.
50
-
51
- .. math::
52
-
53
- \nabla_{x}\distance{x}{y} = \frac{1}{\sigma^2} x-y
54
-
55
-
56
- :param torch.Tensor x: Variable :math:`x` at which the gradient is computed.
57
- :param torch.Tensor y: Observation :math:`y`.
58
- :return: (:class:`torch.Tensor`) gradient of the distance function :math:`\nabla_{x}\distance{x}{y}`.
59
- """
60
- norm = 1.0 / (self.sigma**2 + y * self.gain)
61
- return (x - y) * norm
62
-
63
- def prox(self, x, y, *args, gamma=1.0, **kwargs):
64
- r"""
65
- Proximal operator of :math:`\gamma \distance{x}{y} = \frac{\gamma}{2 \sigma^2} \|x-y\|^2`.
66
-
67
- Computes :math:`\operatorname{prox}_{\gamma \distancename}`, i.e.
68
-
69
- .. math::
70
-
71
- \operatorname{prox}_{\gamma \distancename} = \underset{u}{\text{argmin}} \frac{\gamma}{2\sigma^2}\|u-y\|_2^2+\frac{1}{2}\|u-x\|_2^2
72
-
73
-
74
- :param torch.Tensor x: Variable :math:`x` at which the proximity operator is computed.
75
- :param torch.Tensor y: Data :math:`y`.
76
- :param float gamma: thresholding parameter.
77
- :return: (:class:`torch.Tensor`) proximity operator :math:`\operatorname{prox}_{\gamma \distancename}(x)`.
78
- """
79
- norm = 1.0 / (self.sigma**2 + y * self.gain)
80
- return (x + norm * gamma * y) / (1 + gamma * norm)
81
-
82
-
83
- class PoissonGaussianDataFidelity(dinv.optim.DataFidelity):
84
- r"""
85
- Implementation of the data-fidelity as the normalized :math:`\ell_2` norm
86
-
87
- .. math::
88
-
89
- f(x) = \|\forw{x}-y\|^2_{\text{diag}(\sigma^2 + y \gamma)}
90
-
91
- It can be used to define a log-likelihood function associated with Poisson Gaussian noise
92
- by setting an appropriate noise level :math:`\sigma`.
93
-
94
- :param float sigma: Standard deviation of the noise to be used as a normalisation factor.
95
- :param float gain: Gain factor of the data-fidelity term.
96
- """
97
-
98
- def __init__(self, sigma=1.0, gain=0.):
99
- super().__init__()
100
- self.d = PoissonGaussianDistance(sigma=sigma, gain=gain)
101
- self.gain = gain
102
- self.sigma = sigma
103
-
104
- def prox(self, x, y, physics, gamma=1.0, *args, **kwargs):
105
- r"""
106
- Proximal operator of :math:`\gamma \datafid{Ax}{y} = \frac{\gamma}{2\sigma^2}\|Ax-y\|^2`.
107
-
108
- Computes :math:`\operatorname{prox}_{\gamma \datafidname}`, i.e.
109
-
110
- .. math::
111
-
112
- \operatorname{prox}_{\gamma \datafidname} = \underset{u}{\text{argmin}} \frac{\gamma}{2\sigma^2}\|Au-y\|_2^2+\frac{1}{2}\|u-x\|_2^2
113
-
114
-
115
- :param torch.Tensor x: Variable :math:`x` at which the proximity operator is computed.
116
- :param torch.Tensor y: Data :math:`y`.
117
- :param deepinv.physics.Physics physics: physics model.
118
- :param float gamma: stepsize of the proximity operator.
119
- :return: (:class:`torch.Tensor`) proximity operator :math:`\operatorname{prox}_{\gamma \datafidname}(x)`.
120
- """
121
- assert isinstance(physics, dinv.physics.LinearPhysics), "not implemented for non-linear physics"
122
- if isinstance(physics, dinv.physics.StackedPhysics):
123
- device=y[0].device
124
- noise_model = physics[-1].noise_model
125
- else:
126
- device=y.device
127
- noise_model = physics.noise_model
128
- if hasattr(noise_model, "gain"):
129
- self.gain = noise_model.gain.detach().to(device)
130
- if hasattr(noise_model, "sigma"):
131
- self.sigma = noise_model.sigma.detach().to(device)
132
- # Ensure sigma is a tensor and reshape if necessary
133
- if isinstance(self.sigma, float):
134
- self.sigma = torch.tensor([self.sigma], device=device)
135
- if self.sigma.ndim == 0 :
136
- self.sigma = self.sigma.unsqueeze(0).to(device)
137
- # Ensure gain is a tensor and reshape if necessary
138
- if isinstance(self.gain, float):
139
- self.gain = torch.tensor([self.gain], device=device)
140
- if self.gain.ndim == 0 :
141
- self.gain = self.gain.unsqueeze(0).to(device)
142
- if self.gain[0] > 0 :
143
- norm = gamma / (self.sigma[:, None, None, None]**2 + y * self.gain[:, None, None, None])
144
- else :
145
- norm = gamma / (self.sigma[:, None, None, None]**2)
146
- A = lambda u: physics.A_adjoint(physics.A(u)*norm) + u
147
- b = physics.A_adjoint(norm*y) + x
148
- return deepinv.optim.utils.conjugate_gradient(A, b, init=x, max_iter=3, tol=1e-3)
149
-
150
- from deepinv.optim.optim_iterators import OptimIterator, fStep, gStep
151
-
152
- class myHQSIteration(OptimIterator):
153
- r"""
154
- Single iteration of half-quadratic splitting.
155
-
156
- Class for a single iteration of the Half-Quadratic Splitting (HQS) algorithm for minimising :math:`f(x) + \lambda \regname(x)`.
157
- The iteration is given by
158
-
159
-
160
- .. math::
161
- \begin{equation*}
162
- \begin{aligned}
163
- u_{k} &= \operatorname{prox}_{\gamma f}(x_k) \\
164
- x_{k+1} &= \operatorname{prox}_{\sigma \lambda \regname}(u_k).
165
- \end{aligned}
166
- \end{equation*}
167
-
168
-
169
- where :math:`\gamma` and :math:`\sigma` are step-sizes. Note that this algorithm does not converge to
170
- a minimizer of :math:`f(x) + \lambda \regname(x)`, but instead to a minimizer of
171
- :math:`\gamma\, ^1f+\sigma \lambda \regname`, where :math:`^1f` denotes
172
- the Moreau envelope of :math:`f`
173
-
174
- """
175
-
176
- def __init__(self, **kwargs):
177
- super(myHQSIteration, self).__init__(**kwargs)
178
- self.g_step = mygStepHQS(**kwargs)
179
- self.f_step = myfStepHQS(**kwargs)
180
- self.requires_prox_g = True
181
-
182
- class myfStepHQS(fStep):
183
- r"""
184
- HQS fStep module.
185
- """
186
-
187
- def __init__(self, **kwargs):
188
- super(myfStepHQS, self).__init__(**kwargs)
189
-
190
- def forward(self, x, cur_data_fidelity, cur_params, y, physics):
191
- r"""
192
- Single proximal step on the data-fidelity term :math:`f`.
193
-
194
- :param torch.Tensor x: Current iterate :math:`x_k`.
195
- :param deepinv.optim.DataFidelity cur_data_fidelity: Instance of the DataFidelity class defining the current data_fidelity.
196
- :param dict cur_params: Dictionary containing the current parameters of the algorithm.
197
- :param torch.Tensor y: Input data.
198
- :param deepinv.physics.Physics physics: Instance of the physics modeling the data-fidelity term.
199
- """
200
- return cur_data_fidelity.prox(x, y, physics, gamma=cur_params["stepsize"])
201
-
202
- class mygStepHQS(gStep):
203
- r"""
204
- HQS gStep module.
205
- """
206
-
207
- def __init__(self, **kwargs):
208
- super(mygStepHQS, self).__init__(**kwargs)
209
-
210
- def forward(self, x, cur_prior, cur_params):
211
- r"""
212
- Single proximal step on the prior term :math:`\lambda \regname`.
213
-
214
- :param torch.Tensor x: Current iterate :math:`x_k`.
215
- :param dict cur_prior: Class containing the current prior.
216
- :param dict cur_params: Dictionary containing the current parameters of the algorithm.
217
- """
218
- return cur_prior.prox(
219
- x,
220
- sigma_denoiser = cur_params["g_param"],
221
- gain_denoiser = cur_params["gain_param"],
222
- gamma=cur_params["lambda"] * cur_params["stepsize"],
223
- )
224
-
225
-
226
- def get_unrolled_architecture(gain_param_init = 1e-3, weight_tied = True, model = None, device = 'cpu'):
227
-
228
- # Unrolled optimization algorithm parameters
229
- max_iter = 8 # number of unfolded layers
230
-
231
- # Select the data fidelity term
232
-
233
-
234
- # Set up the trainable denoising prior
235
- # Here the prior model is common for all iterations
236
- if model is not None :
237
- denoiser = model.to(device)
238
- else :
239
- denoiser = dinv.models.DRUNet(
240
- pretrained= '/lustre/fswork/projects/rech/nyd/commun/mterris/base_checkpoints/drunet_deepinv_color_finetune_22k.pth',
241
- ).to(device)
242
-
243
- class myPnP(PnP):
244
- r"""
245
- Gradient-Step Denoiser prior.
246
- """
247
-
248
- def __init__(self, *args, **kwargs):
249
- super().__init__(*args, **kwargs)
250
-
251
- def prox(self, x, sigma_denoiser, gain_denoiser, *args, **kwargs):
252
- if not self.training:
253
- pad = (-x.size(-2) % 8, -x.size(-1) % 8)
254
- x = torch.nn.functional.pad(x, (0, pad[1], 0, pad[0]), mode="constant")
255
- out = self.denoiser(x, sigma=sigma_denoiser, gamma=gain_denoiser)
256
- if not self.training:
257
- out = out[..., : -pad[0] or None, : -pad[1] or None]
258
- return out
259
-
260
- data_fidelity = PoissonGaussianDataFidelity()
261
-
262
- if not weight_tied :
263
- prior = [myPnP(denoiser=copy.deepcopy(denoiser)) for i in range(max_iter)]
264
- else :
265
- prior = [myPnP(denoiser=denoiser)]
266
-
267
- def get_DPIR_params(noise_level_img, max_iter=8):
268
- r"""
269
- Default parameters for the DPIR Plug-and-Play algorithm.
270
-
271
- :param float noise_level_img: Noise level of the input image.
272
- """
273
- s1 = 49.0 / 255.0
274
- s2 = noise_level_img
275
- sigma_denoiser = np.logspace(np.log10(s1), np.log10(s2), max_iter).astype(
276
- np.float32
277
- )
278
- stepsize = (sigma_denoiser / max(0.01, noise_level_img)) ** 2
279
- lamb = 1 / 0.23
280
- return list(sigma_denoiser), list(lamb * stepsize)
281
-
282
- sigma_denoiser, stepsize = get_DPIR_params(0.05)
283
- stepsize = torch.tensor(stepsize) * (torch.tensor(sigma_denoiser)**2)
284
- gain_denoiser = [gain_param_init]*len(sigma_denoiser)
285
- params_algo = {"stepsize": stepsize, "g_param": sigma_denoiser, "gain_param": gain_denoiser}
286
-
287
- trainable_params = [
288
- "g_param",
289
- "gain_param"
290
- "stepsize",
291
- ] # define which parameters from 'params_algo' are trainable
292
-
293
- # Define the unfolded trainable model.
294
- model = unfolded_builder(
295
- iteration=myHQSIteration(),
296
- params_algo=params_algo.copy(),
297
- trainable_params=trainable_params,
298
- data_fidelity=data_fidelity,
299
- max_iter=max_iter,
300
- prior=prior,
301
- device=device,
302
- )
303
-
304
- return model.to(device)