Spaces:
Running
Running
from configuration import ROOT_DIR, OUTPUT_FOLDER_NAME, INFERENCE_FOLDER_NAME | |
from rstor.analyzis.parser import get_models_parser | |
from batch_processing import Batch | |
from rstor.properties import ( | |
DEVICE, NAME, PRETTY_NAME, DATALOADER, CONFIG_DEAD_LEAVES, VALIDATION, | |
BATCH_SIZE, SIZE, | |
REDUCTION_SKIP, | |
TRACES_TARGET, TRACES_DEGRADED, TRACES_RESTORED, TRACES_METRICS, TRACES_ALL, | |
SAMPLER_SATURATED, | |
CONFIG_DEGRADATION, | |
DATASET_DIV2K, | |
DATASET_DL_DIV2K_512, | |
DATASET_DL_EXTRAPRIMITIVES_DIV2K_512, | |
METRIC_PSNR, METRIC_SSIM, METRIC_LPIPS, | |
DEGRADATION_BLUR_MAT, DEGRADATION_BLUR_NONE | |
) | |
from rstor.data.dataloader import get_data_loader | |
from tqdm import tqdm | |
from pathlib import Path | |
import torch | |
from typing import Optional | |
import argparse | |
import sys | |
from rstor.analyzis.interactive.model_selection import get_default_models | |
from rstor.learning.metrics import compute_metrics | |
from interactive_pipe.data_objects.image import Image | |
from interactive_pipe.data_objects.parameters import Parameters | |
from typing import List | |
from itertools import product | |
import pandas as pd | |
ALL_TRACES = [TRACES_TARGET, TRACES_DEGRADED, TRACES_RESTORED, TRACES_METRICS] | |
def parse_int_pairs(s): | |
try: | |
# Split the input string by spaces to separate pairs, then split each pair by ',' and convert to tuple of ints | |
return [tuple(map(int, item.split(','))) for item in s.split()] | |
except ValueError: | |
raise argparse.ArgumentTypeError("Must be a series of pairs 'a,b' separated by spaces.") | |
def get_parser(parser: Optional[argparse.ArgumentParser] = None, batch_mode=False) -> argparse.ArgumentParser: | |
parser = get_models_parser( | |
parser=parser, | |
help="Inference on validation set", | |
default_models_path=ROOT_DIR/OUTPUT_FOLDER_NAME) | |
if not batch_mode: | |
parser.add_argument("-o", "--output-dir", type=str, default=ROOT_DIR / | |
INFERENCE_FOLDER_NAME, help="Output directory") | |
parser.add_argument("--cpu", action="store_true", help="Force CPU") | |
parser.add_argument("--traces", "-t", nargs="+", type=str, choices=ALL_TRACES+[TRACES_ALL], | |
help="Traces to be computed", default=TRACES_ALL) | |
parser.add_argument("--size", type=parse_int_pairs, | |
default=[(256, 256)], help="Size of the images like '256,512 512,512'") | |
parser.add_argument("--std-dev", type=parse_int_pairs, default=[(0, 50)], | |
help="Noise standard deviation (a, b) as pairs separated by spaces, e.g., '0,50 8,8 6,10'") | |
parser.add_argument("-n", "--number-of-images", type=int, default=None, | |
required=False, help="Number of images to process") | |
parser.add_argument("-d", "--dataset", type=str, | |
choices=[None, DATASET_DL_DIV2K_512, DATASET_DIV2K, DATASET_DL_EXTRAPRIMITIVES_DIV2K_512], | |
default=None), | |
parser.add_argument("-b", "--blur", action="store_true") | |
parser.add_argument("--blur-index", type=int, nargs="+", default=None) | |
return parser | |
def to_image(img: torch.Tensor): | |
return img.permute(0, 2, 3, 1).cpu().numpy() | |
def infer(model, dataloader, config, device, output_dir: Path, traces: List[str] = ALL_TRACES, number_of_images=None, degradation_key=CONFIG_DEAD_LEAVES, | |
chosen_metrics=[METRIC_PSNR, METRIC_SSIM]): # add METRIC_LPIPS here! | |
img_index = 0 | |
if TRACES_ALL in traces: | |
traces = ALL_TRACES | |
if TRACES_METRICS in traces: | |
all_metrics = {} | |
else: | |
all_metrics = None | |
with torch.no_grad(): | |
model.eval() | |
for img_degraded, img_target in tqdm(dataloader): | |
img_degraded = img_degraded.to(device) | |
img_target = img_target.to(device) | |
img_restored = model(img_degraded) | |
if TRACES_METRICS in traces: | |
metrics_input_per_image = compute_metrics( | |
img_degraded, img_target, reduction=REDUCTION_SKIP, chosen_metrics=chosen_metrics) | |
metrics_per_image = compute_metrics( | |
img_restored, img_target, reduction=REDUCTION_SKIP, chosen_metrics=chosen_metrics) | |
# print(metrics_per_image) | |
img_degraded = to_image(img_degraded) | |
img_target = to_image(img_target) | |
img_restored = to_image(img_restored) | |
for idx in range(img_restored.shape[0]): | |
degradation_parameters = dataloader.dataset.current_degradation[img_index] | |
common_prefix = f"{img_index:05d}_{img_degraded.shape[-3]:04d}x{img_degraded.shape[-2]:04d}" | |
common_prefix += f"_noise=[{config[DATALOADER][degradation_key]['noise_stddev'][0]:02d},{config[DATALOADER][degradation_key]['noise_stddev'][1]:02d}]" | |
suffix_deg = "" | |
if degradation_parameters['noise_stddev'] > 0: | |
suffix_deg += f"_noise={round(degradation_parameters['noise_stddev']):02d}" | |
suffix_deg += f"_blur={degradation_parameters['blur_kernel_id']:04d}" | |
#if degradation_parameters.get("blur_kernel_id", False) else "" | |
save_path_pred = output_dir/f"{common_prefix}_pred{suffix_deg}_{config[PRETTY_NAME]}.png" | |
save_path_degr = output_dir/f"{common_prefix}_degr{suffix_deg}.png" | |
save_path_targ = output_dir/f"{common_prefix}_targ.png" | |
if TRACES_RESTORED in traces: | |
Image(img_restored[idx]).save(save_path_pred) | |
if TRACES_DEGRADED in traces: | |
Image(img_degraded[idx]).save(save_path_degr) | |
if TRACES_TARGET in traces: | |
Image(img_target[idx]).save(save_path_targ) | |
if TRACES_METRICS in traces: | |
# current_metrics = {"in": {}, "out": {}} | |
# for key, value in metrics_per_image.items(): | |
# print(f"{key}: {value[idx]:.3f}") | |
# current_metrics["in"][key] = metrics_input_per_image[key][idx].item() | |
# current_metrics["out"][key] = metrics_per_image[key][idx].item() | |
current_metrics = {} | |
for key, value in metrics_per_image.items(): | |
current_metrics["in_"+key] = metrics_input_per_image[key][idx].item() | |
current_metrics["out_"+key] = metrics_per_image[key][idx].item() | |
current_metrics["degradation"] = degradation_parameters | |
current_metrics["size"] = (img_degraded.shape[-3], img_degraded.shape[-2]) | |
current_metrics["deadleaves_config"] = config[DATALOADER][degradation_key] | |
current_metrics["restored"] = save_path_pred.relative_to(output_dir).as_posix() | |
current_metrics["degraded"] = save_path_degr.relative_to(output_dir).as_posix() | |
current_metrics["target"] = save_path_targ.relative_to(output_dir).as_posix() | |
current_metrics["model"] = config[PRETTY_NAME] | |
current_metrics["model_id"] = config[NAME] | |
Parameters(current_metrics).save(output_dir/f"{common_prefix}_metrics.json") | |
# for key, value in all_metrics.items(): | |
all_metrics[img_index] = current_metrics | |
img_index += 1 | |
if number_of_images is not None and img_index > number_of_images: | |
return all_metrics | |
return all_metrics | |
def infer_main(argv, batch_mode=False): | |
parser = get_parser(batch_mode=batch_mode) | |
if batch_mode: | |
batch = Batch(argv) | |
batch.set_io_description( | |
input_help='input image files', | |
output_help=f'output directory {str(ROOT_DIR/INFERENCE_FOLDER_NAME)}', | |
) | |
batch.parse_args(parser) | |
else: | |
args = parser.parse_args(argv) | |
device = "cpu" if args.cpu else DEVICE | |
dataset = args.dataset | |
blur_flag = args.blur | |
for exp in args.experiments: | |
model_dict = get_default_models([exp], Path(args.models_storage), interactive_flag=False) | |
# print(list(model_dict.keys())) | |
current_model_dict = model_dict[list(model_dict.keys())[0]] | |
model = current_model_dict["model"] | |
config = current_model_dict["config"] | |
for std_dev, size, blur_index in product(args.std_dev, args.size, args.blur_index): | |
if dataset is None: | |
config[DATALOADER][CONFIG_DEAD_LEAVES] = dict( | |
blur_kernel_half_size=[0, 0], | |
ds_factor=1, | |
noise_stddev=list(std_dev), | |
sampler=SAMPLER_SATURATED | |
) | |
config[DATALOADER]["gpu_gen"] = True | |
config[DATALOADER][SIZE] = size | |
config[DATALOADER][BATCH_SIZE][VALIDATION] = 1 if size[0] > 512 else 4 | |
else: | |
config[DATALOADER][CONFIG_DEGRADATION] = dict( | |
noise_stddev=list(std_dev), | |
degradation_blur=DEGRADATION_BLUR_MAT if blur_flag else DEGRADATION_BLUR_NONE, | |
blur_index=blur_index | |
) | |
config[DATALOADER][NAME] = dataset | |
config[DATALOADER][SIZE] = size | |
config[DATALOADER][BATCH_SIZE][VALIDATION] = 1 if size[0] > 512 else 4 | |
dataloader = get_data_loader(config, frozen_seed=42) | |
# print(config) | |
output_dir = Path(args.output_dir)/(config[NAME] + "_" + | |
config[PRETTY_NAME]) # + "_" + f"{size[0]:04d}x{size[1]:04d}") | |
output_dir.mkdir(parents=True, exist_ok=True) | |
all_metrics = infer(model, dataloader[VALIDATION], config, device, output_dir, | |
traces=args.traces, number_of_images=args.number_of_images, | |
degradation_key=CONFIG_DEAD_LEAVES if dataset is None else CONFIG_DEGRADATION) | |
if all_metrics is not None: | |
# print(all_metrics) | |
df = pd.DataFrame(all_metrics).T | |
prefix = f"{size[0]:04d}x{size[1]:04d}" | |
if not (std_dev[0] == 0 and std_dev[1] == 0): | |
prefix += f"_noise=[{std_dev[0]:02d},{std_dev[1]:02d}]" | |
if blur_index is not None: | |
prefix += f"_blur={blur_index:02d}" | |
prefix += f"_{config[PRETTY_NAME]}" | |
df.to_csv(output_dir/f"__{prefix}_metrics_.csv", index=False) | |
# Normally this could go into another script to handle the metrics analyzis | |
# print(df) | |
if __name__ == "__main__": | |
infer_main(sys.argv[1:]) | |