File size: 10,749 Bytes
cec5823
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
from configuration import ROOT_DIR, OUTPUT_FOLDER_NAME, INFERENCE_FOLDER_NAME
from rstor.analyzis.parser import get_models_parser
from batch_processing import Batch
from rstor.properties import (
    DEVICE, NAME, PRETTY_NAME, DATALOADER, CONFIG_DEAD_LEAVES, VALIDATION,
    BATCH_SIZE, SIZE,
    REDUCTION_SKIP,
    TRACES_TARGET, TRACES_DEGRADED, TRACES_RESTORED, TRACES_METRICS, TRACES_ALL,
    SAMPLER_SATURATED,
    CONFIG_DEGRADATION,
    DATASET_DIV2K,
    DATASET_DL_DIV2K_512,
    DATASET_DL_EXTRAPRIMITIVES_DIV2K_512,
    METRIC_PSNR, METRIC_SSIM, METRIC_LPIPS,
    DEGRADATION_BLUR_MAT, DEGRADATION_BLUR_NONE
)
from rstor.data.dataloader import get_data_loader
from tqdm import tqdm
from pathlib import Path
import torch
from typing import Optional
import argparse
import sys
from rstor.analyzis.interactive.model_selection import get_default_models
from rstor.learning.metrics import compute_metrics
from interactive_pipe.data_objects.image import Image
from interactive_pipe.data_objects.parameters import Parameters
from typing import List
from itertools import product
import pandas as pd
ALL_TRACES = [TRACES_TARGET, TRACES_DEGRADED, TRACES_RESTORED, TRACES_METRICS]


def parse_int_pairs(s):
    try:
        # Split the input string by spaces to separate pairs, then split each pair by ',' and convert to tuple of ints
        return [tuple(map(int, item.split(','))) for item in s.split()]
    except ValueError:
        raise argparse.ArgumentTypeError("Must be a series of pairs 'a,b' separated by spaces.")


def get_parser(parser: Optional[argparse.ArgumentParser] = None, batch_mode=False) -> argparse.ArgumentParser:
    parser = get_models_parser(
        parser=parser,
        help="Inference on validation set",
        default_models_path=ROOT_DIR/OUTPUT_FOLDER_NAME)
    if not batch_mode:
        parser.add_argument("-o", "--output-dir", type=str, default=ROOT_DIR /
                            INFERENCE_FOLDER_NAME, help="Output directory")
    parser.add_argument("--cpu", action="store_true", help="Force CPU")
    parser.add_argument("--traces", "-t", nargs="+", type=str, choices=ALL_TRACES+[TRACES_ALL],
                        help="Traces to be computed", default=TRACES_ALL)
    parser.add_argument("--size", type=parse_int_pairs,
                        default=[(256, 256)], help="Size of the images like '256,512 512,512'")
    parser.add_argument("--std-dev", type=parse_int_pairs, default=[(0, 50)],
                        help="Noise standard deviation (a, b) as pairs separated by spaces, e.g., '0,50 8,8 6,10'")
    parser.add_argument("-n", "--number-of-images", type=int, default=None,
                        required=False, help="Number of images to process")
    parser.add_argument("-d", "--dataset", type=str,
                        choices=[None,  DATASET_DL_DIV2K_512, DATASET_DIV2K, DATASET_DL_EXTRAPRIMITIVES_DIV2K_512],
                        default=None),
    parser.add_argument("-b", "--blur", action="store_true")
    parser.add_argument("--blur-index", type=int, nargs="+", default=None)
    return parser


def to_image(img: torch.Tensor):
    return img.permute(0, 2, 3, 1).cpu().numpy()


def infer(model, dataloader, config, device, output_dir: Path, traces: List[str] = ALL_TRACES, number_of_images=None, degradation_key=CONFIG_DEAD_LEAVES,
          chosen_metrics=[METRIC_PSNR, METRIC_SSIM]):  # add METRIC_LPIPS here!
    img_index = 0
    if TRACES_ALL in traces:
        traces = ALL_TRACES
    if TRACES_METRICS in traces:
        all_metrics = {}
    else:
        all_metrics = None
    with torch.no_grad():
        model.eval()
        for img_degraded, img_target in tqdm(dataloader):
            img_degraded = img_degraded.to(device)
            img_target = img_target.to(device)
            img_restored = model(img_degraded)
            if TRACES_METRICS in traces:
                metrics_input_per_image = compute_metrics(
                    img_degraded, img_target, reduction=REDUCTION_SKIP, chosen_metrics=chosen_metrics)
                metrics_per_image = compute_metrics(
                    img_restored, img_target, reduction=REDUCTION_SKIP, chosen_metrics=chosen_metrics)
                # print(metrics_per_image)
            img_degraded = to_image(img_degraded)
            img_target = to_image(img_target)
            img_restored = to_image(img_restored)
            for idx in range(img_restored.shape[0]):
                degradation_parameters = dataloader.dataset.current_degradation[img_index]
                common_prefix = f"{img_index:05d}_{img_degraded.shape[-3]:04d}x{img_degraded.shape[-2]:04d}"
                common_prefix += f"_noise=[{config[DATALOADER][degradation_key]['noise_stddev'][0]:02d},{config[DATALOADER][degradation_key]['noise_stddev'][1]:02d}]"
                suffix_deg = ""
                if degradation_parameters['noise_stddev'] > 0:
                    suffix_deg += f"_noise={round(degradation_parameters['noise_stddev']):02d}"
                suffix_deg += f"_blur={degradation_parameters['blur_kernel_id']:04d}"
                #if degradation_parameters.get("blur_kernel_id", False) else ""
                save_path_pred = output_dir/f"{common_prefix}_pred{suffix_deg}_{config[PRETTY_NAME]}.png"
                save_path_degr = output_dir/f"{common_prefix}_degr{suffix_deg}.png"
                save_path_targ = output_dir/f"{common_prefix}_targ.png"
                if TRACES_RESTORED in traces:
                    Image(img_restored[idx]).save(save_path_pred)
                if TRACES_DEGRADED in traces:
                    Image(img_degraded[idx]).save(save_path_degr)
                if TRACES_TARGET in traces:
                    Image(img_target[idx]).save(save_path_targ)
                if TRACES_METRICS in traces:
                    # current_metrics = {"in": {}, "out": {}}
                    # for key, value in metrics_per_image.items():
                    #     print(f"{key}: {value[idx]:.3f}")
                    #     current_metrics["in"][key] = metrics_input_per_image[key][idx].item()
                    #     current_metrics["out"][key] = metrics_per_image[key][idx].item()
                    current_metrics = {}
                    for key, value in metrics_per_image.items():
                        current_metrics["in_"+key] = metrics_input_per_image[key][idx].item()
                        current_metrics["out_"+key] = metrics_per_image[key][idx].item()
                    current_metrics["degradation"] = degradation_parameters
                    current_metrics["size"] = (img_degraded.shape[-3], img_degraded.shape[-2])
                    current_metrics["deadleaves_config"] = config[DATALOADER][degradation_key]
                    current_metrics["restored"] = save_path_pred.relative_to(output_dir).as_posix()
                    current_metrics["degraded"] = save_path_degr.relative_to(output_dir).as_posix()
                    current_metrics["target"] = save_path_targ.relative_to(output_dir).as_posix()
                    current_metrics["model"] = config[PRETTY_NAME]
                    current_metrics["model_id"] = config[NAME]
                    Parameters(current_metrics).save(output_dir/f"{common_prefix}_metrics.json")
                    # for key, value in all_metrics.items():

                    all_metrics[img_index] = current_metrics
                img_index += 1
                if number_of_images is not None and img_index > number_of_images:
                    return all_metrics
    return all_metrics


def infer_main(argv, batch_mode=False):
    parser = get_parser(batch_mode=batch_mode)
    if batch_mode:
        batch = Batch(argv)
        batch.set_io_description(
            input_help='input image files',
            output_help=f'output directory {str(ROOT_DIR/INFERENCE_FOLDER_NAME)}',
        )
        batch.parse_args(parser)
    else:
        args = parser.parse_args(argv)
    device = "cpu" if args.cpu else DEVICE
    dataset = args.dataset
    blur_flag = args.blur
    for exp in args.experiments:
        model_dict = get_default_models([exp], Path(args.models_storage), interactive_flag=False)
        # print(list(model_dict.keys()))
        current_model_dict = model_dict[list(model_dict.keys())[0]]
        model = current_model_dict["model"]
        config = current_model_dict["config"]
        for std_dev, size, blur_index in product(args.std_dev, args.size, args.blur_index):
            if dataset is None:
                config[DATALOADER][CONFIG_DEAD_LEAVES] = dict(
                    blur_kernel_half_size=[0, 0],
                    ds_factor=1,
                    noise_stddev=list(std_dev),
                    sampler=SAMPLER_SATURATED
                )
                config[DATALOADER]["gpu_gen"] = True
                config[DATALOADER][SIZE] = size
                config[DATALOADER][BATCH_SIZE][VALIDATION] = 1 if size[0] > 512 else 4
            else:
                config[DATALOADER][CONFIG_DEGRADATION] = dict(
                    noise_stddev=list(std_dev),
                    degradation_blur=DEGRADATION_BLUR_MAT if blur_flag else DEGRADATION_BLUR_NONE,
                    blur_index=blur_index
                )
                config[DATALOADER][NAME] = dataset
                config[DATALOADER][SIZE] = size
                config[DATALOADER][BATCH_SIZE][VALIDATION] = 1 if size[0] > 512 else 4
            dataloader = get_data_loader(config, frozen_seed=42)
            # print(config)
            output_dir = Path(args.output_dir)/(config[NAME] + "_" +
                                                config[PRETTY_NAME])  # + "_" + f"{size[0]:04d}x{size[1]:04d}")
            output_dir.mkdir(parents=True, exist_ok=True)

            all_metrics = infer(model, dataloader[VALIDATION], config, device, output_dir,
                                traces=args.traces, number_of_images=args.number_of_images,
                                degradation_key=CONFIG_DEAD_LEAVES if dataset is None else CONFIG_DEGRADATION)
            if all_metrics is not None:
                # print(all_metrics)
                df = pd.DataFrame(all_metrics).T
                prefix = f"{size[0]:04d}x{size[1]:04d}"
                if not (std_dev[0] == 0 and std_dev[1] == 0):
                    prefix += f"_noise=[{std_dev[0]:02d},{std_dev[1]:02d}]"
                if blur_index is not None:
                    prefix += f"_blur={blur_index:02d}"
                prefix += f"_{config[PRETTY_NAME]}"
                df.to_csv(output_dir/f"__{prefix}_metrics_.csv", index=False)
                # Normally this could go into another script to handle the metrics analyzis
                # print(df)


if __name__ == "__main__":
    infer_main(sys.argv[1:])