Spaces:
Running
Running
initiate demo
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- __dataset/sample/0000.png +0 -0
- __dataset/sample/0001.jpg +0 -0
- __dataset/sample/0002.png +0 -0
- app.py +66 -0
- requirements.txt +6 -0
- scripts/configuration.py +18 -0
- scripts/infer.py +205 -0
- scripts/interactive_inference_natural.py +65 -0
- scripts/interactive_inference_synthetic.py +25 -0
- scripts/metrics_analyzis.ipynb +0 -0
- scripts/quantitative_results.ipynb +0 -0
- scripts/remote_training.py +116 -0
- scripts/remote_training_template.ipynb +1 -0
- scripts/save_deadleaves.py +172 -0
- scripts/train.py +171 -0
- src/rstor/__init__.py +0 -0
- src/rstor/analyzis/interactive/crop.py +75 -0
- src/rstor/analyzis/interactive/degradation.py +71 -0
- src/rstor/analyzis/interactive/images.py +10 -0
- src/rstor/analyzis/interactive/inference.py +12 -0
- src/rstor/analyzis/interactive/metrics.py +36 -0
- src/rstor/analyzis/interactive/model_selection.py +58 -0
- src/rstor/analyzis/interactive/pipelines.py +61 -0
- src/rstor/analyzis/metrics_plots.py +73 -0
- src/rstor/analyzis/parser.py +26 -0
- src/rstor/architecture/base.py +56 -0
- src/rstor/architecture/convolution_blocks.py +47 -0
- src/rstor/architecture/nafnet.py +299 -0
- src/rstor/architecture/selector.py +19 -0
- src/rstor/architecture/stacked_convolutions.py +30 -0
- src/rstor/data/augmentation.py +27 -0
- src/rstor/data/dataloader.py +120 -0
- src/rstor/data/degradation.py +156 -0
- src/rstor/data/stored_images_dataloader.py +156 -0
- src/rstor/data/synthetic_dataloader.py +187 -0
- src/rstor/learning/experiments.py +24 -0
- src/rstor/learning/experiments_definition.py +489 -0
- src/rstor/learning/loss.py +25 -0
- src/rstor/learning/metrics.py +140 -0
- src/rstor/properties.py +67 -0
- src/rstor/synthetic_data/color_sampler.py +73 -0
- src/rstor/synthetic_data/dead_leaves_cpu.py +79 -0
- src/rstor/synthetic_data/dead_leaves_gpu.py +176 -0
- src/rstor/synthetic_data/dead_leaves_sampler.py +74 -0
- src/rstor/synthetic_data/interactive/interactive_dead_leaves.py +81 -0
- src/rstor/utils.py +12 -0
- test/test_dataloader.py +49 -0
- test/test_dataloader_gpu.py +56 -0
- test/test_dataloader_stored.py +67 -0
- test/test_dead_leaves.py +44 -0
__dataset/sample/0000.png
ADDED
![]() |
__dataset/sample/0001.jpg
ADDED
![]() |
__dataset/sample/0002.png
ADDED
![]() |
app.py
ADDED
@@ -0,0 +1,66 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import sys
|
2 |
+
sys.path.append("src")
|
3 |
+
from interactive_pipe import interactive_pipeline
|
4 |
+
from rstor.analyzis.interactive.pipelines import natural_inference_pipeline, morph_canvas, CANVAS
|
5 |
+
from rstor.analyzis.interactive.model_selection import get_default_models
|
6 |
+
from pathlib import Path
|
7 |
+
from rstor.analyzis.parser import get_parser
|
8 |
+
import argparse
|
9 |
+
from batch_processing import Batch
|
10 |
+
from interactive_pipe.data_objects.image import Image
|
11 |
+
from rstor.analyzis.interactive.images import image_selector
|
12 |
+
from rstor.analyzis.interactive.crop import plug_crop_selector
|
13 |
+
from rstor.analyzis.interactive.metrics import plug_configure_metrics
|
14 |
+
from interactive_pipe import interactive, KeyboardControl
|
15 |
+
|
16 |
+
|
17 |
+
def plug_morph_canvas():
|
18 |
+
interactive(
|
19 |
+
canvas=KeyboardControl(CANVAS[0], CANVAS, name="canvas", keyup="p", modulo=True)
|
20 |
+
)(morph_canvas)
|
21 |
+
|
22 |
+
|
23 |
+
def image_loading_batch(input: Path, args: argparse.Namespace) -> dict:
|
24 |
+
"""Wrapper to load images files from a directory using batch_processing
|
25 |
+
"""
|
26 |
+
|
27 |
+
if not args.disable_preload:
|
28 |
+
img = Image.from_file(input).data
|
29 |
+
return {"name": input.name, "path": input, "buffer": img}
|
30 |
+
else:
|
31 |
+
return {"name": input.name, "path": input, "buffer": None}
|
32 |
+
|
33 |
+
|
34 |
+
def main(argv):
|
35 |
+
batch = Batch(argv)
|
36 |
+
batch.set_io_description(
|
37 |
+
input_help='input image files',
|
38 |
+
output_help=argparse.SUPPRESS
|
39 |
+
)
|
40 |
+
parser = get_parser()
|
41 |
+
parser.add_argument("-nop", "--disable-preload", action="store_true", help="Disable images preload")
|
42 |
+
args = batch.parse_args(parser)
|
43 |
+
# batch.set_multiprocessing_enabled(False)
|
44 |
+
img_list = batch.run(image_loading_batch)
|
45 |
+
if args.keyboard:
|
46 |
+
image_control = KeyboardControl(0, [0, len(img_list)-1], keydown="3", keyup="9", modulo=True)
|
47 |
+
else:
|
48 |
+
image_control = (0, [0, len(img_list)-1])
|
49 |
+
interactive(image_index=image_control)(image_selector)
|
50 |
+
plug_crop_selector(num_pad=args.keyboard)
|
51 |
+
plug_configure_metrics(key_shortcut="a") # "a" if args.keyboard else None)
|
52 |
+
plug_morph_canvas()
|
53 |
+
model_dict = get_default_models(args.experiments, Path(args.models_storage), keyboard_control=args.keyboard)
|
54 |
+
interactive_pipeline(
|
55 |
+
gui=args.backend,
|
56 |
+
cache=True,
|
57 |
+
safe_input_buffer_deepcopy=False
|
58 |
+
)(natural_inference_pipeline)(
|
59 |
+
img_list,
|
60 |
+
model_dict
|
61 |
+
)
|
62 |
+
|
63 |
+
|
64 |
+
if __name__ == "__main__":
|
65 |
+
# main(sys.argv[1:])
|
66 |
+
main(["-e", "6002", "-i", "__dataset/sample/*.*g", "-b","gradio"])
|
requirements.txt
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
interactive-pipe>=0.7.8
|
2 |
+
opencv_python_headless==4.8.0.74
|
3 |
+
torch>=2.0.0
|
4 |
+
tqdm
|
5 |
+
|
6 |
+
|
scripts/configuration.py
ADDED
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from pathlib import Path
|
2 |
+
|
3 |
+
NB_ID = "blind-deblurring-from-synthetic-data" # This will be the name which appears on Kaggle.
|
4 |
+
GIT_USER = "balthazarneveu" # Your git user name
|
5 |
+
GIT_REPO = "blind-deblurring-from-synthetic-data" # Your current git repo
|
6 |
+
# Keep free unless you need to acess kaggle datasets. You'll need to modify the remote_training_template.ipynb.
|
7 |
+
KAGGLE_DATASET_LIST = [
|
8 |
+
"balthazarneveu/deadleaves-div2k-512", # Deadleaves classic
|
9 |
+
"balthazarneveu/deadleaves-primitives-div2k-512", # Deadleaves with extra primitives
|
10 |
+
"balthazarneveu/motion-blur-kernels", # Motion blur kernels
|
11 |
+
"joe1995/div2k-dataset",
|
12 |
+
]
|
13 |
+
WANDBSPACE = "deblur-from-deadleaves"
|
14 |
+
TRAIN_SCRIPT = "scripts/train.py" # Location of the training script
|
15 |
+
|
16 |
+
ROOT_DIR = Path(__file__).parent
|
17 |
+
OUTPUT_FOLDER_NAME = "__output"
|
18 |
+
INFERENCE_FOLDER_NAME = "__inference"
|
scripts/infer.py
ADDED
@@ -0,0 +1,205 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from configuration import ROOT_DIR, OUTPUT_FOLDER_NAME, INFERENCE_FOLDER_NAME
|
2 |
+
from rstor.analyzis.parser import get_models_parser
|
3 |
+
from batch_processing import Batch
|
4 |
+
from rstor.properties import (
|
5 |
+
DEVICE, NAME, PRETTY_NAME, DATALOADER, CONFIG_DEAD_LEAVES, VALIDATION,
|
6 |
+
BATCH_SIZE, SIZE,
|
7 |
+
REDUCTION_SKIP,
|
8 |
+
TRACES_TARGET, TRACES_DEGRADED, TRACES_RESTORED, TRACES_METRICS, TRACES_ALL,
|
9 |
+
SAMPLER_SATURATED,
|
10 |
+
CONFIG_DEGRADATION,
|
11 |
+
DATASET_DIV2K,
|
12 |
+
DATASET_DL_DIV2K_512,
|
13 |
+
DATASET_DL_EXTRAPRIMITIVES_DIV2K_512,
|
14 |
+
METRIC_PSNR, METRIC_SSIM, METRIC_LPIPS,
|
15 |
+
DEGRADATION_BLUR_MAT, DEGRADATION_BLUR_NONE
|
16 |
+
)
|
17 |
+
from rstor.data.dataloader import get_data_loader
|
18 |
+
from tqdm import tqdm
|
19 |
+
from pathlib import Path
|
20 |
+
import torch
|
21 |
+
from typing import Optional
|
22 |
+
import argparse
|
23 |
+
import sys
|
24 |
+
from rstor.analyzis.interactive.model_selection import get_default_models
|
25 |
+
from rstor.learning.metrics import compute_metrics
|
26 |
+
from interactive_pipe.data_objects.image import Image
|
27 |
+
from interactive_pipe.data_objects.parameters import Parameters
|
28 |
+
from typing import List
|
29 |
+
from itertools import product
|
30 |
+
import pandas as pd
|
31 |
+
ALL_TRACES = [TRACES_TARGET, TRACES_DEGRADED, TRACES_RESTORED, TRACES_METRICS]
|
32 |
+
|
33 |
+
|
34 |
+
def parse_int_pairs(s):
|
35 |
+
try:
|
36 |
+
# Split the input string by spaces to separate pairs, then split each pair by ',' and convert to tuple of ints
|
37 |
+
return [tuple(map(int, item.split(','))) for item in s.split()]
|
38 |
+
except ValueError:
|
39 |
+
raise argparse.ArgumentTypeError("Must be a series of pairs 'a,b' separated by spaces.")
|
40 |
+
|
41 |
+
|
42 |
+
def get_parser(parser: Optional[argparse.ArgumentParser] = None, batch_mode=False) -> argparse.ArgumentParser:
|
43 |
+
parser = get_models_parser(
|
44 |
+
parser=parser,
|
45 |
+
help="Inference on validation set",
|
46 |
+
default_models_path=ROOT_DIR/OUTPUT_FOLDER_NAME)
|
47 |
+
if not batch_mode:
|
48 |
+
parser.add_argument("-o", "--output-dir", type=str, default=ROOT_DIR /
|
49 |
+
INFERENCE_FOLDER_NAME, help="Output directory")
|
50 |
+
parser.add_argument("--cpu", action="store_true", help="Force CPU")
|
51 |
+
parser.add_argument("--traces", "-t", nargs="+", type=str, choices=ALL_TRACES+[TRACES_ALL],
|
52 |
+
help="Traces to be computed", default=TRACES_ALL)
|
53 |
+
parser.add_argument("--size", type=parse_int_pairs,
|
54 |
+
default=[(256, 256)], help="Size of the images like '256,512 512,512'")
|
55 |
+
parser.add_argument("--std-dev", type=parse_int_pairs, default=[(0, 50)],
|
56 |
+
help="Noise standard deviation (a, b) as pairs separated by spaces, e.g., '0,50 8,8 6,10'")
|
57 |
+
parser.add_argument("-n", "--number-of-images", type=int, default=None,
|
58 |
+
required=False, help="Number of images to process")
|
59 |
+
parser.add_argument("-d", "--dataset", type=str,
|
60 |
+
choices=[None, DATASET_DL_DIV2K_512, DATASET_DIV2K, DATASET_DL_EXTRAPRIMITIVES_DIV2K_512],
|
61 |
+
default=None),
|
62 |
+
parser.add_argument("-b", "--blur", action="store_true")
|
63 |
+
parser.add_argument("--blur-index", type=int, nargs="+", default=None)
|
64 |
+
return parser
|
65 |
+
|
66 |
+
|
67 |
+
def to_image(img: torch.Tensor):
|
68 |
+
return img.permute(0, 2, 3, 1).cpu().numpy()
|
69 |
+
|
70 |
+
|
71 |
+
def infer(model, dataloader, config, device, output_dir: Path, traces: List[str] = ALL_TRACES, number_of_images=None, degradation_key=CONFIG_DEAD_LEAVES,
|
72 |
+
chosen_metrics=[METRIC_PSNR, METRIC_SSIM]): # add METRIC_LPIPS here!
|
73 |
+
img_index = 0
|
74 |
+
if TRACES_ALL in traces:
|
75 |
+
traces = ALL_TRACES
|
76 |
+
if TRACES_METRICS in traces:
|
77 |
+
all_metrics = {}
|
78 |
+
else:
|
79 |
+
all_metrics = None
|
80 |
+
with torch.no_grad():
|
81 |
+
model.eval()
|
82 |
+
for img_degraded, img_target in tqdm(dataloader):
|
83 |
+
img_degraded = img_degraded.to(device)
|
84 |
+
img_target = img_target.to(device)
|
85 |
+
img_restored = model(img_degraded)
|
86 |
+
if TRACES_METRICS in traces:
|
87 |
+
metrics_input_per_image = compute_metrics(
|
88 |
+
img_degraded, img_target, reduction=REDUCTION_SKIP, chosen_metrics=chosen_metrics)
|
89 |
+
metrics_per_image = compute_metrics(
|
90 |
+
img_restored, img_target, reduction=REDUCTION_SKIP, chosen_metrics=chosen_metrics)
|
91 |
+
# print(metrics_per_image)
|
92 |
+
img_degraded = to_image(img_degraded)
|
93 |
+
img_target = to_image(img_target)
|
94 |
+
img_restored = to_image(img_restored)
|
95 |
+
for idx in range(img_restored.shape[0]):
|
96 |
+
degradation_parameters = dataloader.dataset.current_degradation[img_index]
|
97 |
+
common_prefix = f"{img_index:05d}_{img_degraded.shape[-3]:04d}x{img_degraded.shape[-2]:04d}"
|
98 |
+
common_prefix += f"_noise=[{config[DATALOADER][degradation_key]['noise_stddev'][0]:02d},{config[DATALOADER][degradation_key]['noise_stddev'][1]:02d}]"
|
99 |
+
suffix_deg = ""
|
100 |
+
if degradation_parameters['noise_stddev'] > 0:
|
101 |
+
suffix_deg += f"_noise={round(degradation_parameters['noise_stddev']):02d}"
|
102 |
+
suffix_deg += f"_blur={degradation_parameters['blur_kernel_id']:04d}"
|
103 |
+
#if degradation_parameters.get("blur_kernel_id", False) else ""
|
104 |
+
save_path_pred = output_dir/f"{common_prefix}_pred{suffix_deg}_{config[PRETTY_NAME]}.png"
|
105 |
+
save_path_degr = output_dir/f"{common_prefix}_degr{suffix_deg}.png"
|
106 |
+
save_path_targ = output_dir/f"{common_prefix}_targ.png"
|
107 |
+
if TRACES_RESTORED in traces:
|
108 |
+
Image(img_restored[idx]).save(save_path_pred)
|
109 |
+
if TRACES_DEGRADED in traces:
|
110 |
+
Image(img_degraded[idx]).save(save_path_degr)
|
111 |
+
if TRACES_TARGET in traces:
|
112 |
+
Image(img_target[idx]).save(save_path_targ)
|
113 |
+
if TRACES_METRICS in traces:
|
114 |
+
# current_metrics = {"in": {}, "out": {}}
|
115 |
+
# for key, value in metrics_per_image.items():
|
116 |
+
# print(f"{key}: {value[idx]:.3f}")
|
117 |
+
# current_metrics["in"][key] = metrics_input_per_image[key][idx].item()
|
118 |
+
# current_metrics["out"][key] = metrics_per_image[key][idx].item()
|
119 |
+
current_metrics = {}
|
120 |
+
for key, value in metrics_per_image.items():
|
121 |
+
current_metrics["in_"+key] = metrics_input_per_image[key][idx].item()
|
122 |
+
current_metrics["out_"+key] = metrics_per_image[key][idx].item()
|
123 |
+
current_metrics["degradation"] = degradation_parameters
|
124 |
+
current_metrics["size"] = (img_degraded.shape[-3], img_degraded.shape[-2])
|
125 |
+
current_metrics["deadleaves_config"] = config[DATALOADER][degradation_key]
|
126 |
+
current_metrics["restored"] = save_path_pred.relative_to(output_dir).as_posix()
|
127 |
+
current_metrics["degraded"] = save_path_degr.relative_to(output_dir).as_posix()
|
128 |
+
current_metrics["target"] = save_path_targ.relative_to(output_dir).as_posix()
|
129 |
+
current_metrics["model"] = config[PRETTY_NAME]
|
130 |
+
current_metrics["model_id"] = config[NAME]
|
131 |
+
Parameters(current_metrics).save(output_dir/f"{common_prefix}_metrics.json")
|
132 |
+
# for key, value in all_metrics.items():
|
133 |
+
|
134 |
+
all_metrics[img_index] = current_metrics
|
135 |
+
img_index += 1
|
136 |
+
if number_of_images is not None and img_index > number_of_images:
|
137 |
+
return all_metrics
|
138 |
+
return all_metrics
|
139 |
+
|
140 |
+
|
141 |
+
def infer_main(argv, batch_mode=False):
|
142 |
+
parser = get_parser(batch_mode=batch_mode)
|
143 |
+
if batch_mode:
|
144 |
+
batch = Batch(argv)
|
145 |
+
batch.set_io_description(
|
146 |
+
input_help='input image files',
|
147 |
+
output_help=f'output directory {str(ROOT_DIR/INFERENCE_FOLDER_NAME)}',
|
148 |
+
)
|
149 |
+
batch.parse_args(parser)
|
150 |
+
else:
|
151 |
+
args = parser.parse_args(argv)
|
152 |
+
device = "cpu" if args.cpu else DEVICE
|
153 |
+
dataset = args.dataset
|
154 |
+
blur_flag = args.blur
|
155 |
+
for exp in args.experiments:
|
156 |
+
model_dict = get_default_models([exp], Path(args.models_storage), interactive_flag=False)
|
157 |
+
# print(list(model_dict.keys()))
|
158 |
+
current_model_dict = model_dict[list(model_dict.keys())[0]]
|
159 |
+
model = current_model_dict["model"]
|
160 |
+
config = current_model_dict["config"]
|
161 |
+
for std_dev, size, blur_index in product(args.std_dev, args.size, args.blur_index):
|
162 |
+
if dataset is None:
|
163 |
+
config[DATALOADER][CONFIG_DEAD_LEAVES] = dict(
|
164 |
+
blur_kernel_half_size=[0, 0],
|
165 |
+
ds_factor=1,
|
166 |
+
noise_stddev=list(std_dev),
|
167 |
+
sampler=SAMPLER_SATURATED
|
168 |
+
)
|
169 |
+
config[DATALOADER]["gpu_gen"] = True
|
170 |
+
config[DATALOADER][SIZE] = size
|
171 |
+
config[DATALOADER][BATCH_SIZE][VALIDATION] = 1 if size[0] > 512 else 4
|
172 |
+
else:
|
173 |
+
config[DATALOADER][CONFIG_DEGRADATION] = dict(
|
174 |
+
noise_stddev=list(std_dev),
|
175 |
+
degradation_blur=DEGRADATION_BLUR_MAT if blur_flag else DEGRADATION_BLUR_NONE,
|
176 |
+
blur_index=blur_index
|
177 |
+
)
|
178 |
+
config[DATALOADER][NAME] = dataset
|
179 |
+
config[DATALOADER][SIZE] = size
|
180 |
+
config[DATALOADER][BATCH_SIZE][VALIDATION] = 1 if size[0] > 512 else 4
|
181 |
+
dataloader = get_data_loader(config, frozen_seed=42)
|
182 |
+
# print(config)
|
183 |
+
output_dir = Path(args.output_dir)/(config[NAME] + "_" +
|
184 |
+
config[PRETTY_NAME]) # + "_" + f"{size[0]:04d}x{size[1]:04d}")
|
185 |
+
output_dir.mkdir(parents=True, exist_ok=True)
|
186 |
+
|
187 |
+
all_metrics = infer(model, dataloader[VALIDATION], config, device, output_dir,
|
188 |
+
traces=args.traces, number_of_images=args.number_of_images,
|
189 |
+
degradation_key=CONFIG_DEAD_LEAVES if dataset is None else CONFIG_DEGRADATION)
|
190 |
+
if all_metrics is not None:
|
191 |
+
# print(all_metrics)
|
192 |
+
df = pd.DataFrame(all_metrics).T
|
193 |
+
prefix = f"{size[0]:04d}x{size[1]:04d}"
|
194 |
+
if not (std_dev[0] == 0 and std_dev[1] == 0):
|
195 |
+
prefix += f"_noise=[{std_dev[0]:02d},{std_dev[1]:02d}]"
|
196 |
+
if blur_index is not None:
|
197 |
+
prefix += f"_blur={blur_index:02d}"
|
198 |
+
prefix += f"_{config[PRETTY_NAME]}"
|
199 |
+
df.to_csv(output_dir/f"__{prefix}_metrics_.csv", index=False)
|
200 |
+
# Normally this could go into another script to handle the metrics analyzis
|
201 |
+
# print(df)
|
202 |
+
|
203 |
+
|
204 |
+
if __name__ == "__main__":
|
205 |
+
infer_main(sys.argv[1:])
|
scripts/interactive_inference_natural.py
ADDED
@@ -0,0 +1,65 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import sys
|
2 |
+
sys.path.append("src")
|
3 |
+
from interactive_pipe import interactive_pipeline
|
4 |
+
from rstor.analyzis.interactive.pipelines import natural_inference_pipeline, morph_canvas, CANVAS
|
5 |
+
from rstor.analyzis.interactive.model_selection import get_default_models
|
6 |
+
from pathlib import Path
|
7 |
+
from rstor.analyzis.parser import get_parser
|
8 |
+
import argparse
|
9 |
+
from batch_processing import Batch
|
10 |
+
from interactive_pipe.data_objects.image import Image
|
11 |
+
from rstor.analyzis.interactive.images import image_selector
|
12 |
+
from rstor.analyzis.interactive.crop import plug_crop_selector
|
13 |
+
from rstor.analyzis.interactive.metrics import plug_configure_metrics
|
14 |
+
from interactive_pipe import interactive, KeyboardControl
|
15 |
+
|
16 |
+
|
17 |
+
def plug_morph_canvas():
|
18 |
+
interactive(
|
19 |
+
canvas=KeyboardControl(CANVAS[0], CANVAS, name="canvas", keyup="p", modulo=True)
|
20 |
+
)(morph_canvas)
|
21 |
+
|
22 |
+
|
23 |
+
def image_loading_batch(input: Path, args: argparse.Namespace) -> dict:
|
24 |
+
"""Wrapper to load images files from a directory using batch_processing
|
25 |
+
"""
|
26 |
+
|
27 |
+
if not args.disable_preload:
|
28 |
+
img = Image.from_file(input).data
|
29 |
+
return {"name": input.name, "path": input, "buffer": img}
|
30 |
+
else:
|
31 |
+
return {"name": input.name, "path": input, "buffer": None}
|
32 |
+
|
33 |
+
|
34 |
+
def main(argv):
|
35 |
+
batch = Batch(argv)
|
36 |
+
batch.set_io_description(
|
37 |
+
input_help='input image files',
|
38 |
+
output_help=argparse.SUPPRESS
|
39 |
+
)
|
40 |
+
parser = get_parser()
|
41 |
+
parser.add_argument("-nop", "--disable-preload", action="store_true", help="Disable images preload")
|
42 |
+
args = batch.parse_args(parser)
|
43 |
+
# batch.set_multiprocessing_enabled(False)
|
44 |
+
img_list = batch.run(image_loading_batch)
|
45 |
+
if args.keyboard:
|
46 |
+
image_control = KeyboardControl(0, [0, len(img_list)-1], keydown="3", keyup="9", modulo=True)
|
47 |
+
else:
|
48 |
+
image_control = (0, [0, len(img_list)-1])
|
49 |
+
interactive(image_index=image_control)(image_selector)
|
50 |
+
plug_crop_selector(num_pad=args.keyboard)
|
51 |
+
plug_configure_metrics(key_shortcut="a") # "a" if args.keyboard else None)
|
52 |
+
plug_morph_canvas()
|
53 |
+
model_dict = get_default_models(args.experiments, Path(args.models_storage), keyboard_control=args.keyboard)
|
54 |
+
interactive_pipeline(
|
55 |
+
gui=args.backend,
|
56 |
+
cache=True,
|
57 |
+
safe_input_buffer_deepcopy=False
|
58 |
+
)(natural_inference_pipeline)(
|
59 |
+
img_list,
|
60 |
+
model_dict
|
61 |
+
)
|
62 |
+
|
63 |
+
|
64 |
+
if __name__ == "__main__":
|
65 |
+
main(sys.argv[1:])
|
scripts/interactive_inference_synthetic.py
ADDED
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import sys
|
2 |
+
sys.path.append("src")
|
3 |
+
from interactive_pipe import interactive_pipeline
|
4 |
+
from rstor.synthetic_data.interactive.interactive_dead_leaves import dead_leave_plugin
|
5 |
+
from rstor.analyzis.interactive.pipelines import deadleave_inference_pipeline
|
6 |
+
from rstor.analyzis.interactive.model_selection import get_default_models
|
7 |
+
from rstor.analyzis.interactive.crop import plug_crop_selector
|
8 |
+
from rstor.analyzis.interactive.metrics import plug_configure_metrics
|
9 |
+
from pathlib import Path
|
10 |
+
from rstor.analyzis.parser import get_parser
|
11 |
+
|
12 |
+
|
13 |
+
def main(argv):
|
14 |
+
parser = get_parser()
|
15 |
+
args = parser.parse_args(argv)
|
16 |
+
plug_crop_selector(num_pad=args.keyboard)
|
17 |
+
model_dict = get_default_models(args.experiments, Path(args.models_storage))
|
18 |
+
dead_leave_plugin(ds=1)
|
19 |
+
plug_configure_metrics(key_shortcut="a")
|
20 |
+
interactive_pipeline(gui="auto", cache=True, safe_input_buffer_deepcopy=False)(
|
21 |
+
deadleave_inference_pipeline)(model_dict)
|
22 |
+
|
23 |
+
|
24 |
+
if __name__ == "__main__":
|
25 |
+
main(sys.argv[1:])
|
scripts/metrics_analyzis.ipynb
ADDED
The diff for this file is too large to render.
See raw diff
|
|
scripts/quantitative_results.ipynb
ADDED
The diff for this file is too large to render.
See raw diff
|
|
scripts/remote_training.py
ADDED
@@ -0,0 +1,116 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import kaggle
|
2 |
+
from pathlib import Path, PurePosixPath
|
3 |
+
import json
|
4 |
+
try:
|
5 |
+
from __kaggle_login import kaggle_users
|
6 |
+
except ImportError:
|
7 |
+
raise ImportError("Please create a __kaggle_login.py file with a kaggle_users" +
|
8 |
+
"dict containing your Kaggle credentials.")
|
9 |
+
import argparse
|
10 |
+
import sys
|
11 |
+
import subprocess
|
12 |
+
from configuration import ROOT_DIR, OUTPUT_FOLDER_NAME
|
13 |
+
from train import get_parser as get_train_parser
|
14 |
+
from typing import Optional
|
15 |
+
from configuration import KAGGLE_DATASET_LIST, NB_ID, GIT_USER, GIT_REPO, TRAIN_SCRIPT
|
16 |
+
|
17 |
+
|
18 |
+
def get_git_branch_name():
|
19 |
+
try:
|
20 |
+
branch_name = subprocess.check_output(["git", "branch", "--show-current"]).strip().decode()
|
21 |
+
return branch_name
|
22 |
+
except subprocess.CalledProcessError:
|
23 |
+
return "Error: Could not determine the Git branch name."
|
24 |
+
|
25 |
+
|
26 |
+
def prepare_notebook(
|
27 |
+
output_nb_path: Path,
|
28 |
+
exp: int,
|
29 |
+
branch: str,
|
30 |
+
git_user: str = None,
|
31 |
+
git_repo: str = None,
|
32 |
+
template_nb_path: Path = Path(__file__).parent/"remote_training_template.ipynb",
|
33 |
+
wandb_flag: bool = False,
|
34 |
+
output_dir: Path = "scripts/"+OUTPUT_FOLDER_NAME,
|
35 |
+
dataset_files: Optional[list] = None,
|
36 |
+
train_script: str = TRAIN_SCRIPT
|
37 |
+
):
|
38 |
+
assert git_user is not None, "Please provide a git username for the repo"
|
39 |
+
assert git_repo is not None, "Please provide a git repo name for the repo"
|
40 |
+
expressions = [
|
41 |
+
("exp", f"{exp}"),
|
42 |
+
("branch", f"\'{branch}\'"),
|
43 |
+
("git_user", f"\'{git_user}\'"),
|
44 |
+
("git_repo", f"\'{git_repo}\'"),
|
45 |
+
("wandb_flag", "True" if wandb_flag else "False"),
|
46 |
+
("output_dir", "None" if output_dir is None else f"\'{output_dir}\'"),
|
47 |
+
("dataset_files", "None" if dataset_files is None else f"{dataset_files}"),
|
48 |
+
("train_script", "\'"+train_script+"\'")
|
49 |
+
]
|
50 |
+
with open(template_nb_path) as f:
|
51 |
+
template_nb = f.readlines()
|
52 |
+
for line_idx, li in enumerate(template_nb):
|
53 |
+
for expr, expr_replace in expressions:
|
54 |
+
if f"!!!{expr}!!!" in li:
|
55 |
+
template_nb[line_idx] = template_nb[line_idx].replace(f"!!!{expr}!!!", expr_replace)
|
56 |
+
template_nb = "".join(template_nb)
|
57 |
+
with open(output_nb_path, "w") as w:
|
58 |
+
w.write(template_nb)
|
59 |
+
|
60 |
+
|
61 |
+
def main(argv):
|
62 |
+
parser = argparse.ArgumentParser(description="Train a model on Kaggle using a script")
|
63 |
+
parser.add_argument("-n", "--nb_id", type=str, help="Notebook name in kaggle", default=NB_ID)
|
64 |
+
parser.add_argument("-u", "--user", type=str, help="Kaggle user", choices=list(kaggle_users.keys()))
|
65 |
+
parser.add_argument("--branch", type=str, help="Git branch name", default=get_git_branch_name())
|
66 |
+
parser.add_argument("-p", "--push", action="store_true", help="Push")
|
67 |
+
parser.add_argument("-d", "--download", action="store_true", help="Download results")
|
68 |
+
get_train_parser(parser)
|
69 |
+
args = parser.parse_args(argv)
|
70 |
+
nb_id = args.nb_id
|
71 |
+
exp_str = "_".join(f"{exp:04d}" for exp in args.exp)
|
72 |
+
kaggle_user = kaggle_users[args.user]
|
73 |
+
uname_kaggle = kaggle_user["username"]
|
74 |
+
kaggle.api._load_config(kaggle_user)
|
75 |
+
if args.download:
|
76 |
+
tmp_dir = ROOT_DIR/f"__tmp_{exp_str}"
|
77 |
+
tmp_dir.mkdir(exist_ok=True, parents=True)
|
78 |
+
kaggle.api.kernels_output_cli(f"{kaggle_user['username']}/{nb_id}", path=str(tmp_dir))
|
79 |
+
subprocess.run(["tar", "-xzf", tmp_dir/"output.tgz"])
|
80 |
+
# @FIXME: windows probably does not have tar command
|
81 |
+
import shutil
|
82 |
+
shutil.rmtree(tmp_dir, ignore_errors=True)
|
83 |
+
return
|
84 |
+
kernel_root = ROOT_DIR/f"__nb_{uname_kaggle}"
|
85 |
+
kernel_root.mkdir(exist_ok=True, parents=True)
|
86 |
+
|
87 |
+
kernel_path = kernel_root/exp_str
|
88 |
+
kernel_path.mkdir(exist_ok=True, parents=True)
|
89 |
+
branch = args.branch
|
90 |
+
config = {
|
91 |
+
"id": str(PurePosixPath(f"{kaggle_user['username']}")/nb_id),
|
92 |
+
"title": nb_id.lower(),
|
93 |
+
"code_file": f"{nb_id}.ipynb",
|
94 |
+
"language": "python",
|
95 |
+
"kernel_type": "notebook",
|
96 |
+
"is_private": "true",
|
97 |
+
"enable_gpu": "true" if not args.cpu else "false",
|
98 |
+
"enable_tpu": "false",
|
99 |
+
"enable_internet": "true",
|
100 |
+
"dataset_sources": KAGGLE_DATASET_LIST,
|
101 |
+
"competition_sources": [],
|
102 |
+
"kernel_sources": [],
|
103 |
+
"model_sources": []
|
104 |
+
}
|
105 |
+
prepare_notebook((kernel_path/nb_id).with_suffix(".ipynb"), args.exp, branch,
|
106 |
+
git_user=GIT_USER, git_repo=GIT_REPO, wandb_flag=not args.no_wandb)
|
107 |
+
assert (kernel_path/nb_id).with_suffix(".ipynb").exists()
|
108 |
+
with open(kernel_path/"kernel-metadata.json", "w") as f:
|
109 |
+
json.dump(config, f, indent=4)
|
110 |
+
|
111 |
+
if args.push:
|
112 |
+
kaggle.api.kernels_push_cli(str(kernel_path))
|
113 |
+
|
114 |
+
|
115 |
+
if __name__ == '__main__':
|
116 |
+
main(sys.argv[1:])
|
scripts/remote_training_template.ipynb
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
{"cells":[{"cell_type":"code","execution_count":null,"metadata":{"trusted":true},"outputs":[],"source":["exp = !!!exp!!!\n","branch = !!!branch!!!\n","git_user = !!!git_user!!!\n","git_repo = !!!git_repo!!!\n","wandb_flag = !!!wandb_flag!!!\n","output_dir = !!!output_dir!!!\n","dataset_files = !!!dataset_files!!!\n","train_script = !!!train_script!!!"]},{"cell_type":"markdown","metadata":{},"source":["# Clone git repo"]},{"cell_type":"code","execution_count":null,"metadata":{"trusted":true},"outputs":[],"source":["%cd ~\n","!git clone https://github.com/$git_user/$git_repo >/dev/null\n","%cd $git_repo\n","!git checkout $branch\n","!pip install -e ."]},{"cell_type":"markdown","metadata":{},"source":["# Load Kaggle datasets"]},{"cell_type":"code","execution_count":null,"metadata":{},"outputs":[],"source":["!mkdir __dataset/deadleaves_div2k_512\n","!cp /kaggle/input/deadleaves-div2k-512/deadleaves_div2k_512/* \"__dataset/deadleaves_div2k_512\""]},{"cell_type":"code","execution_count":null,"metadata":{},"outputs":[],"source":["!mkdir __dataset/deadleaves_primitives_div2k_512\n","!cp /kaggle/input/deadleaves-primitives-div2k-512/deadleaves_primitives_div2k_512/* \"__dataset/deadleaves_primitives_div2k_512\""]},{"cell_type":"code","execution_count":null,"metadata":{},"outputs":[],"source":["!mkdir __dataset/div2k\n","!cp -r \"/kaggle/input/div2k-dataset/DIV2K_train_HR\" \"__dataset/div2k/\"\n","!cp -r \"/kaggle/input/div2k-dataset/DIV2K_valid_HR/\" \"__dataset/div2k/\""]},{"cell_type":"code","execution_count":null,"metadata":{},"outputs":[],"source":["!mkdir __dataset/kernels\n","!cp \"/kaggle/input/motion-blur-kernels/custom_blur_centered.mat\" __dataset/kernels/"]},{"cell_type":"markdown","metadata":{},"source":["# Setup weights and biases"]},{"cell_type":"code","execution_count":null,"metadata":{},"outputs":[],"source":["if wandb_flag:\n"," from kaggle_secrets import UserSecretsClient\n"," user_secrets = UserSecretsClient()\n"," wandb_api_key = user_secrets.get_secret(\"wandb_api_key\")\n","\n"," !pip install wandb >/dev/null\n"," !wandb login $wandb_api_key"]},{"cell_type":"markdown","metadata":{},"source":["# Launch training"]},{"cell_type":"code","execution_count":null,"metadata":{},"outputs":[],"source":["exp_str = ' '.join([str(e) for e in exp])\n","wb_ext = \"-nowb\" if not wandb_flag else \"\"\n","!python $train_script -e $exp_str $wb_ext"]},{"cell_type":"markdown","metadata":{},"source":["# Prepare outputs"]},{"cell_type":"code","execution_count":null,"metadata":{},"outputs":[],"source":["if output_dir is not None:\n"," !tar -cvzf /kaggle/working/output.tgz $output_dir"]}],"metadata":{"kaggle":{"accelerator":"gpu","dataSources":[{"datasetId":4234777,"sourceId":7299921,"sourceType":"datasetVersion"}],"dockerImageVersionId":30626,"isGpuEnabled":true,"isInternetEnabled":true,"language":"python","sourceType":"notebook"},"kernelspec":{"display_name":"Python 3","language":"python","name":"python3"},"language_info":{"codemirror_mode":{"name":"ipython","version":3},"file_extension":".py","mimetype":"text/x-python","name":"python","nbconvert_exporter":"python","pygments_lexer":"ipython3","version":"3.10.12"}},"nbformat":4,"nbformat_minor":4}
|
scripts/save_deadleaves.py
ADDED
@@ -0,0 +1,172 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
"""
|
3 |
+
Created on Sat Mar 23 15:38:28 2024
|
4 |
+
|
5 |
+
@author: jamyl
|
6 |
+
"""
|
7 |
+
import cv2
|
8 |
+
from pathlib import Path
|
9 |
+
from time import perf_counter
|
10 |
+
import matplotlib.pyplot as plt
|
11 |
+
from typing import Tuple
|
12 |
+
|
13 |
+
import numpy as np
|
14 |
+
import torch
|
15 |
+
import torch.nn.functional as F
|
16 |
+
from torch.utils.data import Dataset
|
17 |
+
from numba import cuda
|
18 |
+
from tqdm import tqdm
|
19 |
+
import argparse
|
20 |
+
from rstor.synthetic_data.dead_leaves_gpu import gpu_dead_leaves_chart
|
21 |
+
from rstor.utils import DEFAULT_TORCH_FLOAT_TYPE
|
22 |
+
from rstor.properties import DATASET_PATH, DATASET_DL_RANDOMRGB_1024, DATASET_DL_DIV2K_1024, SAMPLER_NATURAL, SAMPLER_UNIFORM, DATASET_DL_DIV2K_512, DATASET_DL_EXTRAPRIMITIVES_DIV2K_512
|
23 |
+
|
24 |
+
|
25 |
+
class DeadLeavesDatasetGPU(Dataset):
|
26 |
+
def __init__(
|
27 |
+
self,
|
28 |
+
size: Tuple[int, int] = (128, 128),
|
29 |
+
length: int = 1000,
|
30 |
+
frozen_seed: int = None, # useful for validation set!
|
31 |
+
ds_factor: int = 5,
|
32 |
+
**config_dead_leaves
|
33 |
+
):
|
34 |
+
|
35 |
+
self.frozen_seed = frozen_seed
|
36 |
+
self.ds_factor = ds_factor
|
37 |
+
self.size = (size[0]*ds_factor, size[1]*ds_factor)
|
38 |
+
self.length = length
|
39 |
+
self.config_dead_leaves = config_dead_leaves
|
40 |
+
|
41 |
+
# downsample kernel
|
42 |
+
sigma = 3/5
|
43 |
+
k_size = 5 # This fits with sigma = 3/5, the cutoff value is 0.0038 (neglectable)
|
44 |
+
x = (torch.arange(k_size) - 2).to('cuda')
|
45 |
+
kernel = torch.stack(torch.meshgrid((x, x), indexing='ij'))
|
46 |
+
dist_sq = kernel[0]**2 + kernel[1]**2
|
47 |
+
kernel = (-dist_sq.square()/(2*sigma**2)).exp()
|
48 |
+
kernel = kernel / kernel.sum()
|
49 |
+
self.downsample_kernel = kernel.repeat(3, 1, 1, 1) # shape [3, 1, k_size, k_size]
|
50 |
+
|
51 |
+
def __len__(self) -> int:
|
52 |
+
return self.length
|
53 |
+
|
54 |
+
def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor]:
|
55 |
+
"""Get a single deadleave chart and its degraded version.
|
56 |
+
|
57 |
+
Args:
|
58 |
+
idx (int): index of the item to retrieve
|
59 |
+
|
60 |
+
Returns:
|
61 |
+
Tuple[torch.Tensor, torch.Tensor]: degraded chart, target chart
|
62 |
+
"""
|
63 |
+
seed = self.frozen_seed + idx if self.frozen_seed is not None else None
|
64 |
+
|
65 |
+
# Return numba device array
|
66 |
+
numba_chart = gpu_dead_leaves_chart(self.size, seed=seed, **self.config_dead_leaves)
|
67 |
+
if self.ds_factor > 1:
|
68 |
+
# print(f"Downsampling {chart.shape} with factor {self.ds_factor}...")
|
69 |
+
|
70 |
+
# Downsample using strided gaussian conv (sigma=3/5)
|
71 |
+
th_chart = torch.as_tensor(numba_chart, dtype=DEFAULT_TORCH_FLOAT_TYPE,
|
72 |
+
device="cuda").permute(2, 0, 1)[None] # [b, c, h, w]
|
73 |
+
th_chart = F.pad(th_chart,
|
74 |
+
pad=(2, 2, 0, 0),
|
75 |
+
mode="replicate")
|
76 |
+
th_chart = F.conv2d(th_chart,
|
77 |
+
self.downsample_kernel,
|
78 |
+
padding='valid',
|
79 |
+
groups=3,
|
80 |
+
stride=self.ds_factor)
|
81 |
+
|
82 |
+
# Convert back to numba
|
83 |
+
numba_chart = cuda.as_cuda_array(th_chart.permute(0, 2, 3, 1)) # [b, h, w, c]
|
84 |
+
|
85 |
+
# convert back to numpy (temporary for legacy)
|
86 |
+
chart = numba_chart.copy_to_host()[0]
|
87 |
+
|
88 |
+
return chart
|
89 |
+
|
90 |
+
|
91 |
+
def generate_images(path: Path, dataset: Dataset, imin=0):
|
92 |
+
for i in tqdm(range(imin, dataset.length)):
|
93 |
+
img = dataset[i]
|
94 |
+
img = (img * 255).astype(np.uint8)
|
95 |
+
out_path = path / "{:04d}.png".format(i)
|
96 |
+
cv2.imwrite(out_path.as_posix(), img)
|
97 |
+
|
98 |
+
|
99 |
+
def bench(dataset):
|
100 |
+
|
101 |
+
print("dataset initialised")
|
102 |
+
t1 = perf_counter()
|
103 |
+
chart = dataset[0]
|
104 |
+
|
105 |
+
d = (perf_counter()-t1)
|
106 |
+
print(f"generation done {d}")
|
107 |
+
print(f"{d*1_000/60} min for 1_000")
|
108 |
+
plt.imshow(chart)
|
109 |
+
plt.show()
|
110 |
+
|
111 |
+
|
112 |
+
if __name__ == "__main__":
|
113 |
+
argparser = argparse.ArgumentParser()
|
114 |
+
argparser.add_argument("-o", "--output-dir", type=str, default=str(DATASET_PATH))
|
115 |
+
argparser.add_argument(
|
116 |
+
"-n", "--name", type=str,
|
117 |
+
choices=[DATASET_DL_RANDOMRGB_1024, DATASET_DL_DIV2K_1024,
|
118 |
+
DATASET_DL_DIV2K_512, DATASET_DL_EXTRAPRIMITIVES_DIV2K_512],
|
119 |
+
default=DATASET_DL_RANDOMRGB_1024
|
120 |
+
)
|
121 |
+
argparser.add_argument("-b", "--benchmark", action="store_true")
|
122 |
+
default_config = dict(
|
123 |
+
size=(1_024, 1_024),
|
124 |
+
length=1_000,
|
125 |
+
frozen_seed=42,
|
126 |
+
background_color=(0.2, 0.4, 0.6),
|
127 |
+
colored=True,
|
128 |
+
radius_min=5,
|
129 |
+
radius_max=2_000,
|
130 |
+
ds_factor=5,
|
131 |
+
)
|
132 |
+
|
133 |
+
args = argparser.parse_args()
|
134 |
+
dataset_dir = args.output_dir
|
135 |
+
name = args.name
|
136 |
+
path = Path(dataset_dir)/name
|
137 |
+
# print(path)
|
138 |
+
path.mkdir(parents=True, exist_ok=True)
|
139 |
+
if name == DATASET_DL_RANDOMRGB_1024:
|
140 |
+
config = default_config
|
141 |
+
config["sampler"] = SAMPLER_UNIFORM
|
142 |
+
elif name == DATASET_DL_DIV2K_1024:
|
143 |
+
config = default_config
|
144 |
+
config["sampler"] = SAMPLER_NATURAL
|
145 |
+
config["natural_image_list"] = sorted(
|
146 |
+
list((DATASET_PATH / "div2k" / "DIV2K_train_HR" / "DIV2K_train_HR").glob("*.png"))
|
147 |
+
)
|
148 |
+
elif name == DATASET_DL_DIV2K_512:
|
149 |
+
config = default_config
|
150 |
+
config["size"] = (512, 512)
|
151 |
+
config["rmin"] = 3
|
152 |
+
config["length"] = 4000
|
153 |
+
config["sampler"] = SAMPLER_NATURAL
|
154 |
+
config["natural_image_list"] = sorted(
|
155 |
+
list((DATASET_PATH / "div2k" / "DIV2K_train_HR" / "DIV2K_train_HR").glob("*.png"))
|
156 |
+
)
|
157 |
+
elif name == DATASET_DL_EXTRAPRIMITIVES_DIV2K_512:
|
158 |
+
config = default_config
|
159 |
+
config["size"] = (512, 512)
|
160 |
+
config["sampler"] = SAMPLER_NATURAL
|
161 |
+
config["circle_primitives"] = False
|
162 |
+
config["length"] = 4000
|
163 |
+
config["natural_image_list"] = sorted(
|
164 |
+
list((DATASET_PATH / "div2k" / "DIV2K_train_HR" / "DIV2K_train_HR").glob("*.png"))
|
165 |
+
)
|
166 |
+
else:
|
167 |
+
raise NotImplementedError
|
168 |
+
dataset = DeadLeavesDatasetGPU(**config)
|
169 |
+
if args.benchmark:
|
170 |
+
bench(dataset)
|
171 |
+
else:
|
172 |
+
generate_images(path, dataset)
|
scripts/train.py
ADDED
@@ -0,0 +1,171 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import sys
|
2 |
+
import argparse
|
3 |
+
from typing import Optional
|
4 |
+
import torch
|
5 |
+
import logging
|
6 |
+
from pathlib import Path
|
7 |
+
import json
|
8 |
+
from tqdm import tqdm
|
9 |
+
from rstor.properties import (
|
10 |
+
ID, NAME, NB_EPOCHS,
|
11 |
+
TRAIN, VALIDATION, LR,
|
12 |
+
LOSS_MSE, METRIC_PSNR, METRIC_SSIM,
|
13 |
+
DEVICE, SCHEDULER_CONFIGURATION, SCHEDULER, REDUCELRONPLATEAU,
|
14 |
+
REDUCTION_SUM,
|
15 |
+
SELECTED_METRICS,
|
16 |
+
LOSS
|
17 |
+
)
|
18 |
+
from rstor.learning.metrics import compute_metrics
|
19 |
+
from rstor.learning.loss import compute_loss
|
20 |
+
from torch.optim.lr_scheduler import ReduceLROnPlateau
|
21 |
+
from configuration import WANDBSPACE, ROOT_DIR, OUTPUT_FOLDER_NAME
|
22 |
+
from rstor.learning.experiments import get_training_content
|
23 |
+
from rstor.learning.experiments_definition import get_experiment_config
|
24 |
+
WANDB_AVAILABLE = False
|
25 |
+
try:
|
26 |
+
WANDB_AVAILABLE = True
|
27 |
+
import wandb
|
28 |
+
except ImportError:
|
29 |
+
logging.warning("Could not import wandb. Disabling wandb.")
|
30 |
+
pass
|
31 |
+
|
32 |
+
|
33 |
+
def get_parser(parser: Optional[argparse.ArgumentParser] = None) -> argparse.ArgumentParser:
|
34 |
+
if parser is None:
|
35 |
+
parser = argparse.ArgumentParser(description="Train a model")
|
36 |
+
parser.add_argument("-e", "--exp", nargs="+", type=int, required=True, help="Experiment id")
|
37 |
+
parser.add_argument("-o", "--output-dir", type=str, default=ROOT_DIR/OUTPUT_FOLDER_NAME, help="Output directory")
|
38 |
+
parser.add_argument("-nowb", "--no-wandb", action="store_true", help="Disable weights and biases")
|
39 |
+
parser.add_argument("--cpu", action="store_true", help="Force CPU")
|
40 |
+
return parser
|
41 |
+
|
42 |
+
|
43 |
+
def training_loop(
|
44 |
+
model,
|
45 |
+
optimizer,
|
46 |
+
dl_dict: dict,
|
47 |
+
config: dict,
|
48 |
+
scheduler=None,
|
49 |
+
device: str = DEVICE,
|
50 |
+
wandb_flag: bool = False,
|
51 |
+
output_dir: Path = None,
|
52 |
+
):
|
53 |
+
best_accuracy = 0.
|
54 |
+
chosen_metrics = config.get(SELECTED_METRICS, [METRIC_PSNR, METRIC_SSIM])
|
55 |
+
for n_epoch in tqdm(range(config[NB_EPOCHS])):
|
56 |
+
current_metrics = {
|
57 |
+
TRAIN: 0.,
|
58 |
+
VALIDATION: 0.,
|
59 |
+
LR: optimizer.param_groups[0]['lr'],
|
60 |
+
}
|
61 |
+
for met in chosen_metrics:
|
62 |
+
current_metrics[met] = 0.
|
63 |
+
for phase in [TRAIN, VALIDATION]:
|
64 |
+
total_elements = 0
|
65 |
+
if phase == TRAIN:
|
66 |
+
model.train()
|
67 |
+
else:
|
68 |
+
model.eval()
|
69 |
+
for x, y in tqdm(dl_dict[phase], desc=f"{phase} - Epoch {n_epoch}"):
|
70 |
+
x, y = x.to(device), y.to(device)
|
71 |
+
optimizer.zero_grad()
|
72 |
+
with torch.set_grad_enabled(phase == TRAIN):
|
73 |
+
y_pred = model(x)
|
74 |
+
loss = compute_loss(y_pred, y, mode=config.get(LOSS, LOSS_MSE))
|
75 |
+
if torch.isnan(loss):
|
76 |
+
print(f"Loss is NaN at epoch {n_epoch} and phase {phase}!")
|
77 |
+
continue
|
78 |
+
if phase == TRAIN:
|
79 |
+
loss.backward()
|
80 |
+
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.)
|
81 |
+
optimizer.step()
|
82 |
+
current_metrics[phase] += loss.item()
|
83 |
+
if phase == VALIDATION:
|
84 |
+
metrics_on_batch = compute_metrics(
|
85 |
+
y_pred,
|
86 |
+
y,
|
87 |
+
chosen_metrics=chosen_metrics,
|
88 |
+
reduction=REDUCTION_SUM
|
89 |
+
)
|
90 |
+
total_elements += y_pred.shape[0]
|
91 |
+
for k, v in metrics_on_batch.items():
|
92 |
+
current_metrics[k] += v
|
93 |
+
|
94 |
+
current_metrics[phase] /= (len(dl_dict[phase]))
|
95 |
+
if phase == VALIDATION:
|
96 |
+
for k, v in metrics_on_batch.items():
|
97 |
+
current_metrics[k] /= total_elements
|
98 |
+
try:
|
99 |
+
current_metrics[k] = current_metrics[k].item()
|
100 |
+
except AttributeError:
|
101 |
+
pass
|
102 |
+
debug_print = f"{phase}: Epoch {n_epoch} - Loss: {current_metrics[phase]:.3e} "
|
103 |
+
for k, v in current_metrics.items():
|
104 |
+
if k not in [TRAIN, VALIDATION, LR]:
|
105 |
+
debug_print += f"{k}: {v:.3} |"
|
106 |
+
print(debug_print)
|
107 |
+
if scheduler is not None and isinstance(scheduler, ReduceLROnPlateau):
|
108 |
+
scheduler.step(current_metrics[VALIDATION])
|
109 |
+
if output_dir is not None:
|
110 |
+
with open(output_dir/f"metrics_{n_epoch}.json", "w") as f:
|
111 |
+
json.dump(current_metrics, f)
|
112 |
+
if wandb_flag:
|
113 |
+
wandb.log(current_metrics)
|
114 |
+
if best_accuracy < current_metrics[METRIC_PSNR]:
|
115 |
+
best_accuracy = current_metrics[METRIC_PSNR]
|
116 |
+
if output_dir is not None:
|
117 |
+
print("new best model saved!")
|
118 |
+
torch.save(model.state_dict(), output_dir/"best_model.pt")
|
119 |
+
if output_dir is not None:
|
120 |
+
torch.save(model.cpu().state_dict(), output_dir/"last_model.pt")
|
121 |
+
return model
|
122 |
+
|
123 |
+
|
124 |
+
def train(config: dict, output_dir: Path, device: str = DEVICE, wandb_flag: bool = False):
|
125 |
+
logging.basicConfig(level=logging.INFO)
|
126 |
+
logging.info(f"Training experiment {config[ID]} on device {device}...")
|
127 |
+
output_dir.mkdir(parents=True, exist_ok=True)
|
128 |
+
with open(output_dir/"config.json", "w") as f:
|
129 |
+
json.dump(config, f)
|
130 |
+
model, optimizer, dl_dict = get_training_content(config, training_mode=True, device=device)
|
131 |
+
model.to(device)
|
132 |
+
if wandb_flag:
|
133 |
+
import wandb
|
134 |
+
wandb.init(
|
135 |
+
project=WANDBSPACE,
|
136 |
+
entity="balthazarneveu",
|
137 |
+
name=config[NAME],
|
138 |
+
tags=["debug"],
|
139 |
+
# tags=["base"],
|
140 |
+
config=config
|
141 |
+
)
|
142 |
+
scheduler = None
|
143 |
+
if config.get(SCHEDULER, False):
|
144 |
+
scheduler_config = config[SCHEDULER_CONFIGURATION]
|
145 |
+
if config[SCHEDULER] == REDUCELRONPLATEAU:
|
146 |
+
scheduler = ReduceLROnPlateau(optimizer, mode='min', verbose=True, **scheduler_config)
|
147 |
+
else:
|
148 |
+
raise NameError(f"Scheduler {config[SCHEDULER]} not implemented")
|
149 |
+
model = training_loop(model, optimizer, dl_dict, config, scheduler=scheduler, device=device,
|
150 |
+
wandb_flag=wandb_flag, output_dir=output_dir)
|
151 |
+
|
152 |
+
if wandb_flag:
|
153 |
+
wandb.finish()
|
154 |
+
|
155 |
+
|
156 |
+
def train_main(argv):
|
157 |
+
parser = get_parser()
|
158 |
+
args = parser.parse_args(argv)
|
159 |
+
if not WANDB_AVAILABLE:
|
160 |
+
args.no_wandb = True
|
161 |
+
device = "cpu" if args.cpu else DEVICE
|
162 |
+
for exp in args.exp:
|
163 |
+
config = get_experiment_config(exp)
|
164 |
+
print(config)
|
165 |
+
output_dir = Path(args.output_dir)/config[NAME]
|
166 |
+
logging.info(f"Training experiment {config[ID]} on device {device}...")
|
167 |
+
train(config, device=device, output_dir=output_dir, wandb_flag=not args.no_wandb)
|
168 |
+
|
169 |
+
|
170 |
+
if __name__ == "__main__":
|
171 |
+
train_main(sys.argv[1:])
|
src/rstor/__init__.py
ADDED
File without changes
|
src/rstor/analyzis/interactive/crop.py
ADDED
@@ -0,0 +1,75 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import cv2
|
2 |
+
import numpy as np
|
3 |
+
from interactive_pipe import interactive
|
4 |
+
|
5 |
+
|
6 |
+
def get_color_channel_offset(image):
|
7 |
+
# size is defined in power of 2
|
8 |
+
if len(image.shape) == 2:
|
9 |
+
offset = 0
|
10 |
+
elif len(image.shape) == 3:
|
11 |
+
channel_guesser_max_size = 4
|
12 |
+
if image.shape[0] <= channel_guesser_max_size: # channel first C,H,W
|
13 |
+
offset = 0
|
14 |
+
elif image.shape[-1] <= channel_guesser_max_size: # channel last or numpy H,W,C
|
15 |
+
offset = 1
|
16 |
+
else:
|
17 |
+
raise NameError(f"Not supported shape {image.shape}")
|
18 |
+
return offset
|
19 |
+
|
20 |
+
|
21 |
+
def crop_selector(image, center_x=0.5, center_y=0.5, size=9., global_params={}):
|
22 |
+
offset = get_color_channel_offset(image)
|
23 |
+
crop_size_pixels = int(2.**(size)/2.)
|
24 |
+
h, w = image.shape[-2-offset], image.shape[-1-offset]
|
25 |
+
ar = w/h
|
26 |
+
half_crop_h, half_crop_w = crop_size_pixels, int(ar*crop_size_pixels)
|
27 |
+
|
28 |
+
def round(val):
|
29 |
+
return int(np.round(val))
|
30 |
+
center_x_int = round(half_crop_w + center_x*(w-2*half_crop_w))
|
31 |
+
center_y_int = round(half_crop_h + center_y*(h-2*half_crop_h))
|
32 |
+
start_x = max(0, center_x_int-half_crop_w)
|
33 |
+
start_y = max(0, center_y_int-half_crop_h)
|
34 |
+
end_x = min(start_x+2*half_crop_w, w-1)
|
35 |
+
end_y = min(start_y+2*half_crop_h, h-1)
|
36 |
+
start_x = max(0, end_x-2*half_crop_w)
|
37 |
+
start_y = max(0, end_y-2*half_crop_h)
|
38 |
+
MAX_ALLOWED_SIZE = 512
|
39 |
+
w_resize = int(min(MAX_ALLOWED_SIZE, w))
|
40 |
+
h_resize = int(w_resize/w*h)
|
41 |
+
h_resize = int(min(MAX_ALLOWED_SIZE, h_resize))
|
42 |
+
w_resize = int(h_resize/h*w)
|
43 |
+
global_params["crop"] = (start_x, start_y, end_x, end_y)
|
44 |
+
global_params["resize"] = (w_resize, h_resize)
|
45 |
+
return
|
46 |
+
|
47 |
+
|
48 |
+
def plug_crop_selector(num_pad: bool = False):
|
49 |
+
interactive(
|
50 |
+
center_x=(0.5, [0., 1.], "cx", ["4" if num_pad else "left", "6" if num_pad else "right"]),
|
51 |
+
center_y=(0.5, [0., 1.], "cy", ["8" if num_pad else "up", "2" if num_pad else "down"]),
|
52 |
+
size=(9., [6., 13., 0.3], "crop size", ["+", "-"])
|
53 |
+
)(crop_selector)
|
54 |
+
|
55 |
+
|
56 |
+
def crop(*images, global_params={}):
|
57 |
+
images_resized = []
|
58 |
+
for image in images:
|
59 |
+
offset = get_color_channel_offset(image)
|
60 |
+
start_x, start_y, end_x, end_y = global_params["crop"]
|
61 |
+
w_resize, h_resize = global_params["resize"]
|
62 |
+
if offset == 0:
|
63 |
+
crop = image[..., start_y:end_y, start_x:end_x]
|
64 |
+
if offset == 1:
|
65 |
+
crop = image[..., start_y:end_y, start_x:end_x, :]
|
66 |
+
image_resized = cv2.resize(crop, (w_resize, h_resize), interpolation=cv2.INTER_NEAREST)
|
67 |
+
images_resized.append(image_resized)
|
68 |
+
return tuple(images_resized)
|
69 |
+
|
70 |
+
|
71 |
+
def rescale_thumbnail(image, global_params={}):
|
72 |
+
if image is None: # support no blur kernel!
|
73 |
+
return None
|
74 |
+
resize_dim = max(global_params.get("resize", (512, 512)))
|
75 |
+
return cv2.resize(image, (resize_dim, resize_dim), interpolation=cv2.INTER_NEAREST)
|
src/rstor/analyzis/interactive/degradation.py
ADDED
@@ -0,0 +1,71 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
from interactive_pipe import interactive
|
3 |
+
from skimage.filters import gaussian
|
4 |
+
from rstor.properties import DATASET_BLUR_KERNEL_PATH
|
5 |
+
from scipy.io import loadmat
|
6 |
+
import cv2
|
7 |
+
|
8 |
+
|
9 |
+
@interactive(
|
10 |
+
sigma=(3/5, [0., 2.])
|
11 |
+
)
|
12 |
+
def downsample(chart: np.ndarray, sigma=3/5, global_params={}):
|
13 |
+
ds_factor = global_params.get("ds_factor", 5)
|
14 |
+
if sigma > 0.:
|
15 |
+
ds_chart = gaussian(chart, sigma=(sigma, sigma, 0), mode='nearest', cval=0, preserve_range=True, truncate=4.0)
|
16 |
+
else:
|
17 |
+
ds_chart = chart.copy()
|
18 |
+
ds_chart = ds_chart[ds_factor//2::ds_factor, ds_factor//2::ds_factor]
|
19 |
+
return ds_chart
|
20 |
+
|
21 |
+
|
22 |
+
@interactive(
|
23 |
+
k_size_x=(0, [0, 10]),
|
24 |
+
k_size_y=(0, [0, 10]),
|
25 |
+
)
|
26 |
+
def degrade_blur_gaussian(chart: np.ndarray, k_size_x: int = 1, k_size_y: int = 1):
|
27 |
+
if k_size_x == 0 and k_size_y == 0:
|
28 |
+
blurred = chart
|
29 |
+
blurred = cv2.GaussianBlur(chart, (2*k_size_x+1, 2*k_size_y+1), 0)
|
30 |
+
return blurred
|
31 |
+
|
32 |
+
|
33 |
+
@interactive(
|
34 |
+
noise_stddev=(0., [0., 50.])
|
35 |
+
)
|
36 |
+
def degrade_noise(img: np.ndarray, noise_stddev=0., global_params={}):
|
37 |
+
seed = global_params.get("seed", 42)
|
38 |
+
np.random.seed(seed)
|
39 |
+
if noise_stddev > 0.:
|
40 |
+
noise = np.random.normal(0, noise_stddev/255., img.shape)
|
41 |
+
img = img.copy()+noise
|
42 |
+
return img
|
43 |
+
|
44 |
+
|
45 |
+
@interactive(
|
46 |
+
ksize=(3, [1, 10])
|
47 |
+
)
|
48 |
+
def get_blur_kernel_box(ksize=3):
|
49 |
+
return np.ones((ksize, ksize), dtype=np.float32) / (1.*ksize**2)
|
50 |
+
|
51 |
+
|
52 |
+
@interactive(
|
53 |
+
blur_index=(-1, [-1, 1000])
|
54 |
+
)
|
55 |
+
def get_blur_kernel(blur_index: int = -1, global_params={}):
|
56 |
+
if blur_index == -1:
|
57 |
+
return None
|
58 |
+
blur_mat = global_params.get("blur_mat", False)
|
59 |
+
if blur_mat is False:
|
60 |
+
blur_mat = loadmat(DATASET_BLUR_KERNEL_PATH)["kernels"].squeeze()
|
61 |
+
global_params["blur_mat"] = blur_mat
|
62 |
+
blur_k = blur_mat[blur_index]
|
63 |
+
blur_k = blur_k/blur_k.sum()
|
64 |
+
return blur_k
|
65 |
+
|
66 |
+
|
67 |
+
def degrade_blur(img: np.ndarray, blur_kernel: np.ndarray, global_params={}):
|
68 |
+
if blur_kernel is None:
|
69 |
+
return img
|
70 |
+
img_blur = cv2.filter2D(img, -1, blur_kernel)
|
71 |
+
return img_blur
|
src/rstor/analyzis/interactive/images.py
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from interactive_pipe.data_objects.image import Image
|
2 |
+
from typing import List
|
3 |
+
|
4 |
+
|
5 |
+
def image_selector(image_list: List[dict], image_index: int = 0) -> dict:
|
6 |
+
current_image = image_list[image_index % len(image_list)]
|
7 |
+
img = current_image.get("buffer", None)
|
8 |
+
if img is None:
|
9 |
+
img = Image.from_file(current_image["path"]).data
|
10 |
+
return img
|
src/rstor/analyzis/interactive/inference.py
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import numpy as np
|
3 |
+
from rstor.properties import DEVICE
|
4 |
+
|
5 |
+
|
6 |
+
def infer(degraded: np.ndarray, model: torch.nn.Module):
|
7 |
+
degraded_tensor = torch.from_numpy(degraded).permute(-1, 0, 1).float().unsqueeze(0)
|
8 |
+
model.eval()
|
9 |
+
with torch.no_grad():
|
10 |
+
output = model(degraded_tensor.to(DEVICE))
|
11 |
+
output = output.squeeze().permute(1, 2, 0).cpu().numpy()
|
12 |
+
return np.ascontiguousarray(output)
|
src/rstor/analyzis/interactive/metrics.py
ADDED
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from rstor.learning.metrics import compute_metrics, ALL_METRICS
|
2 |
+
import torch
|
3 |
+
import numpy as np
|
4 |
+
from rstor.properties import METRIC_PSNR, METRIC_SSIM
|
5 |
+
from interactive_pipe import interactive, KeyboardControl
|
6 |
+
from typing import Optional
|
7 |
+
|
8 |
+
|
9 |
+
def plug_configure_metrics(key_shortcut: Optional[str] = None) -> None:
|
10 |
+
interactive(
|
11 |
+
advanced_metrics=KeyboardControl(False, keydown=key_shortcut) if key_shortcut is not None else (True,)
|
12 |
+
)(configure_metrics)
|
13 |
+
|
14 |
+
|
15 |
+
def configure_metrics(advanced_metrics=False, global_params={}) -> None:
|
16 |
+
chosen_metrics = ALL_METRICS if advanced_metrics else [METRIC_PSNR, METRIC_SSIM]
|
17 |
+
global_params["chosen_metrics"] = chosen_metrics
|
18 |
+
|
19 |
+
|
20 |
+
def get_metrics(prediction: torch.Tensor, target: torch.Tensor,
|
21 |
+
image_name: str, # use functools.partial to root where you want the title to appear
|
22 |
+
global_params: dict = {}) -> None:
|
23 |
+
if isinstance(prediction, np.ndarray):
|
24 |
+
prediction_ = torch.from_numpy(prediction).permute(-1, 0, 1).float().unsqueeze(0)
|
25 |
+
else:
|
26 |
+
prediction_ = prediction
|
27 |
+
if isinstance(target, np.ndarray):
|
28 |
+
target_ = torch.from_numpy(target).permute(-1, 0, 1).float().unsqueeze(0)
|
29 |
+
else:
|
30 |
+
target_ = target
|
31 |
+
chosen_metrics = global_params.get("chosen_metrics", [METRIC_PSNR])
|
32 |
+
metrics = compute_metrics(prediction_, target_, chosen_metrics=chosen_metrics)
|
33 |
+
global_params["metrics"] = metrics
|
34 |
+
title = f"{image_name}: "
|
35 |
+
title += " ".join([f"{key}: {value:.4f}" for key, value in metrics.items()])
|
36 |
+
global_params["__output_styles"][image_name] = {"title": title, "image_name": image_name}
|
src/rstor/analyzis/interactive/model_selection.py
ADDED
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from interactive_pipe import KeyboardControl
|
3 |
+
from rstor.learning.experiments import get_training_content
|
4 |
+
from rstor.learning.experiments_definition import get_experiment_config
|
5 |
+
from rstor.properties import DEVICE, PRETTY_NAME
|
6 |
+
from tqdm import tqdm
|
7 |
+
from pathlib import Path
|
8 |
+
from typing import List, Tuple
|
9 |
+
|
10 |
+
from interactive_pipe import interactive
|
11 |
+
MODELS_PATH = Path("scripts")/"__output"
|
12 |
+
|
13 |
+
|
14 |
+
def model_selector(models_dict: dict, global_params={}, model_name="vanilla"):
|
15 |
+
if isinstance(model_name, str):
|
16 |
+
current_model = models_dict[model_name]
|
17 |
+
elif isinstance(model_name, int):
|
18 |
+
model_names = [name for name in models_dict.keys()]
|
19 |
+
current_model = models_dict[model_names[model_name % len(model_names)]]
|
20 |
+
else:
|
21 |
+
raise ValueError(f"Model name {model_name} not understood")
|
22 |
+
global_params["model_config"] = current_model["config"]
|
23 |
+
return current_model["model"]
|
24 |
+
|
25 |
+
|
26 |
+
def get_model_from_exp(exp: int, model_storage: Path = MODELS_PATH, device=DEVICE) -> Tuple[torch.nn.Module, dict]:
|
27 |
+
config = get_experiment_config(exp)
|
28 |
+
model, _, _ = get_training_content(config, training_mode=False)
|
29 |
+
model_path = torch.load(model_storage/f"{exp:04d}"/"best_model.pt")
|
30 |
+
assert model_path is not None, f"Model {exp} not found"
|
31 |
+
model.load_state_dict(model_path)
|
32 |
+
model = model.to(device)
|
33 |
+
return model, config
|
34 |
+
|
35 |
+
|
36 |
+
def get_default_models(
|
37 |
+
exp_list: List[int] = [1000, 1001],
|
38 |
+
model_storage: Path = MODELS_PATH,
|
39 |
+
keyboard_control: bool = False,
|
40 |
+
interactive_flag: bool = True
|
41 |
+
) -> dict:
|
42 |
+
model_dict = {}
|
43 |
+
assert model_storage.exists(), f"Model storage {model_storage} does not exist"
|
44 |
+
for exp in tqdm(exp_list, desc="Loading models"):
|
45 |
+
model, config = get_model_from_exp(exp, model_storage=model_storage)
|
46 |
+
name = config.get(PRETTY_NAME, f"{exp:04d}")
|
47 |
+
model_dict[name] = {
|
48 |
+
"model": model,
|
49 |
+
"config": config
|
50 |
+
}
|
51 |
+
exp_names = [name for name in model_dict.keys()]
|
52 |
+
if interactive_flag:
|
53 |
+
if keyboard_control:
|
54 |
+
model_control = KeyboardControl(0, [0, len(exp_names)-1], keydown="pagedown", keyup="pageup", modulo=True)
|
55 |
+
else:
|
56 |
+
model_control = (exp_names[0], exp_names)
|
57 |
+
interactive(model_name=model_control)(model_selector) # Create the model dialog
|
58 |
+
return model_dict
|
src/rstor/analyzis/interactive/pipelines.py
ADDED
@@ -0,0 +1,61 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from rstor.synthetic_data.interactive.interactive_dead_leaves import generate_deadleave
|
2 |
+
from rstor.analyzis.interactive.crop import crop_selector, crop, rescale_thumbnail
|
3 |
+
from rstor.analyzis.interactive.inference import infer
|
4 |
+
from rstor.analyzis.interactive.degradation import degrade_noise, degrade_blur, downsample, degrade_blur_gaussian, get_blur_kernel
|
5 |
+
from rstor.analyzis.interactive.model_selection import model_selector
|
6 |
+
from rstor.analyzis.interactive.images import image_selector
|
7 |
+
from rstor.analyzis.interactive.metrics import get_metrics, configure_metrics
|
8 |
+
from interactive_pipe import interactive, KeyboardControl
|
9 |
+
from typing import Tuple, List
|
10 |
+
from functools import partial
|
11 |
+
import numpy as np
|
12 |
+
|
13 |
+
|
14 |
+
get_metrics_restored = partial(get_metrics, image_name="restored")
|
15 |
+
get_metrics_degraded = partial(get_metrics, image_name="degraded")
|
16 |
+
|
17 |
+
|
18 |
+
def deadleave_inference_pipeline(models_dict: dict) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
|
19 |
+
groundtruth = generate_deadleave()
|
20 |
+
groundtruth = downsample(groundtruth)
|
21 |
+
model = model_selector(models_dict)
|
22 |
+
degraded = degrade_blur_gaussian(groundtruth)
|
23 |
+
degraded = degrade_noise(degraded)
|
24 |
+
restored = infer(degraded, model)
|
25 |
+
crop_selector(restored)
|
26 |
+
groundtruth, degraded, restored = crop(groundtruth, degraded, restored)
|
27 |
+
configure_metrics()
|
28 |
+
get_metrics_restored(restored, groundtruth)
|
29 |
+
get_metrics_degraded(degraded, groundtruth)
|
30 |
+
return groundtruth, degraded, restored
|
31 |
+
|
32 |
+
|
33 |
+
CANVAS_DICT = {
|
34 |
+
"demo": [["degraded", "restored"]],
|
35 |
+
"landscape_light": [["degraded", "restored", "groundtruth"]],
|
36 |
+
"landscape": [["degraded", "restored", "blur_kernel", "groundtruth"]],
|
37 |
+
"full": [["degraded", "restored"], ["blur_kernel", "groundtruth"]]
|
38 |
+
}
|
39 |
+
CANVAS = list(CANVAS_DICT.keys())
|
40 |
+
|
41 |
+
|
42 |
+
def morph_canvas(canvas=CANVAS[0], global_params={}):
|
43 |
+
global_params["__pipeline"].outputs = CANVAS_DICT[canvas]
|
44 |
+
return None
|
45 |
+
|
46 |
+
|
47 |
+
def natural_inference_pipeline(input_image_list: List[np.ndarray], models_dict: dict):
|
48 |
+
model = model_selector(models_dict)
|
49 |
+
img_clean = image_selector(input_image_list)
|
50 |
+
crop_selector(img_clean)
|
51 |
+
groundtruth = crop(img_clean)
|
52 |
+
blur_kernel = get_blur_kernel()
|
53 |
+
degraded = degrade_blur(groundtruth, blur_kernel)
|
54 |
+
degraded = degrade_noise(degraded)
|
55 |
+
blur_kernel = rescale_thumbnail(blur_kernel)
|
56 |
+
restored = infer(degraded, model)
|
57 |
+
configure_metrics()
|
58 |
+
get_metrics_restored(restored, groundtruth)
|
59 |
+
get_metrics_degraded(degraded, groundtruth)
|
60 |
+
morph_canvas()
|
61 |
+
return [[degraded, restored], [blur_kernel, groundtruth]]
|
src/rstor/analyzis/metrics_plots.py
ADDED
@@ -0,0 +1,73 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import matplotlib.pyplot as plt
|
3 |
+
import numpy as np
|
4 |
+
from pathlib import Path
|
5 |
+
import pandas as pd
|
6 |
+
|
7 |
+
|
8 |
+
def snr_to_sigma(snr):
|
9 |
+
return 10**(-snr/20.)*255.
|
10 |
+
|
11 |
+
|
12 |
+
def sigma_to_snr(sigma):
|
13 |
+
return -20.*np.log10(sigma/255.)
|
14 |
+
|
15 |
+
|
16 |
+
def plot_results(selected_paths, title=None, diff=True, ylim=None):
|
17 |
+
# plt.figure(figsize=(10, 10))
|
18 |
+
all_stats = {}
|
19 |
+
fig, ax = plt.subplots(layout='constrained', figsize=(10, 10))
|
20 |
+
for selected_path, selected_regex in selected_paths:
|
21 |
+
selected_path = Path(selected_path)
|
22 |
+
assert selected_path.exists()
|
23 |
+
results_path = sorted(list(selected_path.glob(selected_regex)))
|
24 |
+
stats = []
|
25 |
+
for result_path in results_path:
|
26 |
+
df = pd.read_csv(result_path)
|
27 |
+
in_psnr = df["in_PSNR"].mean()
|
28 |
+
out_psnr = df["out_PSNR"].mean()
|
29 |
+
out_ssim = df["out_SSIM"].mean()
|
30 |
+
noise_stddev = np.array([
|
31 |
+
float(el.replace("}", "").split(":")[1]) for el in df["degradation"]]).mean()
|
32 |
+
stats.append({
|
33 |
+
# "label": label,
|
34 |
+
"in_psnr": in_psnr,
|
35 |
+
"out_psnr": out_psnr,
|
36 |
+
"noise_stddev": noise_stddev,
|
37 |
+
"ssim": out_ssim
|
38 |
+
})
|
39 |
+
label = selected_path.name + " " + df["size"][0]
|
40 |
+
|
41 |
+
stats_array = pd.DataFrame(stats)
|
42 |
+
all_stats[label] = stats_array
|
43 |
+
x_data = stats_array["in_psnr"].copy()
|
44 |
+
x_data = snr_to_sigma(x_data)
|
45 |
+
|
46 |
+
ax.plot(
|
47 |
+
x_data,
|
48 |
+
stats_array["out_psnr"]-stats_array["in_psnr"] if diff else stats_array["out_psnr"],
|
49 |
+
"-o",
|
50 |
+
label=label
|
51 |
+
)
|
52 |
+
# label=selected_path.name)
|
53 |
+
if not diff:
|
54 |
+
neutral_sigma = np.linspace(1, 80, 80)
|
55 |
+
ax.plot(neutral_sigma, sigma_to_snr(neutral_sigma), "k--", alpha=0.1, label="Neutral")
|
56 |
+
secax = ax.secondary_xaxis('top', functions=(sigma_to_snr, snr_to_sigma))
|
57 |
+
secax.set_xlabel('PSNR [db]')
|
58 |
+
|
59 |
+
ax.set_xlabel("sigma 255")
|
60 |
+
ax.set_ylabel("PSNR improvement" if diff else "PSNR out")
|
61 |
+
plt.xlim(1., 50.)
|
62 |
+
if diff:
|
63 |
+
plt.ylim(0, 15)
|
64 |
+
else:
|
65 |
+
if ylim is not None:
|
66 |
+
plt.ylim(*ylim)
|
67 |
+
if title is not None:
|
68 |
+
plt.title(title)
|
69 |
+
plt.legend()
|
70 |
+
plt.grid()
|
71 |
+
plt.show()
|
72 |
+
|
73 |
+
return all_stats
|
src/rstor/analyzis/parser.py
ADDED
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from rstor.analyzis.interactive.model_selection import MODELS_PATH
|
2 |
+
import argparse
|
3 |
+
|
4 |
+
|
5 |
+
def get_models_parser(parser: argparse.ArgumentParser = None, help: str = "Inference",
|
6 |
+
default_models_path: str = MODELS_PATH) -> argparse.ArgumentParser:
|
7 |
+
if parser is None:
|
8 |
+
parser = argparse.ArgumentParser(description=help)
|
9 |
+
parser.add_argument("-e", "--experiments", type=int, nargs="+", required=True,
|
10 |
+
help="Experience indexes to be used at inference time")
|
11 |
+
parser.add_argument("-m", "--models-storage", type=str, help="Model storage path", default=default_models_path)
|
12 |
+
return parser
|
13 |
+
|
14 |
+
|
15 |
+
def get_parser(
|
16 |
+
parser: argparse.ArgumentParser = None,
|
17 |
+
help: str = "Live inference pipeline"
|
18 |
+
) -> argparse.ArgumentParser:
|
19 |
+
"""Generic parser for live interactive inference
|
20 |
+
"""
|
21 |
+
if parser is None:
|
22 |
+
parser = argparse.ArgumentParser(description=help)
|
23 |
+
get_models_parser(parser=parser, help=help)
|
24 |
+
parser.add_argument("-k", "--keyboard", action="store_true", help="Keyboard control - less sliders")
|
25 |
+
parser.add_argument("-b", "--backend", default="gradio", help="Backend to use for the GUI", choices=["gradio", "qt"])
|
26 |
+
return parser
|
src/rstor/architecture/base.py
ADDED
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from rstor.properties import LEAKY_RELU, RELU, SIMPLE_GATE
|
3 |
+
from typing import Optional, Tuple
|
4 |
+
|
5 |
+
|
6 |
+
class SimpleGate(torch.nn.Module):
|
7 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
8 |
+
x1, x2 = x.chunk(2, dim=1)
|
9 |
+
return x1 * x2
|
10 |
+
|
11 |
+
|
12 |
+
def get_non_linearity(activation: str):
|
13 |
+
if activation == LEAKY_RELU:
|
14 |
+
non_linearity = torch.nn.LeakyReLU()
|
15 |
+
elif activation == RELU:
|
16 |
+
non_linearity = torch.nn.ReLU()
|
17 |
+
elif activation is None:
|
18 |
+
non_linearity = torch.nn.Identity()
|
19 |
+
elif activation == SIMPLE_GATE:
|
20 |
+
non_linearity = SimpleGate()
|
21 |
+
else:
|
22 |
+
raise ValueError(f"Unknown activation {activation}")
|
23 |
+
return non_linearity
|
24 |
+
|
25 |
+
|
26 |
+
class BaseModel(torch.nn.Module):
|
27 |
+
"""Base class for all restoration models with additional useful methods"""
|
28 |
+
|
29 |
+
def count_parameters(self):
|
30 |
+
return sum(p.numel() for p in self.parameters() if p.requires_grad)
|
31 |
+
|
32 |
+
def receptive_field(
|
33 |
+
self,
|
34 |
+
channels: Optional[int] = 3,
|
35 |
+
size: Optional[int] = 256,
|
36 |
+
device: Optional[str] = None
|
37 |
+
) -> Tuple[int, int]:
|
38 |
+
"""Compute the receptive field of the model
|
39 |
+
|
40 |
+
Returns:
|
41 |
+
int: receptive field
|
42 |
+
"""
|
43 |
+
input_tensor = torch.ones(1, channels, size, size, requires_grad=True)
|
44 |
+
if device is not None:
|
45 |
+
input_tensor = input_tensor.to(device)
|
46 |
+
out = self.forward(input_tensor)
|
47 |
+
grad = torch.zeros_like(out)
|
48 |
+
grad[..., out.shape[-2]//2, out.shape[-1]//2] = torch.nan # set NaN gradient at the middle of the output
|
49 |
+
out.backward(gradient=grad)
|
50 |
+
self.zero_grad()
|
51 |
+
receptive_field_mask = input_tensor.grad.isnan()[0, 0]
|
52 |
+
receptive_field_indexes = torch.where(receptive_field_mask)
|
53 |
+
# Count NaN in the input
|
54 |
+
receptive_x = 1+receptive_field_indexes[-1].max() - receptive_field_indexes[-1].min() # Horizontal x
|
55 |
+
receptive_y = 1+receptive_field_indexes[-2].max() - receptive_field_indexes[-2].min() # Vertical y
|
56 |
+
return receptive_x.item(), receptive_y.item()
|
src/rstor/architecture/convolution_blocks.py
ADDED
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from rstor.properties import LEAKY_RELU
|
3 |
+
from rstor.architecture.base import get_non_linearity
|
4 |
+
|
5 |
+
|
6 |
+
class BaseConvolutionBlock(torch.nn.Module):
|
7 |
+
def __init__(
|
8 |
+
self,
|
9 |
+
ch_in: int,
|
10 |
+
ch_out: int,
|
11 |
+
k_size: int,
|
12 |
+
activation=LEAKY_RELU,
|
13 |
+
bias: bool = True
|
14 |
+
) -> None:
|
15 |
+
super().__init__()
|
16 |
+
self.conv = torch.nn.Conv2d(ch_in, ch_out, k_size, padding=k_size//2, bias=bias)
|
17 |
+
self.non_linearity = get_non_linearity(activation)
|
18 |
+
self.conv_non_lin = torch.nn.Sequential(self.conv, self.non_linearity)
|
19 |
+
|
20 |
+
def forward(self, x_in: torch.Tensor) -> torch.Tensor:
|
21 |
+
return self.conv_non_lin(x_in)
|
22 |
+
|
23 |
+
|
24 |
+
class ResConvolutionBlock(torch.nn.Module):
|
25 |
+
def __init__(
|
26 |
+
self,
|
27 |
+
ch_in: int,
|
28 |
+
ch_out: int,
|
29 |
+
k_size: int,
|
30 |
+
activation=LEAKY_RELU,
|
31 |
+
bias: bool = True,
|
32 |
+
residual: bool = True
|
33 |
+
) -> None:
|
34 |
+
super().__init__()
|
35 |
+
self.conv1 = torch.nn.Conv2d(ch_in, ch_out, k_size, padding=k_size//2, bias=bias)
|
36 |
+
self.non_linearity = get_non_linearity(activation)
|
37 |
+
self.conv2 = torch.nn.Conv2d(ch_out, ch_out, k_size, padding=k_size//2, bias=bias)
|
38 |
+
self.residual = residual
|
39 |
+
|
40 |
+
def forward(self, x_in: torch.Tensor) -> torch.Tensor:
|
41 |
+
y = self.conv1(x_in)
|
42 |
+
y = self.non_linearity(y)
|
43 |
+
y = self.conv2(y)
|
44 |
+
if self.residual:
|
45 |
+
y = x_in + y
|
46 |
+
y = self.non_linearity(y)
|
47 |
+
return y
|
src/rstor/architecture/nafnet.py
ADDED
@@ -0,0 +1,299 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
NAFNet: Non linear activation free neural network
|
3 |
+
Architecture adapted from Simple Baselines for Image Restoration
|
4 |
+
https://github.com/megvii-research/NAFNet/tree/main
|
5 |
+
"""
|
6 |
+
from torch import nn
|
7 |
+
import torch.nn.functional as F
|
8 |
+
import torch
|
9 |
+
from rstor.architecture.base import BaseModel, get_non_linearity
|
10 |
+
from typing import Optional, List
|
11 |
+
from rstor.properties import RELU, SIMPLE_GATE
|
12 |
+
|
13 |
+
|
14 |
+
class LayerNormFunction(torch.autograd.Function):
|
15 |
+
|
16 |
+
@staticmethod
|
17 |
+
def forward(ctx, x, weight, bias, eps):
|
18 |
+
ctx.eps = eps
|
19 |
+
N, C, H, W = x.size()
|
20 |
+
mu = x.mean(1, keepdim=True)
|
21 |
+
var = (x - mu).pow(2).mean(1, keepdim=True)
|
22 |
+
y = (x - mu) / (var + eps).sqrt()
|
23 |
+
ctx.save_for_backward(y, var, weight)
|
24 |
+
y = weight.view(1, C, 1, 1) * y + bias.view(1, C, 1, 1)
|
25 |
+
return y
|
26 |
+
|
27 |
+
@staticmethod
|
28 |
+
def backward(ctx, grad_output):
|
29 |
+
eps = ctx.eps
|
30 |
+
|
31 |
+
N, C, H, W = grad_output.size()
|
32 |
+
y, var, weight = ctx.saved_variables
|
33 |
+
g = grad_output * weight.view(1, C, 1, 1)
|
34 |
+
mean_g = g.mean(dim=1, keepdim=True)
|
35 |
+
|
36 |
+
mean_gy = (g * y).mean(dim=1, keepdim=True)
|
37 |
+
gx = 1. / torch.sqrt(var + eps) * (g - y * mean_gy - mean_g)
|
38 |
+
return gx, (grad_output * y).sum(dim=3).sum(dim=2).sum(dim=0), grad_output.sum(dim=3).sum(dim=2).sum(
|
39 |
+
dim=0), None
|
40 |
+
|
41 |
+
|
42 |
+
class LayerNorm2d(nn.Module):
|
43 |
+
def __init__(self, channels, eps=1e-6):
|
44 |
+
super(LayerNorm2d, self).__init__()
|
45 |
+
self.register_parameter('weight', nn.Parameter(torch.ones(channels)))
|
46 |
+
self.register_parameter('bias', nn.Parameter(torch.zeros(channels)))
|
47 |
+
self.eps = eps
|
48 |
+
|
49 |
+
def forward(self, x):
|
50 |
+
return LayerNormFunction.apply(x, self.weight, self.bias, self.eps)
|
51 |
+
|
52 |
+
|
53 |
+
class NAFBlock(nn.Module):
|
54 |
+
def __init__(
|
55 |
+
self,
|
56 |
+
c, DW_Expand=2, FFN_Expand=2, drop_out_rate=0.,
|
57 |
+
activation: Optional[str] = SIMPLE_GATE,
|
58 |
+
layer_norm_flag: Optional[bool] = True,
|
59 |
+
channel_attention_flag: Optional[bool] = True,
|
60 |
+
):
|
61 |
+
super().__init__()
|
62 |
+
self.layer_norm_flag = layer_norm_flag
|
63 |
+
self.channel_attention_flag = channel_attention_flag
|
64 |
+
dw_channel = c * DW_Expand
|
65 |
+
half_dw_channel = dw_channel // 2
|
66 |
+
self.conv1 = nn.Conv2d(in_channels=c, out_channels=dw_channel, kernel_size=1,
|
67 |
+
padding=0, stride=1, groups=1, bias=True)
|
68 |
+
self.conv2 = nn.Conv2d(
|
69 |
+
in_channels=dw_channel,
|
70 |
+
out_channels=dw_channel if activation == SIMPLE_GATE else half_dw_channel,
|
71 |
+
kernel_size=3,
|
72 |
+
padding=1, stride=1,
|
73 |
+
groups=dw_channel if activation == SIMPLE_GATE else half_dw_channel,
|
74 |
+
bias=True
|
75 |
+
)
|
76 |
+
# To grand the same amount of parameters between Simple Gate and ReLU versions...
|
77 |
+
# Conv2 has to reduce the number of channels to half but... using grouped convolution
|
78 |
+
# w -> w/2 ... not really a depthwise convolution but rather by channels of 2!
|
79 |
+
self.conv3 = nn.Conv2d(in_channels=half_dw_channel, out_channels=c,
|
80 |
+
kernel_size=1, padding=0, stride=1, groups=1, bias=True)
|
81 |
+
|
82 |
+
# Simplified Channel Attention
|
83 |
+
if self.channel_attention_flag:
|
84 |
+
self.sca = nn.Sequential(
|
85 |
+
nn.AdaptiveAvgPool2d(1),
|
86 |
+
nn.Conv2d(in_channels=half_dw_channel, out_channels=half_dw_channel, kernel_size=1,
|
87 |
+
padding=0, stride=1,
|
88 |
+
groups=1, bias=True),
|
89 |
+
)
|
90 |
+
|
91 |
+
# SimpleGate
|
92 |
+
self.sg = get_non_linearity(activation)
|
93 |
+
ffn_channel = FFN_Expand
|
94 |
+
half_ffn_channel = ffn_channel // 2 if activation == SIMPLE_GATE else ffn_channel
|
95 |
+
self.conv4 = nn.Conv2d(
|
96 |
+
in_channels=c,
|
97 |
+
out_channels=ffn_channel if activation == SIMPLE_GATE else half_ffn_channel,
|
98 |
+
kernel_size=1,
|
99 |
+
padding=0, stride=1, groups=1, bias=True)
|
100 |
+
self.conv5 = nn.Conv2d(in_channels=half_ffn_channel, out_channels=c,
|
101 |
+
kernel_size=1, padding=0, stride=1, groups=1, bias=True)
|
102 |
+
if self.layer_norm_flag:
|
103 |
+
self.norm1 = LayerNorm2d(c)
|
104 |
+
self.norm2 = LayerNorm2d(c)
|
105 |
+
|
106 |
+
self.dropout1 = nn.Dropout(drop_out_rate) if drop_out_rate > 0. else nn.Identity()
|
107 |
+
self.dropout2 = nn.Dropout(drop_out_rate) if drop_out_rate > 0. else nn.Identity()
|
108 |
+
|
109 |
+
self.beta = nn.Parameter(torch.zeros((1, c, 1, 1)), requires_grad=True)
|
110 |
+
self.gamma = nn.Parameter(torch.zeros((1, c, 1, 1)), requires_grad=True)
|
111 |
+
|
112 |
+
def forward(self, inp):
|
113 |
+
x = inp
|
114 |
+
if self.layer_norm_flag:
|
115 |
+
x = self.norm1(x)
|
116 |
+
|
117 |
+
x = self.conv1(x)
|
118 |
+
x = self.conv2(x)
|
119 |
+
x = self.sg(x)
|
120 |
+
if self.channel_attention_flag:
|
121 |
+
x = x * self.sca(x)
|
122 |
+
x = self.conv3(x)
|
123 |
+
|
124 |
+
x = self.dropout1(x)
|
125 |
+
|
126 |
+
y = inp + x * self.beta
|
127 |
+
|
128 |
+
x = self.conv4(self.norm2(y) if self.layer_norm_flag else y)
|
129 |
+
x = self.sg(x)
|
130 |
+
x = self.conv5(x)
|
131 |
+
|
132 |
+
x = self.dropout2(x)
|
133 |
+
|
134 |
+
return y + x * self.gamma
|
135 |
+
|
136 |
+
|
137 |
+
class NAFNet(BaseModel):
|
138 |
+
def __init__(
|
139 |
+
self,
|
140 |
+
img_channel: Optional[int] = 3,
|
141 |
+
width: Optional[int] = 16,
|
142 |
+
middle_blk_num: Optional[int] = 1,
|
143 |
+
enc_blk_nums: List[int] = [],
|
144 |
+
dec_blk_nums: List[int] = [],
|
145 |
+
activation: Optional[bool] = SIMPLE_GATE,
|
146 |
+
layer_norm_flag: Optional[bool] = True,
|
147 |
+
channel_attention_flag: Optional[bool] = True,
|
148 |
+
) -> None:
|
149 |
+
super().__init__()
|
150 |
+
|
151 |
+
self.intro = nn.Conv2d(
|
152 |
+
in_channels=img_channel,
|
153 |
+
out_channels=width,
|
154 |
+
kernel_size=3,
|
155 |
+
padding=1, stride=1,
|
156 |
+
groups=1,
|
157 |
+
bias=True
|
158 |
+
)
|
159 |
+
config_block = {
|
160 |
+
"activation": activation,
|
161 |
+
"layer_norm_flag": layer_norm_flag,
|
162 |
+
"channel_attention_flag": channel_attention_flag
|
163 |
+
}
|
164 |
+
self.ending = nn.Conv2d(
|
165 |
+
in_channels=width, out_channels=img_channel, kernel_size=3,
|
166 |
+
padding=1, stride=1, groups=1,
|
167 |
+
bias=True)
|
168 |
+
|
169 |
+
self.encoders = nn.ModuleList()
|
170 |
+
self.decoders = nn.ModuleList()
|
171 |
+
self.middle_blks = nn.ModuleList()
|
172 |
+
self.ups = nn.ModuleList()
|
173 |
+
self.downs = nn.ModuleList()
|
174 |
+
|
175 |
+
chan = width
|
176 |
+
for num in enc_blk_nums:
|
177 |
+
self.encoders.append(
|
178 |
+
nn.Sequential(
|
179 |
+
*[NAFBlock(chan, **config_block) for _ in range(num)]
|
180 |
+
)
|
181 |
+
)
|
182 |
+
self.downs.append(
|
183 |
+
nn.Conv2d(chan, 2*chan, 2, 2)
|
184 |
+
)
|
185 |
+
chan = chan * 2
|
186 |
+
|
187 |
+
self.middle_blks = \
|
188 |
+
nn.Sequential(
|
189 |
+
*[NAFBlock(chan, **config_block) for _ in range(middle_blk_num)]
|
190 |
+
)
|
191 |
+
|
192 |
+
for num in dec_blk_nums:
|
193 |
+
self.ups.append(
|
194 |
+
nn.Sequential(
|
195 |
+
nn.Conv2d(chan, chan * 2, 1, bias=False),
|
196 |
+
nn.PixelShuffle(2)
|
197 |
+
)
|
198 |
+
)
|
199 |
+
chan = chan // 2
|
200 |
+
self.decoders.append(
|
201 |
+
nn.Sequential(
|
202 |
+
*[NAFBlock(chan, **config_block) for _ in range(num)]
|
203 |
+
)
|
204 |
+
)
|
205 |
+
|
206 |
+
self.padder_size = 2 ** len(self.encoders)
|
207 |
+
|
208 |
+
def forward(self, inp: torch.Tensor) -> torch.Tensor:
|
209 |
+
B, C, H, W = inp.shape
|
210 |
+
inp = self.sanitize_image_size(inp)
|
211 |
+
|
212 |
+
x = self.intro(inp)
|
213 |
+
|
214 |
+
encs = []
|
215 |
+
|
216 |
+
for encoder, down in zip(self.encoders, self.downs):
|
217 |
+
x = encoder(x)
|
218 |
+
encs.append(x)
|
219 |
+
x = down(x)
|
220 |
+
|
221 |
+
x = self.middle_blks(x)
|
222 |
+
|
223 |
+
for decoder, up, enc_skip in zip(self.decoders, self.ups, encs[::-1]):
|
224 |
+
x = up(x)
|
225 |
+
x = x + enc_skip
|
226 |
+
x = decoder(x)
|
227 |
+
|
228 |
+
x = self.ending(x)
|
229 |
+
x = x + inp
|
230 |
+
|
231 |
+
return x[:, :, :H, :W]
|
232 |
+
|
233 |
+
def sanitize_image_size(self, x: torch.Tensor) -> torch.Tensor:
|
234 |
+
_, _, h, w = x.size()
|
235 |
+
mod_pad_h = (self.padder_size - h % self.padder_size) % self.padder_size
|
236 |
+
mod_pad_w = (self.padder_size - w % self.padder_size) % self.padder_size
|
237 |
+
x = F.pad(x, (0, mod_pad_w, 0, mod_pad_h))
|
238 |
+
return x
|
239 |
+
|
240 |
+
|
241 |
+
class UNet(NAFNet):
|
242 |
+
def __init__(
|
243 |
+
self,
|
244 |
+
activation: Optional[bool] = RELU,
|
245 |
+
layer_norm_flag: Optional[bool] = False,
|
246 |
+
channel_attention_flag: Optional[bool] = False,
|
247 |
+
**kwargs):
|
248 |
+
super().__init__(
|
249 |
+
activation=activation,
|
250 |
+
layer_norm_flag=layer_norm_flag,
|
251 |
+
channel_attention_flag=channel_attention_flag, **kwargs)
|
252 |
+
|
253 |
+
|
254 |
+
if __name__ == '__main__':
|
255 |
+
tiny_recetive_field = True
|
256 |
+
if tiny_recetive_field:
|
257 |
+
enc_blks = [1, 1, 2]
|
258 |
+
middle_blk_num = 1
|
259 |
+
dec_blks = [1, 1, 1]
|
260 |
+
width = 16
|
261 |
+
# Receptive field is 208x208
|
262 |
+
else:
|
263 |
+
enc_blks = [1, 1, 1, 28]
|
264 |
+
middle_blk_num = 1
|
265 |
+
dec_blks = [1, 1, 1, 1]
|
266 |
+
width = 2
|
267 |
+
# Receptive field is 544x544
|
268 |
+
device = "cpu"
|
269 |
+
|
270 |
+
for model_name in ["NAFNet", "UNet"]:
|
271 |
+
if model_name == "NAFNet":
|
272 |
+
model = NAFNet(
|
273 |
+
img_channel=3,
|
274 |
+
width=width,
|
275 |
+
middle_blk_num=middle_blk_num,
|
276 |
+
enc_blk_nums=enc_blks,
|
277 |
+
dec_blk_nums=dec_blks,
|
278 |
+
activation=SIMPLE_GATE,
|
279 |
+
layer_norm_flag=False,
|
280 |
+
channel_attention_flag=False
|
281 |
+
)
|
282 |
+
if model_name == "UNet":
|
283 |
+
model = UNet(
|
284 |
+
img_channel=3,
|
285 |
+
width=width,
|
286 |
+
middle_blk_num=middle_blk_num,
|
287 |
+
enc_blk_nums=enc_blks,
|
288 |
+
dec_blk_nums=dec_blks
|
289 |
+
)
|
290 |
+
model.to(device)
|
291 |
+
with torch.no_grad():
|
292 |
+
x = torch.randn(1, 3, 256, 256).to(device)
|
293 |
+
y = model(x)
|
294 |
+
|
295 |
+
# print(y.shape)
|
296 |
+
# print(y)
|
297 |
+
# print(model)
|
298 |
+
print(f"{model.count_parameters()/1E3:.2f}k parameters")
|
299 |
+
print(model.receptive_field(size=256 if tiny_recetive_field else 1024, device=device))
|
src/rstor/architecture/selector.py
ADDED
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from rstor.properties import MODEL, NAME, N_PARAMS, ARCHITECTURE
|
2 |
+
from rstor.architecture.stacked_convolutions import StackedConvolutions
|
3 |
+
from rstor.architecture.nafnet import NAFNet, UNet
|
4 |
+
import torch
|
5 |
+
|
6 |
+
|
7 |
+
def load_architecture(config: dict) -> torch.nn.Module:
|
8 |
+
conf_model = config[MODEL][ARCHITECTURE]
|
9 |
+
if config[MODEL][NAME] == StackedConvolutions.__name__:
|
10 |
+
model = StackedConvolutions(**conf_model)
|
11 |
+
elif config[MODEL][NAME] == NAFNet.__name__:
|
12 |
+
model = NAFNet(**conf_model)
|
13 |
+
elif config[MODEL][NAME] == UNet.__name__:
|
14 |
+
model = UNet(**conf_model)
|
15 |
+
else:
|
16 |
+
raise ValueError(f"Unknown model {config[MODEL][NAME]}")
|
17 |
+
config[MODEL][N_PARAMS] = model.count_parameters()
|
18 |
+
config[MODEL]["receptive_field"] = model.receptive_field()
|
19 |
+
return model
|
src/rstor/architecture/stacked_convolutions.py
ADDED
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from rstor.architecture.base import BaseModel
|
2 |
+
from rstor.architecture.convolution_blocks import BaseConvolutionBlock, ResConvolutionBlock
|
3 |
+
from rstor.properties import LEAKY_RELU
|
4 |
+
import torch
|
5 |
+
|
6 |
+
|
7 |
+
class StackedConvolutions(BaseModel):
|
8 |
+
def __init__(self,
|
9 |
+
ch_in: int = 3,
|
10 |
+
ch_out: int = 3,
|
11 |
+
h_dim: int = 64,
|
12 |
+
num_layers: int = 8,
|
13 |
+
k_size: int = 3,
|
14 |
+
activation: str = LEAKY_RELU,
|
15 |
+
bias: bool = True,
|
16 |
+
) -> None:
|
17 |
+
super().__init__()
|
18 |
+
assert num_layers % 2 == 0, "Number of layers should be even"
|
19 |
+
self.conv_in_modality = BaseConvolutionBlock(
|
20 |
+
ch_in, h_dim, k_size, activation=activation, bias=bias)
|
21 |
+
conv_list = []
|
22 |
+
for _i in range(num_layers-2):
|
23 |
+
conv_list.append(ResConvolutionBlock(
|
24 |
+
h_dim, h_dim, k_size, activation=activation, bias=bias, residual=True))
|
25 |
+
self.conv_out_modality = BaseConvolutionBlock(
|
26 |
+
h_dim, ch_out, k_size, activation=None, bias=bias)
|
27 |
+
self.conv_stack = torch.nn.Sequential(self.conv_in_modality, *conv_list, self.conv_out_modality)
|
28 |
+
|
29 |
+
def forward(self, x_in: torch.Tensor) -> torch.Tensor:
|
30 |
+
return self.conv_stack(x_in)
|
src/rstor/data/augmentation.py
ADDED
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from typing import Tuple, Optional
|
3 |
+
|
4 |
+
|
5 |
+
def augment_flip(
|
6 |
+
img: torch.Tensor,
|
7 |
+
flip: Optional[Tuple[bool, bool]] = None
|
8 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
9 |
+
"""Roll pixels horizontally to avoid negative index
|
10 |
+
|
11 |
+
Args:
|
12 |
+
img (torch.Tensor): [N, 3, H, W] image tensor
|
13 |
+
lab (torch.Tensor): [N, 3, H, W] label tensor
|
14 |
+
flip (Optional[bool], optional): forced flip_h, flip_v value. Defaults to None.
|
15 |
+
If not provided, a random flip_h, flip_v values are used
|
16 |
+
Returns:
|
17 |
+
torch.Tensor, torch.Tensor: flipped image, labels
|
18 |
+
|
19 |
+
"""
|
20 |
+
if flip is None:
|
21 |
+
flip = torch.randint(0, 2, (2,))
|
22 |
+
flipped_img = img
|
23 |
+
if flip[0] > 0:
|
24 |
+
flipped_img = torch.flip(flipped_img, (-1,))
|
25 |
+
if flip[1] > 0:
|
26 |
+
flipped_img = torch.flip(flipped_img, (-2,))
|
27 |
+
return flipped_img
|
src/rstor/data/dataloader.py
ADDED
@@ -0,0 +1,120 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from torch.utils.data import DataLoader
|
2 |
+
from rstor.data.synthetic_dataloader import DeadLeavesDataset, DeadLeavesDatasetGPU
|
3 |
+
from rstor.data.stored_images_dataloader import RestorationDataset
|
4 |
+
from rstor.properties import (
|
5 |
+
DATALOADER, BATCH_SIZE, TRAIN, VALIDATION, LENGTH, CONFIG_DEAD_LEAVES, SIZE, NAME, CONFIG_DEGRADATION,
|
6 |
+
DATASET_SYNTH_LIST, DATASET_DIV2K,
|
7 |
+
DATASET_PATH
|
8 |
+
)
|
9 |
+
from typing import Optional
|
10 |
+
from random import seed, shuffle
|
11 |
+
|
12 |
+
|
13 |
+
def get_data_loader_synthetic(config, frozen_seed=42):
|
14 |
+
# print(config[DATALOADER].get(CONFIG_DEAD_LEAVES, {}))
|
15 |
+
if config[DATALOADER].get("gpu_gen", False):
|
16 |
+
print("Using GPU dead leaves generator")
|
17 |
+
ds = DeadLeavesDatasetGPU
|
18 |
+
else:
|
19 |
+
ds = DeadLeavesDataset
|
20 |
+
dl_train = ds(config[DATALOADER][SIZE], config[DATALOADER][LENGTH][TRAIN],
|
21 |
+
frozen_seed=None, **config[DATALOADER].get(CONFIG_DEAD_LEAVES, {}))
|
22 |
+
dl_valid = ds(config[DATALOADER][SIZE], config[DATALOADER][LENGTH][VALIDATION],
|
23 |
+
frozen_seed=frozen_seed, **config[DATALOADER].get(CONFIG_DEAD_LEAVES, {}))
|
24 |
+
dl_dict = create_dataloaders(config, dl_train, dl_valid)
|
25 |
+
return dl_dict
|
26 |
+
|
27 |
+
|
28 |
+
def create_dataloaders(config, dl_train, dl_valid) -> dict:
|
29 |
+
dl_dict = {
|
30 |
+
TRAIN: DataLoader(
|
31 |
+
dl_train,
|
32 |
+
shuffle=True,
|
33 |
+
batch_size=config[DATALOADER][BATCH_SIZE][TRAIN],
|
34 |
+
),
|
35 |
+
VALIDATION: DataLoader(
|
36 |
+
dl_valid,
|
37 |
+
shuffle=False,
|
38 |
+
batch_size=config[DATALOADER][BATCH_SIZE][VALIDATION]
|
39 |
+
),
|
40 |
+
# TEST: DataLoader(dl_test, shuffle=False, batch_size=config[DATALOADER][BATCH_SIZE][TEST])
|
41 |
+
}
|
42 |
+
return dl_dict
|
43 |
+
|
44 |
+
|
45 |
+
def get_data_loader_from_disk(config, frozen_seed: Optional[int] = 42) -> dict:
|
46 |
+
ds = RestorationDataset
|
47 |
+
dataset_name = config[DATALOADER][NAME] # NAME shall be here!
|
48 |
+
if dataset_name == DATASET_DIV2K:
|
49 |
+
dataset_root = DATASET_PATH/DATASET_DIV2K
|
50 |
+
train_root = dataset_root/"DIV2K_train_HR"/"DIV2K_train_HR"
|
51 |
+
valid_root = dataset_root/"DIV2K_valid_HR"/"DIV2K_valid_HR"
|
52 |
+
train_files = sorted(list(train_root.glob("*.png")))
|
53 |
+
train_files = 5*train_files # Just to get 4000 elements...
|
54 |
+
valid_files = sorted(list(valid_root.glob("*.png")))
|
55 |
+
elif dataset_name in DATASET_SYNTH_LIST:
|
56 |
+
dataset_root = DATASET_PATH/dataset_name
|
57 |
+
all_files = sorted(list(dataset_root.glob("*.png")))
|
58 |
+
seed(frozen_seed)
|
59 |
+
shuffle(all_files) # Easy way to perform cross validation if neeeded
|
60 |
+
cut_index = int(0.9*len(all_files))
|
61 |
+
train_files = all_files[:cut_index]
|
62 |
+
valid_files = all_files[cut_index:]
|
63 |
+
dl_train = ds(
|
64 |
+
train_files,
|
65 |
+
size=config[DATALOADER][SIZE],
|
66 |
+
frozen_seed=None,
|
67 |
+
**config[DATALOADER].get(CONFIG_DEGRADATION, {})
|
68 |
+
)
|
69 |
+
dl_valid = ds(
|
70 |
+
valid_files,
|
71 |
+
size=config[DATALOADER][SIZE],
|
72 |
+
frozen_seed=frozen_seed,
|
73 |
+
**config[DATALOADER].get(CONFIG_DEGRADATION, {})
|
74 |
+
)
|
75 |
+
dl_dict = create_dataloaders(config, dl_train, dl_valid)
|
76 |
+
return dl_dict
|
77 |
+
|
78 |
+
|
79 |
+
def get_data_loader(config, frozen_seed=42):
|
80 |
+
dataset_name = config[DATALOADER].get(NAME, False)
|
81 |
+
if dataset_name:
|
82 |
+
return get_data_loader_from_disk(config, frozen_seed)
|
83 |
+
else:
|
84 |
+
return get_data_loader_synthetic(config, frozen_seed)
|
85 |
+
|
86 |
+
|
87 |
+
if __name__ == "__main__":
|
88 |
+
# Example of usage synthetic dataset
|
89 |
+
for dataset_name in [DATASET_DIV2K, None, DATASET_DL_DIV2K_512, DATASET_DL_DIV2K_1024]:
|
90 |
+
if dataset_name is None:
|
91 |
+
dead_leaves_dataset = DeadLeavesDatasetGPU(colored=True)
|
92 |
+
dl = DataLoader(dead_leaves_dataset, batch_size=4, shuffle=True)
|
93 |
+
else:
|
94 |
+
# Example of usage stored images dataset
|
95 |
+
config = {
|
96 |
+
DATALOADER: {
|
97 |
+
NAME: dataset_name,
|
98 |
+
SIZE: (128, 128),
|
99 |
+
BATCH_SIZE: {
|
100 |
+
TRAIN: 4,
|
101 |
+
VALIDATION: 4
|
102 |
+
},
|
103 |
+
}
|
104 |
+
}
|
105 |
+
dl_dict = get_data_loader(config)
|
106 |
+
dl = dl_dict[TRAIN]
|
107 |
+
# dl = dl_dict[VALIDATION]
|
108 |
+
for i, (batch_inp, batch_target) in enumerate(dl):
|
109 |
+
print(batch_inp.shape, batch_target.shape) # Should print [batch_size, size[0], size[1], 3] for each batch
|
110 |
+
if i == 1: # Just to break the loop after two batches for demonstration
|
111 |
+
import matplotlib.pyplot as plt
|
112 |
+
plt.subplot(1, 2, 1)
|
113 |
+
plt.imshow(batch_inp.permute(0, 2, 3, 1).reshape(-1, batch_inp.shape[-1], 3).cpu().numpy())
|
114 |
+
plt.title("Degraded")
|
115 |
+
plt.subplot(1, 2, 2)
|
116 |
+
plt.imshow(batch_target.permute(0, 2, 3, 1).reshape(-1, batch_inp.shape[-1], 3).cpu().numpy())
|
117 |
+
plt.title("Target")
|
118 |
+
plt.show()
|
119 |
+
# print(batch_target)
|
120 |
+
break
|
src/rstor/data/degradation.py
ADDED
@@ -0,0 +1,156 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# -*- coding: utf-8 -*-
|
2 |
+
"""
|
3 |
+
Created on Sun Mar 24 01:21:46 2024
|
4 |
+
|
5 |
+
@author: jamyl
|
6 |
+
"""
|
7 |
+
import torch
|
8 |
+
from rstor.properties import DATASET_BLUR_KERNEL_PATH
|
9 |
+
import random
|
10 |
+
from scipy.io import loadmat
|
11 |
+
import cv2
|
12 |
+
|
13 |
+
|
14 |
+
class Degradation():
|
15 |
+
def __init__(self,
|
16 |
+
length: int = 1000,
|
17 |
+
frozen_seed: int = None):
|
18 |
+
self.frozen_seed = frozen_seed
|
19 |
+
self.current_degradation = {}
|
20 |
+
|
21 |
+
|
22 |
+
class DegradationNoise(Degradation):
|
23 |
+
def __init__(self,
|
24 |
+
length: int = 1000,
|
25 |
+
noise_stddev: float = [0., 50.],
|
26 |
+
frozen_seed: int = None):
|
27 |
+
super().__init__(length, frozen_seed)
|
28 |
+
self.noise_stddev = noise_stddev
|
29 |
+
|
30 |
+
if frozen_seed is not None:
|
31 |
+
random.seed(frozen_seed)
|
32 |
+
self.noise_stddev = [(self.noise_stddev[1] - self.noise_stddev[0]) *
|
33 |
+
random.random() + self.noise_stddev[0] for _ in range(length)]
|
34 |
+
|
35 |
+
def __call__(self, x: torch.Tensor, idx: int):
|
36 |
+
# WARNING! INPLACE OPERATIONS!!!!!
|
37 |
+
# expects x of shape [b, c, h, w]
|
38 |
+
assert x.ndim == 4
|
39 |
+
assert x.shape[1] in [1, 3]
|
40 |
+
|
41 |
+
if self.frozen_seed is not None:
|
42 |
+
std_dev = self.noise_stddev[idx]
|
43 |
+
else:
|
44 |
+
std_dev = (self.noise_stddev[1] - self.noise_stddev[0]) * random.random() + self.noise_stddev[0]
|
45 |
+
|
46 |
+
if std_dev > 0.:
|
47 |
+
# x += (std_dev/255.)*np.random.randn(*x.shape)
|
48 |
+
x += (std_dev/255.)*torch.randn(*x.shape, device=x.device)
|
49 |
+
self.current_degradation[idx] = {
|
50 |
+
"noise_stddev": std_dev
|
51 |
+
}
|
52 |
+
return x
|
53 |
+
|
54 |
+
|
55 |
+
class DegradationBlurMat(Degradation):
|
56 |
+
def __init__(self,
|
57 |
+
length: int = 1000,
|
58 |
+
frozen_seed: int = None,
|
59 |
+
blur_index: int = None):
|
60 |
+
super().__init__(length, frozen_seed)
|
61 |
+
|
62 |
+
kernels = loadmat(DATASET_BLUR_KERNEL_PATH)["kernels"].squeeze()
|
63 |
+
# conversion to torch (the shape of the kernel is not constant)
|
64 |
+
self.kernels = tuple([
|
65 |
+
torch.from_numpy(kernel/kernel.sum(keepdims=True)).unsqueeze(0).unsqueeze(0)
|
66 |
+
for kernel in kernels] + [torch.ones((1, 1)).unsqueeze(0).unsqueeze(0)])
|
67 |
+
self.n_kernels = len(self.kernels)
|
68 |
+
|
69 |
+
if frozen_seed is not None:
|
70 |
+
random.seed(frozen_seed)
|
71 |
+
self.kernel_ids = [random.randint(0, self.n_kernels-1) for _ in range(length)]
|
72 |
+
if blur_index is not None:
|
73 |
+
self.frozen_seed = 42
|
74 |
+
self.kernel_ids = [blur_index for _ in range(length)]
|
75 |
+
|
76 |
+
def __call__(self, x: torch.Tensor, idx: int):
|
77 |
+
# expects x of shape [b, c, h, w]
|
78 |
+
assert x.ndim == 4
|
79 |
+
assert x.shape[1] in [1, 3]
|
80 |
+
device = x.device
|
81 |
+
|
82 |
+
if self.frozen_seed is not None:
|
83 |
+
kernel_id = self.kernel_ids[idx]
|
84 |
+
else:
|
85 |
+
kernel_id = random.randint(0, self.n_kernels-1)
|
86 |
+
|
87 |
+
kernel = self.kernels[kernel_id].to(device).repeat(3, 1, 1, 1).float() # repeat for grouped conv
|
88 |
+
_, _, kh, kw = kernel.shape
|
89 |
+
# We use padding = same to make
|
90 |
+
# sure that the output size does not depend on the kernel.
|
91 |
+
|
92 |
+
# define nn.Conf layer to define both padding mode and padding value...
|
93 |
+
conv_layer = torch.nn.Conv2d(in_channels=x.shape[1],
|
94 |
+
out_channels=x.shape[1],
|
95 |
+
kernel_size=(kh, kw),
|
96 |
+
padding="same",
|
97 |
+
padding_mode='replicate',
|
98 |
+
groups=3,
|
99 |
+
bias=False)
|
100 |
+
|
101 |
+
# Set the predefined kernel as weights and freeze the parameters
|
102 |
+
with torch.no_grad():
|
103 |
+
conv_layer.weight = torch.nn.Parameter(kernel)
|
104 |
+
conv_layer.weight.requires_grad = False
|
105 |
+
# breakpoint()
|
106 |
+
x = conv_layer(x)
|
107 |
+
# Alternative Functional version with 0 padding :
|
108 |
+
# x = F.conv2d(x, kernel, padding="same", groups=3)
|
109 |
+
|
110 |
+
self.current_degradation[idx] = {
|
111 |
+
"blur_kernel_id": kernel_id
|
112 |
+
}
|
113 |
+
return x
|
114 |
+
|
115 |
+
|
116 |
+
class DegradationBlurGauss(Degradation):
|
117 |
+
def __init__(self,
|
118 |
+
length: int = 1000,
|
119 |
+
blur_kernel_half_size: int = [0, 2],
|
120 |
+
frozen_seed: int = None):
|
121 |
+
super().__init__(length, frozen_seed)
|
122 |
+
|
123 |
+
self.blur_kernel_half_size = blur_kernel_half_size
|
124 |
+
# conversion to torch (the shape of the kernel is not constant)
|
125 |
+
if frozen_seed is not None:
|
126 |
+
random.seed(self.frozen_seed)
|
127 |
+
self.blur_kernel_half_size = [
|
128 |
+
(
|
129 |
+
random.randint(self.blur_kernel_half_size[0], self.blur_kernel_half_size[1]),
|
130 |
+
random.randint(self.blur_kernel_half_size[0], self.blur_kernel_half_size[1])
|
131 |
+
) for _ in range(length)
|
132 |
+
]
|
133 |
+
|
134 |
+
def __call__(self, x: torch.Tensor, idx: int):
|
135 |
+
# expects x of shape [b, c, h, w]
|
136 |
+
assert x.ndim == 4
|
137 |
+
assert x.shape[1] in [1, 3]
|
138 |
+
device = x.device
|
139 |
+
|
140 |
+
if self.frozen_seed is not None:
|
141 |
+
k_size_x, k_size_y = self.blur_kernel_half_size[idx]
|
142 |
+
else:
|
143 |
+
k_size_x = random.randint(self.blur_kernel_half_size[0], self.blur_kernel_half_size[1])
|
144 |
+
k_size_y = random.randint(self.blur_kernel_half_size[0], self.blur_kernel_half_size[1])
|
145 |
+
|
146 |
+
k_size_x = 2 * k_size_x + 1
|
147 |
+
k_size_y = 2 * k_size_y + 1
|
148 |
+
|
149 |
+
x = x.squeeze(0).permute(1, 2, 0).cpu().numpy()
|
150 |
+
x = cv2.GaussianBlur(x, (k_size_x, k_size_y), 0)
|
151 |
+
x = torch.from_numpy(x).to(device).permute(2, 0, 1).unsqueeze(0)
|
152 |
+
|
153 |
+
self.current_degradation[idx] = {
|
154 |
+
"blur_kernel_half_size": (k_size_x, k_size_y),
|
155 |
+
}
|
156 |
+
return x
|
src/rstor/data/stored_images_dataloader.py
ADDED
@@ -0,0 +1,156 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torch.utils.data import DataLoader, Dataset
|
3 |
+
from rstor.data.augmentation import augment_flip
|
4 |
+
from rstor.data.degradation import DegradationBlurMat, DegradationBlurGauss, DegradationNoise
|
5 |
+
from rstor.properties import DEVICE, AUGMENTATION_FLIP, AUGMENTATION_ROTATE, DEGRADATION_BLUR_NONE, DEGRADATION_BLUR_MAT, DEGRADATION_BLUR_GAUSS
|
6 |
+
from rstor.properties import DATALOADER, BATCH_SIZE, TRAIN, VALIDATION, LENGTH, CONFIG_DEAD_LEAVES, SIZE
|
7 |
+
from typing import Tuple, Optional, Union
|
8 |
+
from torchvision import transforms
|
9 |
+
# from torchvision.transforms import RandomCrop
|
10 |
+
from pathlib import Path
|
11 |
+
from tqdm import tqdm
|
12 |
+
from time import time
|
13 |
+
from torchvision.io import read_image
|
14 |
+
IMAGES_FOLDER = "images"
|
15 |
+
|
16 |
+
|
17 |
+
def load_image(path):
|
18 |
+
return read_image(str(path))
|
19 |
+
|
20 |
+
|
21 |
+
class RestorationDataset(Dataset):
|
22 |
+
def __init__(
|
23 |
+
self,
|
24 |
+
images_path: Path,
|
25 |
+
size: Tuple[int, int] = (128, 128),
|
26 |
+
device: str = DEVICE,
|
27 |
+
preloaded: bool = False,
|
28 |
+
augmentation_list: Optional[list] = [],
|
29 |
+
frozen_seed: int = None, # useful for validation set!
|
30 |
+
blur_kernel_half_size: int = [0, 2],
|
31 |
+
noise_stddev: float = [0., 50.],
|
32 |
+
degradation_blur=DEGRADATION_BLUR_NONE,
|
33 |
+
blur_index=None,
|
34 |
+
**_extra_kwargs
|
35 |
+
):
|
36 |
+
self.preloaded = preloaded
|
37 |
+
self.augmentation_list = augmentation_list
|
38 |
+
self.device = device
|
39 |
+
self.frozen_seed = frozen_seed
|
40 |
+
if not isinstance(images_path, list):
|
41 |
+
self.path_list = sorted(list(images_path.glob("*.png")))
|
42 |
+
else:
|
43 |
+
self.path_list = images_path
|
44 |
+
|
45 |
+
self.length = len(self.path_list)
|
46 |
+
|
47 |
+
self.n_samples = len(self.path_list)
|
48 |
+
# If we can preload everything in memory, we can do it
|
49 |
+
if preloaded:
|
50 |
+
self.data_list = [load_image(pth) for pth in tqdm(self.path_list)]
|
51 |
+
else:
|
52 |
+
self.data_list = self.path_list
|
53 |
+
|
54 |
+
# if AUGMENTATION_FLIP in self.augmentation_list:
|
55 |
+
# img_data = augment_flip(img_data)
|
56 |
+
# img_data = self.cropper(img_data)
|
57 |
+
self.transforms = []
|
58 |
+
|
59 |
+
if self.frozen_seed is None:
|
60 |
+
if AUGMENTATION_FLIP in self.augmentation_list:
|
61 |
+
self.transforms.append(transforms.RandomHorizontalFlip(p=0.5))
|
62 |
+
self.transforms.append(transforms.RandomVerticalFlip(p=0.5))
|
63 |
+
if AUGMENTATION_ROTATE in self.augmentation_list:
|
64 |
+
self.transforms.append(transforms.RandomRotation(degrees=180))
|
65 |
+
|
66 |
+
crop = transforms.RandomCrop(size) if frozen_seed is None else transforms.CenterCrop(size)
|
67 |
+
self.transforms.append(crop)
|
68 |
+
self.transforms = transforms.Compose(self.transforms)
|
69 |
+
|
70 |
+
# self.cropper = RandomCrop(size=size)
|
71 |
+
|
72 |
+
self.degradation_blur_type = degradation_blur
|
73 |
+
if degradation_blur == DEGRADATION_BLUR_GAUSS:
|
74 |
+
self.degradation_blur = DegradationBlurGauss(self.length,
|
75 |
+
blur_kernel_half_size,
|
76 |
+
frozen_seed)
|
77 |
+
self.blur_deg_str = "blur_kernel_half_size"
|
78 |
+
elif degradation_blur == DEGRADATION_BLUR_MAT:
|
79 |
+
self.degradation_blur = DegradationBlurMat(self.length,
|
80 |
+
frozen_seed,
|
81 |
+
blur_index)
|
82 |
+
self.blur_deg_str = "blur_kernel_id"
|
83 |
+
elif degradation_blur == DEGRADATION_BLUR_NONE:
|
84 |
+
pass
|
85 |
+
else:
|
86 |
+
raise ValueError(f"Unknown degradation blur {degradation_blur}")
|
87 |
+
|
88 |
+
self.degradation_noise = DegradationNoise(self.length,
|
89 |
+
noise_stddev,
|
90 |
+
frozen_seed)
|
91 |
+
self.current_degradation = {}
|
92 |
+
|
93 |
+
def __getitem__(self, index: int) -> Tuple[torch.Tensor, Union[torch.Tensor, None]]:
|
94 |
+
"""Access a specific image from dataset and augment
|
95 |
+
|
96 |
+
Args:
|
97 |
+
index (int): access index
|
98 |
+
|
99 |
+
Returns:
|
100 |
+
torch.Tensor: image tensor [C, H, W]
|
101 |
+
"""
|
102 |
+
if self.preloaded:
|
103 |
+
img_data = self.data_list[index]
|
104 |
+
else:
|
105 |
+
img_data = load_image(self.data_list[index])
|
106 |
+
img_data = img_data.to(self.device)
|
107 |
+
|
108 |
+
# if AUGMENTATION_FLIP in self.augmentation_list:
|
109 |
+
# img_data = augment_flip(img_data)
|
110 |
+
# img_data = self.cropper(img_data)
|
111 |
+
|
112 |
+
img_data = self.transforms(img_data)
|
113 |
+
img_data = img_data.float()/255.
|
114 |
+
degraded_img = img_data.clone().unsqueeze(0)
|
115 |
+
|
116 |
+
self.current_degradation[index] = {}
|
117 |
+
if self.degradation_blur_type != DEGRADATION_BLUR_NONE:
|
118 |
+
degraded_img = self.degradation_blur(degraded_img, index)
|
119 |
+
self.current_degradation[index][self.blur_deg_str] = self.degradation_blur.current_degradation[index][self.blur_deg_str]
|
120 |
+
|
121 |
+
degraded_img = self.degradation_noise(degraded_img, index)
|
122 |
+
self.current_degradation[index]["noise_stddev"] = self.degradation_noise.current_degradation[index]["noise_stddev"]
|
123 |
+
|
124 |
+
degraded_img = degraded_img.squeeze(0)
|
125 |
+
self.current_degradation[index] = {
|
126 |
+
"noise_stddev": self.degradation_noise.current_degradation[index]["noise_stddev"]
|
127 |
+
}
|
128 |
+
try:
|
129 |
+
self.current_degradation[index][self.blur_deg_str] = self.degradation_blur.current_degradation[index][self.blur_deg_str]
|
130 |
+
except KeyError:
|
131 |
+
pass
|
132 |
+
|
133 |
+
return degraded_img, img_data
|
134 |
+
|
135 |
+
def __len__(self):
|
136 |
+
return self.n_samples
|
137 |
+
|
138 |
+
|
139 |
+
if __name__ == "__main__":
|
140 |
+
dataset_restoration = RestorationDataset(
|
141 |
+
Path("__dataset/div2k/DIV2K_train_HR/DIV2K_train_HR/"),
|
142 |
+
preloaded=True,
|
143 |
+
)
|
144 |
+
dataloader = DataLoader(
|
145 |
+
dataset_restoration,
|
146 |
+
batch_size=16,
|
147 |
+
shuffle=True
|
148 |
+
)
|
149 |
+
start = time()
|
150 |
+
total = 0
|
151 |
+
for batch in tqdm(dataloader):
|
152 |
+
# print(batch.shape)
|
153 |
+
torch.cuda.synchronize()
|
154 |
+
total += batch.shape[0]
|
155 |
+
end = time()
|
156 |
+
print(f"Time elapsed: {(end-start)/total*1000.:.2f}ms/image")
|
src/rstor/data/synthetic_dataloader.py
ADDED
@@ -0,0 +1,187 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
import torch
|
3 |
+
import torch.nn.functional as F
|
4 |
+
from torch.utils.data import Dataset
|
5 |
+
from typing import Tuple
|
6 |
+
from rstor.data.degradation import DegradationBlurMat, DegradationBlurGauss, DegradationNoise
|
7 |
+
from rstor.properties import DEVICE, AUGMENTATION_FLIP, DEGRADATION_BLUR_NONE, DEGRADATION_BLUR_MAT, DEGRADATION_BLUR_GAUSS
|
8 |
+
from rstor.synthetic_data.dead_leaves_cpu import cpu_dead_leaves_chart
|
9 |
+
from rstor.synthetic_data.dead_leaves_gpu import gpu_dead_leaves_chart
|
10 |
+
import cv2
|
11 |
+
from skimage.filters import gaussian
|
12 |
+
import random
|
13 |
+
import numpy as np
|
14 |
+
|
15 |
+
from rstor.utils import DEFAULT_TORCH_FLOAT_TYPE
|
16 |
+
|
17 |
+
|
18 |
+
class DeadLeavesDataset(Dataset):
|
19 |
+
def __init__(
|
20 |
+
self,
|
21 |
+
size: Tuple[int, int] = (128, 128),
|
22 |
+
length: int = 1000,
|
23 |
+
frozen_seed: int = None, # useful for validation set!
|
24 |
+
blur_kernel_half_size: int = [0, 2],
|
25 |
+
ds_factor: int = 5,
|
26 |
+
noise_stddev: float = [0., 50.],
|
27 |
+
degradation_blur=DEGRADATION_BLUR_NONE,
|
28 |
+
**config_dead_leaves
|
29 |
+
# number_of_circles: int = -1,
|
30 |
+
# background_color: Optional[Tuple[float, float, float]] = (0.5, 0.5, 0.5),
|
31 |
+
# colored: Optional[bool] = False,
|
32 |
+
# radius_mean: Optional[int] = -1,
|
33 |
+
# radius_stddev: Optional[int] = -1,
|
34 |
+
):
|
35 |
+
|
36 |
+
self.frozen_seed = frozen_seed
|
37 |
+
self.ds_factor = ds_factor
|
38 |
+
self.size = (size[0]*ds_factor, size[1]*ds_factor)
|
39 |
+
self.length = length
|
40 |
+
self.config_dead_leaves = config_dead_leaves
|
41 |
+
self.blur_kernel_half_size = blur_kernel_half_size
|
42 |
+
self.noise_stddev = noise_stddev
|
43 |
+
|
44 |
+
|
45 |
+
self.degradation_blur_type = degradation_blur
|
46 |
+
if degradation_blur == DEGRADATION_BLUR_GAUSS:
|
47 |
+
self.degradation_blur = DegradationBlurGauss(self.length,
|
48 |
+
blur_kernel_half_size,
|
49 |
+
frozen_seed)
|
50 |
+
self.blur_deg_str = "blur_kernel_half_size"
|
51 |
+
elif degradation_blur == DEGRADATION_BLUR_MAT:
|
52 |
+
self.degradation_blur = DegradationBlurMat(self.length,
|
53 |
+
frozen_seed)
|
54 |
+
self.blur_deg_str = "blur_kernel_id"
|
55 |
+
elif degradation_blur == DEGRADATION_BLUR_NONE:
|
56 |
+
pass
|
57 |
+
else:
|
58 |
+
raise ValueError(f"Unknown degradation blur {degradation_blur}")
|
59 |
+
|
60 |
+
self.degradation_noise = DegradationNoise(self.length,
|
61 |
+
noise_stddev,
|
62 |
+
frozen_seed)
|
63 |
+
self.current_degradation = {}
|
64 |
+
|
65 |
+
def __len__(self):
|
66 |
+
return self.length
|
67 |
+
|
68 |
+
def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor]:
|
69 |
+
# TODO there is a bug on this cpu version, the dead leaved dont appear ot be right
|
70 |
+
seed = self.frozen_seed + idx if self.frozen_seed is not None else None
|
71 |
+
chart = cpu_dead_leaves_chart(self.size, seed=seed, **self.config_dead_leaves)
|
72 |
+
|
73 |
+
if self.ds_factor > 1:
|
74 |
+
# print(f"Downsampling {chart.shape} with factor {self.ds_factor}...")
|
75 |
+
sigma = 3/5
|
76 |
+
chart = gaussian(
|
77 |
+
chart, sigma=(sigma, sigma, 0), mode='nearest',
|
78 |
+
cval=0, preserve_range=True, truncate=4.0)
|
79 |
+
chart = chart[::self.ds_factor, ::self.ds_factor]
|
80 |
+
|
81 |
+
th_chart = torch.from_numpy(chart).permute(2, 0, 1).unsqueeze(0)
|
82 |
+
degraded_chart = th_chart
|
83 |
+
|
84 |
+
self.current_degradation[idx] = {}
|
85 |
+
if self.degradation_blur_type != DEGRADATION_BLUR_NONE:
|
86 |
+
degraded_chart = self.degradation_blur(degraded_chart, idx)
|
87 |
+
self.current_degradation[idx][self.blur_deg_str] = self.degradation_blur.current_degradation[idx][self.blur_deg_str]
|
88 |
+
|
89 |
+
degraded_chart = self.degradation_noise(degraded_chart, idx)
|
90 |
+
self.current_degradation[idx]["noise_stddev"] = self.degradation_noise.current_degradation[idx]["noise_stddev"]
|
91 |
+
|
92 |
+
degraded_chart = degraded_chart.squeeze(0)
|
93 |
+
th_chart = th_chart.squeeze(0)
|
94 |
+
|
95 |
+
return degraded_chart, th_chart
|
96 |
+
|
97 |
+
|
98 |
+
class DeadLeavesDatasetGPU(Dataset):
|
99 |
+
def __init__(
|
100 |
+
self,
|
101 |
+
size: Tuple[int, int] = (128, 128),
|
102 |
+
length: int = 1000,
|
103 |
+
frozen_seed: int = None, # useful for validation set!
|
104 |
+
blur_kernel_half_size: int = [0, 2],
|
105 |
+
ds_factor: int = 5,
|
106 |
+
noise_stddev: float = [0., 50.],
|
107 |
+
use_gaussian_kernel=True,
|
108 |
+
**config_dead_leaves
|
109 |
+
# number_of_circles: int = -1,
|
110 |
+
# background_color: Optional[Tuple[float, float, float]] = (0.5, 0.5, 0.5),
|
111 |
+
# colored: Optional[bool] = False,
|
112 |
+
# radius_mean: Optional[int] = -1,
|
113 |
+
# radius_stddev: Optional[int] = -1,
|
114 |
+
):
|
115 |
+
self.frozen_seed = frozen_seed
|
116 |
+
self.ds_factor = ds_factor
|
117 |
+
self.size = (size[0]*ds_factor, size[1]*ds_factor)
|
118 |
+
self.length = length
|
119 |
+
self.config_dead_leaves = config_dead_leaves
|
120 |
+
|
121 |
+
# downsample kernel
|
122 |
+
sigma = 3/5
|
123 |
+
k_size = 5 # This fits with sigma = 3/5, the cutoff value is 0.0038 (neglectable)
|
124 |
+
x = (torch.arange(k_size) - 2).to('cuda')
|
125 |
+
kernel = torch.stack(torch.meshgrid((x, x), indexing='ij'))
|
126 |
+
kernel.requires_grad = False
|
127 |
+
dist_sq = kernel[0]**2 + kernel[1]**2
|
128 |
+
kernel = (-dist_sq.square()/(2*sigma**2)).exp()
|
129 |
+
kernel = kernel / kernel.sum()
|
130 |
+
self.downsample_kernel = kernel.repeat(3, 1, 1, 1) # shape [3, 1, k_size, k_size]
|
131 |
+
self.downsample_kernel.requires_grad = False
|
132 |
+
self.use_gaussian_kernel = use_gaussian_kernel
|
133 |
+
if use_gaussian_kernel:
|
134 |
+
self.degradation_blur = DegradationBlurGauss(length,
|
135 |
+
blur_kernel_half_size,
|
136 |
+
frozen_seed)
|
137 |
+
else:
|
138 |
+
self.degradation_blur = DegradationBlurMat(length,
|
139 |
+
frozen_seed)
|
140 |
+
|
141 |
+
self.degradation_noise = DegradationNoise(length,
|
142 |
+
noise_stddev,
|
143 |
+
frozen_seed)
|
144 |
+
self.current_degradation = {}
|
145 |
+
|
146 |
+
def __len__(self) -> int:
|
147 |
+
return self.length
|
148 |
+
|
149 |
+
def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor]:
|
150 |
+
"""Get a single deadleave chart and its degraded version.
|
151 |
+
|
152 |
+
Args:
|
153 |
+
idx (int): index of the item to retrieve
|
154 |
+
|
155 |
+
Returns:
|
156 |
+
Tuple[torch.Tensor, torch.Tensor]: degraded chart, target chart
|
157 |
+
"""
|
158 |
+
seed = self.frozen_seed + idx if self.frozen_seed is not None else None
|
159 |
+
|
160 |
+
# Return numba device array
|
161 |
+
numba_chart = gpu_dead_leaves_chart(self.size, seed=seed, **self.config_dead_leaves)
|
162 |
+
th_chart = torch.as_tensor(numba_chart, dtype=DEFAULT_TORCH_FLOAT_TYPE, device="cuda")[
|
163 |
+
None].permute(0, 3, 1, 2) # [1, c, h, w]
|
164 |
+
if self.ds_factor > 1:
|
165 |
+
# Downsample using strided gaussian conv (sigma=3/5)
|
166 |
+
th_chart = F.pad(th_chart,
|
167 |
+
pad=(2, 2, 0, 0),
|
168 |
+
mode="replicate")
|
169 |
+
th_chart = F.conv2d(th_chart,
|
170 |
+
self.downsample_kernel,
|
171 |
+
padding='valid',
|
172 |
+
groups=3,
|
173 |
+
stride=self.ds_factor)
|
174 |
+
|
175 |
+
degraded_chart = self.degradation_blur(th_chart, idx)
|
176 |
+
degraded_chart = self.degradation_noise(degraded_chart, idx)
|
177 |
+
|
178 |
+
blur_deg_str = "blur_kernel_half_size" if self.use_gaussian_kernel else "blur_kernel_id"
|
179 |
+
self.current_degradation[idx] = {
|
180 |
+
blur_deg_str: self.degradation_blur.current_degradation[idx][blur_deg_str],
|
181 |
+
"noise_stddev": self.degradation_noise.current_degradation[idx]["noise_stddev"]
|
182 |
+
}
|
183 |
+
|
184 |
+
degraded_chart = degraded_chart.squeeze(0)
|
185 |
+
th_chart = th_chart.squeeze(0)
|
186 |
+
|
187 |
+
return degraded_chart, th_chart
|
src/rstor/learning/experiments.py
ADDED
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from rstor.properties import DEVICE, OPTIMIZER, PARAMS
|
2 |
+
from rstor.architecture.selector import load_architecture
|
3 |
+
from rstor.data.dataloader import get_data_loader
|
4 |
+
from typing import Tuple
|
5 |
+
import torch
|
6 |
+
|
7 |
+
|
8 |
+
def get_training_content(
|
9 |
+
config: dict,
|
10 |
+
training_mode: bool = False,
|
11 |
+
device=DEVICE) -> Tuple[torch.nn.Module, torch.optim.Optimizer, dict]:
|
12 |
+
model = load_architecture(config)
|
13 |
+
optimizer, dl_dict = None, None
|
14 |
+
if training_mode:
|
15 |
+
optimizer = torch.optim.Adam(model.parameters(), **config[OPTIMIZER][PARAMS])
|
16 |
+
dl_dict = get_data_loader(config)
|
17 |
+
return model, optimizer, dl_dict
|
18 |
+
|
19 |
+
|
20 |
+
if __name__ == "__main__":
|
21 |
+
from rstor.learning.experiments_definition import default_experiment
|
22 |
+
config = default_experiment(1)
|
23 |
+
model, optimizer, dl_dict = get_training_content(config, training_mode=True)
|
24 |
+
print(config)
|
src/rstor/learning/experiments_definition.py
ADDED
@@ -0,0 +1,489 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from rstor.properties import (NB_EPOCHS, DATALOADER, BATCH_SIZE, SIZE, LENGTH,
|
2 |
+
TRAIN, VALIDATION, SCHEDULER, REDUCELRONPLATEAU,
|
3 |
+
MODEL, ARCHITECTURE, ID, NAME, SCHEDULER_CONFIGURATION, OPTIMIZER, PARAMS, LR,
|
4 |
+
LOSS, LOSS_MSE, CONFIG_DEAD_LEAVES,
|
5 |
+
SELECTED_METRICS, METRIC_PSNR, METRIC_SSIM, METRIC_LPIPS,
|
6 |
+
DATASET_DL_DIV2K_512, DATASET_DIV2K,
|
7 |
+
CONFIG_DEGRADATION,
|
8 |
+
PRETTY_NAME,
|
9 |
+
DEGRADATION_BLUR_NONE, DEGRADATION_BLUR_MAT, DEGRADATION_BLUR_GAUSS,
|
10 |
+
AUGMENTATION_FLIP, AUGMENTATION_ROTATE,
|
11 |
+
DATASET_DL_EXTRAPRIMITIVES_DIV2K_512)
|
12 |
+
|
13 |
+
|
14 |
+
from typing import Tuple
|
15 |
+
|
16 |
+
|
17 |
+
def model_configurations(config, model_preset="StackedConvolutions", bias: bool = True) -> dict:
|
18 |
+
if model_preset == "StackedConvolutions":
|
19 |
+
config[MODEL] = {
|
20 |
+
ARCHITECTURE: dict(
|
21 |
+
num_layers=8,
|
22 |
+
k_size=3,
|
23 |
+
h_dim=16,
|
24 |
+
bias=bias
|
25 |
+
),
|
26 |
+
NAME: "StackedConvolutions"
|
27 |
+
}
|
28 |
+
elif model_preset == "NAFNet" or model_preset == "UNet":
|
29 |
+
# https://github.com/megvii-research/NAFNet/blob/main/options/test/GoPro/NAFNet-width64.yml
|
30 |
+
config[MODEL] = {
|
31 |
+
ARCHITECTURE: dict(
|
32 |
+
width=64,
|
33 |
+
enc_blk_nums=[1, 1, 1, 28],
|
34 |
+
middle_blk_num=1,
|
35 |
+
dec_blk_nums=[1, 1, 1, 1],
|
36 |
+
),
|
37 |
+
NAME: model_preset
|
38 |
+
}
|
39 |
+
else:
|
40 |
+
raise ValueError(f"Unknown model preset {model_preset}")
|
41 |
+
|
42 |
+
|
43 |
+
def presets_experiments(
|
44 |
+
exp: int,
|
45 |
+
b: int = 32,
|
46 |
+
n: int = 50,
|
47 |
+
bias: bool = True,
|
48 |
+
length: int = 5000,
|
49 |
+
data_size: Tuple[int, int] = (128, 128),
|
50 |
+
model_preset: str = "StackedConvolutions",
|
51 |
+
lpips: bool = False
|
52 |
+
) -> dict:
|
53 |
+
config = {
|
54 |
+
ID: exp,
|
55 |
+
NAME: f"{exp:04d}",
|
56 |
+
NB_EPOCHS: n
|
57 |
+
}
|
58 |
+
config[DATALOADER] = {
|
59 |
+
BATCH_SIZE: {
|
60 |
+
TRAIN: b,
|
61 |
+
VALIDATION: b
|
62 |
+
},
|
63 |
+
SIZE: data_size, # (width, height)
|
64 |
+
LENGTH: {
|
65 |
+
TRAIN: length,
|
66 |
+
VALIDATION: 800
|
67 |
+
}
|
68 |
+
}
|
69 |
+
config[OPTIMIZER] = {
|
70 |
+
NAME: "Adam",
|
71 |
+
PARAMS: {
|
72 |
+
LR: 1e-3
|
73 |
+
}
|
74 |
+
}
|
75 |
+
model_configurations(config, model_preset=model_preset, bias=bias)
|
76 |
+
config[SCHEDULER] = REDUCELRONPLATEAU
|
77 |
+
config[SCHEDULER_CONFIGURATION] = {
|
78 |
+
"factor": 0.8,
|
79 |
+
"patience": 5
|
80 |
+
}
|
81 |
+
config[LOSS] = LOSS_MSE
|
82 |
+
config[SELECTED_METRICS] = [METRIC_PSNR, METRIC_SSIM]
|
83 |
+
if lpips:
|
84 |
+
config[SELECTED_METRICS].append(METRIC_LPIPS)
|
85 |
+
return config
|
86 |
+
|
87 |
+
|
88 |
+
def get_experiment_config(exp: int) -> dict:
|
89 |
+
if exp == -1:
|
90 |
+
config = presets_experiments(exp, length=10, n=2)
|
91 |
+
elif exp == -2:
|
92 |
+
config = presets_experiments(exp, length=10, n=2, lpips=True)
|
93 |
+
elif exp == -3:
|
94 |
+
config = presets_experiments(exp, n=20)
|
95 |
+
config[DATALOADER]["gpu_gen"] = True
|
96 |
+
config[DATALOADER][CONFIG_DEAD_LEAVES] = dict(
|
97 |
+
blur_kernel_half_size=[0, 0],
|
98 |
+
ds_factor=1,
|
99 |
+
noise_stddev=[0., 50.]
|
100 |
+
)
|
101 |
+
config[PRETTY_NAME] = "Vanilla denoise only - ds=1 - noisy 0-50"
|
102 |
+
elif exp == -4:
|
103 |
+
config = presets_experiments(exp, b=4, n=20)
|
104 |
+
config[DATALOADER][NAME] = DATASET_DL_DIV2K_512
|
105 |
+
config[DATALOADER][CONFIG_DEGRADATION] = dict(
|
106 |
+
noise_stddev=[0., 50.]
|
107 |
+
)
|
108 |
+
config[PRETTY_NAME] = "Vanilla exp from disk - noisy 0-50"
|
109 |
+
elif exp == 1000:
|
110 |
+
config = presets_experiments(exp, n=60)
|
111 |
+
config[PRETTY_NAME] = "Vanilla small blur"
|
112 |
+
config[DATALOADER][CONFIG_DEAD_LEAVES] = dict(blur_kernel_half_size=[0, 2], ds_factor=1, noise_stddev=[0., 0.])
|
113 |
+
elif exp == 1001:
|
114 |
+
config = presets_experiments(exp, n=60)
|
115 |
+
config[DATALOADER][CONFIG_DEAD_LEAVES] = dict(blur_kernel_half_size=[0, 6], ds_factor=1, noise_stddev=[0., 0.])
|
116 |
+
config[PRETTY_NAME] = "Vanilla large blur 0 - 6"
|
117 |
+
elif exp == 1002:
|
118 |
+
config = presets_experiments(exp, n=6) # Less epochs because of the large downsample factor
|
119 |
+
config[DATALOADER][CONFIG_DEAD_LEAVES] = dict(blur_kernel_half_size=[0, 2], ds_factor=5, noise_stddev=[0., 0.])
|
120 |
+
config[PRETTY_NAME] = "Vanilla small blur - ds=5"
|
121 |
+
elif exp == 1003:
|
122 |
+
config = presets_experiments(exp, n=6) # Less epochs because of the large downsample factor
|
123 |
+
config[DATALOADER][CONFIG_DEAD_LEAVES] = dict(blur_kernel_half_size=[0, 2], ds_factor=5, noise_stddev=[0., 50.])
|
124 |
+
config[PRETTY_NAME] = "Vanilla small blur - ds=5 - noisy 0-50"
|
125 |
+
elif exp == 1004:
|
126 |
+
config = presets_experiments(exp, n=60)
|
127 |
+
config[DATALOADER][CONFIG_DEAD_LEAVES] = dict(blur_kernel_half_size=[0, 0], ds_factor=1, noise_stddev=[0., 50.])
|
128 |
+
config[PRETTY_NAME] = "Vanilla denoise only - ds=1 - noisy 0-50"
|
129 |
+
elif exp == 1005:
|
130 |
+
config = presets_experiments(exp, bias=False, n=60)
|
131 |
+
config[DATALOADER][CONFIG_DEAD_LEAVES] = dict(blur_kernel_half_size=[0, 0], ds_factor=1, noise_stddev=[0., 50.])
|
132 |
+
config[PRETTY_NAME] = "Vanilla denoise only - ds=1 - noisy 0-50 - bias free"
|
133 |
+
elif exp == 1006:
|
134 |
+
config = presets_experiments(exp, n=60)
|
135 |
+
config[PRETTY_NAME] = "Vanilla small blur - noisy 0-50"
|
136 |
+
config[DATALOADER][CONFIG_DEAD_LEAVES] = dict(blur_kernel_half_size=[0, 2], ds_factor=1, noise_stddev=[0., 50.])
|
137 |
+
elif exp == 1007:
|
138 |
+
config = presets_experiments(exp, n=60)
|
139 |
+
config[PRETTY_NAME] = "Vanilla large blur 0 - 6 - noisy 0-50"
|
140 |
+
config[DATALOADER][CONFIG_DEAD_LEAVES] = dict(blur_kernel_half_size=[0, 6], ds_factor=1, noise_stddev=[0., 50.])
|
141 |
+
elif exp == 2000:
|
142 |
+
config = presets_experiments(exp, n=60, b=16, model_preset="NAFNet")
|
143 |
+
config[PRETTY_NAME] = "NAFNet denoise 0-50"
|
144 |
+
config[DATALOADER][CONFIG_DEAD_LEAVES] = dict(
|
145 |
+
blur_kernel_half_size=[0, 0],
|
146 |
+
ds_factor=1,
|
147 |
+
noise_stddev=[0., 50.]
|
148 |
+
)
|
149 |
+
elif exp == 2001:
|
150 |
+
config = presets_experiments(exp, n=60, b=16, model_preset="UNet")
|
151 |
+
config[PRETTY_NAME] = "UNEt denoise 0-50"
|
152 |
+
config[DATALOADER][CONFIG_DEAD_LEAVES] = dict(
|
153 |
+
blur_kernel_half_size=[0, 0],
|
154 |
+
ds_factor=1,
|
155 |
+
noise_stddev=[0., 50.]
|
156 |
+
)
|
157 |
+
elif exp == 2002:
|
158 |
+
config = presets_experiments(exp, n=20, b=8, data_size=(256, 256), model_preset="NAFNet")
|
159 |
+
config[PRETTY_NAME] = "NAFNet denoise 0-50 gpu dl 256x256"
|
160 |
+
config[DATALOADER]["gpu_gen"] = True
|
161 |
+
config[DATALOADER][CONFIG_DEAD_LEAVES] = dict(
|
162 |
+
blur_kernel_half_size=[0, 0],
|
163 |
+
ds_factor=1,
|
164 |
+
noise_stddev=[0., 50.]
|
165 |
+
)
|
166 |
+
elif exp == 2003:
|
167 |
+
config = presets_experiments(exp, n=20, b=8, data_size=(128, 128), model_preset="NAFNet")
|
168 |
+
config[PRETTY_NAME] = "NAFNet denoise 0-50 gpu dl - 128x128"
|
169 |
+
config[DATALOADER]["gpu_gen"] = True
|
170 |
+
config[DATALOADER][CONFIG_DEAD_LEAVES] = dict(
|
171 |
+
blur_kernel_half_size=[0, 0],
|
172 |
+
ds_factor=1,
|
173 |
+
noise_stddev=[0., 50.]
|
174 |
+
)
|
175 |
+
elif exp == 2004:
|
176 |
+
config = presets_experiments(exp, n=20, b=16, data_size=(128, 128), model_preset="NAFNet")
|
177 |
+
config[PRETTY_NAME] = "NAFNet Light denoise 0-50 gpu dl - 128x128"
|
178 |
+
config[DATALOADER]["gpu_gen"] = True
|
179 |
+
config[DATALOADER][CONFIG_DEAD_LEAVES] = dict(
|
180 |
+
blur_kernel_half_size=[0, 0],
|
181 |
+
ds_factor=1,
|
182 |
+
noise_stddev=[0., 50.]
|
183 |
+
)
|
184 |
+
config[MODEL][ARCHITECTURE] = dict(
|
185 |
+
width=64,
|
186 |
+
enc_blk_nums=[1, 1, 1, 2],
|
187 |
+
middle_blk_num=1,
|
188 |
+
dec_blk_nums=[1, 1, 1, 1],
|
189 |
+
)
|
190 |
+
elif exp == 2005:
|
191 |
+
config = presets_experiments(exp, n=20, b=16, data_size=(128, 128), model_preset="NAFNet")
|
192 |
+
config[PRETTY_NAME] = "NAFNet TresLight denoise 0-50 gpu dl - 128x128"
|
193 |
+
config[DATALOADER]["gpu_gen"] = True
|
194 |
+
config[DATALOADER][CONFIG_DEAD_LEAVES] = dict(
|
195 |
+
blur_kernel_half_size=[0, 0],
|
196 |
+
ds_factor=1,
|
197 |
+
noise_stddev=[0., 50.]
|
198 |
+
)
|
199 |
+
config[MODEL][ARCHITECTURE] = dict(
|
200 |
+
width=64,
|
201 |
+
enc_blk_nums=[1, 1, 2],
|
202 |
+
middle_blk_num=1,
|
203 |
+
dec_blk_nums=[1, 1, 1],
|
204 |
+
)
|
205 |
+
elif exp == 2006:
|
206 |
+
config = presets_experiments(exp, n=20, b=16, data_size=(128, 128), model_preset="NAFNet")
|
207 |
+
config[PRETTY_NAME] = "NAFNet TresLight denoise 0-50 ds=5 gpu dl - 128x128"
|
208 |
+
config[DATALOADER]["gpu_gen"] = True
|
209 |
+
config[DATALOADER][CONFIG_DEAD_LEAVES] = dict(
|
210 |
+
blur_kernel_half_size=[0, 0],
|
211 |
+
ds_factor=5,
|
212 |
+
noise_stddev=[0., 50.]
|
213 |
+
)
|
214 |
+
config[MODEL][ARCHITECTURE] = dict(
|
215 |
+
width=64,
|
216 |
+
enc_blk_nums=[1, 1, 2],
|
217 |
+
middle_blk_num=1,
|
218 |
+
dec_blk_nums=[1, 1, 1],
|
219 |
+
)
|
220 |
+
elif exp == 2007:
|
221 |
+
config = presets_experiments(exp, n=20, b=16, data_size=(128, 128), model_preset="NAFNet")
|
222 |
+
config[PRETTY_NAME] = "NAFNet denoise 0-50 gpu dl -ds=5 128x128"
|
223 |
+
config[DATALOADER]["gpu_gen"] = True
|
224 |
+
config[DATALOADER][CONFIG_DEAD_LEAVES] = dict(
|
225 |
+
blur_kernel_half_size=[0, 0],
|
226 |
+
ds_factor=5,
|
227 |
+
noise_stddev=[0., 50.]
|
228 |
+
)
|
229 |
+
elif exp == 1008:
|
230 |
+
config = presets_experiments(exp, n=20)
|
231 |
+
config[DATALOADER]["gpu_gen"] = True
|
232 |
+
config[DATALOADER][CONFIG_DEAD_LEAVES] = dict(
|
233 |
+
blur_kernel_half_size=[0, 0],
|
234 |
+
ds_factor=5,
|
235 |
+
noise_stddev=[0., 50.]
|
236 |
+
)
|
237 |
+
config[PRETTY_NAME] = "Vanilla denoise only - ds=5 - noisy 0-50"
|
238 |
+
# ---------------------------------
|
239 |
+
# Pure DL DENOISING trainings!
|
240 |
+
# ---------------------------------
|
241 |
+
elif exp == 3000:
|
242 |
+
config = presets_experiments(exp, n=30, b=4, model_preset="NAFNet")
|
243 |
+
config[PRETTY_NAME] = "NAFNet denoise - DL_DIV2K_512 0-50"
|
244 |
+
config[DATALOADER][NAME] = DATASET_DL_DIV2K_512
|
245 |
+
config[DATALOADER][CONFIG_DEGRADATION] = dict(
|
246 |
+
noise_stddev=[0., 50.]
|
247 |
+
)
|
248 |
+
config[DATALOADER][SIZE] = (256, 256)
|
249 |
+
elif exp == 3001: # ENABLE GRADIENT CLIPPING
|
250 |
+
config = presets_experiments(exp, n=30, b=8, model_preset="NAFNet")
|
251 |
+
config[PRETTY_NAME] = "NAFNet41.4M denoise - DL_DIV2K_512 0-50 256x256"
|
252 |
+
config[DATALOADER][NAME] = DATASET_DL_DIV2K_512
|
253 |
+
config[DATALOADER][CONFIG_DEGRADATION] = dict(
|
254 |
+
noise_stddev=[0., 50.]
|
255 |
+
)
|
256 |
+
config[DATALOADER][SIZE] = (256, 256)
|
257 |
+
elif exp == 3002: # ENABLE GRADIENT CLIPPING
|
258 |
+
config = presets_experiments(exp, n=30, b=16, model_preset="NAFNet")
|
259 |
+
config[PRETTY_NAME] = "NAFNet41.4M denoise - DL_DIV2K_512 0-50 128x128"
|
260 |
+
config[DATALOADER][NAME] = DATASET_DL_DIV2K_512
|
261 |
+
config[DATALOADER][CONFIG_DEGRADATION] = dict(
|
262 |
+
noise_stddev=[0., 50.]
|
263 |
+
)
|
264 |
+
config[DATALOADER][SIZE] = (128, 128)
|
265 |
+
elif exp == 3010 or exp == 3011: # exp 3011 = REDO with Gradient clipping
|
266 |
+
config = presets_experiments(exp, n=50, b=4, model_preset="NAFNet")
|
267 |
+
config[PRETTY_NAME] = "NAFNet3.4M Light denoise - DL_DIV2K_512 0-50 256x256"
|
268 |
+
config[DATALOADER][NAME] = DATASET_DL_DIV2K_512
|
269 |
+
config[DATALOADER][CONFIG_DEGRADATION] = dict(
|
270 |
+
noise_stddev=[0., 50.]
|
271 |
+
)
|
272 |
+
config[MODEL][ARCHITECTURE] = dict(
|
273 |
+
width=64,
|
274 |
+
enc_blk_nums=[1, 1, 2],
|
275 |
+
middle_blk_num=1,
|
276 |
+
dec_blk_nums=[1, 1, 1],
|
277 |
+
)
|
278 |
+
config[DATALOADER][SIZE] = (256, 256)
|
279 |
+
elif exp == 3020:
|
280 |
+
config = presets_experiments(exp, b=32, n=50)
|
281 |
+
config[DATALOADER][NAME] = DATASET_DL_DIV2K_512
|
282 |
+
config[DATALOADER][CONFIG_DEGRADATION] = dict(
|
283 |
+
noise_stddev=[0., 50.]
|
284 |
+
)
|
285 |
+
config[PRETTY_NAME] = "Vanilla denoise DL 0-50 - noisy 0-50"
|
286 |
+
# ---------------------------------
|
287 |
+
# Pure DIV2K DENOISING trainings!
|
288 |
+
# ---------------------------------
|
289 |
+
elif exp == 3120:
|
290 |
+
config = presets_experiments(exp, b=32, n=50)
|
291 |
+
config[DATALOADER][NAME] = DATASET_DIV2K
|
292 |
+
config[DATALOADER][CONFIG_DEGRADATION] = dict(
|
293 |
+
noise_stddev=[0., 50.]
|
294 |
+
)
|
295 |
+
config[PRETTY_NAME] = "Vanilla DIV2K_512 0-50 - noisy 0-50"
|
296 |
+
elif exp == 3111:
|
297 |
+
config = presets_experiments(exp, n=50, b=4, model_preset="NAFNet")
|
298 |
+
config[PRETTY_NAME] = "NAFNet3.4M Light denoise - DIV2K_512 0-50 256x256"
|
299 |
+
config[DATALOADER][NAME] = DATASET_DIV2K
|
300 |
+
config[DATALOADER][CONFIG_DEGRADATION] = dict(noise_stddev=[0., 50.])
|
301 |
+
config[MODEL][ARCHITECTURE] = dict(
|
302 |
+
width=64,
|
303 |
+
enc_blk_nums=[1, 1, 2],
|
304 |
+
middle_blk_num=1,
|
305 |
+
dec_blk_nums=[1, 1, 1],
|
306 |
+
)
|
307 |
+
config[DATALOADER][SIZE] = (256, 256)
|
308 |
+
elif exp == 3101:
|
309 |
+
config = presets_experiments(exp, n=30, b=8, model_preset="NAFNet")
|
310 |
+
config[PRETTY_NAME] = "NAFNet41.4M denoise - DIV2K_512 0-50 256x256"
|
311 |
+
config[DATALOADER][NAME] = DATASET_DIV2K
|
312 |
+
config[DATALOADER][CONFIG_DEGRADATION] = dict(
|
313 |
+
noise_stddev=[0., 50.]
|
314 |
+
)
|
315 |
+
config[DATALOADER][SIZE] = (256, 256)
|
316 |
+
# ---------------------------------
|
317 |
+
# Pure EXTRA PRIMITIVES
|
318 |
+
# ---------------------------------
|
319 |
+
elif exp == 3030:
|
320 |
+
config = presets_experiments(exp, b=128, n=50)
|
321 |
+
config[DATALOADER][NAME] = DATASET_DL_EXTRAPRIMITIVES_DIV2K_512
|
322 |
+
config[DATALOADER][CONFIG_DEGRADATION] = dict(
|
323 |
+
noise_stddev=[0., 50.]
|
324 |
+
)
|
325 |
+
config[PRETTY_NAME] = "Vanilla DL_PRIMITIVES_512 0-50 - noisy 0-50"
|
326 |
+
# config[DATALOADER][SIZE] = (256, 256)
|
327 |
+
elif exp == 3040:
|
328 |
+
config = presets_experiments(exp, n=50, b=8, model_preset="NAFNet")
|
329 |
+
config[PRETTY_NAME] = "NAFNet3.4M Light denoise - DL_PRIMITIVES_512 0-50 256x256"
|
330 |
+
config[DATALOADER][NAME] = DATASET_DL_EXTRAPRIMITIVES_DIV2K_512
|
331 |
+
config[DATALOADER][CONFIG_DEGRADATION] = dict(noise_stddev=[0., 50.])
|
332 |
+
config[MODEL][ARCHITECTURE] = dict(
|
333 |
+
width=64,
|
334 |
+
enc_blk_nums=[1, 1, 2],
|
335 |
+
middle_blk_num=1,
|
336 |
+
dec_blk_nums=[1, 1, 1],
|
337 |
+
)
|
338 |
+
config[DATALOADER][SIZE] = (256, 256)
|
339 |
+
elif exp == 3050:
|
340 |
+
config = presets_experiments(exp, n=30, b=8, model_preset="NAFNet")
|
341 |
+
config[PRETTY_NAME] = "NAFNet41.4M denoise - DL_PRIMITIVES_512 0-50 256x256"
|
342 |
+
config[DATALOADER][NAME] = DATASET_DL_EXTRAPRIMITIVES_DIV2K_512
|
343 |
+
config[DATALOADER][CONFIG_DEGRADATION] = dict(noise_stddev=[0., 50.])
|
344 |
+
config[DATALOADER][SIZE] = (256, 256)
|
345 |
+
# ---------------------------------
|
346 |
+
# DEBLURRING
|
347 |
+
# ---------------------------------
|
348 |
+
elif exp == 5000:
|
349 |
+
config = presets_experiments(exp, n=30, b=8, model_preset="NAFNet")
|
350 |
+
config[PRETTY_NAME] = "NAFNet deblur - DL_DIV2K_512 256x256"
|
351 |
+
config[DATALOADER][NAME] = DATASET_DL_DIV2K_512
|
352 |
+
config[DATALOADER][CONFIG_DEGRADATION] = dict(
|
353 |
+
noise_stddev=[0., 0.],
|
354 |
+
degradation_blur=DEGRADATION_BLUR_MAT, # Using .mat kernels
|
355 |
+
augmentation_list=[AUGMENTATION_FLIP]
|
356 |
+
)
|
357 |
+
config[DATALOADER][SIZE] = (256, 256)
|
358 |
+
elif exp == 5001:
|
359 |
+
config = presets_experiments(exp, n=30, b=8, model_preset="NAFNet")
|
360 |
+
config[PRETTY_NAME] = "NAFNet deblur - DIV2K_512 256x256"
|
361 |
+
config[DATALOADER][NAME] = DATASET_DIV2K
|
362 |
+
config[DATALOADER][CONFIG_DEGRADATION] = dict(
|
363 |
+
noise_stddev=[0., 0.],
|
364 |
+
degradation_blur=DEGRADATION_BLUR_MAT, # Using .mat kernels
|
365 |
+
augmentation_list=[AUGMENTATION_FLIP]
|
366 |
+
)
|
367 |
+
config[DATALOADER][SIZE] = (256, 256)
|
368 |
+
elif exp == 5002:
|
369 |
+
config = presets_experiments(exp, n=30, b=8, model_preset="NAFNet")
|
370 |
+
config[PRETTY_NAME] = "NAFNet deblur - DL_DIV2K_512 256x256"
|
371 |
+
config[DATALOADER][NAME] = DATASET_DL_DIV2K_512
|
372 |
+
config[DATALOADER][CONFIG_DEGRADATION] = dict(
|
373 |
+
noise_stddev=[0., 0.],
|
374 |
+
degradation_blur=DEGRADATION_BLUR_MAT, # Using .mat kernels
|
375 |
+
augmentation_list=[AUGMENTATION_FLIP]
|
376 |
+
)
|
377 |
+
config[DATALOADER][SIZE] = (256, 256)
|
378 |
+
elif exp == 5003:
|
379 |
+
config = presets_experiments(exp, n=30, b=8, model_preset="NAFNet")
|
380 |
+
config[PRETTY_NAME] = "NAFNet deblur - DIV2K_512 256x256"
|
381 |
+
config[DATALOADER][NAME] = DATASET_DIV2K
|
382 |
+
config[DATALOADER][CONFIG_DEGRADATION] = dict(
|
383 |
+
noise_stddev=[0., 0.],
|
384 |
+
degradation_blur=DEGRADATION_BLUR_MAT, # Using .mat kernels
|
385 |
+
augmentation_list=[AUGMENTATION_FLIP]
|
386 |
+
)
|
387 |
+
config[DATALOADER][SIZE] = (256, 256)
|
388 |
+
elif exp == 5004:
|
389 |
+
config = presets_experiments(exp, n=30, b=8, model_preset="NAFNet")
|
390 |
+
config[PRETTY_NAME] = "NAFNet deblur - DIV2K_512 256x256"
|
391 |
+
config[DATALOADER][NAME] = DATASET_DIV2K
|
392 |
+
config[DATALOADER][CONFIG_DEGRADATION] = dict(
|
393 |
+
noise_stddev=[0., 0.],
|
394 |
+
degradation_blur=DEGRADATION_BLUR_MAT, # Using .mat kernels
|
395 |
+
augmentation_list=[AUGMENTATION_FLIP]
|
396 |
+
)
|
397 |
+
config[DATALOADER][SIZE] = (256, 256)
|
398 |
+
elif exp == 5005:
|
399 |
+
config = presets_experiments(exp, n=30, b=8, model_preset="UNet")
|
400 |
+
config[PRETTY_NAME] = "UNet deblur - DL_512 256x256"
|
401 |
+
config[DATALOADER][NAME] = DATASET_DL_DIV2K_512
|
402 |
+
config[DATALOADER][CONFIG_DEGRADATION] = dict(
|
403 |
+
noise_stddev=[0., 0.],
|
404 |
+
degradation_blur=DEGRADATION_BLUR_MAT, # Using .mat kernels
|
405 |
+
augmentation_list=[AUGMENTATION_FLIP]
|
406 |
+
)
|
407 |
+
config[DATALOADER][SIZE] = (256, 256)
|
408 |
+
# elif exp == 6000: # -> FAILED, no kernels normalization!
|
409 |
+
# config = presets_experiments(exp, b=32, n=50)
|
410 |
+
# config[DATALOADER][NAME] = DATASET_DL_DIV2K_512
|
411 |
+
# config[DATALOADER][CONFIG_DEGRADATION] = dict(
|
412 |
+
# noise_stddev=[0., 50.],
|
413 |
+
# degradation_blur=DEGRADATION_BLUR_MAT, # Deblur = Using .mat kernels
|
414 |
+
# augmentation_list=[AUGMENTATION_FLIP]
|
415 |
+
# )
|
416 |
+
# config[PRETTY_NAME] = "Vanilla deblur DL_DIV2K_512"
|
417 |
+
elif exp == 6002:
|
418 |
+
config = presets_experiments(exp, b=128, n=50)
|
419 |
+
config[DATALOADER][NAME] = DATASET_DL_DIV2K_512
|
420 |
+
config[DATALOADER][CONFIG_DEGRADATION] = dict(
|
421 |
+
noise_stddev=[0., 50.],
|
422 |
+
degradation_blur=DEGRADATION_BLUR_MAT, # Deblur = Using .mat kernels
|
423 |
+
augmentation_list=[AUGMENTATION_FLIP]
|
424 |
+
)
|
425 |
+
config[PRETTY_NAME] = "Vanilla deblur DL_DIV2K_512"
|
426 |
+
# elif exp == 6001: # -> FAILED, no kernels normalization!
|
427 |
+
# config = presets_experiments(exp, b=32, n=50)
|
428 |
+
# config[DATALOADER][NAME] = DATASET_DIV2K
|
429 |
+
# config[DATALOADER][CONFIG_DEGRADATION] = dict(
|
430 |
+
# noise_stddev=[0., 50.],
|
431 |
+
# degradation_blur=DEGRADATION_BLUR_MAT, # Deblur = Using .mat kernels
|
432 |
+
# augmentation_list=[AUGMENTATION_FLIP]
|
433 |
+
# )
|
434 |
+
# config[PRETTY_NAME] = "Vanilla delbur DIV2K_512"
|
435 |
+
elif exp == 6003:
|
436 |
+
config = presets_experiments(exp, b=128, n=50)
|
437 |
+
config[DATALOADER][NAME] = DATASET_DIV2K
|
438 |
+
config[DATALOADER][CONFIG_DEGRADATION] = dict(
|
439 |
+
noise_stddev=[0., 50.],
|
440 |
+
degradation_blur=DEGRADATION_BLUR_MAT, # Deblur = Using .mat kernels
|
441 |
+
augmentation_list=[AUGMENTATION_FLIP]
|
442 |
+
)
|
443 |
+
config[PRETTY_NAME] = "Vanilla delbur DIV2K_512"
|
444 |
+
|
445 |
+
elif exp == 7000:
|
446 |
+
config = presets_experiments(exp, b=16, n=30, model_preset="NAFNet")
|
447 |
+
config[MODEL][ARCHITECTURE] = dict(
|
448 |
+
width=64,
|
449 |
+
enc_blk_nums=[1, 1, 2],
|
450 |
+
middle_blk_num=1,
|
451 |
+
dec_blk_nums=[1, 1, 1],
|
452 |
+
)
|
453 |
+
config[DATALOADER][NAME] = DATASET_DL_DIV2K_512
|
454 |
+
config[DATALOADER][CONFIG_DEGRADATION] = dict(
|
455 |
+
noise_stddev=[0., 50.],
|
456 |
+
degradation_blur=DEGRADATION_BLUR_MAT, # Deblur = Using .mat kernels
|
457 |
+
augmentation_list=[AUGMENTATION_FLIP]
|
458 |
+
)
|
459 |
+
config[PRETTY_NAME] = "NafNet Light deblur DL"
|
460 |
+
config[DATALOADER][SIZE] = (256, 256)
|
461 |
+
elif exp == 7001:
|
462 |
+
config = presets_experiments(exp, b=16, n=50, model_preset="NAFNet")
|
463 |
+
config[DATALOADER][NAME] = DATASET_DIV2K
|
464 |
+
config[MODEL][ARCHITECTURE] = dict(
|
465 |
+
width=64,
|
466 |
+
enc_blk_nums=[1, 1, 2],
|
467 |
+
middle_blk_num=1,
|
468 |
+
dec_blk_nums=[1, 1, 1],
|
469 |
+
)
|
470 |
+
config[DATALOADER][CONFIG_DEGRADATION] = dict(
|
471 |
+
noise_stddev=[0., 50.],
|
472 |
+
degradation_blur=DEGRADATION_BLUR_MAT, # Deblur = Using .mat kernels
|
473 |
+
augmentation_list=[AUGMENTATION_FLIP]
|
474 |
+
)
|
475 |
+
config[PRETTY_NAME] = "NafNet Light deblur DIV2K"
|
476 |
+
config[DATALOADER][SIZE] = (256, 256)
|
477 |
+
elif exp == 7002:
|
478 |
+
config = presets_experiments(exp, n=20, b=8, model_preset="UNet")
|
479 |
+
config[PRETTY_NAME] = "UNET deblur - DIV2K"
|
480 |
+
config[DATALOADER][NAME] = DATASET_DIV2K
|
481 |
+
config[DATALOADER][CONFIG_DEGRADATION] = dict(
|
482 |
+
noise_stddev=[0., 0.],
|
483 |
+
degradation_blur=DEGRADATION_BLUR_MAT, # Using .mat kernels
|
484 |
+
augmentation_list=[AUGMENTATION_FLIP]
|
485 |
+
)
|
486 |
+
config[DATALOADER][SIZE] = (256, 256)
|
487 |
+
else:
|
488 |
+
raise ValueError(f"Experiment {exp} not found")
|
489 |
+
return config
|
src/rstor/learning/loss.py
ADDED
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from typing import Optional
|
3 |
+
from rstor.properties import LOSS_MSE
|
4 |
+
|
5 |
+
|
6 |
+
def compute_loss(
|
7 |
+
predic: torch.Tensor,
|
8 |
+
target: torch.Tensor,
|
9 |
+
mode: Optional[str] = LOSS_MSE
|
10 |
+
) -> torch.Tensor:
|
11 |
+
"""
|
12 |
+
Compute loss based on the predicted and true values.
|
13 |
+
|
14 |
+
Args:
|
15 |
+
predic (torch.Tensor): [N, C, H, W] predicted values
|
16 |
+
target (torch.Tensor): [N, C, H, W] target values.
|
17 |
+
mode (Optional[str], optional): mode of loss computation.
|
18 |
+
|
19 |
+
Returns:
|
20 |
+
torch.Tensor: The computed loss.
|
21 |
+
"""
|
22 |
+
assert mode in [LOSS_MSE], f"Mode {mode} not supported"
|
23 |
+
if mode == LOSS_MSE:
|
24 |
+
loss = torch.nn.functional.mse_loss(predic, target)
|
25 |
+
return loss
|
src/rstor/learning/metrics.py
ADDED
@@ -0,0 +1,140 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from rstor.properties import METRIC_PSNR, METRIC_SSIM, METRIC_LPIPS, REDUCTION_AVERAGE, REDUCTION_SKIP, REDUCTION_SUM
|
3 |
+
from torchmetrics.image import StructuralSimilarityIndexMeasure as SSIM
|
4 |
+
from torchmetrics.image.lpip import LearnedPerceptualImagePatchSimilarity
|
5 |
+
from typing import List, Optional
|
6 |
+
ALL_METRICS = [METRIC_PSNR, METRIC_SSIM, METRIC_LPIPS]
|
7 |
+
|
8 |
+
|
9 |
+
def compute_psnr(
|
10 |
+
predic: torch.Tensor,
|
11 |
+
target: torch.Tensor,
|
12 |
+
clamp_mse=1e-10,
|
13 |
+
reduction: Optional[str] = REDUCTION_AVERAGE
|
14 |
+
) -> torch.Tensor:
|
15 |
+
"""
|
16 |
+
Compute the average PSNR metric for a batch of predicted and true values.
|
17 |
+
|
18 |
+
Args:
|
19 |
+
predic (torch.Tensor): [N, C, H, W] predicted values.
|
20 |
+
target (torch.Tensor): [N, C, H, W] target values.
|
21 |
+
reduction (str): Reduction method. REDUCTION_AVERAGE/REDUCTION_SKIP/REDUCTION_SUM.
|
22 |
+
|
23 |
+
Returns:
|
24 |
+
torch.Tensor: The average PSNR value for the batch.
|
25 |
+
"""
|
26 |
+
with torch.no_grad():
|
27 |
+
mse_per_image = torch.mean((predic - target) ** 2, dim=(-3, -2, -1))
|
28 |
+
mse_per_image = torch.clamp(mse_per_image, min=clamp_mse)
|
29 |
+
psnr_per_image = 10 * torch.log10(1 / mse_per_image)
|
30 |
+
if reduction == REDUCTION_AVERAGE:
|
31 |
+
average_psnr = torch.mean(psnr_per_image)
|
32 |
+
elif reduction == REDUCTION_SUM:
|
33 |
+
average_psnr = torch.sum(psnr_per_image)
|
34 |
+
elif reduction == REDUCTION_SKIP:
|
35 |
+
average_psnr = psnr_per_image
|
36 |
+
else:
|
37 |
+
raise ValueError(f"Unknown reduction {reduction}")
|
38 |
+
return average_psnr
|
39 |
+
|
40 |
+
|
41 |
+
def compute_ssim(
|
42 |
+
predic: torch.Tensor,
|
43 |
+
target: torch.Tensor,
|
44 |
+
reduction: Optional[str] = REDUCTION_AVERAGE
|
45 |
+
) -> torch.Tensor:
|
46 |
+
"""
|
47 |
+
Compute the average SSIM metric for a batch of predicted and true values.
|
48 |
+
|
49 |
+
Args:
|
50 |
+
predic (torch.Tensor): [N, C, H, W] predicted values.
|
51 |
+
target (torch.Tensor): [N, C, H, W] target values.
|
52 |
+
reduction (str): Reduction method. REDUCTION_AVERAGE/REDUCTION_SKIP.
|
53 |
+
|
54 |
+
Returns:
|
55 |
+
torch.Tensor: The average SSIM value for the batch.
|
56 |
+
"""
|
57 |
+
with torch.no_grad():
|
58 |
+
reduction_mode = {
|
59 |
+
REDUCTION_SKIP: None,
|
60 |
+
REDUCTION_AVERAGE: "elementwise_mean",
|
61 |
+
REDUCTION_SUM: "sum"
|
62 |
+
}[reduction]
|
63 |
+
ssim = SSIM(data_range=1.0, reduction=reduction_mode).to(predic.device)
|
64 |
+
assert predic.shape == target.shape, f"{predic.shape} != {target.shape}"
|
65 |
+
assert predic.device == target.device, f"{predic.device} != {target.device}"
|
66 |
+
ssim_value = ssim(predic, target)
|
67 |
+
return ssim_value
|
68 |
+
|
69 |
+
|
70 |
+
def compute_lpips(
|
71 |
+
predic: torch.Tensor,
|
72 |
+
target: torch.Tensor,
|
73 |
+
reduction: Optional[str] = REDUCTION_AVERAGE,
|
74 |
+
) -> torch.Tensor:
|
75 |
+
"""
|
76 |
+
Compute the average LPIPS metric for a batch of predicted and true values.
|
77 |
+
https://richzhang.github.io/PerceptualSimilarity/
|
78 |
+
|
79 |
+
Args:
|
80 |
+
predic (torch.Tensor): [N, C, H, W] predicted values.
|
81 |
+
target (torch.Tensor): [N, C, H, W] target values.
|
82 |
+
reduction (str): Reduction method. REDUCTION_AVERAGE/REDUCTION_SKIP.
|
83 |
+
|
84 |
+
Returns:
|
85 |
+
torch.Tensor: The average SSIM value for the batch.
|
86 |
+
"""
|
87 |
+
reduction_mode = {
|
88 |
+
REDUCTION_SKIP: "sum", # does not really matter
|
89 |
+
REDUCTION_AVERAGE: "mean",
|
90 |
+
REDUCTION_SUM: "sum"
|
91 |
+
}[reduction]
|
92 |
+
|
93 |
+
with torch.no_grad():
|
94 |
+
lpip_metrics = LearnedPerceptualImagePatchSimilarity(
|
95 |
+
reduction=reduction_mode,
|
96 |
+
normalize=True # If set to True will instead expect input to be in the [0,1] range.
|
97 |
+
).to(predic.device)
|
98 |
+
assert predic.shape == target.shape, f"{predic.shape} != {target.shape}"
|
99 |
+
assert predic.device == target.device, f"{predic.device} != {target.device}"
|
100 |
+
if reduction == REDUCTION_SKIP:
|
101 |
+
lpip_value = []
|
102 |
+
for idx in range(predic.shape[0]):
|
103 |
+
lpip_value.append(lpip_metrics(
|
104 |
+
predic[idx, ...].unsqueeze(0).clip(0, 1),
|
105 |
+
target[idx, ...].unsqueeze(0).clip(0, 1)
|
106 |
+
))
|
107 |
+
lpip_value = torch.stack(lpip_value)
|
108 |
+
elif reduction in [REDUCTION_SUM, REDUCTION_AVERAGE]:
|
109 |
+
lpip_value = lpip_metrics(predic.clip(0, 1), target.clip(0, 1))
|
110 |
+
return lpip_value
|
111 |
+
|
112 |
+
|
113 |
+
def compute_metrics(
|
114 |
+
predic: torch.Tensor,
|
115 |
+
target: torch.Tensor,
|
116 |
+
reduction: Optional[str] = REDUCTION_AVERAGE,
|
117 |
+
chosen_metrics: Optional[List[str]] = ALL_METRICS) -> dict:
|
118 |
+
"""
|
119 |
+
Compute the metrics for a batch of predicted and true values.
|
120 |
+
|
121 |
+
Args:
|
122 |
+
predic (torch.Tensor): [N, C, H, W] predicted values.
|
123 |
+
target (torch.Tensor): [N, C, H, W] target values.
|
124 |
+
reduction (str): Reduction method. REDUCTION_AVERAGE/REDUCTION_SKIP/REDUCTION SUM.
|
125 |
+
chosen_metrics (list): List of metrics to compute, default [METRIC_PSNR, METRIC_SSIM]
|
126 |
+
|
127 |
+
Returns:
|
128 |
+
dict: computed metrics.
|
129 |
+
"""
|
130 |
+
metrics = {}
|
131 |
+
if METRIC_PSNR in chosen_metrics:
|
132 |
+
average_psnr = compute_psnr(predic, target, reduction=reduction)
|
133 |
+
metrics[METRIC_PSNR] = average_psnr.item() if reduction != REDUCTION_SKIP else average_psnr
|
134 |
+
if METRIC_SSIM in chosen_metrics:
|
135 |
+
ssim_value = compute_ssim(predic, target, reduction=reduction)
|
136 |
+
metrics[METRIC_SSIM] = ssim_value.item() if reduction != REDUCTION_SKIP else ssim_value
|
137 |
+
if METRIC_LPIPS in chosen_metrics:
|
138 |
+
lpip_value = compute_lpips(predic, target, reduction=reduction)
|
139 |
+
metrics[METRIC_LPIPS] = lpip_value.item() if reduction != REDUCTION_SKIP else lpip_value
|
140 |
+
return metrics
|
src/rstor/properties.py
ADDED
@@ -0,0 +1,67 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from pathlib import Path
|
3 |
+
RELU = "ReLU"
|
4 |
+
LEAKY_RELU = "LeakyReLU"
|
5 |
+
SIMPLE_GATE = "simple_gate"
|
6 |
+
LOSS = "loss"
|
7 |
+
LOSS_MSE = "MSE"
|
8 |
+
METRIC_PSNR = "PSNR"
|
9 |
+
METRIC_SSIM = "SSIM"
|
10 |
+
METRIC_LPIPS = "LPIPS"
|
11 |
+
SELECTED_METRICS = "selected_metrics"
|
12 |
+
DATALOADER = "data_loader"
|
13 |
+
BATCH_SIZE = "batch_size"
|
14 |
+
SIZE = "size"
|
15 |
+
TRAIN, VALIDATION, TEST = "train", "validation", "test"
|
16 |
+
LENGTH = "length"
|
17 |
+
ID = "id"
|
18 |
+
NAME = "name"
|
19 |
+
PRETTY_NAME = "pretty_name"
|
20 |
+
NB_EPOCHS = "nb_epochs"
|
21 |
+
ARCHITECTURE = "architecture"
|
22 |
+
MODEL = "model"
|
23 |
+
NAME = "name"
|
24 |
+
N_PARAMS = "n_params"
|
25 |
+
OPTIMIZER = "optimizer"
|
26 |
+
LR = "lr"
|
27 |
+
PARAMS = "parameters"
|
28 |
+
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
29 |
+
SCHEDULER_CONFIGURATION = "scheduler_configuration"
|
30 |
+
SCHEDULER = "scheduler"
|
31 |
+
REDUCELRONPLATEAU = "ReduceLROnPlateau"
|
32 |
+
ARCHITECTURE = "architecture"
|
33 |
+
CONFIG_DEAD_LEAVES = "config_dead_leaves"
|
34 |
+
CONFIG_DEGRADATION = "config_degradation"
|
35 |
+
REDUCTION_SUM = "reduction_sum"
|
36 |
+
REDUCTION_AVERAGE = "reduction_average"
|
37 |
+
REDUCTION_SKIP = "reduction_skip"
|
38 |
+
TRACES_TARGET = "target"
|
39 |
+
TRACES_DEGRADED = "degraded"
|
40 |
+
TRACES_RESTORED = "restored"
|
41 |
+
TRACES_METRICS = "metrics"
|
42 |
+
TRACES_ALL = "all"
|
43 |
+
|
44 |
+
DEGRADATION_BLUR_NONE = "none"
|
45 |
+
DEGRADATION_BLUR_MAT = "mat"
|
46 |
+
DEGRADATION_BLUR_GAUSS = "gauss"
|
47 |
+
|
48 |
+
|
49 |
+
SAMPLER_SATURATED = "saturated"
|
50 |
+
SAMPLER_UNIFORM = "uniform"
|
51 |
+
SAMPLER_NATURAL = "natural"
|
52 |
+
SAMPLER_DIV2K = "div2k"
|
53 |
+
|
54 |
+
DATASET_FOLDER = "__dataset"
|
55 |
+
DATASET_PATH = Path(__file__).parent.parent.parent/DATASET_FOLDER
|
56 |
+
DATASET_DL_RANDOMRGB_1024 = "deadleaves_randomrgb_1024"
|
57 |
+
DATASET_DL_DIV2K_1024 = "deadleaves_div2k_1024"
|
58 |
+
DATASET_DL_DIV2K_512 = "deadleaves_div2k_512"
|
59 |
+
DATASET_DL_EXTRAPRIMITIVES_DIV2K_512 = "deadleaves_primitives_div2k_512"
|
60 |
+
DATASET_SYNTH_LIST = [DATASET_DL_DIV2K_512, DATASET_DL_DIV2K_1024,
|
61 |
+
DATASET_DL_RANDOMRGB_1024, DATASET_DL_EXTRAPRIMITIVES_DIV2K_512]
|
62 |
+
DATASET_BLUR_KERNEL_PATH = DATASET_PATH / "kernels" / "custom_blur_centered.mat"
|
63 |
+
AUGMENTATION_FLIP = "flip"
|
64 |
+
AUGMENTATION_ROTATE = "rotate"
|
65 |
+
|
66 |
+
|
67 |
+
DATASET_DIV2K = "div2k"
|
src/rstor/synthetic_data/color_sampler.py
ADDED
@@ -0,0 +1,73 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import cv2
|
3 |
+
|
4 |
+
from pathlib import Path
|
5 |
+
from rstor.properties import DATASET_PATH
|
6 |
+
from typing import List
|
7 |
+
|
8 |
+
|
9 |
+
def sample_uniform_rgb(size: int, seed: int = None) -> np.ndarray:
|
10 |
+
"""
|
11 |
+
Generate n random RGB values.
|
12 |
+
|
13 |
+
Args:
|
14 |
+
n (int): number of colors to sample
|
15 |
+
seed (int, optional): Seed for the random number generator. Defaults to None.
|
16 |
+
|
17 |
+
Returns:
|
18 |
+
np.ndarray: Random RGB values as a numpy array.
|
19 |
+
"""
|
20 |
+
# https://github.com/numpy/numpy/issues/17079
|
21 |
+
# https://numpy.org/devdocs/reference/random/new-or-different.html#new-or-different
|
22 |
+
rng = np.random.default_rng(np.random.SeedSequence(seed))
|
23 |
+
|
24 |
+
random_samples = rng.uniform(size=(size, 3))
|
25 |
+
rgb = random_samples
|
26 |
+
|
27 |
+
# Below old version with sturation
|
28 |
+
# lab = (random_samples + np.array([0., -0.5, -0.5])[None]) * np.array([100., 127 * 2, 127 * 2])[None]
|
29 |
+
# rgb = cv2.cvtColor(lab[None, :].astype(np.float32), cv2.COLOR_Lab2RGB)
|
30 |
+
return rgb.squeeze()
|
31 |
+
|
32 |
+
|
33 |
+
def sample_saturated_color(size: int, seed: int = None) -> np.ndarray:
|
34 |
+
"""
|
35 |
+
Generate n saturated RGB values.
|
36 |
+
|
37 |
+
Args:
|
38 |
+
n (int): number of colors to sample
|
39 |
+
seed (int, optional): Seed for the random number generator. Defaults to None.
|
40 |
+
|
41 |
+
Returns:
|
42 |
+
np.ndarray: Random RGB values as a numpy array.
|
43 |
+
"""
|
44 |
+
# https://github.com/numpy/numpy/issues/17079
|
45 |
+
# https://numpy.org/devdocs/reference/random/new-or-different.html#new-or-different
|
46 |
+
rng = np.random.default_rng(np.random.SeedSequence(seed))
|
47 |
+
|
48 |
+
random_samples = rng.uniform(size=(size, 3))
|
49 |
+
|
50 |
+
lab = (random_samples + np.array([0., -0.5, -0.5])[None]) * np.array([100., 127 * 2, 127 * 2])[None]
|
51 |
+
rgb = cv2.cvtColor(lab[None, :].astype(np.float32), cv2.COLOR_Lab2RGB)
|
52 |
+
return rgb.squeeze()
|
53 |
+
|
54 |
+
|
55 |
+
def sample_color_from_images(size: int, seed: int = None, path_to_images: List[Path] = []) -> np.ndarray:
|
56 |
+
print("path : ", path_to_images)
|
57 |
+
assert len(path_to_images) > 0, "Please provide a list of images to sample colors from."
|
58 |
+
rng = np.random.default_rng(np.random.SeedSequence(seed))
|
59 |
+
|
60 |
+
# Randomly pick an image and load it
|
61 |
+
img_id = rng.integers(0, len(path_to_images))
|
62 |
+
|
63 |
+
img = cv2.imread(path_to_images[img_id].as_posix())
|
64 |
+
img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) / 255
|
65 |
+
|
66 |
+
pixels = img.reshape(-1, 3)
|
67 |
+
n_pixels = pixels.shape[0]
|
68 |
+
|
69 |
+
# sample a pixel color for each disc
|
70 |
+
pixel_ids = rng.integers(0, n_pixels, size)
|
71 |
+
colors = pixels[pixel_ids, :]
|
72 |
+
|
73 |
+
return colors
|
src/rstor/synthetic_data/dead_leaves_cpu.py
ADDED
@@ -0,0 +1,79 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Tuple, Optional, List
|
2 |
+
import numpy as np
|
3 |
+
import cv2
|
4 |
+
from rstor.synthetic_data.dead_leaves_sampler import define_dead_leaves_chart
|
5 |
+
from rstor.properties import SAMPLER_UNIFORM
|
6 |
+
from pathlib import Path
|
7 |
+
|
8 |
+
|
9 |
+
def cpu_dead_leaves_chart(size: Tuple[int, int] = (100, 100),
|
10 |
+
number_of_circles: int = -1,
|
11 |
+
background_color: Optional[Tuple[float, float, float]] = (0.5, 0.5, 0.5),
|
12 |
+
colored: Optional[bool] = True,
|
13 |
+
radius_min: Optional[int] = -1,
|
14 |
+
radius_max: Optional[int] = -1,
|
15 |
+
radius_alpha: Optional[int] = 3,
|
16 |
+
seed: int = None,
|
17 |
+
reverse: Optional[bool] = True,
|
18 |
+
sampler=SAMPLER_UNIFORM,
|
19 |
+
natural_image_list: List[Path] = []) -> np.ndarray:
|
20 |
+
"""
|
21 |
+
Generation of a deqqad leaves chart by splatting circles on top of each other.
|
22 |
+
|
23 |
+
Args:
|
24 |
+
size (Tuple[int, int], optional): size of the generated chart. Defaults to (100, 100).
|
25 |
+
number_of_circles (int, optional): number of circles to generate.
|
26 |
+
If negative, it is computed based on the size. Defaults to -1.
|
27 |
+
background_color (Optional[Tuple[float, float, float]], optional):
|
28 |
+
background color of the chart. Defaults to gray (0.5, 0.5, 0.5).
|
29 |
+
colored (Optional[bool], optional): Whether to generate colored circles. Defaults to True.
|
30 |
+
radius_min (Optional[int], optional): minimum radius of the circles. Defaults to -1. (=> 1)
|
31 |
+
radius_max (Optional[int], optional): maximum radius of the circles. Defaults to -1. (=> 2000)
|
32 |
+
radius_alpha (Optional[int], optional): standard deviation of the radius of the circles.
|
33 |
+
If negative, it is calculated based on the size. Defaults to -1.
|
34 |
+
seed (int, optional): seed for the random number generator. Defaults to None
|
35 |
+
reverse: (Optional[bool], optional): View circles from the back view
|
36 |
+
by reversing order. Defaults to True.
|
37 |
+
WARNING: This option is extremely slow on CPU.
|
38 |
+
|
39 |
+
Returns:
|
40 |
+
np.ndarray: generated dead leaves chart as a NumPy array.
|
41 |
+
"""
|
42 |
+
center_x, center_y, radius, color = define_dead_leaves_chart(
|
43 |
+
size,
|
44 |
+
number_of_circles,
|
45 |
+
colored,
|
46 |
+
radius_min,
|
47 |
+
radius_max,
|
48 |
+
radius_alpha,
|
49 |
+
seed,
|
50 |
+
sampler=sampler,
|
51 |
+
natural_image_list=natural_image_list
|
52 |
+
)
|
53 |
+
if not colored:
|
54 |
+
color = np.concatenate((color, color, color), axis=1)
|
55 |
+
|
56 |
+
if reverse:
|
57 |
+
chart = np.zeros((size[0], size[1], 3), dtype=np.float32)
|
58 |
+
buffer = np.zeros_like(chart)
|
59 |
+
is_not_covered_mask = np.ones((*chart.shape[:2], 1))
|
60 |
+
for i in range(number_of_circles):
|
61 |
+
cv2.circle(buffer, (center_x[i], center_y[i]), radius[i], color[i], -1)
|
62 |
+
chart += buffer * is_not_covered_mask
|
63 |
+
is_not_covered_mask = cv2.circle(is_not_covered_mask, (center_x[i], center_y[i]), radius[i], 0, -1)
|
64 |
+
|
65 |
+
if not np.any(is_not_covered_mask):
|
66 |
+
break
|
67 |
+
|
68 |
+
chart += np.multiply(background_color, np.ones((size[0], size[1], 3), dtype=np.float32)) * is_not_covered_mask
|
69 |
+
else:
|
70 |
+
chart = np.multiply(background_color, np.ones((size[0], size[1], 3), dtype=np.float32))
|
71 |
+
for i in range(number_of_circles):
|
72 |
+
# circle is inplace
|
73 |
+
cv2.circle(chart, (center_x[i], center_y[i]), radius[i], color[i], -1)
|
74 |
+
|
75 |
+
chart = chart.clip(0, 1)
|
76 |
+
|
77 |
+
if not colored:
|
78 |
+
chart = chart[:, :, 0, None] # return shape [h, w, 1] in gray mode
|
79 |
+
return chart
|
src/rstor/synthetic_data/dead_leaves_gpu.py
ADDED
@@ -0,0 +1,176 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from rstor.utils import DEFAULT_NUMPY_FLOAT_TYPE, THREADS_PER_BLOCK
|
2 |
+
from rstor.properties import SAMPLER_UNIFORM
|
3 |
+
from typing import Tuple, Optional
|
4 |
+
from rstor.synthetic_data.dead_leaves_cpu import define_dead_leaves_chart
|
5 |
+
import numpy as np
|
6 |
+
from numba import cuda
|
7 |
+
import math
|
8 |
+
|
9 |
+
|
10 |
+
def gpu_dead_leaves_chart(
|
11 |
+
size: Tuple[int, int] = (100, 100),
|
12 |
+
number_of_circles: int = -1,
|
13 |
+
background_color: Optional[Tuple[float, float, float]] = (0.5, 0.5, 0.5),
|
14 |
+
colored: Optional[bool] = True,
|
15 |
+
radius_min: Optional[int] = -1,
|
16 |
+
radius_max: Optional[int] = -1,
|
17 |
+
radius_alpha: Optional[int] = 3,
|
18 |
+
seed: int = None,
|
19 |
+
reverse=True,
|
20 |
+
sampler=SAMPLER_UNIFORM,
|
21 |
+
natural_image_list=None,
|
22 |
+
circle_primitives: bool = True,
|
23 |
+
anisotropy: float = 1.,
|
24 |
+
angle: float = 0.
|
25 |
+
) -> np.ndarray:
|
26 |
+
center_x, center_y, radius, color = define_dead_leaves_chart(
|
27 |
+
size,
|
28 |
+
number_of_circles,
|
29 |
+
colored,
|
30 |
+
radius_min,
|
31 |
+
radius_max,
|
32 |
+
radius_alpha,
|
33 |
+
seed,
|
34 |
+
sampler=sampler,
|
35 |
+
natural_image_list=natural_image_list
|
36 |
+
)
|
37 |
+
|
38 |
+
# Generate on gpu
|
39 |
+
chart = _generate_dead_leaves(
|
40 |
+
size,
|
41 |
+
centers=np.stack((center_x, center_y), axis=-1),
|
42 |
+
radia=radius,
|
43 |
+
colors=color,
|
44 |
+
background=background_color,
|
45 |
+
reverse=reverse,
|
46 |
+
circle_primitives=circle_primitives,
|
47 |
+
anisotropy=anisotropy,
|
48 |
+
angle=angle
|
49 |
+
)
|
50 |
+
|
51 |
+
return chart
|
52 |
+
|
53 |
+
|
54 |
+
def _generate_dead_leaves(size, centers, radia, colors, background, reverse, circle_primitives: bool, anisotropy: float = 1., angle: float=0.):
|
55 |
+
assert centers.ndim == 2
|
56 |
+
ny, nx = size
|
57 |
+
nc = colors.shape[-1]
|
58 |
+
|
59 |
+
# Init empty array on GPU
|
60 |
+
generation_ = cuda.device_array((ny, nx, nc), DEFAULT_NUMPY_FLOAT_TYPE)
|
61 |
+
# Move useful array to GPU
|
62 |
+
centers_ = cuda.to_device(centers)
|
63 |
+
radia_ = cuda.to_device(radia)
|
64 |
+
colors_ = cuda.to_device(colors)
|
65 |
+
|
66 |
+
# Dispatch threads
|
67 |
+
threadsperblock = (THREADS_PER_BLOCK, THREADS_PER_BLOCK)
|
68 |
+
blockspergrid_x = math.ceil(nx/threadsperblock[1])
|
69 |
+
blockspergrid_y = math.ceil(ny/threadsperblock[0])
|
70 |
+
blockspergrid = (blockspergrid_x, blockspergrid_y)
|
71 |
+
|
72 |
+
if reverse:
|
73 |
+
cuda_dead_leaves_gen_reversed[blockspergrid, threadsperblock](
|
74 |
+
generation_,
|
75 |
+
centers_,
|
76 |
+
radia_,
|
77 |
+
colors_,
|
78 |
+
background,
|
79 |
+
circle_primitives,
|
80 |
+
anisotropy,
|
81 |
+
np.deg2rad(angle)
|
82 |
+
)
|
83 |
+
else:
|
84 |
+
cuda_dead_leaves_gen[blockspergrid, threadsperblock](
|
85 |
+
generation_,
|
86 |
+
centers_,
|
87 |
+
radia_,
|
88 |
+
colors_,
|
89 |
+
background)
|
90 |
+
|
91 |
+
return generation_
|
92 |
+
|
93 |
+
|
94 |
+
@cuda.jit(cache=False)
|
95 |
+
def cuda_dead_leaves_gen_reversed(generation, centers, radia, colors, background, circle_primitives: bool, anisotropy: float, angle: float):
|
96 |
+
idx, idy = cuda.grid(2)
|
97 |
+
ny, nx, nc = generation.shape
|
98 |
+
|
99 |
+
n_discs = centers.shape[0]
|
100 |
+
|
101 |
+
# Out of bound threads
|
102 |
+
if idx >= nx or idy >= ny:
|
103 |
+
return
|
104 |
+
|
105 |
+
for disc_id in range(n_discs):
|
106 |
+
dx_ = idx - centers[disc_id, 0]
|
107 |
+
dy_ = idy - centers[disc_id, 1]
|
108 |
+
dx = math.cos(angle)*dx_ + math.sin(angle)*dy_
|
109 |
+
dy = -math.sin(angle)*dx_ + math.cos(angle)*dy_
|
110 |
+
dx = dx * anisotropy
|
111 |
+
dist_sq = dx*dx + dy*dy
|
112 |
+
|
113 |
+
# Naive thread diverging version
|
114 |
+
r = radia[disc_id]
|
115 |
+
r_sq = r*r
|
116 |
+
if circle_primitives:
|
117 |
+
if dist_sq <= r_sq:
|
118 |
+
# Copy back to global memory
|
119 |
+
for c in range(nc):
|
120 |
+
generation[idy, idx, c] = colors[disc_id, c]
|
121 |
+
return
|
122 |
+
else:
|
123 |
+
if (disc_id % 4) == 0 and dist_sq <= r_sq:
|
124 |
+
# Copy back to global memory
|
125 |
+
alpha = dist_sq/r_sq
|
126 |
+
for c in range(nc):
|
127 |
+
generation[idy, idx, c] = colors[disc_id, c] * alpha + colors[disc_id, (c+1) % 3] * (1-alpha)
|
128 |
+
return
|
129 |
+
elif (disc_id % 4) == 1 and (abs(dx)+abs(dy)) <= r:
|
130 |
+
# Copy back to global memory
|
131 |
+
alpha = dist_sq/r_sq
|
132 |
+
for c in range(nc):
|
133 |
+
generation[idy, idx, c] = colors[disc_id, c] * alpha + colors[disc_id, (c+1) % 3] * (1-alpha)
|
134 |
+
return
|
135 |
+
elif (disc_id % 4) == 2 and abs(dx) <= r and abs(dy) <= r:
|
136 |
+
for c in range(nc):
|
137 |
+
generation[idy, idx, c] = colors[disc_id, c]
|
138 |
+
return
|
139 |
+
elif (disc_id % 200) == 3 and abs(dy) <= r//5:
|
140 |
+
for c in range(nc):
|
141 |
+
generation[idy, idx, c] = colors[disc_id, c] * alpha + colors[disc_id, (c+1) % 3] * (1-alpha)
|
142 |
+
return
|
143 |
+
elif (disc_id % 200) == 4 and abs(dx) <= r//5:
|
144 |
+
for c in range(nc):
|
145 |
+
generation[idy, idx, c] = colors[disc_id, c] * alpha + colors[disc_id, (c+1) % 3] * (1-alpha)
|
146 |
+
return
|
147 |
+
for c in range(nc):
|
148 |
+
generation[idy, idx, c] = background[c]
|
149 |
+
|
150 |
+
|
151 |
+
@cuda.jit(cache=False)
|
152 |
+
def cuda_dead_leaves_gen(generation, centers, radia, colors, background):
|
153 |
+
idx, idy, c = cuda.grid(3)
|
154 |
+
ny, nx, nc = generation.shape
|
155 |
+
|
156 |
+
n_discs = centers.shape[0]
|
157 |
+
|
158 |
+
# Out of bound threads
|
159 |
+
if idx >= nx or idy >= ny:
|
160 |
+
return
|
161 |
+
|
162 |
+
out = background[c]
|
163 |
+
for disc_id in range(n_discs):
|
164 |
+
dx = idx - centers[disc_id, 0]
|
165 |
+
dy = idy - centers[disc_id, 1]
|
166 |
+
dist_sq = dx*dx + dy*dy
|
167 |
+
|
168 |
+
# Naive thread diverging version
|
169 |
+
r = radia[disc_id]
|
170 |
+
r_sq = r*r
|
171 |
+
|
172 |
+
if dist_sq <= r_sq:
|
173 |
+
out = colors[disc_id, c]
|
174 |
+
|
175 |
+
# Copy back to global memory
|
176 |
+
generation[idy, idx, c] = out
|
src/rstor/synthetic_data/dead_leaves_sampler.py
ADDED
@@ -0,0 +1,74 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Tuple, Optional, List
|
2 |
+
from rstor.properties import SAMPLER_SATURATED, SAMPLER_NATURAL, SAMPLER_UNIFORM
|
3 |
+
from rstor.synthetic_data.color_sampler import sample_uniform_rgb, sample_saturated_color, sample_color_from_images
|
4 |
+
import numpy as np
|
5 |
+
from pathlib import Path
|
6 |
+
|
7 |
+
|
8 |
+
def define_dead_leaves_chart(
|
9 |
+
size: Tuple[int, int] = (100, 100),
|
10 |
+
number_of_circles: int = -1,
|
11 |
+
colored: Optional[bool] = True,
|
12 |
+
radius_min: Optional[int] = -1,
|
13 |
+
radius_max: Optional[int] = -1,
|
14 |
+
radius_alpha: Optional[int] = 3,
|
15 |
+
seed: int = None,
|
16 |
+
sampler=SAMPLER_UNIFORM,
|
17 |
+
natural_image_list: Optional[List[Path]] = None
|
18 |
+
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
|
19 |
+
"""
|
20 |
+
Defines the geometric and color properties of the primitives in the dead leaves chart to later be sampled.
|
21 |
+
|
22 |
+
Args:
|
23 |
+
size (Tuple[int, int], optional): size of the generated chart. Defaults to (100, 100).
|
24 |
+
number_of_circles (int, optional): number of circles to generate.
|
25 |
+
If negative, it is computed based on the size. Defaults to -1.
|
26 |
+
colored (Optional[bool], optional): Whether to generate colored circles. Defaults to True.
|
27 |
+
radius_min (Optional[int], optional): minimum radius of the circles. Defaults to -1. (=> 1)
|
28 |
+
radius_max (Optional[int], optional): maximum radius of the circles. Defaults to -1. (=> 2000)
|
29 |
+
radius_alpha (Optional[int], optional): standard deviation of the radius of the circles.
|
30 |
+
If negative, it is calculated based on the size. Defaults to -1.
|
31 |
+
seed (int, optional): seed for the random number generator. Defaults to None
|
32 |
+
|
33 |
+
Returns:
|
34 |
+
Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]: center_x, center_y, radius, color
|
35 |
+
"""
|
36 |
+
rng = np.random.default_rng(np.random.SeedSequence(seed))
|
37 |
+
|
38 |
+
if number_of_circles < 0:
|
39 |
+
number_of_circles = 30 * max(size)
|
40 |
+
if radius_min < 0.:
|
41 |
+
radius_min = 1.
|
42 |
+
if radius_max < 0.:
|
43 |
+
radius_max = 2000.
|
44 |
+
|
45 |
+
# Pick random circle centers and radii
|
46 |
+
center_x = rng.integers(0, size[1], size=number_of_circles)
|
47 |
+
center_y = rng.integers(0, size[0], size=number_of_circles)
|
48 |
+
|
49 |
+
# Sample from a power law distribution for the p(radius=r) = (r.clip(radius_min, radius_max))^(-alpha)
|
50 |
+
|
51 |
+
radius = rng.uniform(
|
52 |
+
low=radius_max ** (1 - radius_alpha),
|
53 |
+
high=radius_min ** (1 - radius_alpha),
|
54 |
+
size=number_of_circles
|
55 |
+
)
|
56 |
+
# Using the change of variables formula for random variables.
|
57 |
+
radius = radius ** (-1/(radius_alpha - 1))
|
58 |
+
radius = radius.round().astype(int)
|
59 |
+
|
60 |
+
# Pick random colors
|
61 |
+
if colored:
|
62 |
+
if sampler == SAMPLER_UNIFORM:
|
63 |
+
color = sample_uniform_rgb(number_of_circles, seed=rng.integers(0, 1e10)).astype(float)
|
64 |
+
elif sampler == SAMPLER_SATURATED:
|
65 |
+
color = sample_saturated_color(number_of_circles, seed=rng.integers(0, 1e10)).astype(float)
|
66 |
+
elif sampler == SAMPLER_NATURAL:
|
67 |
+
assert natural_image_list is not None, "Please provide a list of images to sample colors from."
|
68 |
+
color = sample_color_from_images(number_of_circles, seed=rng.integers(0, 1e10),
|
69 |
+
path_to_images=natural_image_list).astype(float)
|
70 |
+
else:
|
71 |
+
raise NotImplementedError(f"Unknown color sampler {sampler}")
|
72 |
+
else:
|
73 |
+
color = rng.uniform(0.25, 0.75, size=(number_of_circles, 1))
|
74 |
+
return center_x, center_y, radius, color
|
src/rstor/synthetic_data/interactive/interactive_dead_leaves.py
ADDED
@@ -0,0 +1,81 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from rstor.synthetic_data.dead_leaves_cpu import cpu_dead_leaves_chart
|
2 |
+
from rstor.synthetic_data.dead_leaves_gpu import gpu_dead_leaves_chart
|
3 |
+
from rstor.properties import SAMPLER_UNIFORM, SAMPLER_DIV2K, SAMPLER_NATURAL, SAMPLER_SATURATED, DATASET_PATH
|
4 |
+
import sys
|
5 |
+
import numpy as np
|
6 |
+
from interactive_pipe import interactive_pipeline, interactive
|
7 |
+
from typing import Optional
|
8 |
+
|
9 |
+
|
10 |
+
def dead_leave_plugin(ds=1):
|
11 |
+
interactive(
|
12 |
+
background_intensity=(0.5, [0., 1.]),
|
13 |
+
number_of_circles=(-1, [-1, 10000]),
|
14 |
+
colored=(True,),
|
15 |
+
radius_alpha=(3., [2., 10.]),
|
16 |
+
seed=(0, [-1, 42]),
|
17 |
+
ds=(ds, [1, 5]),
|
18 |
+
numba_flag=(True,), # Default CPU to avoid issues by default
|
19 |
+
sampler=(SAMPLER_UNIFORM, [SAMPLER_UNIFORM, SAMPLER_DIV2K, SAMPLER_SATURATED]),
|
20 |
+
circle_primitives=(True,),
|
21 |
+
anisotropy=(1., [0.1, 10.]),
|
22 |
+
angle=(0., [-180., 180.])
|
23 |
+
# ds=(ds, [1, 5])
|
24 |
+
)(generate_deadleave)
|
25 |
+
|
26 |
+
|
27 |
+
def generate_deadleave(
|
28 |
+
background_intensity: float = 0.5,
|
29 |
+
number_of_circles: int = -1,
|
30 |
+
colored: Optional[bool] = False,
|
31 |
+
radius_alpha: Optional[int] = 3,
|
32 |
+
seed=0,
|
33 |
+
ds=3,
|
34 |
+
numba_flag=True,
|
35 |
+
sampler=SAMPLER_UNIFORM,
|
36 |
+
circle_primitives=True,
|
37 |
+
anisotropy=1.,
|
38 |
+
angle=0.,
|
39 |
+
global_params={}
|
40 |
+
) -> np.ndarray:
|
41 |
+
global_params["ds_factor"] = ds
|
42 |
+
bg_color = (background_intensity, background_intensity, background_intensity)
|
43 |
+
natural_image_list = None
|
44 |
+
if sampler == SAMPLER_DIV2K:
|
45 |
+
sampler = SAMPLER_NATURAL
|
46 |
+
div2k_path = DATASET_PATH / "div2k" / "DIV2K_train_HR" / "DIV2K_train_HR"
|
47 |
+
natural_image_list = sorted([file for file in div2k_path.glob("*.png")])
|
48 |
+
if not numba_flag:
|
49 |
+
chart = cpu_dead_leaves_chart((512*ds, 512*ds), number_of_circles, bg_color, colored,
|
50 |
+
radius_alpha=radius_alpha,
|
51 |
+
seed=None if seed < 0 else seed,
|
52 |
+
sampler=sampler,
|
53 |
+
reverse=False,
|
54 |
+
natural_image_list=natural_image_list)
|
55 |
+
else:
|
56 |
+
chart = gpu_dead_leaves_chart((512*ds, 512*ds), number_of_circles, bg_color, colored,
|
57 |
+
radius_alpha=radius_alpha,
|
58 |
+
seed=None if seed < 0 else seed,
|
59 |
+
sampler=sampler,
|
60 |
+
natural_image_list=natural_image_list,
|
61 |
+
circle_primitives=circle_primitives,
|
62 |
+
anisotropy=anisotropy,
|
63 |
+
angle=angle).copy_to_host()
|
64 |
+
if chart.shape[-1] == 1:
|
65 |
+
chart = chart.repeat(3, axis=-1)
|
66 |
+
# Required to switch from colors to gray scale visualization.
|
67 |
+
return chart
|
68 |
+
|
69 |
+
|
70 |
+
def deadleave_pipeline():
|
71 |
+
deadleave_chart = generate_deadleave()
|
72 |
+
return deadleave_chart
|
73 |
+
|
74 |
+
|
75 |
+
def main(argv):
|
76 |
+
dead_leave_plugin(ds=1)
|
77 |
+
interactive_pipeline(gui="auto")(deadleave_pipeline)()
|
78 |
+
|
79 |
+
|
80 |
+
if __name__ == "__main__":
|
81 |
+
main(sys.argv[1:])
|
src/rstor/utils.py
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import numba
|
3 |
+
import torch
|
4 |
+
|
5 |
+
THREADS_PER_BLOCK = 32 # 32 or 16
|
6 |
+
DEFAULT_NUMPY_FLOAT_TYPE = np.float32
|
7 |
+
DEFAULT_CUDA_FLOAT_TYPE = numba.float32
|
8 |
+
DEFAULT_TORCH_FLOAT_TYPE = torch.float32
|
9 |
+
|
10 |
+
|
11 |
+
DEFAULT_NUMPY_INT_TYPE = np.int32
|
12 |
+
DEFAULT_CUDA_INT_TYPE = numba.int32
|
test/test_dataloader.py
ADDED
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from rstor.data.synthetic_dataloader import DeadLeavesDataset
|
3 |
+
|
4 |
+
|
5 |
+
def test_dead_leaves_dataset():
|
6 |
+
# Test case 1: Default parameters
|
7 |
+
dataset = DeadLeavesDataset(noise_stddev=(0, 0), ds_factor=1)
|
8 |
+
assert len(dataset) == 1000
|
9 |
+
assert dataset.size == (128, 128)
|
10 |
+
assert dataset.frozen_seed is None
|
11 |
+
assert dataset.config_dead_leaves == {}
|
12 |
+
|
13 |
+
# Test case 2: Custom parameters
|
14 |
+
dataset = DeadLeavesDataset(size=(256, 256), length=500, frozen_seed=42, number_of_circles=5,
|
15 |
+
background_color=(0.2, 0.4, 0.6), colored=True, radius_min=1, radius_alpha=3,
|
16 |
+
noise_stddev=(0, 0), ds_factor=1)
|
17 |
+
assert len(dataset) == 500
|
18 |
+
assert dataset.size == (256, 256)
|
19 |
+
assert dataset.frozen_seed == 42
|
20 |
+
assert dataset.config_dead_leaves == {
|
21 |
+
'number_of_circles': 5,
|
22 |
+
'background_color': (0.2, 0.4, 0.6),
|
23 |
+
'colored': True,
|
24 |
+
'radius_min': 1,
|
25 |
+
'radius_alpha': 3
|
26 |
+
}
|
27 |
+
|
28 |
+
# Test case 3: Check item retrieval
|
29 |
+
item, item_tgt = dataset[0]
|
30 |
+
assert isinstance(item, torch.Tensor)
|
31 |
+
assert item.shape == (3, 256, 256)
|
32 |
+
|
33 |
+
# Test case 4: Repeatable results with frozen seed
|
34 |
+
dataset1 = DeadLeavesDataset(frozen_seed=42, noise_stddev=(0, 0), number_of_circles=256)
|
35 |
+
dataset2 = DeadLeavesDataset(frozen_seed=42, noise_stddev=(0, 0), number_of_circles=256)
|
36 |
+
item1, item_tgt1 = dataset1[0]
|
37 |
+
item2, item_tgt2 = dataset2[0]
|
38 |
+
assert torch.all(torch.eq(item1, item2))
|
39 |
+
|
40 |
+
# Test case 5: Visualize
|
41 |
+
# dataset = DeadLeavesDataset(size=(256, 256), length=500, frozen_seed=43,
|
42 |
+
# background_color=(0.2, 0.4, 0.6), colored=True, radius_min=1, radius_alpha=3,
|
43 |
+
# noise_stddev=(0, 0), ds_factor=1)
|
44 |
+
# item, item_tgt = dataset[0]
|
45 |
+
# import matplotlib.pyplot as plt
|
46 |
+
# plt.figure()
|
47 |
+
# plt.imshow(item.permute(1, 2, 0).detach().cpu())
|
48 |
+
# plt.show()
|
49 |
+
# print("done")
|
test/test_dataloader_gpu.py
ADDED
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from rstor.data.synthetic_dataloader import DeadLeavesDatasetGPU
|
3 |
+
import numba
|
4 |
+
|
5 |
+
|
6 |
+
def test_dead_leaves_dataset_gpu():
|
7 |
+
if not numba.cuda.is_available():
|
8 |
+
return
|
9 |
+
|
10 |
+
# Test case 1: Default parameters
|
11 |
+
dataset = DeadLeavesDatasetGPU(noise_stddev=(0, 0), ds_factor=1)
|
12 |
+
assert len(dataset) == 1000
|
13 |
+
assert dataset.size == (128, 128)
|
14 |
+
assert dataset.frozen_seed is None
|
15 |
+
assert dataset.config_dead_leaves == {}
|
16 |
+
|
17 |
+
# Test case 2: Custom parameters
|
18 |
+
dataset = DeadLeavesDatasetGPU(size=(256, 256), length=500, frozen_seed=42, number_of_circles=5,
|
19 |
+
background_color=(0.2, 0.4, 0.6), colored=True, radius_min=1, radius_alpha=3,
|
20 |
+
noise_stddev=(0, 0), ds_factor=1)
|
21 |
+
assert len(dataset) == 500
|
22 |
+
assert dataset.size == (256, 256)
|
23 |
+
assert dataset.frozen_seed == 42
|
24 |
+
assert dataset.config_dead_leaves == {
|
25 |
+
'number_of_circles': 5,
|
26 |
+
'background_color': (0.2, 0.4, 0.6),
|
27 |
+
'colored': True,
|
28 |
+
'radius_min': 1,
|
29 |
+
'radius_alpha': 3
|
30 |
+
}
|
31 |
+
|
32 |
+
# Test case 3: Check item retrieval
|
33 |
+
item, item_tgt = dataset[0]
|
34 |
+
assert isinstance(item, torch.Tensor)
|
35 |
+
assert item.shape == (3, 256, 256)
|
36 |
+
|
37 |
+
# Test case 4: Repeatable results with frozen seed
|
38 |
+
dataset1 = DeadLeavesDatasetGPU(frozen_seed=42, noise_stddev=(0, 0), number_of_circles=256)
|
39 |
+
dataset2 = DeadLeavesDatasetGPU(frozen_seed=42, noise_stddev=(0, 0), number_of_circles=256)
|
40 |
+
item1, item_tgt1 = dataset1[0]
|
41 |
+
item2, item_tgt2 = dataset2[0]
|
42 |
+
assert torch.all(torch.eq(item1, item2))
|
43 |
+
|
44 |
+
|
45 |
+
|
46 |
+
# Test case 5: Visualize
|
47 |
+
# dataset = DeadLeavesDatasetGPU(size=(256, 256), length=500, frozen_seed=44, number_of_circles=10_000,
|
48 |
+
# background_color=(0.2, 0.4, 0.6), colored=True, radius_min=1, radius_alpha=3,
|
49 |
+
# noise_stddev=(0, 0), ds_factor=1)
|
50 |
+
# item, item_tgt = dataset[0]
|
51 |
+
# import matplotlib.pyplot as plt
|
52 |
+
# plt.figure()
|
53 |
+
# plt.imshow(item.permute(1, 2, 0).detach().cpu())
|
54 |
+
# plt.show()
|
55 |
+
# print("done")
|
56 |
+
|
test/test_dataloader_stored.py
ADDED
@@ -0,0 +1,67 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from rstor.data.stored_images_dataloader import RestorationDataset
|
3 |
+
from numba import cuda
|
4 |
+
from rstor.properties import DATASET_PATH, AUGMENTATION_FLIP, AUGMENTATION_ROTATE
|
5 |
+
|
6 |
+
|
7 |
+
def test_dataloader_stored():
|
8 |
+
if not cuda.is_available():
|
9 |
+
print("cuda unavailable, exiting")
|
10 |
+
return
|
11 |
+
|
12 |
+
# Test case 1: Default parameters
|
13 |
+
dataset = RestorationDataset(noise_stddev=(0, 0),
|
14 |
+
images_path=DATASET_PATH/"sample")
|
15 |
+
assert len(dataset) == 2
|
16 |
+
assert dataset.frozen_seed is None
|
17 |
+
|
18 |
+
# Test case 2: Custom parameters
|
19 |
+
dataset = RestorationDataset(images_path=DATASET_PATH/"sample",
|
20 |
+
size=(64, 64),
|
21 |
+
frozen_seed=42,
|
22 |
+
noise_stddev=(0, 0))
|
23 |
+
assert len(dataset) == 2
|
24 |
+
assert dataset.frozen_seed == 42
|
25 |
+
|
26 |
+
# Test case 3: Check item retrieval
|
27 |
+
item, item_tgt = dataset[0]
|
28 |
+
assert isinstance(item, torch.Tensor)
|
29 |
+
assert item.shape == item_tgt.shape
|
30 |
+
assert item.shape == (3, 64, 64)
|
31 |
+
|
32 |
+
# Test case 4: Repeatable results with frozen seed
|
33 |
+
dataset1 = RestorationDataset(images_path=DATASET_PATH/"sample",
|
34 |
+
frozen_seed=42, noise_stddev=(0, 0))
|
35 |
+
dataset2 = RestorationDataset(images_path=DATASET_PATH/"sample",
|
36 |
+
frozen_seed=42, noise_stddev=(0, 0))
|
37 |
+
item1, item_tgt1 = dataset1[0]
|
38 |
+
item2, item_tgt2 = dataset2[0]
|
39 |
+
|
40 |
+
assert torch.all(torch.eq(item1, item2))
|
41 |
+
|
42 |
+
# Test case 4: Repeatable results with frozen seed and augmentation
|
43 |
+
augmentation_list = [AUGMENTATION_FLIP, AUGMENTATION_ROTATE]
|
44 |
+
dataset1 = RestorationDataset(images_path=DATASET_PATH/"sample",
|
45 |
+
frozen_seed=42, noise_stddev=(0, 0),
|
46 |
+
augmentation_list=augmentation_list)
|
47 |
+
dataset2 = RestorationDataset(images_path=DATASET_PATH/"sample",
|
48 |
+
frozen_seed=42, noise_stddev=(0, 0),
|
49 |
+
augmentation_list=augmentation_list)
|
50 |
+
item1, item_tgt1 = dataset1[0]
|
51 |
+
item2, item_tgt2 = dataset2[0]
|
52 |
+
assert torch.all(torch.eq(item1, item2))
|
53 |
+
|
54 |
+
|
55 |
+
|
56 |
+
# Test case 5: Visualize
|
57 |
+
# dataset = RestorationDataset(images_path=DATASET_PATH/"sample",
|
58 |
+
# noise_stddev=(0, 0),
|
59 |
+
# augmentation_list=augmentation_list)
|
60 |
+
# item, item_tgt = dataset[0]
|
61 |
+
# import matplotlib.pyplot as plt
|
62 |
+
# plt.figure()
|
63 |
+
# plt.imshow(item.permute(1, 2, 0).detach().cpu())
|
64 |
+
# plt.show()
|
65 |
+
# breakpoint()
|
66 |
+
print("done")
|
67 |
+
|
test/test_dead_leaves.py
ADDED
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
from rstor.synthetic_data.dead_leaves_cpu import cpu_dead_leaves_chart
|
3 |
+
from rstor.properties import SAMPLER_NATURAL, DATASET_PATH
|
4 |
+
|
5 |
+
|
6 |
+
def test_dead_leaves_chart():
|
7 |
+
# Test case 1: Default parameters
|
8 |
+
chart = cpu_dead_leaves_chart()
|
9 |
+
assert isinstance(chart, np.ndarray)
|
10 |
+
assert chart.shape == (100, 100, 3)
|
11 |
+
|
12 |
+
# Test case 2: Custom size and number of circles
|
13 |
+
chart = cpu_dead_leaves_chart(size=(200, 150), number_of_circles=10)
|
14 |
+
assert isinstance(chart, np.ndarray)
|
15 |
+
assert chart.shape == (200, 150, 3)
|
16 |
+
|
17 |
+
# Test case 3: Colored circles
|
18 |
+
chart = cpu_dead_leaves_chart(colored=True, number_of_circles=300)
|
19 |
+
assert isinstance(chart, np.ndarray)
|
20 |
+
assert chart.shape == (100, 100, 3)
|
21 |
+
|
22 |
+
# Test case 4: Custom radius mean and stddev
|
23 |
+
chart = cpu_dead_leaves_chart(radius_min=5, radius_alpha=2, number_of_circles=300)
|
24 |
+
assert isinstance(chart, np.ndarray)
|
25 |
+
assert chart.shape == (100, 100, 3)
|
26 |
+
|
27 |
+
# Test case 5: Custom background color
|
28 |
+
chart = cpu_dead_leaves_chart(background_color=(0.2, 0.4, 0.6), number_of_circles=300)
|
29 |
+
assert isinstance(chart, np.ndarray)
|
30 |
+
assert chart.shape == (100, 100, 3)
|
31 |
+
|
32 |
+
# Test case 6: Custom seed
|
33 |
+
chart1 = cpu_dead_leaves_chart(seed=42, number_of_circles=300)
|
34 |
+
chart2 = cpu_dead_leaves_chart(seed=42, number_of_circles=300)
|
35 |
+
assert np.array_equal(chart1, chart2)
|
36 |
+
|
37 |
+
|
38 |
+
def test_dead_leaves_color_sampler():
|
39 |
+
img_list = sorted(
|
40 |
+
list((DATASET_PATH / "sample").glob("*.png"))
|
41 |
+
)
|
42 |
+
_gen = cpu_dead_leaves_chart(number_of_circles=300, sampler=SAMPLER_NATURAL, natural_image_list=img_list)
|
43 |
+
# from interactive_pipe.data_objects.image import Image
|
44 |
+
# Image(_gen).show()
|