File size: 11,907 Bytes
4dc3e99
 
 
1eb9e66
4dc3e99
 
 
 
b6d8eef
4dc3e99
b6d8eef
4dc3e99
 
 
a0ca102
4dc3e99
 
a0ca102
499f595
 
 
 
ff76a8d
 
4dc3e99
92dc96b
 
499f595
92dc96b
499f595
0fb7549
 
 
499f595
a0ca102
 
499f595
a0ca102
 
 
499f595
a0ca102
 
 
 
92dc96b
499f595
a0ca102
 
 
 
499f595
 
a0ca102
 
 
499f595
a0ca102
 
92dc96b
 
 
499f595
92dc96b
1eb9e66
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4dc3e99
 
1eb9e66
4dc3e99
1eb9e66
 
5a3ed26
4dc3e99
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3a575e4
499f595
3a575e4
4e7aed4
01bc330
1eb9e66
ff64068
1eb9e66
01bc330
1eb9e66
ff64068
1eb9e66
4e7aed4
1eb9e66
a0ca102
ff64068
1eb9e66
a0ca102
 
1eb9e66
a0ca102
 
1eb9e66
ff64068
4dc3e99
499f595
 
 
 
 
 
 
 
 
4dc3e99
 
 
906193d
4dc3e99
0e089a6
499f595
a0ca102
1eb9e66
 
499f595
 
 
 
a0ca102
0e089a6
1eb9e66
 
 
a0ca102
499f595
 
 
 
 
 
 
a0ca102
4dc3e99
e0ec252
 
a0ca102
 
 
499f595
4dc3e99
e0ec252
 
 
01302e8
e0ec252
01302e8
e0ec252
01302e8
3a575e4
a0ca102
 
 
1eb9e66
a0ca102
 
499f595
a0ca102
4dc3e99
3a575e4
499f595
a0ca102
499f595
a0ca102
 
109f096
a0ca102
3a575e4
4dc3e99
a0ca102
01bc330
4dc3e99
1eb9e66
ff64068
4dc3e99
ff64068
a0ca102
 
499f595
a0ca102
92dc96b
 
499f595
a0ca102
 
499f595
3a575e4
 
4dc3e99
 
499f595
a0ca102
 
499f595
3a575e4
4dc3e99
 
499f595
a0ca102
 
 
288f88f
499f595
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
import json
import os
import random
import time
from functools import partial
from pathlib import Path
from typing import List

import deepinv as dinv
import gradio as gr
import torch
from PIL import Image
from torchvision import transforms

from factories import PhysicsWithGenerator, EvalModel, BaselineModel, EvalDataset, Metric


### Config
# run model inference on NVIDIA gpu if available
DEVICE_STR = 'cuda' if torch.cuda.is_available() else 'cpu'
# stops tracking values for gradients
torch.set_grad_enabled(False)


### Gradio Utils
def generate_imgs_from_user(image,
                            physics: PhysicsWithGenerator, use_gen: bool,
                            baseline: BaselineModel, model: EvalModel,
                            metrics: List[Metric]):
    # Happens when user image is missing
    if image is None:
        return None, None, None, None, None, None, None, None

    # PIL image -> torch.Tensor / (1, C, H, W) / move to DEVICE_STR
    x = transforms.ToTensor()(image).unsqueeze(0).to(DEVICE_STR)

    return generate_imgs(x, physics, use_gen, baseline, model, metrics)

def generate_imgs_from_dataset(dataset: EvalDataset, idx: int,
                               physics: PhysicsWithGenerator, use_gen: bool,
                               baseline: BaselineModel, model: EvalModel,
                               metrics: List[Metric]):
    ### Load 1 image
    x = dataset[idx]    # shape : (C, H, W)
    x = x.unsqueeze(0)  # shape : (1, C, H, W)

    return generate_imgs(x, physics, use_gen, baseline, model, metrics)

def generate_random_imgs_from_dataset(dataset: EvalDataset,
                                      physics: PhysicsWithGenerator,
                                      use_gen: bool,
                                      baseline: BaselineModel,
                                      model: EvalModel,
                                      metrics: List[Metric]):
    idx = random.randint(0, len(dataset)-1)
    x, y, out, out_baseline, saved_params_str, metrics_y, metrics_out, metrics_out_baseline = generate_imgs_from_dataset(
        dataset, idx, physics, use_gen, baseline, model, metrics
        )
    return idx, x, y, out, out_baseline, saved_params_str, metrics_y, metrics_out, metrics_out_baseline
    
def generate_imgs(x: torch.Tensor,
                  physics: PhysicsWithGenerator, use_gen: bool,
                  baseline: BaselineModel, model: EvalModel,
                  metrics: List[Metric]):
    ### Compute y
    y = physics(x, use_gen)  # possible reduction in img shape due to Blurring

    ### Compute x_hat from RAM & DPIR
    ram_time = time.time()
    out = model(y=y, physics=physics.physics)
    ram_time = time.time() - ram_time

    dpir_time = time.time()
    out_baseline = baseline(y=y, physics=physics.physics)
    dpir_time = time.time() - dpir_time

    ### Process tensors before metric computation
    if "Blur" in physics.name:
        w_1, w_2 = (x.shape[2] - y.shape[2]) // 2, (x.shape[2] + y.shape[2]) // 2
        h_1, h_2 = (x.shape[3] - y.shape[3]) // 2, (x.shape[3] + y.shape[3]) // 2

        x = x[..., w_1:w_2, h_1:h_2]
        out = out[..., w_1:w_2, h_1:h_2]
        if out_baseline.shape != out.shape:
            out_baseline = out_baseline[..., w_1:w_2, h_1:h_2]

    ### Metrics
    metrics_y = ""
    metrics_out = f"Inference time = {ram_time:.3f}s" + "\n"
    metrics_out_baseline = f"Inference time = {dpir_time:.3f}s" + "\n"
    for metric in metrics:
        if y.shape == x.shape:
            metrics_y += f"{metric.name} = {metric(y, x).item():.4f}" + "\n"
        metrics_out += f"{metric.name} = {metric(out, x).item():.4f}" + "\n"
        metrics_out_baseline += f"{metric.name} = {metric(out_baseline, x).item():.4f}" + "\n"

    ### Process y when y shape is different from x shape
    if physics.name == "MRI":
        y_plot = physics.physics.prox_l2(physics.physics.A_adjoint(y), y, 1e4)
    elif physics.name == "CT":
        y_plot = physics.physics.A_adjoint(y)
    else:
        y_plot = y.clone()

    ### Processing images for plotting :
    #     - clip value outside of [0,1]
    #     - shape (1, C, H, W) -> (C, H, W)
    #     - torch.Tensor object -> Pil object
    process_img = partial(dinv.utils.plotting.preprocess_img, rescale_mode="clip")
    to_pil = transforms.ToPILImage()
    x = to_pil(process_img(x)[0].to('cpu'))
    y = to_pil(process_img(y_plot)[0].to('cpu'))
    out = to_pil(process_img(out)[0].to('cpu'))
    out_baseline = to_pil(process_img(out_baseline)[0].to('cpu'))

    return x, y, out, out_baseline, physics.display_saved_params(), metrics_y, metrics_out, metrics_out_baseline


get_dataset_on_DEVICE_STR = partial(EvalDataset, device_str=DEVICE_STR)
get_physics_on_DEVICE_STR = partial(PhysicsWithGenerator, device_str=DEVICE_STR)
get_baseline_model_on_DEVICE_STR = partial(BaselineModel, device_str=DEVICE_STR)

def get_dataset(dataset_name):
    if dataset_name == 'MRI':
        available_physics = ['MRI']
        physics_name = 'MRI'
        baseline_name = 'DPIR_MRI'
    elif dataset_name == 'CT':
        available_physics = ['CT']
        physics_name = 'CT'
        baseline_name = 'DPIR_CT'
    else:
        available_physics = ['MotionBlur_easy', 'MotionBlur_medium', 'MotionBlur_hard',
                             'GaussianBlur_easy', 'GaussianBlur_medium', 'GaussianBlur_hard']
        physics_name = 'MotionBlur_easy'
        baseline_name = 'DPIR'

    dataset = get_dataset_on_DEVICE_STR(dataset_name)
    idx = 0
    physics = get_physics_on_DEVICE_STR(physics_name)
    baseline = get_baseline_model_on_DEVICE_STR(baseline_name)
    return dataset, idx, physics, baseline, available_physics


# global variables shared by all users
ram_model = EvalModel("unext_emb_physics_config_C", device_str=DEVICE_STR)
psnr = Metric.get_list_metrics(["PSNR"], device_str=DEVICE_STR)

generate_imgs_from_user_partial = partial(generate_imgs_from_user, model=ram_model, metrics=psnr)
generate_imgs_from_dataset_partial = partial(generate_imgs_from_dataset, model=ram_model, metrics=psnr)
generate_random_imgs_from_dataset_partial = partial(generate_random_imgs_from_dataset, model=ram_model, metrics=psnr)


### Gradio Blocks interface

title = "Inverse problem playground"  # displayed on gradio tab and in the gradio page
with gr.Blocks(title=title, theme=gr.themes.Glass()) as interface:
    gr.Markdown("## " + title)

    ### USER-SPECIFIC VARIABLES
    dataset_placeholder = gr.State(get_dataset_on_DEVICE_STR("Natural"))
    available_physics_placeholder = gr.State(['MotionBlur_easy', 'MotionBlur_medium', 'MotionBlur_hard',
                                              'GaussianBlur_easy', 'GaussianBlur_medium', 'GaussianBlur_hard'])
    # Issue giving directly a `torch.nn.module` to `gr.State(...)` since it has __call__ method
    # Solution: using lambda expression
    physics_placeholder = gr.State(lambda: get_physics_on_DEVICE_STR("MotionBlur_easy"))
    model_b_placeholder = gr.State(lambda: get_baseline_model_on_DEVICE_STR("DPIR"))


    @gr.render(inputs=[dataset_placeholder, physics_placeholder, available_physics_placeholder])
    def dynamic_layout(dataset, physics, available_physics):
        ### LAYOUT

        # Display images
        with gr.Row():
            gt_img = gr.Image(label="Ground-truth image", interactive=True, key=0)
            observed_img = gr.Image(label="Observed image", interactive=False, key=1)
            model_a_out = gr.Image(label="RAM output", interactive=False, key=2)
            model_b_out = gr.Image(label="DPIR output", interactive=False, key=3)

        # Manage datasets and display metric values
        with gr.Row():
            with gr.Column(scale=1, min_width=160):
                run_button = gr.Button("Demo on above image", size='md')
                choose_dataset = gr.Radio(choices=EvalDataset.all_datasets,
                                          label="Datasets",
                                          value=dataset.name)
                idx_slider = gr.Slider(minimum=0, maximum=len(dataset)-1, step=1, label="Sample index", key=4)
                with gr.Row():
                    load_button = gr.Button("Run on index image from dataset", size='md')
                    load_random_button = gr.Button("Run on random image from dataset", size='md')
            with gr.Column(scale=1, min_width=160):
                observed_metrics = gr.Textbox(label="Observed metric", lines=2, key=5)
            with gr.Column(scale=1, min_width=160):
                out_a_metric = gr.Textbox(label="RAM output metrics", lines=2, key=6)
            with gr.Column(scale=1, min_width=160):
                out_b_metric = gr.Textbox(label="DPIR output metrics", lines=2, key=7)

        # Manage physics
        with gr.Row():
            with gr.Column(scale=1):
                choose_physics = gr.Radio(choices=available_physics,
                                          label="Physics",
                                          value=physics.name)
                use_generator_button = gr.Checkbox(label="Generate physics parameters during inference", value=True, key=8)
            with gr.Column(scale=1):
                with gr.Row():
                    key_selector = gr.Dropdown(choices=list(physics.saved_params["updatable_params"].keys()),
                                               label="Updatable Key")
                    value_text = gr.Textbox(label="Update Value")
                update_button = gr.Button("Manually update parameter value", size='md')
            with gr.Column(scale=2):
                physics_params = gr.Textbox(label="Physics parameters",
                                            lines=5,
                                            value=physics.display_saved_params())  

        ### Event listeners

        choose_dataset.change(fn=get_dataset,
                              inputs=choose_dataset,
                              outputs=[dataset_placeholder, idx_slider, physics_placeholder, model_b_placeholder, available_physics_placeholder])
        choose_physics.change(fn=get_physics_on_DEVICE_STR,
                              inputs=choose_physics,
                              outputs=[physics_placeholder])
        update_button.click(fn=physics.update_and_display_params,
                            inputs=[key_selector, value_text], outputs=physics_params)
        run_button.click(fn=generate_imgs_from_user_partial,
                         inputs=[gt_img,
                                 physics_placeholder,
                                 use_generator_button,
                                 model_b_placeholder],
                         outputs=[gt_img, observed_img, model_a_out, model_b_out,
                                  physics_params, observed_metrics, out_a_metric, out_b_metric])
        load_button.click(fn=generate_imgs_from_dataset_partial,
                          inputs=[dataset_placeholder,
                                  idx_slider,
                                  physics_placeholder,
                                  use_generator_button,
                                  model_b_placeholder],
                          outputs=[gt_img, observed_img, model_a_out, model_b_out,
                                   physics_params, observed_metrics, out_a_metric, out_b_metric])
        load_random_button.click(fn=generate_random_imgs_from_dataset_partial,
                                 inputs=[dataset_placeholder,
                                         physics_placeholder,
                                         use_generator_button,
                                         model_b_placeholder],
                                 outputs=[idx_slider, gt_img, observed_img, model_a_out, model_b_out,
                                          physics_params, observed_metrics, out_a_metric, out_b_metric])


interface.launch()