balthou commited on
Commit
cec5823
·
1 Parent(s): 1e98db4

initiate demo

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. __dataset/sample/0000.png +0 -0
  2. __dataset/sample/0001.jpg +0 -0
  3. __dataset/sample/0002.png +0 -0
  4. app.py +66 -0
  5. requirements.txt +6 -0
  6. scripts/configuration.py +18 -0
  7. scripts/infer.py +205 -0
  8. scripts/interactive_inference_natural.py +65 -0
  9. scripts/interactive_inference_synthetic.py +25 -0
  10. scripts/metrics_analyzis.ipynb +0 -0
  11. scripts/quantitative_results.ipynb +0 -0
  12. scripts/remote_training.py +116 -0
  13. scripts/remote_training_template.ipynb +1 -0
  14. scripts/save_deadleaves.py +172 -0
  15. scripts/train.py +171 -0
  16. src/rstor/__init__.py +0 -0
  17. src/rstor/analyzis/interactive/crop.py +75 -0
  18. src/rstor/analyzis/interactive/degradation.py +71 -0
  19. src/rstor/analyzis/interactive/images.py +10 -0
  20. src/rstor/analyzis/interactive/inference.py +12 -0
  21. src/rstor/analyzis/interactive/metrics.py +36 -0
  22. src/rstor/analyzis/interactive/model_selection.py +58 -0
  23. src/rstor/analyzis/interactive/pipelines.py +61 -0
  24. src/rstor/analyzis/metrics_plots.py +73 -0
  25. src/rstor/analyzis/parser.py +26 -0
  26. src/rstor/architecture/base.py +56 -0
  27. src/rstor/architecture/convolution_blocks.py +47 -0
  28. src/rstor/architecture/nafnet.py +299 -0
  29. src/rstor/architecture/selector.py +19 -0
  30. src/rstor/architecture/stacked_convolutions.py +30 -0
  31. src/rstor/data/augmentation.py +27 -0
  32. src/rstor/data/dataloader.py +120 -0
  33. src/rstor/data/degradation.py +156 -0
  34. src/rstor/data/stored_images_dataloader.py +156 -0
  35. src/rstor/data/synthetic_dataloader.py +187 -0
  36. src/rstor/learning/experiments.py +24 -0
  37. src/rstor/learning/experiments_definition.py +489 -0
  38. src/rstor/learning/loss.py +25 -0
  39. src/rstor/learning/metrics.py +140 -0
  40. src/rstor/properties.py +67 -0
  41. src/rstor/synthetic_data/color_sampler.py +73 -0
  42. src/rstor/synthetic_data/dead_leaves_cpu.py +79 -0
  43. src/rstor/synthetic_data/dead_leaves_gpu.py +176 -0
  44. src/rstor/synthetic_data/dead_leaves_sampler.py +74 -0
  45. src/rstor/synthetic_data/interactive/interactive_dead_leaves.py +81 -0
  46. src/rstor/utils.py +12 -0
  47. test/test_dataloader.py +49 -0
  48. test/test_dataloader_gpu.py +56 -0
  49. test/test_dataloader_stored.py +67 -0
  50. test/test_dead_leaves.py +44 -0
__dataset/sample/0000.png ADDED
__dataset/sample/0001.jpg ADDED
__dataset/sample/0002.png ADDED
app.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ sys.path.append("src")
3
+ from interactive_pipe import interactive_pipeline
4
+ from rstor.analyzis.interactive.pipelines import natural_inference_pipeline, morph_canvas, CANVAS
5
+ from rstor.analyzis.interactive.model_selection import get_default_models
6
+ from pathlib import Path
7
+ from rstor.analyzis.parser import get_parser
8
+ import argparse
9
+ from batch_processing import Batch
10
+ from interactive_pipe.data_objects.image import Image
11
+ from rstor.analyzis.interactive.images import image_selector
12
+ from rstor.analyzis.interactive.crop import plug_crop_selector
13
+ from rstor.analyzis.interactive.metrics import plug_configure_metrics
14
+ from interactive_pipe import interactive, KeyboardControl
15
+
16
+
17
+ def plug_morph_canvas():
18
+ interactive(
19
+ canvas=KeyboardControl(CANVAS[0], CANVAS, name="canvas", keyup="p", modulo=True)
20
+ )(morph_canvas)
21
+
22
+
23
+ def image_loading_batch(input: Path, args: argparse.Namespace) -> dict:
24
+ """Wrapper to load images files from a directory using batch_processing
25
+ """
26
+
27
+ if not args.disable_preload:
28
+ img = Image.from_file(input).data
29
+ return {"name": input.name, "path": input, "buffer": img}
30
+ else:
31
+ return {"name": input.name, "path": input, "buffer": None}
32
+
33
+
34
+ def main(argv):
35
+ batch = Batch(argv)
36
+ batch.set_io_description(
37
+ input_help='input image files',
38
+ output_help=argparse.SUPPRESS
39
+ )
40
+ parser = get_parser()
41
+ parser.add_argument("-nop", "--disable-preload", action="store_true", help="Disable images preload")
42
+ args = batch.parse_args(parser)
43
+ # batch.set_multiprocessing_enabled(False)
44
+ img_list = batch.run(image_loading_batch)
45
+ if args.keyboard:
46
+ image_control = KeyboardControl(0, [0, len(img_list)-1], keydown="3", keyup="9", modulo=True)
47
+ else:
48
+ image_control = (0, [0, len(img_list)-1])
49
+ interactive(image_index=image_control)(image_selector)
50
+ plug_crop_selector(num_pad=args.keyboard)
51
+ plug_configure_metrics(key_shortcut="a") # "a" if args.keyboard else None)
52
+ plug_morph_canvas()
53
+ model_dict = get_default_models(args.experiments, Path(args.models_storage), keyboard_control=args.keyboard)
54
+ interactive_pipeline(
55
+ gui=args.backend,
56
+ cache=True,
57
+ safe_input_buffer_deepcopy=False
58
+ )(natural_inference_pipeline)(
59
+ img_list,
60
+ model_dict
61
+ )
62
+
63
+
64
+ if __name__ == "__main__":
65
+ # main(sys.argv[1:])
66
+ main(["-e", "6002", "-i", "__dataset/sample/*.*g", "-b","gradio"])
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ interactive-pipe>=0.7.8
2
+ opencv_python_headless==4.8.0.74
3
+ torch>=2.0.0
4
+ tqdm
5
+
6
+
scripts/configuration.py ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+
3
+ NB_ID = "blind-deblurring-from-synthetic-data" # This will be the name which appears on Kaggle.
4
+ GIT_USER = "balthazarneveu" # Your git user name
5
+ GIT_REPO = "blind-deblurring-from-synthetic-data" # Your current git repo
6
+ # Keep free unless you need to acess kaggle datasets. You'll need to modify the remote_training_template.ipynb.
7
+ KAGGLE_DATASET_LIST = [
8
+ "balthazarneveu/deadleaves-div2k-512", # Deadleaves classic
9
+ "balthazarneveu/deadleaves-primitives-div2k-512", # Deadleaves with extra primitives
10
+ "balthazarneveu/motion-blur-kernels", # Motion blur kernels
11
+ "joe1995/div2k-dataset",
12
+ ]
13
+ WANDBSPACE = "deblur-from-deadleaves"
14
+ TRAIN_SCRIPT = "scripts/train.py" # Location of the training script
15
+
16
+ ROOT_DIR = Path(__file__).parent
17
+ OUTPUT_FOLDER_NAME = "__output"
18
+ INFERENCE_FOLDER_NAME = "__inference"
scripts/infer.py ADDED
@@ -0,0 +1,205 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from configuration import ROOT_DIR, OUTPUT_FOLDER_NAME, INFERENCE_FOLDER_NAME
2
+ from rstor.analyzis.parser import get_models_parser
3
+ from batch_processing import Batch
4
+ from rstor.properties import (
5
+ DEVICE, NAME, PRETTY_NAME, DATALOADER, CONFIG_DEAD_LEAVES, VALIDATION,
6
+ BATCH_SIZE, SIZE,
7
+ REDUCTION_SKIP,
8
+ TRACES_TARGET, TRACES_DEGRADED, TRACES_RESTORED, TRACES_METRICS, TRACES_ALL,
9
+ SAMPLER_SATURATED,
10
+ CONFIG_DEGRADATION,
11
+ DATASET_DIV2K,
12
+ DATASET_DL_DIV2K_512,
13
+ DATASET_DL_EXTRAPRIMITIVES_DIV2K_512,
14
+ METRIC_PSNR, METRIC_SSIM, METRIC_LPIPS,
15
+ DEGRADATION_BLUR_MAT, DEGRADATION_BLUR_NONE
16
+ )
17
+ from rstor.data.dataloader import get_data_loader
18
+ from tqdm import tqdm
19
+ from pathlib import Path
20
+ import torch
21
+ from typing import Optional
22
+ import argparse
23
+ import sys
24
+ from rstor.analyzis.interactive.model_selection import get_default_models
25
+ from rstor.learning.metrics import compute_metrics
26
+ from interactive_pipe.data_objects.image import Image
27
+ from interactive_pipe.data_objects.parameters import Parameters
28
+ from typing import List
29
+ from itertools import product
30
+ import pandas as pd
31
+ ALL_TRACES = [TRACES_TARGET, TRACES_DEGRADED, TRACES_RESTORED, TRACES_METRICS]
32
+
33
+
34
+ def parse_int_pairs(s):
35
+ try:
36
+ # Split the input string by spaces to separate pairs, then split each pair by ',' and convert to tuple of ints
37
+ return [tuple(map(int, item.split(','))) for item in s.split()]
38
+ except ValueError:
39
+ raise argparse.ArgumentTypeError("Must be a series of pairs 'a,b' separated by spaces.")
40
+
41
+
42
+ def get_parser(parser: Optional[argparse.ArgumentParser] = None, batch_mode=False) -> argparse.ArgumentParser:
43
+ parser = get_models_parser(
44
+ parser=parser,
45
+ help="Inference on validation set",
46
+ default_models_path=ROOT_DIR/OUTPUT_FOLDER_NAME)
47
+ if not batch_mode:
48
+ parser.add_argument("-o", "--output-dir", type=str, default=ROOT_DIR /
49
+ INFERENCE_FOLDER_NAME, help="Output directory")
50
+ parser.add_argument("--cpu", action="store_true", help="Force CPU")
51
+ parser.add_argument("--traces", "-t", nargs="+", type=str, choices=ALL_TRACES+[TRACES_ALL],
52
+ help="Traces to be computed", default=TRACES_ALL)
53
+ parser.add_argument("--size", type=parse_int_pairs,
54
+ default=[(256, 256)], help="Size of the images like '256,512 512,512'")
55
+ parser.add_argument("--std-dev", type=parse_int_pairs, default=[(0, 50)],
56
+ help="Noise standard deviation (a, b) as pairs separated by spaces, e.g., '0,50 8,8 6,10'")
57
+ parser.add_argument("-n", "--number-of-images", type=int, default=None,
58
+ required=False, help="Number of images to process")
59
+ parser.add_argument("-d", "--dataset", type=str,
60
+ choices=[None, DATASET_DL_DIV2K_512, DATASET_DIV2K, DATASET_DL_EXTRAPRIMITIVES_DIV2K_512],
61
+ default=None),
62
+ parser.add_argument("-b", "--blur", action="store_true")
63
+ parser.add_argument("--blur-index", type=int, nargs="+", default=None)
64
+ return parser
65
+
66
+
67
+ def to_image(img: torch.Tensor):
68
+ return img.permute(0, 2, 3, 1).cpu().numpy()
69
+
70
+
71
+ def infer(model, dataloader, config, device, output_dir: Path, traces: List[str] = ALL_TRACES, number_of_images=None, degradation_key=CONFIG_DEAD_LEAVES,
72
+ chosen_metrics=[METRIC_PSNR, METRIC_SSIM]): # add METRIC_LPIPS here!
73
+ img_index = 0
74
+ if TRACES_ALL in traces:
75
+ traces = ALL_TRACES
76
+ if TRACES_METRICS in traces:
77
+ all_metrics = {}
78
+ else:
79
+ all_metrics = None
80
+ with torch.no_grad():
81
+ model.eval()
82
+ for img_degraded, img_target in tqdm(dataloader):
83
+ img_degraded = img_degraded.to(device)
84
+ img_target = img_target.to(device)
85
+ img_restored = model(img_degraded)
86
+ if TRACES_METRICS in traces:
87
+ metrics_input_per_image = compute_metrics(
88
+ img_degraded, img_target, reduction=REDUCTION_SKIP, chosen_metrics=chosen_metrics)
89
+ metrics_per_image = compute_metrics(
90
+ img_restored, img_target, reduction=REDUCTION_SKIP, chosen_metrics=chosen_metrics)
91
+ # print(metrics_per_image)
92
+ img_degraded = to_image(img_degraded)
93
+ img_target = to_image(img_target)
94
+ img_restored = to_image(img_restored)
95
+ for idx in range(img_restored.shape[0]):
96
+ degradation_parameters = dataloader.dataset.current_degradation[img_index]
97
+ common_prefix = f"{img_index:05d}_{img_degraded.shape[-3]:04d}x{img_degraded.shape[-2]:04d}"
98
+ common_prefix += f"_noise=[{config[DATALOADER][degradation_key]['noise_stddev'][0]:02d},{config[DATALOADER][degradation_key]['noise_stddev'][1]:02d}]"
99
+ suffix_deg = ""
100
+ if degradation_parameters['noise_stddev'] > 0:
101
+ suffix_deg += f"_noise={round(degradation_parameters['noise_stddev']):02d}"
102
+ suffix_deg += f"_blur={degradation_parameters['blur_kernel_id']:04d}"
103
+ #if degradation_parameters.get("blur_kernel_id", False) else ""
104
+ save_path_pred = output_dir/f"{common_prefix}_pred{suffix_deg}_{config[PRETTY_NAME]}.png"
105
+ save_path_degr = output_dir/f"{common_prefix}_degr{suffix_deg}.png"
106
+ save_path_targ = output_dir/f"{common_prefix}_targ.png"
107
+ if TRACES_RESTORED in traces:
108
+ Image(img_restored[idx]).save(save_path_pred)
109
+ if TRACES_DEGRADED in traces:
110
+ Image(img_degraded[idx]).save(save_path_degr)
111
+ if TRACES_TARGET in traces:
112
+ Image(img_target[idx]).save(save_path_targ)
113
+ if TRACES_METRICS in traces:
114
+ # current_metrics = {"in": {}, "out": {}}
115
+ # for key, value in metrics_per_image.items():
116
+ # print(f"{key}: {value[idx]:.3f}")
117
+ # current_metrics["in"][key] = metrics_input_per_image[key][idx].item()
118
+ # current_metrics["out"][key] = metrics_per_image[key][idx].item()
119
+ current_metrics = {}
120
+ for key, value in metrics_per_image.items():
121
+ current_metrics["in_"+key] = metrics_input_per_image[key][idx].item()
122
+ current_metrics["out_"+key] = metrics_per_image[key][idx].item()
123
+ current_metrics["degradation"] = degradation_parameters
124
+ current_metrics["size"] = (img_degraded.shape[-3], img_degraded.shape[-2])
125
+ current_metrics["deadleaves_config"] = config[DATALOADER][degradation_key]
126
+ current_metrics["restored"] = save_path_pred.relative_to(output_dir).as_posix()
127
+ current_metrics["degraded"] = save_path_degr.relative_to(output_dir).as_posix()
128
+ current_metrics["target"] = save_path_targ.relative_to(output_dir).as_posix()
129
+ current_metrics["model"] = config[PRETTY_NAME]
130
+ current_metrics["model_id"] = config[NAME]
131
+ Parameters(current_metrics).save(output_dir/f"{common_prefix}_metrics.json")
132
+ # for key, value in all_metrics.items():
133
+
134
+ all_metrics[img_index] = current_metrics
135
+ img_index += 1
136
+ if number_of_images is not None and img_index > number_of_images:
137
+ return all_metrics
138
+ return all_metrics
139
+
140
+
141
+ def infer_main(argv, batch_mode=False):
142
+ parser = get_parser(batch_mode=batch_mode)
143
+ if batch_mode:
144
+ batch = Batch(argv)
145
+ batch.set_io_description(
146
+ input_help='input image files',
147
+ output_help=f'output directory {str(ROOT_DIR/INFERENCE_FOLDER_NAME)}',
148
+ )
149
+ batch.parse_args(parser)
150
+ else:
151
+ args = parser.parse_args(argv)
152
+ device = "cpu" if args.cpu else DEVICE
153
+ dataset = args.dataset
154
+ blur_flag = args.blur
155
+ for exp in args.experiments:
156
+ model_dict = get_default_models([exp], Path(args.models_storage), interactive_flag=False)
157
+ # print(list(model_dict.keys()))
158
+ current_model_dict = model_dict[list(model_dict.keys())[0]]
159
+ model = current_model_dict["model"]
160
+ config = current_model_dict["config"]
161
+ for std_dev, size, blur_index in product(args.std_dev, args.size, args.blur_index):
162
+ if dataset is None:
163
+ config[DATALOADER][CONFIG_DEAD_LEAVES] = dict(
164
+ blur_kernel_half_size=[0, 0],
165
+ ds_factor=1,
166
+ noise_stddev=list(std_dev),
167
+ sampler=SAMPLER_SATURATED
168
+ )
169
+ config[DATALOADER]["gpu_gen"] = True
170
+ config[DATALOADER][SIZE] = size
171
+ config[DATALOADER][BATCH_SIZE][VALIDATION] = 1 if size[0] > 512 else 4
172
+ else:
173
+ config[DATALOADER][CONFIG_DEGRADATION] = dict(
174
+ noise_stddev=list(std_dev),
175
+ degradation_blur=DEGRADATION_BLUR_MAT if blur_flag else DEGRADATION_BLUR_NONE,
176
+ blur_index=blur_index
177
+ )
178
+ config[DATALOADER][NAME] = dataset
179
+ config[DATALOADER][SIZE] = size
180
+ config[DATALOADER][BATCH_SIZE][VALIDATION] = 1 if size[0] > 512 else 4
181
+ dataloader = get_data_loader(config, frozen_seed=42)
182
+ # print(config)
183
+ output_dir = Path(args.output_dir)/(config[NAME] + "_" +
184
+ config[PRETTY_NAME]) # + "_" + f"{size[0]:04d}x{size[1]:04d}")
185
+ output_dir.mkdir(parents=True, exist_ok=True)
186
+
187
+ all_metrics = infer(model, dataloader[VALIDATION], config, device, output_dir,
188
+ traces=args.traces, number_of_images=args.number_of_images,
189
+ degradation_key=CONFIG_DEAD_LEAVES if dataset is None else CONFIG_DEGRADATION)
190
+ if all_metrics is not None:
191
+ # print(all_metrics)
192
+ df = pd.DataFrame(all_metrics).T
193
+ prefix = f"{size[0]:04d}x{size[1]:04d}"
194
+ if not (std_dev[0] == 0 and std_dev[1] == 0):
195
+ prefix += f"_noise=[{std_dev[0]:02d},{std_dev[1]:02d}]"
196
+ if blur_index is not None:
197
+ prefix += f"_blur={blur_index:02d}"
198
+ prefix += f"_{config[PRETTY_NAME]}"
199
+ df.to_csv(output_dir/f"__{prefix}_metrics_.csv", index=False)
200
+ # Normally this could go into another script to handle the metrics analyzis
201
+ # print(df)
202
+
203
+
204
+ if __name__ == "__main__":
205
+ infer_main(sys.argv[1:])
scripts/interactive_inference_natural.py ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ sys.path.append("src")
3
+ from interactive_pipe import interactive_pipeline
4
+ from rstor.analyzis.interactive.pipelines import natural_inference_pipeline, morph_canvas, CANVAS
5
+ from rstor.analyzis.interactive.model_selection import get_default_models
6
+ from pathlib import Path
7
+ from rstor.analyzis.parser import get_parser
8
+ import argparse
9
+ from batch_processing import Batch
10
+ from interactive_pipe.data_objects.image import Image
11
+ from rstor.analyzis.interactive.images import image_selector
12
+ from rstor.analyzis.interactive.crop import plug_crop_selector
13
+ from rstor.analyzis.interactive.metrics import plug_configure_metrics
14
+ from interactive_pipe import interactive, KeyboardControl
15
+
16
+
17
+ def plug_morph_canvas():
18
+ interactive(
19
+ canvas=KeyboardControl(CANVAS[0], CANVAS, name="canvas", keyup="p", modulo=True)
20
+ )(morph_canvas)
21
+
22
+
23
+ def image_loading_batch(input: Path, args: argparse.Namespace) -> dict:
24
+ """Wrapper to load images files from a directory using batch_processing
25
+ """
26
+
27
+ if not args.disable_preload:
28
+ img = Image.from_file(input).data
29
+ return {"name": input.name, "path": input, "buffer": img}
30
+ else:
31
+ return {"name": input.name, "path": input, "buffer": None}
32
+
33
+
34
+ def main(argv):
35
+ batch = Batch(argv)
36
+ batch.set_io_description(
37
+ input_help='input image files',
38
+ output_help=argparse.SUPPRESS
39
+ )
40
+ parser = get_parser()
41
+ parser.add_argument("-nop", "--disable-preload", action="store_true", help="Disable images preload")
42
+ args = batch.parse_args(parser)
43
+ # batch.set_multiprocessing_enabled(False)
44
+ img_list = batch.run(image_loading_batch)
45
+ if args.keyboard:
46
+ image_control = KeyboardControl(0, [0, len(img_list)-1], keydown="3", keyup="9", modulo=True)
47
+ else:
48
+ image_control = (0, [0, len(img_list)-1])
49
+ interactive(image_index=image_control)(image_selector)
50
+ plug_crop_selector(num_pad=args.keyboard)
51
+ plug_configure_metrics(key_shortcut="a") # "a" if args.keyboard else None)
52
+ plug_morph_canvas()
53
+ model_dict = get_default_models(args.experiments, Path(args.models_storage), keyboard_control=args.keyboard)
54
+ interactive_pipeline(
55
+ gui=args.backend,
56
+ cache=True,
57
+ safe_input_buffer_deepcopy=False
58
+ )(natural_inference_pipeline)(
59
+ img_list,
60
+ model_dict
61
+ )
62
+
63
+
64
+ if __name__ == "__main__":
65
+ main(sys.argv[1:])
scripts/interactive_inference_synthetic.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ sys.path.append("src")
3
+ from interactive_pipe import interactive_pipeline
4
+ from rstor.synthetic_data.interactive.interactive_dead_leaves import dead_leave_plugin
5
+ from rstor.analyzis.interactive.pipelines import deadleave_inference_pipeline
6
+ from rstor.analyzis.interactive.model_selection import get_default_models
7
+ from rstor.analyzis.interactive.crop import plug_crop_selector
8
+ from rstor.analyzis.interactive.metrics import plug_configure_metrics
9
+ from pathlib import Path
10
+ from rstor.analyzis.parser import get_parser
11
+
12
+
13
+ def main(argv):
14
+ parser = get_parser()
15
+ args = parser.parse_args(argv)
16
+ plug_crop_selector(num_pad=args.keyboard)
17
+ model_dict = get_default_models(args.experiments, Path(args.models_storage))
18
+ dead_leave_plugin(ds=1)
19
+ plug_configure_metrics(key_shortcut="a")
20
+ interactive_pipeline(gui="auto", cache=True, safe_input_buffer_deepcopy=False)(
21
+ deadleave_inference_pipeline)(model_dict)
22
+
23
+
24
+ if __name__ == "__main__":
25
+ main(sys.argv[1:])
scripts/metrics_analyzis.ipynb ADDED
The diff for this file is too large to render. See raw diff
 
scripts/quantitative_results.ipynb ADDED
The diff for this file is too large to render. See raw diff
 
scripts/remote_training.py ADDED
@@ -0,0 +1,116 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import kaggle
2
+ from pathlib import Path, PurePosixPath
3
+ import json
4
+ try:
5
+ from __kaggle_login import kaggle_users
6
+ except ImportError:
7
+ raise ImportError("Please create a __kaggle_login.py file with a kaggle_users" +
8
+ "dict containing your Kaggle credentials.")
9
+ import argparse
10
+ import sys
11
+ import subprocess
12
+ from configuration import ROOT_DIR, OUTPUT_FOLDER_NAME
13
+ from train import get_parser as get_train_parser
14
+ from typing import Optional
15
+ from configuration import KAGGLE_DATASET_LIST, NB_ID, GIT_USER, GIT_REPO, TRAIN_SCRIPT
16
+
17
+
18
+ def get_git_branch_name():
19
+ try:
20
+ branch_name = subprocess.check_output(["git", "branch", "--show-current"]).strip().decode()
21
+ return branch_name
22
+ except subprocess.CalledProcessError:
23
+ return "Error: Could not determine the Git branch name."
24
+
25
+
26
+ def prepare_notebook(
27
+ output_nb_path: Path,
28
+ exp: int,
29
+ branch: str,
30
+ git_user: str = None,
31
+ git_repo: str = None,
32
+ template_nb_path: Path = Path(__file__).parent/"remote_training_template.ipynb",
33
+ wandb_flag: bool = False,
34
+ output_dir: Path = "scripts/"+OUTPUT_FOLDER_NAME,
35
+ dataset_files: Optional[list] = None,
36
+ train_script: str = TRAIN_SCRIPT
37
+ ):
38
+ assert git_user is not None, "Please provide a git username for the repo"
39
+ assert git_repo is not None, "Please provide a git repo name for the repo"
40
+ expressions = [
41
+ ("exp", f"{exp}"),
42
+ ("branch", f"\'{branch}\'"),
43
+ ("git_user", f"\'{git_user}\'"),
44
+ ("git_repo", f"\'{git_repo}\'"),
45
+ ("wandb_flag", "True" if wandb_flag else "False"),
46
+ ("output_dir", "None" if output_dir is None else f"\'{output_dir}\'"),
47
+ ("dataset_files", "None" if dataset_files is None else f"{dataset_files}"),
48
+ ("train_script", "\'"+train_script+"\'")
49
+ ]
50
+ with open(template_nb_path) as f:
51
+ template_nb = f.readlines()
52
+ for line_idx, li in enumerate(template_nb):
53
+ for expr, expr_replace in expressions:
54
+ if f"!!!{expr}!!!" in li:
55
+ template_nb[line_idx] = template_nb[line_idx].replace(f"!!!{expr}!!!", expr_replace)
56
+ template_nb = "".join(template_nb)
57
+ with open(output_nb_path, "w") as w:
58
+ w.write(template_nb)
59
+
60
+
61
+ def main(argv):
62
+ parser = argparse.ArgumentParser(description="Train a model on Kaggle using a script")
63
+ parser.add_argument("-n", "--nb_id", type=str, help="Notebook name in kaggle", default=NB_ID)
64
+ parser.add_argument("-u", "--user", type=str, help="Kaggle user", choices=list(kaggle_users.keys()))
65
+ parser.add_argument("--branch", type=str, help="Git branch name", default=get_git_branch_name())
66
+ parser.add_argument("-p", "--push", action="store_true", help="Push")
67
+ parser.add_argument("-d", "--download", action="store_true", help="Download results")
68
+ get_train_parser(parser)
69
+ args = parser.parse_args(argv)
70
+ nb_id = args.nb_id
71
+ exp_str = "_".join(f"{exp:04d}" for exp in args.exp)
72
+ kaggle_user = kaggle_users[args.user]
73
+ uname_kaggle = kaggle_user["username"]
74
+ kaggle.api._load_config(kaggle_user)
75
+ if args.download:
76
+ tmp_dir = ROOT_DIR/f"__tmp_{exp_str}"
77
+ tmp_dir.mkdir(exist_ok=True, parents=True)
78
+ kaggle.api.kernels_output_cli(f"{kaggle_user['username']}/{nb_id}", path=str(tmp_dir))
79
+ subprocess.run(["tar", "-xzf", tmp_dir/"output.tgz"])
80
+ # @FIXME: windows probably does not have tar command
81
+ import shutil
82
+ shutil.rmtree(tmp_dir, ignore_errors=True)
83
+ return
84
+ kernel_root = ROOT_DIR/f"__nb_{uname_kaggle}"
85
+ kernel_root.mkdir(exist_ok=True, parents=True)
86
+
87
+ kernel_path = kernel_root/exp_str
88
+ kernel_path.mkdir(exist_ok=True, parents=True)
89
+ branch = args.branch
90
+ config = {
91
+ "id": str(PurePosixPath(f"{kaggle_user['username']}")/nb_id),
92
+ "title": nb_id.lower(),
93
+ "code_file": f"{nb_id}.ipynb",
94
+ "language": "python",
95
+ "kernel_type": "notebook",
96
+ "is_private": "true",
97
+ "enable_gpu": "true" if not args.cpu else "false",
98
+ "enable_tpu": "false",
99
+ "enable_internet": "true",
100
+ "dataset_sources": KAGGLE_DATASET_LIST,
101
+ "competition_sources": [],
102
+ "kernel_sources": [],
103
+ "model_sources": []
104
+ }
105
+ prepare_notebook((kernel_path/nb_id).with_suffix(".ipynb"), args.exp, branch,
106
+ git_user=GIT_USER, git_repo=GIT_REPO, wandb_flag=not args.no_wandb)
107
+ assert (kernel_path/nb_id).with_suffix(".ipynb").exists()
108
+ with open(kernel_path/"kernel-metadata.json", "w") as f:
109
+ json.dump(config, f, indent=4)
110
+
111
+ if args.push:
112
+ kaggle.api.kernels_push_cli(str(kernel_path))
113
+
114
+
115
+ if __name__ == '__main__':
116
+ main(sys.argv[1:])
scripts/remote_training_template.ipynb ADDED
@@ -0,0 +1 @@
 
 
1
+ {"cells":[{"cell_type":"code","execution_count":null,"metadata":{"trusted":true},"outputs":[],"source":["exp = !!!exp!!!\n","branch = !!!branch!!!\n","git_user = !!!git_user!!!\n","git_repo = !!!git_repo!!!\n","wandb_flag = !!!wandb_flag!!!\n","output_dir = !!!output_dir!!!\n","dataset_files = !!!dataset_files!!!\n","train_script = !!!train_script!!!"]},{"cell_type":"markdown","metadata":{},"source":["# Clone git repo"]},{"cell_type":"code","execution_count":null,"metadata":{"trusted":true},"outputs":[],"source":["%cd ~\n","!git clone https://github.com/$git_user/$git_repo >/dev/null\n","%cd $git_repo\n","!git checkout $branch\n","!pip install -e ."]},{"cell_type":"markdown","metadata":{},"source":["# Load Kaggle datasets"]},{"cell_type":"code","execution_count":null,"metadata":{},"outputs":[],"source":["!mkdir __dataset/deadleaves_div2k_512\n","!cp /kaggle/input/deadleaves-div2k-512/deadleaves_div2k_512/* \"__dataset/deadleaves_div2k_512\""]},{"cell_type":"code","execution_count":null,"metadata":{},"outputs":[],"source":["!mkdir __dataset/deadleaves_primitives_div2k_512\n","!cp /kaggle/input/deadleaves-primitives-div2k-512/deadleaves_primitives_div2k_512/* \"__dataset/deadleaves_primitives_div2k_512\""]},{"cell_type":"code","execution_count":null,"metadata":{},"outputs":[],"source":["!mkdir __dataset/div2k\n","!cp -r \"/kaggle/input/div2k-dataset/DIV2K_train_HR\" \"__dataset/div2k/\"\n","!cp -r \"/kaggle/input/div2k-dataset/DIV2K_valid_HR/\" \"__dataset/div2k/\""]},{"cell_type":"code","execution_count":null,"metadata":{},"outputs":[],"source":["!mkdir __dataset/kernels\n","!cp \"/kaggle/input/motion-blur-kernels/custom_blur_centered.mat\" __dataset/kernels/"]},{"cell_type":"markdown","metadata":{},"source":["# Setup weights and biases"]},{"cell_type":"code","execution_count":null,"metadata":{},"outputs":[],"source":["if wandb_flag:\n"," from kaggle_secrets import UserSecretsClient\n"," user_secrets = UserSecretsClient()\n"," wandb_api_key = user_secrets.get_secret(\"wandb_api_key\")\n","\n"," !pip install wandb >/dev/null\n"," !wandb login $wandb_api_key"]},{"cell_type":"markdown","metadata":{},"source":["# Launch training"]},{"cell_type":"code","execution_count":null,"metadata":{},"outputs":[],"source":["exp_str = ' '.join([str(e) for e in exp])\n","wb_ext = \"-nowb\" if not wandb_flag else \"\"\n","!python $train_script -e $exp_str $wb_ext"]},{"cell_type":"markdown","metadata":{},"source":["# Prepare outputs"]},{"cell_type":"code","execution_count":null,"metadata":{},"outputs":[],"source":["if output_dir is not None:\n"," !tar -cvzf /kaggle/working/output.tgz $output_dir"]}],"metadata":{"kaggle":{"accelerator":"gpu","dataSources":[{"datasetId":4234777,"sourceId":7299921,"sourceType":"datasetVersion"}],"dockerImageVersionId":30626,"isGpuEnabled":true,"isInternetEnabled":true,"language":"python","sourceType":"notebook"},"kernelspec":{"display_name":"Python 3","language":"python","name":"python3"},"language_info":{"codemirror_mode":{"name":"ipython","version":3},"file_extension":".py","mimetype":"text/x-python","name":"python","nbconvert_exporter":"python","pygments_lexer":"ipython3","version":"3.10.12"}},"nbformat":4,"nbformat_minor":4}
scripts/save_deadleaves.py ADDED
@@ -0,0 +1,172 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ """
3
+ Created on Sat Mar 23 15:38:28 2024
4
+
5
+ @author: jamyl
6
+ """
7
+ import cv2
8
+ from pathlib import Path
9
+ from time import perf_counter
10
+ import matplotlib.pyplot as plt
11
+ from typing import Tuple
12
+
13
+ import numpy as np
14
+ import torch
15
+ import torch.nn.functional as F
16
+ from torch.utils.data import Dataset
17
+ from numba import cuda
18
+ from tqdm import tqdm
19
+ import argparse
20
+ from rstor.synthetic_data.dead_leaves_gpu import gpu_dead_leaves_chart
21
+ from rstor.utils import DEFAULT_TORCH_FLOAT_TYPE
22
+ from rstor.properties import DATASET_PATH, DATASET_DL_RANDOMRGB_1024, DATASET_DL_DIV2K_1024, SAMPLER_NATURAL, SAMPLER_UNIFORM, DATASET_DL_DIV2K_512, DATASET_DL_EXTRAPRIMITIVES_DIV2K_512
23
+
24
+
25
+ class DeadLeavesDatasetGPU(Dataset):
26
+ def __init__(
27
+ self,
28
+ size: Tuple[int, int] = (128, 128),
29
+ length: int = 1000,
30
+ frozen_seed: int = None, # useful for validation set!
31
+ ds_factor: int = 5,
32
+ **config_dead_leaves
33
+ ):
34
+
35
+ self.frozen_seed = frozen_seed
36
+ self.ds_factor = ds_factor
37
+ self.size = (size[0]*ds_factor, size[1]*ds_factor)
38
+ self.length = length
39
+ self.config_dead_leaves = config_dead_leaves
40
+
41
+ # downsample kernel
42
+ sigma = 3/5
43
+ k_size = 5 # This fits with sigma = 3/5, the cutoff value is 0.0038 (neglectable)
44
+ x = (torch.arange(k_size) - 2).to('cuda')
45
+ kernel = torch.stack(torch.meshgrid((x, x), indexing='ij'))
46
+ dist_sq = kernel[0]**2 + kernel[1]**2
47
+ kernel = (-dist_sq.square()/(2*sigma**2)).exp()
48
+ kernel = kernel / kernel.sum()
49
+ self.downsample_kernel = kernel.repeat(3, 1, 1, 1) # shape [3, 1, k_size, k_size]
50
+
51
+ def __len__(self) -> int:
52
+ return self.length
53
+
54
+ def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor]:
55
+ """Get a single deadleave chart and its degraded version.
56
+
57
+ Args:
58
+ idx (int): index of the item to retrieve
59
+
60
+ Returns:
61
+ Tuple[torch.Tensor, torch.Tensor]: degraded chart, target chart
62
+ """
63
+ seed = self.frozen_seed + idx if self.frozen_seed is not None else None
64
+
65
+ # Return numba device array
66
+ numba_chart = gpu_dead_leaves_chart(self.size, seed=seed, **self.config_dead_leaves)
67
+ if self.ds_factor > 1:
68
+ # print(f"Downsampling {chart.shape} with factor {self.ds_factor}...")
69
+
70
+ # Downsample using strided gaussian conv (sigma=3/5)
71
+ th_chart = torch.as_tensor(numba_chart, dtype=DEFAULT_TORCH_FLOAT_TYPE,
72
+ device="cuda").permute(2, 0, 1)[None] # [b, c, h, w]
73
+ th_chart = F.pad(th_chart,
74
+ pad=(2, 2, 0, 0),
75
+ mode="replicate")
76
+ th_chart = F.conv2d(th_chart,
77
+ self.downsample_kernel,
78
+ padding='valid',
79
+ groups=3,
80
+ stride=self.ds_factor)
81
+
82
+ # Convert back to numba
83
+ numba_chart = cuda.as_cuda_array(th_chart.permute(0, 2, 3, 1)) # [b, h, w, c]
84
+
85
+ # convert back to numpy (temporary for legacy)
86
+ chart = numba_chart.copy_to_host()[0]
87
+
88
+ return chart
89
+
90
+
91
+ def generate_images(path: Path, dataset: Dataset, imin=0):
92
+ for i in tqdm(range(imin, dataset.length)):
93
+ img = dataset[i]
94
+ img = (img * 255).astype(np.uint8)
95
+ out_path = path / "{:04d}.png".format(i)
96
+ cv2.imwrite(out_path.as_posix(), img)
97
+
98
+
99
+ def bench(dataset):
100
+
101
+ print("dataset initialised")
102
+ t1 = perf_counter()
103
+ chart = dataset[0]
104
+
105
+ d = (perf_counter()-t1)
106
+ print(f"generation done {d}")
107
+ print(f"{d*1_000/60} min for 1_000")
108
+ plt.imshow(chart)
109
+ plt.show()
110
+
111
+
112
+ if __name__ == "__main__":
113
+ argparser = argparse.ArgumentParser()
114
+ argparser.add_argument("-o", "--output-dir", type=str, default=str(DATASET_PATH))
115
+ argparser.add_argument(
116
+ "-n", "--name", type=str,
117
+ choices=[DATASET_DL_RANDOMRGB_1024, DATASET_DL_DIV2K_1024,
118
+ DATASET_DL_DIV2K_512, DATASET_DL_EXTRAPRIMITIVES_DIV2K_512],
119
+ default=DATASET_DL_RANDOMRGB_1024
120
+ )
121
+ argparser.add_argument("-b", "--benchmark", action="store_true")
122
+ default_config = dict(
123
+ size=(1_024, 1_024),
124
+ length=1_000,
125
+ frozen_seed=42,
126
+ background_color=(0.2, 0.4, 0.6),
127
+ colored=True,
128
+ radius_min=5,
129
+ radius_max=2_000,
130
+ ds_factor=5,
131
+ )
132
+
133
+ args = argparser.parse_args()
134
+ dataset_dir = args.output_dir
135
+ name = args.name
136
+ path = Path(dataset_dir)/name
137
+ # print(path)
138
+ path.mkdir(parents=True, exist_ok=True)
139
+ if name == DATASET_DL_RANDOMRGB_1024:
140
+ config = default_config
141
+ config["sampler"] = SAMPLER_UNIFORM
142
+ elif name == DATASET_DL_DIV2K_1024:
143
+ config = default_config
144
+ config["sampler"] = SAMPLER_NATURAL
145
+ config["natural_image_list"] = sorted(
146
+ list((DATASET_PATH / "div2k" / "DIV2K_train_HR" / "DIV2K_train_HR").glob("*.png"))
147
+ )
148
+ elif name == DATASET_DL_DIV2K_512:
149
+ config = default_config
150
+ config["size"] = (512, 512)
151
+ config["rmin"] = 3
152
+ config["length"] = 4000
153
+ config["sampler"] = SAMPLER_NATURAL
154
+ config["natural_image_list"] = sorted(
155
+ list((DATASET_PATH / "div2k" / "DIV2K_train_HR" / "DIV2K_train_HR").glob("*.png"))
156
+ )
157
+ elif name == DATASET_DL_EXTRAPRIMITIVES_DIV2K_512:
158
+ config = default_config
159
+ config["size"] = (512, 512)
160
+ config["sampler"] = SAMPLER_NATURAL
161
+ config["circle_primitives"] = False
162
+ config["length"] = 4000
163
+ config["natural_image_list"] = sorted(
164
+ list((DATASET_PATH / "div2k" / "DIV2K_train_HR" / "DIV2K_train_HR").glob("*.png"))
165
+ )
166
+ else:
167
+ raise NotImplementedError
168
+ dataset = DeadLeavesDatasetGPU(**config)
169
+ if args.benchmark:
170
+ bench(dataset)
171
+ else:
172
+ generate_images(path, dataset)
scripts/train.py ADDED
@@ -0,0 +1,171 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ import argparse
3
+ from typing import Optional
4
+ import torch
5
+ import logging
6
+ from pathlib import Path
7
+ import json
8
+ from tqdm import tqdm
9
+ from rstor.properties import (
10
+ ID, NAME, NB_EPOCHS,
11
+ TRAIN, VALIDATION, LR,
12
+ LOSS_MSE, METRIC_PSNR, METRIC_SSIM,
13
+ DEVICE, SCHEDULER_CONFIGURATION, SCHEDULER, REDUCELRONPLATEAU,
14
+ REDUCTION_SUM,
15
+ SELECTED_METRICS,
16
+ LOSS
17
+ )
18
+ from rstor.learning.metrics import compute_metrics
19
+ from rstor.learning.loss import compute_loss
20
+ from torch.optim.lr_scheduler import ReduceLROnPlateau
21
+ from configuration import WANDBSPACE, ROOT_DIR, OUTPUT_FOLDER_NAME
22
+ from rstor.learning.experiments import get_training_content
23
+ from rstor.learning.experiments_definition import get_experiment_config
24
+ WANDB_AVAILABLE = False
25
+ try:
26
+ WANDB_AVAILABLE = True
27
+ import wandb
28
+ except ImportError:
29
+ logging.warning("Could not import wandb. Disabling wandb.")
30
+ pass
31
+
32
+
33
+ def get_parser(parser: Optional[argparse.ArgumentParser] = None) -> argparse.ArgumentParser:
34
+ if parser is None:
35
+ parser = argparse.ArgumentParser(description="Train a model")
36
+ parser.add_argument("-e", "--exp", nargs="+", type=int, required=True, help="Experiment id")
37
+ parser.add_argument("-o", "--output-dir", type=str, default=ROOT_DIR/OUTPUT_FOLDER_NAME, help="Output directory")
38
+ parser.add_argument("-nowb", "--no-wandb", action="store_true", help="Disable weights and biases")
39
+ parser.add_argument("--cpu", action="store_true", help="Force CPU")
40
+ return parser
41
+
42
+
43
+ def training_loop(
44
+ model,
45
+ optimizer,
46
+ dl_dict: dict,
47
+ config: dict,
48
+ scheduler=None,
49
+ device: str = DEVICE,
50
+ wandb_flag: bool = False,
51
+ output_dir: Path = None,
52
+ ):
53
+ best_accuracy = 0.
54
+ chosen_metrics = config.get(SELECTED_METRICS, [METRIC_PSNR, METRIC_SSIM])
55
+ for n_epoch in tqdm(range(config[NB_EPOCHS])):
56
+ current_metrics = {
57
+ TRAIN: 0.,
58
+ VALIDATION: 0.,
59
+ LR: optimizer.param_groups[0]['lr'],
60
+ }
61
+ for met in chosen_metrics:
62
+ current_metrics[met] = 0.
63
+ for phase in [TRAIN, VALIDATION]:
64
+ total_elements = 0
65
+ if phase == TRAIN:
66
+ model.train()
67
+ else:
68
+ model.eval()
69
+ for x, y in tqdm(dl_dict[phase], desc=f"{phase} - Epoch {n_epoch}"):
70
+ x, y = x.to(device), y.to(device)
71
+ optimizer.zero_grad()
72
+ with torch.set_grad_enabled(phase == TRAIN):
73
+ y_pred = model(x)
74
+ loss = compute_loss(y_pred, y, mode=config.get(LOSS, LOSS_MSE))
75
+ if torch.isnan(loss):
76
+ print(f"Loss is NaN at epoch {n_epoch} and phase {phase}!")
77
+ continue
78
+ if phase == TRAIN:
79
+ loss.backward()
80
+ torch.nn.utils.clip_grad_norm_(model.parameters(), 1.)
81
+ optimizer.step()
82
+ current_metrics[phase] += loss.item()
83
+ if phase == VALIDATION:
84
+ metrics_on_batch = compute_metrics(
85
+ y_pred,
86
+ y,
87
+ chosen_metrics=chosen_metrics,
88
+ reduction=REDUCTION_SUM
89
+ )
90
+ total_elements += y_pred.shape[0]
91
+ for k, v in metrics_on_batch.items():
92
+ current_metrics[k] += v
93
+
94
+ current_metrics[phase] /= (len(dl_dict[phase]))
95
+ if phase == VALIDATION:
96
+ for k, v in metrics_on_batch.items():
97
+ current_metrics[k] /= total_elements
98
+ try:
99
+ current_metrics[k] = current_metrics[k].item()
100
+ except AttributeError:
101
+ pass
102
+ debug_print = f"{phase}: Epoch {n_epoch} - Loss: {current_metrics[phase]:.3e} "
103
+ for k, v in current_metrics.items():
104
+ if k not in [TRAIN, VALIDATION, LR]:
105
+ debug_print += f"{k}: {v:.3} |"
106
+ print(debug_print)
107
+ if scheduler is not None and isinstance(scheduler, ReduceLROnPlateau):
108
+ scheduler.step(current_metrics[VALIDATION])
109
+ if output_dir is not None:
110
+ with open(output_dir/f"metrics_{n_epoch}.json", "w") as f:
111
+ json.dump(current_metrics, f)
112
+ if wandb_flag:
113
+ wandb.log(current_metrics)
114
+ if best_accuracy < current_metrics[METRIC_PSNR]:
115
+ best_accuracy = current_metrics[METRIC_PSNR]
116
+ if output_dir is not None:
117
+ print("new best model saved!")
118
+ torch.save(model.state_dict(), output_dir/"best_model.pt")
119
+ if output_dir is not None:
120
+ torch.save(model.cpu().state_dict(), output_dir/"last_model.pt")
121
+ return model
122
+
123
+
124
+ def train(config: dict, output_dir: Path, device: str = DEVICE, wandb_flag: bool = False):
125
+ logging.basicConfig(level=logging.INFO)
126
+ logging.info(f"Training experiment {config[ID]} on device {device}...")
127
+ output_dir.mkdir(parents=True, exist_ok=True)
128
+ with open(output_dir/"config.json", "w") as f:
129
+ json.dump(config, f)
130
+ model, optimizer, dl_dict = get_training_content(config, training_mode=True, device=device)
131
+ model.to(device)
132
+ if wandb_flag:
133
+ import wandb
134
+ wandb.init(
135
+ project=WANDBSPACE,
136
+ entity="balthazarneveu",
137
+ name=config[NAME],
138
+ tags=["debug"],
139
+ # tags=["base"],
140
+ config=config
141
+ )
142
+ scheduler = None
143
+ if config.get(SCHEDULER, False):
144
+ scheduler_config = config[SCHEDULER_CONFIGURATION]
145
+ if config[SCHEDULER] == REDUCELRONPLATEAU:
146
+ scheduler = ReduceLROnPlateau(optimizer, mode='min', verbose=True, **scheduler_config)
147
+ else:
148
+ raise NameError(f"Scheduler {config[SCHEDULER]} not implemented")
149
+ model = training_loop(model, optimizer, dl_dict, config, scheduler=scheduler, device=device,
150
+ wandb_flag=wandb_flag, output_dir=output_dir)
151
+
152
+ if wandb_flag:
153
+ wandb.finish()
154
+
155
+
156
+ def train_main(argv):
157
+ parser = get_parser()
158
+ args = parser.parse_args(argv)
159
+ if not WANDB_AVAILABLE:
160
+ args.no_wandb = True
161
+ device = "cpu" if args.cpu else DEVICE
162
+ for exp in args.exp:
163
+ config = get_experiment_config(exp)
164
+ print(config)
165
+ output_dir = Path(args.output_dir)/config[NAME]
166
+ logging.info(f"Training experiment {config[ID]} on device {device}...")
167
+ train(config, device=device, output_dir=output_dir, wandb_flag=not args.no_wandb)
168
+
169
+
170
+ if __name__ == "__main__":
171
+ train_main(sys.argv[1:])
src/rstor/__init__.py ADDED
File without changes
src/rstor/analyzis/interactive/crop.py ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import cv2
2
+ import numpy as np
3
+ from interactive_pipe import interactive
4
+
5
+
6
+ def get_color_channel_offset(image):
7
+ # size is defined in power of 2
8
+ if len(image.shape) == 2:
9
+ offset = 0
10
+ elif len(image.shape) == 3:
11
+ channel_guesser_max_size = 4
12
+ if image.shape[0] <= channel_guesser_max_size: # channel first C,H,W
13
+ offset = 0
14
+ elif image.shape[-1] <= channel_guesser_max_size: # channel last or numpy H,W,C
15
+ offset = 1
16
+ else:
17
+ raise NameError(f"Not supported shape {image.shape}")
18
+ return offset
19
+
20
+
21
+ def crop_selector(image, center_x=0.5, center_y=0.5, size=9., global_params={}):
22
+ offset = get_color_channel_offset(image)
23
+ crop_size_pixels = int(2.**(size)/2.)
24
+ h, w = image.shape[-2-offset], image.shape[-1-offset]
25
+ ar = w/h
26
+ half_crop_h, half_crop_w = crop_size_pixels, int(ar*crop_size_pixels)
27
+
28
+ def round(val):
29
+ return int(np.round(val))
30
+ center_x_int = round(half_crop_w + center_x*(w-2*half_crop_w))
31
+ center_y_int = round(half_crop_h + center_y*(h-2*half_crop_h))
32
+ start_x = max(0, center_x_int-half_crop_w)
33
+ start_y = max(0, center_y_int-half_crop_h)
34
+ end_x = min(start_x+2*half_crop_w, w-1)
35
+ end_y = min(start_y+2*half_crop_h, h-1)
36
+ start_x = max(0, end_x-2*half_crop_w)
37
+ start_y = max(0, end_y-2*half_crop_h)
38
+ MAX_ALLOWED_SIZE = 512
39
+ w_resize = int(min(MAX_ALLOWED_SIZE, w))
40
+ h_resize = int(w_resize/w*h)
41
+ h_resize = int(min(MAX_ALLOWED_SIZE, h_resize))
42
+ w_resize = int(h_resize/h*w)
43
+ global_params["crop"] = (start_x, start_y, end_x, end_y)
44
+ global_params["resize"] = (w_resize, h_resize)
45
+ return
46
+
47
+
48
+ def plug_crop_selector(num_pad: bool = False):
49
+ interactive(
50
+ center_x=(0.5, [0., 1.], "cx", ["4" if num_pad else "left", "6" if num_pad else "right"]),
51
+ center_y=(0.5, [0., 1.], "cy", ["8" if num_pad else "up", "2" if num_pad else "down"]),
52
+ size=(9., [6., 13., 0.3], "crop size", ["+", "-"])
53
+ )(crop_selector)
54
+
55
+
56
+ def crop(*images, global_params={}):
57
+ images_resized = []
58
+ for image in images:
59
+ offset = get_color_channel_offset(image)
60
+ start_x, start_y, end_x, end_y = global_params["crop"]
61
+ w_resize, h_resize = global_params["resize"]
62
+ if offset == 0:
63
+ crop = image[..., start_y:end_y, start_x:end_x]
64
+ if offset == 1:
65
+ crop = image[..., start_y:end_y, start_x:end_x, :]
66
+ image_resized = cv2.resize(crop, (w_resize, h_resize), interpolation=cv2.INTER_NEAREST)
67
+ images_resized.append(image_resized)
68
+ return tuple(images_resized)
69
+
70
+
71
+ def rescale_thumbnail(image, global_params={}):
72
+ if image is None: # support no blur kernel!
73
+ return None
74
+ resize_dim = max(global_params.get("resize", (512, 512)))
75
+ return cv2.resize(image, (resize_dim, resize_dim), interpolation=cv2.INTER_NEAREST)
src/rstor/analyzis/interactive/degradation.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ from interactive_pipe import interactive
3
+ from skimage.filters import gaussian
4
+ from rstor.properties import DATASET_BLUR_KERNEL_PATH
5
+ from scipy.io import loadmat
6
+ import cv2
7
+
8
+
9
+ @interactive(
10
+ sigma=(3/5, [0., 2.])
11
+ )
12
+ def downsample(chart: np.ndarray, sigma=3/5, global_params={}):
13
+ ds_factor = global_params.get("ds_factor", 5)
14
+ if sigma > 0.:
15
+ ds_chart = gaussian(chart, sigma=(sigma, sigma, 0), mode='nearest', cval=0, preserve_range=True, truncate=4.0)
16
+ else:
17
+ ds_chart = chart.copy()
18
+ ds_chart = ds_chart[ds_factor//2::ds_factor, ds_factor//2::ds_factor]
19
+ return ds_chart
20
+
21
+
22
+ @interactive(
23
+ k_size_x=(0, [0, 10]),
24
+ k_size_y=(0, [0, 10]),
25
+ )
26
+ def degrade_blur_gaussian(chart: np.ndarray, k_size_x: int = 1, k_size_y: int = 1):
27
+ if k_size_x == 0 and k_size_y == 0:
28
+ blurred = chart
29
+ blurred = cv2.GaussianBlur(chart, (2*k_size_x+1, 2*k_size_y+1), 0)
30
+ return blurred
31
+
32
+
33
+ @interactive(
34
+ noise_stddev=(0., [0., 50.])
35
+ )
36
+ def degrade_noise(img: np.ndarray, noise_stddev=0., global_params={}):
37
+ seed = global_params.get("seed", 42)
38
+ np.random.seed(seed)
39
+ if noise_stddev > 0.:
40
+ noise = np.random.normal(0, noise_stddev/255., img.shape)
41
+ img = img.copy()+noise
42
+ return img
43
+
44
+
45
+ @interactive(
46
+ ksize=(3, [1, 10])
47
+ )
48
+ def get_blur_kernel_box(ksize=3):
49
+ return np.ones((ksize, ksize), dtype=np.float32) / (1.*ksize**2)
50
+
51
+
52
+ @interactive(
53
+ blur_index=(-1, [-1, 1000])
54
+ )
55
+ def get_blur_kernel(blur_index: int = -1, global_params={}):
56
+ if blur_index == -1:
57
+ return None
58
+ blur_mat = global_params.get("blur_mat", False)
59
+ if blur_mat is False:
60
+ blur_mat = loadmat(DATASET_BLUR_KERNEL_PATH)["kernels"].squeeze()
61
+ global_params["blur_mat"] = blur_mat
62
+ blur_k = blur_mat[blur_index]
63
+ blur_k = blur_k/blur_k.sum()
64
+ return blur_k
65
+
66
+
67
+ def degrade_blur(img: np.ndarray, blur_kernel: np.ndarray, global_params={}):
68
+ if blur_kernel is None:
69
+ return img
70
+ img_blur = cv2.filter2D(img, -1, blur_kernel)
71
+ return img_blur
src/rstor/analyzis/interactive/images.py ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ from interactive_pipe.data_objects.image import Image
2
+ from typing import List
3
+
4
+
5
+ def image_selector(image_list: List[dict], image_index: int = 0) -> dict:
6
+ current_image = image_list[image_index % len(image_list)]
7
+ img = current_image.get("buffer", None)
8
+ if img is None:
9
+ img = Image.from_file(current_image["path"]).data
10
+ return img
src/rstor/analyzis/interactive/inference.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ from rstor.properties import DEVICE
4
+
5
+
6
+ def infer(degraded: np.ndarray, model: torch.nn.Module):
7
+ degraded_tensor = torch.from_numpy(degraded).permute(-1, 0, 1).float().unsqueeze(0)
8
+ model.eval()
9
+ with torch.no_grad():
10
+ output = model(degraded_tensor.to(DEVICE))
11
+ output = output.squeeze().permute(1, 2, 0).cpu().numpy()
12
+ return np.ascontiguousarray(output)
src/rstor/analyzis/interactive/metrics.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from rstor.learning.metrics import compute_metrics, ALL_METRICS
2
+ import torch
3
+ import numpy as np
4
+ from rstor.properties import METRIC_PSNR, METRIC_SSIM
5
+ from interactive_pipe import interactive, KeyboardControl
6
+ from typing import Optional
7
+
8
+
9
+ def plug_configure_metrics(key_shortcut: Optional[str] = None) -> None:
10
+ interactive(
11
+ advanced_metrics=KeyboardControl(False, keydown=key_shortcut) if key_shortcut is not None else (True,)
12
+ )(configure_metrics)
13
+
14
+
15
+ def configure_metrics(advanced_metrics=False, global_params={}) -> None:
16
+ chosen_metrics = ALL_METRICS if advanced_metrics else [METRIC_PSNR, METRIC_SSIM]
17
+ global_params["chosen_metrics"] = chosen_metrics
18
+
19
+
20
+ def get_metrics(prediction: torch.Tensor, target: torch.Tensor,
21
+ image_name: str, # use functools.partial to root where you want the title to appear
22
+ global_params: dict = {}) -> None:
23
+ if isinstance(prediction, np.ndarray):
24
+ prediction_ = torch.from_numpy(prediction).permute(-1, 0, 1).float().unsqueeze(0)
25
+ else:
26
+ prediction_ = prediction
27
+ if isinstance(target, np.ndarray):
28
+ target_ = torch.from_numpy(target).permute(-1, 0, 1).float().unsqueeze(0)
29
+ else:
30
+ target_ = target
31
+ chosen_metrics = global_params.get("chosen_metrics", [METRIC_PSNR])
32
+ metrics = compute_metrics(prediction_, target_, chosen_metrics=chosen_metrics)
33
+ global_params["metrics"] = metrics
34
+ title = f"{image_name}: "
35
+ title += " ".join([f"{key}: {value:.4f}" for key, value in metrics.items()])
36
+ global_params["__output_styles"][image_name] = {"title": title, "image_name": image_name}
src/rstor/analyzis/interactive/model_selection.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from interactive_pipe import KeyboardControl
3
+ from rstor.learning.experiments import get_training_content
4
+ from rstor.learning.experiments_definition import get_experiment_config
5
+ from rstor.properties import DEVICE, PRETTY_NAME
6
+ from tqdm import tqdm
7
+ from pathlib import Path
8
+ from typing import List, Tuple
9
+
10
+ from interactive_pipe import interactive
11
+ MODELS_PATH = Path("scripts")/"__output"
12
+
13
+
14
+ def model_selector(models_dict: dict, global_params={}, model_name="vanilla"):
15
+ if isinstance(model_name, str):
16
+ current_model = models_dict[model_name]
17
+ elif isinstance(model_name, int):
18
+ model_names = [name for name in models_dict.keys()]
19
+ current_model = models_dict[model_names[model_name % len(model_names)]]
20
+ else:
21
+ raise ValueError(f"Model name {model_name} not understood")
22
+ global_params["model_config"] = current_model["config"]
23
+ return current_model["model"]
24
+
25
+
26
+ def get_model_from_exp(exp: int, model_storage: Path = MODELS_PATH, device=DEVICE) -> Tuple[torch.nn.Module, dict]:
27
+ config = get_experiment_config(exp)
28
+ model, _, _ = get_training_content(config, training_mode=False)
29
+ model_path = torch.load(model_storage/f"{exp:04d}"/"best_model.pt")
30
+ assert model_path is not None, f"Model {exp} not found"
31
+ model.load_state_dict(model_path)
32
+ model = model.to(device)
33
+ return model, config
34
+
35
+
36
+ def get_default_models(
37
+ exp_list: List[int] = [1000, 1001],
38
+ model_storage: Path = MODELS_PATH,
39
+ keyboard_control: bool = False,
40
+ interactive_flag: bool = True
41
+ ) -> dict:
42
+ model_dict = {}
43
+ assert model_storage.exists(), f"Model storage {model_storage} does not exist"
44
+ for exp in tqdm(exp_list, desc="Loading models"):
45
+ model, config = get_model_from_exp(exp, model_storage=model_storage)
46
+ name = config.get(PRETTY_NAME, f"{exp:04d}")
47
+ model_dict[name] = {
48
+ "model": model,
49
+ "config": config
50
+ }
51
+ exp_names = [name for name in model_dict.keys()]
52
+ if interactive_flag:
53
+ if keyboard_control:
54
+ model_control = KeyboardControl(0, [0, len(exp_names)-1], keydown="pagedown", keyup="pageup", modulo=True)
55
+ else:
56
+ model_control = (exp_names[0], exp_names)
57
+ interactive(model_name=model_control)(model_selector) # Create the model dialog
58
+ return model_dict
src/rstor/analyzis/interactive/pipelines.py ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from rstor.synthetic_data.interactive.interactive_dead_leaves import generate_deadleave
2
+ from rstor.analyzis.interactive.crop import crop_selector, crop, rescale_thumbnail
3
+ from rstor.analyzis.interactive.inference import infer
4
+ from rstor.analyzis.interactive.degradation import degrade_noise, degrade_blur, downsample, degrade_blur_gaussian, get_blur_kernel
5
+ from rstor.analyzis.interactive.model_selection import model_selector
6
+ from rstor.analyzis.interactive.images import image_selector
7
+ from rstor.analyzis.interactive.metrics import get_metrics, configure_metrics
8
+ from interactive_pipe import interactive, KeyboardControl
9
+ from typing import Tuple, List
10
+ from functools import partial
11
+ import numpy as np
12
+
13
+
14
+ get_metrics_restored = partial(get_metrics, image_name="restored")
15
+ get_metrics_degraded = partial(get_metrics, image_name="degraded")
16
+
17
+
18
+ def deadleave_inference_pipeline(models_dict: dict) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
19
+ groundtruth = generate_deadleave()
20
+ groundtruth = downsample(groundtruth)
21
+ model = model_selector(models_dict)
22
+ degraded = degrade_blur_gaussian(groundtruth)
23
+ degraded = degrade_noise(degraded)
24
+ restored = infer(degraded, model)
25
+ crop_selector(restored)
26
+ groundtruth, degraded, restored = crop(groundtruth, degraded, restored)
27
+ configure_metrics()
28
+ get_metrics_restored(restored, groundtruth)
29
+ get_metrics_degraded(degraded, groundtruth)
30
+ return groundtruth, degraded, restored
31
+
32
+
33
+ CANVAS_DICT = {
34
+ "demo": [["degraded", "restored"]],
35
+ "landscape_light": [["degraded", "restored", "groundtruth"]],
36
+ "landscape": [["degraded", "restored", "blur_kernel", "groundtruth"]],
37
+ "full": [["degraded", "restored"], ["blur_kernel", "groundtruth"]]
38
+ }
39
+ CANVAS = list(CANVAS_DICT.keys())
40
+
41
+
42
+ def morph_canvas(canvas=CANVAS[0], global_params={}):
43
+ global_params["__pipeline"].outputs = CANVAS_DICT[canvas]
44
+ return None
45
+
46
+
47
+ def natural_inference_pipeline(input_image_list: List[np.ndarray], models_dict: dict):
48
+ model = model_selector(models_dict)
49
+ img_clean = image_selector(input_image_list)
50
+ crop_selector(img_clean)
51
+ groundtruth = crop(img_clean)
52
+ blur_kernel = get_blur_kernel()
53
+ degraded = degrade_blur(groundtruth, blur_kernel)
54
+ degraded = degrade_noise(degraded)
55
+ blur_kernel = rescale_thumbnail(blur_kernel)
56
+ restored = infer(degraded, model)
57
+ configure_metrics()
58
+ get_metrics_restored(restored, groundtruth)
59
+ get_metrics_degraded(degraded, groundtruth)
60
+ morph_canvas()
61
+ return [[degraded, restored], [blur_kernel, groundtruth]]
src/rstor/analyzis/metrics_plots.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import matplotlib.pyplot as plt
3
+ import numpy as np
4
+ from pathlib import Path
5
+ import pandas as pd
6
+
7
+
8
+ def snr_to_sigma(snr):
9
+ return 10**(-snr/20.)*255.
10
+
11
+
12
+ def sigma_to_snr(sigma):
13
+ return -20.*np.log10(sigma/255.)
14
+
15
+
16
+ def plot_results(selected_paths, title=None, diff=True, ylim=None):
17
+ # plt.figure(figsize=(10, 10))
18
+ all_stats = {}
19
+ fig, ax = plt.subplots(layout='constrained', figsize=(10, 10))
20
+ for selected_path, selected_regex in selected_paths:
21
+ selected_path = Path(selected_path)
22
+ assert selected_path.exists()
23
+ results_path = sorted(list(selected_path.glob(selected_regex)))
24
+ stats = []
25
+ for result_path in results_path:
26
+ df = pd.read_csv(result_path)
27
+ in_psnr = df["in_PSNR"].mean()
28
+ out_psnr = df["out_PSNR"].mean()
29
+ out_ssim = df["out_SSIM"].mean()
30
+ noise_stddev = np.array([
31
+ float(el.replace("}", "").split(":")[1]) for el in df["degradation"]]).mean()
32
+ stats.append({
33
+ # "label": label,
34
+ "in_psnr": in_psnr,
35
+ "out_psnr": out_psnr,
36
+ "noise_stddev": noise_stddev,
37
+ "ssim": out_ssim
38
+ })
39
+ label = selected_path.name + " " + df["size"][0]
40
+
41
+ stats_array = pd.DataFrame(stats)
42
+ all_stats[label] = stats_array
43
+ x_data = stats_array["in_psnr"].copy()
44
+ x_data = snr_to_sigma(x_data)
45
+
46
+ ax.plot(
47
+ x_data,
48
+ stats_array["out_psnr"]-stats_array["in_psnr"] if diff else stats_array["out_psnr"],
49
+ "-o",
50
+ label=label
51
+ )
52
+ # label=selected_path.name)
53
+ if not diff:
54
+ neutral_sigma = np.linspace(1, 80, 80)
55
+ ax.plot(neutral_sigma, sigma_to_snr(neutral_sigma), "k--", alpha=0.1, label="Neutral")
56
+ secax = ax.secondary_xaxis('top', functions=(sigma_to_snr, snr_to_sigma))
57
+ secax.set_xlabel('PSNR [db]')
58
+
59
+ ax.set_xlabel("sigma 255")
60
+ ax.set_ylabel("PSNR improvement" if diff else "PSNR out")
61
+ plt.xlim(1., 50.)
62
+ if diff:
63
+ plt.ylim(0, 15)
64
+ else:
65
+ if ylim is not None:
66
+ plt.ylim(*ylim)
67
+ if title is not None:
68
+ plt.title(title)
69
+ plt.legend()
70
+ plt.grid()
71
+ plt.show()
72
+
73
+ return all_stats
src/rstor/analyzis/parser.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from rstor.analyzis.interactive.model_selection import MODELS_PATH
2
+ import argparse
3
+
4
+
5
+ def get_models_parser(parser: argparse.ArgumentParser = None, help: str = "Inference",
6
+ default_models_path: str = MODELS_PATH) -> argparse.ArgumentParser:
7
+ if parser is None:
8
+ parser = argparse.ArgumentParser(description=help)
9
+ parser.add_argument("-e", "--experiments", type=int, nargs="+", required=True,
10
+ help="Experience indexes to be used at inference time")
11
+ parser.add_argument("-m", "--models-storage", type=str, help="Model storage path", default=default_models_path)
12
+ return parser
13
+
14
+
15
+ def get_parser(
16
+ parser: argparse.ArgumentParser = None,
17
+ help: str = "Live inference pipeline"
18
+ ) -> argparse.ArgumentParser:
19
+ """Generic parser for live interactive inference
20
+ """
21
+ if parser is None:
22
+ parser = argparse.ArgumentParser(description=help)
23
+ get_models_parser(parser=parser, help=help)
24
+ parser.add_argument("-k", "--keyboard", action="store_true", help="Keyboard control - less sliders")
25
+ parser.add_argument("-b", "--backend", default="gradio", help="Backend to use for the GUI", choices=["gradio", "qt"])
26
+ return parser
src/rstor/architecture/base.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from rstor.properties import LEAKY_RELU, RELU, SIMPLE_GATE
3
+ from typing import Optional, Tuple
4
+
5
+
6
+ class SimpleGate(torch.nn.Module):
7
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
8
+ x1, x2 = x.chunk(2, dim=1)
9
+ return x1 * x2
10
+
11
+
12
+ def get_non_linearity(activation: str):
13
+ if activation == LEAKY_RELU:
14
+ non_linearity = torch.nn.LeakyReLU()
15
+ elif activation == RELU:
16
+ non_linearity = torch.nn.ReLU()
17
+ elif activation is None:
18
+ non_linearity = torch.nn.Identity()
19
+ elif activation == SIMPLE_GATE:
20
+ non_linearity = SimpleGate()
21
+ else:
22
+ raise ValueError(f"Unknown activation {activation}")
23
+ return non_linearity
24
+
25
+
26
+ class BaseModel(torch.nn.Module):
27
+ """Base class for all restoration models with additional useful methods"""
28
+
29
+ def count_parameters(self):
30
+ return sum(p.numel() for p in self.parameters() if p.requires_grad)
31
+
32
+ def receptive_field(
33
+ self,
34
+ channels: Optional[int] = 3,
35
+ size: Optional[int] = 256,
36
+ device: Optional[str] = None
37
+ ) -> Tuple[int, int]:
38
+ """Compute the receptive field of the model
39
+
40
+ Returns:
41
+ int: receptive field
42
+ """
43
+ input_tensor = torch.ones(1, channels, size, size, requires_grad=True)
44
+ if device is not None:
45
+ input_tensor = input_tensor.to(device)
46
+ out = self.forward(input_tensor)
47
+ grad = torch.zeros_like(out)
48
+ grad[..., out.shape[-2]//2, out.shape[-1]//2] = torch.nan # set NaN gradient at the middle of the output
49
+ out.backward(gradient=grad)
50
+ self.zero_grad()
51
+ receptive_field_mask = input_tensor.grad.isnan()[0, 0]
52
+ receptive_field_indexes = torch.where(receptive_field_mask)
53
+ # Count NaN in the input
54
+ receptive_x = 1+receptive_field_indexes[-1].max() - receptive_field_indexes[-1].min() # Horizontal x
55
+ receptive_y = 1+receptive_field_indexes[-2].max() - receptive_field_indexes[-2].min() # Vertical y
56
+ return receptive_x.item(), receptive_y.item()
src/rstor/architecture/convolution_blocks.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from rstor.properties import LEAKY_RELU
3
+ from rstor.architecture.base import get_non_linearity
4
+
5
+
6
+ class BaseConvolutionBlock(torch.nn.Module):
7
+ def __init__(
8
+ self,
9
+ ch_in: int,
10
+ ch_out: int,
11
+ k_size: int,
12
+ activation=LEAKY_RELU,
13
+ bias: bool = True
14
+ ) -> None:
15
+ super().__init__()
16
+ self.conv = torch.nn.Conv2d(ch_in, ch_out, k_size, padding=k_size//2, bias=bias)
17
+ self.non_linearity = get_non_linearity(activation)
18
+ self.conv_non_lin = torch.nn.Sequential(self.conv, self.non_linearity)
19
+
20
+ def forward(self, x_in: torch.Tensor) -> torch.Tensor:
21
+ return self.conv_non_lin(x_in)
22
+
23
+
24
+ class ResConvolutionBlock(torch.nn.Module):
25
+ def __init__(
26
+ self,
27
+ ch_in: int,
28
+ ch_out: int,
29
+ k_size: int,
30
+ activation=LEAKY_RELU,
31
+ bias: bool = True,
32
+ residual: bool = True
33
+ ) -> None:
34
+ super().__init__()
35
+ self.conv1 = torch.nn.Conv2d(ch_in, ch_out, k_size, padding=k_size//2, bias=bias)
36
+ self.non_linearity = get_non_linearity(activation)
37
+ self.conv2 = torch.nn.Conv2d(ch_out, ch_out, k_size, padding=k_size//2, bias=bias)
38
+ self.residual = residual
39
+
40
+ def forward(self, x_in: torch.Tensor) -> torch.Tensor:
41
+ y = self.conv1(x_in)
42
+ y = self.non_linearity(y)
43
+ y = self.conv2(y)
44
+ if self.residual:
45
+ y = x_in + y
46
+ y = self.non_linearity(y)
47
+ return y
src/rstor/architecture/nafnet.py ADDED
@@ -0,0 +1,299 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ NAFNet: Non linear activation free neural network
3
+ Architecture adapted from Simple Baselines for Image Restoration
4
+ https://github.com/megvii-research/NAFNet/tree/main
5
+ """
6
+ from torch import nn
7
+ import torch.nn.functional as F
8
+ import torch
9
+ from rstor.architecture.base import BaseModel, get_non_linearity
10
+ from typing import Optional, List
11
+ from rstor.properties import RELU, SIMPLE_GATE
12
+
13
+
14
+ class LayerNormFunction(torch.autograd.Function):
15
+
16
+ @staticmethod
17
+ def forward(ctx, x, weight, bias, eps):
18
+ ctx.eps = eps
19
+ N, C, H, W = x.size()
20
+ mu = x.mean(1, keepdim=True)
21
+ var = (x - mu).pow(2).mean(1, keepdim=True)
22
+ y = (x - mu) / (var + eps).sqrt()
23
+ ctx.save_for_backward(y, var, weight)
24
+ y = weight.view(1, C, 1, 1) * y + bias.view(1, C, 1, 1)
25
+ return y
26
+
27
+ @staticmethod
28
+ def backward(ctx, grad_output):
29
+ eps = ctx.eps
30
+
31
+ N, C, H, W = grad_output.size()
32
+ y, var, weight = ctx.saved_variables
33
+ g = grad_output * weight.view(1, C, 1, 1)
34
+ mean_g = g.mean(dim=1, keepdim=True)
35
+
36
+ mean_gy = (g * y).mean(dim=1, keepdim=True)
37
+ gx = 1. / torch.sqrt(var + eps) * (g - y * mean_gy - mean_g)
38
+ return gx, (grad_output * y).sum(dim=3).sum(dim=2).sum(dim=0), grad_output.sum(dim=3).sum(dim=2).sum(
39
+ dim=0), None
40
+
41
+
42
+ class LayerNorm2d(nn.Module):
43
+ def __init__(self, channels, eps=1e-6):
44
+ super(LayerNorm2d, self).__init__()
45
+ self.register_parameter('weight', nn.Parameter(torch.ones(channels)))
46
+ self.register_parameter('bias', nn.Parameter(torch.zeros(channels)))
47
+ self.eps = eps
48
+
49
+ def forward(self, x):
50
+ return LayerNormFunction.apply(x, self.weight, self.bias, self.eps)
51
+
52
+
53
+ class NAFBlock(nn.Module):
54
+ def __init__(
55
+ self,
56
+ c, DW_Expand=2, FFN_Expand=2, drop_out_rate=0.,
57
+ activation: Optional[str] = SIMPLE_GATE,
58
+ layer_norm_flag: Optional[bool] = True,
59
+ channel_attention_flag: Optional[bool] = True,
60
+ ):
61
+ super().__init__()
62
+ self.layer_norm_flag = layer_norm_flag
63
+ self.channel_attention_flag = channel_attention_flag
64
+ dw_channel = c * DW_Expand
65
+ half_dw_channel = dw_channel // 2
66
+ self.conv1 = nn.Conv2d(in_channels=c, out_channels=dw_channel, kernel_size=1,
67
+ padding=0, stride=1, groups=1, bias=True)
68
+ self.conv2 = nn.Conv2d(
69
+ in_channels=dw_channel,
70
+ out_channels=dw_channel if activation == SIMPLE_GATE else half_dw_channel,
71
+ kernel_size=3,
72
+ padding=1, stride=1,
73
+ groups=dw_channel if activation == SIMPLE_GATE else half_dw_channel,
74
+ bias=True
75
+ )
76
+ # To grand the same amount of parameters between Simple Gate and ReLU versions...
77
+ # Conv2 has to reduce the number of channels to half but... using grouped convolution
78
+ # w -> w/2 ... not really a depthwise convolution but rather by channels of 2!
79
+ self.conv3 = nn.Conv2d(in_channels=half_dw_channel, out_channels=c,
80
+ kernel_size=1, padding=0, stride=1, groups=1, bias=True)
81
+
82
+ # Simplified Channel Attention
83
+ if self.channel_attention_flag:
84
+ self.sca = nn.Sequential(
85
+ nn.AdaptiveAvgPool2d(1),
86
+ nn.Conv2d(in_channels=half_dw_channel, out_channels=half_dw_channel, kernel_size=1,
87
+ padding=0, stride=1,
88
+ groups=1, bias=True),
89
+ )
90
+
91
+ # SimpleGate
92
+ self.sg = get_non_linearity(activation)
93
+ ffn_channel = FFN_Expand
94
+ half_ffn_channel = ffn_channel // 2 if activation == SIMPLE_GATE else ffn_channel
95
+ self.conv4 = nn.Conv2d(
96
+ in_channels=c,
97
+ out_channels=ffn_channel if activation == SIMPLE_GATE else half_ffn_channel,
98
+ kernel_size=1,
99
+ padding=0, stride=1, groups=1, bias=True)
100
+ self.conv5 = nn.Conv2d(in_channels=half_ffn_channel, out_channels=c,
101
+ kernel_size=1, padding=0, stride=1, groups=1, bias=True)
102
+ if self.layer_norm_flag:
103
+ self.norm1 = LayerNorm2d(c)
104
+ self.norm2 = LayerNorm2d(c)
105
+
106
+ self.dropout1 = nn.Dropout(drop_out_rate) if drop_out_rate > 0. else nn.Identity()
107
+ self.dropout2 = nn.Dropout(drop_out_rate) if drop_out_rate > 0. else nn.Identity()
108
+
109
+ self.beta = nn.Parameter(torch.zeros((1, c, 1, 1)), requires_grad=True)
110
+ self.gamma = nn.Parameter(torch.zeros((1, c, 1, 1)), requires_grad=True)
111
+
112
+ def forward(self, inp):
113
+ x = inp
114
+ if self.layer_norm_flag:
115
+ x = self.norm1(x)
116
+
117
+ x = self.conv1(x)
118
+ x = self.conv2(x)
119
+ x = self.sg(x)
120
+ if self.channel_attention_flag:
121
+ x = x * self.sca(x)
122
+ x = self.conv3(x)
123
+
124
+ x = self.dropout1(x)
125
+
126
+ y = inp + x * self.beta
127
+
128
+ x = self.conv4(self.norm2(y) if self.layer_norm_flag else y)
129
+ x = self.sg(x)
130
+ x = self.conv5(x)
131
+
132
+ x = self.dropout2(x)
133
+
134
+ return y + x * self.gamma
135
+
136
+
137
+ class NAFNet(BaseModel):
138
+ def __init__(
139
+ self,
140
+ img_channel: Optional[int] = 3,
141
+ width: Optional[int] = 16,
142
+ middle_blk_num: Optional[int] = 1,
143
+ enc_blk_nums: List[int] = [],
144
+ dec_blk_nums: List[int] = [],
145
+ activation: Optional[bool] = SIMPLE_GATE,
146
+ layer_norm_flag: Optional[bool] = True,
147
+ channel_attention_flag: Optional[bool] = True,
148
+ ) -> None:
149
+ super().__init__()
150
+
151
+ self.intro = nn.Conv2d(
152
+ in_channels=img_channel,
153
+ out_channels=width,
154
+ kernel_size=3,
155
+ padding=1, stride=1,
156
+ groups=1,
157
+ bias=True
158
+ )
159
+ config_block = {
160
+ "activation": activation,
161
+ "layer_norm_flag": layer_norm_flag,
162
+ "channel_attention_flag": channel_attention_flag
163
+ }
164
+ self.ending = nn.Conv2d(
165
+ in_channels=width, out_channels=img_channel, kernel_size=3,
166
+ padding=1, stride=1, groups=1,
167
+ bias=True)
168
+
169
+ self.encoders = nn.ModuleList()
170
+ self.decoders = nn.ModuleList()
171
+ self.middle_blks = nn.ModuleList()
172
+ self.ups = nn.ModuleList()
173
+ self.downs = nn.ModuleList()
174
+
175
+ chan = width
176
+ for num in enc_blk_nums:
177
+ self.encoders.append(
178
+ nn.Sequential(
179
+ *[NAFBlock(chan, **config_block) for _ in range(num)]
180
+ )
181
+ )
182
+ self.downs.append(
183
+ nn.Conv2d(chan, 2*chan, 2, 2)
184
+ )
185
+ chan = chan * 2
186
+
187
+ self.middle_blks = \
188
+ nn.Sequential(
189
+ *[NAFBlock(chan, **config_block) for _ in range(middle_blk_num)]
190
+ )
191
+
192
+ for num in dec_blk_nums:
193
+ self.ups.append(
194
+ nn.Sequential(
195
+ nn.Conv2d(chan, chan * 2, 1, bias=False),
196
+ nn.PixelShuffle(2)
197
+ )
198
+ )
199
+ chan = chan // 2
200
+ self.decoders.append(
201
+ nn.Sequential(
202
+ *[NAFBlock(chan, **config_block) for _ in range(num)]
203
+ )
204
+ )
205
+
206
+ self.padder_size = 2 ** len(self.encoders)
207
+
208
+ def forward(self, inp: torch.Tensor) -> torch.Tensor:
209
+ B, C, H, W = inp.shape
210
+ inp = self.sanitize_image_size(inp)
211
+
212
+ x = self.intro(inp)
213
+
214
+ encs = []
215
+
216
+ for encoder, down in zip(self.encoders, self.downs):
217
+ x = encoder(x)
218
+ encs.append(x)
219
+ x = down(x)
220
+
221
+ x = self.middle_blks(x)
222
+
223
+ for decoder, up, enc_skip in zip(self.decoders, self.ups, encs[::-1]):
224
+ x = up(x)
225
+ x = x + enc_skip
226
+ x = decoder(x)
227
+
228
+ x = self.ending(x)
229
+ x = x + inp
230
+
231
+ return x[:, :, :H, :W]
232
+
233
+ def sanitize_image_size(self, x: torch.Tensor) -> torch.Tensor:
234
+ _, _, h, w = x.size()
235
+ mod_pad_h = (self.padder_size - h % self.padder_size) % self.padder_size
236
+ mod_pad_w = (self.padder_size - w % self.padder_size) % self.padder_size
237
+ x = F.pad(x, (0, mod_pad_w, 0, mod_pad_h))
238
+ return x
239
+
240
+
241
+ class UNet(NAFNet):
242
+ def __init__(
243
+ self,
244
+ activation: Optional[bool] = RELU,
245
+ layer_norm_flag: Optional[bool] = False,
246
+ channel_attention_flag: Optional[bool] = False,
247
+ **kwargs):
248
+ super().__init__(
249
+ activation=activation,
250
+ layer_norm_flag=layer_norm_flag,
251
+ channel_attention_flag=channel_attention_flag, **kwargs)
252
+
253
+
254
+ if __name__ == '__main__':
255
+ tiny_recetive_field = True
256
+ if tiny_recetive_field:
257
+ enc_blks = [1, 1, 2]
258
+ middle_blk_num = 1
259
+ dec_blks = [1, 1, 1]
260
+ width = 16
261
+ # Receptive field is 208x208
262
+ else:
263
+ enc_blks = [1, 1, 1, 28]
264
+ middle_blk_num = 1
265
+ dec_blks = [1, 1, 1, 1]
266
+ width = 2
267
+ # Receptive field is 544x544
268
+ device = "cpu"
269
+
270
+ for model_name in ["NAFNet", "UNet"]:
271
+ if model_name == "NAFNet":
272
+ model = NAFNet(
273
+ img_channel=3,
274
+ width=width,
275
+ middle_blk_num=middle_blk_num,
276
+ enc_blk_nums=enc_blks,
277
+ dec_blk_nums=dec_blks,
278
+ activation=SIMPLE_GATE,
279
+ layer_norm_flag=False,
280
+ channel_attention_flag=False
281
+ )
282
+ if model_name == "UNet":
283
+ model = UNet(
284
+ img_channel=3,
285
+ width=width,
286
+ middle_blk_num=middle_blk_num,
287
+ enc_blk_nums=enc_blks,
288
+ dec_blk_nums=dec_blks
289
+ )
290
+ model.to(device)
291
+ with torch.no_grad():
292
+ x = torch.randn(1, 3, 256, 256).to(device)
293
+ y = model(x)
294
+
295
+ # print(y.shape)
296
+ # print(y)
297
+ # print(model)
298
+ print(f"{model.count_parameters()/1E3:.2f}k parameters")
299
+ print(model.receptive_field(size=256 if tiny_recetive_field else 1024, device=device))
src/rstor/architecture/selector.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from rstor.properties import MODEL, NAME, N_PARAMS, ARCHITECTURE
2
+ from rstor.architecture.stacked_convolutions import StackedConvolutions
3
+ from rstor.architecture.nafnet import NAFNet, UNet
4
+ import torch
5
+
6
+
7
+ def load_architecture(config: dict) -> torch.nn.Module:
8
+ conf_model = config[MODEL][ARCHITECTURE]
9
+ if config[MODEL][NAME] == StackedConvolutions.__name__:
10
+ model = StackedConvolutions(**conf_model)
11
+ elif config[MODEL][NAME] == NAFNet.__name__:
12
+ model = NAFNet(**conf_model)
13
+ elif config[MODEL][NAME] == UNet.__name__:
14
+ model = UNet(**conf_model)
15
+ else:
16
+ raise ValueError(f"Unknown model {config[MODEL][NAME]}")
17
+ config[MODEL][N_PARAMS] = model.count_parameters()
18
+ config[MODEL]["receptive_field"] = model.receptive_field()
19
+ return model
src/rstor/architecture/stacked_convolutions.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from rstor.architecture.base import BaseModel
2
+ from rstor.architecture.convolution_blocks import BaseConvolutionBlock, ResConvolutionBlock
3
+ from rstor.properties import LEAKY_RELU
4
+ import torch
5
+
6
+
7
+ class StackedConvolutions(BaseModel):
8
+ def __init__(self,
9
+ ch_in: int = 3,
10
+ ch_out: int = 3,
11
+ h_dim: int = 64,
12
+ num_layers: int = 8,
13
+ k_size: int = 3,
14
+ activation: str = LEAKY_RELU,
15
+ bias: bool = True,
16
+ ) -> None:
17
+ super().__init__()
18
+ assert num_layers % 2 == 0, "Number of layers should be even"
19
+ self.conv_in_modality = BaseConvolutionBlock(
20
+ ch_in, h_dim, k_size, activation=activation, bias=bias)
21
+ conv_list = []
22
+ for _i in range(num_layers-2):
23
+ conv_list.append(ResConvolutionBlock(
24
+ h_dim, h_dim, k_size, activation=activation, bias=bias, residual=True))
25
+ self.conv_out_modality = BaseConvolutionBlock(
26
+ h_dim, ch_out, k_size, activation=None, bias=bias)
27
+ self.conv_stack = torch.nn.Sequential(self.conv_in_modality, *conv_list, self.conv_out_modality)
28
+
29
+ def forward(self, x_in: torch.Tensor) -> torch.Tensor:
30
+ return self.conv_stack(x_in)
src/rstor/data/augmentation.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from typing import Tuple, Optional
3
+
4
+
5
+ def augment_flip(
6
+ img: torch.Tensor,
7
+ flip: Optional[Tuple[bool, bool]] = None
8
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
9
+ """Roll pixels horizontally to avoid negative index
10
+
11
+ Args:
12
+ img (torch.Tensor): [N, 3, H, W] image tensor
13
+ lab (torch.Tensor): [N, 3, H, W] label tensor
14
+ flip (Optional[bool], optional): forced flip_h, flip_v value. Defaults to None.
15
+ If not provided, a random flip_h, flip_v values are used
16
+ Returns:
17
+ torch.Tensor, torch.Tensor: flipped image, labels
18
+
19
+ """
20
+ if flip is None:
21
+ flip = torch.randint(0, 2, (2,))
22
+ flipped_img = img
23
+ if flip[0] > 0:
24
+ flipped_img = torch.flip(flipped_img, (-1,))
25
+ if flip[1] > 0:
26
+ flipped_img = torch.flip(flipped_img, (-2,))
27
+ return flipped_img
src/rstor/data/dataloader.py ADDED
@@ -0,0 +1,120 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch.utils.data import DataLoader
2
+ from rstor.data.synthetic_dataloader import DeadLeavesDataset, DeadLeavesDatasetGPU
3
+ from rstor.data.stored_images_dataloader import RestorationDataset
4
+ from rstor.properties import (
5
+ DATALOADER, BATCH_SIZE, TRAIN, VALIDATION, LENGTH, CONFIG_DEAD_LEAVES, SIZE, NAME, CONFIG_DEGRADATION,
6
+ DATASET_SYNTH_LIST, DATASET_DIV2K,
7
+ DATASET_PATH
8
+ )
9
+ from typing import Optional
10
+ from random import seed, shuffle
11
+
12
+
13
+ def get_data_loader_synthetic(config, frozen_seed=42):
14
+ # print(config[DATALOADER].get(CONFIG_DEAD_LEAVES, {}))
15
+ if config[DATALOADER].get("gpu_gen", False):
16
+ print("Using GPU dead leaves generator")
17
+ ds = DeadLeavesDatasetGPU
18
+ else:
19
+ ds = DeadLeavesDataset
20
+ dl_train = ds(config[DATALOADER][SIZE], config[DATALOADER][LENGTH][TRAIN],
21
+ frozen_seed=None, **config[DATALOADER].get(CONFIG_DEAD_LEAVES, {}))
22
+ dl_valid = ds(config[DATALOADER][SIZE], config[DATALOADER][LENGTH][VALIDATION],
23
+ frozen_seed=frozen_seed, **config[DATALOADER].get(CONFIG_DEAD_LEAVES, {}))
24
+ dl_dict = create_dataloaders(config, dl_train, dl_valid)
25
+ return dl_dict
26
+
27
+
28
+ def create_dataloaders(config, dl_train, dl_valid) -> dict:
29
+ dl_dict = {
30
+ TRAIN: DataLoader(
31
+ dl_train,
32
+ shuffle=True,
33
+ batch_size=config[DATALOADER][BATCH_SIZE][TRAIN],
34
+ ),
35
+ VALIDATION: DataLoader(
36
+ dl_valid,
37
+ shuffle=False,
38
+ batch_size=config[DATALOADER][BATCH_SIZE][VALIDATION]
39
+ ),
40
+ # TEST: DataLoader(dl_test, shuffle=False, batch_size=config[DATALOADER][BATCH_SIZE][TEST])
41
+ }
42
+ return dl_dict
43
+
44
+
45
+ def get_data_loader_from_disk(config, frozen_seed: Optional[int] = 42) -> dict:
46
+ ds = RestorationDataset
47
+ dataset_name = config[DATALOADER][NAME] # NAME shall be here!
48
+ if dataset_name == DATASET_DIV2K:
49
+ dataset_root = DATASET_PATH/DATASET_DIV2K
50
+ train_root = dataset_root/"DIV2K_train_HR"/"DIV2K_train_HR"
51
+ valid_root = dataset_root/"DIV2K_valid_HR"/"DIV2K_valid_HR"
52
+ train_files = sorted(list(train_root.glob("*.png")))
53
+ train_files = 5*train_files # Just to get 4000 elements...
54
+ valid_files = sorted(list(valid_root.glob("*.png")))
55
+ elif dataset_name in DATASET_SYNTH_LIST:
56
+ dataset_root = DATASET_PATH/dataset_name
57
+ all_files = sorted(list(dataset_root.glob("*.png")))
58
+ seed(frozen_seed)
59
+ shuffle(all_files) # Easy way to perform cross validation if neeeded
60
+ cut_index = int(0.9*len(all_files))
61
+ train_files = all_files[:cut_index]
62
+ valid_files = all_files[cut_index:]
63
+ dl_train = ds(
64
+ train_files,
65
+ size=config[DATALOADER][SIZE],
66
+ frozen_seed=None,
67
+ **config[DATALOADER].get(CONFIG_DEGRADATION, {})
68
+ )
69
+ dl_valid = ds(
70
+ valid_files,
71
+ size=config[DATALOADER][SIZE],
72
+ frozen_seed=frozen_seed,
73
+ **config[DATALOADER].get(CONFIG_DEGRADATION, {})
74
+ )
75
+ dl_dict = create_dataloaders(config, dl_train, dl_valid)
76
+ return dl_dict
77
+
78
+
79
+ def get_data_loader(config, frozen_seed=42):
80
+ dataset_name = config[DATALOADER].get(NAME, False)
81
+ if dataset_name:
82
+ return get_data_loader_from_disk(config, frozen_seed)
83
+ else:
84
+ return get_data_loader_synthetic(config, frozen_seed)
85
+
86
+
87
+ if __name__ == "__main__":
88
+ # Example of usage synthetic dataset
89
+ for dataset_name in [DATASET_DIV2K, None, DATASET_DL_DIV2K_512, DATASET_DL_DIV2K_1024]:
90
+ if dataset_name is None:
91
+ dead_leaves_dataset = DeadLeavesDatasetGPU(colored=True)
92
+ dl = DataLoader(dead_leaves_dataset, batch_size=4, shuffle=True)
93
+ else:
94
+ # Example of usage stored images dataset
95
+ config = {
96
+ DATALOADER: {
97
+ NAME: dataset_name,
98
+ SIZE: (128, 128),
99
+ BATCH_SIZE: {
100
+ TRAIN: 4,
101
+ VALIDATION: 4
102
+ },
103
+ }
104
+ }
105
+ dl_dict = get_data_loader(config)
106
+ dl = dl_dict[TRAIN]
107
+ # dl = dl_dict[VALIDATION]
108
+ for i, (batch_inp, batch_target) in enumerate(dl):
109
+ print(batch_inp.shape, batch_target.shape) # Should print [batch_size, size[0], size[1], 3] for each batch
110
+ if i == 1: # Just to break the loop after two batches for demonstration
111
+ import matplotlib.pyplot as plt
112
+ plt.subplot(1, 2, 1)
113
+ plt.imshow(batch_inp.permute(0, 2, 3, 1).reshape(-1, batch_inp.shape[-1], 3).cpu().numpy())
114
+ plt.title("Degraded")
115
+ plt.subplot(1, 2, 2)
116
+ plt.imshow(batch_target.permute(0, 2, 3, 1).reshape(-1, batch_inp.shape[-1], 3).cpu().numpy())
117
+ plt.title("Target")
118
+ plt.show()
119
+ # print(batch_target)
120
+ break
src/rstor/data/degradation.py ADDED
@@ -0,0 +1,156 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ """
3
+ Created on Sun Mar 24 01:21:46 2024
4
+
5
+ @author: jamyl
6
+ """
7
+ import torch
8
+ from rstor.properties import DATASET_BLUR_KERNEL_PATH
9
+ import random
10
+ from scipy.io import loadmat
11
+ import cv2
12
+
13
+
14
+ class Degradation():
15
+ def __init__(self,
16
+ length: int = 1000,
17
+ frozen_seed: int = None):
18
+ self.frozen_seed = frozen_seed
19
+ self.current_degradation = {}
20
+
21
+
22
+ class DegradationNoise(Degradation):
23
+ def __init__(self,
24
+ length: int = 1000,
25
+ noise_stddev: float = [0., 50.],
26
+ frozen_seed: int = None):
27
+ super().__init__(length, frozen_seed)
28
+ self.noise_stddev = noise_stddev
29
+
30
+ if frozen_seed is not None:
31
+ random.seed(frozen_seed)
32
+ self.noise_stddev = [(self.noise_stddev[1] - self.noise_stddev[0]) *
33
+ random.random() + self.noise_stddev[0] for _ in range(length)]
34
+
35
+ def __call__(self, x: torch.Tensor, idx: int):
36
+ # WARNING! INPLACE OPERATIONS!!!!!
37
+ # expects x of shape [b, c, h, w]
38
+ assert x.ndim == 4
39
+ assert x.shape[1] in [1, 3]
40
+
41
+ if self.frozen_seed is not None:
42
+ std_dev = self.noise_stddev[idx]
43
+ else:
44
+ std_dev = (self.noise_stddev[1] - self.noise_stddev[0]) * random.random() + self.noise_stddev[0]
45
+
46
+ if std_dev > 0.:
47
+ # x += (std_dev/255.)*np.random.randn(*x.shape)
48
+ x += (std_dev/255.)*torch.randn(*x.shape, device=x.device)
49
+ self.current_degradation[idx] = {
50
+ "noise_stddev": std_dev
51
+ }
52
+ return x
53
+
54
+
55
+ class DegradationBlurMat(Degradation):
56
+ def __init__(self,
57
+ length: int = 1000,
58
+ frozen_seed: int = None,
59
+ blur_index: int = None):
60
+ super().__init__(length, frozen_seed)
61
+
62
+ kernels = loadmat(DATASET_BLUR_KERNEL_PATH)["kernels"].squeeze()
63
+ # conversion to torch (the shape of the kernel is not constant)
64
+ self.kernels = tuple([
65
+ torch.from_numpy(kernel/kernel.sum(keepdims=True)).unsqueeze(0).unsqueeze(0)
66
+ for kernel in kernels] + [torch.ones((1, 1)).unsqueeze(0).unsqueeze(0)])
67
+ self.n_kernels = len(self.kernels)
68
+
69
+ if frozen_seed is not None:
70
+ random.seed(frozen_seed)
71
+ self.kernel_ids = [random.randint(0, self.n_kernels-1) for _ in range(length)]
72
+ if blur_index is not None:
73
+ self.frozen_seed = 42
74
+ self.kernel_ids = [blur_index for _ in range(length)]
75
+
76
+ def __call__(self, x: torch.Tensor, idx: int):
77
+ # expects x of shape [b, c, h, w]
78
+ assert x.ndim == 4
79
+ assert x.shape[1] in [1, 3]
80
+ device = x.device
81
+
82
+ if self.frozen_seed is not None:
83
+ kernel_id = self.kernel_ids[idx]
84
+ else:
85
+ kernel_id = random.randint(0, self.n_kernels-1)
86
+
87
+ kernel = self.kernels[kernel_id].to(device).repeat(3, 1, 1, 1).float() # repeat for grouped conv
88
+ _, _, kh, kw = kernel.shape
89
+ # We use padding = same to make
90
+ # sure that the output size does not depend on the kernel.
91
+
92
+ # define nn.Conf layer to define both padding mode and padding value...
93
+ conv_layer = torch.nn.Conv2d(in_channels=x.shape[1],
94
+ out_channels=x.shape[1],
95
+ kernel_size=(kh, kw),
96
+ padding="same",
97
+ padding_mode='replicate',
98
+ groups=3,
99
+ bias=False)
100
+
101
+ # Set the predefined kernel as weights and freeze the parameters
102
+ with torch.no_grad():
103
+ conv_layer.weight = torch.nn.Parameter(kernel)
104
+ conv_layer.weight.requires_grad = False
105
+ # breakpoint()
106
+ x = conv_layer(x)
107
+ # Alternative Functional version with 0 padding :
108
+ # x = F.conv2d(x, kernel, padding="same", groups=3)
109
+
110
+ self.current_degradation[idx] = {
111
+ "blur_kernel_id": kernel_id
112
+ }
113
+ return x
114
+
115
+
116
+ class DegradationBlurGauss(Degradation):
117
+ def __init__(self,
118
+ length: int = 1000,
119
+ blur_kernel_half_size: int = [0, 2],
120
+ frozen_seed: int = None):
121
+ super().__init__(length, frozen_seed)
122
+
123
+ self.blur_kernel_half_size = blur_kernel_half_size
124
+ # conversion to torch (the shape of the kernel is not constant)
125
+ if frozen_seed is not None:
126
+ random.seed(self.frozen_seed)
127
+ self.blur_kernel_half_size = [
128
+ (
129
+ random.randint(self.blur_kernel_half_size[0], self.blur_kernel_half_size[1]),
130
+ random.randint(self.blur_kernel_half_size[0], self.blur_kernel_half_size[1])
131
+ ) for _ in range(length)
132
+ ]
133
+
134
+ def __call__(self, x: torch.Tensor, idx: int):
135
+ # expects x of shape [b, c, h, w]
136
+ assert x.ndim == 4
137
+ assert x.shape[1] in [1, 3]
138
+ device = x.device
139
+
140
+ if self.frozen_seed is not None:
141
+ k_size_x, k_size_y = self.blur_kernel_half_size[idx]
142
+ else:
143
+ k_size_x = random.randint(self.blur_kernel_half_size[0], self.blur_kernel_half_size[1])
144
+ k_size_y = random.randint(self.blur_kernel_half_size[0], self.blur_kernel_half_size[1])
145
+
146
+ k_size_x = 2 * k_size_x + 1
147
+ k_size_y = 2 * k_size_y + 1
148
+
149
+ x = x.squeeze(0).permute(1, 2, 0).cpu().numpy()
150
+ x = cv2.GaussianBlur(x, (k_size_x, k_size_y), 0)
151
+ x = torch.from_numpy(x).to(device).permute(2, 0, 1).unsqueeze(0)
152
+
153
+ self.current_degradation[idx] = {
154
+ "blur_kernel_half_size": (k_size_x, k_size_y),
155
+ }
156
+ return x
src/rstor/data/stored_images_dataloader.py ADDED
@@ -0,0 +1,156 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch.utils.data import DataLoader, Dataset
3
+ from rstor.data.augmentation import augment_flip
4
+ from rstor.data.degradation import DegradationBlurMat, DegradationBlurGauss, DegradationNoise
5
+ from rstor.properties import DEVICE, AUGMENTATION_FLIP, AUGMENTATION_ROTATE, DEGRADATION_BLUR_NONE, DEGRADATION_BLUR_MAT, DEGRADATION_BLUR_GAUSS
6
+ from rstor.properties import DATALOADER, BATCH_SIZE, TRAIN, VALIDATION, LENGTH, CONFIG_DEAD_LEAVES, SIZE
7
+ from typing import Tuple, Optional, Union
8
+ from torchvision import transforms
9
+ # from torchvision.transforms import RandomCrop
10
+ from pathlib import Path
11
+ from tqdm import tqdm
12
+ from time import time
13
+ from torchvision.io import read_image
14
+ IMAGES_FOLDER = "images"
15
+
16
+
17
+ def load_image(path):
18
+ return read_image(str(path))
19
+
20
+
21
+ class RestorationDataset(Dataset):
22
+ def __init__(
23
+ self,
24
+ images_path: Path,
25
+ size: Tuple[int, int] = (128, 128),
26
+ device: str = DEVICE,
27
+ preloaded: bool = False,
28
+ augmentation_list: Optional[list] = [],
29
+ frozen_seed: int = None, # useful for validation set!
30
+ blur_kernel_half_size: int = [0, 2],
31
+ noise_stddev: float = [0., 50.],
32
+ degradation_blur=DEGRADATION_BLUR_NONE,
33
+ blur_index=None,
34
+ **_extra_kwargs
35
+ ):
36
+ self.preloaded = preloaded
37
+ self.augmentation_list = augmentation_list
38
+ self.device = device
39
+ self.frozen_seed = frozen_seed
40
+ if not isinstance(images_path, list):
41
+ self.path_list = sorted(list(images_path.glob("*.png")))
42
+ else:
43
+ self.path_list = images_path
44
+
45
+ self.length = len(self.path_list)
46
+
47
+ self.n_samples = len(self.path_list)
48
+ # If we can preload everything in memory, we can do it
49
+ if preloaded:
50
+ self.data_list = [load_image(pth) for pth in tqdm(self.path_list)]
51
+ else:
52
+ self.data_list = self.path_list
53
+
54
+ # if AUGMENTATION_FLIP in self.augmentation_list:
55
+ # img_data = augment_flip(img_data)
56
+ # img_data = self.cropper(img_data)
57
+ self.transforms = []
58
+
59
+ if self.frozen_seed is None:
60
+ if AUGMENTATION_FLIP in self.augmentation_list:
61
+ self.transforms.append(transforms.RandomHorizontalFlip(p=0.5))
62
+ self.transforms.append(transforms.RandomVerticalFlip(p=0.5))
63
+ if AUGMENTATION_ROTATE in self.augmentation_list:
64
+ self.transforms.append(transforms.RandomRotation(degrees=180))
65
+
66
+ crop = transforms.RandomCrop(size) if frozen_seed is None else transforms.CenterCrop(size)
67
+ self.transforms.append(crop)
68
+ self.transforms = transforms.Compose(self.transforms)
69
+
70
+ # self.cropper = RandomCrop(size=size)
71
+
72
+ self.degradation_blur_type = degradation_blur
73
+ if degradation_blur == DEGRADATION_BLUR_GAUSS:
74
+ self.degradation_blur = DegradationBlurGauss(self.length,
75
+ blur_kernel_half_size,
76
+ frozen_seed)
77
+ self.blur_deg_str = "blur_kernel_half_size"
78
+ elif degradation_blur == DEGRADATION_BLUR_MAT:
79
+ self.degradation_blur = DegradationBlurMat(self.length,
80
+ frozen_seed,
81
+ blur_index)
82
+ self.blur_deg_str = "blur_kernel_id"
83
+ elif degradation_blur == DEGRADATION_BLUR_NONE:
84
+ pass
85
+ else:
86
+ raise ValueError(f"Unknown degradation blur {degradation_blur}")
87
+
88
+ self.degradation_noise = DegradationNoise(self.length,
89
+ noise_stddev,
90
+ frozen_seed)
91
+ self.current_degradation = {}
92
+
93
+ def __getitem__(self, index: int) -> Tuple[torch.Tensor, Union[torch.Tensor, None]]:
94
+ """Access a specific image from dataset and augment
95
+
96
+ Args:
97
+ index (int): access index
98
+
99
+ Returns:
100
+ torch.Tensor: image tensor [C, H, W]
101
+ """
102
+ if self.preloaded:
103
+ img_data = self.data_list[index]
104
+ else:
105
+ img_data = load_image(self.data_list[index])
106
+ img_data = img_data.to(self.device)
107
+
108
+ # if AUGMENTATION_FLIP in self.augmentation_list:
109
+ # img_data = augment_flip(img_data)
110
+ # img_data = self.cropper(img_data)
111
+
112
+ img_data = self.transforms(img_data)
113
+ img_data = img_data.float()/255.
114
+ degraded_img = img_data.clone().unsqueeze(0)
115
+
116
+ self.current_degradation[index] = {}
117
+ if self.degradation_blur_type != DEGRADATION_BLUR_NONE:
118
+ degraded_img = self.degradation_blur(degraded_img, index)
119
+ self.current_degradation[index][self.blur_deg_str] = self.degradation_blur.current_degradation[index][self.blur_deg_str]
120
+
121
+ degraded_img = self.degradation_noise(degraded_img, index)
122
+ self.current_degradation[index]["noise_stddev"] = self.degradation_noise.current_degradation[index]["noise_stddev"]
123
+
124
+ degraded_img = degraded_img.squeeze(0)
125
+ self.current_degradation[index] = {
126
+ "noise_stddev": self.degradation_noise.current_degradation[index]["noise_stddev"]
127
+ }
128
+ try:
129
+ self.current_degradation[index][self.blur_deg_str] = self.degradation_blur.current_degradation[index][self.blur_deg_str]
130
+ except KeyError:
131
+ pass
132
+
133
+ return degraded_img, img_data
134
+
135
+ def __len__(self):
136
+ return self.n_samples
137
+
138
+
139
+ if __name__ == "__main__":
140
+ dataset_restoration = RestorationDataset(
141
+ Path("__dataset/div2k/DIV2K_train_HR/DIV2K_train_HR/"),
142
+ preloaded=True,
143
+ )
144
+ dataloader = DataLoader(
145
+ dataset_restoration,
146
+ batch_size=16,
147
+ shuffle=True
148
+ )
149
+ start = time()
150
+ total = 0
151
+ for batch in tqdm(dataloader):
152
+ # print(batch.shape)
153
+ torch.cuda.synchronize()
154
+ total += batch.shape[0]
155
+ end = time()
156
+ print(f"Time elapsed: {(end-start)/total*1000.:.2f}ms/image")
src/rstor/data/synthetic_dataloader.py ADDED
@@ -0,0 +1,187 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import torch
3
+ import torch.nn.functional as F
4
+ from torch.utils.data import Dataset
5
+ from typing import Tuple
6
+ from rstor.data.degradation import DegradationBlurMat, DegradationBlurGauss, DegradationNoise
7
+ from rstor.properties import DEVICE, AUGMENTATION_FLIP, DEGRADATION_BLUR_NONE, DEGRADATION_BLUR_MAT, DEGRADATION_BLUR_GAUSS
8
+ from rstor.synthetic_data.dead_leaves_cpu import cpu_dead_leaves_chart
9
+ from rstor.synthetic_data.dead_leaves_gpu import gpu_dead_leaves_chart
10
+ import cv2
11
+ from skimage.filters import gaussian
12
+ import random
13
+ import numpy as np
14
+
15
+ from rstor.utils import DEFAULT_TORCH_FLOAT_TYPE
16
+
17
+
18
+ class DeadLeavesDataset(Dataset):
19
+ def __init__(
20
+ self,
21
+ size: Tuple[int, int] = (128, 128),
22
+ length: int = 1000,
23
+ frozen_seed: int = None, # useful for validation set!
24
+ blur_kernel_half_size: int = [0, 2],
25
+ ds_factor: int = 5,
26
+ noise_stddev: float = [0., 50.],
27
+ degradation_blur=DEGRADATION_BLUR_NONE,
28
+ **config_dead_leaves
29
+ # number_of_circles: int = -1,
30
+ # background_color: Optional[Tuple[float, float, float]] = (0.5, 0.5, 0.5),
31
+ # colored: Optional[bool] = False,
32
+ # radius_mean: Optional[int] = -1,
33
+ # radius_stddev: Optional[int] = -1,
34
+ ):
35
+
36
+ self.frozen_seed = frozen_seed
37
+ self.ds_factor = ds_factor
38
+ self.size = (size[0]*ds_factor, size[1]*ds_factor)
39
+ self.length = length
40
+ self.config_dead_leaves = config_dead_leaves
41
+ self.blur_kernel_half_size = blur_kernel_half_size
42
+ self.noise_stddev = noise_stddev
43
+
44
+
45
+ self.degradation_blur_type = degradation_blur
46
+ if degradation_blur == DEGRADATION_BLUR_GAUSS:
47
+ self.degradation_blur = DegradationBlurGauss(self.length,
48
+ blur_kernel_half_size,
49
+ frozen_seed)
50
+ self.blur_deg_str = "blur_kernel_half_size"
51
+ elif degradation_blur == DEGRADATION_BLUR_MAT:
52
+ self.degradation_blur = DegradationBlurMat(self.length,
53
+ frozen_seed)
54
+ self.blur_deg_str = "blur_kernel_id"
55
+ elif degradation_blur == DEGRADATION_BLUR_NONE:
56
+ pass
57
+ else:
58
+ raise ValueError(f"Unknown degradation blur {degradation_blur}")
59
+
60
+ self.degradation_noise = DegradationNoise(self.length,
61
+ noise_stddev,
62
+ frozen_seed)
63
+ self.current_degradation = {}
64
+
65
+ def __len__(self):
66
+ return self.length
67
+
68
+ def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor]:
69
+ # TODO there is a bug on this cpu version, the dead leaved dont appear ot be right
70
+ seed = self.frozen_seed + idx if self.frozen_seed is not None else None
71
+ chart = cpu_dead_leaves_chart(self.size, seed=seed, **self.config_dead_leaves)
72
+
73
+ if self.ds_factor > 1:
74
+ # print(f"Downsampling {chart.shape} with factor {self.ds_factor}...")
75
+ sigma = 3/5
76
+ chart = gaussian(
77
+ chart, sigma=(sigma, sigma, 0), mode='nearest',
78
+ cval=0, preserve_range=True, truncate=4.0)
79
+ chart = chart[::self.ds_factor, ::self.ds_factor]
80
+
81
+ th_chart = torch.from_numpy(chart).permute(2, 0, 1).unsqueeze(0)
82
+ degraded_chart = th_chart
83
+
84
+ self.current_degradation[idx] = {}
85
+ if self.degradation_blur_type != DEGRADATION_BLUR_NONE:
86
+ degraded_chart = self.degradation_blur(degraded_chart, idx)
87
+ self.current_degradation[idx][self.blur_deg_str] = self.degradation_blur.current_degradation[idx][self.blur_deg_str]
88
+
89
+ degraded_chart = self.degradation_noise(degraded_chart, idx)
90
+ self.current_degradation[idx]["noise_stddev"] = self.degradation_noise.current_degradation[idx]["noise_stddev"]
91
+
92
+ degraded_chart = degraded_chart.squeeze(0)
93
+ th_chart = th_chart.squeeze(0)
94
+
95
+ return degraded_chart, th_chart
96
+
97
+
98
+ class DeadLeavesDatasetGPU(Dataset):
99
+ def __init__(
100
+ self,
101
+ size: Tuple[int, int] = (128, 128),
102
+ length: int = 1000,
103
+ frozen_seed: int = None, # useful for validation set!
104
+ blur_kernel_half_size: int = [0, 2],
105
+ ds_factor: int = 5,
106
+ noise_stddev: float = [0., 50.],
107
+ use_gaussian_kernel=True,
108
+ **config_dead_leaves
109
+ # number_of_circles: int = -1,
110
+ # background_color: Optional[Tuple[float, float, float]] = (0.5, 0.5, 0.5),
111
+ # colored: Optional[bool] = False,
112
+ # radius_mean: Optional[int] = -1,
113
+ # radius_stddev: Optional[int] = -1,
114
+ ):
115
+ self.frozen_seed = frozen_seed
116
+ self.ds_factor = ds_factor
117
+ self.size = (size[0]*ds_factor, size[1]*ds_factor)
118
+ self.length = length
119
+ self.config_dead_leaves = config_dead_leaves
120
+
121
+ # downsample kernel
122
+ sigma = 3/5
123
+ k_size = 5 # This fits with sigma = 3/5, the cutoff value is 0.0038 (neglectable)
124
+ x = (torch.arange(k_size) - 2).to('cuda')
125
+ kernel = torch.stack(torch.meshgrid((x, x), indexing='ij'))
126
+ kernel.requires_grad = False
127
+ dist_sq = kernel[0]**2 + kernel[1]**2
128
+ kernel = (-dist_sq.square()/(2*sigma**2)).exp()
129
+ kernel = kernel / kernel.sum()
130
+ self.downsample_kernel = kernel.repeat(3, 1, 1, 1) # shape [3, 1, k_size, k_size]
131
+ self.downsample_kernel.requires_grad = False
132
+ self.use_gaussian_kernel = use_gaussian_kernel
133
+ if use_gaussian_kernel:
134
+ self.degradation_blur = DegradationBlurGauss(length,
135
+ blur_kernel_half_size,
136
+ frozen_seed)
137
+ else:
138
+ self.degradation_blur = DegradationBlurMat(length,
139
+ frozen_seed)
140
+
141
+ self.degradation_noise = DegradationNoise(length,
142
+ noise_stddev,
143
+ frozen_seed)
144
+ self.current_degradation = {}
145
+
146
+ def __len__(self) -> int:
147
+ return self.length
148
+
149
+ def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor]:
150
+ """Get a single deadleave chart and its degraded version.
151
+
152
+ Args:
153
+ idx (int): index of the item to retrieve
154
+
155
+ Returns:
156
+ Tuple[torch.Tensor, torch.Tensor]: degraded chart, target chart
157
+ """
158
+ seed = self.frozen_seed + idx if self.frozen_seed is not None else None
159
+
160
+ # Return numba device array
161
+ numba_chart = gpu_dead_leaves_chart(self.size, seed=seed, **self.config_dead_leaves)
162
+ th_chart = torch.as_tensor(numba_chart, dtype=DEFAULT_TORCH_FLOAT_TYPE, device="cuda")[
163
+ None].permute(0, 3, 1, 2) # [1, c, h, w]
164
+ if self.ds_factor > 1:
165
+ # Downsample using strided gaussian conv (sigma=3/5)
166
+ th_chart = F.pad(th_chart,
167
+ pad=(2, 2, 0, 0),
168
+ mode="replicate")
169
+ th_chart = F.conv2d(th_chart,
170
+ self.downsample_kernel,
171
+ padding='valid',
172
+ groups=3,
173
+ stride=self.ds_factor)
174
+
175
+ degraded_chart = self.degradation_blur(th_chart, idx)
176
+ degraded_chart = self.degradation_noise(degraded_chart, idx)
177
+
178
+ blur_deg_str = "blur_kernel_half_size" if self.use_gaussian_kernel else "blur_kernel_id"
179
+ self.current_degradation[idx] = {
180
+ blur_deg_str: self.degradation_blur.current_degradation[idx][blur_deg_str],
181
+ "noise_stddev": self.degradation_noise.current_degradation[idx]["noise_stddev"]
182
+ }
183
+
184
+ degraded_chart = degraded_chart.squeeze(0)
185
+ th_chart = th_chart.squeeze(0)
186
+
187
+ return degraded_chart, th_chart
src/rstor/learning/experiments.py ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from rstor.properties import DEVICE, OPTIMIZER, PARAMS
2
+ from rstor.architecture.selector import load_architecture
3
+ from rstor.data.dataloader import get_data_loader
4
+ from typing import Tuple
5
+ import torch
6
+
7
+
8
+ def get_training_content(
9
+ config: dict,
10
+ training_mode: bool = False,
11
+ device=DEVICE) -> Tuple[torch.nn.Module, torch.optim.Optimizer, dict]:
12
+ model = load_architecture(config)
13
+ optimizer, dl_dict = None, None
14
+ if training_mode:
15
+ optimizer = torch.optim.Adam(model.parameters(), **config[OPTIMIZER][PARAMS])
16
+ dl_dict = get_data_loader(config)
17
+ return model, optimizer, dl_dict
18
+
19
+
20
+ if __name__ == "__main__":
21
+ from rstor.learning.experiments_definition import default_experiment
22
+ config = default_experiment(1)
23
+ model, optimizer, dl_dict = get_training_content(config, training_mode=True)
24
+ print(config)
src/rstor/learning/experiments_definition.py ADDED
@@ -0,0 +1,489 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from rstor.properties import (NB_EPOCHS, DATALOADER, BATCH_SIZE, SIZE, LENGTH,
2
+ TRAIN, VALIDATION, SCHEDULER, REDUCELRONPLATEAU,
3
+ MODEL, ARCHITECTURE, ID, NAME, SCHEDULER_CONFIGURATION, OPTIMIZER, PARAMS, LR,
4
+ LOSS, LOSS_MSE, CONFIG_DEAD_LEAVES,
5
+ SELECTED_METRICS, METRIC_PSNR, METRIC_SSIM, METRIC_LPIPS,
6
+ DATASET_DL_DIV2K_512, DATASET_DIV2K,
7
+ CONFIG_DEGRADATION,
8
+ PRETTY_NAME,
9
+ DEGRADATION_BLUR_NONE, DEGRADATION_BLUR_MAT, DEGRADATION_BLUR_GAUSS,
10
+ AUGMENTATION_FLIP, AUGMENTATION_ROTATE,
11
+ DATASET_DL_EXTRAPRIMITIVES_DIV2K_512)
12
+
13
+
14
+ from typing import Tuple
15
+
16
+
17
+ def model_configurations(config, model_preset="StackedConvolutions", bias: bool = True) -> dict:
18
+ if model_preset == "StackedConvolutions":
19
+ config[MODEL] = {
20
+ ARCHITECTURE: dict(
21
+ num_layers=8,
22
+ k_size=3,
23
+ h_dim=16,
24
+ bias=bias
25
+ ),
26
+ NAME: "StackedConvolutions"
27
+ }
28
+ elif model_preset == "NAFNet" or model_preset == "UNet":
29
+ # https://github.com/megvii-research/NAFNet/blob/main/options/test/GoPro/NAFNet-width64.yml
30
+ config[MODEL] = {
31
+ ARCHITECTURE: dict(
32
+ width=64,
33
+ enc_blk_nums=[1, 1, 1, 28],
34
+ middle_blk_num=1,
35
+ dec_blk_nums=[1, 1, 1, 1],
36
+ ),
37
+ NAME: model_preset
38
+ }
39
+ else:
40
+ raise ValueError(f"Unknown model preset {model_preset}")
41
+
42
+
43
+ def presets_experiments(
44
+ exp: int,
45
+ b: int = 32,
46
+ n: int = 50,
47
+ bias: bool = True,
48
+ length: int = 5000,
49
+ data_size: Tuple[int, int] = (128, 128),
50
+ model_preset: str = "StackedConvolutions",
51
+ lpips: bool = False
52
+ ) -> dict:
53
+ config = {
54
+ ID: exp,
55
+ NAME: f"{exp:04d}",
56
+ NB_EPOCHS: n
57
+ }
58
+ config[DATALOADER] = {
59
+ BATCH_SIZE: {
60
+ TRAIN: b,
61
+ VALIDATION: b
62
+ },
63
+ SIZE: data_size, # (width, height)
64
+ LENGTH: {
65
+ TRAIN: length,
66
+ VALIDATION: 800
67
+ }
68
+ }
69
+ config[OPTIMIZER] = {
70
+ NAME: "Adam",
71
+ PARAMS: {
72
+ LR: 1e-3
73
+ }
74
+ }
75
+ model_configurations(config, model_preset=model_preset, bias=bias)
76
+ config[SCHEDULER] = REDUCELRONPLATEAU
77
+ config[SCHEDULER_CONFIGURATION] = {
78
+ "factor": 0.8,
79
+ "patience": 5
80
+ }
81
+ config[LOSS] = LOSS_MSE
82
+ config[SELECTED_METRICS] = [METRIC_PSNR, METRIC_SSIM]
83
+ if lpips:
84
+ config[SELECTED_METRICS].append(METRIC_LPIPS)
85
+ return config
86
+
87
+
88
+ def get_experiment_config(exp: int) -> dict:
89
+ if exp == -1:
90
+ config = presets_experiments(exp, length=10, n=2)
91
+ elif exp == -2:
92
+ config = presets_experiments(exp, length=10, n=2, lpips=True)
93
+ elif exp == -3:
94
+ config = presets_experiments(exp, n=20)
95
+ config[DATALOADER]["gpu_gen"] = True
96
+ config[DATALOADER][CONFIG_DEAD_LEAVES] = dict(
97
+ blur_kernel_half_size=[0, 0],
98
+ ds_factor=1,
99
+ noise_stddev=[0., 50.]
100
+ )
101
+ config[PRETTY_NAME] = "Vanilla denoise only - ds=1 - noisy 0-50"
102
+ elif exp == -4:
103
+ config = presets_experiments(exp, b=4, n=20)
104
+ config[DATALOADER][NAME] = DATASET_DL_DIV2K_512
105
+ config[DATALOADER][CONFIG_DEGRADATION] = dict(
106
+ noise_stddev=[0., 50.]
107
+ )
108
+ config[PRETTY_NAME] = "Vanilla exp from disk - noisy 0-50"
109
+ elif exp == 1000:
110
+ config = presets_experiments(exp, n=60)
111
+ config[PRETTY_NAME] = "Vanilla small blur"
112
+ config[DATALOADER][CONFIG_DEAD_LEAVES] = dict(blur_kernel_half_size=[0, 2], ds_factor=1, noise_stddev=[0., 0.])
113
+ elif exp == 1001:
114
+ config = presets_experiments(exp, n=60)
115
+ config[DATALOADER][CONFIG_DEAD_LEAVES] = dict(blur_kernel_half_size=[0, 6], ds_factor=1, noise_stddev=[0., 0.])
116
+ config[PRETTY_NAME] = "Vanilla large blur 0 - 6"
117
+ elif exp == 1002:
118
+ config = presets_experiments(exp, n=6) # Less epochs because of the large downsample factor
119
+ config[DATALOADER][CONFIG_DEAD_LEAVES] = dict(blur_kernel_half_size=[0, 2], ds_factor=5, noise_stddev=[0., 0.])
120
+ config[PRETTY_NAME] = "Vanilla small blur - ds=5"
121
+ elif exp == 1003:
122
+ config = presets_experiments(exp, n=6) # Less epochs because of the large downsample factor
123
+ config[DATALOADER][CONFIG_DEAD_LEAVES] = dict(blur_kernel_half_size=[0, 2], ds_factor=5, noise_stddev=[0., 50.])
124
+ config[PRETTY_NAME] = "Vanilla small blur - ds=5 - noisy 0-50"
125
+ elif exp == 1004:
126
+ config = presets_experiments(exp, n=60)
127
+ config[DATALOADER][CONFIG_DEAD_LEAVES] = dict(blur_kernel_half_size=[0, 0], ds_factor=1, noise_stddev=[0., 50.])
128
+ config[PRETTY_NAME] = "Vanilla denoise only - ds=1 - noisy 0-50"
129
+ elif exp == 1005:
130
+ config = presets_experiments(exp, bias=False, n=60)
131
+ config[DATALOADER][CONFIG_DEAD_LEAVES] = dict(blur_kernel_half_size=[0, 0], ds_factor=1, noise_stddev=[0., 50.])
132
+ config[PRETTY_NAME] = "Vanilla denoise only - ds=1 - noisy 0-50 - bias free"
133
+ elif exp == 1006:
134
+ config = presets_experiments(exp, n=60)
135
+ config[PRETTY_NAME] = "Vanilla small blur - noisy 0-50"
136
+ config[DATALOADER][CONFIG_DEAD_LEAVES] = dict(blur_kernel_half_size=[0, 2], ds_factor=1, noise_stddev=[0., 50.])
137
+ elif exp == 1007:
138
+ config = presets_experiments(exp, n=60)
139
+ config[PRETTY_NAME] = "Vanilla large blur 0 - 6 - noisy 0-50"
140
+ config[DATALOADER][CONFIG_DEAD_LEAVES] = dict(blur_kernel_half_size=[0, 6], ds_factor=1, noise_stddev=[0., 50.])
141
+ elif exp == 2000:
142
+ config = presets_experiments(exp, n=60, b=16, model_preset="NAFNet")
143
+ config[PRETTY_NAME] = "NAFNet denoise 0-50"
144
+ config[DATALOADER][CONFIG_DEAD_LEAVES] = dict(
145
+ blur_kernel_half_size=[0, 0],
146
+ ds_factor=1,
147
+ noise_stddev=[0., 50.]
148
+ )
149
+ elif exp == 2001:
150
+ config = presets_experiments(exp, n=60, b=16, model_preset="UNet")
151
+ config[PRETTY_NAME] = "UNEt denoise 0-50"
152
+ config[DATALOADER][CONFIG_DEAD_LEAVES] = dict(
153
+ blur_kernel_half_size=[0, 0],
154
+ ds_factor=1,
155
+ noise_stddev=[0., 50.]
156
+ )
157
+ elif exp == 2002:
158
+ config = presets_experiments(exp, n=20, b=8, data_size=(256, 256), model_preset="NAFNet")
159
+ config[PRETTY_NAME] = "NAFNet denoise 0-50 gpu dl 256x256"
160
+ config[DATALOADER]["gpu_gen"] = True
161
+ config[DATALOADER][CONFIG_DEAD_LEAVES] = dict(
162
+ blur_kernel_half_size=[0, 0],
163
+ ds_factor=1,
164
+ noise_stddev=[0., 50.]
165
+ )
166
+ elif exp == 2003:
167
+ config = presets_experiments(exp, n=20, b=8, data_size=(128, 128), model_preset="NAFNet")
168
+ config[PRETTY_NAME] = "NAFNet denoise 0-50 gpu dl - 128x128"
169
+ config[DATALOADER]["gpu_gen"] = True
170
+ config[DATALOADER][CONFIG_DEAD_LEAVES] = dict(
171
+ blur_kernel_half_size=[0, 0],
172
+ ds_factor=1,
173
+ noise_stddev=[0., 50.]
174
+ )
175
+ elif exp == 2004:
176
+ config = presets_experiments(exp, n=20, b=16, data_size=(128, 128), model_preset="NAFNet")
177
+ config[PRETTY_NAME] = "NAFNet Light denoise 0-50 gpu dl - 128x128"
178
+ config[DATALOADER]["gpu_gen"] = True
179
+ config[DATALOADER][CONFIG_DEAD_LEAVES] = dict(
180
+ blur_kernel_half_size=[0, 0],
181
+ ds_factor=1,
182
+ noise_stddev=[0., 50.]
183
+ )
184
+ config[MODEL][ARCHITECTURE] = dict(
185
+ width=64,
186
+ enc_blk_nums=[1, 1, 1, 2],
187
+ middle_blk_num=1,
188
+ dec_blk_nums=[1, 1, 1, 1],
189
+ )
190
+ elif exp == 2005:
191
+ config = presets_experiments(exp, n=20, b=16, data_size=(128, 128), model_preset="NAFNet")
192
+ config[PRETTY_NAME] = "NAFNet TresLight denoise 0-50 gpu dl - 128x128"
193
+ config[DATALOADER]["gpu_gen"] = True
194
+ config[DATALOADER][CONFIG_DEAD_LEAVES] = dict(
195
+ blur_kernel_half_size=[0, 0],
196
+ ds_factor=1,
197
+ noise_stddev=[0., 50.]
198
+ )
199
+ config[MODEL][ARCHITECTURE] = dict(
200
+ width=64,
201
+ enc_blk_nums=[1, 1, 2],
202
+ middle_blk_num=1,
203
+ dec_blk_nums=[1, 1, 1],
204
+ )
205
+ elif exp == 2006:
206
+ config = presets_experiments(exp, n=20, b=16, data_size=(128, 128), model_preset="NAFNet")
207
+ config[PRETTY_NAME] = "NAFNet TresLight denoise 0-50 ds=5 gpu dl - 128x128"
208
+ config[DATALOADER]["gpu_gen"] = True
209
+ config[DATALOADER][CONFIG_DEAD_LEAVES] = dict(
210
+ blur_kernel_half_size=[0, 0],
211
+ ds_factor=5,
212
+ noise_stddev=[0., 50.]
213
+ )
214
+ config[MODEL][ARCHITECTURE] = dict(
215
+ width=64,
216
+ enc_blk_nums=[1, 1, 2],
217
+ middle_blk_num=1,
218
+ dec_blk_nums=[1, 1, 1],
219
+ )
220
+ elif exp == 2007:
221
+ config = presets_experiments(exp, n=20, b=16, data_size=(128, 128), model_preset="NAFNet")
222
+ config[PRETTY_NAME] = "NAFNet denoise 0-50 gpu dl -ds=5 128x128"
223
+ config[DATALOADER]["gpu_gen"] = True
224
+ config[DATALOADER][CONFIG_DEAD_LEAVES] = dict(
225
+ blur_kernel_half_size=[0, 0],
226
+ ds_factor=5,
227
+ noise_stddev=[0., 50.]
228
+ )
229
+ elif exp == 1008:
230
+ config = presets_experiments(exp, n=20)
231
+ config[DATALOADER]["gpu_gen"] = True
232
+ config[DATALOADER][CONFIG_DEAD_LEAVES] = dict(
233
+ blur_kernel_half_size=[0, 0],
234
+ ds_factor=5,
235
+ noise_stddev=[0., 50.]
236
+ )
237
+ config[PRETTY_NAME] = "Vanilla denoise only - ds=5 - noisy 0-50"
238
+ # ---------------------------------
239
+ # Pure DL DENOISING trainings!
240
+ # ---------------------------------
241
+ elif exp == 3000:
242
+ config = presets_experiments(exp, n=30, b=4, model_preset="NAFNet")
243
+ config[PRETTY_NAME] = "NAFNet denoise - DL_DIV2K_512 0-50"
244
+ config[DATALOADER][NAME] = DATASET_DL_DIV2K_512
245
+ config[DATALOADER][CONFIG_DEGRADATION] = dict(
246
+ noise_stddev=[0., 50.]
247
+ )
248
+ config[DATALOADER][SIZE] = (256, 256)
249
+ elif exp == 3001: # ENABLE GRADIENT CLIPPING
250
+ config = presets_experiments(exp, n=30, b=8, model_preset="NAFNet")
251
+ config[PRETTY_NAME] = "NAFNet41.4M denoise - DL_DIV2K_512 0-50 256x256"
252
+ config[DATALOADER][NAME] = DATASET_DL_DIV2K_512
253
+ config[DATALOADER][CONFIG_DEGRADATION] = dict(
254
+ noise_stddev=[0., 50.]
255
+ )
256
+ config[DATALOADER][SIZE] = (256, 256)
257
+ elif exp == 3002: # ENABLE GRADIENT CLIPPING
258
+ config = presets_experiments(exp, n=30, b=16, model_preset="NAFNet")
259
+ config[PRETTY_NAME] = "NAFNet41.4M denoise - DL_DIV2K_512 0-50 128x128"
260
+ config[DATALOADER][NAME] = DATASET_DL_DIV2K_512
261
+ config[DATALOADER][CONFIG_DEGRADATION] = dict(
262
+ noise_stddev=[0., 50.]
263
+ )
264
+ config[DATALOADER][SIZE] = (128, 128)
265
+ elif exp == 3010 or exp == 3011: # exp 3011 = REDO with Gradient clipping
266
+ config = presets_experiments(exp, n=50, b=4, model_preset="NAFNet")
267
+ config[PRETTY_NAME] = "NAFNet3.4M Light denoise - DL_DIV2K_512 0-50 256x256"
268
+ config[DATALOADER][NAME] = DATASET_DL_DIV2K_512
269
+ config[DATALOADER][CONFIG_DEGRADATION] = dict(
270
+ noise_stddev=[0., 50.]
271
+ )
272
+ config[MODEL][ARCHITECTURE] = dict(
273
+ width=64,
274
+ enc_blk_nums=[1, 1, 2],
275
+ middle_blk_num=1,
276
+ dec_blk_nums=[1, 1, 1],
277
+ )
278
+ config[DATALOADER][SIZE] = (256, 256)
279
+ elif exp == 3020:
280
+ config = presets_experiments(exp, b=32, n=50)
281
+ config[DATALOADER][NAME] = DATASET_DL_DIV2K_512
282
+ config[DATALOADER][CONFIG_DEGRADATION] = dict(
283
+ noise_stddev=[0., 50.]
284
+ )
285
+ config[PRETTY_NAME] = "Vanilla denoise DL 0-50 - noisy 0-50"
286
+ # ---------------------------------
287
+ # Pure DIV2K DENOISING trainings!
288
+ # ---------------------------------
289
+ elif exp == 3120:
290
+ config = presets_experiments(exp, b=32, n=50)
291
+ config[DATALOADER][NAME] = DATASET_DIV2K
292
+ config[DATALOADER][CONFIG_DEGRADATION] = dict(
293
+ noise_stddev=[0., 50.]
294
+ )
295
+ config[PRETTY_NAME] = "Vanilla DIV2K_512 0-50 - noisy 0-50"
296
+ elif exp == 3111:
297
+ config = presets_experiments(exp, n=50, b=4, model_preset="NAFNet")
298
+ config[PRETTY_NAME] = "NAFNet3.4M Light denoise - DIV2K_512 0-50 256x256"
299
+ config[DATALOADER][NAME] = DATASET_DIV2K
300
+ config[DATALOADER][CONFIG_DEGRADATION] = dict(noise_stddev=[0., 50.])
301
+ config[MODEL][ARCHITECTURE] = dict(
302
+ width=64,
303
+ enc_blk_nums=[1, 1, 2],
304
+ middle_blk_num=1,
305
+ dec_blk_nums=[1, 1, 1],
306
+ )
307
+ config[DATALOADER][SIZE] = (256, 256)
308
+ elif exp == 3101:
309
+ config = presets_experiments(exp, n=30, b=8, model_preset="NAFNet")
310
+ config[PRETTY_NAME] = "NAFNet41.4M denoise - DIV2K_512 0-50 256x256"
311
+ config[DATALOADER][NAME] = DATASET_DIV2K
312
+ config[DATALOADER][CONFIG_DEGRADATION] = dict(
313
+ noise_stddev=[0., 50.]
314
+ )
315
+ config[DATALOADER][SIZE] = (256, 256)
316
+ # ---------------------------------
317
+ # Pure EXTRA PRIMITIVES
318
+ # ---------------------------------
319
+ elif exp == 3030:
320
+ config = presets_experiments(exp, b=128, n=50)
321
+ config[DATALOADER][NAME] = DATASET_DL_EXTRAPRIMITIVES_DIV2K_512
322
+ config[DATALOADER][CONFIG_DEGRADATION] = dict(
323
+ noise_stddev=[0., 50.]
324
+ )
325
+ config[PRETTY_NAME] = "Vanilla DL_PRIMITIVES_512 0-50 - noisy 0-50"
326
+ # config[DATALOADER][SIZE] = (256, 256)
327
+ elif exp == 3040:
328
+ config = presets_experiments(exp, n=50, b=8, model_preset="NAFNet")
329
+ config[PRETTY_NAME] = "NAFNet3.4M Light denoise - DL_PRIMITIVES_512 0-50 256x256"
330
+ config[DATALOADER][NAME] = DATASET_DL_EXTRAPRIMITIVES_DIV2K_512
331
+ config[DATALOADER][CONFIG_DEGRADATION] = dict(noise_stddev=[0., 50.])
332
+ config[MODEL][ARCHITECTURE] = dict(
333
+ width=64,
334
+ enc_blk_nums=[1, 1, 2],
335
+ middle_blk_num=1,
336
+ dec_blk_nums=[1, 1, 1],
337
+ )
338
+ config[DATALOADER][SIZE] = (256, 256)
339
+ elif exp == 3050:
340
+ config = presets_experiments(exp, n=30, b=8, model_preset="NAFNet")
341
+ config[PRETTY_NAME] = "NAFNet41.4M denoise - DL_PRIMITIVES_512 0-50 256x256"
342
+ config[DATALOADER][NAME] = DATASET_DL_EXTRAPRIMITIVES_DIV2K_512
343
+ config[DATALOADER][CONFIG_DEGRADATION] = dict(noise_stddev=[0., 50.])
344
+ config[DATALOADER][SIZE] = (256, 256)
345
+ # ---------------------------------
346
+ # DEBLURRING
347
+ # ---------------------------------
348
+ elif exp == 5000:
349
+ config = presets_experiments(exp, n=30, b=8, model_preset="NAFNet")
350
+ config[PRETTY_NAME] = "NAFNet deblur - DL_DIV2K_512 256x256"
351
+ config[DATALOADER][NAME] = DATASET_DL_DIV2K_512
352
+ config[DATALOADER][CONFIG_DEGRADATION] = dict(
353
+ noise_stddev=[0., 0.],
354
+ degradation_blur=DEGRADATION_BLUR_MAT, # Using .mat kernels
355
+ augmentation_list=[AUGMENTATION_FLIP]
356
+ )
357
+ config[DATALOADER][SIZE] = (256, 256)
358
+ elif exp == 5001:
359
+ config = presets_experiments(exp, n=30, b=8, model_preset="NAFNet")
360
+ config[PRETTY_NAME] = "NAFNet deblur - DIV2K_512 256x256"
361
+ config[DATALOADER][NAME] = DATASET_DIV2K
362
+ config[DATALOADER][CONFIG_DEGRADATION] = dict(
363
+ noise_stddev=[0., 0.],
364
+ degradation_blur=DEGRADATION_BLUR_MAT, # Using .mat kernels
365
+ augmentation_list=[AUGMENTATION_FLIP]
366
+ )
367
+ config[DATALOADER][SIZE] = (256, 256)
368
+ elif exp == 5002:
369
+ config = presets_experiments(exp, n=30, b=8, model_preset="NAFNet")
370
+ config[PRETTY_NAME] = "NAFNet deblur - DL_DIV2K_512 256x256"
371
+ config[DATALOADER][NAME] = DATASET_DL_DIV2K_512
372
+ config[DATALOADER][CONFIG_DEGRADATION] = dict(
373
+ noise_stddev=[0., 0.],
374
+ degradation_blur=DEGRADATION_BLUR_MAT, # Using .mat kernels
375
+ augmentation_list=[AUGMENTATION_FLIP]
376
+ )
377
+ config[DATALOADER][SIZE] = (256, 256)
378
+ elif exp == 5003:
379
+ config = presets_experiments(exp, n=30, b=8, model_preset="NAFNet")
380
+ config[PRETTY_NAME] = "NAFNet deblur - DIV2K_512 256x256"
381
+ config[DATALOADER][NAME] = DATASET_DIV2K
382
+ config[DATALOADER][CONFIG_DEGRADATION] = dict(
383
+ noise_stddev=[0., 0.],
384
+ degradation_blur=DEGRADATION_BLUR_MAT, # Using .mat kernels
385
+ augmentation_list=[AUGMENTATION_FLIP]
386
+ )
387
+ config[DATALOADER][SIZE] = (256, 256)
388
+ elif exp == 5004:
389
+ config = presets_experiments(exp, n=30, b=8, model_preset="NAFNet")
390
+ config[PRETTY_NAME] = "NAFNet deblur - DIV2K_512 256x256"
391
+ config[DATALOADER][NAME] = DATASET_DIV2K
392
+ config[DATALOADER][CONFIG_DEGRADATION] = dict(
393
+ noise_stddev=[0., 0.],
394
+ degradation_blur=DEGRADATION_BLUR_MAT, # Using .mat kernels
395
+ augmentation_list=[AUGMENTATION_FLIP]
396
+ )
397
+ config[DATALOADER][SIZE] = (256, 256)
398
+ elif exp == 5005:
399
+ config = presets_experiments(exp, n=30, b=8, model_preset="UNet")
400
+ config[PRETTY_NAME] = "UNet deblur - DL_512 256x256"
401
+ config[DATALOADER][NAME] = DATASET_DL_DIV2K_512
402
+ config[DATALOADER][CONFIG_DEGRADATION] = dict(
403
+ noise_stddev=[0., 0.],
404
+ degradation_blur=DEGRADATION_BLUR_MAT, # Using .mat kernels
405
+ augmentation_list=[AUGMENTATION_FLIP]
406
+ )
407
+ config[DATALOADER][SIZE] = (256, 256)
408
+ # elif exp == 6000: # -> FAILED, no kernels normalization!
409
+ # config = presets_experiments(exp, b=32, n=50)
410
+ # config[DATALOADER][NAME] = DATASET_DL_DIV2K_512
411
+ # config[DATALOADER][CONFIG_DEGRADATION] = dict(
412
+ # noise_stddev=[0., 50.],
413
+ # degradation_blur=DEGRADATION_BLUR_MAT, # Deblur = Using .mat kernels
414
+ # augmentation_list=[AUGMENTATION_FLIP]
415
+ # )
416
+ # config[PRETTY_NAME] = "Vanilla deblur DL_DIV2K_512"
417
+ elif exp == 6002:
418
+ config = presets_experiments(exp, b=128, n=50)
419
+ config[DATALOADER][NAME] = DATASET_DL_DIV2K_512
420
+ config[DATALOADER][CONFIG_DEGRADATION] = dict(
421
+ noise_stddev=[0., 50.],
422
+ degradation_blur=DEGRADATION_BLUR_MAT, # Deblur = Using .mat kernels
423
+ augmentation_list=[AUGMENTATION_FLIP]
424
+ )
425
+ config[PRETTY_NAME] = "Vanilla deblur DL_DIV2K_512"
426
+ # elif exp == 6001: # -> FAILED, no kernels normalization!
427
+ # config = presets_experiments(exp, b=32, n=50)
428
+ # config[DATALOADER][NAME] = DATASET_DIV2K
429
+ # config[DATALOADER][CONFIG_DEGRADATION] = dict(
430
+ # noise_stddev=[0., 50.],
431
+ # degradation_blur=DEGRADATION_BLUR_MAT, # Deblur = Using .mat kernels
432
+ # augmentation_list=[AUGMENTATION_FLIP]
433
+ # )
434
+ # config[PRETTY_NAME] = "Vanilla delbur DIV2K_512"
435
+ elif exp == 6003:
436
+ config = presets_experiments(exp, b=128, n=50)
437
+ config[DATALOADER][NAME] = DATASET_DIV2K
438
+ config[DATALOADER][CONFIG_DEGRADATION] = dict(
439
+ noise_stddev=[0., 50.],
440
+ degradation_blur=DEGRADATION_BLUR_MAT, # Deblur = Using .mat kernels
441
+ augmentation_list=[AUGMENTATION_FLIP]
442
+ )
443
+ config[PRETTY_NAME] = "Vanilla delbur DIV2K_512"
444
+
445
+ elif exp == 7000:
446
+ config = presets_experiments(exp, b=16, n=30, model_preset="NAFNet")
447
+ config[MODEL][ARCHITECTURE] = dict(
448
+ width=64,
449
+ enc_blk_nums=[1, 1, 2],
450
+ middle_blk_num=1,
451
+ dec_blk_nums=[1, 1, 1],
452
+ )
453
+ config[DATALOADER][NAME] = DATASET_DL_DIV2K_512
454
+ config[DATALOADER][CONFIG_DEGRADATION] = dict(
455
+ noise_stddev=[0., 50.],
456
+ degradation_blur=DEGRADATION_BLUR_MAT, # Deblur = Using .mat kernels
457
+ augmentation_list=[AUGMENTATION_FLIP]
458
+ )
459
+ config[PRETTY_NAME] = "NafNet Light deblur DL"
460
+ config[DATALOADER][SIZE] = (256, 256)
461
+ elif exp == 7001:
462
+ config = presets_experiments(exp, b=16, n=50, model_preset="NAFNet")
463
+ config[DATALOADER][NAME] = DATASET_DIV2K
464
+ config[MODEL][ARCHITECTURE] = dict(
465
+ width=64,
466
+ enc_blk_nums=[1, 1, 2],
467
+ middle_blk_num=1,
468
+ dec_blk_nums=[1, 1, 1],
469
+ )
470
+ config[DATALOADER][CONFIG_DEGRADATION] = dict(
471
+ noise_stddev=[0., 50.],
472
+ degradation_blur=DEGRADATION_BLUR_MAT, # Deblur = Using .mat kernels
473
+ augmentation_list=[AUGMENTATION_FLIP]
474
+ )
475
+ config[PRETTY_NAME] = "NafNet Light deblur DIV2K"
476
+ config[DATALOADER][SIZE] = (256, 256)
477
+ elif exp == 7002:
478
+ config = presets_experiments(exp, n=20, b=8, model_preset="UNet")
479
+ config[PRETTY_NAME] = "UNET deblur - DIV2K"
480
+ config[DATALOADER][NAME] = DATASET_DIV2K
481
+ config[DATALOADER][CONFIG_DEGRADATION] = dict(
482
+ noise_stddev=[0., 0.],
483
+ degradation_blur=DEGRADATION_BLUR_MAT, # Using .mat kernels
484
+ augmentation_list=[AUGMENTATION_FLIP]
485
+ )
486
+ config[DATALOADER][SIZE] = (256, 256)
487
+ else:
488
+ raise ValueError(f"Experiment {exp} not found")
489
+ return config
src/rstor/learning/loss.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from typing import Optional
3
+ from rstor.properties import LOSS_MSE
4
+
5
+
6
+ def compute_loss(
7
+ predic: torch.Tensor,
8
+ target: torch.Tensor,
9
+ mode: Optional[str] = LOSS_MSE
10
+ ) -> torch.Tensor:
11
+ """
12
+ Compute loss based on the predicted and true values.
13
+
14
+ Args:
15
+ predic (torch.Tensor): [N, C, H, W] predicted values
16
+ target (torch.Tensor): [N, C, H, W] target values.
17
+ mode (Optional[str], optional): mode of loss computation.
18
+
19
+ Returns:
20
+ torch.Tensor: The computed loss.
21
+ """
22
+ assert mode in [LOSS_MSE], f"Mode {mode} not supported"
23
+ if mode == LOSS_MSE:
24
+ loss = torch.nn.functional.mse_loss(predic, target)
25
+ return loss
src/rstor/learning/metrics.py ADDED
@@ -0,0 +1,140 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from rstor.properties import METRIC_PSNR, METRIC_SSIM, METRIC_LPIPS, REDUCTION_AVERAGE, REDUCTION_SKIP, REDUCTION_SUM
3
+ from torchmetrics.image import StructuralSimilarityIndexMeasure as SSIM
4
+ from torchmetrics.image.lpip import LearnedPerceptualImagePatchSimilarity
5
+ from typing import List, Optional
6
+ ALL_METRICS = [METRIC_PSNR, METRIC_SSIM, METRIC_LPIPS]
7
+
8
+
9
+ def compute_psnr(
10
+ predic: torch.Tensor,
11
+ target: torch.Tensor,
12
+ clamp_mse=1e-10,
13
+ reduction: Optional[str] = REDUCTION_AVERAGE
14
+ ) -> torch.Tensor:
15
+ """
16
+ Compute the average PSNR metric for a batch of predicted and true values.
17
+
18
+ Args:
19
+ predic (torch.Tensor): [N, C, H, W] predicted values.
20
+ target (torch.Tensor): [N, C, H, W] target values.
21
+ reduction (str): Reduction method. REDUCTION_AVERAGE/REDUCTION_SKIP/REDUCTION_SUM.
22
+
23
+ Returns:
24
+ torch.Tensor: The average PSNR value for the batch.
25
+ """
26
+ with torch.no_grad():
27
+ mse_per_image = torch.mean((predic - target) ** 2, dim=(-3, -2, -1))
28
+ mse_per_image = torch.clamp(mse_per_image, min=clamp_mse)
29
+ psnr_per_image = 10 * torch.log10(1 / mse_per_image)
30
+ if reduction == REDUCTION_AVERAGE:
31
+ average_psnr = torch.mean(psnr_per_image)
32
+ elif reduction == REDUCTION_SUM:
33
+ average_psnr = torch.sum(psnr_per_image)
34
+ elif reduction == REDUCTION_SKIP:
35
+ average_psnr = psnr_per_image
36
+ else:
37
+ raise ValueError(f"Unknown reduction {reduction}")
38
+ return average_psnr
39
+
40
+
41
+ def compute_ssim(
42
+ predic: torch.Tensor,
43
+ target: torch.Tensor,
44
+ reduction: Optional[str] = REDUCTION_AVERAGE
45
+ ) -> torch.Tensor:
46
+ """
47
+ Compute the average SSIM metric for a batch of predicted and true values.
48
+
49
+ Args:
50
+ predic (torch.Tensor): [N, C, H, W] predicted values.
51
+ target (torch.Tensor): [N, C, H, W] target values.
52
+ reduction (str): Reduction method. REDUCTION_AVERAGE/REDUCTION_SKIP.
53
+
54
+ Returns:
55
+ torch.Tensor: The average SSIM value for the batch.
56
+ """
57
+ with torch.no_grad():
58
+ reduction_mode = {
59
+ REDUCTION_SKIP: None,
60
+ REDUCTION_AVERAGE: "elementwise_mean",
61
+ REDUCTION_SUM: "sum"
62
+ }[reduction]
63
+ ssim = SSIM(data_range=1.0, reduction=reduction_mode).to(predic.device)
64
+ assert predic.shape == target.shape, f"{predic.shape} != {target.shape}"
65
+ assert predic.device == target.device, f"{predic.device} != {target.device}"
66
+ ssim_value = ssim(predic, target)
67
+ return ssim_value
68
+
69
+
70
+ def compute_lpips(
71
+ predic: torch.Tensor,
72
+ target: torch.Tensor,
73
+ reduction: Optional[str] = REDUCTION_AVERAGE,
74
+ ) -> torch.Tensor:
75
+ """
76
+ Compute the average LPIPS metric for a batch of predicted and true values.
77
+ https://richzhang.github.io/PerceptualSimilarity/
78
+
79
+ Args:
80
+ predic (torch.Tensor): [N, C, H, W] predicted values.
81
+ target (torch.Tensor): [N, C, H, W] target values.
82
+ reduction (str): Reduction method. REDUCTION_AVERAGE/REDUCTION_SKIP.
83
+
84
+ Returns:
85
+ torch.Tensor: The average SSIM value for the batch.
86
+ """
87
+ reduction_mode = {
88
+ REDUCTION_SKIP: "sum", # does not really matter
89
+ REDUCTION_AVERAGE: "mean",
90
+ REDUCTION_SUM: "sum"
91
+ }[reduction]
92
+
93
+ with torch.no_grad():
94
+ lpip_metrics = LearnedPerceptualImagePatchSimilarity(
95
+ reduction=reduction_mode,
96
+ normalize=True # If set to True will instead expect input to be in the [0,1] range.
97
+ ).to(predic.device)
98
+ assert predic.shape == target.shape, f"{predic.shape} != {target.shape}"
99
+ assert predic.device == target.device, f"{predic.device} != {target.device}"
100
+ if reduction == REDUCTION_SKIP:
101
+ lpip_value = []
102
+ for idx in range(predic.shape[0]):
103
+ lpip_value.append(lpip_metrics(
104
+ predic[idx, ...].unsqueeze(0).clip(0, 1),
105
+ target[idx, ...].unsqueeze(0).clip(0, 1)
106
+ ))
107
+ lpip_value = torch.stack(lpip_value)
108
+ elif reduction in [REDUCTION_SUM, REDUCTION_AVERAGE]:
109
+ lpip_value = lpip_metrics(predic.clip(0, 1), target.clip(0, 1))
110
+ return lpip_value
111
+
112
+
113
+ def compute_metrics(
114
+ predic: torch.Tensor,
115
+ target: torch.Tensor,
116
+ reduction: Optional[str] = REDUCTION_AVERAGE,
117
+ chosen_metrics: Optional[List[str]] = ALL_METRICS) -> dict:
118
+ """
119
+ Compute the metrics for a batch of predicted and true values.
120
+
121
+ Args:
122
+ predic (torch.Tensor): [N, C, H, W] predicted values.
123
+ target (torch.Tensor): [N, C, H, W] target values.
124
+ reduction (str): Reduction method. REDUCTION_AVERAGE/REDUCTION_SKIP/REDUCTION SUM.
125
+ chosen_metrics (list): List of metrics to compute, default [METRIC_PSNR, METRIC_SSIM]
126
+
127
+ Returns:
128
+ dict: computed metrics.
129
+ """
130
+ metrics = {}
131
+ if METRIC_PSNR in chosen_metrics:
132
+ average_psnr = compute_psnr(predic, target, reduction=reduction)
133
+ metrics[METRIC_PSNR] = average_psnr.item() if reduction != REDUCTION_SKIP else average_psnr
134
+ if METRIC_SSIM in chosen_metrics:
135
+ ssim_value = compute_ssim(predic, target, reduction=reduction)
136
+ metrics[METRIC_SSIM] = ssim_value.item() if reduction != REDUCTION_SKIP else ssim_value
137
+ if METRIC_LPIPS in chosen_metrics:
138
+ lpip_value = compute_lpips(predic, target, reduction=reduction)
139
+ metrics[METRIC_LPIPS] = lpip_value.item() if reduction != REDUCTION_SKIP else lpip_value
140
+ return metrics
src/rstor/properties.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from pathlib import Path
3
+ RELU = "ReLU"
4
+ LEAKY_RELU = "LeakyReLU"
5
+ SIMPLE_GATE = "simple_gate"
6
+ LOSS = "loss"
7
+ LOSS_MSE = "MSE"
8
+ METRIC_PSNR = "PSNR"
9
+ METRIC_SSIM = "SSIM"
10
+ METRIC_LPIPS = "LPIPS"
11
+ SELECTED_METRICS = "selected_metrics"
12
+ DATALOADER = "data_loader"
13
+ BATCH_SIZE = "batch_size"
14
+ SIZE = "size"
15
+ TRAIN, VALIDATION, TEST = "train", "validation", "test"
16
+ LENGTH = "length"
17
+ ID = "id"
18
+ NAME = "name"
19
+ PRETTY_NAME = "pretty_name"
20
+ NB_EPOCHS = "nb_epochs"
21
+ ARCHITECTURE = "architecture"
22
+ MODEL = "model"
23
+ NAME = "name"
24
+ N_PARAMS = "n_params"
25
+ OPTIMIZER = "optimizer"
26
+ LR = "lr"
27
+ PARAMS = "parameters"
28
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
29
+ SCHEDULER_CONFIGURATION = "scheduler_configuration"
30
+ SCHEDULER = "scheduler"
31
+ REDUCELRONPLATEAU = "ReduceLROnPlateau"
32
+ ARCHITECTURE = "architecture"
33
+ CONFIG_DEAD_LEAVES = "config_dead_leaves"
34
+ CONFIG_DEGRADATION = "config_degradation"
35
+ REDUCTION_SUM = "reduction_sum"
36
+ REDUCTION_AVERAGE = "reduction_average"
37
+ REDUCTION_SKIP = "reduction_skip"
38
+ TRACES_TARGET = "target"
39
+ TRACES_DEGRADED = "degraded"
40
+ TRACES_RESTORED = "restored"
41
+ TRACES_METRICS = "metrics"
42
+ TRACES_ALL = "all"
43
+
44
+ DEGRADATION_BLUR_NONE = "none"
45
+ DEGRADATION_BLUR_MAT = "mat"
46
+ DEGRADATION_BLUR_GAUSS = "gauss"
47
+
48
+
49
+ SAMPLER_SATURATED = "saturated"
50
+ SAMPLER_UNIFORM = "uniform"
51
+ SAMPLER_NATURAL = "natural"
52
+ SAMPLER_DIV2K = "div2k"
53
+
54
+ DATASET_FOLDER = "__dataset"
55
+ DATASET_PATH = Path(__file__).parent.parent.parent/DATASET_FOLDER
56
+ DATASET_DL_RANDOMRGB_1024 = "deadleaves_randomrgb_1024"
57
+ DATASET_DL_DIV2K_1024 = "deadleaves_div2k_1024"
58
+ DATASET_DL_DIV2K_512 = "deadleaves_div2k_512"
59
+ DATASET_DL_EXTRAPRIMITIVES_DIV2K_512 = "deadleaves_primitives_div2k_512"
60
+ DATASET_SYNTH_LIST = [DATASET_DL_DIV2K_512, DATASET_DL_DIV2K_1024,
61
+ DATASET_DL_RANDOMRGB_1024, DATASET_DL_EXTRAPRIMITIVES_DIV2K_512]
62
+ DATASET_BLUR_KERNEL_PATH = DATASET_PATH / "kernels" / "custom_blur_centered.mat"
63
+ AUGMENTATION_FLIP = "flip"
64
+ AUGMENTATION_ROTATE = "rotate"
65
+
66
+
67
+ DATASET_DIV2K = "div2k"
src/rstor/synthetic_data/color_sampler.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import cv2
3
+
4
+ from pathlib import Path
5
+ from rstor.properties import DATASET_PATH
6
+ from typing import List
7
+
8
+
9
+ def sample_uniform_rgb(size: int, seed: int = None) -> np.ndarray:
10
+ """
11
+ Generate n random RGB values.
12
+
13
+ Args:
14
+ n (int): number of colors to sample
15
+ seed (int, optional): Seed for the random number generator. Defaults to None.
16
+
17
+ Returns:
18
+ np.ndarray: Random RGB values as a numpy array.
19
+ """
20
+ # https://github.com/numpy/numpy/issues/17079
21
+ # https://numpy.org/devdocs/reference/random/new-or-different.html#new-or-different
22
+ rng = np.random.default_rng(np.random.SeedSequence(seed))
23
+
24
+ random_samples = rng.uniform(size=(size, 3))
25
+ rgb = random_samples
26
+
27
+ # Below old version with sturation
28
+ # lab = (random_samples + np.array([0., -0.5, -0.5])[None]) * np.array([100., 127 * 2, 127 * 2])[None]
29
+ # rgb = cv2.cvtColor(lab[None, :].astype(np.float32), cv2.COLOR_Lab2RGB)
30
+ return rgb.squeeze()
31
+
32
+
33
+ def sample_saturated_color(size: int, seed: int = None) -> np.ndarray:
34
+ """
35
+ Generate n saturated RGB values.
36
+
37
+ Args:
38
+ n (int): number of colors to sample
39
+ seed (int, optional): Seed for the random number generator. Defaults to None.
40
+
41
+ Returns:
42
+ np.ndarray: Random RGB values as a numpy array.
43
+ """
44
+ # https://github.com/numpy/numpy/issues/17079
45
+ # https://numpy.org/devdocs/reference/random/new-or-different.html#new-or-different
46
+ rng = np.random.default_rng(np.random.SeedSequence(seed))
47
+
48
+ random_samples = rng.uniform(size=(size, 3))
49
+
50
+ lab = (random_samples + np.array([0., -0.5, -0.5])[None]) * np.array([100., 127 * 2, 127 * 2])[None]
51
+ rgb = cv2.cvtColor(lab[None, :].astype(np.float32), cv2.COLOR_Lab2RGB)
52
+ return rgb.squeeze()
53
+
54
+
55
+ def sample_color_from_images(size: int, seed: int = None, path_to_images: List[Path] = []) -> np.ndarray:
56
+ print("path : ", path_to_images)
57
+ assert len(path_to_images) > 0, "Please provide a list of images to sample colors from."
58
+ rng = np.random.default_rng(np.random.SeedSequence(seed))
59
+
60
+ # Randomly pick an image and load it
61
+ img_id = rng.integers(0, len(path_to_images))
62
+
63
+ img = cv2.imread(path_to_images[img_id].as_posix())
64
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) / 255
65
+
66
+ pixels = img.reshape(-1, 3)
67
+ n_pixels = pixels.shape[0]
68
+
69
+ # sample a pixel color for each disc
70
+ pixel_ids = rng.integers(0, n_pixels, size)
71
+ colors = pixels[pixel_ids, :]
72
+
73
+ return colors
src/rstor/synthetic_data/dead_leaves_cpu.py ADDED
@@ -0,0 +1,79 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Tuple, Optional, List
2
+ import numpy as np
3
+ import cv2
4
+ from rstor.synthetic_data.dead_leaves_sampler import define_dead_leaves_chart
5
+ from rstor.properties import SAMPLER_UNIFORM
6
+ from pathlib import Path
7
+
8
+
9
+ def cpu_dead_leaves_chart(size: Tuple[int, int] = (100, 100),
10
+ number_of_circles: int = -1,
11
+ background_color: Optional[Tuple[float, float, float]] = (0.5, 0.5, 0.5),
12
+ colored: Optional[bool] = True,
13
+ radius_min: Optional[int] = -1,
14
+ radius_max: Optional[int] = -1,
15
+ radius_alpha: Optional[int] = 3,
16
+ seed: int = None,
17
+ reverse: Optional[bool] = True,
18
+ sampler=SAMPLER_UNIFORM,
19
+ natural_image_list: List[Path] = []) -> np.ndarray:
20
+ """
21
+ Generation of a deqqad leaves chart by splatting circles on top of each other.
22
+
23
+ Args:
24
+ size (Tuple[int, int], optional): size of the generated chart. Defaults to (100, 100).
25
+ number_of_circles (int, optional): number of circles to generate.
26
+ If negative, it is computed based on the size. Defaults to -1.
27
+ background_color (Optional[Tuple[float, float, float]], optional):
28
+ background color of the chart. Defaults to gray (0.5, 0.5, 0.5).
29
+ colored (Optional[bool], optional): Whether to generate colored circles. Defaults to True.
30
+ radius_min (Optional[int], optional): minimum radius of the circles. Defaults to -1. (=> 1)
31
+ radius_max (Optional[int], optional): maximum radius of the circles. Defaults to -1. (=> 2000)
32
+ radius_alpha (Optional[int], optional): standard deviation of the radius of the circles.
33
+ If negative, it is calculated based on the size. Defaults to -1.
34
+ seed (int, optional): seed for the random number generator. Defaults to None
35
+ reverse: (Optional[bool], optional): View circles from the back view
36
+ by reversing order. Defaults to True.
37
+ WARNING: This option is extremely slow on CPU.
38
+
39
+ Returns:
40
+ np.ndarray: generated dead leaves chart as a NumPy array.
41
+ """
42
+ center_x, center_y, radius, color = define_dead_leaves_chart(
43
+ size,
44
+ number_of_circles,
45
+ colored,
46
+ radius_min,
47
+ radius_max,
48
+ radius_alpha,
49
+ seed,
50
+ sampler=sampler,
51
+ natural_image_list=natural_image_list
52
+ )
53
+ if not colored:
54
+ color = np.concatenate((color, color, color), axis=1)
55
+
56
+ if reverse:
57
+ chart = np.zeros((size[0], size[1], 3), dtype=np.float32)
58
+ buffer = np.zeros_like(chart)
59
+ is_not_covered_mask = np.ones((*chart.shape[:2], 1))
60
+ for i in range(number_of_circles):
61
+ cv2.circle(buffer, (center_x[i], center_y[i]), radius[i], color[i], -1)
62
+ chart += buffer * is_not_covered_mask
63
+ is_not_covered_mask = cv2.circle(is_not_covered_mask, (center_x[i], center_y[i]), radius[i], 0, -1)
64
+
65
+ if not np.any(is_not_covered_mask):
66
+ break
67
+
68
+ chart += np.multiply(background_color, np.ones((size[0], size[1], 3), dtype=np.float32)) * is_not_covered_mask
69
+ else:
70
+ chart = np.multiply(background_color, np.ones((size[0], size[1], 3), dtype=np.float32))
71
+ for i in range(number_of_circles):
72
+ # circle is inplace
73
+ cv2.circle(chart, (center_x[i], center_y[i]), radius[i], color[i], -1)
74
+
75
+ chart = chart.clip(0, 1)
76
+
77
+ if not colored:
78
+ chart = chart[:, :, 0, None] # return shape [h, w, 1] in gray mode
79
+ return chart
src/rstor/synthetic_data/dead_leaves_gpu.py ADDED
@@ -0,0 +1,176 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from rstor.utils import DEFAULT_NUMPY_FLOAT_TYPE, THREADS_PER_BLOCK
2
+ from rstor.properties import SAMPLER_UNIFORM
3
+ from typing import Tuple, Optional
4
+ from rstor.synthetic_data.dead_leaves_cpu import define_dead_leaves_chart
5
+ import numpy as np
6
+ from numba import cuda
7
+ import math
8
+
9
+
10
+ def gpu_dead_leaves_chart(
11
+ size: Tuple[int, int] = (100, 100),
12
+ number_of_circles: int = -1,
13
+ background_color: Optional[Tuple[float, float, float]] = (0.5, 0.5, 0.5),
14
+ colored: Optional[bool] = True,
15
+ radius_min: Optional[int] = -1,
16
+ radius_max: Optional[int] = -1,
17
+ radius_alpha: Optional[int] = 3,
18
+ seed: int = None,
19
+ reverse=True,
20
+ sampler=SAMPLER_UNIFORM,
21
+ natural_image_list=None,
22
+ circle_primitives: bool = True,
23
+ anisotropy: float = 1.,
24
+ angle: float = 0.
25
+ ) -> np.ndarray:
26
+ center_x, center_y, radius, color = define_dead_leaves_chart(
27
+ size,
28
+ number_of_circles,
29
+ colored,
30
+ radius_min,
31
+ radius_max,
32
+ radius_alpha,
33
+ seed,
34
+ sampler=sampler,
35
+ natural_image_list=natural_image_list
36
+ )
37
+
38
+ # Generate on gpu
39
+ chart = _generate_dead_leaves(
40
+ size,
41
+ centers=np.stack((center_x, center_y), axis=-1),
42
+ radia=radius,
43
+ colors=color,
44
+ background=background_color,
45
+ reverse=reverse,
46
+ circle_primitives=circle_primitives,
47
+ anisotropy=anisotropy,
48
+ angle=angle
49
+ )
50
+
51
+ return chart
52
+
53
+
54
+ def _generate_dead_leaves(size, centers, radia, colors, background, reverse, circle_primitives: bool, anisotropy: float = 1., angle: float=0.):
55
+ assert centers.ndim == 2
56
+ ny, nx = size
57
+ nc = colors.shape[-1]
58
+
59
+ # Init empty array on GPU
60
+ generation_ = cuda.device_array((ny, nx, nc), DEFAULT_NUMPY_FLOAT_TYPE)
61
+ # Move useful array to GPU
62
+ centers_ = cuda.to_device(centers)
63
+ radia_ = cuda.to_device(radia)
64
+ colors_ = cuda.to_device(colors)
65
+
66
+ # Dispatch threads
67
+ threadsperblock = (THREADS_PER_BLOCK, THREADS_PER_BLOCK)
68
+ blockspergrid_x = math.ceil(nx/threadsperblock[1])
69
+ blockspergrid_y = math.ceil(ny/threadsperblock[0])
70
+ blockspergrid = (blockspergrid_x, blockspergrid_y)
71
+
72
+ if reverse:
73
+ cuda_dead_leaves_gen_reversed[blockspergrid, threadsperblock](
74
+ generation_,
75
+ centers_,
76
+ radia_,
77
+ colors_,
78
+ background,
79
+ circle_primitives,
80
+ anisotropy,
81
+ np.deg2rad(angle)
82
+ )
83
+ else:
84
+ cuda_dead_leaves_gen[blockspergrid, threadsperblock](
85
+ generation_,
86
+ centers_,
87
+ radia_,
88
+ colors_,
89
+ background)
90
+
91
+ return generation_
92
+
93
+
94
+ @cuda.jit(cache=False)
95
+ def cuda_dead_leaves_gen_reversed(generation, centers, radia, colors, background, circle_primitives: bool, anisotropy: float, angle: float):
96
+ idx, idy = cuda.grid(2)
97
+ ny, nx, nc = generation.shape
98
+
99
+ n_discs = centers.shape[0]
100
+
101
+ # Out of bound threads
102
+ if idx >= nx or idy >= ny:
103
+ return
104
+
105
+ for disc_id in range(n_discs):
106
+ dx_ = idx - centers[disc_id, 0]
107
+ dy_ = idy - centers[disc_id, 1]
108
+ dx = math.cos(angle)*dx_ + math.sin(angle)*dy_
109
+ dy = -math.sin(angle)*dx_ + math.cos(angle)*dy_
110
+ dx = dx * anisotropy
111
+ dist_sq = dx*dx + dy*dy
112
+
113
+ # Naive thread diverging version
114
+ r = radia[disc_id]
115
+ r_sq = r*r
116
+ if circle_primitives:
117
+ if dist_sq <= r_sq:
118
+ # Copy back to global memory
119
+ for c in range(nc):
120
+ generation[idy, idx, c] = colors[disc_id, c]
121
+ return
122
+ else:
123
+ if (disc_id % 4) == 0 and dist_sq <= r_sq:
124
+ # Copy back to global memory
125
+ alpha = dist_sq/r_sq
126
+ for c in range(nc):
127
+ generation[idy, idx, c] = colors[disc_id, c] * alpha + colors[disc_id, (c+1) % 3] * (1-alpha)
128
+ return
129
+ elif (disc_id % 4) == 1 and (abs(dx)+abs(dy)) <= r:
130
+ # Copy back to global memory
131
+ alpha = dist_sq/r_sq
132
+ for c in range(nc):
133
+ generation[idy, idx, c] = colors[disc_id, c] * alpha + colors[disc_id, (c+1) % 3] * (1-alpha)
134
+ return
135
+ elif (disc_id % 4) == 2 and abs(dx) <= r and abs(dy) <= r:
136
+ for c in range(nc):
137
+ generation[idy, idx, c] = colors[disc_id, c]
138
+ return
139
+ elif (disc_id % 200) == 3 and abs(dy) <= r//5:
140
+ for c in range(nc):
141
+ generation[idy, idx, c] = colors[disc_id, c] * alpha + colors[disc_id, (c+1) % 3] * (1-alpha)
142
+ return
143
+ elif (disc_id % 200) == 4 and abs(dx) <= r//5:
144
+ for c in range(nc):
145
+ generation[idy, idx, c] = colors[disc_id, c] * alpha + colors[disc_id, (c+1) % 3] * (1-alpha)
146
+ return
147
+ for c in range(nc):
148
+ generation[idy, idx, c] = background[c]
149
+
150
+
151
+ @cuda.jit(cache=False)
152
+ def cuda_dead_leaves_gen(generation, centers, radia, colors, background):
153
+ idx, idy, c = cuda.grid(3)
154
+ ny, nx, nc = generation.shape
155
+
156
+ n_discs = centers.shape[0]
157
+
158
+ # Out of bound threads
159
+ if idx >= nx or idy >= ny:
160
+ return
161
+
162
+ out = background[c]
163
+ for disc_id in range(n_discs):
164
+ dx = idx - centers[disc_id, 0]
165
+ dy = idy - centers[disc_id, 1]
166
+ dist_sq = dx*dx + dy*dy
167
+
168
+ # Naive thread diverging version
169
+ r = radia[disc_id]
170
+ r_sq = r*r
171
+
172
+ if dist_sq <= r_sq:
173
+ out = colors[disc_id, c]
174
+
175
+ # Copy back to global memory
176
+ generation[idy, idx, c] = out
src/rstor/synthetic_data/dead_leaves_sampler.py ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Tuple, Optional, List
2
+ from rstor.properties import SAMPLER_SATURATED, SAMPLER_NATURAL, SAMPLER_UNIFORM
3
+ from rstor.synthetic_data.color_sampler import sample_uniform_rgb, sample_saturated_color, sample_color_from_images
4
+ import numpy as np
5
+ from pathlib import Path
6
+
7
+
8
+ def define_dead_leaves_chart(
9
+ size: Tuple[int, int] = (100, 100),
10
+ number_of_circles: int = -1,
11
+ colored: Optional[bool] = True,
12
+ radius_min: Optional[int] = -1,
13
+ radius_max: Optional[int] = -1,
14
+ radius_alpha: Optional[int] = 3,
15
+ seed: int = None,
16
+ sampler=SAMPLER_UNIFORM,
17
+ natural_image_list: Optional[List[Path]] = None
18
+ ) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
19
+ """
20
+ Defines the geometric and color properties of the primitives in the dead leaves chart to later be sampled.
21
+
22
+ Args:
23
+ size (Tuple[int, int], optional): size of the generated chart. Defaults to (100, 100).
24
+ number_of_circles (int, optional): number of circles to generate.
25
+ If negative, it is computed based on the size. Defaults to -1.
26
+ colored (Optional[bool], optional): Whether to generate colored circles. Defaults to True.
27
+ radius_min (Optional[int], optional): minimum radius of the circles. Defaults to -1. (=> 1)
28
+ radius_max (Optional[int], optional): maximum radius of the circles. Defaults to -1. (=> 2000)
29
+ radius_alpha (Optional[int], optional): standard deviation of the radius of the circles.
30
+ If negative, it is calculated based on the size. Defaults to -1.
31
+ seed (int, optional): seed for the random number generator. Defaults to None
32
+
33
+ Returns:
34
+ Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]: center_x, center_y, radius, color
35
+ """
36
+ rng = np.random.default_rng(np.random.SeedSequence(seed))
37
+
38
+ if number_of_circles < 0:
39
+ number_of_circles = 30 * max(size)
40
+ if radius_min < 0.:
41
+ radius_min = 1.
42
+ if radius_max < 0.:
43
+ radius_max = 2000.
44
+
45
+ # Pick random circle centers and radii
46
+ center_x = rng.integers(0, size[1], size=number_of_circles)
47
+ center_y = rng.integers(0, size[0], size=number_of_circles)
48
+
49
+ # Sample from a power law distribution for the p(radius=r) = (r.clip(radius_min, radius_max))^(-alpha)
50
+
51
+ radius = rng.uniform(
52
+ low=radius_max ** (1 - radius_alpha),
53
+ high=radius_min ** (1 - radius_alpha),
54
+ size=number_of_circles
55
+ )
56
+ # Using the change of variables formula for random variables.
57
+ radius = radius ** (-1/(radius_alpha - 1))
58
+ radius = radius.round().astype(int)
59
+
60
+ # Pick random colors
61
+ if colored:
62
+ if sampler == SAMPLER_UNIFORM:
63
+ color = sample_uniform_rgb(number_of_circles, seed=rng.integers(0, 1e10)).astype(float)
64
+ elif sampler == SAMPLER_SATURATED:
65
+ color = sample_saturated_color(number_of_circles, seed=rng.integers(0, 1e10)).astype(float)
66
+ elif sampler == SAMPLER_NATURAL:
67
+ assert natural_image_list is not None, "Please provide a list of images to sample colors from."
68
+ color = sample_color_from_images(number_of_circles, seed=rng.integers(0, 1e10),
69
+ path_to_images=natural_image_list).astype(float)
70
+ else:
71
+ raise NotImplementedError(f"Unknown color sampler {sampler}")
72
+ else:
73
+ color = rng.uniform(0.25, 0.75, size=(number_of_circles, 1))
74
+ return center_x, center_y, radius, color
src/rstor/synthetic_data/interactive/interactive_dead_leaves.py ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from rstor.synthetic_data.dead_leaves_cpu import cpu_dead_leaves_chart
2
+ from rstor.synthetic_data.dead_leaves_gpu import gpu_dead_leaves_chart
3
+ from rstor.properties import SAMPLER_UNIFORM, SAMPLER_DIV2K, SAMPLER_NATURAL, SAMPLER_SATURATED, DATASET_PATH
4
+ import sys
5
+ import numpy as np
6
+ from interactive_pipe import interactive_pipeline, interactive
7
+ from typing import Optional
8
+
9
+
10
+ def dead_leave_plugin(ds=1):
11
+ interactive(
12
+ background_intensity=(0.5, [0., 1.]),
13
+ number_of_circles=(-1, [-1, 10000]),
14
+ colored=(True,),
15
+ radius_alpha=(3., [2., 10.]),
16
+ seed=(0, [-1, 42]),
17
+ ds=(ds, [1, 5]),
18
+ numba_flag=(True,), # Default CPU to avoid issues by default
19
+ sampler=(SAMPLER_UNIFORM, [SAMPLER_UNIFORM, SAMPLER_DIV2K, SAMPLER_SATURATED]),
20
+ circle_primitives=(True,),
21
+ anisotropy=(1., [0.1, 10.]),
22
+ angle=(0., [-180., 180.])
23
+ # ds=(ds, [1, 5])
24
+ )(generate_deadleave)
25
+
26
+
27
+ def generate_deadleave(
28
+ background_intensity: float = 0.5,
29
+ number_of_circles: int = -1,
30
+ colored: Optional[bool] = False,
31
+ radius_alpha: Optional[int] = 3,
32
+ seed=0,
33
+ ds=3,
34
+ numba_flag=True,
35
+ sampler=SAMPLER_UNIFORM,
36
+ circle_primitives=True,
37
+ anisotropy=1.,
38
+ angle=0.,
39
+ global_params={}
40
+ ) -> np.ndarray:
41
+ global_params["ds_factor"] = ds
42
+ bg_color = (background_intensity, background_intensity, background_intensity)
43
+ natural_image_list = None
44
+ if sampler == SAMPLER_DIV2K:
45
+ sampler = SAMPLER_NATURAL
46
+ div2k_path = DATASET_PATH / "div2k" / "DIV2K_train_HR" / "DIV2K_train_HR"
47
+ natural_image_list = sorted([file for file in div2k_path.glob("*.png")])
48
+ if not numba_flag:
49
+ chart = cpu_dead_leaves_chart((512*ds, 512*ds), number_of_circles, bg_color, colored,
50
+ radius_alpha=radius_alpha,
51
+ seed=None if seed < 0 else seed,
52
+ sampler=sampler,
53
+ reverse=False,
54
+ natural_image_list=natural_image_list)
55
+ else:
56
+ chart = gpu_dead_leaves_chart((512*ds, 512*ds), number_of_circles, bg_color, colored,
57
+ radius_alpha=radius_alpha,
58
+ seed=None if seed < 0 else seed,
59
+ sampler=sampler,
60
+ natural_image_list=natural_image_list,
61
+ circle_primitives=circle_primitives,
62
+ anisotropy=anisotropy,
63
+ angle=angle).copy_to_host()
64
+ if chart.shape[-1] == 1:
65
+ chart = chart.repeat(3, axis=-1)
66
+ # Required to switch from colors to gray scale visualization.
67
+ return chart
68
+
69
+
70
+ def deadleave_pipeline():
71
+ deadleave_chart = generate_deadleave()
72
+ return deadleave_chart
73
+
74
+
75
+ def main(argv):
76
+ dead_leave_plugin(ds=1)
77
+ interactive_pipeline(gui="auto")(deadleave_pipeline)()
78
+
79
+
80
+ if __name__ == "__main__":
81
+ main(sys.argv[1:])
src/rstor/utils.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import numba
3
+ import torch
4
+
5
+ THREADS_PER_BLOCK = 32 # 32 or 16
6
+ DEFAULT_NUMPY_FLOAT_TYPE = np.float32
7
+ DEFAULT_CUDA_FLOAT_TYPE = numba.float32
8
+ DEFAULT_TORCH_FLOAT_TYPE = torch.float32
9
+
10
+
11
+ DEFAULT_NUMPY_INT_TYPE = np.int32
12
+ DEFAULT_CUDA_INT_TYPE = numba.int32
test/test_dataloader.py ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from rstor.data.synthetic_dataloader import DeadLeavesDataset
3
+
4
+
5
+ def test_dead_leaves_dataset():
6
+ # Test case 1: Default parameters
7
+ dataset = DeadLeavesDataset(noise_stddev=(0, 0), ds_factor=1)
8
+ assert len(dataset) == 1000
9
+ assert dataset.size == (128, 128)
10
+ assert dataset.frozen_seed is None
11
+ assert dataset.config_dead_leaves == {}
12
+
13
+ # Test case 2: Custom parameters
14
+ dataset = DeadLeavesDataset(size=(256, 256), length=500, frozen_seed=42, number_of_circles=5,
15
+ background_color=(0.2, 0.4, 0.6), colored=True, radius_min=1, radius_alpha=3,
16
+ noise_stddev=(0, 0), ds_factor=1)
17
+ assert len(dataset) == 500
18
+ assert dataset.size == (256, 256)
19
+ assert dataset.frozen_seed == 42
20
+ assert dataset.config_dead_leaves == {
21
+ 'number_of_circles': 5,
22
+ 'background_color': (0.2, 0.4, 0.6),
23
+ 'colored': True,
24
+ 'radius_min': 1,
25
+ 'radius_alpha': 3
26
+ }
27
+
28
+ # Test case 3: Check item retrieval
29
+ item, item_tgt = dataset[0]
30
+ assert isinstance(item, torch.Tensor)
31
+ assert item.shape == (3, 256, 256)
32
+
33
+ # Test case 4: Repeatable results with frozen seed
34
+ dataset1 = DeadLeavesDataset(frozen_seed=42, noise_stddev=(0, 0), number_of_circles=256)
35
+ dataset2 = DeadLeavesDataset(frozen_seed=42, noise_stddev=(0, 0), number_of_circles=256)
36
+ item1, item_tgt1 = dataset1[0]
37
+ item2, item_tgt2 = dataset2[0]
38
+ assert torch.all(torch.eq(item1, item2))
39
+
40
+ # Test case 5: Visualize
41
+ # dataset = DeadLeavesDataset(size=(256, 256), length=500, frozen_seed=43,
42
+ # background_color=(0.2, 0.4, 0.6), colored=True, radius_min=1, radius_alpha=3,
43
+ # noise_stddev=(0, 0), ds_factor=1)
44
+ # item, item_tgt = dataset[0]
45
+ # import matplotlib.pyplot as plt
46
+ # plt.figure()
47
+ # plt.imshow(item.permute(1, 2, 0).detach().cpu())
48
+ # plt.show()
49
+ # print("done")
test/test_dataloader_gpu.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from rstor.data.synthetic_dataloader import DeadLeavesDatasetGPU
3
+ import numba
4
+
5
+
6
+ def test_dead_leaves_dataset_gpu():
7
+ if not numba.cuda.is_available():
8
+ return
9
+
10
+ # Test case 1: Default parameters
11
+ dataset = DeadLeavesDatasetGPU(noise_stddev=(0, 0), ds_factor=1)
12
+ assert len(dataset) == 1000
13
+ assert dataset.size == (128, 128)
14
+ assert dataset.frozen_seed is None
15
+ assert dataset.config_dead_leaves == {}
16
+
17
+ # Test case 2: Custom parameters
18
+ dataset = DeadLeavesDatasetGPU(size=(256, 256), length=500, frozen_seed=42, number_of_circles=5,
19
+ background_color=(0.2, 0.4, 0.6), colored=True, radius_min=1, radius_alpha=3,
20
+ noise_stddev=(0, 0), ds_factor=1)
21
+ assert len(dataset) == 500
22
+ assert dataset.size == (256, 256)
23
+ assert dataset.frozen_seed == 42
24
+ assert dataset.config_dead_leaves == {
25
+ 'number_of_circles': 5,
26
+ 'background_color': (0.2, 0.4, 0.6),
27
+ 'colored': True,
28
+ 'radius_min': 1,
29
+ 'radius_alpha': 3
30
+ }
31
+
32
+ # Test case 3: Check item retrieval
33
+ item, item_tgt = dataset[0]
34
+ assert isinstance(item, torch.Tensor)
35
+ assert item.shape == (3, 256, 256)
36
+
37
+ # Test case 4: Repeatable results with frozen seed
38
+ dataset1 = DeadLeavesDatasetGPU(frozen_seed=42, noise_stddev=(0, 0), number_of_circles=256)
39
+ dataset2 = DeadLeavesDatasetGPU(frozen_seed=42, noise_stddev=(0, 0), number_of_circles=256)
40
+ item1, item_tgt1 = dataset1[0]
41
+ item2, item_tgt2 = dataset2[0]
42
+ assert torch.all(torch.eq(item1, item2))
43
+
44
+
45
+
46
+ # Test case 5: Visualize
47
+ # dataset = DeadLeavesDatasetGPU(size=(256, 256), length=500, frozen_seed=44, number_of_circles=10_000,
48
+ # background_color=(0.2, 0.4, 0.6), colored=True, radius_min=1, radius_alpha=3,
49
+ # noise_stddev=(0, 0), ds_factor=1)
50
+ # item, item_tgt = dataset[0]
51
+ # import matplotlib.pyplot as plt
52
+ # plt.figure()
53
+ # plt.imshow(item.permute(1, 2, 0).detach().cpu())
54
+ # plt.show()
55
+ # print("done")
56
+
test/test_dataloader_stored.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from rstor.data.stored_images_dataloader import RestorationDataset
3
+ from numba import cuda
4
+ from rstor.properties import DATASET_PATH, AUGMENTATION_FLIP, AUGMENTATION_ROTATE
5
+
6
+
7
+ def test_dataloader_stored():
8
+ if not cuda.is_available():
9
+ print("cuda unavailable, exiting")
10
+ return
11
+
12
+ # Test case 1: Default parameters
13
+ dataset = RestorationDataset(noise_stddev=(0, 0),
14
+ images_path=DATASET_PATH/"sample")
15
+ assert len(dataset) == 2
16
+ assert dataset.frozen_seed is None
17
+
18
+ # Test case 2: Custom parameters
19
+ dataset = RestorationDataset(images_path=DATASET_PATH/"sample",
20
+ size=(64, 64),
21
+ frozen_seed=42,
22
+ noise_stddev=(0, 0))
23
+ assert len(dataset) == 2
24
+ assert dataset.frozen_seed == 42
25
+
26
+ # Test case 3: Check item retrieval
27
+ item, item_tgt = dataset[0]
28
+ assert isinstance(item, torch.Tensor)
29
+ assert item.shape == item_tgt.shape
30
+ assert item.shape == (3, 64, 64)
31
+
32
+ # Test case 4: Repeatable results with frozen seed
33
+ dataset1 = RestorationDataset(images_path=DATASET_PATH/"sample",
34
+ frozen_seed=42, noise_stddev=(0, 0))
35
+ dataset2 = RestorationDataset(images_path=DATASET_PATH/"sample",
36
+ frozen_seed=42, noise_stddev=(0, 0))
37
+ item1, item_tgt1 = dataset1[0]
38
+ item2, item_tgt2 = dataset2[0]
39
+
40
+ assert torch.all(torch.eq(item1, item2))
41
+
42
+ # Test case 4: Repeatable results with frozen seed and augmentation
43
+ augmentation_list = [AUGMENTATION_FLIP, AUGMENTATION_ROTATE]
44
+ dataset1 = RestorationDataset(images_path=DATASET_PATH/"sample",
45
+ frozen_seed=42, noise_stddev=(0, 0),
46
+ augmentation_list=augmentation_list)
47
+ dataset2 = RestorationDataset(images_path=DATASET_PATH/"sample",
48
+ frozen_seed=42, noise_stddev=(0, 0),
49
+ augmentation_list=augmentation_list)
50
+ item1, item_tgt1 = dataset1[0]
51
+ item2, item_tgt2 = dataset2[0]
52
+ assert torch.all(torch.eq(item1, item2))
53
+
54
+
55
+
56
+ # Test case 5: Visualize
57
+ # dataset = RestorationDataset(images_path=DATASET_PATH/"sample",
58
+ # noise_stddev=(0, 0),
59
+ # augmentation_list=augmentation_list)
60
+ # item, item_tgt = dataset[0]
61
+ # import matplotlib.pyplot as plt
62
+ # plt.figure()
63
+ # plt.imshow(item.permute(1, 2, 0).detach().cpu())
64
+ # plt.show()
65
+ # breakpoint()
66
+ print("done")
67
+
test/test_dead_leaves.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ from rstor.synthetic_data.dead_leaves_cpu import cpu_dead_leaves_chart
3
+ from rstor.properties import SAMPLER_NATURAL, DATASET_PATH
4
+
5
+
6
+ def test_dead_leaves_chart():
7
+ # Test case 1: Default parameters
8
+ chart = cpu_dead_leaves_chart()
9
+ assert isinstance(chart, np.ndarray)
10
+ assert chart.shape == (100, 100, 3)
11
+
12
+ # Test case 2: Custom size and number of circles
13
+ chart = cpu_dead_leaves_chart(size=(200, 150), number_of_circles=10)
14
+ assert isinstance(chart, np.ndarray)
15
+ assert chart.shape == (200, 150, 3)
16
+
17
+ # Test case 3: Colored circles
18
+ chart = cpu_dead_leaves_chart(colored=True, number_of_circles=300)
19
+ assert isinstance(chart, np.ndarray)
20
+ assert chart.shape == (100, 100, 3)
21
+
22
+ # Test case 4: Custom radius mean and stddev
23
+ chart = cpu_dead_leaves_chart(radius_min=5, radius_alpha=2, number_of_circles=300)
24
+ assert isinstance(chart, np.ndarray)
25
+ assert chart.shape == (100, 100, 3)
26
+
27
+ # Test case 5: Custom background color
28
+ chart = cpu_dead_leaves_chart(background_color=(0.2, 0.4, 0.6), number_of_circles=300)
29
+ assert isinstance(chart, np.ndarray)
30
+ assert chart.shape == (100, 100, 3)
31
+
32
+ # Test case 6: Custom seed
33
+ chart1 = cpu_dead_leaves_chart(seed=42, number_of_circles=300)
34
+ chart2 = cpu_dead_leaves_chart(seed=42, number_of_circles=300)
35
+ assert np.array_equal(chart1, chart2)
36
+
37
+
38
+ def test_dead_leaves_color_sampler():
39
+ img_list = sorted(
40
+ list((DATASET_PATH / "sample").glob("*.png"))
41
+ )
42
+ _gen = cpu_dead_leaves_chart(number_of_circles=300, sampler=SAMPLER_NATURAL, natural_image_list=img_list)
43
+ # from interactive_pipe.data_objects.image import Image
44
+ # Image(_gen).show()