balthou's picture
initiate demo
cec5823
from configuration import ROOT_DIR, OUTPUT_FOLDER_NAME, INFERENCE_FOLDER_NAME
from rstor.analyzis.parser import get_models_parser
from batch_processing import Batch
from rstor.properties import (
DEVICE, NAME, PRETTY_NAME, DATALOADER, CONFIG_DEAD_LEAVES, VALIDATION,
BATCH_SIZE, SIZE,
REDUCTION_SKIP,
TRACES_TARGET, TRACES_DEGRADED, TRACES_RESTORED, TRACES_METRICS, TRACES_ALL,
SAMPLER_SATURATED,
CONFIG_DEGRADATION,
DATASET_DIV2K,
DATASET_DL_DIV2K_512,
DATASET_DL_EXTRAPRIMITIVES_DIV2K_512,
METRIC_PSNR, METRIC_SSIM, METRIC_LPIPS,
DEGRADATION_BLUR_MAT, DEGRADATION_BLUR_NONE
)
from rstor.data.dataloader import get_data_loader
from tqdm import tqdm
from pathlib import Path
import torch
from typing import Optional
import argparse
import sys
from rstor.analyzis.interactive.model_selection import get_default_models
from rstor.learning.metrics import compute_metrics
from interactive_pipe.data_objects.image import Image
from interactive_pipe.data_objects.parameters import Parameters
from typing import List
from itertools import product
import pandas as pd
ALL_TRACES = [TRACES_TARGET, TRACES_DEGRADED, TRACES_RESTORED, TRACES_METRICS]
def parse_int_pairs(s):
try:
# Split the input string by spaces to separate pairs, then split each pair by ',' and convert to tuple of ints
return [tuple(map(int, item.split(','))) for item in s.split()]
except ValueError:
raise argparse.ArgumentTypeError("Must be a series of pairs 'a,b' separated by spaces.")
def get_parser(parser: Optional[argparse.ArgumentParser] = None, batch_mode=False) -> argparse.ArgumentParser:
parser = get_models_parser(
parser=parser,
help="Inference on validation set",
default_models_path=ROOT_DIR/OUTPUT_FOLDER_NAME)
if not batch_mode:
parser.add_argument("-o", "--output-dir", type=str, default=ROOT_DIR /
INFERENCE_FOLDER_NAME, help="Output directory")
parser.add_argument("--cpu", action="store_true", help="Force CPU")
parser.add_argument("--traces", "-t", nargs="+", type=str, choices=ALL_TRACES+[TRACES_ALL],
help="Traces to be computed", default=TRACES_ALL)
parser.add_argument("--size", type=parse_int_pairs,
default=[(256, 256)], help="Size of the images like '256,512 512,512'")
parser.add_argument("--std-dev", type=parse_int_pairs, default=[(0, 50)],
help="Noise standard deviation (a, b) as pairs separated by spaces, e.g., '0,50 8,8 6,10'")
parser.add_argument("-n", "--number-of-images", type=int, default=None,
required=False, help="Number of images to process")
parser.add_argument("-d", "--dataset", type=str,
choices=[None, DATASET_DL_DIV2K_512, DATASET_DIV2K, DATASET_DL_EXTRAPRIMITIVES_DIV2K_512],
default=None),
parser.add_argument("-b", "--blur", action="store_true")
parser.add_argument("--blur-index", type=int, nargs="+", default=None)
return parser
def to_image(img: torch.Tensor):
return img.permute(0, 2, 3, 1).cpu().numpy()
def infer(model, dataloader, config, device, output_dir: Path, traces: List[str] = ALL_TRACES, number_of_images=None, degradation_key=CONFIG_DEAD_LEAVES,
chosen_metrics=[METRIC_PSNR, METRIC_SSIM]): # add METRIC_LPIPS here!
img_index = 0
if TRACES_ALL in traces:
traces = ALL_TRACES
if TRACES_METRICS in traces:
all_metrics = {}
else:
all_metrics = None
with torch.no_grad():
model.eval()
for img_degraded, img_target in tqdm(dataloader):
img_degraded = img_degraded.to(device)
img_target = img_target.to(device)
img_restored = model(img_degraded)
if TRACES_METRICS in traces:
metrics_input_per_image = compute_metrics(
img_degraded, img_target, reduction=REDUCTION_SKIP, chosen_metrics=chosen_metrics)
metrics_per_image = compute_metrics(
img_restored, img_target, reduction=REDUCTION_SKIP, chosen_metrics=chosen_metrics)
# print(metrics_per_image)
img_degraded = to_image(img_degraded)
img_target = to_image(img_target)
img_restored = to_image(img_restored)
for idx in range(img_restored.shape[0]):
degradation_parameters = dataloader.dataset.current_degradation[img_index]
common_prefix = f"{img_index:05d}_{img_degraded.shape[-3]:04d}x{img_degraded.shape[-2]:04d}"
common_prefix += f"_noise=[{config[DATALOADER][degradation_key]['noise_stddev'][0]:02d},{config[DATALOADER][degradation_key]['noise_stddev'][1]:02d}]"
suffix_deg = ""
if degradation_parameters['noise_stddev'] > 0:
suffix_deg += f"_noise={round(degradation_parameters['noise_stddev']):02d}"
suffix_deg += f"_blur={degradation_parameters['blur_kernel_id']:04d}"
#if degradation_parameters.get("blur_kernel_id", False) else ""
save_path_pred = output_dir/f"{common_prefix}_pred{suffix_deg}_{config[PRETTY_NAME]}.png"
save_path_degr = output_dir/f"{common_prefix}_degr{suffix_deg}.png"
save_path_targ = output_dir/f"{common_prefix}_targ.png"
if TRACES_RESTORED in traces:
Image(img_restored[idx]).save(save_path_pred)
if TRACES_DEGRADED in traces:
Image(img_degraded[idx]).save(save_path_degr)
if TRACES_TARGET in traces:
Image(img_target[idx]).save(save_path_targ)
if TRACES_METRICS in traces:
# current_metrics = {"in": {}, "out": {}}
# for key, value in metrics_per_image.items():
# print(f"{key}: {value[idx]:.3f}")
# current_metrics["in"][key] = metrics_input_per_image[key][idx].item()
# current_metrics["out"][key] = metrics_per_image[key][idx].item()
current_metrics = {}
for key, value in metrics_per_image.items():
current_metrics["in_"+key] = metrics_input_per_image[key][idx].item()
current_metrics["out_"+key] = metrics_per_image[key][idx].item()
current_metrics["degradation"] = degradation_parameters
current_metrics["size"] = (img_degraded.shape[-3], img_degraded.shape[-2])
current_metrics["deadleaves_config"] = config[DATALOADER][degradation_key]
current_metrics["restored"] = save_path_pred.relative_to(output_dir).as_posix()
current_metrics["degraded"] = save_path_degr.relative_to(output_dir).as_posix()
current_metrics["target"] = save_path_targ.relative_to(output_dir).as_posix()
current_metrics["model"] = config[PRETTY_NAME]
current_metrics["model_id"] = config[NAME]
Parameters(current_metrics).save(output_dir/f"{common_prefix}_metrics.json")
# for key, value in all_metrics.items():
all_metrics[img_index] = current_metrics
img_index += 1
if number_of_images is not None and img_index > number_of_images:
return all_metrics
return all_metrics
def infer_main(argv, batch_mode=False):
parser = get_parser(batch_mode=batch_mode)
if batch_mode:
batch = Batch(argv)
batch.set_io_description(
input_help='input image files',
output_help=f'output directory {str(ROOT_DIR/INFERENCE_FOLDER_NAME)}',
)
batch.parse_args(parser)
else:
args = parser.parse_args(argv)
device = "cpu" if args.cpu else DEVICE
dataset = args.dataset
blur_flag = args.blur
for exp in args.experiments:
model_dict = get_default_models([exp], Path(args.models_storage), interactive_flag=False)
# print(list(model_dict.keys()))
current_model_dict = model_dict[list(model_dict.keys())[0]]
model = current_model_dict["model"]
config = current_model_dict["config"]
for std_dev, size, blur_index in product(args.std_dev, args.size, args.blur_index):
if dataset is None:
config[DATALOADER][CONFIG_DEAD_LEAVES] = dict(
blur_kernel_half_size=[0, 0],
ds_factor=1,
noise_stddev=list(std_dev),
sampler=SAMPLER_SATURATED
)
config[DATALOADER]["gpu_gen"] = True
config[DATALOADER][SIZE] = size
config[DATALOADER][BATCH_SIZE][VALIDATION] = 1 if size[0] > 512 else 4
else:
config[DATALOADER][CONFIG_DEGRADATION] = dict(
noise_stddev=list(std_dev),
degradation_blur=DEGRADATION_BLUR_MAT if blur_flag else DEGRADATION_BLUR_NONE,
blur_index=blur_index
)
config[DATALOADER][NAME] = dataset
config[DATALOADER][SIZE] = size
config[DATALOADER][BATCH_SIZE][VALIDATION] = 1 if size[0] > 512 else 4
dataloader = get_data_loader(config, frozen_seed=42)
# print(config)
output_dir = Path(args.output_dir)/(config[NAME] + "_" +
config[PRETTY_NAME]) # + "_" + f"{size[0]:04d}x{size[1]:04d}")
output_dir.mkdir(parents=True, exist_ok=True)
all_metrics = infer(model, dataloader[VALIDATION], config, device, output_dir,
traces=args.traces, number_of_images=args.number_of_images,
degradation_key=CONFIG_DEAD_LEAVES if dataset is None else CONFIG_DEGRADATION)
if all_metrics is not None:
# print(all_metrics)
df = pd.DataFrame(all_metrics).T
prefix = f"{size[0]:04d}x{size[1]:04d}"
if not (std_dev[0] == 0 and std_dev[1] == 0):
prefix += f"_noise=[{std_dev[0]:02d},{std_dev[1]:02d}]"
if blur_index is not None:
prefix += f"_blur={blur_index:02d}"
prefix += f"_{config[PRETTY_NAME]}"
df.to_csv(output_dir/f"__{prefix}_metrics_.csv", index=False)
# Normally this could go into another script to handle the metrics analyzis
# print(df)
if __name__ == "__main__":
infer_main(sys.argv[1:])