Spaces:
Sleeping
Sleeping
new layout
Browse files- app.py +102 -90
- evals.py → factories.py +3 -2
- models/unrolled_dpir.py +0 -304
app.py
CHANGED
@@ -11,24 +11,16 @@ import torch
|
|
11 |
from PIL import Image
|
12 |
from torchvision import transforms
|
13 |
|
14 |
-
from
|
15 |
|
16 |
|
17 |
-
|
|
|
|
|
18 |
|
19 |
|
20 |
### Gradio Utils
|
21 |
|
22 |
-
def generate_imgs_from_dataset(dataset: EvalDataset, idx: int,
|
23 |
-
model: EvalModel, baseline: BaselineModel,
|
24 |
-
physics: PhysicsWithGenerator, use_gen: bool,
|
25 |
-
metrics: List[Metric]):
|
26 |
-
### Load 1 image
|
27 |
-
x = dataset[idx] # shape : (3, 256, 256)
|
28 |
-
x = x.unsqueeze(0) # shape : (1, 3, 256, 256)
|
29 |
-
|
30 |
-
return generate_imgs(x, model, baseline, physics, use_gen, metrics)
|
31 |
-
|
32 |
def generate_imgs_from_user(image,
|
33 |
model: EvalModel, baseline: BaselineModel,
|
34 |
physics: PhysicsWithGenerator, use_gen: bool,
|
@@ -37,9 +29,31 @@ def generate_imgs_from_user(image,
|
|
37 |
return None, None, None, None, None, None, None, None
|
38 |
|
39 |
# PIL image -> torch.Tensor
|
40 |
-
x = transforms.ToTensor()(image).unsqueeze(0).to(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
41 |
|
42 |
return generate_imgs(x, model, baseline, physics, use_gen, metrics)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
43 |
|
44 |
def generate_imgs(x: torch.Tensor,
|
45 |
model: EvalModel, baseline: BaselineModel,
|
@@ -75,7 +89,7 @@ def generate_imgs(x: torch.Tensor,
|
|
75 |
metrics_out_baseline += f"{metric.name} = {metric(out_baseline, x).item():.4f}" + "\n"
|
76 |
|
77 |
### Process y when y shape is different from x shape
|
78 |
-
if physics.name == "MRI"
|
79 |
y_plot = physics.physics.prox_l2(physics.physics.A_adjoint(y), y, 1e4)
|
80 |
else:
|
81 |
y_plot = y.clone()
|
@@ -93,18 +107,6 @@ def generate_imgs(x: torch.Tensor,
|
|
93 |
|
94 |
return x, y, out, out_baseline, physics.display_saved_params(), metrics_y, metrics_out, metrics_out_baseline
|
95 |
|
96 |
-
def generate_random_imgs_from_dataset(dataset: EvalDataset,
|
97 |
-
model: EvalModel,
|
98 |
-
baseline: BaselineModel,
|
99 |
-
physics: PhysicsWithGenerator,
|
100 |
-
use_gen: bool,
|
101 |
-
metrics: List[Metric]):
|
102 |
-
idx = random.randint(0, len(dataset)-1)
|
103 |
-
x, y, out, out_baseline, saved_params_str, metrics_y, metrics_out, metrics_out_baseline = generate_imgs_from_dataset(
|
104 |
-
dataset, idx, model, baseline, physics, use_gen, metrics
|
105 |
-
)
|
106 |
-
return idx, x, y, out, out_baseline, saved_params_str, metrics_y, metrics_out, metrics_out_baseline
|
107 |
-
|
108 |
|
109 |
get_list_metrics_on_DEVICE_STR = partial(Metric.get_list_metrics, device_str=DEVICE_STR)
|
110 |
get_eval_model_on_DEVICE_STR = partial(EvalModel, device_str=DEVICE_STR)
|
@@ -112,7 +114,8 @@ get_baseline_model_on_DEVICE_STR = partial(BaselineModel, device_str=DEVICE_STR)
|
|
112 |
get_dataset_on_DEVICE_STR = partial(EvalDataset, device_str=DEVICE_STR)
|
113 |
get_physics_on_DEVICE_STR = partial(PhysicsWithGenerator, device_str=DEVICE_STR)
|
114 |
|
115 |
-
AVAILABLE_PHYSICS =
|
|
|
116 |
def get_dataset(dataset_name):
|
117 |
global AVAILABLE_PHYSICS
|
118 |
if dataset_name == 'MRI':
|
@@ -124,10 +127,15 @@ def get_dataset(dataset_name):
|
|
124 |
baseline_name = 'DPIR_CT'
|
125 |
physics_name = 'CT'
|
126 |
else:
|
127 |
-
AVAILABLE_PHYSICS = ['MotionBlur_easy', 'MotionBlur_medium', 'MotionBlur_hard',
|
|
|
128 |
baseline_name = 'DPIR'
|
129 |
physics_name = 'MotionBlur_easy'
|
130 |
-
|
|
|
|
|
|
|
|
|
131 |
|
132 |
|
133 |
### Gradio Blocks interface
|
@@ -135,9 +143,9 @@ def get_dataset(dataset_name):
|
|
135 |
# Define custom CSS
|
136 |
custom_css = """
|
137 |
.fixed-textbox textarea {
|
138 |
-
height:
|
139 |
-
overflow: scroll;
|
140 |
-
resize: none;
|
141 |
}
|
142 |
"""
|
143 |
|
@@ -145,87 +153,88 @@ title = "Inverse problem playground" # displayed on gradio tab and in the gradi
|
|
145 |
with gr.Blocks(title=title, css=custom_css) as interface:
|
146 |
gr.Markdown("## " + title)
|
147 |
|
148 |
-
#
|
149 |
-
|
150 |
-
|
151 |
-
|
152 |
-
|
153 |
-
|
|
|
|
|
|
|
|
|
|
|
154 |
|
155 |
-
@gr.render(inputs=[dataset_placeholder, physics_placeholder
|
156 |
-
def dynamic_layout(dataset, physics
|
157 |
### LAYOUT
|
158 |
-
dataset_name = dataset.name
|
159 |
-
physics_name = physics.name
|
160 |
-
metric_names = [metric.name for metric in metrics]
|
161 |
|
162 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
163 |
with gr.Row():
|
164 |
with gr.Column():
|
|
|
|
|
|
|
|
|
|
|
165 |
with gr.Row():
|
166 |
-
|
167 |
-
|
168 |
-
|
169 |
-
|
170 |
-
|
171 |
-
|
|
|
|
|
|
|
|
|
|
|
172 |
|
|
|
|
|
|
|
173 |
choose_physics = gr.Radio(choices=AVAILABLE_PHYSICS,
|
174 |
-
label="
|
175 |
-
value=
|
|
|
|
|
176 |
with gr.Row():
|
177 |
key_selector = gr.Dropdown(choices=list(physics.saved_params["updatable_params"].keys()),
|
178 |
-
|
179 |
-
|
180 |
-
|
181 |
-
|
182 |
-
|
183 |
-
|
184 |
-
|
185 |
-
with gr.Column():
|
186 |
-
model_a_out = gr.Image(label="RAM OUTPUT", interactive=False)
|
187 |
-
out_a_metric = gr.Textbox(label="Metrics(RAM(y, physics), x)", elem_classes=["fixed-textbox"])
|
188 |
-
with gr.Column():
|
189 |
-
model_b_out = gr.Image(label="DPIR OUTPUT", interactive=False)
|
190 |
-
out_b_metric = gr.Textbox(label="Metrics(DPIR(y, physics), x)", elem_classes=["fixed-textbox"])
|
191 |
-
with gr.Row():
|
192 |
-
choose_dataset = gr.Radio(choices=EvalDataset.all_datasets,
|
193 |
-
label="List of EvalDataset",
|
194 |
-
value=dataset_name,
|
195 |
-
scale=2)
|
196 |
-
idx_slider = gr.Slider(minimum=0, maximum=len(dataset)-1, step=1, label="Sample index", scale=1)
|
197 |
|
198 |
-
# Components: Load Metric + Load image Buttons
|
199 |
-
with gr.Row():
|
200 |
-
with gr.Column(scale=3):
|
201 |
-
choose_metrics = gr.CheckboxGroup(choices=Metric.all_metrics,
|
202 |
-
value=metric_names,
|
203 |
-
label="Choose metrics you are interested")
|
204 |
-
use_generator_button = gr.Checkbox(label="Generate valid physics parameters", scale=1)
|
205 |
-
run_button = gr.Button("Run current image", scale=1)
|
206 |
-
with gr.Column(scale=1):
|
207 |
-
load_button = gr.Button("Load images from dataset...")
|
208 |
-
load_random_button = gr.Button("Load randomly from dataset...")
|
209 |
|
210 |
### Event listeners
|
|
|
211 |
choose_dataset.change(fn=get_dataset,
|
212 |
inputs=choose_dataset,
|
213 |
outputs=[dataset_placeholder, physics_placeholder, model_b_placeholder])
|
214 |
choose_physics.change(fn=get_physics_on_DEVICE_STR,
|
215 |
inputs=choose_physics,
|
216 |
outputs=[physics_placeholder])
|
217 |
-
update_button.click(fn=physics.update_and_display_params,
|
218 |
-
|
219 |
-
inputs=choose_metrics,
|
220 |
-
outputs=metrics_placeholder)
|
221 |
run_button.click(fn=generate_imgs_from_user,
|
222 |
-
inputs=[
|
223 |
model_a_placeholder,
|
224 |
model_b_placeholder,
|
225 |
physics_placeholder,
|
226 |
use_generator_button,
|
227 |
metrics_placeholder],
|
228 |
-
outputs=[
|
|
|
229 |
load_button.click(fn=generate_imgs_from_dataset,
|
230 |
inputs=[dataset_placeholder,
|
231 |
idx_slider,
|
@@ -234,7 +243,8 @@ with gr.Blocks(title=title, css=custom_css) as interface:
|
|
234 |
physics_placeholder,
|
235 |
use_generator_button,
|
236 |
metrics_placeholder],
|
237 |
-
outputs=[
|
|
|
238 |
load_random_button.click(fn=generate_random_imgs_from_dataset,
|
239 |
inputs=[dataset_placeholder,
|
240 |
model_a_placeholder,
|
@@ -242,6 +252,8 @@ with gr.Blocks(title=title, css=custom_css) as interface:
|
|
242 |
physics_placeholder,
|
243 |
use_generator_button,
|
244 |
metrics_placeholder],
|
245 |
-
outputs=[idx_slider,
|
|
|
|
|
246 |
|
247 |
interface.launch()
|
|
|
11 |
from PIL import Image
|
12 |
from torchvision import transforms
|
13 |
|
14 |
+
from factories import PhysicsWithGenerator, EvalModel, BaselineModel, EvalDataset, Metric
|
15 |
|
16 |
|
17 |
+
### Config
|
18 |
+
DEVICE_STR = 'cuda' # run model inference on NVIDIA gpu
|
19 |
+
torch.set_grad_enabled(False) # stops tracking values for gradients
|
20 |
|
21 |
|
22 |
### Gradio Utils
|
23 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
24 |
def generate_imgs_from_user(image,
|
25 |
model: EvalModel, baseline: BaselineModel,
|
26 |
physics: PhysicsWithGenerator, use_gen: bool,
|
|
|
29 |
return None, None, None, None, None, None, None, None
|
30 |
|
31 |
# PIL image -> torch.Tensor
|
32 |
+
x = transforms.ToTensor()(image).unsqueeze(0).to(DEVICE_STR)
|
33 |
+
|
34 |
+
return generate_imgs(x, model, baseline, physics, use_gen, metrics)
|
35 |
+
|
36 |
+
def generate_imgs_from_dataset(dataset: EvalDataset, idx: int,
|
37 |
+
model: EvalModel, baseline: BaselineModel,
|
38 |
+
physics: PhysicsWithGenerator, use_gen: bool,
|
39 |
+
metrics: List[Metric]):
|
40 |
+
### Load 1 image
|
41 |
+
x = dataset[idx] # shape : (C, H, W)
|
42 |
+
x = x.unsqueeze(0) # shape : (1, C, H, W)
|
43 |
|
44 |
return generate_imgs(x, model, baseline, physics, use_gen, metrics)
|
45 |
+
|
46 |
+
def generate_random_imgs_from_dataset(dataset: EvalDataset,
|
47 |
+
model: EvalModel,
|
48 |
+
baseline: BaselineModel,
|
49 |
+
physics: PhysicsWithGenerator,
|
50 |
+
use_gen: bool,
|
51 |
+
metrics: List[Metric]):
|
52 |
+
idx = random.randint(0, len(dataset)-1)
|
53 |
+
x, y, out, out_baseline, saved_params_str, metrics_y, metrics_out, metrics_out_baseline = generate_imgs_from_dataset(
|
54 |
+
dataset, idx, model, baseline, physics, use_gen, metrics
|
55 |
+
)
|
56 |
+
return idx, x, y, out, out_baseline, saved_params_str, metrics_y, metrics_out, metrics_out_baseline
|
57 |
|
58 |
def generate_imgs(x: torch.Tensor,
|
59 |
model: EvalModel, baseline: BaselineModel,
|
|
|
89 |
metrics_out_baseline += f"{metric.name} = {metric(out_baseline, x).item():.4f}" + "\n"
|
90 |
|
91 |
### Process y when y shape is different from x shape
|
92 |
+
if physics.name == "MRI" in physics.name:
|
93 |
y_plot = physics.physics.prox_l2(physics.physics.A_adjoint(y), y, 1e4)
|
94 |
else:
|
95 |
y_plot = y.clone()
|
|
|
107 |
|
108 |
return x, y, out, out_baseline, physics.display_saved_params(), metrics_y, metrics_out, metrics_out_baseline
|
109 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
110 |
|
111 |
get_list_metrics_on_DEVICE_STR = partial(Metric.get_list_metrics, device_str=DEVICE_STR)
|
112 |
get_eval_model_on_DEVICE_STR = partial(EvalModel, device_str=DEVICE_STR)
|
|
|
114 |
get_dataset_on_DEVICE_STR = partial(EvalDataset, device_str=DEVICE_STR)
|
115 |
get_physics_on_DEVICE_STR = partial(PhysicsWithGenerator, device_str=DEVICE_STR)
|
116 |
|
117 |
+
AVAILABLE_PHYSICS = ['MotionBlur_easy', 'MotionBlur_medium', 'MotionBlur_hard',
|
118 |
+
'GaussianBlur_easy', 'GaussianBlur_medium', 'GaussianBlur_hard']
|
119 |
def get_dataset(dataset_name):
|
120 |
global AVAILABLE_PHYSICS
|
121 |
if dataset_name == 'MRI':
|
|
|
127 |
baseline_name = 'DPIR_CT'
|
128 |
physics_name = 'CT'
|
129 |
else:
|
130 |
+
AVAILABLE_PHYSICS = ['MotionBlur_easy', 'MotionBlur_medium', 'MotionBlur_hard',
|
131 |
+
'GaussianBlur_easy', 'GaussianBlur_medium', 'GaussianBlur_hard']
|
132 |
baseline_name = 'DPIR'
|
133 |
physics_name = 'MotionBlur_easy'
|
134 |
+
|
135 |
+
dataset = get_dataset_on_DEVICE_STR(dataset_name)
|
136 |
+
physics = get_physics_on_DEVICE_STR(physics_name)
|
137 |
+
baseline = get_baseline_model_on_DEVICE_STR(baseline_name)
|
138 |
+
return dataset, physics, baseline
|
139 |
|
140 |
|
141 |
### Gradio Blocks interface
|
|
|
143 |
# Define custom CSS
|
144 |
custom_css = """
|
145 |
.fixed-textbox textarea {
|
146 |
+
height: 100px !important; /* Adjust height to fit exactly 4 lines */
|
147 |
+
overflow: scroll; /* Add a scroll bar if necessary */
|
148 |
+
resize: none; /* User can resize vertically the textbox */
|
149 |
}
|
150 |
"""
|
151 |
|
|
|
153 |
with gr.Blocks(title=title, css=custom_css) as interface:
|
154 |
gr.Markdown("## " + title)
|
155 |
|
156 |
+
# DEFAULT VALUES
|
157 |
+
# Issue: giving directly a `torch.nn.module` to `gr.State(...)` since it has __call__ method
|
158 |
+
# Solution: using lambda expression
|
159 |
+
model_a_placeholder = gr.State(lambda: get_eval_model_on_DEVICE_STR("unext_emb_physics_config_C", ""))
|
160 |
+
model_b_placeholder = gr.State(lambda: get_baseline_model_on_DEVICE_STR("DPIR"))
|
161 |
+
dataset_placeholder = gr.State(get_dataset_on_DEVICE_STR("Natural"))
|
162 |
+
physics_placeholder = gr.State(lambda: get_physics_on_DEVICE_STR("MotionBlur_easy"))
|
163 |
+
idx_placeholder = gr.State(0)
|
164 |
+
|
165 |
+
metric_names = ["PSNR"]
|
166 |
+
metrics_placeholder = gr.State(get_list_metrics_on_DEVICE_STR(metric_names))
|
167 |
|
168 |
+
@gr.render(inputs=[dataset_placeholder, physics_placeholder])
|
169 |
+
def dynamic_layout(dataset, physics):
|
170 |
### LAYOUT
|
|
|
|
|
|
|
171 |
|
172 |
+
# Display images
|
173 |
+
with gr.Row():
|
174 |
+
gt_img = gr.Image(label=f"Ground-truth IMAGE", interactive=True)
|
175 |
+
observed_img = gr.Image(label=f"Observed IMAGE", interactive=False)
|
176 |
+
model_a_out = gr.Image(label="RAM OUTPUT", interactive=False)
|
177 |
+
model_b_out = gr.Image(label="DPIR OUTPUT", interactive=False)
|
178 |
+
|
179 |
+
# Manage datasets and display metric values
|
180 |
with gr.Row():
|
181 |
with gr.Column():
|
182 |
+
run_button = gr.Button("Demo on above image")
|
183 |
+
choose_dataset = gr.Radio(choices=EvalDataset.all_datasets,
|
184 |
+
label="Datasets",
|
185 |
+
value=dataset.name)
|
186 |
+
idx_slider = gr.Slider(minimum=0, maximum=len(dataset)-1, step=1, label="Sample index")
|
187 |
with gr.Row():
|
188 |
+
load_button = gr.Button("Run on index image from dataset")
|
189 |
+
load_random_button = gr.Button("Run on random image from dataset")
|
190 |
+
with gr.Column():
|
191 |
+
observed_metrics = gr.Textbox(label="PSNR(Observed, Ground-truth)",
|
192 |
+
elem_classes=["fixed-textbox"])
|
193 |
+
with gr.Column():
|
194 |
+
out_a_metric = gr.Textbox(label="PSNR(RAM(Observed, Ground-truth)",
|
195 |
+
elem_classes=["fixed-textbox"])
|
196 |
+
with gr.Column():
|
197 |
+
out_b_metric = gr.Textbox(label="PSNR(DPIR(Observed, Ground-truth)",
|
198 |
+
elem_classes=["fixed-textbox"])
|
199 |
|
200 |
+
# Manage physics
|
201 |
+
with gr.Row():
|
202 |
+
with gr.Column(scale=1):
|
203 |
choose_physics = gr.Radio(choices=AVAILABLE_PHYSICS,
|
204 |
+
label="Physics",
|
205 |
+
value=physics.name)
|
206 |
+
use_generator_button = gr.Checkbox(label="Generate physics parameters during inference")
|
207 |
+
with gr.Column(scale=1):
|
208 |
with gr.Row():
|
209 |
key_selector = gr.Dropdown(choices=list(physics.saved_params["updatable_params"].keys()),
|
210 |
+
label="Updatable Parameter Key")
|
211 |
+
value_text = gr.Textbox(label="Update Value")
|
212 |
+
update_button = gr.Button("Manually update parameter value")
|
213 |
+
with gr.Column(scale=2):
|
214 |
+
physics_params = gr.Textbox(label="Physics parameters",
|
215 |
+
elem_classes=["fixed-textbox"],
|
216 |
+
value=physics.display_saved_params())
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
217 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
218 |
|
219 |
### Event listeners
|
220 |
+
|
221 |
choose_dataset.change(fn=get_dataset,
|
222 |
inputs=choose_dataset,
|
223 |
outputs=[dataset_placeholder, physics_placeholder, model_b_placeholder])
|
224 |
choose_physics.change(fn=get_physics_on_DEVICE_STR,
|
225 |
inputs=choose_physics,
|
226 |
outputs=[physics_placeholder])
|
227 |
+
update_button.click(fn=physics.update_and_display_params,
|
228 |
+
inputs=[key_selector, value_text], outputs=physics_params)
|
|
|
|
|
229 |
run_button.click(fn=generate_imgs_from_user,
|
230 |
+
inputs=[gt_img,
|
231 |
model_a_placeholder,
|
232 |
model_b_placeholder,
|
233 |
physics_placeholder,
|
234 |
use_generator_button,
|
235 |
metrics_placeholder],
|
236 |
+
outputs=[gt_img, observed_img, model_a_out, model_b_out,
|
237 |
+
physics_params, observed_metrics, out_a_metric, out_b_metric])
|
238 |
load_button.click(fn=generate_imgs_from_dataset,
|
239 |
inputs=[dataset_placeholder,
|
240 |
idx_slider,
|
|
|
243 |
physics_placeholder,
|
244 |
use_generator_button,
|
245 |
metrics_placeholder],
|
246 |
+
outputs=[gt_img, observed_img, model_a_out, model_b_out,
|
247 |
+
physics_params, observed_metrics, out_a_metric, out_b_metric])
|
248 |
load_random_button.click(fn=generate_random_imgs_from_dataset,
|
249 |
inputs=[dataset_placeholder,
|
250 |
model_a_placeholder,
|
|
|
252 |
physics_placeholder,
|
253 |
use_generator_button,
|
254 |
metrics_placeholder],
|
255 |
+
outputs=[idx_slider, gt_img, observed_img, model_a_out, model_b_out,
|
256 |
+
physics_params, observed_metrics, out_a_metric, out_b_metric])
|
257 |
+
|
258 |
|
259 |
interface.launch()
|
evals.py → factories.py
RENAMED
@@ -8,6 +8,7 @@ from torchvision import transforms
|
|
8 |
|
9 |
from datasets import Preprocessed_fastMRI, Preprocessed_LIDCIDRI, LsdirMiniDataset
|
10 |
from model_factory import get_model
|
|
|
11 |
|
12 |
DEFAULT_MODEL_PARAMS = {
|
13 |
"in_channels": [1, 2, 3],
|
@@ -159,7 +160,7 @@ class PhysicsWithGenerator(torch.nn.Module):
|
|
159 |
|
160 |
def _update_save_params(self, key: str, value: Any) -> None:
|
161 |
"""Update value of an existing key in save_params."""
|
162 |
-
if key in list(self.saved_params["updatable_params"].keys()):
|
163 |
if type(value) == str: # it may be only a str representation
|
164 |
# type: str -> ???
|
165 |
value = self.saved_params["updatable_params_converter"][key](value)
|
@@ -168,7 +169,7 @@ class PhysicsWithGenerator(torch.nn.Module):
|
|
168 |
value = float(f"{value:.4f}") # keeps only 4 significant digits
|
169 |
self.saved_params["updatable_params"][key] = value
|
170 |
|
171 |
-
def update_and_display_params(self, key, value) -> str:
|
172 |
"""_update_save_params + update physics with saved_params + display_saved_params"""
|
173 |
self._update_save_params(key, value)
|
174 |
|
|
|
8 |
|
9 |
from datasets import Preprocessed_fastMRI, Preprocessed_LIDCIDRI, LsdirMiniDataset
|
10 |
from model_factory import get_model
|
11 |
+
from physics.blur_generator import GaussianBlurGenerator
|
12 |
|
13 |
DEFAULT_MODEL_PARAMS = {
|
14 |
"in_channels": [1, 2, 3],
|
|
|
160 |
|
161 |
def _update_save_params(self, key: str, value: Any) -> None:
|
162 |
"""Update value of an existing key in save_params."""
|
163 |
+
if value != "" and key in list(self.saved_params["updatable_params"].keys()):
|
164 |
if type(value) == str: # it may be only a str representation
|
165 |
# type: str -> ???
|
166 |
value = self.saved_params["updatable_params_converter"][key](value)
|
|
|
169 |
value = float(f"{value:.4f}") # keeps only 4 significant digits
|
170 |
self.saved_params["updatable_params"][key] = value
|
171 |
|
172 |
+
def update_and_display_params(self, key: str, value: Any) -> str:
|
173 |
"""_update_save_params + update physics with saved_params + display_saved_params"""
|
174 |
self._update_save_params(key, value)
|
175 |
|
models/unrolled_dpir.py
DELETED
@@ -1,304 +0,0 @@
|
|
1 |
-
import numpy as np
|
2 |
-
import deepinv
|
3 |
-
import torch
|
4 |
-
import deepinv as dinv
|
5 |
-
from deepinv.optim.data_fidelity import L2
|
6 |
-
from deepinv.optim.prior import PnP
|
7 |
-
from deepinv.unfolded import unfolded_builder
|
8 |
-
import copy
|
9 |
-
import deepinv.optim.utils
|
10 |
-
|
11 |
-
class PoissonGaussianDistance(dinv.optim.Distance):
|
12 |
-
r"""
|
13 |
-
Implementation of :math:`\distancename` as the normalized :math:`\ell_2` norm
|
14 |
-
|
15 |
-
.. math::
|
16 |
-
f(x) = (x-y)^{T}\Sigma_y(x-y)
|
17 |
-
|
18 |
-
with :math:`\Sigma_y=\text{diag}(gamma y + \sigma^2)`
|
19 |
-
|
20 |
-
:param float sigma: Gaussian noise parameter. Default: 1.
|
21 |
-
:param float gain: Poisson noise parameter. Default 0.
|
22 |
-
"""
|
23 |
-
|
24 |
-
def __init__(self, sigma=1.0, gain=0.):
|
25 |
-
super().__init__()
|
26 |
-
self.sigma = sigma
|
27 |
-
self.gain = gain
|
28 |
-
|
29 |
-
def fn(self, x, y, *args, **kwargs):
|
30 |
-
r"""
|
31 |
-
Computes the distance :math:`\distance{x}{y}` i.e.
|
32 |
-
|
33 |
-
.. math::
|
34 |
-
|
35 |
-
\distance{x}{y} = \frac{1}{2}\|x-y\|^2
|
36 |
-
|
37 |
-
|
38 |
-
:param torch.Tensor u: Variable :math:`x` at which the data fidelity is computed.
|
39 |
-
:param torch.Tensor y: Data :math:`y`.
|
40 |
-
:return: (:class:`torch.Tensor`) data fidelity :math:`\datafid{u}{y}` of size `B` with `B` the size of the batch.
|
41 |
-
"""
|
42 |
-
norm = 1.0 / (self.sigma**2 + y * self.gain)
|
43 |
-
z = (x - y) * norm
|
44 |
-
d = 0.5 * torch.norm(z.reshape(z.shape[0], -1), p=2, dim=-1) ** 2
|
45 |
-
return d
|
46 |
-
|
47 |
-
def grad(self, x, y, *args, **kwargs):
|
48 |
-
r"""
|
49 |
-
Computes the gradient of :math:`\distancename`, that is :math:`\nabla_{x}\distance{x}{y}`, i.e.
|
50 |
-
|
51 |
-
.. math::
|
52 |
-
|
53 |
-
\nabla_{x}\distance{x}{y} = \frac{1}{\sigma^2} x-y
|
54 |
-
|
55 |
-
|
56 |
-
:param torch.Tensor x: Variable :math:`x` at which the gradient is computed.
|
57 |
-
:param torch.Tensor y: Observation :math:`y`.
|
58 |
-
:return: (:class:`torch.Tensor`) gradient of the distance function :math:`\nabla_{x}\distance{x}{y}`.
|
59 |
-
"""
|
60 |
-
norm = 1.0 / (self.sigma**2 + y * self.gain)
|
61 |
-
return (x - y) * norm
|
62 |
-
|
63 |
-
def prox(self, x, y, *args, gamma=1.0, **kwargs):
|
64 |
-
r"""
|
65 |
-
Proximal operator of :math:`\gamma \distance{x}{y} = \frac{\gamma}{2 \sigma^2} \|x-y\|^2`.
|
66 |
-
|
67 |
-
Computes :math:`\operatorname{prox}_{\gamma \distancename}`, i.e.
|
68 |
-
|
69 |
-
.. math::
|
70 |
-
|
71 |
-
\operatorname{prox}_{\gamma \distancename} = \underset{u}{\text{argmin}} \frac{\gamma}{2\sigma^2}\|u-y\|_2^2+\frac{1}{2}\|u-x\|_2^2
|
72 |
-
|
73 |
-
|
74 |
-
:param torch.Tensor x: Variable :math:`x` at which the proximity operator is computed.
|
75 |
-
:param torch.Tensor y: Data :math:`y`.
|
76 |
-
:param float gamma: thresholding parameter.
|
77 |
-
:return: (:class:`torch.Tensor`) proximity operator :math:`\operatorname{prox}_{\gamma \distancename}(x)`.
|
78 |
-
"""
|
79 |
-
norm = 1.0 / (self.sigma**2 + y * self.gain)
|
80 |
-
return (x + norm * gamma * y) / (1 + gamma * norm)
|
81 |
-
|
82 |
-
|
83 |
-
class PoissonGaussianDataFidelity(dinv.optim.DataFidelity):
|
84 |
-
r"""
|
85 |
-
Implementation of the data-fidelity as the normalized :math:`\ell_2` norm
|
86 |
-
|
87 |
-
.. math::
|
88 |
-
|
89 |
-
f(x) = \|\forw{x}-y\|^2_{\text{diag}(\sigma^2 + y \gamma)}
|
90 |
-
|
91 |
-
It can be used to define a log-likelihood function associated with Poisson Gaussian noise
|
92 |
-
by setting an appropriate noise level :math:`\sigma`.
|
93 |
-
|
94 |
-
:param float sigma: Standard deviation of the noise to be used as a normalisation factor.
|
95 |
-
:param float gain: Gain factor of the data-fidelity term.
|
96 |
-
"""
|
97 |
-
|
98 |
-
def __init__(self, sigma=1.0, gain=0.):
|
99 |
-
super().__init__()
|
100 |
-
self.d = PoissonGaussianDistance(sigma=sigma, gain=gain)
|
101 |
-
self.gain = gain
|
102 |
-
self.sigma = sigma
|
103 |
-
|
104 |
-
def prox(self, x, y, physics, gamma=1.0, *args, **kwargs):
|
105 |
-
r"""
|
106 |
-
Proximal operator of :math:`\gamma \datafid{Ax}{y} = \frac{\gamma}{2\sigma^2}\|Ax-y\|^2`.
|
107 |
-
|
108 |
-
Computes :math:`\operatorname{prox}_{\gamma \datafidname}`, i.e.
|
109 |
-
|
110 |
-
.. math::
|
111 |
-
|
112 |
-
\operatorname{prox}_{\gamma \datafidname} = \underset{u}{\text{argmin}} \frac{\gamma}{2\sigma^2}\|Au-y\|_2^2+\frac{1}{2}\|u-x\|_2^2
|
113 |
-
|
114 |
-
|
115 |
-
:param torch.Tensor x: Variable :math:`x` at which the proximity operator is computed.
|
116 |
-
:param torch.Tensor y: Data :math:`y`.
|
117 |
-
:param deepinv.physics.Physics physics: physics model.
|
118 |
-
:param float gamma: stepsize of the proximity operator.
|
119 |
-
:return: (:class:`torch.Tensor`) proximity operator :math:`\operatorname{prox}_{\gamma \datafidname}(x)`.
|
120 |
-
"""
|
121 |
-
assert isinstance(physics, dinv.physics.LinearPhysics), "not implemented for non-linear physics"
|
122 |
-
if isinstance(physics, dinv.physics.StackedPhysics):
|
123 |
-
device=y[0].device
|
124 |
-
noise_model = physics[-1].noise_model
|
125 |
-
else:
|
126 |
-
device=y.device
|
127 |
-
noise_model = physics.noise_model
|
128 |
-
if hasattr(noise_model, "gain"):
|
129 |
-
self.gain = noise_model.gain.detach().to(device)
|
130 |
-
if hasattr(noise_model, "sigma"):
|
131 |
-
self.sigma = noise_model.sigma.detach().to(device)
|
132 |
-
# Ensure sigma is a tensor and reshape if necessary
|
133 |
-
if isinstance(self.sigma, float):
|
134 |
-
self.sigma = torch.tensor([self.sigma], device=device)
|
135 |
-
if self.sigma.ndim == 0 :
|
136 |
-
self.sigma = self.sigma.unsqueeze(0).to(device)
|
137 |
-
# Ensure gain is a tensor and reshape if necessary
|
138 |
-
if isinstance(self.gain, float):
|
139 |
-
self.gain = torch.tensor([self.gain], device=device)
|
140 |
-
if self.gain.ndim == 0 :
|
141 |
-
self.gain = self.gain.unsqueeze(0).to(device)
|
142 |
-
if self.gain[0] > 0 :
|
143 |
-
norm = gamma / (self.sigma[:, None, None, None]**2 + y * self.gain[:, None, None, None])
|
144 |
-
else :
|
145 |
-
norm = gamma / (self.sigma[:, None, None, None]**2)
|
146 |
-
A = lambda u: physics.A_adjoint(physics.A(u)*norm) + u
|
147 |
-
b = physics.A_adjoint(norm*y) + x
|
148 |
-
return deepinv.optim.utils.conjugate_gradient(A, b, init=x, max_iter=3, tol=1e-3)
|
149 |
-
|
150 |
-
from deepinv.optim.optim_iterators import OptimIterator, fStep, gStep
|
151 |
-
|
152 |
-
class myHQSIteration(OptimIterator):
|
153 |
-
r"""
|
154 |
-
Single iteration of half-quadratic splitting.
|
155 |
-
|
156 |
-
Class for a single iteration of the Half-Quadratic Splitting (HQS) algorithm for minimising :math:`f(x) + \lambda \regname(x)`.
|
157 |
-
The iteration is given by
|
158 |
-
|
159 |
-
|
160 |
-
.. math::
|
161 |
-
\begin{equation*}
|
162 |
-
\begin{aligned}
|
163 |
-
u_{k} &= \operatorname{prox}_{\gamma f}(x_k) \\
|
164 |
-
x_{k+1} &= \operatorname{prox}_{\sigma \lambda \regname}(u_k).
|
165 |
-
\end{aligned}
|
166 |
-
\end{equation*}
|
167 |
-
|
168 |
-
|
169 |
-
where :math:`\gamma` and :math:`\sigma` are step-sizes. Note that this algorithm does not converge to
|
170 |
-
a minimizer of :math:`f(x) + \lambda \regname(x)`, but instead to a minimizer of
|
171 |
-
:math:`\gamma\, ^1f+\sigma \lambda \regname`, where :math:`^1f` denotes
|
172 |
-
the Moreau envelope of :math:`f`
|
173 |
-
|
174 |
-
"""
|
175 |
-
|
176 |
-
def __init__(self, **kwargs):
|
177 |
-
super(myHQSIteration, self).__init__(**kwargs)
|
178 |
-
self.g_step = mygStepHQS(**kwargs)
|
179 |
-
self.f_step = myfStepHQS(**kwargs)
|
180 |
-
self.requires_prox_g = True
|
181 |
-
|
182 |
-
class myfStepHQS(fStep):
|
183 |
-
r"""
|
184 |
-
HQS fStep module.
|
185 |
-
"""
|
186 |
-
|
187 |
-
def __init__(self, **kwargs):
|
188 |
-
super(myfStepHQS, self).__init__(**kwargs)
|
189 |
-
|
190 |
-
def forward(self, x, cur_data_fidelity, cur_params, y, physics):
|
191 |
-
r"""
|
192 |
-
Single proximal step on the data-fidelity term :math:`f`.
|
193 |
-
|
194 |
-
:param torch.Tensor x: Current iterate :math:`x_k`.
|
195 |
-
:param deepinv.optim.DataFidelity cur_data_fidelity: Instance of the DataFidelity class defining the current data_fidelity.
|
196 |
-
:param dict cur_params: Dictionary containing the current parameters of the algorithm.
|
197 |
-
:param torch.Tensor y: Input data.
|
198 |
-
:param deepinv.physics.Physics physics: Instance of the physics modeling the data-fidelity term.
|
199 |
-
"""
|
200 |
-
return cur_data_fidelity.prox(x, y, physics, gamma=cur_params["stepsize"])
|
201 |
-
|
202 |
-
class mygStepHQS(gStep):
|
203 |
-
r"""
|
204 |
-
HQS gStep module.
|
205 |
-
"""
|
206 |
-
|
207 |
-
def __init__(self, **kwargs):
|
208 |
-
super(mygStepHQS, self).__init__(**kwargs)
|
209 |
-
|
210 |
-
def forward(self, x, cur_prior, cur_params):
|
211 |
-
r"""
|
212 |
-
Single proximal step on the prior term :math:`\lambda \regname`.
|
213 |
-
|
214 |
-
:param torch.Tensor x: Current iterate :math:`x_k`.
|
215 |
-
:param dict cur_prior: Class containing the current prior.
|
216 |
-
:param dict cur_params: Dictionary containing the current parameters of the algorithm.
|
217 |
-
"""
|
218 |
-
return cur_prior.prox(
|
219 |
-
x,
|
220 |
-
sigma_denoiser = cur_params["g_param"],
|
221 |
-
gain_denoiser = cur_params["gain_param"],
|
222 |
-
gamma=cur_params["lambda"] * cur_params["stepsize"],
|
223 |
-
)
|
224 |
-
|
225 |
-
|
226 |
-
def get_unrolled_architecture(gain_param_init = 1e-3, weight_tied = True, model = None, device = 'cpu'):
|
227 |
-
|
228 |
-
# Unrolled optimization algorithm parameters
|
229 |
-
max_iter = 8 # number of unfolded layers
|
230 |
-
|
231 |
-
# Select the data fidelity term
|
232 |
-
|
233 |
-
|
234 |
-
# Set up the trainable denoising prior
|
235 |
-
# Here the prior model is common for all iterations
|
236 |
-
if model is not None :
|
237 |
-
denoiser = model.to(device)
|
238 |
-
else :
|
239 |
-
denoiser = dinv.models.DRUNet(
|
240 |
-
pretrained= '/lustre/fswork/projects/rech/nyd/commun/mterris/base_checkpoints/drunet_deepinv_color_finetune_22k.pth',
|
241 |
-
).to(device)
|
242 |
-
|
243 |
-
class myPnP(PnP):
|
244 |
-
r"""
|
245 |
-
Gradient-Step Denoiser prior.
|
246 |
-
"""
|
247 |
-
|
248 |
-
def __init__(self, *args, **kwargs):
|
249 |
-
super().__init__(*args, **kwargs)
|
250 |
-
|
251 |
-
def prox(self, x, sigma_denoiser, gain_denoiser, *args, **kwargs):
|
252 |
-
if not self.training:
|
253 |
-
pad = (-x.size(-2) % 8, -x.size(-1) % 8)
|
254 |
-
x = torch.nn.functional.pad(x, (0, pad[1], 0, pad[0]), mode="constant")
|
255 |
-
out = self.denoiser(x, sigma=sigma_denoiser, gamma=gain_denoiser)
|
256 |
-
if not self.training:
|
257 |
-
out = out[..., : -pad[0] or None, : -pad[1] or None]
|
258 |
-
return out
|
259 |
-
|
260 |
-
data_fidelity = PoissonGaussianDataFidelity()
|
261 |
-
|
262 |
-
if not weight_tied :
|
263 |
-
prior = [myPnP(denoiser=copy.deepcopy(denoiser)) for i in range(max_iter)]
|
264 |
-
else :
|
265 |
-
prior = [myPnP(denoiser=denoiser)]
|
266 |
-
|
267 |
-
def get_DPIR_params(noise_level_img, max_iter=8):
|
268 |
-
r"""
|
269 |
-
Default parameters for the DPIR Plug-and-Play algorithm.
|
270 |
-
|
271 |
-
:param float noise_level_img: Noise level of the input image.
|
272 |
-
"""
|
273 |
-
s1 = 49.0 / 255.0
|
274 |
-
s2 = noise_level_img
|
275 |
-
sigma_denoiser = np.logspace(np.log10(s1), np.log10(s2), max_iter).astype(
|
276 |
-
np.float32
|
277 |
-
)
|
278 |
-
stepsize = (sigma_denoiser / max(0.01, noise_level_img)) ** 2
|
279 |
-
lamb = 1 / 0.23
|
280 |
-
return list(sigma_denoiser), list(lamb * stepsize)
|
281 |
-
|
282 |
-
sigma_denoiser, stepsize = get_DPIR_params(0.05)
|
283 |
-
stepsize = torch.tensor(stepsize) * (torch.tensor(sigma_denoiser)**2)
|
284 |
-
gain_denoiser = [gain_param_init]*len(sigma_denoiser)
|
285 |
-
params_algo = {"stepsize": stepsize, "g_param": sigma_denoiser, "gain_param": gain_denoiser}
|
286 |
-
|
287 |
-
trainable_params = [
|
288 |
-
"g_param",
|
289 |
-
"gain_param"
|
290 |
-
"stepsize",
|
291 |
-
] # define which parameters from 'params_algo' are trainable
|
292 |
-
|
293 |
-
# Define the unfolded trainable model.
|
294 |
-
model = unfolded_builder(
|
295 |
-
iteration=myHQSIteration(),
|
296 |
-
params_algo=params_algo.copy(),
|
297 |
-
trainable_params=trainable_params,
|
298 |
-
data_fidelity=data_fidelity,
|
299 |
-
max_iter=max_iter,
|
300 |
-
prior=prior,
|
301 |
-
device=device,
|
302 |
-
)
|
303 |
-
|
304 |
-
return model.to(device)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|