File size: 12,037 Bytes
4dc3e99
 
 
1eb9e66
4dc3e99
 
 
 
b6d8eef
4dc3e99
b6d8eef
4dc3e99
 
 
a0ca102
4dc3e99
 
a0ca102
e0ec252
a0ca102
ff76a8d
 
4dc3e99
92dc96b
 
 
 
0fb7549
 
 
92dc96b
a0ca102
 
 
 
 
 
 
 
 
 
 
92dc96b
 
a0ca102
 
 
 
 
 
 
 
 
 
 
 
92dc96b
 
 
 
 
 
1eb9e66
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4dc3e99
 
1eb9e66
4dc3e99
1eb9e66
 
5a3ed26
4dc3e99
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3a575e4
4dc3e99
3a575e4
 
4e7aed4
01bc330
1eb9e66
ff64068
1eb9e66
01bc330
1eb9e66
ff64068
1eb9e66
4e7aed4
1eb9e66
a0ca102
ff64068
1eb9e66
a0ca102
 
1eb9e66
a0ca102
 
1eb9e66
ff64068
4dc3e99
 
 
 
906193d
4dc3e99
0e089a6
1eb9e66
a0ca102
 
 
 
1eb9e66
 
a0ca102
 
1eb9e66
 
a0ca102
0e089a6
1eb9e66
 
 
e0ec252
 
 
 
c472fe6
1eb9e66
 
 
a0ca102
 
4dc3e99
e0ec252
 
a0ca102
 
 
1eb9e66
4dc3e99
e0ec252
 
 
 
109f096
e0ec252
 
109f096
e0ec252
 
109f096
3a575e4
a0ca102
 
 
1eb9e66
a0ca102
 
e0ec252
a0ca102
4dc3e99
3a575e4
1eb9e66
a0ca102
 
 
 
109f096
a0ca102
3a575e4
288f88f
4dc3e99
a0ca102
01bc330
4dc3e99
1eb9e66
ff64068
4dc3e99
ff64068
a0ca102
 
92dc96b
a0ca102
92dc96b
 
 
 
 
a0ca102
 
92dc96b
3a575e4
 
4dc3e99
 
 
 
3a575e4
a0ca102
 
92dc96b
3a575e4
4dc3e99
 
 
 
caa6a5d
a0ca102
 
 
288f88f
caa6a5d
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
import json
import os
import random
import time
from functools import partial
from pathlib import Path
from typing import List

import deepinv as dinv
import gradio as gr
import torch
from PIL import Image
from torchvision import transforms

from factories import PhysicsWithGenerator, EvalModel, BaselineModel, EvalDataset, Metric


### Config
DEVICE_STR = 'cuda' if torch.cuda.is_available() else 'cpu'            # run model inference on NVIDIA gpu
torch.set_grad_enabled(False)  # stops tracking values for gradients


### Gradio Utils
def generate_imgs_from_user(image,
                            model: EvalModel, baseline: BaselineModel,
                            physics: PhysicsWithGenerator, use_gen: bool,
                            metrics: List[Metric]):
    if image is None:
        return None, None, None, None, None, None, None, None

    # PIL image -> torch.Tensor
    x = transforms.ToTensor()(image).unsqueeze(0).to(DEVICE_STR)

    return generate_imgs(x, model, baseline, physics, use_gen, metrics)

def generate_imgs_from_dataset(dataset: EvalDataset, idx: int,
                               model: EvalModel, baseline: BaselineModel,
                               physics: PhysicsWithGenerator, use_gen: bool,
                               metrics: List[Metric]):
    ### Load 1 image
    x = dataset[idx]    # shape : (C, H, W)
    x = x.unsqueeze(0)  # shape : (1, C, H, W)

    return generate_imgs(x, model, baseline, physics, use_gen, metrics)

def generate_random_imgs_from_dataset(dataset: EvalDataset,
                                      model: EvalModel,
                                      baseline: BaselineModel,
                                      physics: PhysicsWithGenerator,
                                      use_gen: bool,
                                      metrics: List[Metric]):
    idx = random.randint(0, len(dataset)-1)
    x, y, out, out_baseline, saved_params_str, metrics_y, metrics_out, metrics_out_baseline = generate_imgs_from_dataset(
        dataset, idx, model, baseline, physics, use_gen, metrics
        )
    return idx, x, y, out, out_baseline, saved_params_str, metrics_y, metrics_out, metrics_out_baseline
    
def generate_imgs(x: torch.Tensor,
                  model: EvalModel, baseline: BaselineModel,
                  physics: PhysicsWithGenerator, use_gen: bool,
                  metrics: List[Metric]):

    ### Compute y
    y = physics(x, use_gen)  # possible reduction in img shape due to Blurring

    ### Compute x_hat from RAM & DPIR
    ram_time = time.time()
    out = model(y=y, physics=physics.physics)
    ram_time = time.time() - ram_time

    dpir_time = time.time()
    out_baseline = baseline(y=y, physics=physics.physics)
    dpir_time = time.time() - dpir_time

    ### Process tensors before metric computation
    if "Blur" in physics.name:
        w_1, w_2 = (x.shape[2] - y.shape[2]) // 2, (x.shape[2] + y.shape[2]) // 2
        h_1, h_2 = (x.shape[3] - y.shape[3]) // 2, (x.shape[3] + y.shape[3]) // 2

        x = x[..., w_1:w_2, h_1:h_2]
        out = out[..., w_1:w_2, h_1:h_2]
        if out_baseline.shape != out.shape:
            out_baseline = out_baseline[..., w_1:w_2, h_1:h_2]

    ### Metrics
    metrics_y = ""
    metrics_out = f"Inference time = {ram_time:.3f}s" + "\n"
    metrics_out_baseline = f"Inference time = {dpir_time:.3f}s" + "\n"
    for metric in metrics:
        if y.shape == x.shape:
            metrics_y += f"{metric.name} = {metric(y, x).item():.4f}" + "\n"
        metrics_out += f"{metric.name} = {metric(out, x).item():.4f}" + "\n"
        metrics_out_baseline += f"{metric.name} = {metric(out_baseline, x).item():.4f}" + "\n"

    ### Process y when y shape is different from x shape
    if physics.name == "MRI":
        y_plot = physics.physics.prox_l2(physics.physics.A_adjoint(y), y, 1e4)
    elif physics.name == "CT":
        y_plot = physics.physics.A_adjoint(y)
    else:
        y_plot = y.clone()

    ### Processing images for plotting :
    #     - clip value outside of [0,1]
    #     - shape (1, C, H, W) -> (C, H, W)
    #     - torch.Tensor object -> Pil object
    process_img = partial(dinv.utils.plotting.preprocess_img, rescale_mode="clip")
    to_pil = transforms.ToPILImage()
    x = to_pil(process_img(x)[0].to('cpu'))
    y = to_pil(process_img(y_plot)[0].to('cpu'))
    out = to_pil(process_img(out)[0].to('cpu'))
    out_baseline = to_pil(process_img(out_baseline)[0].to('cpu'))

    return x, y, out, out_baseline, physics.display_saved_params(), metrics_y, metrics_out, metrics_out_baseline


get_list_metrics_on_DEVICE_STR = partial(Metric.get_list_metrics, device_str=DEVICE_STR)
get_eval_model_on_DEVICE_STR = partial(EvalModel, device_str=DEVICE_STR)
get_baseline_model_on_DEVICE_STR = partial(BaselineModel, device_str=DEVICE_STR)
get_dataset_on_DEVICE_STR = partial(EvalDataset, device_str=DEVICE_STR)
get_physics_on_DEVICE_STR = partial(PhysicsWithGenerator, device_str=DEVICE_STR)

def get_dataset(dataset_name):
    if dataset_name == 'MRI':
        available_physics = ['MRI']
        physics_name = 'MRI'
        baseline_name = 'DPIR_MRI'
    elif dataset_name == 'CT':
        available_physics = ['CT']
        physics_name = 'CT'
        baseline_name = 'DPIR_CT'
    else:
        available_physics = ['MotionBlur_easy', 'MotionBlur_medium', 'MotionBlur_hard',
                             'GaussianBlur_easy', 'GaussianBlur_medium', 'GaussianBlur_hard']
        physics_name = 'MotionBlur_easy'
        baseline_name = 'DPIR'

    dataset = get_dataset_on_DEVICE_STR(dataset_name)
    idx = 0
    physics = get_physics_on_DEVICE_STR(physics_name)
    baseline = get_baseline_model_on_DEVICE_STR(baseline_name)
    return dataset, idx, physics, baseline, available_physics


### Gradio Blocks interface

title = "Inverse problem playground"  # displayed on gradio tab and in the gradio page
with gr.Blocks(title=title, theme=gr.themes.Glass()) as interface:
    gr.Markdown("## " + title)

    ### DEFAULT VALUES
    # Issue: giving directly a `torch.nn.module` to `gr.State(...)` since it has __call__ method
    # Solution: using lambda expression
    model_a_placeholder = gr.State(lambda: get_eval_model_on_DEVICE_STR("unext_emb_physics_config_C", ""))
    model_b_placeholder = gr.State(lambda: get_baseline_model_on_DEVICE_STR("DPIR"))
    metrics_placeholder = gr.State(get_list_metrics_on_DEVICE_STR(["PSNR"]))

    dataset_placeholder = gr.State(get_dataset_on_DEVICE_STR("Natural"))
    physics_placeholder = gr.State(lambda: get_physics_on_DEVICE_STR("MotionBlur_easy"))
    available_physics_placeholder = gr.State(['MotionBlur_easy', 'MotionBlur_medium', 'MotionBlur_hard',
                                              'GaussianBlur_easy', 'GaussianBlur_medium', 'GaussianBlur_hard'])


    ### LAYOUT
    # Display images
    with gr.Row():
        gt_img = gr.Image(label="Ground-truth image", interactive=True)
        observed_img = gr.Image(label="Observed image", interactive=False)
        model_a_out = gr.Image(label="RAM output", interactive=False)
        model_b_out = gr.Image(label="DPIR output", interactive=False)

    @gr.render(inputs=[dataset_placeholder, physics_placeholder, available_physics_placeholder])
    def dynamic_layout(dataset, physics, available_physics):
        ### LAYOUT

        # Manage datasets and display metric values
        with gr.Row():
            with gr.Column(scale=1, min_width=160):
                run_button = gr.Button("Demo on above image", size='md')
                choose_dataset = gr.Radio(choices=EvalDataset.all_datasets,
                                          label="Datasets",
                                          value=dataset.name)
                idx_slider = gr.Slider(minimum=0, maximum=len(dataset)-1, step=1, label="Sample index", key=0)
                with gr.Row():
                    load_button = gr.Button("Run on index image from dataset", size='md')
                    load_random_button = gr.Button("Run on random image from dataset", size='md')
            with gr.Column(scale=1, min_width=160):
                observed_metrics = gr.Textbox(label="PSNR Observed",
                                              lines=1)
            with gr.Column(scale=1, min_width=160):
                out_a_metric = gr.Textbox(label="PSNR RAM output",
                                          lines=1)
            with gr.Column(scale=1, min_width=160):
                out_b_metric = gr.Textbox(label="PSNR DPIR",
                                          lines=1)

        # Manage physics
        with gr.Row():
            with gr.Column(scale=1):
                choose_physics = gr.Radio(choices=available_physics,
                                          label="Physics",
                                          value=physics.name)
                use_generator_button = gr.Checkbox(label="Generate physics parameters during inference", value=True)
            with gr.Column(scale=1):
                with gr.Row():
                    key_selector = gr.Dropdown(choices=list(physics.saved_params["updatable_params"].keys()),
                                               label="Updatable Parameter Key")
                    value_text = gr.Textbox(label="Update Value")
                update_button = gr.Button("Manually update parameter value")
            with gr.Column(scale=2):
                physics_params = gr.Textbox(label="Physics parameters",
                                            lines=5,
                                            value=physics.display_saved_params())  


        ### Event listeners

        choose_dataset.change(fn=get_dataset,
                              inputs=choose_dataset,
                              outputs=[dataset_placeholder, idx_slider, physics_placeholder, model_b_placeholder, available_physics_placeholder])
        choose_physics.change(fn=get_physics_on_DEVICE_STR,
                              inputs=choose_physics,
                              outputs=[physics_placeholder])
        update_button.click(fn=physics.update_and_display_params,
                            inputs=[key_selector, value_text], outputs=physics_params)
        run_button.click(fn=generate_imgs_from_user,
                         inputs=[gt_img,
                                 model_a_placeholder,
                                 model_b_placeholder,
                                 physics_placeholder,
                                 use_generator_button,
                                 metrics_placeholder],
                         outputs=[gt_img, observed_img, model_a_out, model_b_out,
                                  physics_params, observed_metrics, out_a_metric, out_b_metric])
        load_button.click(fn=generate_imgs_from_dataset,
                          inputs=[dataset_placeholder,
                                  idx_slider,
                                  model_a_placeholder,
                                  model_b_placeholder,
                                  physics_placeholder,
                                  use_generator_button,
                                  metrics_placeholder],
                          outputs=[gt_img, observed_img, model_a_out, model_b_out,
                                   physics_params, observed_metrics, out_a_metric, out_b_metric])
        load_random_button.click(fn=generate_random_imgs_from_dataset,
                                 inputs=[dataset_placeholder,
                                         model_a_placeholder,
                                         model_b_placeholder,
                                         physics_placeholder,
                                         use_generator_button,
                                         metrics_placeholder],
                                 outputs=[idx_slider, gt_img, observed_img, model_a_out, model_b_out,
                                          physics_params, observed_metrics, out_a_metric, out_b_metric])


interface.launch()