Spaces:
Sleeping
Sleeping
import json | |
import os | |
import random | |
import time | |
from functools import partial | |
from pathlib import Path | |
from typing import List | |
import deepinv as dinv | |
import gradio as gr | |
import torch | |
from PIL import Image | |
from torchvision import transforms | |
from factories import PhysicsWithGenerator, EvalModel, BaselineModel, EvalDataset, Metric | |
### Config | |
DEVICE_STR = 'cuda' if torch.cuda.is_available() else 'cpu' # run model inference on NVIDIA gpu | |
torch.set_grad_enabled(False) # stops tracking values for gradients | |
### Gradio Utils | |
def generate_imgs_from_user(image, | |
model: EvalModel, baseline: BaselineModel, | |
physics: PhysicsWithGenerator, use_gen: bool, | |
metrics: List[Metric]): | |
if image is None: | |
return None, None, None, None, None, None, None, None | |
# PIL image -> torch.Tensor | |
x = transforms.ToTensor()(image).unsqueeze(0).to(DEVICE_STR) | |
return generate_imgs(x, model, baseline, physics, use_gen, metrics) | |
def generate_imgs_from_dataset(dataset: EvalDataset, idx: int, | |
model: EvalModel, baseline: BaselineModel, | |
physics: PhysicsWithGenerator, use_gen: bool, | |
metrics: List[Metric]): | |
### Load 1 image | |
x = dataset[idx] # shape : (C, H, W) | |
x = x.unsqueeze(0) # shape : (1, C, H, W) | |
return generate_imgs(x, model, baseline, physics, use_gen, metrics) | |
def generate_random_imgs_from_dataset(dataset: EvalDataset, | |
model: EvalModel, | |
baseline: BaselineModel, | |
physics: PhysicsWithGenerator, | |
use_gen: bool, | |
metrics: List[Metric]): | |
idx = random.randint(0, len(dataset)-1) | |
x, y, out, out_baseline, saved_params_str, metrics_y, metrics_out, metrics_out_baseline = generate_imgs_from_dataset( | |
dataset, idx, model, baseline, physics, use_gen, metrics | |
) | |
return idx, x, y, out, out_baseline, saved_params_str, metrics_y, metrics_out, metrics_out_baseline | |
def generate_imgs(x: torch.Tensor, | |
model: EvalModel, baseline: BaselineModel, | |
physics: PhysicsWithGenerator, use_gen: bool, | |
metrics: List[Metric]): | |
### Compute y | |
y = physics(x, use_gen) # possible reduction in img shape due to Blurring | |
### Compute x_hat from RAM & DPIR | |
ram_time = time.time() | |
out = model(y=y, physics=physics.physics) | |
ram_time = time.time() - ram_time | |
dpir_time = time.time() | |
out_baseline = baseline(y=y, physics=physics.physics) | |
dpir_time = time.time() - dpir_time | |
### Process tensors before metric computation | |
if "Blur" in physics.name: | |
w_1, w_2 = (x.shape[2] - y.shape[2]) // 2, (x.shape[2] + y.shape[2]) // 2 | |
h_1, h_2 = (x.shape[3] - y.shape[3]) // 2, (x.shape[3] + y.shape[3]) // 2 | |
x = x[..., w_1:w_2, h_1:h_2] | |
out = out[..., w_1:w_2, h_1:h_2] | |
if out_baseline.shape != out.shape: | |
out_baseline = out_baseline[..., w_1:w_2, h_1:h_2] | |
### Metrics | |
metrics_y = "" | |
metrics_out = f"Inference time = {ram_time:.3f}s" + "\n" | |
metrics_out_baseline = f"Inference time = {dpir_time:.3f}s" + "\n" | |
for metric in metrics: | |
if y.shape == x.shape: | |
metrics_y += f"{metric.name} = {metric(y, x).item():.4f}" + "\n" | |
metrics_out += f"{metric.name} = {metric(out, x).item():.4f}" + "\n" | |
metrics_out_baseline += f"{metric.name} = {metric(out_baseline, x).item():.4f}" + "\n" | |
### Process y when y shape is different from x shape | |
if physics.name == "MRI": | |
y_plot = physics.physics.prox_l2(physics.physics.A_adjoint(y), y, 1e4) | |
elif physics.name == "CT": | |
y_plot = physics.physics.A_adjoint(y) | |
else: | |
y_plot = y.clone() | |
### Processing images for plotting : | |
# - clip value outside of [0,1] | |
# - shape (1, C, H, W) -> (C, H, W) | |
# - torch.Tensor object -> Pil object | |
process_img = partial(dinv.utils.plotting.preprocess_img, rescale_mode="clip") | |
to_pil = transforms.ToPILImage() | |
x = to_pil(process_img(x)[0].to('cpu')) | |
y = to_pil(process_img(y_plot)[0].to('cpu')) | |
out = to_pil(process_img(out)[0].to('cpu')) | |
out_baseline = to_pil(process_img(out_baseline)[0].to('cpu')) | |
return x, y, out, out_baseline, physics.display_saved_params(), metrics_y, metrics_out, metrics_out_baseline | |
get_list_metrics_on_DEVICE_STR = partial(Metric.get_list_metrics, device_str=DEVICE_STR) | |
get_eval_model_on_DEVICE_STR = partial(EvalModel, device_str=DEVICE_STR) | |
get_baseline_model_on_DEVICE_STR = partial(BaselineModel, device_str=DEVICE_STR) | |
get_dataset_on_DEVICE_STR = partial(EvalDataset, device_str=DEVICE_STR) | |
get_physics_on_DEVICE_STR = partial(PhysicsWithGenerator, device_str=DEVICE_STR) | |
def get_dataset(dataset_name): | |
if dataset_name == 'MRI': | |
available_physics = ['MRI'] | |
physics_name = 'MRI' | |
baseline_name = 'DPIR_MRI' | |
elif dataset_name == 'CT': | |
available_physics = ['CT'] | |
physics_name = 'CT' | |
baseline_name = 'DPIR_CT' | |
else: | |
available_physics = ['MotionBlur_easy', 'MotionBlur_medium', 'MotionBlur_hard', | |
'GaussianBlur_easy', 'GaussianBlur_medium', 'GaussianBlur_hard'] | |
physics_name = 'MotionBlur_easy' | |
baseline_name = 'DPIR' | |
dataset = get_dataset_on_DEVICE_STR(dataset_name) | |
idx = 0 | |
physics = get_physics_on_DEVICE_STR(physics_name) | |
baseline = get_baseline_model_on_DEVICE_STR(baseline_name) | |
return dataset, idx, physics, baseline, available_physics | |
### Gradio Blocks interface | |
title = "Inverse problem playground" # displayed on gradio tab and in the gradio page | |
with gr.Blocks(title=title, theme=gr.themes.Glass()) as interface: | |
gr.Markdown("## " + title) | |
### DEFAULT VALUES | |
# Issue: giving directly a `torch.nn.module` to `gr.State(...)` since it has __call__ method | |
# Solution: using lambda expression | |
model_a_placeholder = gr.State(lambda: get_eval_model_on_DEVICE_STR("unext_emb_physics_config_C", "")) | |
model_b_placeholder = gr.State(lambda: get_baseline_model_on_DEVICE_STR("DPIR")) | |
metrics_placeholder = gr.State(get_list_metrics_on_DEVICE_STR(["PSNR"])) | |
dataset_placeholder = gr.State(get_dataset_on_DEVICE_STR("Natural")) | |
physics_placeholder = gr.State(lambda: get_physics_on_DEVICE_STR("MotionBlur_easy")) | |
available_physics_placeholder = gr.State(['MotionBlur_easy', 'MotionBlur_medium', 'MotionBlur_hard', | |
'GaussianBlur_easy', 'GaussianBlur_medium', 'GaussianBlur_hard']) | |
### LAYOUT | |
# Display images | |
with gr.Row(): | |
gt_img = gr.Image(label="Ground-truth image", interactive=True) | |
observed_img = gr.Image(label="Observed image", interactive=False) | |
model_a_out = gr.Image(label="RAM output", interactive=False) | |
model_b_out = gr.Image(label="DPIR output", interactive=False) | |
def dynamic_layout(dataset, physics, available_physics): | |
### LAYOUT | |
# Manage datasets and display metric values | |
with gr.Row(): | |
with gr.Column(scale=1, min_width=160): | |
run_button = gr.Button("Demo on above image", size='md') | |
choose_dataset = gr.Radio(choices=EvalDataset.all_datasets, | |
label="Datasets", | |
value=dataset.name) | |
idx_slider = gr.Slider(minimum=0, maximum=len(dataset)-1, step=1, label="Sample index", key=0) | |
with gr.Row(): | |
load_button = gr.Button("Run on index image from dataset", size='md') | |
load_random_button = gr.Button("Run on random image from dataset", size='md') | |
with gr.Column(scale=1, min_width=160): | |
observed_metrics = gr.Textbox(label="PSNR Observed", | |
lines=1) | |
with gr.Column(scale=1, min_width=160): | |
out_a_metric = gr.Textbox(label="PSNR RAM output", | |
lines=1) | |
with gr.Column(scale=1, min_width=160): | |
out_b_metric = gr.Textbox(label="PSNR DPIR", | |
lines=1) | |
# Manage physics | |
with gr.Row(): | |
with gr.Column(scale=1): | |
choose_physics = gr.Radio(choices=available_physics, | |
label="Physics", | |
value=physics.name) | |
use_generator_button = gr.Checkbox(label="Generate physics parameters during inference", value=True) | |
with gr.Column(scale=1): | |
with gr.Row(): | |
key_selector = gr.Dropdown(choices=list(physics.saved_params["updatable_params"].keys()), | |
label="Updatable Parameter Key") | |
value_text = gr.Textbox(label="Update Value") | |
update_button = gr.Button("Manually update parameter value") | |
with gr.Column(scale=2): | |
physics_params = gr.Textbox(label="Physics parameters", | |
lines=5, | |
value=physics.display_saved_params()) | |
### Event listeners | |
choose_dataset.change(fn=get_dataset, | |
inputs=choose_dataset, | |
outputs=[dataset_placeholder, idx_slider, physics_placeholder, model_b_placeholder, available_physics_placeholder]) | |
choose_physics.change(fn=get_physics_on_DEVICE_STR, | |
inputs=choose_physics, | |
outputs=[physics_placeholder]) | |
update_button.click(fn=physics.update_and_display_params, | |
inputs=[key_selector, value_text], outputs=physics_params) | |
run_button.click(fn=generate_imgs_from_user, | |
inputs=[gt_img, | |
model_a_placeholder, | |
model_b_placeholder, | |
physics_placeholder, | |
use_generator_button, | |
metrics_placeholder], | |
outputs=[gt_img, observed_img, model_a_out, model_b_out, | |
physics_params, observed_metrics, out_a_metric, out_b_metric]) | |
load_button.click(fn=generate_imgs_from_dataset, | |
inputs=[dataset_placeholder, | |
idx_slider, | |
model_a_placeholder, | |
model_b_placeholder, | |
physics_placeholder, | |
use_generator_button, | |
metrics_placeholder], | |
outputs=[gt_img, observed_img, model_a_out, model_b_out, | |
physics_params, observed_metrics, out_a_metric, out_b_metric]) | |
load_random_button.click(fn=generate_random_imgs_from_dataset, | |
inputs=[dataset_placeholder, | |
model_a_placeholder, | |
model_b_placeholder, | |
physics_placeholder, | |
use_generator_button, | |
metrics_placeholder], | |
outputs=[idx_slider, gt_img, observed_img, model_a_out, model_b_out, | |
physics_params, observed_metrics, out_a_metric, out_b_metric]) | |
interface.launch() | |