msong97 commited on
Commit
4dc3e99
·
0 Parent(s):

gradio demo

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +36 -0
  2. .gitignore +2 -0
  3. README.md +13 -0
  4. app.py +269 -0
  5. datasets.py +84 -0
  6. evals.py +564 -0
  7. img_samples/LSDIR_samples/0001000/0000007_s005.png +0 -0
  8. img_samples/LSDIR_samples/0001000/0000030_s003.png +0 -0
  9. img_samples/LSDIR_samples/0001000/0000067_s005.png +0 -0
  10. img_samples/LSDIR_samples/0001000/0000082_s003.png +0 -0
  11. img_samples/LSDIR_samples/0001000/0000110_s002.png +0 -0
  12. img_samples/LSDIR_samples/0001000/0000125_s003.png +0 -0
  13. img_samples/LSDIR_samples/0001000/0000154_s007.png +0 -0
  14. img_samples/LSDIR_samples/0001000/0000247_s007.png +0 -0
  15. img_samples/LSDIR_samples/0001000/0000259_s003.png +0 -0
  16. img_samples/LSDIR_samples/0001000/0000405_s008.png +0 -0
  17. img_samples/LSDIR_samples/0001000/0000578_s002.png +0 -0
  18. img_samples/LSDIR_samples/0001000/0000669_s010.png +0 -0
  19. img_samples/LSDIR_samples/0001000/0000689_s006.png +0 -0
  20. img_samples/LSDIR_samples/0001000/0000715_s011.png +0 -0
  21. img_samples/LSDIR_samples/0001000/0000752_s010.png +0 -0
  22. img_samples/LSDIR_samples/0001000/0000803_s012.png +0 -0
  23. img_samples/LSDIR_samples/0001000/0000825_s012.png +0 -0
  24. img_samples/LSDIR_samples/0001000/0000921_s012.png +0 -0
  25. img_samples/LSDIR_samples/0001000/0000958_s004.png +0 -0
  26. img_samples/LSDIR_samples/0001000/0000994_s021.png +0 -0
  27. img_samples/LSDIR_samples/0009000/0008033_s006.png +0 -0
  28. img_samples/LSDIR_samples/0009000/0008068_s005.png +0 -0
  29. img_samples/LSDIR_samples/0009000/0008115_s004.png +0 -0
  30. img_samples/LSDIR_samples/0009000/0008217_s002.png +0 -0
  31. img_samples/LSDIR_samples/0009000/0008294_s010.png +0 -0
  32. img_samples/LSDIR_samples/0009000/0008315_s053.png +0 -0
  33. img_samples/LSDIR_samples/0009000/0008340_s015.png +0 -0
  34. img_samples/LSDIR_samples/0009000/0008361_s009.png +0 -0
  35. img_samples/LSDIR_samples/0009000/0008386_s007.png +0 -0
  36. img_samples/LSDIR_samples/0009000/0008491_s006.png +0 -0
  37. img_samples/LSDIR_samples/0009000/0008528_s007.png +0 -0
  38. img_samples/LSDIR_samples/0009000/0008571_s007.png +0 -0
  39. img_samples/LSDIR_samples/0009000/0008573_s012.png +0 -0
  40. img_samples/LSDIR_samples/0009000/0008605_s007.png +0 -0
  41. img_samples/LSDIR_samples/0009000/0008611_s002.png +0 -0
  42. img_samples/LSDIR_samples/0009000/0008631_s005.png +0 -0
  43. img_samples/LSDIR_samples/0009000/0008681_s008.png +0 -0
  44. img_samples/LSDIR_samples/0009000/0008703_s013.png +0 -0
  45. img_samples/LSDIR_samples/0009000/0008714_s010.png +0 -0
  46. img_samples/LSDIR_samples/0009000/0008774_s004.png +0 -0
  47. img_samples/LSDIR_samples/0023000/0022020_s005.png +0 -0
  48. img_samples/LSDIR_samples/0023000/0022037_s011.png +0 -0
  49. img_samples/LSDIR_samples/0023000/0022059_s008.png +0 -0
  50. img_samples/LSDIR_samples/0023000/0022135_s002.png +0 -0
.gitattributes ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ *.pth.tar filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ .ipynb
2
+ __pycache__
README.md ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: Denoising
3
+ emoji: 💻
4
+ colorFrom: blue
5
+ colorTo: yellow
6
+ sdk: gradio
7
+ sdk_version: 4.19.2
8
+ app_file: app.py
9
+ pinned: false
10
+ license: bsd-3-clause
11
+ ---
12
+
13
+ Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py ADDED
@@ -0,0 +1,269 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import os
3
+ import random
4
+ from functools import partial
5
+ from pathlib import Path
6
+ from typing import List
7
+
8
+ import deepinv as dinv
9
+ import gradio as gr
10
+ import torch
11
+ from PIL import Image
12
+ from torchvision import transforms
13
+
14
+ from evals import PhysicsWithGenerator, EvalModel, BaselineModel, EvalDataset, Metric
15
+
16
+
17
+ ### Gradio Utils
18
+ def generate_imgs(dataset: EvalDataset, idx: int,
19
+ model: EvalModel, baseline: BaselineModel,
20
+ physics: PhysicsWithGenerator, use_gen: bool,
21
+ metrics: List[Metric]):
22
+ ### Load 1 image
23
+ x = dataset[idx] # shape : (3, 256, 256)
24
+ x = x.unsqueeze(0) # shape : (1, 3, 256, 256)
25
+
26
+ with torch.no_grad():
27
+ ### Compute y
28
+ y = physics(x, use_gen) # possible reduction in img shape due to Blurring
29
+
30
+ ### Compute x_hat
31
+ out = model(y=y, physics=physics.physics)
32
+ out_baseline = baseline(y=y, physics=physics.physics)
33
+
34
+ ### Process tensors before metric computation
35
+ if "Blur" in physics.name:
36
+ w_1, w_2 = (x.shape[2] - y.shape[2]) // 2, (x.shape[2] + y.shape[2]) // 2
37
+ h_1, h_2 = (x.shape[3] - y.shape[3]) // 2, (x.shape[3] + y.shape[3]) // 2
38
+
39
+ x = x[..., w_1:w_2, h_1:h_2]
40
+ out = out[..., w_1:w_2, h_1:h_2]
41
+ if out_baseline.shape != out.shape:
42
+ out_baseline = out_baseline[..., w_1:w_2, h_1:h_2]
43
+
44
+ ### Metrics
45
+ metrics_y = ""
46
+ metrics_out = ""
47
+ metrics_out_baseline = ""
48
+ for metric in metrics:
49
+ if y.shape == x.shape:
50
+ metrics_y += f"{metric.name} = {metric(y, x).item():.4f}" + "\n"
51
+ metrics_out += f"{metric.name} = {metric(out, x).item():.4f}" + "\n"
52
+ metrics_out_baseline += f"{metric.name} = {metric(out_baseline, x).item():.4f}" + "\n"
53
+
54
+ ### Process y when y shape is different from x shape
55
+ if physics.name == "MRI" or "SR" in physics.name:
56
+ y_plot = physics.physics.prox_l2(physics.physics.A_adjoint(y), y, 1e4)
57
+ else:
58
+ y_plot = y.clone()
59
+
60
+ ### Processing images for plotting :
61
+ # - clip value outside of [0,1]
62
+ # - shape (1, C, H, W) -> (C, H, W)
63
+ # - torch.Tensor object -> Pil object
64
+ process_img = partial(dinv.utils.plotting.preprocess_img, rescale_mode="clip")
65
+ to_pil = transforms.ToPILImage()
66
+ x = to_pil(process_img(x)[0].to('cpu'))
67
+ y = to_pil(process_img(y_plot)[0].to('cpu'))
68
+ out = to_pil(process_img(out)[0].to('cpu'))
69
+ out_baseline = to_pil(process_img(out_baseline)[0].to('cpu'))
70
+
71
+
72
+ return x, y, out, out_baseline, physics.display_saved_params(), metrics_y, metrics_out, metrics_out_baseline
73
+
74
+ def update_random_idx_and_generate_imgs(dataset: EvalDataset,
75
+ model: EvalModel,
76
+ baseline: BaselineModel,
77
+ physics: PhysicsWithGenerator,
78
+ use_gen: bool,
79
+ metrics: List[Metric]):
80
+ idx = random.randint(0, len(dataset))
81
+ x, y, out, out_baseline, saved_params_str, metrics_y, metrics_out, metrics_out_baseline = generate_imgs(dataset,
82
+ idx,
83
+ model,
84
+ baseline,
85
+ physics,
86
+ use_gen,
87
+ metrics)
88
+ return idx, x, y, out, out_baseline, saved_params_str, metrics_y, metrics_out, metrics_out_baseline
89
+
90
+ def save_imgs(dataset: EvalDataset, idx: int, physics: PhysicsWithGenerator,
91
+ model_a: EvalModel | BaselineModel, model_b: EvalModel | BaselineModel,
92
+ x: Image.Image, y: Image.Image,
93
+ out_a: Image.Image, out_b: Image.Image,
94
+ y_metrics_str: str,
95
+ out_a_metric_str : str, out_b_metric_str: str) -> None:
96
+
97
+ ### PROCESSES STR
98
+ physics_params_str = ""
99
+ for param_name, param_value in physics.saved_params["updatable_params"].items():
100
+ physics_params_str += f"{param_name}_{param_value}-"
101
+ physics_params_str = physics_params_str[:-1] if physics_params_str.endswith("-") else physics_params_str
102
+ y_metrics_str = y_metrics_str.replace(" = ", "_").replace("\n", "-")
103
+ y_metrics_str = y_metrics_str[:-1] if y_metrics_str.endswith("-") else y_metrics_str
104
+ out_a_metric_str = out_a_metric_str.replace(" = ", "_").replace("\n", "-")
105
+ out_a_metric_str = out_a_metric_str[:-1] if out_a_metric_str.endswith("-") else out_a_metric_str
106
+ out_b_metric_str = out_b_metric_str.replace(" = ", "_").replace("\n", "-")
107
+ out_b_metric_str = out_b_metric_str[:-1] if out_b_metric_str.endswith("-") else out_b_metric_str
108
+
109
+ save_path = SAVE_IMG_DIR / f"{dataset.name}+{idx}+{physics.name}+{physics_params_str}+{y_metrics_str}+{model_a.name}+{out_a_metric_str}+{model_b.name}+{out_b_metric_str}.png"
110
+ titles = [f"{dataset.name}[{idx}]",
111
+ f"y = {physics.name}(x)",
112
+ f"{model_a.name}",
113
+ f"{model_b.name}"]
114
+
115
+ # Pil object -> torch.Tensor
116
+ to_tensor = transforms.ToTensor()
117
+ x = to_tensor(x)
118
+ y = to_tensor(y)
119
+ out_a = to_tensor(out_a)
120
+ out_b = to_tensor(out_b)
121
+
122
+ dinv.utils.plot([x, y, out_a, out_b], titles=titles, show=False, save_fn=save_path)
123
+
124
+ get_list_metrics_on_DEVICE_STR = partial(Metric.get_list_metrics, device_str=DEVICE_STR)
125
+ get_baseline_model_on_DEVICE_STR = partial(BaselineModel, device_str=DEVICE_STR)
126
+ get_eval_model_on_DEVICE_STR = partial(EvalModel, device_str=DEVICE_STR)
127
+ get_dataset_on_DEVICE_STR = partial(EvalDataset, device_str=DEVICE_STR)
128
+ get_physics_generator_on_DEVICE_STR = partial(PhysicsWithGenerator, device_str=DEVICE_STR)
129
+
130
+ def get_model(model_name, ckpt_pth):
131
+ if model_name in BaselineModel.all_baselines:
132
+ return get_baseline_model_on_DEVICE_STR(model_name)
133
+ else:
134
+ return get_eval_model_on_DEVICE_STR(model_name, ckpt_pth)
135
+
136
+
137
+ ### Gradio Blocks interface
138
+
139
+ # Define custom CSS
140
+ custom_css = """
141
+ .fixed-textbox textarea {
142
+ height: 90px !important; /* Adjust height to fit exactly 4 lines */
143
+ overflow: scroll; /* Add a scroll bar if necessary */
144
+ resize: none; /* User can resize vertically the textbox */
145
+ }
146
+ """
147
+
148
+ title = "Inverse problem playground" # displayed on gradio tab and in the gradio page
149
+ with gr.Blocks(title=title, css=custom_css) as interface:
150
+ gr.Markdown("## " + title)
151
+
152
+ # Loading things
153
+ model_a_placeholder = gr.State(lambda: get_eval_model_on_DEVICE_STR("unext_emb_physics_config_C", "")) # lambda expression to instanciate a callable in a gr.State
154
+ model_b_placeholder = gr.State(lambda: get_baseline_model_on_DEVICE_STR("DRUNET")) # lambda expression to instanciate a callable in a gr.State
155
+ dataset_placeholder = gr.State(get_dataset_on_DEVICE_STR("DIV2K_valid_HR"))
156
+ physics_placeholder = gr.State(lambda: get_physics_generator_on_DEVICE_STR("Denoising")) # lambda expression to instanciate a callable in a gr.State
157
+ metrics_placeholder = gr.State(get_list_metrics_on_DEVICE_STR(["PSNR"]))
158
+
159
+ @gr.render(inputs=[model_a_placeholder, model_b_placeholder, dataset_placeholder, physics_placeholder, metrics_placeholder])
160
+ def dynamic_layout(model_a, model_b, dataset, physics, metrics):
161
+ ### LAYOUT
162
+ model_a_name = model_a.base_name
163
+ model_a_full_name = model_a.name
164
+ model_b_name = model_b.base_name
165
+ model_b_full_name = model_b.name
166
+ dataset_name = dataset.name
167
+ physics_name = physics.name
168
+ metric_names = [metric.name for metric in metrics]
169
+
170
+ # Components: Inputs/Outputs + Load EvalDataset/PhysicsWithGenerator/EvalModel/BaselineModel
171
+ with gr.Row():
172
+ with gr.Column():
173
+ with gr.Row():
174
+ with gr.Column():
175
+ clean = gr.Image(label=f"{dataset_name} IMAGE", interactive=False)
176
+ physics_params = gr.Textbox(label="Physics parameters", elem_classes=["fixed-textbox"], value=physics.display_saved_params())
177
+ with gr.Column():
178
+ y_image = gr.Image(label=f"{physics_name} IMAGE", interactive=False)
179
+ y_metrics = gr.Textbox(label="Metrics(y, x)", elem_classes=["fixed-textbox"],)
180
+ with gr.Row():
181
+ choose_dataset = gr.Radio(choices=EvalDataset.all_datasets,
182
+ label="List of EvalDataset",
183
+ value=dataset_name,
184
+ scale=2)
185
+ idx_slider = gr.Slider(minimum=0, maximum=len(dataset)-1, step=1, label="Sample index", scale=1)
186
+
187
+ choose_physics = gr.Radio(choices=PhysicsWithGenerator.all_physics,
188
+ label="List of PhysicsWithGenerator",
189
+ value=physics_name)
190
+ with gr.Row():
191
+ key_selector = gr.Dropdown(choices=list(physics.saved_params["updatable_params"].keys()),
192
+ label="Updatable Parameter Key",
193
+ scale=2)
194
+ value_text = gr.Textbox(label="Update Value", scale=2)
195
+ with gr.Column(scale=1):
196
+ update_button = gr.Button("Update Param")
197
+ use_generator_button = gr.Checkbox(label="Use param generator")
198
+
199
+ with gr.Column():
200
+ with gr.Row():
201
+ with gr.Column():
202
+ model_a_out = gr.Image(label=f"{model_a_full_name} OUTPUT", interactive=False)
203
+ out_a_metric = gr.Textbox(label="Metrics(model_a(y), x)", elem_classes=["fixed-textbox"])
204
+ load_model_a = gr.Button("Load model A...", scale=1)
205
+ with gr.Column():
206
+ model_b_out = gr.Image(label=f"{model_b_full_name} OUTPUT", interactive=False)
207
+ out_b_metric = gr.Textbox(label="Metrics(model_b(y), x)", elem_classes=["fixed-textbox"])
208
+ load_model_b = gr.Button("Load model B...", scale=1)
209
+ with gr.Row():
210
+ choose_model_a = gr.Dropdown(choices=EvalModel.all_models + BaselineModel.all_baselines,
211
+ label="List of Model A",
212
+ value=model_a_name,
213
+ scale=2)
214
+ path_a_str = gr.Textbox(value=model_a.ckpt_pth, label="Checkpoint path", scale=3)
215
+ with gr.Row():
216
+ choose_model_b = gr.Dropdown(choices=EvalModel.all_models + BaselineModel.all_baselines,
217
+ label="List of Model B",
218
+ value=model_b_name,
219
+ scale=2)
220
+ path_b_str = gr.Textbox(value=model_b.ckpt_pth, label="Checkpoint path", scale=3)
221
+
222
+ # Components: Load Metric + Load/Save Buttons
223
+ with gr.Row():
224
+ with gr.Column():
225
+ choose_metrics = gr.CheckboxGroup(choices=Metric.all_metrics,
226
+ value=metric_names,
227
+ label="Choose metrics you are interested")
228
+ with gr.Column():
229
+ load_button = gr.Button("Load images...")
230
+ load_random_button = gr.Button("Load randomly...")
231
+ save_button = gr.Button("Save images...")
232
+
233
+ ### Event listeners
234
+ choose_dataset.change(fn=get_dataset_on_DEVICE_STR,
235
+ inputs=choose_dataset,
236
+ outputs=dataset_placeholder)
237
+ choose_physics.change(fn=get_physics_generator_on_DEVICE_STR,
238
+ inputs=choose_physics,
239
+ outputs=physics_placeholder)
240
+ update_button.click(fn=physics.update_and_display_params, inputs=[key_selector, value_text], outputs=physics_params)
241
+ load_model_a.click(fn=get_model,
242
+ inputs=[choose_model_a, path_a_str],
243
+ outputs=model_a_placeholder)
244
+ load_model_b.click(fn=get_model,
245
+ inputs=[choose_model_b, path_b_str],
246
+ outputs=model_b_placeholder)
247
+ choose_metrics.change(fn=get_list_metrics_on_DEVICE_STR,
248
+ inputs=choose_metrics,
249
+ outputs=metrics_placeholder)
250
+ load_button.click(fn=generate_imgs,
251
+ inputs=[dataset_placeholder,
252
+ idx_slider,
253
+ model_a_placeholder,
254
+ model_b_placeholder,
255
+ physics_placeholder,
256
+ use_generator_button,
257
+ metrics_placeholder],
258
+ outputs=[clean, y_image, model_a_out, model_b_out, physics_params, y_metrics, out_a_metric, out_b_metric])
259
+ load_random_button.click(fn=update_random_idx_and_generate_imgs,
260
+ inputs=[dataset_placeholder,
261
+ model_a_placeholder,
262
+ model_b_placeholder,
263
+ physics_placeholder,
264
+ use_generator_button,
265
+ metrics_placeholder],
266
+ outputs=[idx_slider, clean, y_image, model_a_out, model_b_out, physics_params, y_metrics, out_a_metric, out_b_metric])
267
+
268
+ if __name__ == "__main__":
269
+ interface.launch()
datasets.py ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+ from typing import Callable, Optional
3
+ import os
4
+
5
+ import torch
6
+ from torch.utils.data import Dataset
7
+
8
+
9
+ class Preprocessed_fastMRI(torch.utils.data.Dataset):
10
+ """FastMRI from preprocessed data for faster lading."""
11
+
12
+ def __init__(
13
+ self,
14
+ root: str,
15
+ transform: Optional[Callable] = None,
16
+ preprocess: bool = False,
17
+ ) -> None:
18
+ self.root = root
19
+ self.transform = transform
20
+ self.preprocess = preprocess
21
+
22
+ # should contain all the information to load a data sample from the storage
23
+ self.sample_identifiers = []
24
+
25
+ # append all filenames in self.root ending with .pt
26
+ for root, _, files in os.walk(self.root):
27
+ for file in files:
28
+ if file.endswith(".pt"):
29
+ self.sample_identifiers.append(file)
30
+
31
+ def __len__(self) -> int:
32
+ return len(self.sample_identifiers)
33
+
34
+ def __getitem__(self, idx: int):
35
+ fname = self.sample_identifiers[idx]
36
+
37
+ tensor = torch.load(os.path.join(self.root, fname), weights_only=True)
38
+ img = tensor['data'].float()
39
+
40
+ if self.transform is not None:
41
+ img = self.transform(img)
42
+
43
+ if not self.preprocess:
44
+ return img
45
+
46
+ else:
47
+ # remove extension and prefix from filename
48
+ fname = Path(fname).stem
49
+ return img, fname
50
+
51
+
52
+ class Preprocessed_LIDCIDRI(torch.utils.data.Dataset):
53
+ """FastMRI from preprocessed data for faster lading."""
54
+
55
+ def __init__(
56
+ self,
57
+ root: str,
58
+ transform: Optional[Callable] = None,
59
+ ) -> None:
60
+ self.root = root
61
+ self.transform = transform
62
+
63
+ # should contain all the information to load a data sample from the storage
64
+ self.sample_identifiers = []
65
+
66
+ # append all filenames in self.root ending with .pt
67
+ for root, _, files in os.walk(self.root):
68
+ for file in files:
69
+ if file.endswith(".pt"):
70
+ self.sample_identifiers.append(file)
71
+
72
+ def __len__(self) -> int:
73
+ return len(self.sample_identifiers)
74
+
75
+ def __getitem__(self, idx: int):
76
+ fname = self.sample_identifiers[idx]
77
+
78
+ tensor = torch.load(os.path.join(self.root, fname), weights_only=True)
79
+ img = tensor['data'].float()
80
+
81
+ if self.transform is not None:
82
+ img = self.transform(img)
83
+
84
+ img = img.unsqueeze(0) # add channel dim
evals.py ADDED
@@ -0,0 +1,564 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, List
2
+
3
+ import deepinv as dinv
4
+ import numpy as np
5
+ import torch
6
+ from deepinv.physics.generator import MotionBlurGenerator, SigmaGenerator
7
+ from torchvision import transforms
8
+
9
+ from datasets import Preprocessed_fastMRI, Preprocessed_LIDCIDRI
10
+ from utils import get_model
11
+
12
+ DEFAULT_MODEL_PARAMS = {
13
+ "in_channels": [1, 2, 3],
14
+ "grayscale": False,
15
+ "conv_type": "base",
16
+ "pool_type": "base",
17
+ "layer_scale_init_value": 1e-6,
18
+ "init_type": "ortho",
19
+ "gain_init_conv": 1.0,
20
+ "gain_init_linear": 1.0,
21
+ "drop_prob": 0.0,
22
+ "replk": False,
23
+ "mult_fact": 4,
24
+ "antialias": "gaussian",
25
+ "nc_base": 64,
26
+ "cond_type": "base",
27
+ "blind": False,
28
+ "pretrained_pth": None,
29
+ "N": 2,
30
+ "c_mult": 2,
31
+ "depth_encoding": 2,
32
+ "relu_in_encoding": False,
33
+ "skip_in_encoding": True
34
+ }
35
+
36
+
37
+ class PhysicsWithGenerator(torch.nn.Module):
38
+ """Interface between Physics, Generator and Gradio."""
39
+ all_physics = ["MotionBlur_easy", "MotionBlur_medium", "MotionBlur_hard", "GaussianBlur",
40
+ "MRI", "CT"]
41
+
42
+ def __init__(self, physics_name: str, device_str: str = "cpu") -> None:
43
+ super().__init__()
44
+
45
+ self.name = physics_name
46
+ if self.name not in self.all_physics:
47
+ raise ValueError(f"{self.name} is unavailable.")
48
+
49
+ self.sigma_generator = SigmaGenerator(sigma_min=0.001, sigma_max=0.2, device=device_str)
50
+ if self.name == "MotionBlur_easy":
51
+ psf_size = 31
52
+ self.physics = dinv.physics.Blur(noise_model=dinv.physics.GaussianNoise(sigma=.01), padding="valid",
53
+ device=device_str)
54
+ self.physics_generator = MotionBlurGenerator((psf_size, psf_size), l=0.1, sigma=0.1, device=device_str) + SigmaGenerator(sigma_min=0.01, sigma_max=0.01, device=device_str)
55
+ self.generator = self.physics_generator + SigmaGenerator(sigma_min=0.01, sigma_max=0.01, device=device_str)
56
+ self.saved_params = {"updatable_params": {"sigma": 0.05},
57
+ "updatable_params_converter": {"sigma": float},
58
+ "fixed_params": {"noise_sigma_min": 0.01, "noise_sigma_max": 0.01,
59
+ "psf_size": 31, "motion_gen_l": 0.1, "motion_gen_s": 0.1}}
60
+ elif self.name == "MotionBlur_medium":
61
+ psf_size = 31
62
+ self.physics = dinv.physics.Blur(noise_model=dinv.physics.GaussianNoise(sigma=.05), padding="valid",
63
+ device=device_str)
64
+ self.physics_generator = MotionBlurGenerator((psf_size, psf_size), l=0.6, sigma=0.5, device=device_str) + SigmaGenerator(sigma_min=0.05, sigma_max=0.05, device=device_str)
65
+ self.generator = self.physics_generator + SigmaGenerator(sigma_min=0.05, sigma_max=0.05, device=device_str)
66
+ self.saved_params = {"updatable_params": {"sigma": 0.05},
67
+ "updatable_params_converter": {"sigma": float},
68
+ "fixed_params": {"noise_sigma_min": 0.05, "noise_sigma_max": 0.05,
69
+ "psf_size": 31, "motion_gen_l": 0.6, "motion_gen_s": 0.5}}
70
+ elif self.name == "MotionBlur_hard":
71
+ psf_size = 31
72
+ self.physics = dinv.physics.Blur(noise_model=dinv.physics.GaussianNoise(sigma=.1), padding="valid",
73
+ device=device_str)
74
+ self.physics_generator = MotionBlurGenerator((psf_size, psf_size), l=1.2, sigma=1.0, device=device_str) + SigmaGenerator(sigma_min=0.1, sigma_max=0.1, device=device_str)
75
+ self.generator = self.physics_generator + SigmaGenerator(sigma_min=0.1, sigma_max=0.1, device=device_str)
76
+ self.saved_params = {"updatable_params": {"sigma": 0.05},
77
+ "updatable_params_converter": {"sigma": float},
78
+ "fixed_params": {"noise_sigma_min": 0.1, "noise_sigma_max": 0.1,
79
+ "psf_size": 31, "motion_gen_l": 1.2, "motion_gen_s": 1.0}}
80
+ elif self.name == "GaussianBlur":
81
+ psf_size = 31
82
+ self.physics = dinv.physics.Blur(noise_model=dinv.physics.GaussianNoise(sigma=0.05), padding="valid",
83
+ device=device_str)
84
+ self.physics_generator = GaussianBlurGenerator(psf_size=(psf_size, psf_size), num_channels=1,
85
+ device=device_str)
86
+ self.generator = self.physics_generator + self.sigma_generator
87
+ self.saved_params = {"updatable_params": {"sigma": 0.05},
88
+ "updatable_params_converter": {"sigma": float},
89
+ "fixed_params": {"noise_sigma_min": 0.001, "noise_sigma_max": 0.2,
90
+ "psf_size": 31, "num_channels": 1}}
91
+ elif self.name == "MRI":
92
+ self.physics = dinv.physics.MRI(img_size=(640, 320), noise_model=dinv.physics.GaussianNoise(sigma=.01),
93
+ device=device_str)
94
+ self.physics_generator = dinv.physics.generator.RandomMaskGenerator((2, 640, 320), acceleration_factor=4)
95
+ self.generator = self.physics_generator # + self.sigma_generator
96
+ self.saved_params = {"updatable_params": {"sigma": 0.05},
97
+ "updatable_params_converter": {"sigma": float},
98
+ "fixed_params": {"noise_sigma_min": 0.001, "noise_sigma_max": 0.2,
99
+ "acceleration_factor": 4}}
100
+ elif self.name == "CT":
101
+ acceleration_factor = 10
102
+ img_h = 480
103
+ angles = int(img_h / acceleration_factor)
104
+ # angles = torch.linspace(0, 180, steps=10)
105
+ self.physics = dinv.physics.Tomography(
106
+ img_width=img_h,
107
+ angles=angles,
108
+ circle=False,
109
+ normalize=True,
110
+ device=device_str,
111
+ noise_model=dinv.physics.GaussianNoise(sigma=1e-4),
112
+ max_iter=10,
113
+ )
114
+ self.physics_generator = None
115
+ self.generator = self.sigma_generator
116
+ self.saved_params = {"updatable_params": {"sigma": 0.1},
117
+ "updatable_params_converter": {"sigma": float},
118
+ "fixed_params": {"noise_sigma_min": 0.001, "noise_sigma_max": 0.,
119
+ "angles": angles, "max_iter": 10}}
120
+
121
+ def display_saved_params(self) -> str:
122
+ """Printable version of saved_params."""
123
+ updatable_params_str = "Updatable parameters:\n"
124
+ for param_name, param_value in self.saved_params["updatable_params"].items():
125
+ updatable_params_str += f"\t\t{param_name} = {param_value}" + "\n"
126
+
127
+ fixed_params_str = "Fixed parameters:\n"
128
+ for param_name, param_value in self.saved_params["fixed_params"].items():
129
+ fixed_params_str += f"\t\t{param_name} = {param_value}" + "\n"
130
+
131
+ return updatable_params_str + fixed_params_str
132
+
133
+ def _update_save_params(self, key: str, value: Any) -> None:
134
+ """Update value of an existing key in save_params."""
135
+ if key in list(self.saved_params["updatable_params"].keys()):
136
+ if type(value) == str: # it may be only a str representation
137
+ # type: str -> ???
138
+ value = self.saved_params["updatable_params_converter"][key](value)
139
+ elif isinstance(value, torch.Tensor):
140
+ value = value.item() # type: torch.Tensor -> float
141
+ value = float(f"{value:.4f}") # keeps only 4 significant digits
142
+ self.saved_params["updatable_params"][key] = value
143
+
144
+ def update_and_display_params(self, key, value) -> str:
145
+ """_update_save_params + update physics with saved_params + display_saved_params"""
146
+ self._update_save_params(key, value)
147
+
148
+ if self.name == "Denoising":
149
+ self.physics.noise_model.update_parameters(**self.saved_params["updatable_params"])
150
+ else:
151
+ self.physics.update_parameters(**self.saved_params["updatable_params"])
152
+
153
+ return self.display_saved_params()
154
+
155
+ def update_saved_params_and_physics(self, **kwargs) -> None:
156
+ """Update save_params and update physics."""
157
+ for key, value in kwargs.items():
158
+ self._update_save_params(key, value)
159
+
160
+ self.physics.update(**kwargs)
161
+
162
+ def forward(self, x: torch.Tensor, use_gen: bool) -> torch.Tensor:
163
+ if self.name in ["MotionBlur_easy", "MotionBlur_medium", "MotionBlur_hard", "GaussianBlur"] and not hasattr(self.physics, "filter"):
164
+ use_gen = True
165
+ elif self.name in ["MRI"] and not hasattr(self.physics, "mask"):
166
+ use_gen = True
167
+
168
+ if use_gen:
169
+ kwargs = self.generator.step(batch_size=x.shape[0]) # generate a set of params for each sample
170
+ self.update_saved_params_and_physics(**kwargs)
171
+
172
+ return self.physics(x)
173
+
174
+
175
+ class EvalModel(torch.nn.Module):
176
+ """Eval model.
177
+
178
+ Is there a difference with BaselineModel ?
179
+ -> BaselineModel should be models that are already trained and will have fixed weights.
180
+ -> Eval model will change depending on differents checkpoints.
181
+ """
182
+ all_models = ["unext_emb_physics_config_C"]
183
+
184
+ def __init__(self, model_name: str, ckpt_pth: str = "", device_str: str = "cpu") -> None:
185
+ """Load the model we want to evaluate."""
186
+ super().__init__()
187
+ self.base_name = model_name
188
+ self.ckpt_pth = ckpt_pth
189
+ self.name = self.base_name
190
+ if self.base_name not in self.all_models:
191
+ raise ValueError(f"{self.base_name} is unavailable.")
192
+ if self.base_name == "unext_emb_physics_config_C":
193
+ if self.ckpt_pth == "":
194
+ self.ckpt_pth = "ckpt/ram_ckp_10.pth.tar"
195
+ self.model = get_model(model_name=self.base_name,
196
+ device='cpu',
197
+ **DEFAULT_MODEL_PARAMS)
198
+
199
+ # load model checkpoint
200
+ state_dict = torch.load(self.ckpt_pth, map_location=lambda storage, loc: storage)[
201
+ 'state_dict'] # load on cpu
202
+ self.model.load_state_dict(state_dict)
203
+ self.model.to(device_str)
204
+ self.model.eval()
205
+
206
+ # add epoch in the model name
207
+ epoch = torch.load(self.ckpt_pth, map_location=lambda storage, loc: storage)['epoch']
208
+ self.name = self.name + f"+{epoch}"
209
+
210
+ def forward(self, y: torch.Tensor, physics: torch.nn.Module) -> torch.Tensor:
211
+ return self.model(y, physics=physics)
212
+
213
+
214
+ class BaselineModel(torch.nn.Module):
215
+ """Baseline model.
216
+
217
+ Is there a difference with EvalModel ?
218
+ -> BaselineModel should be models that are already trained and will have fixed weights.
219
+ -> Eval model will change depending on differents checkpoints.
220
+ """
221
+ all_baselines = ["DRUNET", "PnP-PGD-DRUNET", "SWINIRx2", "SWINIRx4", "DPIR",
222
+ "DPIR_MRI", "DPIR_CT", "PDNET"]
223
+
224
+ def __init__(self, model_name: str, device_str: str = "cpu") -> None:
225
+ super().__init__()
226
+ self.base_name = model_name
227
+ self.ckpt_pth = ""
228
+ self.name = self.base_name
229
+ if self.name not in self.all_baselines:
230
+ raise ValueError(f"{self.name} is unavailable.")
231
+ elif self.name == "DRUNET":
232
+ n_channels = 3
233
+ ckpt_pth = "ckpt/drunet_deepinv_color_finetune_22k.pth"
234
+ self.model = dinv.models.DRUNet(in_channels=n_channels,
235
+ out_channels=n_channels,
236
+ device=device_str,
237
+ pretrained=ckpt_pth)
238
+ self.model.eval() # Set the model to evaluation mode
239
+ elif self.name == 'PDNET':
240
+ ckpt_pth = "ckpt/pdnet.pth.tar"
241
+ self.model = get_model(model_name='pdnet',
242
+ device=device_str)
243
+ self.model.eval()
244
+ self.model.load_state_dict(torch.load(ckpt_pth, map_location=lambda storage, loc: storage)['state_dict'])
245
+ elif self.name == "SWINIRx2":
246
+ n_channels = 3
247
+ scale = 2
248
+ ckpt_pth = "ckpt/001_classicalSR_DF2K_s64w8_SwinIR-M_x2.pth"
249
+ upsampler = 'nearest+conv' if 'realSR' in ckpt_pth else 'pixelshuffle'
250
+ self.model = dinv.models.SwinIR(upscale=scale, in_chans=n_channels, img_size=64, window_size=8,
251
+ img_range=1., depths=[6, 6, 6, 6, 6, 6], embed_dim=180,
252
+ num_heads=[6, 6, 6, 6, 6, 6],
253
+ mlp_ratio=2, upsampler=upsampler, resi_connection='1conv',
254
+ pretrained=ckpt_pth)
255
+ self.model.to(device_str)
256
+ self.model.eval() # Set the model to evaluation mode
257
+ elif self.name == "SWINIRx4":
258
+ n_channels = 3
259
+ scale = 4
260
+ ckpt_pth = "ckpt/001_classicalSR_DF2K_s64w8_SwinIR-M_x4.pth"
261
+ upsampler = 'nearest+conv' if 'realSR' in ckpt_pth else 'pixelshuffle'
262
+ self.model = dinv.models.SwinIR(upscale=scale, in_chans=n_channels, img_size=64, window_size=8,
263
+ img_range=1., depths=[6, 6, 6, 6, 6, 6], embed_dim=180,
264
+ num_heads=[6, 6, 6, 6, 6, 6],
265
+ mlp_ratio=2, upsampler=upsampler, resi_connection='1conv',
266
+ pretrained=ckpt_pth)
267
+ self.model.to(device_str)
268
+ self.model.eval() # Set the model to evaluation mode
269
+
270
+ elif self.name == "PnP-PGD-DRUNET":
271
+ n_channels = 3
272
+ ckpt_pth = "ckpt/drunet_deepinv_color_finetune_22k.pth"
273
+ drunet = dinv.models.DRUNet(in_channels=n_channels,
274
+ out_channels=n_channels,
275
+ device=device_str,
276
+ pretrained=ckpt_pth)
277
+ drunet.eval() # Set the model to evaluation mode
278
+ self.model = dinv.optim.optim_builder(iteration="PGD",
279
+ prior=dinv.optim.PnP(drunet).to(device_str),
280
+ data_fidelity=dinv.optim.L2(),
281
+ max_iter=20,
282
+ params_algo={'stepsize': 1., 'g_param': .05})
283
+ elif self.name == "DPIR":
284
+ n_channels = 3
285
+ ckpt_pth = "ckpt/drunet_deepinv_color_finetune_22k.pth"
286
+ drunet = dinv.models.DRUNet(in_channels=n_channels,
287
+ out_channels=n_channels,
288
+ device=device_str,
289
+ pretrained=ckpt_pth)
290
+ drunet.eval() # Set the model to evaluation mode
291
+
292
+ # Specify the denoising prior
293
+ self.prior = dinv.optim.prior.PnP(denoiser=drunet)
294
+ elif self.name == "DPIR_MRI":
295
+ class ComplexDenoiser(torch.nn.Module):
296
+ def __init__(self, denoiser):
297
+ super().__init__()
298
+ self.denoiser = denoiser
299
+
300
+ def forward(self, x, sigma):
301
+ noisy_batch = torch.cat((x[:, 0:1, ...], x[:, 1:2, ...]), 0)
302
+ input_min = noisy_batch.min()
303
+ denoised_batch = self.denoiser(noisy_batch - input_min, sigma)
304
+ denoised_batch = denoised_batch + input_min
305
+ denoised = torch.cat((denoised_batch[0:1, ...], denoised_batch[1:2, ...]), 1)
306
+ return denoised
307
+
308
+ # Load PnP denoiser backbone
309
+ n_channels = 1
310
+ ckpt_pth = "ckpt/drunet_gray.pth"
311
+ drunet = dinv.models.DRUNet(in_channels=n_channels, out_channels=n_channels, device=device_str,
312
+ pretrained=ckpt_pth)
313
+ complex_drunet = ComplexDenoiser(drunet)
314
+ complex_drunet.eval()
315
+
316
+ # Specify the denoising prior
317
+ self.prior = dinv.optim.prior.PnP(denoiser=complex_drunet)
318
+ elif self.name == "DPIR_CT":
319
+ class CTDenoiser(torch.nn.Module):
320
+ def __init__(self, denoiser):
321
+ super().__init__()
322
+ self.denoiser = denoiser
323
+
324
+ def forward(self, x, sigma):
325
+ x = x - x.min()
326
+ denoised = self.denoiser(x, sigma)
327
+ denoised = denoised + x.min()
328
+ return denoised
329
+
330
+ # Load PnP denoiser backbone
331
+ n_channels = 1
332
+ ckpt_pth = "ckpt/drunet_gray.pth"
333
+ drunet = dinv.models.DRUNet(in_channels=n_channels, out_channels=n_channels, device=device_str,
334
+ pretrained=ckpt_pth)
335
+ ct_drunet = CTDenoiser(drunet)
336
+ ct_drunet.eval()
337
+
338
+ # Specify the denoising prior
339
+ self.prior = dinv.optim.prior.PnP(denoiser=ct_drunet)
340
+
341
+ def circular_roll(self, tensor, p_h, p_w):
342
+ return tensor.roll(shifts=(p_h, p_w), dims=(-2, -1))
343
+
344
+ def get_DPIR_params(self, noise_level_img, max_iter=8):
345
+ r"""
346
+ Default parameters for the DPIR Plug-and-Play algorithm.
347
+
348
+ :param float noise_level_img: Noise level of the input image.
349
+ :return: tuple(list with denoiser noise level per iteration, list with stepsize per iteration, iterations).
350
+ """
351
+ max_iter = 8
352
+ s1 = 49.0 / 255.0
353
+ s2 = max(noise_level_img, 0.01)
354
+ sigma_denoiser = np.logspace(np.log10(s1), np.log10(s2), max_iter).astype(
355
+ np.float32
356
+ )
357
+ stepsize = (sigma_denoiser / max(0.01, noise_level_img)) ** 2
358
+ lamb = 1 / 0.23
359
+ return list(sigma_denoiser), list(lamb * stepsize)
360
+
361
+ def get_DPIR_MRI_params(self, noise_level_img: float, max_iter: int = 8):
362
+ r"""
363
+ Default parameters for the DPIR Plug-and-Play algorithm.
364
+
365
+ :param float noise_level_img: Noise level of the input image.
366
+ """
367
+ s1 = 49.0 / 255.0
368
+ s2 = noise_level_img
369
+ sigma_denoiser = np.logspace(np.log10(s1), np.log10(s2), max_iter).astype(
370
+ np.float32
371
+ )
372
+ stepsize = (sigma_denoiser / max(0.01, noise_level_img)) ** 2
373
+ lamb = 1.
374
+ return lamb, list(sigma_denoiser), list(stepsize), max_iter
375
+
376
+ def get_DPIR_CT_params(self, noise_level_img: float, max_iter: int = 8, lip_cons: float = 1.0):
377
+ r"""
378
+ Default parameters for the DPIR Plug-and-Play algorithm.
379
+
380
+ :param float noise_level_img: Noise level of the input image.
381
+ """
382
+ s1 = 49.0 / 255.0 * lip_cons
383
+ s2 = noise_level_img
384
+ sigma_denoiser = np.logspace(np.log10(s1), np.log10(s2), max_iter).astype(
385
+ np.float32
386
+ )
387
+ stepsize = (sigma_denoiser / max(0.01, noise_level_img)) ** 2 #
388
+ lamb = 1.
389
+ return lamb, list(sigma_denoiser), list(stepsize), max_iter
390
+
391
+ def forward(self, y: torch.Tensor, physics: torch.nn.Module) -> torch.Tensor:
392
+ if self.name == "DRUNET":
393
+ return self.model(y, sigma=physics.noise_model.sigma)
394
+ elif self.name == "PnP-PGD-DRUNET":
395
+ return self.model(y, physics=physics)
396
+ elif self.name == "DPIR":
397
+ # Set the DPIR algorithm parameters
398
+ sigma_float = physics.noise_model.sigma.item() # sigma should be a single value
399
+ max_iter = 8
400
+
401
+ sigma_denoiser, stepsize = self.get_DPIR_params(sigma_float, max_iter=max_iter)
402
+ params_algo = {"stepsize": stepsize, "g_param": sigma_denoiser}
403
+ early_stop = False # Do not stop algorithm with convergence criteria
404
+
405
+ # instantiate DPIR
406
+ model = dinv.optim.optim_builder(
407
+ iteration="HQS",
408
+ prior=self.prior,
409
+ data_fidelity=dinv.optim.data_fidelity.L2(),
410
+ early_stop=early_stop,
411
+ max_iter=max_iter,
412
+ verbose=True,
413
+ params_algo=params_algo,
414
+ )
415
+ return model(y, physics=physics)
416
+ elif self.name == "DPIR_MRI":
417
+ sigma_float = max(physics.noise_model.sigma.item(), 0.015) # sigma should be a single value
418
+ lamb, sigma_denoiser, stepsize, max_iter = self.get_DPIR_MRI_params(sigma_float, max_iter=16)
419
+ stepsize = [stepsize[0]] * max_iter
420
+ params_algo = {"stepsize": stepsize, "g_param": sigma_denoiser, "lambda": lamb}
421
+ early_stop = False # Do not stop algorithm with convergence criteria
422
+
423
+ # Instantiate the algorithm class to solve the IP
424
+ model = dinv.optim.optim_builder(
425
+ iteration="HQS",
426
+ prior=self.prior,
427
+ data_fidelity=dinv.optim.data_fidelity.L2(),
428
+ early_stop=early_stop,
429
+ max_iter=max_iter,
430
+ verbose=True,
431
+ params_algo=params_algo,
432
+ )
433
+ return model(y, physics=physics)
434
+ elif self.name == "DPIR_CT":
435
+ # Set the DPIR algorithm parameters
436
+ sigma_float = physics.noise_model.sigma.item() # sigma should be a single value
437
+ lip_const = physics.compute_norm(physics.A_adjoint(y))
438
+ lamb, sigma_denoiser, stepsize, max_iter = self.get_DPIR_CT_params(sigma_float, max_iter=8,
439
+ lip_cons=lip_const.item())
440
+ params_algo = {"stepsize": stepsize, "g_param": sigma_denoiser, "lambda": lamb}
441
+ early_stop = False # Do not stop algorithm with convergence criteria
442
+
443
+ def custom_init(y, physic_op):
444
+ x_init = physic_op.prox_l2(physic_op.A_adjoint(y), y, gamma=1e4)
445
+ return {"est": (x_init, x_init)}
446
+
447
+ # Instantiate the algorithm class to solve the IP
448
+ algo = dinv.optim.optim_builder(
449
+ iteration="HQS",
450
+ prior=self.prior,
451
+ data_fidelity=dinv.optim.data_fidelity.L2(),
452
+ early_stop=early_stop,
453
+ max_iter=max_iter,
454
+ verbose=True,
455
+ params_algo=params_algo,
456
+ custom_init=custom_init
457
+ )
458
+ return algo(y, physics=physics)
459
+ elif self.name == 'SWINIRx4':
460
+ window_size = 8
461
+ scale = 4
462
+ _, _, h_old, w_old = y.size()
463
+ h_pad = (h_old // window_size + 1) * window_size - h_old
464
+ w_pad = (w_old // window_size + 1) * window_size - w_old
465
+ img_lq = torch.cat([y, torch.flip(y, [2])], 2)[:, :, :h_old + h_pad, :]
466
+ img_lq = torch.cat([img_lq, torch.flip(img_lq, [3])], 3)[:, :, :, :w_old + w_pad]
467
+ output = self.model(img_lq)
468
+ output = output[..., :h_old * scale, :w_old * scale]
469
+ output = self.circular_roll(output, -2, -2)
470
+ # check shape of adjoint
471
+ x_adj = physics.A_adjoint(y)
472
+ output = output[..., :x_adj.size(-2), :x_adj.size(-1)]
473
+ return output
474
+ elif self.name == 'SWINIRx2':
475
+ window_size = 8
476
+ scale = 2
477
+ _, _, h_old, w_old = y.size()
478
+ h_pad = (h_old // window_size + 1) * window_size - h_old
479
+ w_pad = (w_old // window_size + 1) * window_size - w_old
480
+ img_lq = torch.cat([y, torch.flip(y, [2])], 2)[:, :, :h_old + h_pad, :]
481
+ img_lq = torch.cat([img_lq, torch.flip(img_lq, [3])], 3)[:, :, :, :w_old + w_pad]
482
+ output = self.model(img_lq)
483
+ output = output[..., :h_old * scale, :w_old * scale]
484
+ output = self.circular_roll(output, -1, -1)
485
+ # check shape of adjoint
486
+ x_adj = physics.A_adjoint(y)
487
+ output = output[..., :x_adj.size(-2), :x_adj.size(-1)]
488
+ return output
489
+ elif 'UNROLLED_DPIR' in self.name:
490
+ return self.model(y, physics=physics)
491
+ else:
492
+ return self.model(y)
493
+
494
+
495
+ class EvalDataset(torch.utils.data.Dataset):
496
+ """
497
+ We expect that images are 480x480.
498
+ """
499
+ all_datasets = ["Natural", "MRI", "CT"]
500
+
501
+ def __init__(self, dataset_name: str, device_str: str = "cpu") -> None:
502
+ self.name = dataset_name
503
+ self.device_str = device_str
504
+ if self.name not in self.all_datasets:
505
+ raise ValueError(f"{self.name} is unavailable.")
506
+ if self.name == 'Natural':
507
+ self.root = 'datasets/LSDIR_samples'
508
+ self.transform = transforms.Compose([transforms.ToTensor()])
509
+ self.dataset = dinv.datasets.LsdirHR(root=self.root,
510
+ download=False,
511
+ transform=self.transform)
512
+ elif self.name == 'MRI':
513
+ self.root = 'datasets/FastMRI_samples'
514
+ self.transform = transforms.CenterCrop((640, 320)) # , pad_if_needed=True)
515
+ self.dataset = Preprocessed_fastMRI(root=self.root,
516
+ transform=self.transform,
517
+ preprocess=False)
518
+ elif self.name == "CT":
519
+ self.root = 'datasets/LIDC_IDRI_samples'
520
+ self.transform = None
521
+ self.dataset = Preprocessed_LIDCIDRI(root=self.root,
522
+ transform=self.transform)
523
+
524
+ def __len__(self) -> int:
525
+ return len(self.dataset)
526
+
527
+ def __getitem__(self, idx: int) -> torch.Tensor:
528
+ return self.dataset[idx].to(self.device_str)
529
+
530
+
531
+ class Metric():
532
+ """Metrics and utilities."""
533
+ all_metrics = ["PSNR", "SSIM", "LPIPS"]
534
+
535
+ def __init__(self, metric_name: str, device_str: str = "cpu") -> None:
536
+ self.name = metric_name
537
+ if self.name not in self.all_metrics:
538
+ raise ValueError(f"{self.name} is unavailable.")
539
+ elif self.name == "PSNR":
540
+ self.metric = dinv.loss.metric.PSNR()
541
+ elif self.name == "SSIM":
542
+ self.metric = dinv.loss.metric.SSIM()
543
+ elif self.name == "LPIPS":
544
+ self.metric = dinv.loss.metric.LPIPS(device=device_str)
545
+
546
+ def __call__(self, x_net: torch.Tensor, x: torch.Tensor, *args, **kwargs) -> torch.Tensor:
547
+ # it may happen that x_net and x do not have the same size, in which case we take the minimum size of both
548
+ if x_net.shape[-1] != x.shape[-1]:
549
+ min_size = min(x_net.shape[-1], x.shape[-1])
550
+ x_net_crop = x_net[..., x_net.shape[-2] // 2 - min_size // 2: x_net.shape[-2] // 2 + min_size // 2,
551
+ x_net.shape[-1] // 2 - min_size // 2: x_net.shape[-1] // 2 + min_size // 2]
552
+ x_crop = x[..., x_net.shape[-2] // 2 - min_size // 2: x_net.shape[-2] // 2 + min_size // 2,
553
+ x_net.shape[-1] // 2 - min_size // 2: x_net.shape[-1] // 2 + min_size // 2]
554
+ else:
555
+ x_net_crop = x_net
556
+ x_crop = x
557
+ return self.metric(x_net_crop, x_crop)
558
+
559
+ @classmethod
560
+ def get_list_metrics(cls, metric_names: List[str], device_str: str = "cpu") -> List["Metric"]:
561
+ l = []
562
+ for metric_name in metric_names:
563
+ l.append(cls(metric_name, device_str=device_str))
564
+ return l
img_samples/LSDIR_samples/0001000/0000007_s005.png ADDED
img_samples/LSDIR_samples/0001000/0000030_s003.png ADDED
img_samples/LSDIR_samples/0001000/0000067_s005.png ADDED
img_samples/LSDIR_samples/0001000/0000082_s003.png ADDED
img_samples/LSDIR_samples/0001000/0000110_s002.png ADDED
img_samples/LSDIR_samples/0001000/0000125_s003.png ADDED
img_samples/LSDIR_samples/0001000/0000154_s007.png ADDED
img_samples/LSDIR_samples/0001000/0000247_s007.png ADDED
img_samples/LSDIR_samples/0001000/0000259_s003.png ADDED
img_samples/LSDIR_samples/0001000/0000405_s008.png ADDED
img_samples/LSDIR_samples/0001000/0000578_s002.png ADDED
img_samples/LSDIR_samples/0001000/0000669_s010.png ADDED
img_samples/LSDIR_samples/0001000/0000689_s006.png ADDED
img_samples/LSDIR_samples/0001000/0000715_s011.png ADDED
img_samples/LSDIR_samples/0001000/0000752_s010.png ADDED
img_samples/LSDIR_samples/0001000/0000803_s012.png ADDED
img_samples/LSDIR_samples/0001000/0000825_s012.png ADDED
img_samples/LSDIR_samples/0001000/0000921_s012.png ADDED
img_samples/LSDIR_samples/0001000/0000958_s004.png ADDED
img_samples/LSDIR_samples/0001000/0000994_s021.png ADDED
img_samples/LSDIR_samples/0009000/0008033_s006.png ADDED
img_samples/LSDIR_samples/0009000/0008068_s005.png ADDED
img_samples/LSDIR_samples/0009000/0008115_s004.png ADDED
img_samples/LSDIR_samples/0009000/0008217_s002.png ADDED
img_samples/LSDIR_samples/0009000/0008294_s010.png ADDED
img_samples/LSDIR_samples/0009000/0008315_s053.png ADDED
img_samples/LSDIR_samples/0009000/0008340_s015.png ADDED
img_samples/LSDIR_samples/0009000/0008361_s009.png ADDED
img_samples/LSDIR_samples/0009000/0008386_s007.png ADDED
img_samples/LSDIR_samples/0009000/0008491_s006.png ADDED
img_samples/LSDIR_samples/0009000/0008528_s007.png ADDED
img_samples/LSDIR_samples/0009000/0008571_s007.png ADDED
img_samples/LSDIR_samples/0009000/0008573_s012.png ADDED
img_samples/LSDIR_samples/0009000/0008605_s007.png ADDED
img_samples/LSDIR_samples/0009000/0008611_s002.png ADDED
img_samples/LSDIR_samples/0009000/0008631_s005.png ADDED
img_samples/LSDIR_samples/0009000/0008681_s008.png ADDED
img_samples/LSDIR_samples/0009000/0008703_s013.png ADDED
img_samples/LSDIR_samples/0009000/0008714_s010.png ADDED
img_samples/LSDIR_samples/0009000/0008774_s004.png ADDED
img_samples/LSDIR_samples/0023000/0022020_s005.png ADDED
img_samples/LSDIR_samples/0023000/0022037_s011.png ADDED
img_samples/LSDIR_samples/0023000/0022059_s008.png ADDED
img_samples/LSDIR_samples/0023000/0022135_s002.png ADDED