image-deblurring / scripts /save_deadleaves.py
balthou's picture
disable numba imports
86d104b
# -*- coding: utf-8 -*-
"""
Created on Sat Mar 23 15:38:28 2024
@author: jamyl
"""
import cv2
from pathlib import Path
from time import perf_counter
import matplotlib.pyplot as plt
from typing import Tuple
import logging
import numpy as np
import torch
import torch.nn.functional as F
from torch.utils.data import Dataset
try:
from numba import cuda
except ImportError:
logging.warning("Numba not installed, GPU acceleration will not be available")
cuda = None
from tqdm import tqdm
import argparse
from rstor.synthetic_data.dead_leaves_gpu import gpu_dead_leaves_chart
from rstor.utils import DEFAULT_TORCH_FLOAT_TYPE
from rstor.properties import DATASET_PATH, DATASET_DL_RANDOMRGB_1024, DATASET_DL_DIV2K_1024, SAMPLER_NATURAL, SAMPLER_UNIFORM, DATASET_DL_DIV2K_512, DATASET_DL_EXTRAPRIMITIVES_DIV2K_512
class DeadLeavesDatasetGPU(Dataset):
def __init__(
self,
size: Tuple[int, int] = (128, 128),
length: int = 1000,
frozen_seed: int = None, # useful for validation set!
ds_factor: int = 5,
**config_dead_leaves
):
self.frozen_seed = frozen_seed
self.ds_factor = ds_factor
self.size = (size[0]*ds_factor, size[1]*ds_factor)
self.length = length
self.config_dead_leaves = config_dead_leaves
# downsample kernel
sigma = 3/5
k_size = 5 # This fits with sigma = 3/5, the cutoff value is 0.0038 (neglectable)
x = (torch.arange(k_size) - 2).to('cuda')
kernel = torch.stack(torch.meshgrid((x, x), indexing='ij'))
dist_sq = kernel[0]**2 + kernel[1]**2
kernel = (-dist_sq.square()/(2*sigma**2)).exp()
kernel = kernel / kernel.sum()
self.downsample_kernel = kernel.repeat(3, 1, 1, 1) # shape [3, 1, k_size, k_size]
def __len__(self) -> int:
return self.length
def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor]:
"""Get a single deadleave chart and its degraded version.
Args:
idx (int): index of the item to retrieve
Returns:
Tuple[torch.Tensor, torch.Tensor]: degraded chart, target chart
"""
seed = self.frozen_seed + idx if self.frozen_seed is not None else None
# Return numba device array
numba_chart = gpu_dead_leaves_chart(self.size, seed=seed, **self.config_dead_leaves)
if self.ds_factor > 1:
# print(f"Downsampling {chart.shape} with factor {self.ds_factor}...")
# Downsample using strided gaussian conv (sigma=3/5)
th_chart = torch.as_tensor(numba_chart, dtype=DEFAULT_TORCH_FLOAT_TYPE,
device="cuda").permute(2, 0, 1)[None] # [b, c, h, w]
th_chart = F.pad(th_chart,
pad=(2, 2, 0, 0),
mode="replicate")
th_chart = F.conv2d(th_chart,
self.downsample_kernel,
padding='valid',
groups=3,
stride=self.ds_factor)
# Convert back to numba
numba_chart = cuda.as_cuda_array(th_chart.permute(0, 2, 3, 1)) # [b, h, w, c]
# convert back to numpy (temporary for legacy)
chart = numba_chart.copy_to_host()[0]
return chart
def generate_images(path: Path, dataset: Dataset, imin=0):
for i in tqdm(range(imin, dataset.length)):
img = dataset[i]
img = (img * 255).astype(np.uint8)
out_path = path / "{:04d}.png".format(i)
cv2.imwrite(out_path.as_posix(), img)
def bench(dataset):
print("dataset initialised")
t1 = perf_counter()
chart = dataset[0]
d = (perf_counter()-t1)
print(f"generation done {d}")
print(f"{d*1_000/60} min for 1_000")
plt.imshow(chart)
plt.show()
if __name__ == "__main__":
argparser = argparse.ArgumentParser()
argparser.add_argument("-o", "--output-dir", type=str, default=str(DATASET_PATH))
argparser.add_argument(
"-n", "--name", type=str,
choices=[DATASET_DL_RANDOMRGB_1024, DATASET_DL_DIV2K_1024,
DATASET_DL_DIV2K_512, DATASET_DL_EXTRAPRIMITIVES_DIV2K_512],
default=DATASET_DL_RANDOMRGB_1024
)
argparser.add_argument("-b", "--benchmark", action="store_true")
default_config = dict(
size=(1_024, 1_024),
length=1_000,
frozen_seed=42,
background_color=(0.2, 0.4, 0.6),
colored=True,
radius_min=5,
radius_max=2_000,
ds_factor=5,
)
args = argparser.parse_args()
dataset_dir = args.output_dir
name = args.name
path = Path(dataset_dir)/name
# print(path)
path.mkdir(parents=True, exist_ok=True)
if name == DATASET_DL_RANDOMRGB_1024:
config = default_config
config["sampler"] = SAMPLER_UNIFORM
elif name == DATASET_DL_DIV2K_1024:
config = default_config
config["sampler"] = SAMPLER_NATURAL
config["natural_image_list"] = sorted(
list((DATASET_PATH / "div2k" / "DIV2K_train_HR" / "DIV2K_train_HR").glob("*.png"))
)
elif name == DATASET_DL_DIV2K_512:
config = default_config
config["size"] = (512, 512)
config["rmin"] = 3
config["length"] = 4000
config["sampler"] = SAMPLER_NATURAL
config["natural_image_list"] = sorted(
list((DATASET_PATH / "div2k" / "DIV2K_train_HR" / "DIV2K_train_HR").glob("*.png"))
)
elif name == DATASET_DL_EXTRAPRIMITIVES_DIV2K_512:
config = default_config
config["size"] = (512, 512)
config["sampler"] = SAMPLER_NATURAL
config["circle_primitives"] = False
config["length"] = 4000
config["natural_image_list"] = sorted(
list((DATASET_PATH / "div2k" / "DIV2K_train_HR" / "DIV2K_train_HR").glob("*.png"))
)
else:
raise NotImplementedError
dataset = DeadLeavesDatasetGPU(**config)
if args.benchmark:
bench(dataset)
else:
generate_images(path, dataset)