Yonuts commited on
Commit
cd12993
·
2 Parent(s): 71c2965 12a4d59

new gradio demo with ram

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +1 -0
  2. .gitignore +2 -0
  3. app.py +256 -65
  4. ckpt/001_classicalSR_DF2K_s64w8_SwinIR-M_x2.pth +3 -0
  5. ckpt/001_classicalSR_DF2K_s64w8_SwinIR-M_x4.pth +3 -0
  6. ckpt/drunet_deepinv_color_finetune_22k.pth +3 -0
  7. ckpt/drunet_gray.pth +3 -0
  8. ckpt/pdnet.pth.tar +3 -0
  9. ckpt/ram_ckp_10.pth.tar +3 -0
  10. datasets.py +84 -0
  11. evals.py +564 -0
  12. img_samples/FastMRI_samples/file_brain_AXT1POST_209_6001231_11.pt +3 -0
  13. img_samples/FastMRI_samples/file_brain_AXT2_205_2050122_7.pt +3 -0
  14. img_samples/FastMRI_samples/file_brain_AXT2_205_2050160_10.pt +3 -0
  15. img_samples/FastMRI_samples/file_brain_AXT2_210_6001888_6.pt +3 -0
  16. img_samples/FastMRI_samples/file_brain_AXT2_210_6001947_5.pt +3 -0
  17. img_samples/LIDC-IDRI_samples/LIDC-IDRI-0032_01-01-2000-NA-NA-53482_3000537.000000-NA-91689_1-236.pt +3 -0
  18. img_samples/LIDC-IDRI_samples/LIDC-IDRI-0083_01-01-2000-NA-NA-22049_3000646.000000-NA-60532_1-027.pt +3 -0
  19. img_samples/LIDC-IDRI_samples/LIDC-IDRI-0144_01-01-2000-NA-NA-61308_3000703.000000-NA-75826_1-079.pt +3 -0
  20. img_samples/LIDC-IDRI_samples/LIDC-IDRI-0152_01-01-2000-NA-NA-78489_3000696.000000-NA-27171_1-083.pt +3 -0
  21. img_samples/LIDC-IDRI_samples/LIDC-IDRI-0298_01-01-2000-NA-NA-11572_3000663.000000-NA-48288_1-004.pt +3 -0
  22. img_samples/LSDIR_samples/0001000/0000007_s005.png +0 -0
  23. img_samples/LSDIR_samples/0001000/0000030_s003.png +0 -0
  24. img_samples/LSDIR_samples/0001000/0000067_s005.png +0 -0
  25. img_samples/LSDIR_samples/0001000/0000082_s003.png +0 -0
  26. img_samples/LSDIR_samples/0001000/0000110_s002.png +0 -0
  27. img_samples/LSDIR_samples/0001000/0000125_s003.png +0 -0
  28. img_samples/LSDIR_samples/0001000/0000154_s007.png +0 -0
  29. img_samples/LSDIR_samples/0001000/0000247_s007.png +0 -0
  30. img_samples/LSDIR_samples/0001000/0000259_s003.png +0 -0
  31. img_samples/LSDIR_samples/0001000/0000405_s008.png +0 -0
  32. img_samples/LSDIR_samples/0001000/0000578_s002.png +0 -0
  33. img_samples/LSDIR_samples/0001000/0000669_s010.png +0 -0
  34. img_samples/LSDIR_samples/0001000/0000689_s006.png +0 -0
  35. img_samples/LSDIR_samples/0001000/0000715_s011.png +0 -0
  36. img_samples/LSDIR_samples/0001000/0000752_s010.png +0 -0
  37. img_samples/LSDIR_samples/0001000/0000803_s012.png +0 -0
  38. img_samples/LSDIR_samples/0001000/0000825_s012.png +0 -0
  39. img_samples/LSDIR_samples/0001000/0000921_s012.png +0 -0
  40. img_samples/LSDIR_samples/0001000/0000958_s004.png +0 -0
  41. img_samples/LSDIR_samples/0001000/0000994_s021.png +0 -0
  42. img_samples/LSDIR_samples/0009000/0008033_s006.png +0 -0
  43. img_samples/LSDIR_samples/0009000/0008068_s005.png +0 -0
  44. img_samples/LSDIR_samples/0009000/0008115_s004.png +0 -0
  45. img_samples/LSDIR_samples/0009000/0008217_s002.png +0 -0
  46. img_samples/LSDIR_samples/0009000/0008294_s010.png +0 -0
  47. img_samples/LSDIR_samples/0009000/0008315_s053.png +0 -0
  48. img_samples/LSDIR_samples/0009000/0008340_s015.png +0 -0
  49. img_samples/LSDIR_samples/0009000/0008361_s009.png +0 -0
  50. img_samples/LSDIR_samples/0009000/0008386_s007.png +0 -0
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ *.pth.tar filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ .ipynb
2
+ __pycache__
app.py CHANGED
@@ -1,78 +1,269 @@
1
- import gradio as gr
 
 
 
 
 
 
2
  import deepinv as dinv
 
3
  import torch
4
- import numpy as np
5
- import PIL.Image
 
 
 
 
 
 
 
 
 
 
 
 
6
 
 
 
 
7
 
8
- def pil_to_torch(image, ref_size=512):
9
- image = np.array(image)
10
- image = image.transpose((2, 0, 1))
11
- image = torch.tensor(image).float() / 255
12
- image = image.unsqueeze(0)
13
 
14
- if ref_size == 256:
15
- size = (ref_size, ref_size)
16
- elif image.shape[2] > image.shape[3]:
17
- size = (ref_size, ref_size * image.shape[3]//image.shape[2])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
  else:
19
- size = (ref_size * image.shape[2]//image.shape[3], ref_size)
20
-
21
- image = torch.nn.functional.interpolate(image, size=size, mode='bilinear')
22
- return image
23
-
24
-
25
- def torch_to_pil(image):
26
- image = image.squeeze(0).cpu().detach().numpy()
27
- image = image.transpose((1, 2, 0))
28
- image = (np.clip(image, 0, 1) * 255).astype(np.uint8)
29
- image = PIL.Image.fromarray(image)
30
- return image
31
-
32
-
33
- def image_mod(image, noise_level, denoiser):
34
- image = pil_to_torch(image, ref_size=256 if denoiser == 'DiffUNet' else 512)
35
- if denoiser == 'DnCNN':
36
- den = dinv.models.DnCNN()
37
- sigma0 = 2/255
38
- denoiser = lambda x, sigma: den(x*sigma0/sigma)*sigma/sigma0
39
- elif denoiser == 'MedianFilter':
40
- denoiser = dinv.models.MedianFilter(kernel_size=5)
41
- elif denoiser == 'BM3D':
42
- denoiser = dinv.models.BM3D()
43
- elif denoiser == 'TV':
44
- denoiser = dinv.models.TVDenoiser()
45
- elif denoiser == 'TGV':
46
- denoiser = dinv.models.TGVDenoiser()
47
- elif denoiser == 'Wavelets':
48
- denoiser = dinv.models.WaveletPrior()
49
- elif denoiser == 'DiffUNet':
50
- denoiser = dinv.models.DiffUNet()
51
- elif denoiser == 'DRUNet':
52
- denoiser = dinv.models.DRUNet()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
53
  else:
54
- raise ValueError("Invalid denoiser")
55
- noisy = image + torch.randn_like(image) * noise_level
56
- estimated = denoiser(noisy, noise_level)
57
- return torch_to_pil(noisy), torch_to_pil(estimated)
 
 
 
 
 
 
 
 
 
58
 
 
 
 
59
 
60
- input_image = gr.Image(label='Input Image')
61
- output_images = gr.Image(label='Denoised Image')
62
- noise_image = gr.Image(label='Noisy Image')
63
- input_image_output = gr.Image(label='Input Image')
 
 
64
 
65
- noise_levels = gr.Dropdown(choices=[0.05, 0.1, 0.2, 0.3, 0.5, 1], value=0.1, label='Noise Level')
 
 
 
 
 
 
 
 
 
66
 
67
- denoiser = gr.Dropdown(choices=['DnCNN', 'DRUNet', 'DiffUNet', 'BM3D', 'MedianFilter', 'TV', 'TGV', 'Wavelets'], value='DRUNet', label='Denoiser')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
68
 
69
- demo = gr.Interface(
70
- image_mod,
71
- inputs=[input_image, noise_levels, denoiser],
72
- examples=[['https://upload.wikimedia.org/wikipedia/commons/b/b4/Lionel-Messi-Argentina-2022-FIFA-World-Cup_%28cropped%29.jpg', 0.1, 'DRUNet']],
73
- outputs=[noise_image, output_images],
74
- title="Image Denoising with DeepInverse",
75
- description="Denoise an image using a variety of denoisers and noise levels using the deepinverse library (https://deepinv.github.io/). We only include lightweight models like DnCNN and MedianFilter as this example is intended to be run on a CPU. We also automatically resize the input image to 512 pixels to reduce the computation time. For more advanced models, please run the code locally.",
76
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
77
 
78
- demo.launch()
 
 
1
+ import json
2
+ import os
3
+ import random
4
+ from functools import partial
5
+ from pathlib import Path
6
+ from typing import List
7
+
8
  import deepinv as dinv
9
+ import gradio as gr
10
  import torch
11
+ from PIL import Image
12
+ from torchvision import transforms
13
+
14
+ from evals import PhysicsWithGenerator, EvalModel, BaselineModel, EvalDataset, Metric
15
+
16
+
17
+ ### Gradio Utils
18
+ def generate_imgs(dataset: EvalDataset, idx: int,
19
+ model: EvalModel, baseline: BaselineModel,
20
+ physics: PhysicsWithGenerator, use_gen: bool,
21
+ metrics: List[Metric]):
22
+ ### Load 1 image
23
+ x = dataset[idx] # shape : (3, 256, 256)
24
+ x = x.unsqueeze(0) # shape : (1, 3, 256, 256)
25
 
26
+ with torch.no_grad():
27
+ ### Compute y
28
+ y = physics(x, use_gen) # possible reduction in img shape due to Blurring
29
 
30
+ ### Compute x_hat
31
+ out = model(y=y, physics=physics.physics)
32
+ out_baseline = baseline(y=y, physics=physics.physics)
 
 
33
 
34
+ ### Process tensors before metric computation
35
+ if "Blur" in physics.name:
36
+ w_1, w_2 = (x.shape[2] - y.shape[2]) // 2, (x.shape[2] + y.shape[2]) // 2
37
+ h_1, h_2 = (x.shape[3] - y.shape[3]) // 2, (x.shape[3] + y.shape[3]) // 2
38
+
39
+ x = x[..., w_1:w_2, h_1:h_2]
40
+ out = out[..., w_1:w_2, h_1:h_2]
41
+ if out_baseline.shape != out.shape:
42
+ out_baseline = out_baseline[..., w_1:w_2, h_1:h_2]
43
+
44
+ ### Metrics
45
+ metrics_y = ""
46
+ metrics_out = ""
47
+ metrics_out_baseline = ""
48
+ for metric in metrics:
49
+ if y.shape == x.shape:
50
+ metrics_y += f"{metric.name} = {metric(y, x).item():.4f}" + "\n"
51
+ metrics_out += f"{metric.name} = {metric(out, x).item():.4f}" + "\n"
52
+ metrics_out_baseline += f"{metric.name} = {metric(out_baseline, x).item():.4f}" + "\n"
53
+
54
+ ### Process y when y shape is different from x shape
55
+ if physics.name == "MRI" or "SR" in physics.name:
56
+ y_plot = physics.physics.prox_l2(physics.physics.A_adjoint(y), y, 1e4)
57
  else:
58
+ y_plot = y.clone()
59
+
60
+ ### Processing images for plotting :
61
+ # - clip value outside of [0,1]
62
+ # - shape (1, C, H, W) -> (C, H, W)
63
+ # - torch.Tensor object -> Pil object
64
+ process_img = partial(dinv.utils.plotting.preprocess_img, rescale_mode="clip")
65
+ to_pil = transforms.ToPILImage()
66
+ x = to_pil(process_img(x)[0].to('cpu'))
67
+ y = to_pil(process_img(y_plot)[0].to('cpu'))
68
+ out = to_pil(process_img(out)[0].to('cpu'))
69
+ out_baseline = to_pil(process_img(out_baseline)[0].to('cpu'))
70
+
71
+
72
+ return x, y, out, out_baseline, physics.display_saved_params(), metrics_y, metrics_out, metrics_out_baseline
73
+
74
+ def update_random_idx_and_generate_imgs(dataset: EvalDataset,
75
+ model: EvalModel,
76
+ baseline: BaselineModel,
77
+ physics: PhysicsWithGenerator,
78
+ use_gen: bool,
79
+ metrics: List[Metric]):
80
+ idx = random.randint(0, len(dataset))
81
+ x, y, out, out_baseline, saved_params_str, metrics_y, metrics_out, metrics_out_baseline = generate_imgs(dataset,
82
+ idx,
83
+ model,
84
+ baseline,
85
+ physics,
86
+ use_gen,
87
+ metrics)
88
+ return idx, x, y, out, out_baseline, saved_params_str, metrics_y, metrics_out, metrics_out_baseline
89
+
90
+ def save_imgs(dataset: EvalDataset, idx: int, physics: PhysicsWithGenerator,
91
+ model_a: EvalModel | BaselineModel, model_b: EvalModel | BaselineModel,
92
+ x: Image.Image, y: Image.Image,
93
+ out_a: Image.Image, out_b: Image.Image,
94
+ y_metrics_str: str,
95
+ out_a_metric_str : str, out_b_metric_str: str) -> None:
96
+
97
+ ### PROCESSES STR
98
+ physics_params_str = ""
99
+ for param_name, param_value in physics.saved_params["updatable_params"].items():
100
+ physics_params_str += f"{param_name}_{param_value}-"
101
+ physics_params_str = physics_params_str[:-1] if physics_params_str.endswith("-") else physics_params_str
102
+ y_metrics_str = y_metrics_str.replace(" = ", "_").replace("\n", "-")
103
+ y_metrics_str = y_metrics_str[:-1] if y_metrics_str.endswith("-") else y_metrics_str
104
+ out_a_metric_str = out_a_metric_str.replace(" = ", "_").replace("\n", "-")
105
+ out_a_metric_str = out_a_metric_str[:-1] if out_a_metric_str.endswith("-") else out_a_metric_str
106
+ out_b_metric_str = out_b_metric_str.replace(" = ", "_").replace("\n", "-")
107
+ out_b_metric_str = out_b_metric_str[:-1] if out_b_metric_str.endswith("-") else out_b_metric_str
108
+
109
+ save_path = SAVE_IMG_DIR / f"{dataset.name}+{idx}+{physics.name}+{physics_params_str}+{y_metrics_str}+{model_a.name}+{out_a_metric_str}+{model_b.name}+{out_b_metric_str}.png"
110
+ titles = [f"{dataset.name}[{idx}]",
111
+ f"y = {physics.name}(x)",
112
+ f"{model_a.name}",
113
+ f"{model_b.name}"]
114
+
115
+ # Pil object -> torch.Tensor
116
+ to_tensor = transforms.ToTensor()
117
+ x = to_tensor(x)
118
+ y = to_tensor(y)
119
+ out_a = to_tensor(out_a)
120
+ out_b = to_tensor(out_b)
121
+
122
+ dinv.utils.plot([x, y, out_a, out_b], titles=titles, show=False, save_fn=save_path)
123
+
124
+ get_list_metrics_on_DEVICE_STR = partial(Metric.get_list_metrics, device_str=DEVICE_STR)
125
+ get_baseline_model_on_DEVICE_STR = partial(BaselineModel, device_str=DEVICE_STR)
126
+ get_eval_model_on_DEVICE_STR = partial(EvalModel, device_str=DEVICE_STR)
127
+ get_dataset_on_DEVICE_STR = partial(EvalDataset, device_str=DEVICE_STR)
128
+ get_physics_generator_on_DEVICE_STR = partial(PhysicsWithGenerator, device_str=DEVICE_STR)
129
+
130
+ def get_model(model_name, ckpt_pth):
131
+ if model_name in BaselineModel.all_baselines:
132
+ return get_baseline_model_on_DEVICE_STR(model_name)
133
  else:
134
+ return get_eval_model_on_DEVICE_STR(model_name, ckpt_pth)
135
+
136
+
137
+ ### Gradio Blocks interface
138
+
139
+ # Define custom CSS
140
+ custom_css = """
141
+ .fixed-textbox textarea {
142
+ height: 90px !important; /* Adjust height to fit exactly 4 lines */
143
+ overflow: scroll; /* Add a scroll bar if necessary */
144
+ resize: none; /* User can resize vertically the textbox */
145
+ }
146
+ """
147
 
148
+ title = "Inverse problem playground" # displayed on gradio tab and in the gradio page
149
+ with gr.Blocks(title=title, css=custom_css) as interface:
150
+ gr.Markdown("## " + title)
151
 
152
+ # Loading things
153
+ model_a_placeholder = gr.State(lambda: get_eval_model_on_DEVICE_STR("unext_emb_physics_config_C", "")) # lambda expression to instanciate a callable in a gr.State
154
+ model_b_placeholder = gr.State(lambda: get_baseline_model_on_DEVICE_STR("DRUNET")) # lambda expression to instanciate a callable in a gr.State
155
+ dataset_placeholder = gr.State(get_dataset_on_DEVICE_STR("DIV2K_valid_HR"))
156
+ physics_placeholder = gr.State(lambda: get_physics_generator_on_DEVICE_STR("Denoising")) # lambda expression to instanciate a callable in a gr.State
157
+ metrics_placeholder = gr.State(get_list_metrics_on_DEVICE_STR(["PSNR"]))
158
 
159
+ @gr.render(inputs=[model_a_placeholder, model_b_placeholder, dataset_placeholder, physics_placeholder, metrics_placeholder])
160
+ def dynamic_layout(model_a, model_b, dataset, physics, metrics):
161
+ ### LAYOUT
162
+ model_a_name = model_a.base_name
163
+ model_a_full_name = model_a.name
164
+ model_b_name = model_b.base_name
165
+ model_b_full_name = model_b.name
166
+ dataset_name = dataset.name
167
+ physics_name = physics.name
168
+ metric_names = [metric.name for metric in metrics]
169
 
170
+ # Components: Inputs/Outputs + Load EvalDataset/PhysicsWithGenerator/EvalModel/BaselineModel
171
+ with gr.Row():
172
+ with gr.Column():
173
+ with gr.Row():
174
+ with gr.Column():
175
+ clean = gr.Image(label=f"{dataset_name} IMAGE", interactive=False)
176
+ physics_params = gr.Textbox(label="Physics parameters", elem_classes=["fixed-textbox"], value=physics.display_saved_params())
177
+ with gr.Column():
178
+ y_image = gr.Image(label=f"{physics_name} IMAGE", interactive=False)
179
+ y_metrics = gr.Textbox(label="Metrics(y, x)", elem_classes=["fixed-textbox"],)
180
+ with gr.Row():
181
+ choose_dataset = gr.Radio(choices=EvalDataset.all_datasets,
182
+ label="List of EvalDataset",
183
+ value=dataset_name,
184
+ scale=2)
185
+ idx_slider = gr.Slider(minimum=0, maximum=len(dataset)-1, step=1, label="Sample index", scale=1)
186
+
187
+ choose_physics = gr.Radio(choices=PhysicsWithGenerator.all_physics,
188
+ label="List of PhysicsWithGenerator",
189
+ value=physics_name)
190
+ with gr.Row():
191
+ key_selector = gr.Dropdown(choices=list(physics.saved_params["updatable_params"].keys()),
192
+ label="Updatable Parameter Key",
193
+ scale=2)
194
+ value_text = gr.Textbox(label="Update Value", scale=2)
195
+ with gr.Column(scale=1):
196
+ update_button = gr.Button("Update Param")
197
+ use_generator_button = gr.Checkbox(label="Use param generator")
198
+
199
+ with gr.Column():
200
+ with gr.Row():
201
+ with gr.Column():
202
+ model_a_out = gr.Image(label=f"{model_a_full_name} OUTPUT", interactive=False)
203
+ out_a_metric = gr.Textbox(label="Metrics(model_a(y), x)", elem_classes=["fixed-textbox"])
204
+ load_model_a = gr.Button("Load model A...", scale=1)
205
+ with gr.Column():
206
+ model_b_out = gr.Image(label=f"{model_b_full_name} OUTPUT", interactive=False)
207
+ out_b_metric = gr.Textbox(label="Metrics(model_b(y), x)", elem_classes=["fixed-textbox"])
208
+ load_model_b = gr.Button("Load model B...", scale=1)
209
+ with gr.Row():
210
+ choose_model_a = gr.Dropdown(choices=EvalModel.all_models + BaselineModel.all_baselines,
211
+ label="List of Model A",
212
+ value=model_a_name,
213
+ scale=2)
214
+ path_a_str = gr.Textbox(value=model_a.ckpt_pth, label="Checkpoint path", scale=3)
215
+ with gr.Row():
216
+ choose_model_b = gr.Dropdown(choices=EvalModel.all_models + BaselineModel.all_baselines,
217
+ label="List of Model B",
218
+ value=model_b_name,
219
+ scale=2)
220
+ path_b_str = gr.Textbox(value=model_b.ckpt_pth, label="Checkpoint path", scale=3)
221
+
222
+ # Components: Load Metric + Load/Save Buttons
223
+ with gr.Row():
224
+ with gr.Column():
225
+ choose_metrics = gr.CheckboxGroup(choices=Metric.all_metrics,
226
+ value=metric_names,
227
+ label="Choose metrics you are interested")
228
+ with gr.Column():
229
+ load_button = gr.Button("Load images...")
230
+ load_random_button = gr.Button("Load randomly...")
231
+ save_button = gr.Button("Save images...")
232
 
233
+ ### Event listeners
234
+ choose_dataset.change(fn=get_dataset_on_DEVICE_STR,
235
+ inputs=choose_dataset,
236
+ outputs=dataset_placeholder)
237
+ choose_physics.change(fn=get_physics_generator_on_DEVICE_STR,
238
+ inputs=choose_physics,
239
+ outputs=physics_placeholder)
240
+ update_button.click(fn=physics.update_and_display_params, inputs=[key_selector, value_text], outputs=physics_params)
241
+ load_model_a.click(fn=get_model,
242
+ inputs=[choose_model_a, path_a_str],
243
+ outputs=model_a_placeholder)
244
+ load_model_b.click(fn=get_model,
245
+ inputs=[choose_model_b, path_b_str],
246
+ outputs=model_b_placeholder)
247
+ choose_metrics.change(fn=get_list_metrics_on_DEVICE_STR,
248
+ inputs=choose_metrics,
249
+ outputs=metrics_placeholder)
250
+ load_button.click(fn=generate_imgs,
251
+ inputs=[dataset_placeholder,
252
+ idx_slider,
253
+ model_a_placeholder,
254
+ model_b_placeholder,
255
+ physics_placeholder,
256
+ use_generator_button,
257
+ metrics_placeholder],
258
+ outputs=[clean, y_image, model_a_out, model_b_out, physics_params, y_metrics, out_a_metric, out_b_metric])
259
+ load_random_button.click(fn=update_random_idx_and_generate_imgs,
260
+ inputs=[dataset_placeholder,
261
+ model_a_placeholder,
262
+ model_b_placeholder,
263
+ physics_placeholder,
264
+ use_generator_button,
265
+ metrics_placeholder],
266
+ outputs=[idx_slider, clean, y_image, model_a_out, model_b_out, physics_params, y_metrics, out_a_metric, out_b_metric])
267
 
268
+ if __name__ == "__main__":
269
+ interface.launch()
ckpt/001_classicalSR_DF2K_s64w8_SwinIR-M_x2.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2032ebf8f401dd3ce2fae5f3852117cb72101ec6ed8358faa64c2a3fa09ed4ac
3
+ size 67277475
ckpt/001_classicalSR_DF2K_s64w8_SwinIR-M_x4.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4e78e33f22c1aa8a773db0cf4a7381bae97c2362c717f155439ebc690cbd9215
3
+ size 67869037
ckpt/drunet_deepinv_color_finetune_22k.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:20296845d272d3d786b89ea3c1208d5f2ceb57658a499d4dd28073cbb73508aa
3
+ size 130585443
ckpt/drunet_gray.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e27fb29456732c604c8ee3ac3e92ecadd2a7a4e36a8be675e92ebbf93a240de4
3
+ size 130569961
ckpt/pdnet.pth.tar ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d9b898d9d7deac148c3ec722c956231259d13d5329e08d42ba15f212e1967fe0
3
+ size 5756509
ckpt/ram_ckp_10.pth.tar ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b07c87641c0045112461a11723c4f32b1aa1d351ac070f76b40c0c84c86d2a01
3
+ size 427825566
datasets.py ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+ from typing import Callable, Optional
3
+ import os
4
+
5
+ import torch
6
+ from torch.utils.data import Dataset
7
+
8
+
9
+ class Preprocessed_fastMRI(torch.utils.data.Dataset):
10
+ """FastMRI from preprocessed data for faster lading."""
11
+
12
+ def __init__(
13
+ self,
14
+ root: str,
15
+ transform: Optional[Callable] = None,
16
+ preprocess: bool = False,
17
+ ) -> None:
18
+ self.root = root
19
+ self.transform = transform
20
+ self.preprocess = preprocess
21
+
22
+ # should contain all the information to load a data sample from the storage
23
+ self.sample_identifiers = []
24
+
25
+ # append all filenames in self.root ending with .pt
26
+ for root, _, files in os.walk(self.root):
27
+ for file in files:
28
+ if file.endswith(".pt"):
29
+ self.sample_identifiers.append(file)
30
+
31
+ def __len__(self) -> int:
32
+ return len(self.sample_identifiers)
33
+
34
+ def __getitem__(self, idx: int):
35
+ fname = self.sample_identifiers[idx]
36
+
37
+ tensor = torch.load(os.path.join(self.root, fname), weights_only=True)
38
+ img = tensor['data'].float()
39
+
40
+ if self.transform is not None:
41
+ img = self.transform(img)
42
+
43
+ if not self.preprocess:
44
+ return img
45
+
46
+ else:
47
+ # remove extension and prefix from filename
48
+ fname = Path(fname).stem
49
+ return img, fname
50
+
51
+
52
+ class Preprocessed_LIDCIDRI(torch.utils.data.Dataset):
53
+ """FastMRI from preprocessed data for faster lading."""
54
+
55
+ def __init__(
56
+ self,
57
+ root: str,
58
+ transform: Optional[Callable] = None,
59
+ ) -> None:
60
+ self.root = root
61
+ self.transform = transform
62
+
63
+ # should contain all the information to load a data sample from the storage
64
+ self.sample_identifiers = []
65
+
66
+ # append all filenames in self.root ending with .pt
67
+ for root, _, files in os.walk(self.root):
68
+ for file in files:
69
+ if file.endswith(".pt"):
70
+ self.sample_identifiers.append(file)
71
+
72
+ def __len__(self) -> int:
73
+ return len(self.sample_identifiers)
74
+
75
+ def __getitem__(self, idx: int):
76
+ fname = self.sample_identifiers[idx]
77
+
78
+ tensor = torch.load(os.path.join(self.root, fname), weights_only=True)
79
+ img = tensor['data'].float()
80
+
81
+ if self.transform is not None:
82
+ img = self.transform(img)
83
+
84
+ img = img.unsqueeze(0) # add channel dim
evals.py ADDED
@@ -0,0 +1,564 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, List
2
+
3
+ import deepinv as dinv
4
+ import numpy as np
5
+ import torch
6
+ from deepinv.physics.generator import MotionBlurGenerator, SigmaGenerator
7
+ from torchvision import transforms
8
+
9
+ from datasets import Preprocessed_fastMRI, Preprocessed_LIDCIDRI
10
+ from utils import get_model
11
+
12
+ DEFAULT_MODEL_PARAMS = {
13
+ "in_channels": [1, 2, 3],
14
+ "grayscale": False,
15
+ "conv_type": "base",
16
+ "pool_type": "base",
17
+ "layer_scale_init_value": 1e-6,
18
+ "init_type": "ortho",
19
+ "gain_init_conv": 1.0,
20
+ "gain_init_linear": 1.0,
21
+ "drop_prob": 0.0,
22
+ "replk": False,
23
+ "mult_fact": 4,
24
+ "antialias": "gaussian",
25
+ "nc_base": 64,
26
+ "cond_type": "base",
27
+ "blind": False,
28
+ "pretrained_pth": None,
29
+ "N": 2,
30
+ "c_mult": 2,
31
+ "depth_encoding": 2,
32
+ "relu_in_encoding": False,
33
+ "skip_in_encoding": True
34
+ }
35
+
36
+
37
+ class PhysicsWithGenerator(torch.nn.Module):
38
+ """Interface between Physics, Generator and Gradio."""
39
+ all_physics = ["MotionBlur_easy", "MotionBlur_medium", "MotionBlur_hard", "GaussianBlur",
40
+ "MRI", "CT"]
41
+
42
+ def __init__(self, physics_name: str, device_str: str = "cpu") -> None:
43
+ super().__init__()
44
+
45
+ self.name = physics_name
46
+ if self.name not in self.all_physics:
47
+ raise ValueError(f"{self.name} is unavailable.")
48
+
49
+ self.sigma_generator = SigmaGenerator(sigma_min=0.001, sigma_max=0.2, device=device_str)
50
+ if self.name == "MotionBlur_easy":
51
+ psf_size = 31
52
+ self.physics = dinv.physics.Blur(noise_model=dinv.physics.GaussianNoise(sigma=.01), padding="valid",
53
+ device=device_str)
54
+ self.physics_generator = MotionBlurGenerator((psf_size, psf_size), l=0.1, sigma=0.1, device=device_str) + SigmaGenerator(sigma_min=0.01, sigma_max=0.01, device=device_str)
55
+ self.generator = self.physics_generator + SigmaGenerator(sigma_min=0.01, sigma_max=0.01, device=device_str)
56
+ self.saved_params = {"updatable_params": {"sigma": 0.05},
57
+ "updatable_params_converter": {"sigma": float},
58
+ "fixed_params": {"noise_sigma_min": 0.01, "noise_sigma_max": 0.01,
59
+ "psf_size": 31, "motion_gen_l": 0.1, "motion_gen_s": 0.1}}
60
+ elif self.name == "MotionBlur_medium":
61
+ psf_size = 31
62
+ self.physics = dinv.physics.Blur(noise_model=dinv.physics.GaussianNoise(sigma=.05), padding="valid",
63
+ device=device_str)
64
+ self.physics_generator = MotionBlurGenerator((psf_size, psf_size), l=0.6, sigma=0.5, device=device_str) + SigmaGenerator(sigma_min=0.05, sigma_max=0.05, device=device_str)
65
+ self.generator = self.physics_generator + SigmaGenerator(sigma_min=0.05, sigma_max=0.05, device=device_str)
66
+ self.saved_params = {"updatable_params": {"sigma": 0.05},
67
+ "updatable_params_converter": {"sigma": float},
68
+ "fixed_params": {"noise_sigma_min": 0.05, "noise_sigma_max": 0.05,
69
+ "psf_size": 31, "motion_gen_l": 0.6, "motion_gen_s": 0.5}}
70
+ elif self.name == "MotionBlur_hard":
71
+ psf_size = 31
72
+ self.physics = dinv.physics.Blur(noise_model=dinv.physics.GaussianNoise(sigma=.1), padding="valid",
73
+ device=device_str)
74
+ self.physics_generator = MotionBlurGenerator((psf_size, psf_size), l=1.2, sigma=1.0, device=device_str) + SigmaGenerator(sigma_min=0.1, sigma_max=0.1, device=device_str)
75
+ self.generator = self.physics_generator + SigmaGenerator(sigma_min=0.1, sigma_max=0.1, device=device_str)
76
+ self.saved_params = {"updatable_params": {"sigma": 0.05},
77
+ "updatable_params_converter": {"sigma": float},
78
+ "fixed_params": {"noise_sigma_min": 0.1, "noise_sigma_max": 0.1,
79
+ "psf_size": 31, "motion_gen_l": 1.2, "motion_gen_s": 1.0}}
80
+ elif self.name == "GaussianBlur":
81
+ psf_size = 31
82
+ self.physics = dinv.physics.Blur(noise_model=dinv.physics.GaussianNoise(sigma=0.05), padding="valid",
83
+ device=device_str)
84
+ self.physics_generator = GaussianBlurGenerator(psf_size=(psf_size, psf_size), num_channels=1,
85
+ device=device_str)
86
+ self.generator = self.physics_generator + self.sigma_generator
87
+ self.saved_params = {"updatable_params": {"sigma": 0.05},
88
+ "updatable_params_converter": {"sigma": float},
89
+ "fixed_params": {"noise_sigma_min": 0.001, "noise_sigma_max": 0.2,
90
+ "psf_size": 31, "num_channels": 1}}
91
+ elif self.name == "MRI":
92
+ self.physics = dinv.physics.MRI(img_size=(640, 320), noise_model=dinv.physics.GaussianNoise(sigma=.01),
93
+ device=device_str)
94
+ self.physics_generator = dinv.physics.generator.RandomMaskGenerator((2, 640, 320), acceleration_factor=4)
95
+ self.generator = self.physics_generator # + self.sigma_generator
96
+ self.saved_params = {"updatable_params": {"sigma": 0.05},
97
+ "updatable_params_converter": {"sigma": float},
98
+ "fixed_params": {"noise_sigma_min": 0.001, "noise_sigma_max": 0.2,
99
+ "acceleration_factor": 4}}
100
+ elif self.name == "CT":
101
+ acceleration_factor = 10
102
+ img_h = 480
103
+ angles = int(img_h / acceleration_factor)
104
+ # angles = torch.linspace(0, 180, steps=10)
105
+ self.physics = dinv.physics.Tomography(
106
+ img_width=img_h,
107
+ angles=angles,
108
+ circle=False,
109
+ normalize=True,
110
+ device=device_str,
111
+ noise_model=dinv.physics.GaussianNoise(sigma=1e-4),
112
+ max_iter=10,
113
+ )
114
+ self.physics_generator = None
115
+ self.generator = self.sigma_generator
116
+ self.saved_params = {"updatable_params": {"sigma": 0.1},
117
+ "updatable_params_converter": {"sigma": float},
118
+ "fixed_params": {"noise_sigma_min": 0.001, "noise_sigma_max": 0.,
119
+ "angles": angles, "max_iter": 10}}
120
+
121
+ def display_saved_params(self) -> str:
122
+ """Printable version of saved_params."""
123
+ updatable_params_str = "Updatable parameters:\n"
124
+ for param_name, param_value in self.saved_params["updatable_params"].items():
125
+ updatable_params_str += f"\t\t{param_name} = {param_value}" + "\n"
126
+
127
+ fixed_params_str = "Fixed parameters:\n"
128
+ for param_name, param_value in self.saved_params["fixed_params"].items():
129
+ fixed_params_str += f"\t\t{param_name} = {param_value}" + "\n"
130
+
131
+ return updatable_params_str + fixed_params_str
132
+
133
+ def _update_save_params(self, key: str, value: Any) -> None:
134
+ """Update value of an existing key in save_params."""
135
+ if key in list(self.saved_params["updatable_params"].keys()):
136
+ if type(value) == str: # it may be only a str representation
137
+ # type: str -> ???
138
+ value = self.saved_params["updatable_params_converter"][key](value)
139
+ elif isinstance(value, torch.Tensor):
140
+ value = value.item() # type: torch.Tensor -> float
141
+ value = float(f"{value:.4f}") # keeps only 4 significant digits
142
+ self.saved_params["updatable_params"][key] = value
143
+
144
+ def update_and_display_params(self, key, value) -> str:
145
+ """_update_save_params + update physics with saved_params + display_saved_params"""
146
+ self._update_save_params(key, value)
147
+
148
+ if self.name == "Denoising":
149
+ self.physics.noise_model.update_parameters(**self.saved_params["updatable_params"])
150
+ else:
151
+ self.physics.update_parameters(**self.saved_params["updatable_params"])
152
+
153
+ return self.display_saved_params()
154
+
155
+ def update_saved_params_and_physics(self, **kwargs) -> None:
156
+ """Update save_params and update physics."""
157
+ for key, value in kwargs.items():
158
+ self._update_save_params(key, value)
159
+
160
+ self.physics.update(**kwargs)
161
+
162
+ def forward(self, x: torch.Tensor, use_gen: bool) -> torch.Tensor:
163
+ if self.name in ["MotionBlur_easy", "MotionBlur_medium", "MotionBlur_hard", "GaussianBlur"] and not hasattr(self.physics, "filter"):
164
+ use_gen = True
165
+ elif self.name in ["MRI"] and not hasattr(self.physics, "mask"):
166
+ use_gen = True
167
+
168
+ if use_gen:
169
+ kwargs = self.generator.step(batch_size=x.shape[0]) # generate a set of params for each sample
170
+ self.update_saved_params_and_physics(**kwargs)
171
+
172
+ return self.physics(x)
173
+
174
+
175
+ class EvalModel(torch.nn.Module):
176
+ """Eval model.
177
+
178
+ Is there a difference with BaselineModel ?
179
+ -> BaselineModel should be models that are already trained and will have fixed weights.
180
+ -> Eval model will change depending on differents checkpoints.
181
+ """
182
+ all_models = ["unext_emb_physics_config_C"]
183
+
184
+ def __init__(self, model_name: str, ckpt_pth: str = "", device_str: str = "cpu") -> None:
185
+ """Load the model we want to evaluate."""
186
+ super().__init__()
187
+ self.base_name = model_name
188
+ self.ckpt_pth = ckpt_pth
189
+ self.name = self.base_name
190
+ if self.base_name not in self.all_models:
191
+ raise ValueError(f"{self.base_name} is unavailable.")
192
+ if self.base_name == "unext_emb_physics_config_C":
193
+ if self.ckpt_pth == "":
194
+ self.ckpt_pth = "ckpt/ram_ckp_10.pth.tar"
195
+ self.model = get_model(model_name=self.base_name,
196
+ device='cpu',
197
+ **DEFAULT_MODEL_PARAMS)
198
+
199
+ # load model checkpoint
200
+ state_dict = torch.load(self.ckpt_pth, map_location=lambda storage, loc: storage)[
201
+ 'state_dict'] # load on cpu
202
+ self.model.load_state_dict(state_dict)
203
+ self.model.to(device_str)
204
+ self.model.eval()
205
+
206
+ # add epoch in the model name
207
+ epoch = torch.load(self.ckpt_pth, map_location=lambda storage, loc: storage)['epoch']
208
+ self.name = self.name + f"+{epoch}"
209
+
210
+ def forward(self, y: torch.Tensor, physics: torch.nn.Module) -> torch.Tensor:
211
+ return self.model(y, physics=physics)
212
+
213
+
214
+ class BaselineModel(torch.nn.Module):
215
+ """Baseline model.
216
+
217
+ Is there a difference with EvalModel ?
218
+ -> BaselineModel should be models that are already trained and will have fixed weights.
219
+ -> Eval model will change depending on differents checkpoints.
220
+ """
221
+ all_baselines = ["DRUNET", "PnP-PGD-DRUNET", "SWINIRx2", "SWINIRx4", "DPIR",
222
+ "DPIR_MRI", "DPIR_CT", "PDNET"]
223
+
224
+ def __init__(self, model_name: str, device_str: str = "cpu") -> None:
225
+ super().__init__()
226
+ self.base_name = model_name
227
+ self.ckpt_pth = ""
228
+ self.name = self.base_name
229
+ if self.name not in self.all_baselines:
230
+ raise ValueError(f"{self.name} is unavailable.")
231
+ elif self.name == "DRUNET":
232
+ n_channels = 3
233
+ ckpt_pth = "ckpt/drunet_deepinv_color_finetune_22k.pth"
234
+ self.model = dinv.models.DRUNet(in_channels=n_channels,
235
+ out_channels=n_channels,
236
+ device=device_str,
237
+ pretrained=ckpt_pth)
238
+ self.model.eval() # Set the model to evaluation mode
239
+ elif self.name == 'PDNET':
240
+ ckpt_pth = "ckpt/pdnet.pth.tar"
241
+ self.model = get_model(model_name='pdnet',
242
+ device=device_str)
243
+ self.model.eval()
244
+ self.model.load_state_dict(torch.load(ckpt_pth, map_location=lambda storage, loc: storage)['state_dict'])
245
+ elif self.name == "SWINIRx2":
246
+ n_channels = 3
247
+ scale = 2
248
+ ckpt_pth = "ckpt/001_classicalSR_DF2K_s64w8_SwinIR-M_x2.pth"
249
+ upsampler = 'nearest+conv' if 'realSR' in ckpt_pth else 'pixelshuffle'
250
+ self.model = dinv.models.SwinIR(upscale=scale, in_chans=n_channels, img_size=64, window_size=8,
251
+ img_range=1., depths=[6, 6, 6, 6, 6, 6], embed_dim=180,
252
+ num_heads=[6, 6, 6, 6, 6, 6],
253
+ mlp_ratio=2, upsampler=upsampler, resi_connection='1conv',
254
+ pretrained=ckpt_pth)
255
+ self.model.to(device_str)
256
+ self.model.eval() # Set the model to evaluation mode
257
+ elif self.name == "SWINIRx4":
258
+ n_channels = 3
259
+ scale = 4
260
+ ckpt_pth = "ckpt/001_classicalSR_DF2K_s64w8_SwinIR-M_x4.pth"
261
+ upsampler = 'nearest+conv' if 'realSR' in ckpt_pth else 'pixelshuffle'
262
+ self.model = dinv.models.SwinIR(upscale=scale, in_chans=n_channels, img_size=64, window_size=8,
263
+ img_range=1., depths=[6, 6, 6, 6, 6, 6], embed_dim=180,
264
+ num_heads=[6, 6, 6, 6, 6, 6],
265
+ mlp_ratio=2, upsampler=upsampler, resi_connection='1conv',
266
+ pretrained=ckpt_pth)
267
+ self.model.to(device_str)
268
+ self.model.eval() # Set the model to evaluation mode
269
+
270
+ elif self.name == "PnP-PGD-DRUNET":
271
+ n_channels = 3
272
+ ckpt_pth = "ckpt/drunet_deepinv_color_finetune_22k.pth"
273
+ drunet = dinv.models.DRUNet(in_channels=n_channels,
274
+ out_channels=n_channels,
275
+ device=device_str,
276
+ pretrained=ckpt_pth)
277
+ drunet.eval() # Set the model to evaluation mode
278
+ self.model = dinv.optim.optim_builder(iteration="PGD",
279
+ prior=dinv.optim.PnP(drunet).to(device_str),
280
+ data_fidelity=dinv.optim.L2(),
281
+ max_iter=20,
282
+ params_algo={'stepsize': 1., 'g_param': .05})
283
+ elif self.name == "DPIR":
284
+ n_channels = 3
285
+ ckpt_pth = "ckpt/drunet_deepinv_color_finetune_22k.pth"
286
+ drunet = dinv.models.DRUNet(in_channels=n_channels,
287
+ out_channels=n_channels,
288
+ device=device_str,
289
+ pretrained=ckpt_pth)
290
+ drunet.eval() # Set the model to evaluation mode
291
+
292
+ # Specify the denoising prior
293
+ self.prior = dinv.optim.prior.PnP(denoiser=drunet)
294
+ elif self.name == "DPIR_MRI":
295
+ class ComplexDenoiser(torch.nn.Module):
296
+ def __init__(self, denoiser):
297
+ super().__init__()
298
+ self.denoiser = denoiser
299
+
300
+ def forward(self, x, sigma):
301
+ noisy_batch = torch.cat((x[:, 0:1, ...], x[:, 1:2, ...]), 0)
302
+ input_min = noisy_batch.min()
303
+ denoised_batch = self.denoiser(noisy_batch - input_min, sigma)
304
+ denoised_batch = denoised_batch + input_min
305
+ denoised = torch.cat((denoised_batch[0:1, ...], denoised_batch[1:2, ...]), 1)
306
+ return denoised
307
+
308
+ # Load PnP denoiser backbone
309
+ n_channels = 1
310
+ ckpt_pth = "ckpt/drunet_gray.pth"
311
+ drunet = dinv.models.DRUNet(in_channels=n_channels, out_channels=n_channels, device=device_str,
312
+ pretrained=ckpt_pth)
313
+ complex_drunet = ComplexDenoiser(drunet)
314
+ complex_drunet.eval()
315
+
316
+ # Specify the denoising prior
317
+ self.prior = dinv.optim.prior.PnP(denoiser=complex_drunet)
318
+ elif self.name == "DPIR_CT":
319
+ class CTDenoiser(torch.nn.Module):
320
+ def __init__(self, denoiser):
321
+ super().__init__()
322
+ self.denoiser = denoiser
323
+
324
+ def forward(self, x, sigma):
325
+ x = x - x.min()
326
+ denoised = self.denoiser(x, sigma)
327
+ denoised = denoised + x.min()
328
+ return denoised
329
+
330
+ # Load PnP denoiser backbone
331
+ n_channels = 1
332
+ ckpt_pth = "ckpt/drunet_gray.pth"
333
+ drunet = dinv.models.DRUNet(in_channels=n_channels, out_channels=n_channels, device=device_str,
334
+ pretrained=ckpt_pth)
335
+ ct_drunet = CTDenoiser(drunet)
336
+ ct_drunet.eval()
337
+
338
+ # Specify the denoising prior
339
+ self.prior = dinv.optim.prior.PnP(denoiser=ct_drunet)
340
+
341
+ def circular_roll(self, tensor, p_h, p_w):
342
+ return tensor.roll(shifts=(p_h, p_w), dims=(-2, -1))
343
+
344
+ def get_DPIR_params(self, noise_level_img, max_iter=8):
345
+ r"""
346
+ Default parameters for the DPIR Plug-and-Play algorithm.
347
+
348
+ :param float noise_level_img: Noise level of the input image.
349
+ :return: tuple(list with denoiser noise level per iteration, list with stepsize per iteration, iterations).
350
+ """
351
+ max_iter = 8
352
+ s1 = 49.0 / 255.0
353
+ s2 = max(noise_level_img, 0.01)
354
+ sigma_denoiser = np.logspace(np.log10(s1), np.log10(s2), max_iter).astype(
355
+ np.float32
356
+ )
357
+ stepsize = (sigma_denoiser / max(0.01, noise_level_img)) ** 2
358
+ lamb = 1 / 0.23
359
+ return list(sigma_denoiser), list(lamb * stepsize)
360
+
361
+ def get_DPIR_MRI_params(self, noise_level_img: float, max_iter: int = 8):
362
+ r"""
363
+ Default parameters for the DPIR Plug-and-Play algorithm.
364
+
365
+ :param float noise_level_img: Noise level of the input image.
366
+ """
367
+ s1 = 49.0 / 255.0
368
+ s2 = noise_level_img
369
+ sigma_denoiser = np.logspace(np.log10(s1), np.log10(s2), max_iter).astype(
370
+ np.float32
371
+ )
372
+ stepsize = (sigma_denoiser / max(0.01, noise_level_img)) ** 2
373
+ lamb = 1.
374
+ return lamb, list(sigma_denoiser), list(stepsize), max_iter
375
+
376
+ def get_DPIR_CT_params(self, noise_level_img: float, max_iter: int = 8, lip_cons: float = 1.0):
377
+ r"""
378
+ Default parameters for the DPIR Plug-and-Play algorithm.
379
+
380
+ :param float noise_level_img: Noise level of the input image.
381
+ """
382
+ s1 = 49.0 / 255.0 * lip_cons
383
+ s2 = noise_level_img
384
+ sigma_denoiser = np.logspace(np.log10(s1), np.log10(s2), max_iter).astype(
385
+ np.float32
386
+ )
387
+ stepsize = (sigma_denoiser / max(0.01, noise_level_img)) ** 2 #
388
+ lamb = 1.
389
+ return lamb, list(sigma_denoiser), list(stepsize), max_iter
390
+
391
+ def forward(self, y: torch.Tensor, physics: torch.nn.Module) -> torch.Tensor:
392
+ if self.name == "DRUNET":
393
+ return self.model(y, sigma=physics.noise_model.sigma)
394
+ elif self.name == "PnP-PGD-DRUNET":
395
+ return self.model(y, physics=physics)
396
+ elif self.name == "DPIR":
397
+ # Set the DPIR algorithm parameters
398
+ sigma_float = physics.noise_model.sigma.item() # sigma should be a single value
399
+ max_iter = 8
400
+
401
+ sigma_denoiser, stepsize = self.get_DPIR_params(sigma_float, max_iter=max_iter)
402
+ params_algo = {"stepsize": stepsize, "g_param": sigma_denoiser}
403
+ early_stop = False # Do not stop algorithm with convergence criteria
404
+
405
+ # instantiate DPIR
406
+ model = dinv.optim.optim_builder(
407
+ iteration="HQS",
408
+ prior=self.prior,
409
+ data_fidelity=dinv.optim.data_fidelity.L2(),
410
+ early_stop=early_stop,
411
+ max_iter=max_iter,
412
+ verbose=True,
413
+ params_algo=params_algo,
414
+ )
415
+ return model(y, physics=physics)
416
+ elif self.name == "DPIR_MRI":
417
+ sigma_float = max(physics.noise_model.sigma.item(), 0.015) # sigma should be a single value
418
+ lamb, sigma_denoiser, stepsize, max_iter = self.get_DPIR_MRI_params(sigma_float, max_iter=16)
419
+ stepsize = [stepsize[0]] * max_iter
420
+ params_algo = {"stepsize": stepsize, "g_param": sigma_denoiser, "lambda": lamb}
421
+ early_stop = False # Do not stop algorithm with convergence criteria
422
+
423
+ # Instantiate the algorithm class to solve the IP
424
+ model = dinv.optim.optim_builder(
425
+ iteration="HQS",
426
+ prior=self.prior,
427
+ data_fidelity=dinv.optim.data_fidelity.L2(),
428
+ early_stop=early_stop,
429
+ max_iter=max_iter,
430
+ verbose=True,
431
+ params_algo=params_algo,
432
+ )
433
+ return model(y, physics=physics)
434
+ elif self.name == "DPIR_CT":
435
+ # Set the DPIR algorithm parameters
436
+ sigma_float = physics.noise_model.sigma.item() # sigma should be a single value
437
+ lip_const = physics.compute_norm(physics.A_adjoint(y))
438
+ lamb, sigma_denoiser, stepsize, max_iter = self.get_DPIR_CT_params(sigma_float, max_iter=8,
439
+ lip_cons=lip_const.item())
440
+ params_algo = {"stepsize": stepsize, "g_param": sigma_denoiser, "lambda": lamb}
441
+ early_stop = False # Do not stop algorithm with convergence criteria
442
+
443
+ def custom_init(y, physic_op):
444
+ x_init = physic_op.prox_l2(physic_op.A_adjoint(y), y, gamma=1e4)
445
+ return {"est": (x_init, x_init)}
446
+
447
+ # Instantiate the algorithm class to solve the IP
448
+ algo = dinv.optim.optim_builder(
449
+ iteration="HQS",
450
+ prior=self.prior,
451
+ data_fidelity=dinv.optim.data_fidelity.L2(),
452
+ early_stop=early_stop,
453
+ max_iter=max_iter,
454
+ verbose=True,
455
+ params_algo=params_algo,
456
+ custom_init=custom_init
457
+ )
458
+ return algo(y, physics=physics)
459
+ elif self.name == 'SWINIRx4':
460
+ window_size = 8
461
+ scale = 4
462
+ _, _, h_old, w_old = y.size()
463
+ h_pad = (h_old // window_size + 1) * window_size - h_old
464
+ w_pad = (w_old // window_size + 1) * window_size - w_old
465
+ img_lq = torch.cat([y, torch.flip(y, [2])], 2)[:, :, :h_old + h_pad, :]
466
+ img_lq = torch.cat([img_lq, torch.flip(img_lq, [3])], 3)[:, :, :, :w_old + w_pad]
467
+ output = self.model(img_lq)
468
+ output = output[..., :h_old * scale, :w_old * scale]
469
+ output = self.circular_roll(output, -2, -2)
470
+ # check shape of adjoint
471
+ x_adj = physics.A_adjoint(y)
472
+ output = output[..., :x_adj.size(-2), :x_adj.size(-1)]
473
+ return output
474
+ elif self.name == 'SWINIRx2':
475
+ window_size = 8
476
+ scale = 2
477
+ _, _, h_old, w_old = y.size()
478
+ h_pad = (h_old // window_size + 1) * window_size - h_old
479
+ w_pad = (w_old // window_size + 1) * window_size - w_old
480
+ img_lq = torch.cat([y, torch.flip(y, [2])], 2)[:, :, :h_old + h_pad, :]
481
+ img_lq = torch.cat([img_lq, torch.flip(img_lq, [3])], 3)[:, :, :, :w_old + w_pad]
482
+ output = self.model(img_lq)
483
+ output = output[..., :h_old * scale, :w_old * scale]
484
+ output = self.circular_roll(output, -1, -1)
485
+ # check shape of adjoint
486
+ x_adj = physics.A_adjoint(y)
487
+ output = output[..., :x_adj.size(-2), :x_adj.size(-1)]
488
+ return output
489
+ elif 'UNROLLED_DPIR' in self.name:
490
+ return self.model(y, physics=physics)
491
+ else:
492
+ return self.model(y)
493
+
494
+
495
+ class EvalDataset(torch.utils.data.Dataset):
496
+ """
497
+ We expect that images are 480x480.
498
+ """
499
+ all_datasets = ["Natural", "MRI", "CT"]
500
+
501
+ def __init__(self, dataset_name: str, device_str: str = "cpu") -> None:
502
+ self.name = dataset_name
503
+ self.device_str = device_str
504
+ if self.name not in self.all_datasets:
505
+ raise ValueError(f"{self.name} is unavailable.")
506
+ if self.name == 'Natural':
507
+ self.root = 'datasets/LSDIR_samples'
508
+ self.transform = transforms.Compose([transforms.ToTensor()])
509
+ self.dataset = dinv.datasets.LsdirHR(root=self.root,
510
+ download=False,
511
+ transform=self.transform)
512
+ elif self.name == 'MRI':
513
+ self.root = 'datasets/FastMRI_samples'
514
+ self.transform = transforms.CenterCrop((640, 320)) # , pad_if_needed=True)
515
+ self.dataset = Preprocessed_fastMRI(root=self.root,
516
+ transform=self.transform,
517
+ preprocess=False)
518
+ elif self.name == "CT":
519
+ self.root = 'datasets/LIDC_IDRI_samples'
520
+ self.transform = None
521
+ self.dataset = Preprocessed_LIDCIDRI(root=self.root,
522
+ transform=self.transform)
523
+
524
+ def __len__(self) -> int:
525
+ return len(self.dataset)
526
+
527
+ def __getitem__(self, idx: int) -> torch.Tensor:
528
+ return self.dataset[idx].to(self.device_str)
529
+
530
+
531
+ class Metric():
532
+ """Metrics and utilities."""
533
+ all_metrics = ["PSNR", "SSIM", "LPIPS"]
534
+
535
+ def __init__(self, metric_name: str, device_str: str = "cpu") -> None:
536
+ self.name = metric_name
537
+ if self.name not in self.all_metrics:
538
+ raise ValueError(f"{self.name} is unavailable.")
539
+ elif self.name == "PSNR":
540
+ self.metric = dinv.loss.metric.PSNR()
541
+ elif self.name == "SSIM":
542
+ self.metric = dinv.loss.metric.SSIM()
543
+ elif self.name == "LPIPS":
544
+ self.metric = dinv.loss.metric.LPIPS(device=device_str)
545
+
546
+ def __call__(self, x_net: torch.Tensor, x: torch.Tensor, *args, **kwargs) -> torch.Tensor:
547
+ # it may happen that x_net and x do not have the same size, in which case we take the minimum size of both
548
+ if x_net.shape[-1] != x.shape[-1]:
549
+ min_size = min(x_net.shape[-1], x.shape[-1])
550
+ x_net_crop = x_net[..., x_net.shape[-2] // 2 - min_size // 2: x_net.shape[-2] // 2 + min_size // 2,
551
+ x_net.shape[-1] // 2 - min_size // 2: x_net.shape[-1] // 2 + min_size // 2]
552
+ x_crop = x[..., x_net.shape[-2] // 2 - min_size // 2: x_net.shape[-2] // 2 + min_size // 2,
553
+ x_net.shape[-1] // 2 - min_size // 2: x_net.shape[-1] // 2 + min_size // 2]
554
+ else:
555
+ x_net_crop = x_net
556
+ x_crop = x
557
+ return self.metric(x_net_crop, x_crop)
558
+
559
+ @classmethod
560
+ def get_list_metrics(cls, metric_names: List[str], device_str: str = "cpu") -> List["Metric"]:
561
+ l = []
562
+ for metric_name in metric_names:
563
+ l.append(cls(metric_name, device_str=device_str))
564
+ return l
img_samples/FastMRI_samples/file_brain_AXT1POST_209_6001231_11.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3caf7165619d7c5f1e30c6ecca6f5239e318aeb3e070daacc9f8b7343d803fee
3
+ size 1639843
img_samples/FastMRI_samples/file_brain_AXT2_205_2050122_7.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f8dfa2385d380ccb22c8cd4a7045ee1262e57ad8ee64aa9a763d0f9fe9404116
3
+ size 1639818
img_samples/FastMRI_samples/file_brain_AXT2_205_2050160_10.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f10dacc30dade778bc255c9c61859887282bf36b4e967d0bb2828baf2bb2a914
3
+ size 1639823
img_samples/FastMRI_samples/file_brain_AXT2_210_6001888_6.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5cd2d841a101fbd9cc6b37e2c7dea35d868be01977cb08165446e8feb05c22ac
3
+ size 2434442
img_samples/FastMRI_samples/file_brain_AXT2_210_6001947_5.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e85534e07a07e07107cec4a256c3112f1ca6283f39280a603a559c182e534abb
3
+ size 2434442
img_samples/LIDC-IDRI_samples/LIDC-IDRI-0032_01-01-2000-NA-NA-53482_3000537.000000-NA-91689_1-236.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:177a0102e65a7fbf97161c431f77722dc842a202dfde0e8a4e9c75dcb9f4ab9a
3
+ size 2098888
img_samples/LIDC-IDRI_samples/LIDC-IDRI-0083_01-01-2000-NA-NA-22049_3000646.000000-NA-60532_1-027.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:002d727e599d79615f36be14a6136e6c89c6ee20a0c4bd943aff097bb1447270
3
+ size 2098888
img_samples/LIDC-IDRI_samples/LIDC-IDRI-0144_01-01-2000-NA-NA-61308_3000703.000000-NA-75826_1-079.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7a66e185bbc1e5bbd57c12b0c4c55ee3b6572154d968a8511b00564159ef0837
3
+ size 2098888
img_samples/LIDC-IDRI_samples/LIDC-IDRI-0152_01-01-2000-NA-NA-78489_3000696.000000-NA-27171_1-083.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e0b731e35e6debaf7241ee6ae7a89915163851178b9e4d54c4104cb8a4426076
3
+ size 2098888
img_samples/LIDC-IDRI_samples/LIDC-IDRI-0298_01-01-2000-NA-NA-11572_3000663.000000-NA-48288_1-004.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6d7b5a7b1bfaf79f203fb492e735434964c71a9408caa9c6df407579f0df6000
3
+ size 2098888
img_samples/LSDIR_samples/0001000/0000007_s005.png ADDED
img_samples/LSDIR_samples/0001000/0000030_s003.png ADDED
img_samples/LSDIR_samples/0001000/0000067_s005.png ADDED
img_samples/LSDIR_samples/0001000/0000082_s003.png ADDED
img_samples/LSDIR_samples/0001000/0000110_s002.png ADDED
img_samples/LSDIR_samples/0001000/0000125_s003.png ADDED
img_samples/LSDIR_samples/0001000/0000154_s007.png ADDED
img_samples/LSDIR_samples/0001000/0000247_s007.png ADDED
img_samples/LSDIR_samples/0001000/0000259_s003.png ADDED
img_samples/LSDIR_samples/0001000/0000405_s008.png ADDED
img_samples/LSDIR_samples/0001000/0000578_s002.png ADDED
img_samples/LSDIR_samples/0001000/0000669_s010.png ADDED
img_samples/LSDIR_samples/0001000/0000689_s006.png ADDED
img_samples/LSDIR_samples/0001000/0000715_s011.png ADDED
img_samples/LSDIR_samples/0001000/0000752_s010.png ADDED
img_samples/LSDIR_samples/0001000/0000803_s012.png ADDED
img_samples/LSDIR_samples/0001000/0000825_s012.png ADDED
img_samples/LSDIR_samples/0001000/0000921_s012.png ADDED
img_samples/LSDIR_samples/0001000/0000958_s004.png ADDED
img_samples/LSDIR_samples/0001000/0000994_s021.png ADDED
img_samples/LSDIR_samples/0009000/0008033_s006.png ADDED
img_samples/LSDIR_samples/0009000/0008068_s005.png ADDED
img_samples/LSDIR_samples/0009000/0008115_s004.png ADDED
img_samples/LSDIR_samples/0009000/0008217_s002.png ADDED
img_samples/LSDIR_samples/0009000/0008294_s010.png ADDED
img_samples/LSDIR_samples/0009000/0008315_s053.png ADDED
img_samples/LSDIR_samples/0009000/0008340_s015.png ADDED
img_samples/LSDIR_samples/0009000/0008361_s009.png ADDED
img_samples/LSDIR_samples/0009000/0008386_s007.png ADDED