Spaces:
Running
Running
# -*- coding: utf-8 -*- | |
""" | |
Created on Sat Mar 23 15:38:28 2024 | |
@author: jamyl | |
""" | |
import cv2 | |
from pathlib import Path | |
from time import perf_counter | |
import matplotlib.pyplot as plt | |
from typing import Tuple | |
import logging | |
import numpy as np | |
import torch | |
import torch.nn.functional as F | |
from torch.utils.data import Dataset | |
try: | |
from numba import cuda | |
except ImportError: | |
logging.warning("Numba not installed, GPU acceleration will not be available") | |
cuda = None | |
from tqdm import tqdm | |
import argparse | |
from rstor.synthetic_data.dead_leaves_gpu import gpu_dead_leaves_chart | |
from rstor.utils import DEFAULT_TORCH_FLOAT_TYPE | |
from rstor.properties import DATASET_PATH, DATASET_DL_RANDOMRGB_1024, DATASET_DL_DIV2K_1024, SAMPLER_NATURAL, SAMPLER_UNIFORM, DATASET_DL_DIV2K_512, DATASET_DL_EXTRAPRIMITIVES_DIV2K_512 | |
class DeadLeavesDatasetGPU(Dataset): | |
def __init__( | |
self, | |
size: Tuple[int, int] = (128, 128), | |
length: int = 1000, | |
frozen_seed: int = None, # useful for validation set! | |
ds_factor: int = 5, | |
**config_dead_leaves | |
): | |
self.frozen_seed = frozen_seed | |
self.ds_factor = ds_factor | |
self.size = (size[0]*ds_factor, size[1]*ds_factor) | |
self.length = length | |
self.config_dead_leaves = config_dead_leaves | |
# downsample kernel | |
sigma = 3/5 | |
k_size = 5 # This fits with sigma = 3/5, the cutoff value is 0.0038 (neglectable) | |
x = (torch.arange(k_size) - 2).to('cuda') | |
kernel = torch.stack(torch.meshgrid((x, x), indexing='ij')) | |
dist_sq = kernel[0]**2 + kernel[1]**2 | |
kernel = (-dist_sq.square()/(2*sigma**2)).exp() | |
kernel = kernel / kernel.sum() | |
self.downsample_kernel = kernel.repeat(3, 1, 1, 1) # shape [3, 1, k_size, k_size] | |
def __len__(self) -> int: | |
return self.length | |
def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor]: | |
"""Get a single deadleave chart and its degraded version. | |
Args: | |
idx (int): index of the item to retrieve | |
Returns: | |
Tuple[torch.Tensor, torch.Tensor]: degraded chart, target chart | |
""" | |
seed = self.frozen_seed + idx if self.frozen_seed is not None else None | |
# Return numba device array | |
numba_chart = gpu_dead_leaves_chart(self.size, seed=seed, **self.config_dead_leaves) | |
if self.ds_factor > 1: | |
# print(f"Downsampling {chart.shape} with factor {self.ds_factor}...") | |
# Downsample using strided gaussian conv (sigma=3/5) | |
th_chart = torch.as_tensor(numba_chart, dtype=DEFAULT_TORCH_FLOAT_TYPE, | |
device="cuda").permute(2, 0, 1)[None] # [b, c, h, w] | |
th_chart = F.pad(th_chart, | |
pad=(2, 2, 0, 0), | |
mode="replicate") | |
th_chart = F.conv2d(th_chart, | |
self.downsample_kernel, | |
padding='valid', | |
groups=3, | |
stride=self.ds_factor) | |
# Convert back to numba | |
numba_chart = cuda.as_cuda_array(th_chart.permute(0, 2, 3, 1)) # [b, h, w, c] | |
# convert back to numpy (temporary for legacy) | |
chart = numba_chart.copy_to_host()[0] | |
return chart | |
def generate_images(path: Path, dataset: Dataset, imin=0): | |
for i in tqdm(range(imin, dataset.length)): | |
img = dataset[i] | |
img = (img * 255).astype(np.uint8) | |
out_path = path / "{:04d}.png".format(i) | |
cv2.imwrite(out_path.as_posix(), img) | |
def bench(dataset): | |
print("dataset initialised") | |
t1 = perf_counter() | |
chart = dataset[0] | |
d = (perf_counter()-t1) | |
print(f"generation done {d}") | |
print(f"{d*1_000/60} min for 1_000") | |
plt.imshow(chart) | |
plt.show() | |
if __name__ == "__main__": | |
argparser = argparse.ArgumentParser() | |
argparser.add_argument("-o", "--output-dir", type=str, default=str(DATASET_PATH)) | |
argparser.add_argument( | |
"-n", "--name", type=str, | |
choices=[DATASET_DL_RANDOMRGB_1024, DATASET_DL_DIV2K_1024, | |
DATASET_DL_DIV2K_512, DATASET_DL_EXTRAPRIMITIVES_DIV2K_512], | |
default=DATASET_DL_RANDOMRGB_1024 | |
) | |
argparser.add_argument("-b", "--benchmark", action="store_true") | |
default_config = dict( | |
size=(1_024, 1_024), | |
length=1_000, | |
frozen_seed=42, | |
background_color=(0.2, 0.4, 0.6), | |
colored=True, | |
radius_min=5, | |
radius_max=2_000, | |
ds_factor=5, | |
) | |
args = argparser.parse_args() | |
dataset_dir = args.output_dir | |
name = args.name | |
path = Path(dataset_dir)/name | |
# print(path) | |
path.mkdir(parents=True, exist_ok=True) | |
if name == DATASET_DL_RANDOMRGB_1024: | |
config = default_config | |
config["sampler"] = SAMPLER_UNIFORM | |
elif name == DATASET_DL_DIV2K_1024: | |
config = default_config | |
config["sampler"] = SAMPLER_NATURAL | |
config["natural_image_list"] = sorted( | |
list((DATASET_PATH / "div2k" / "DIV2K_train_HR" / "DIV2K_train_HR").glob("*.png")) | |
) | |
elif name == DATASET_DL_DIV2K_512: | |
config = default_config | |
config["size"] = (512, 512) | |
config["rmin"] = 3 | |
config["length"] = 4000 | |
config["sampler"] = SAMPLER_NATURAL | |
config["natural_image_list"] = sorted( | |
list((DATASET_PATH / "div2k" / "DIV2K_train_HR" / "DIV2K_train_HR").glob("*.png")) | |
) | |
elif name == DATASET_DL_EXTRAPRIMITIVES_DIV2K_512: | |
config = default_config | |
config["size"] = (512, 512) | |
config["sampler"] = SAMPLER_NATURAL | |
config["circle_primitives"] = False | |
config["length"] = 4000 | |
config["natural_image_list"] = sorted( | |
list((DATASET_PATH / "div2k" / "DIV2K_train_HR" / "DIV2K_train_HR").glob("*.png")) | |
) | |
else: | |
raise NotImplementedError | |
dataset = DeadLeavesDatasetGPU(**config) | |
if args.benchmark: | |
bench(dataset) | |
else: | |
generate_images(path, dataset) | |