File size: 14,052 Bytes
4dc3e99
1eb9e66
4dc3e99
 
 
b6d8eef
4dc3e99
b6d8eef
4dc3e99
 
a0ca102
4dc3e99
 
a0ca102
499f595
 
ff76a8d
 
4dc3e99
4e6590f
 
 
 
 
 
 
 
 
 
92dc96b
 
499f595
92dc96b
499f595
0fb7549
 
 
499f595
a0ca102
4e6590f
 
a0ca102
4e6590f
 
 
 
 
 
a0ca102
4e6590f
6f0291c
a0ca102
 
499f595
a0ca102
 
 
 
92dc96b
499f595
a0ca102
 
 
 
499f595
 
a0ca102
 
 
499f595
a0ca102
 
92dc96b
 
 
499f595
92dc96b
384859e
 
 
 
 
1eb9e66
384859e
 
1eb9e66
 
 
384859e
 
1eb9e66
 
 
384859e
 
1eb9e66
 
 
 
 
 
 
 
 
 
 
 
6f0291c
35c18b7
6f0291c
 
 
 
1eb9e66
 
384859e
 
1eb9e66
6f0291c
 
1eb9e66
 
384859e
 
4dc3e99
 
 
 
 
 
 
2776aea
 
 
 
4dc3e99
2776aea
f3bcaf9
2776aea
 
384859e
 
 
 
4dc3e99
2776aea
4dc3e99
 
3a575e4
499f595
3a575e4
4e7aed4
01bc330
1eb9e66
ff64068
1eb9e66
01bc330
1eb9e66
ff64068
1eb9e66
4e7aed4
88d3587
a0ca102
88d3587
1eb9e66
a0ca102
 
1eb9e66
a0ca102
 
1eb9e66
ff64068
4dc3e99
499f595
81c09b8
384859e
499f595
 
 
 
 
 
384859e
9eb8ea4
 
384859e
 
 
4dc3e99
906193d
4dc3e99
0e089a6
499f595
a0ca102
c7ae131
1eb9e66
499f595
 
c7ae131
499f595
a0ca102
384859e
 
9eb8ea4
4e6590f
 
1eb9e66
 
a0ca102
499f595
 
2776aea
 
 
 
499f595
a0ca102
4dc3e99
e0ec252
 
a0ca102
 
 
2776aea
4dc3e99
e0ec252
 
 
384859e
e0ec252
384859e
e0ec252
384859e
3a575e4
a0ca102
 
 
1eb9e66
a0ca102
 
2776aea
a0ca102
4dc3e99
3a575e4
499f595
a0ca102
499f595
a0ca102
 
109f096
a0ca102
3a575e4
4dc3e99
a0ca102
01bc330
4dc3e99
1eb9e66
ff64068
4dc3e99
ff64068
a0ca102
 
499f595
a0ca102
92dc96b
 
499f595
a0ca102
 
499f595
3a575e4
 
4dc3e99
 
499f595
a0ca102
 
499f595
3a575e4
4dc3e99
 
499f595
a0ca102
 
 
288f88f
499f595
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
import random
import time
from functools import partial
from typing import List

import deepinv as dinv
import gradio as gr
import torch
from torchvision import transforms

from factories import PhysicsWithGenerator, EvalModel, BaselineModel, EvalDataset, Metric


### Config
# run model inference on NVIDIA gpu if available
DEVICE_STR = 'cuda' if torch.cuda.is_available() else 'cpu'


### Gradio Utils
def resize_tensor_within_box(tensor_img: torch.Tensor, max_size: int = 512):
    _, _, h, w = tensor_img.shape
    scale = min(max_size / h, max_size / w)

    if scale < 1.0:
        new_h, new_w = int(h * scale), int(w * scale)
        tensor_img = transforms.functional.resize(tensor_img, [new_h, new_w], antialias=True)

    return tensor_img

def generate_imgs_from_user(image,
                            physics: PhysicsWithGenerator, use_gen: bool,
                            baseline: BaselineModel, model: EvalModel,
                            metrics: List[Metric]):
    # Happens when user image is missing
    if image is None:
        return None, None, None, None, None, None, None, None

    # PIL image -> torch.Tensor / (1, C, H, W) / move to DEVICE_STR
    x = transforms.ToTensor()(image).unsqueeze(0).to(DEVICE_STR)
    # Resize img within a 512x512 box
    x = resize_tensor_within_box(x)

    C = x.shape[1]
    if C == 3 and physics.name == 'CT':
        x = transforms.Grayscale(num_output_channels=1)(x)
    elif C == 3 and physics.name == 'MRI':  # not working because MRI physics has a fixed img size
        x = transforms.Grayscale(num_output_channels=1)(x)
        x = torch.cat((x, torch.zeros_like(x)), dim=1)

    return generate_imgs(x, physics, use_gen, baseline, model, metrics)

def generate_imgs_from_dataset(dataset: EvalDataset, idx: int,
                               physics: PhysicsWithGenerator, use_gen: bool,
                               baseline: BaselineModel, model: EvalModel,
                               metrics: List[Metric]):
    ### Load 1 image
    x = dataset[idx]    # shape : (C, H, W)
    x = x.unsqueeze(0)  # shape : (1, C, H, W)

    return generate_imgs(x, physics, use_gen, baseline, model, metrics)

def generate_random_imgs_from_dataset(dataset: EvalDataset,
                                      physics: PhysicsWithGenerator,
                                      use_gen: bool,
                                      baseline: BaselineModel,
                                      model: EvalModel,
                                      metrics: List[Metric]):
    idx = random.randint(0, len(dataset)-1)
    x, y, out, out_baseline, saved_params_str, metrics_y, metrics_out, metrics_out_baseline = generate_imgs_from_dataset(
        dataset, idx, physics, use_gen, baseline, model, metrics
        )
    return idx, x, y, out, out_baseline, saved_params_str, metrics_y, metrics_out, metrics_out_baseline
    
def generate_imgs(x: torch.Tensor,
                  physics: PhysicsWithGenerator, use_gen: bool,
                  baseline: BaselineModel, model: EvalModel,
                  metrics: List[Metric]):
    print(f"[Before inference] CUDA current allocated: {torch.cuda.memory_allocated() / 1024**2:.2f} MB")
    print(f"[Before inference] CUDA current reserved: {torch.cuda.memory_reserved() / 1024**2:.2f} MB")
    print(f"[Before inference] CUDA max allocated: {torch.cuda.max_memory_allocated() / 1024**2:.2f} MB")
    print(f"[Before inference] CUDA max reserved: {torch.cuda.max_memory_reserved() / 1024**2:.2f} MB")

    ### Compute y
    with torch.no_grad():
        y = physics(x, use_gen)  # possible reduction in img shape due to Blurring

    ### Compute x_hat from RAM & DPIR
    ram_time = time.time()
    with torch.no_grad():
        out = model(y=y, physics=physics.physics)
    ram_time = time.time() - ram_time

    dpir_time = time.time()
    with torch.no_grad():
        out_baseline = baseline(y=y, physics=physics.physics)
    dpir_time = time.time() - dpir_time

    ### Process tensors before metric computation
    if "Blur" in physics.name:
        w_1, w_2 = (x.shape[2] - y.shape[2]) // 2, (x.shape[2] + y.shape[2]) // 2
        h_1, h_2 = (x.shape[3] - y.shape[3]) // 2, (x.shape[3] + y.shape[3]) // 2

        x = x[..., w_1:w_2, h_1:h_2]
        out = out[..., w_1:w_2, h_1:h_2]
        if out_baseline.shape != out.shape:
            out_baseline = out_baseline[..., w_1:w_2, h_1:h_2]

    ### Process y when y shape is different from x shape
    if physics.name == 'MRI' or physics.name == 'CT':
        y_plot = physics.physics.prox_l2(physics.physics.A_adjoint(y), y, 1e4)
    else:
        y_plot = y.clone()

    ### Metrics
    metrics_y = ""
    metrics_out = ""
    metrics_out_baseline = ""
    for metric in metrics:
        #if y.shape == x.shape:
        metrics_y += f"{metric.name} = {metric(y_plot, x).item():.4f}" + "\n"
        metrics_out += f"{metric.name} = {metric(out, x).item():.4f}" + "\n"
        metrics_out_baseline += f"{metric.name} = {metric(out_baseline, x).item():.4f}" + "\n"
    metrics_out += f"Inference time = {ram_time:.3f}s"
    metrics_out_baseline += f"Inference time = {dpir_time:.3f}s"

    ### Processing images for plotting :
    #     - clip value outside of [0,1]
    #     - shape (1, C, H, W) -> (C, H, W)
    #     - torch.Tensor object -> Pil object
    process_img = partial(dinv.utils.plotting.preprocess_img, rescale_mode="clip")
    to_pil = transforms.ToPILImage()
    x_pil = to_pil(process_img(x)[0].to('cpu'))
    y_pil = to_pil(process_img(y_plot)[0].to('cpu'))
    out_pil = to_pil(process_img(out)[0].to('cpu'))
    out_baseline_pil = to_pil(process_img(out_baseline)[0].to('cpu'))


    ### Free memory
    del x, y, out, out_baseline, y_plot
    torch.cuda.empty_cache()
    print(f"[After inference] CUDA current allocated: {torch.cuda.memory_allocated() / 1024**2:.2f} MB")
    print(f"[After inference] CUDA current reserved: {torch.cuda.memory_reserved() / 1024**2:.2f} MB")
    print(f"[After inference] CUDA max allocated: {torch.cuda.max_memory_allocated() / 1024**2:.2f} MB")
    print(f"[After inference] CUDA max reserved: {torch.cuda.max_memory_reserved() / 1024**2:.2f} MB")

    return x_pil, y_pil, out_pil, out_baseline_pil, physics.display_saved_params(), metrics_y, metrics_out, metrics_out_baseline

get_dataset_on_DEVICE_STR = partial(EvalDataset, device_str=DEVICE_STR)
get_physics_on_DEVICE_STR = partial(PhysicsWithGenerator, device_str=DEVICE_STR)
get_baseline_model_on_DEVICE_STR = partial(BaselineModel, device_str=DEVICE_STR)

def get_dataset(dataset_name):
    if dataset_name == 'MRI':
        available_physics = ['MRI']
        physics_name = 'MRI'
        baseline_name = 'DPIR_MRI'
    elif dataset_name == 'CT':
        available_physics = ['CT']
        physics_name = 'CT'
        baseline_name = 'DPIR_CT'
    else:
        available_physics = ['MotionBlur_medium', 'MotionBlur_hard',
                             'GaussianBlur_easy', 'GaussianBlur_medium', 'GaussianBlur_hard']
        physics_name = 'MotionBlur_hard'
        baseline_name = 'DPIR'

    dataset = get_dataset_on_DEVICE_STR(dataset_name)
    idx = 0
    physics = get_physics_on_DEVICE_STR(physics_name)
    baseline = get_baseline_model_on_DEVICE_STR(baseline_name)
    return dataset, idx, physics, baseline, available_physics


# global variables shared by all users
ram_model = EvalModel(device_str=DEVICE_STR)
ram_model.eval()
psnr = Metric.get_list_metrics(["PSNR"], device_str=DEVICE_STR)

generate_imgs_from_user_partial = partial(generate_imgs_from_user, model=ram_model, metrics=psnr)
generate_imgs_from_dataset_partial = partial(generate_imgs_from_dataset, model=ram_model, metrics=psnr)
generate_random_imgs_from_dataset_partial = partial(generate_random_imgs_from_dataset, model=ram_model, metrics=psnr)


### Gradio Blocks interface

print(f"[Init] CUDA max allocated: {torch.cuda.max_memory_allocated() / 1024**2:.2f} MB")
print(f"[Init] CUDA max reserved: {torch.cuda.max_memory_reserved() / 1024**2:.2f} MB")

title = "Inverse problem playground"  # displayed on gradio tab and in the gradio page
with gr.Blocks(title=title, theme=gr.themes.Glass()) as interface:
    gr.Markdown("## " + title)

    ### USER-SPECIFIC VARIABLES
    dataset_placeholder = gr.State(get_dataset_on_DEVICE_STR("Natural"))
    available_physics_placeholder = gr.State(['MotionBlur_medium', 'MotionBlur_hard',
                                              'GaussianBlur_easy', 'GaussianBlur_medium', 'GaussianBlur_hard'])
    # Issue giving directly a `torch.nn.module` to `gr.State(...)` since it has __call__ method
    # Solution: using lambda expression
    physics_placeholder = gr.State(lambda: get_physics_on_DEVICE_STR("MotionBlur_medium"))
    model_b_placeholder = gr.State(lambda: get_baseline_model_on_DEVICE_STR("DPIR"))

    print(f"[Render] CUDA max allocated: {torch.cuda.max_memory_allocated() / 1024**2:.2f} MB")
    print(f"[Render] CUDA max reserved: {torch.cuda.max_memory_reserved() / 1024**2:.2f} MB")

    @gr.render(inputs=[dataset_placeholder, physics_placeholder, available_physics_placeholder],
               triggers=[dataset_placeholder.change, physics_placeholder.change])
    def dynamic_layout(dataset, physics, available_physics):
        ### LAYOUT

        # Display images
        with gr.Row():
            gt_img = gr.Image(label="Ground-truth image", interactive=True, key='gt_img')
            observed_img = gr.Image(label="Observed image", interactive=False, key='observed_img')
            model_a_out = gr.Image(label="RAM output", interactive=False, key='ram_out')
            model_b_out = gr.Image(label="DPIR output", interactive=False, key='dpir_out')

        # Manage datasets and display metric values
        with gr.Row():
            with gr.Column(scale=1, min_width=160):
                run_button = gr.Button("Demo on above image", size='md')
                choose_dataset = gr.Radio(choices=EvalDataset.all_datasets,
                                          label="Datasets",
                                          value=dataset.name)
                idx_slider = gr.Slider(minimum=0, maximum=len(dataset)-1, step=1, label="Sample index", key='idx_slider')
                with gr.Row():
                    load_button = gr.Button("Run on index image from dataset", size='md')
                    load_random_button = gr.Button("Run on random image from dataset", size='md')
            with gr.Column(scale=1, min_width=160):
                observed_metrics = gr.Textbox(label="Observed metric", lines=2, key='metrics')
            with gr.Column(scale=1, min_width=160):
                out_a_metric = gr.Textbox(label="RAM output metrics", lines=2, key='ram_metrics')
            with gr.Column(scale=1, min_width=160):
                out_b_metric = gr.Textbox(label="DPIR output metrics", lines=2, key='dpir_metrics')

        # Manage physics
        with gr.Row():
            with gr.Column(scale=1):
                choose_physics = gr.Radio(choices=available_physics,
                                          label="Physics",
                                          value=physics.name)
                use_generator_button = gr.Checkbox(label="Generate physics parameters during inference", value=True, key='use_gen')
            with gr.Column(scale=1):
                with gr.Row():
                    key_selector = gr.Dropdown(choices=list(physics.saved_params["updatable_params"].keys()),
                                               label="Updatable Key")
                    value_text = gr.Textbox(label="Update Value")
                update_button = gr.Button("Manually update parameter value", size='md')
            with gr.Column(scale=2):
                physics_params = gr.Textbox(label="Physics parameters",
                                            lines=5,
                                            value=physics.display_saved_params())  

        ### Event listeners

        choose_dataset.change(fn=get_dataset,
                              inputs=choose_dataset,
                              outputs=[dataset_placeholder, idx_slider, physics_placeholder, model_b_placeholder, available_physics_placeholder])
        choose_physics.change(fn=get_physics_on_DEVICE_STR,
                              inputs=choose_physics,
                              outputs=[physics_placeholder])
        update_button.click(fn=physics.update_and_display_params,
                            inputs=[key_selector, value_text], outputs=physics_params)
        run_button.click(fn=generate_imgs_from_user_partial,
                         inputs=[gt_img,
                                 physics_placeholder,
                                 use_generator_button,
                                 model_b_placeholder],
                         outputs=[gt_img, observed_img, model_a_out, model_b_out,
                                  physics_params, observed_metrics, out_a_metric, out_b_metric])
        load_button.click(fn=generate_imgs_from_dataset_partial,
                          inputs=[dataset_placeholder,
                                  idx_slider,
                                  physics_placeholder,
                                  use_generator_button,
                                  model_b_placeholder],
                          outputs=[gt_img, observed_img, model_a_out, model_b_out,
                                   physics_params, observed_metrics, out_a_metric, out_b_metric])
        load_random_button.click(fn=generate_random_imgs_from_dataset_partial,
                                 inputs=[dataset_placeholder,
                                         physics_placeholder,
                                         use_generator_button,
                                         model_b_placeholder],
                                 outputs=[idx_slider, gt_img, observed_img, model_a_out, model_b_out,
                                          physics_params, observed_metrics, out_a_metric, out_b_metric])


interface.launch()