File size: 6,114 Bytes
cec5823
 
 
 
 
 
 
 
 
 
 
86d104b
cec5823
 
 
 
86d104b
 
 
 
 
cec5823
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
# -*- coding: utf-8 -*-
"""
Created on Sat Mar 23 15:38:28 2024

@author: jamyl
"""
import cv2
from pathlib import Path
from time import perf_counter
import matplotlib.pyplot as plt
from typing import Tuple
import logging
import numpy as np
import torch
import torch.nn.functional as F
from torch.utils.data import Dataset
try:
    from numba import cuda
except ImportError:
    logging.warning("Numba not installed, GPU acceleration will not be available")
    cuda = None
from tqdm import tqdm
import argparse
from rstor.synthetic_data.dead_leaves_gpu import gpu_dead_leaves_chart
from rstor.utils import DEFAULT_TORCH_FLOAT_TYPE
from rstor.properties import DATASET_PATH, DATASET_DL_RANDOMRGB_1024, DATASET_DL_DIV2K_1024, SAMPLER_NATURAL, SAMPLER_UNIFORM, DATASET_DL_DIV2K_512, DATASET_DL_EXTRAPRIMITIVES_DIV2K_512


class DeadLeavesDatasetGPU(Dataset):
    def __init__(
        self,
        size: Tuple[int, int] = (128, 128),
        length: int = 1000,
        frozen_seed: int = None,  # useful for validation set!
        ds_factor: int = 5,
        **config_dead_leaves
    ):

        self.frozen_seed = frozen_seed
        self.ds_factor = ds_factor
        self.size = (size[0]*ds_factor, size[1]*ds_factor)
        self.length = length
        self.config_dead_leaves = config_dead_leaves

        # downsample kernel
        sigma = 3/5
        k_size = 5  # This fits with sigma = 3/5, the cutoff value is 0.0038 (neglectable)
        x = (torch.arange(k_size) - 2).to('cuda')
        kernel = torch.stack(torch.meshgrid((x, x), indexing='ij'))
        dist_sq = kernel[0]**2 + kernel[1]**2
        kernel = (-dist_sq.square()/(2*sigma**2)).exp()
        kernel = kernel / kernel.sum()
        self.downsample_kernel = kernel.repeat(3, 1, 1, 1)  # shape [3, 1, k_size, k_size]

    def __len__(self) -> int:
        return self.length

    def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor]:
        """Get a single deadleave chart and its degraded version.

        Args:
            idx (int): index of the item to retrieve

        Returns:
            Tuple[torch.Tensor, torch.Tensor]: degraded chart, target chart
        """
        seed = self.frozen_seed + idx if self.frozen_seed is not None else None

        # Return numba device array
        numba_chart = gpu_dead_leaves_chart(self.size, seed=seed, **self.config_dead_leaves)
        if self.ds_factor > 1:
            # print(f"Downsampling {chart.shape} with factor {self.ds_factor}...")

            # Downsample using strided gaussian conv (sigma=3/5)
            th_chart = torch.as_tensor(numba_chart, dtype=DEFAULT_TORCH_FLOAT_TYPE,
                                       device="cuda").permute(2, 0, 1)[None]  # [b, c, h, w]
            th_chart = F.pad(th_chart,
                             pad=(2, 2, 0, 0),
                             mode="replicate")
            th_chart = F.conv2d(th_chart,
                                self.downsample_kernel,
                                padding='valid',
                                groups=3,
                                stride=self.ds_factor)

            # Convert back to numba
            numba_chart = cuda.as_cuda_array(th_chart.permute(0, 2, 3, 1))  # [b, h, w, c]

        # convert back to numpy (temporary for legacy)
        chart = numba_chart.copy_to_host()[0]

        return chart


def generate_images(path: Path, dataset: Dataset, imin=0):
    for i in tqdm(range(imin, dataset.length)):
        img = dataset[i]
        img = (img * 255).astype(np.uint8)
        out_path = path / "{:04d}.png".format(i)
        cv2.imwrite(out_path.as_posix(), img)


def bench(dataset):

    print("dataset initialised")
    t1 = perf_counter()
    chart = dataset[0]

    d = (perf_counter()-t1)
    print(f"generation done {d}")
    print(f"{d*1_000/60} min for 1_000")
    plt.imshow(chart)
    plt.show()


if __name__ == "__main__":
    argparser = argparse.ArgumentParser()
    argparser.add_argument("-o", "--output-dir", type=str, default=str(DATASET_PATH))
    argparser.add_argument(
        "-n", "--name", type=str,
        choices=[DATASET_DL_RANDOMRGB_1024, DATASET_DL_DIV2K_1024,
                 DATASET_DL_DIV2K_512, DATASET_DL_EXTRAPRIMITIVES_DIV2K_512],
        default=DATASET_DL_RANDOMRGB_1024
    )
    argparser.add_argument("-b", "--benchmark", action="store_true")
    default_config = dict(
        size=(1_024, 1_024),
        length=1_000,
        frozen_seed=42,
        background_color=(0.2, 0.4, 0.6),
        colored=True,
        radius_min=5,
        radius_max=2_000,
        ds_factor=5,
    )

    args = argparser.parse_args()
    dataset_dir = args.output_dir
    name = args.name
    path = Path(dataset_dir)/name
    # print(path)
    path.mkdir(parents=True, exist_ok=True)
    if name == DATASET_DL_RANDOMRGB_1024:
        config = default_config
        config["sampler"] = SAMPLER_UNIFORM
    elif name == DATASET_DL_DIV2K_1024:
        config = default_config
        config["sampler"] = SAMPLER_NATURAL
        config["natural_image_list"] = sorted(
            list((DATASET_PATH / "div2k" / "DIV2K_train_HR" / "DIV2K_train_HR").glob("*.png"))
        )
    elif name == DATASET_DL_DIV2K_512:
        config = default_config
        config["size"] = (512, 512)
        config["rmin"] = 3
        config["length"] = 4000
        config["sampler"] = SAMPLER_NATURAL
        config["natural_image_list"] = sorted(
            list((DATASET_PATH / "div2k" / "DIV2K_train_HR" / "DIV2K_train_HR").glob("*.png"))
        )
    elif name == DATASET_DL_EXTRAPRIMITIVES_DIV2K_512:
        config = default_config
        config["size"] = (512, 512)
        config["sampler"] = SAMPLER_NATURAL
        config["circle_primitives"] = False
        config["length"] = 4000
        config["natural_image_list"] = sorted(
            list((DATASET_PATH / "div2k" / "DIV2K_train_HR" / "DIV2K_train_HR").glob("*.png"))
        )
    else:
        raise NotImplementedError
    dataset = DeadLeavesDatasetGPU(**config)
    if args.benchmark:
        bench(dataset)
    else:
        generate_images(path, dataset)