# -*- coding: utf-8 -*- """ Created on Sat Mar 23 15:38:28 2024 @author: jamyl """ import cv2 from pathlib import Path from time import perf_counter import matplotlib.pyplot as plt from typing import Tuple import logging import numpy as np import torch import torch.nn.functional as F from torch.utils.data import Dataset try: from numba import cuda except ImportError: logging.warning("Numba not installed, GPU acceleration will not be available") cuda = None from tqdm import tqdm import argparse from rstor.synthetic_data.dead_leaves_gpu import gpu_dead_leaves_chart from rstor.utils import DEFAULT_TORCH_FLOAT_TYPE from rstor.properties import DATASET_PATH, DATASET_DL_RANDOMRGB_1024, DATASET_DL_DIV2K_1024, SAMPLER_NATURAL, SAMPLER_UNIFORM, DATASET_DL_DIV2K_512, DATASET_DL_EXTRAPRIMITIVES_DIV2K_512 class DeadLeavesDatasetGPU(Dataset): def __init__( self, size: Tuple[int, int] = (128, 128), length: int = 1000, frozen_seed: int = None, # useful for validation set! ds_factor: int = 5, **config_dead_leaves ): self.frozen_seed = frozen_seed self.ds_factor = ds_factor self.size = (size[0]*ds_factor, size[1]*ds_factor) self.length = length self.config_dead_leaves = config_dead_leaves # downsample kernel sigma = 3/5 k_size = 5 # This fits with sigma = 3/5, the cutoff value is 0.0038 (neglectable) x = (torch.arange(k_size) - 2).to('cuda') kernel = torch.stack(torch.meshgrid((x, x), indexing='ij')) dist_sq = kernel[0]**2 + kernel[1]**2 kernel = (-dist_sq.square()/(2*sigma**2)).exp() kernel = kernel / kernel.sum() self.downsample_kernel = kernel.repeat(3, 1, 1, 1) # shape [3, 1, k_size, k_size] def __len__(self) -> int: return self.length def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor]: """Get a single deadleave chart and its degraded version. Args: idx (int): index of the item to retrieve Returns: Tuple[torch.Tensor, torch.Tensor]: degraded chart, target chart """ seed = self.frozen_seed + idx if self.frozen_seed is not None else None # Return numba device array numba_chart = gpu_dead_leaves_chart(self.size, seed=seed, **self.config_dead_leaves) if self.ds_factor > 1: # print(f"Downsampling {chart.shape} with factor {self.ds_factor}...") # Downsample using strided gaussian conv (sigma=3/5) th_chart = torch.as_tensor(numba_chart, dtype=DEFAULT_TORCH_FLOAT_TYPE, device="cuda").permute(2, 0, 1)[None] # [b, c, h, w] th_chart = F.pad(th_chart, pad=(2, 2, 0, 0), mode="replicate") th_chart = F.conv2d(th_chart, self.downsample_kernel, padding='valid', groups=3, stride=self.ds_factor) # Convert back to numba numba_chart = cuda.as_cuda_array(th_chart.permute(0, 2, 3, 1)) # [b, h, w, c] # convert back to numpy (temporary for legacy) chart = numba_chart.copy_to_host()[0] return chart def generate_images(path: Path, dataset: Dataset, imin=0): for i in tqdm(range(imin, dataset.length)): img = dataset[i] img = (img * 255).astype(np.uint8) out_path = path / "{:04d}.png".format(i) cv2.imwrite(out_path.as_posix(), img) def bench(dataset): print("dataset initialised") t1 = perf_counter() chart = dataset[0] d = (perf_counter()-t1) print(f"generation done {d}") print(f"{d*1_000/60} min for 1_000") plt.imshow(chart) plt.show() if __name__ == "__main__": argparser = argparse.ArgumentParser() argparser.add_argument("-o", "--output-dir", type=str, default=str(DATASET_PATH)) argparser.add_argument( "-n", "--name", type=str, choices=[DATASET_DL_RANDOMRGB_1024, DATASET_DL_DIV2K_1024, DATASET_DL_DIV2K_512, DATASET_DL_EXTRAPRIMITIVES_DIV2K_512], default=DATASET_DL_RANDOMRGB_1024 ) argparser.add_argument("-b", "--benchmark", action="store_true") default_config = dict( size=(1_024, 1_024), length=1_000, frozen_seed=42, background_color=(0.2, 0.4, 0.6), colored=True, radius_min=5, radius_max=2_000, ds_factor=5, ) args = argparser.parse_args() dataset_dir = args.output_dir name = args.name path = Path(dataset_dir)/name # print(path) path.mkdir(parents=True, exist_ok=True) if name == DATASET_DL_RANDOMRGB_1024: config = default_config config["sampler"] = SAMPLER_UNIFORM elif name == DATASET_DL_DIV2K_1024: config = default_config config["sampler"] = SAMPLER_NATURAL config["natural_image_list"] = sorted( list((DATASET_PATH / "div2k" / "DIV2K_train_HR" / "DIV2K_train_HR").glob("*.png")) ) elif name == DATASET_DL_DIV2K_512: config = default_config config["size"] = (512, 512) config["rmin"] = 3 config["length"] = 4000 config["sampler"] = SAMPLER_NATURAL config["natural_image_list"] = sorted( list((DATASET_PATH / "div2k" / "DIV2K_train_HR" / "DIV2K_train_HR").glob("*.png")) ) elif name == DATASET_DL_EXTRAPRIMITIVES_DIV2K_512: config = default_config config["size"] = (512, 512) config["sampler"] = SAMPLER_NATURAL config["circle_primitives"] = False config["length"] = 4000 config["natural_image_list"] = sorted( list((DATASET_PATH / "div2k" / "DIV2K_train_HR" / "DIV2K_train_HR").glob("*.png")) ) else: raise NotImplementedError dataset = DeadLeavesDatasetGPU(**config) if args.benchmark: bench(dataset) else: generate_images(path, dataset)