balthou's picture
update slider names, add md
7aaaf1f
from rstor.synthetic_data.interactive.interactive_dead_leaves import generate_deadleave
from rstor.analyzis.interactive.crop import crop_selector, crop, rescale_thumbnail
from rstor.analyzis.interactive.inference import infer
from rstor.analyzis.interactive.degradation import degrade_noise, degrade_blur, downsample, degrade_blur_gaussian, get_blur_kernel
from rstor.analyzis.interactive.model_selection import model_selector
from rstor.analyzis.interactive.images import image_selector
from rstor.analyzis.interactive.metrics import get_metrics, configure_metrics
from interactive_pipe import interactive, KeyboardControl
from typing import Tuple, List
from functools import partial
import numpy as np
get_metrics_restored = partial(get_metrics, image_name="restored")
get_metrics_degraded = partial(get_metrics, image_name="degraded")
def deadleave_inference_pipeline(models_dict: dict) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
groundtruth = generate_deadleave()
groundtruth = downsample(groundtruth)
model = model_selector(models_dict)
degraded = degrade_blur_gaussian(groundtruth)
degraded = degrade_noise(degraded)
restored = infer(degraded, model)
crop_selector(restored)
groundtruth, degraded, restored = crop(groundtruth, degraded, restored)
configure_metrics()
get_metrics_restored(restored, groundtruth)
get_metrics_degraded(degraded, groundtruth)
return groundtruth, degraded, restored
CANVAS_DICT = {
"demo": [["degraded", "restored"]],
"landscape_light": [["degraded", "restored", "groundtruth"]],
"landscape": [["degraded", "restored", "blur_kernel", "groundtruth"]],
"full": [["degraded", "restored"], ["blur_kernel", "groundtruth"]]
}
CANVAS = list(CANVAS_DICT.keys())
def morph_canvas(canvas=CANVAS[0], global_params={}):
global_params["__pipeline"].outputs = CANVAS_DICT[canvas]
return None
def visualize_kernel(kernel, global_params={}):
kernel_amplif = kernel.copy()
# kernel_amplif = kernel_amplif - kernel_amplif.min() / (kernel_amplif.max() - kernel_amplif.min())
kernel_amplif = (kernel_amplif * 10).clip(0, 1)
return kernel_amplif
def natural_inference_pipeline(input_image_list: List[np.ndarray], models_dict: dict):
model = model_selector(models_dict)
img_clean = image_selector(input_image_list)
crop_selector(img_clean)
groundtruth = crop(img_clean)
blur_kernel = get_blur_kernel()
degraded = degrade_blur(groundtruth, blur_kernel)
# degraded = degrade_noise(degraded)
kernel_amplif = visualize_kernel(blur_kernel)
kernel_amplif = rescale_thumbnail(kernel_amplif)
restored = infer(degraded, model)
# configure_metrics()
# get_metrics_restored(restored, groundtruth)
# get_metrics_degraded(degraded, groundtruth)
# morph_canvas()
# return [[degraded, restored], [blur_kernel, groundtruth]]
return [degraded, restored, groundtruth, kernel_amplif]