Spaces:
Running
Running
File size: 10,749 Bytes
cec5823 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 |
from configuration import ROOT_DIR, OUTPUT_FOLDER_NAME, INFERENCE_FOLDER_NAME
from rstor.analyzis.parser import get_models_parser
from batch_processing import Batch
from rstor.properties import (
DEVICE, NAME, PRETTY_NAME, DATALOADER, CONFIG_DEAD_LEAVES, VALIDATION,
BATCH_SIZE, SIZE,
REDUCTION_SKIP,
TRACES_TARGET, TRACES_DEGRADED, TRACES_RESTORED, TRACES_METRICS, TRACES_ALL,
SAMPLER_SATURATED,
CONFIG_DEGRADATION,
DATASET_DIV2K,
DATASET_DL_DIV2K_512,
DATASET_DL_EXTRAPRIMITIVES_DIV2K_512,
METRIC_PSNR, METRIC_SSIM, METRIC_LPIPS,
DEGRADATION_BLUR_MAT, DEGRADATION_BLUR_NONE
)
from rstor.data.dataloader import get_data_loader
from tqdm import tqdm
from pathlib import Path
import torch
from typing import Optional
import argparse
import sys
from rstor.analyzis.interactive.model_selection import get_default_models
from rstor.learning.metrics import compute_metrics
from interactive_pipe.data_objects.image import Image
from interactive_pipe.data_objects.parameters import Parameters
from typing import List
from itertools import product
import pandas as pd
ALL_TRACES = [TRACES_TARGET, TRACES_DEGRADED, TRACES_RESTORED, TRACES_METRICS]
def parse_int_pairs(s):
try:
# Split the input string by spaces to separate pairs, then split each pair by ',' and convert to tuple of ints
return [tuple(map(int, item.split(','))) for item in s.split()]
except ValueError:
raise argparse.ArgumentTypeError("Must be a series of pairs 'a,b' separated by spaces.")
def get_parser(parser: Optional[argparse.ArgumentParser] = None, batch_mode=False) -> argparse.ArgumentParser:
parser = get_models_parser(
parser=parser,
help="Inference on validation set",
default_models_path=ROOT_DIR/OUTPUT_FOLDER_NAME)
if not batch_mode:
parser.add_argument("-o", "--output-dir", type=str, default=ROOT_DIR /
INFERENCE_FOLDER_NAME, help="Output directory")
parser.add_argument("--cpu", action="store_true", help="Force CPU")
parser.add_argument("--traces", "-t", nargs="+", type=str, choices=ALL_TRACES+[TRACES_ALL],
help="Traces to be computed", default=TRACES_ALL)
parser.add_argument("--size", type=parse_int_pairs,
default=[(256, 256)], help="Size of the images like '256,512 512,512'")
parser.add_argument("--std-dev", type=parse_int_pairs, default=[(0, 50)],
help="Noise standard deviation (a, b) as pairs separated by spaces, e.g., '0,50 8,8 6,10'")
parser.add_argument("-n", "--number-of-images", type=int, default=None,
required=False, help="Number of images to process")
parser.add_argument("-d", "--dataset", type=str,
choices=[None, DATASET_DL_DIV2K_512, DATASET_DIV2K, DATASET_DL_EXTRAPRIMITIVES_DIV2K_512],
default=None),
parser.add_argument("-b", "--blur", action="store_true")
parser.add_argument("--blur-index", type=int, nargs="+", default=None)
return parser
def to_image(img: torch.Tensor):
return img.permute(0, 2, 3, 1).cpu().numpy()
def infer(model, dataloader, config, device, output_dir: Path, traces: List[str] = ALL_TRACES, number_of_images=None, degradation_key=CONFIG_DEAD_LEAVES,
chosen_metrics=[METRIC_PSNR, METRIC_SSIM]): # add METRIC_LPIPS here!
img_index = 0
if TRACES_ALL in traces:
traces = ALL_TRACES
if TRACES_METRICS in traces:
all_metrics = {}
else:
all_metrics = None
with torch.no_grad():
model.eval()
for img_degraded, img_target in tqdm(dataloader):
img_degraded = img_degraded.to(device)
img_target = img_target.to(device)
img_restored = model(img_degraded)
if TRACES_METRICS in traces:
metrics_input_per_image = compute_metrics(
img_degraded, img_target, reduction=REDUCTION_SKIP, chosen_metrics=chosen_metrics)
metrics_per_image = compute_metrics(
img_restored, img_target, reduction=REDUCTION_SKIP, chosen_metrics=chosen_metrics)
# print(metrics_per_image)
img_degraded = to_image(img_degraded)
img_target = to_image(img_target)
img_restored = to_image(img_restored)
for idx in range(img_restored.shape[0]):
degradation_parameters = dataloader.dataset.current_degradation[img_index]
common_prefix = f"{img_index:05d}_{img_degraded.shape[-3]:04d}x{img_degraded.shape[-2]:04d}"
common_prefix += f"_noise=[{config[DATALOADER][degradation_key]['noise_stddev'][0]:02d},{config[DATALOADER][degradation_key]['noise_stddev'][1]:02d}]"
suffix_deg = ""
if degradation_parameters['noise_stddev'] > 0:
suffix_deg += f"_noise={round(degradation_parameters['noise_stddev']):02d}"
suffix_deg += f"_blur={degradation_parameters['blur_kernel_id']:04d}"
#if degradation_parameters.get("blur_kernel_id", False) else ""
save_path_pred = output_dir/f"{common_prefix}_pred{suffix_deg}_{config[PRETTY_NAME]}.png"
save_path_degr = output_dir/f"{common_prefix}_degr{suffix_deg}.png"
save_path_targ = output_dir/f"{common_prefix}_targ.png"
if TRACES_RESTORED in traces:
Image(img_restored[idx]).save(save_path_pred)
if TRACES_DEGRADED in traces:
Image(img_degraded[idx]).save(save_path_degr)
if TRACES_TARGET in traces:
Image(img_target[idx]).save(save_path_targ)
if TRACES_METRICS in traces:
# current_metrics = {"in": {}, "out": {}}
# for key, value in metrics_per_image.items():
# print(f"{key}: {value[idx]:.3f}")
# current_metrics["in"][key] = metrics_input_per_image[key][idx].item()
# current_metrics["out"][key] = metrics_per_image[key][idx].item()
current_metrics = {}
for key, value in metrics_per_image.items():
current_metrics["in_"+key] = metrics_input_per_image[key][idx].item()
current_metrics["out_"+key] = metrics_per_image[key][idx].item()
current_metrics["degradation"] = degradation_parameters
current_metrics["size"] = (img_degraded.shape[-3], img_degraded.shape[-2])
current_metrics["deadleaves_config"] = config[DATALOADER][degradation_key]
current_metrics["restored"] = save_path_pred.relative_to(output_dir).as_posix()
current_metrics["degraded"] = save_path_degr.relative_to(output_dir).as_posix()
current_metrics["target"] = save_path_targ.relative_to(output_dir).as_posix()
current_metrics["model"] = config[PRETTY_NAME]
current_metrics["model_id"] = config[NAME]
Parameters(current_metrics).save(output_dir/f"{common_prefix}_metrics.json")
# for key, value in all_metrics.items():
all_metrics[img_index] = current_metrics
img_index += 1
if number_of_images is not None and img_index > number_of_images:
return all_metrics
return all_metrics
def infer_main(argv, batch_mode=False):
parser = get_parser(batch_mode=batch_mode)
if batch_mode:
batch = Batch(argv)
batch.set_io_description(
input_help='input image files',
output_help=f'output directory {str(ROOT_DIR/INFERENCE_FOLDER_NAME)}',
)
batch.parse_args(parser)
else:
args = parser.parse_args(argv)
device = "cpu" if args.cpu else DEVICE
dataset = args.dataset
blur_flag = args.blur
for exp in args.experiments:
model_dict = get_default_models([exp], Path(args.models_storage), interactive_flag=False)
# print(list(model_dict.keys()))
current_model_dict = model_dict[list(model_dict.keys())[0]]
model = current_model_dict["model"]
config = current_model_dict["config"]
for std_dev, size, blur_index in product(args.std_dev, args.size, args.blur_index):
if dataset is None:
config[DATALOADER][CONFIG_DEAD_LEAVES] = dict(
blur_kernel_half_size=[0, 0],
ds_factor=1,
noise_stddev=list(std_dev),
sampler=SAMPLER_SATURATED
)
config[DATALOADER]["gpu_gen"] = True
config[DATALOADER][SIZE] = size
config[DATALOADER][BATCH_SIZE][VALIDATION] = 1 if size[0] > 512 else 4
else:
config[DATALOADER][CONFIG_DEGRADATION] = dict(
noise_stddev=list(std_dev),
degradation_blur=DEGRADATION_BLUR_MAT if blur_flag else DEGRADATION_BLUR_NONE,
blur_index=blur_index
)
config[DATALOADER][NAME] = dataset
config[DATALOADER][SIZE] = size
config[DATALOADER][BATCH_SIZE][VALIDATION] = 1 if size[0] > 512 else 4
dataloader = get_data_loader(config, frozen_seed=42)
# print(config)
output_dir = Path(args.output_dir)/(config[NAME] + "_" +
config[PRETTY_NAME]) # + "_" + f"{size[0]:04d}x{size[1]:04d}")
output_dir.mkdir(parents=True, exist_ok=True)
all_metrics = infer(model, dataloader[VALIDATION], config, device, output_dir,
traces=args.traces, number_of_images=args.number_of_images,
degradation_key=CONFIG_DEAD_LEAVES if dataset is None else CONFIG_DEGRADATION)
if all_metrics is not None:
# print(all_metrics)
df = pd.DataFrame(all_metrics).T
prefix = f"{size[0]:04d}x{size[1]:04d}"
if not (std_dev[0] == 0 and std_dev[1] == 0):
prefix += f"_noise=[{std_dev[0]:02d},{std_dev[1]:02d}]"
if blur_index is not None:
prefix += f"_blur={blur_index:02d}"
prefix += f"_{config[PRETTY_NAME]}"
df.to_csv(output_dir/f"__{prefix}_metrics_.csv", index=False)
# Normally this could go into another script to handle the metrics analyzis
# print(df)
if __name__ == "__main__":
infer_main(sys.argv[1:])
|