import json import os import random import time from functools import partial from pathlib import Path from typing import List import deepinv as dinv import gradio as gr import torch from PIL import Image from torchvision import transforms from factories import PhysicsWithGenerator, EvalModel, BaselineModel, EvalDataset, Metric ### Config DEVICE_STR = 'cuda' # run model inference on NVIDIA gpu torch.set_grad_enabled(False) # stops tracking values for gradients ### Gradio Utils def generate_imgs_from_user(image, model: EvalModel, baseline: BaselineModel, physics: PhysicsWithGenerator, use_gen: bool, metrics: List[Metric]): if image is None: return None, None, None, None, None, None, None, None # PIL image -> torch.Tensor x = transforms.ToTensor()(image).unsqueeze(0).to(DEVICE_STR) return generate_imgs(x, model, baseline, physics, use_gen, metrics) def generate_imgs_from_dataset(dataset: EvalDataset, idx: int, model: EvalModel, baseline: BaselineModel, physics: PhysicsWithGenerator, use_gen: bool, metrics: List[Metric]): ### Load 1 image x = dataset[idx] # shape : (C, H, W) x = x.unsqueeze(0) # shape : (1, C, H, W) return generate_imgs(x, model, baseline, physics, use_gen, metrics) def generate_random_imgs_from_dataset(dataset: EvalDataset, model: EvalModel, baseline: BaselineModel, physics: PhysicsWithGenerator, use_gen: bool, metrics: List[Metric]): idx = random.randint(0, len(dataset)-1) x, y, out, out_baseline, saved_params_str, metrics_y, metrics_out, metrics_out_baseline = generate_imgs_from_dataset( dataset, idx, model, baseline, physics, use_gen, metrics ) return idx, x, y, out, out_baseline, saved_params_str, metrics_y, metrics_out, metrics_out_baseline def generate_imgs(x: torch.Tensor, model: EvalModel, baseline: BaselineModel, physics: PhysicsWithGenerator, use_gen: bool, metrics: List[Metric]): ### Compute y y = physics(x, use_gen) # possible reduction in img shape due to Blurring ### Compute x_hat from RAM & DPIR ram_time = time.time() out = model(y=y, physics=physics.physics) ram_time = time.time() - ram_time dpir_time = time.time() out_baseline = baseline(y=y, physics=physics.physics) dpir_time = time.time() - dpir_time ### Process tensors before metric computation if "Blur" in physics.name: w_1, w_2 = (x.shape[2] - y.shape[2]) // 2, (x.shape[2] + y.shape[2]) // 2 h_1, h_2 = (x.shape[3] - y.shape[3]) // 2, (x.shape[3] + y.shape[3]) // 2 x = x[..., w_1:w_2, h_1:h_2] out = out[..., w_1:w_2, h_1:h_2] if out_baseline.shape != out.shape: out_baseline = out_baseline[..., w_1:w_2, h_1:h_2] ### Metrics metrics_y = "" metrics_out = f"Inference time = {ram_time:.3f}s" + "\n" metrics_out_baseline = f"Inference time = {dpir_time:.3f}s" + "\n" for metric in metrics: if y.shape == x.shape: metrics_y += f"{metric.name} = {metric(y, x).item():.4f}" + "\n" metrics_out += f"{metric.name} = {metric(out, x).item():.4f}" + "\n" metrics_out_baseline += f"{metric.name} = {metric(out_baseline, x).item():.4f}" + "\n" ### Process y when y shape is different from x shape if physics.name == "MRI": y_plot = physics.physics.prox_l2(physics.physics.A_adjoint(y), y, 1e4) elif physics.name == "CT": y_plot = physics.physics.A_adjoint(y) else: y_plot = y.clone() ### Processing images for plotting : # - clip value outside of [0,1] # - shape (1, C, H, W) -> (C, H, W) # - torch.Tensor object -> Pil object process_img = partial(dinv.utils.plotting.preprocess_img, rescale_mode="clip") to_pil = transforms.ToPILImage() x = to_pil(process_img(x)[0].to('cpu')) y = to_pil(process_img(y_plot)[0].to('cpu')) out = to_pil(process_img(out)[0].to('cpu')) out_baseline = to_pil(process_img(out_baseline)[0].to('cpu')) return x, y, out, out_baseline, physics.display_saved_params(), metrics_y, metrics_out, metrics_out_baseline get_list_metrics_on_DEVICE_STR = partial(Metric.get_list_metrics, device_str=DEVICE_STR) get_eval_model_on_DEVICE_STR = partial(EvalModel, device_str=DEVICE_STR) get_baseline_model_on_DEVICE_STR = partial(BaselineModel, device_str=DEVICE_STR) get_dataset_on_DEVICE_STR = partial(EvalDataset, device_str=DEVICE_STR) get_physics_on_DEVICE_STR = partial(PhysicsWithGenerator, device_str=DEVICE_STR) def get_dataset(dataset_name): if dataset_name == 'MRI': available_physics = ['MRI'] physics_name = 'MRI' baseline_name = 'DPIR_MRI' elif dataset_name == 'CT': available_physics = ['CT'] physics_name = 'CT' baseline_name = 'DPIR_CT' else: available_physics = ['MotionBlur_easy', 'MotionBlur_medium', 'MotionBlur_hard', 'GaussianBlur_easy', 'GaussianBlur_medium', 'GaussianBlur_hard'] physics_name = 'MotionBlur_easy' baseline_name = 'DPIR' dataset = get_dataset_on_DEVICE_STR(dataset_name) idx = 0 physics = get_physics_on_DEVICE_STR(physics_name) baseline = get_baseline_model_on_DEVICE_STR(baseline_name) return dataset, idx, physics, baseline, available_physics ### Gradio Blocks interface title = "Inverse problem playground" # displayed on gradio tab and in the gradio page with gr.Blocks(title=title, theme=gr.themes.Glass()) as interface: gr.Markdown("## " + title) ### DEFAULT VALUES # Issue: giving directly a `torch.nn.module` to `gr.State(...)` since it has __call__ method # Solution: using lambda expression model_a_placeholder = gr.State(lambda: get_eval_model_on_DEVICE_STR("unext_emb_physics_config_C", "")) model_b_placeholder = gr.State(lambda: get_baseline_model_on_DEVICE_STR("DPIR")) metrics_placeholder = gr.State(get_list_metrics_on_DEVICE_STR(["PSNR"])) dataset_placeholder = gr.State(get_dataset_on_DEVICE_STR("Natural")) physics_placeholder = gr.State(lambda: get_physics_on_DEVICE_STR("MotionBlur_easy")) available_physics_placeholder = gr.State(['MotionBlur_easy', 'MotionBlur_medium', 'MotionBlur_hard', 'GaussianBlur_easy', 'GaussianBlur_medium', 'GaussianBlur_hard']) ### LAYOUT # Display images with gr.Row(): gt_img = gr.Image(label="Ground-truth IMAGE", interactive=True) observed_img = gr.Image(label="Observed IMAGE", interactive=False) model_a_out = gr.Image(label="RAM OUTPUT", interactive=False) model_b_out = gr.Image(label="DPIR OUTPUT", interactive=False) @gr.render(inputs=[dataset_placeholder, physics_placeholder, available_physics_placeholder]) def dynamic_layout(dataset, physics, available_physics): ### LAYOUT # Manage datasets and display metric values with gr.Row(): with gr.Column(): run_button = gr.Button("Demo on above image") choose_dataset = gr.Radio(choices=EvalDataset.all_datasets, label="Datasets", value=dataset.name) idx_slider = gr.Slider(minimum=0, maximum=len(dataset)-1, step=1, label="Sample index", key=0) with gr.Row(): load_button = gr.Button("Run on index image from dataset") load_random_button = gr.Button("Run on random image from dataset") with gr.Column(): observed_metrics = gr.Textbox(label="PSNR(Observed, Ground-truth)", lines=1) with gr.Column(): out_a_metric = gr.Textbox(label="PSNR(RAM, Ground-truth)", lines=1) with gr.Column(): out_b_metric = gr.Textbox(label="PSNR(DPIR, Ground-truth)", lines=1) # Manage physics with gr.Row(): with gr.Column(scale=1): choose_physics = gr.Radio(choices=available_physics, label="Physics", value=physics.name) use_generator_button = gr.Checkbox(label="Generate physics parameters during inference") with gr.Column(scale=1): with gr.Row(): key_selector = gr.Dropdown(choices=list(physics.saved_params["updatable_params"].keys()), label="Updatable Parameter Key") value_text = gr.Textbox(label="Update Value") update_button = gr.Button("Manually update parameter value") with gr.Column(scale=2): physics_params = gr.Textbox(label="Physics parameters", lines=5, value=physics.display_saved_params()) ### Event listeners choose_dataset.change(fn=get_dataset, inputs=choose_dataset, outputs=[dataset_placeholder, idx_slider, physics_placeholder, model_b_placeholder, available_physics_placeholder]) choose_physics.change(fn=get_physics_on_DEVICE_STR, inputs=choose_physics, outputs=[physics_placeholder]) update_button.click(fn=physics.update_and_display_params, inputs=[key_selector, value_text], outputs=physics_params) run_button.click(fn=generate_imgs_from_user, inputs=[gt_img, model_a_placeholder, model_b_placeholder, physics_placeholder, use_generator_button, metrics_placeholder], outputs=[gt_img, observed_img, model_a_out, model_b_out, physics_params, observed_metrics, out_a_metric, out_b_metric]) load_button.click(fn=generate_imgs_from_dataset, inputs=[dataset_placeholder, idx_slider, model_a_placeholder, model_b_placeholder, physics_placeholder, use_generator_button, metrics_placeholder], outputs=[gt_img, observed_img, model_a_out, model_b_out, physics_params, observed_metrics, out_a_metric, out_b_metric]) load_random_button.click(fn=generate_random_imgs_from_dataset, inputs=[dataset_placeholder, model_a_placeholder, model_b_placeholder, physics_placeholder, use_generator_button, metrics_placeholder], outputs=[idx_slider, gt_img, observed_img, model_a_out, model_b_out, physics_params, observed_metrics, out_a_metric, out_b_metric]) interface.launch()