import json import os import random import time from functools import partial from pathlib import Path from typing import List import deepinv as dinv import gradio as gr import torch from PIL import Image from torchvision import transforms from factories import PhysicsWithGenerator, EvalModel, BaselineModel, EvalDataset, Metric ### Config # run model inference on NVIDIA gpu if available DEVICE_STR = 'cuda' if torch.cuda.is_available() else 'cpu' # stops tracking values for gradients torch.set_grad_enabled(False) ### Gradio Utils def generate_imgs_from_user(image, physics: PhysicsWithGenerator, use_gen: bool, baseline: BaselineModel, model: EvalModel, metrics: List[Metric]): # Happens when user image is missing if image is None: return None, None, None, None, None, None, None, None # PIL image -> torch.Tensor / (1, C, H, W) / move to DEVICE_STR x = transforms.ToTensor()(image).unsqueeze(0).to(DEVICE_STR) return generate_imgs(x, physics, use_gen, baseline, model, metrics) def generate_imgs_from_dataset(dataset: EvalDataset, idx: int, physics: PhysicsWithGenerator, use_gen: bool, baseline: BaselineModel, model: EvalModel, metrics: List[Metric]): ### Load 1 image x = dataset[idx] # shape : (C, H, W) x = x.unsqueeze(0) # shape : (1, C, H, W) return generate_imgs(x, physics, use_gen, baseline, model, metrics) def generate_random_imgs_from_dataset(dataset: EvalDataset, physics: PhysicsWithGenerator, use_gen: bool, baseline: BaselineModel, model: EvalModel, metrics: List[Metric]): idx = random.randint(0, len(dataset)-1) x, y, out, out_baseline, saved_params_str, metrics_y, metrics_out, metrics_out_baseline = generate_imgs_from_dataset( dataset, idx, physics, use_gen, baseline, model, metrics ) return idx, x, y, out, out_baseline, saved_params_str, metrics_y, metrics_out, metrics_out_baseline def generate_imgs(x: torch.Tensor, physics: PhysicsWithGenerator, use_gen: bool, baseline: BaselineModel, model: EvalModel, metrics: List[Metric]): print(torch.cuda.memory_allocated() / 1024**2) ### Compute y y = physics(x, use_gen) # possible reduction in img shape due to Blurring ### Compute x_hat from RAM & DPIR ram_time = time.time() out = model(y=y, physics=physics.physics) ram_time = time.time() - ram_time dpir_time = time.time() out_baseline = baseline(y=y, physics=physics.physics) dpir_time = time.time() - dpir_time ### Process tensors before metric computation if "Blur" in physics.name: w_1, w_2 = (x.shape[2] - y.shape[2]) // 2, (x.shape[2] + y.shape[2]) // 2 h_1, h_2 = (x.shape[3] - y.shape[3]) // 2, (x.shape[3] + y.shape[3]) // 2 x = x[..., w_1:w_2, h_1:h_2] out = out[..., w_1:w_2, h_1:h_2] if out_baseline.shape != out.shape: out_baseline = out_baseline[..., w_1:w_2, h_1:h_2] ### Metrics metrics_y = "" metrics_out = f"Inference time = {ram_time:.3f}s" + "\n" metrics_out_baseline = f"Inference time = {dpir_time:.3f}s" + "\n" for metric in metrics: if y.shape == x.shape: metrics_y += f"{metric.name} = {metric(y, x).item():.4f}" + "\n" metrics_out += f"{metric.name} = {metric(out, x).item():.4f}" + "\n" metrics_out_baseline += f"{metric.name} = {metric(out_baseline, x).item():.4f}" + "\n" ### Process y when y shape is different from x shape if physics.name == "MRI": y_plot = physics.physics.prox_l2(physics.physics.A_adjoint(y), y, 1e4) elif physics.name == "CT": y_plot = physics.physics.A_adjoint(y) else: y_plot = y.clone() ### Processing images for plotting : # - clip value outside of [0,1] # - shape (1, C, H, W) -> (C, H, W) # - torch.Tensor object -> Pil object process_img = partial(dinv.utils.plotting.preprocess_img, rescale_mode="clip") to_pil = transforms.ToPILImage() x = to_pil(process_img(x)[0].to('cpu')) y = to_pil(process_img(y_plot)[0].to('cpu')) out = to_pil(process_img(out)[0].to('cpu')) out_baseline = to_pil(process_img(out_baseline)[0].to('cpu')) print(torch.cuda.memory_allocated() / 1024**2) return x, y, out, out_baseline, physics.display_saved_params(), metrics_y, metrics_out, metrics_out_baseline get_dataset_on_DEVICE_STR = partial(EvalDataset, device_str=DEVICE_STR) get_physics_on_DEVICE_STR = partial(PhysicsWithGenerator, device_str=DEVICE_STR) get_baseline_model_on_DEVICE_STR = partial(BaselineModel, device_str=DEVICE_STR) def get_dataset(dataset_name): if dataset_name == 'MRI': available_physics = ['MRI'] physics_name = 'MRI' baseline_name = 'DPIR_MRI' elif dataset_name == 'CT': available_physics = ['CT'] physics_name = 'CT' baseline_name = 'DPIR_CT' else: available_physics = ['MotionBlur_easy', 'MotionBlur_medium', 'MotionBlur_hard', 'GaussianBlur_easy', 'GaussianBlur_medium', 'GaussianBlur_hard'] physics_name = 'MotionBlur_easy' baseline_name = 'DPIR' dataset = get_dataset_on_DEVICE_STR(dataset_name) idx = 0 physics = get_physics_on_DEVICE_STR(physics_name) baseline = get_baseline_model_on_DEVICE_STR(baseline_name) return dataset, idx, physics, baseline, available_physics # global variables shared by all users ram_model = EvalModel("unext_emb_physics_config_C", device_str=DEVICE_STR) psnr = Metric.get_list_metrics(["PSNR"], device_str=DEVICE_STR) generate_imgs_from_user_partial = partial(generate_imgs_from_user, model=ram_model, metrics=psnr) generate_imgs_from_dataset_partial = partial(generate_imgs_from_dataset, model=ram_model, metrics=psnr) generate_random_imgs_from_dataset_partial = partial(generate_random_imgs_from_dataset, model=ram_model, metrics=psnr) print(torch.cuda.memory_allocated() / 1024**2) ### Gradio Blocks interface title = "Inverse problem playground" # displayed on gradio tab and in the gradio page with gr.Blocks(title=title, theme=gr.themes.Glass()) as interface: gr.Markdown("## " + title) ### USER-SPECIFIC VARIABLES dataset_placeholder = gr.State(get_dataset_on_DEVICE_STR("Natural")) available_physics_placeholder = gr.State(['MotionBlur_easy', 'MotionBlur_medium', 'MotionBlur_hard', 'GaussianBlur_easy', 'GaussianBlur_medium', 'GaussianBlur_hard']) # Issue giving directly a `torch.nn.module` to `gr.State(...)` since it has __call__ method # Solution: using lambda expression physics_placeholder = gr.State(lambda: get_physics_on_DEVICE_STR("MotionBlur_easy")) model_b_placeholder = gr.State(lambda: get_baseline_model_on_DEVICE_STR("DPIR")) print(torch.cuda.memory_allocated() / 1024**2) @gr.render(inputs=[dataset_placeholder, physics_placeholder, available_physics_placeholder]) def dynamic_layout(dataset, physics, available_physics): ### LAYOUT # Display images with gr.Row(): gt_img = gr.Image(label="Ground-truth image", interactive=True, key=0) observed_img = gr.Image(label="Observed image", interactive=False, key=1) model_a_out = gr.Image(label="RAM output", interactive=False, key=2) model_b_out = gr.Image(label="DPIR output", interactive=False, key=3) # Manage datasets and display metric values with gr.Row(): with gr.Column(scale=1, min_width=160): run_button = gr.Button("Demo on above image", size='md') choose_dataset = gr.Radio(choices=EvalDataset.all_datasets, label="Datasets", value=dataset.name) idx_slider = gr.Slider(minimum=0, maximum=len(dataset)-1, step=1, label="Sample index", key=4) with gr.Row(): load_button = gr.Button("Run on index image from dataset", size='md') load_random_button = gr.Button("Run on random image from dataset", size='md') with gr.Column(scale=1, min_width=160): observed_metrics = gr.Textbox(label="Observed metric", lines=2, key=5) with gr.Column(scale=1, min_width=160): out_a_metric = gr.Textbox(label="RAM output metrics", lines=2, key=6) with gr.Column(scale=1, min_width=160): out_b_metric = gr.Textbox(label="DPIR output metrics", lines=2, key=7) # Manage physics with gr.Row(): with gr.Column(scale=1): choose_physics = gr.Radio(choices=available_physics, label="Physics", value=physics.name) use_generator_button = gr.Checkbox(label="Generate physics parameters during inference", value=True, key=8) with gr.Column(scale=1): with gr.Row(): key_selector = gr.Dropdown(choices=list(physics.saved_params["updatable_params"].keys()), label="Updatable Key") value_text = gr.Textbox(label="Update Value") update_button = gr.Button("Manually update parameter value", size='md') with gr.Column(scale=2): physics_params = gr.Textbox(label="Physics parameters", lines=5, value=physics.display_saved_params()) ### Event listeners choose_dataset.change(fn=get_dataset, inputs=choose_dataset, outputs=[dataset_placeholder, idx_slider, physics_placeholder, model_b_placeholder, available_physics_placeholder]) choose_physics.change(fn=get_physics_on_DEVICE_STR, inputs=choose_physics, outputs=[physics_placeholder]) update_button.click(fn=physics.update_and_display_params, inputs=[key_selector, value_text], outputs=physics_params) run_button.click(fn=generate_imgs_from_user_partial, inputs=[gt_img, physics_placeholder, use_generator_button, model_b_placeholder], outputs=[gt_img, observed_img, model_a_out, model_b_out, physics_params, observed_metrics, out_a_metric, out_b_metric]) load_button.click(fn=generate_imgs_from_dataset_partial, inputs=[dataset_placeholder, idx_slider, physics_placeholder, use_generator_button, model_b_placeholder], outputs=[gt_img, observed_img, model_a_out, model_b_out, physics_params, observed_metrics, out_a_metric, out_b_metric]) load_random_button.click(fn=generate_random_imgs_from_dataset_partial, inputs=[dataset_placeholder, physics_placeholder, use_generator_button, model_b_placeholder], outputs=[idx_slider, gt_img, observed_img, model_a_out, model_b_out, physics_params, observed_metrics, out_a_metric, out_b_metric]) interface.launch()