Last commit not found
import argparse | |
import cv2 | |
import numpy as np | |
import os | |
import sys | |
from basicsr.utils import scandir | |
from multiprocessing import Pool | |
from os import path as osp | |
from tqdm import tqdm | |
def main(args): | |
"""A multi-thread tool to crop large images to sub-images for faster IO. | |
opt (dict): Configuration dict. It contains: | |
n_thread (int): Thread number. | |
compression_level (int): CV_IMWRITE_PNG_COMPRESSION from 0 to 9. A higher value means a smaller size | |
and longer compression time. Use 0 for faster CPU decompression. Default: 3, same in cv2. | |
input_folder (str): Path to the input folder. | |
save_folder (str): Path to save folder. | |
crop_size (int): Crop size. | |
step (int): Step for overlapped sliding window. | |
thresh_size (int): Threshold size. Patches whose size is lower than thresh_size will be dropped. | |
Usage: | |
For each folder, run this script. | |
Typically, there are GT folder and LQ folder to be processed for DIV2K dataset. | |
After process, each sub_folder should have the same number of subimages. | |
Remember to modify opt configurations according to your settings. | |
""" | |
opt = {} | |
opt["n_thread"] = args.n_thread | |
opt["compression_level"] = args.compression_level | |
opt["input_folder"] = args.input | |
opt["save_folder"] = args.output | |
opt["crop_size"] = args.crop_size | |
opt["step"] = args.step | |
opt["thresh_size"] = args.thresh_size | |
extract_subimages(opt) | |
def extract_subimages(opt): | |
"""Crop images to subimages. | |
Args: | |
opt (dict): Configuration dict. It contains: | |
input_folder (str): Path to the input folder. | |
save_folder (str): Path to save folder. | |
n_thread (int): Thread number. | |
""" | |
input_folder = opt["input_folder"] | |
save_folder = opt["save_folder"] | |
if not osp.exists(save_folder): | |
os.makedirs(save_folder) | |
print(f"mkdir {save_folder} ...") | |
else: | |
print(f"Folder {save_folder} already exists. Exit.") | |
sys.exit(1) | |
# scan all images | |
img_list = list(scandir(input_folder, full_path=True)) | |
pbar = tqdm(total=len(img_list), unit="image", desc="Extract") | |
pool = Pool(opt["n_thread"]) | |
for path in img_list: | |
pool.apply_async(worker, args=(path, opt), callback=lambda arg: pbar.update(1)) | |
pool.close() | |
pool.join() | |
pbar.close() | |
print("All processes done.") | |
def worker(path, opt): | |
"""Worker for each process. | |
Args: | |
path (str): Image path. | |
opt (dict): Configuration dict. It contains: | |
crop_size (int): Crop size. | |
step (int): Step for overlapped sliding window. | |
thresh_size (int): Threshold size. Patches whose size is lower than thresh_size will be dropped. | |
save_folder (str): Path to save folder. | |
compression_level (int): for cv2.IMWRITE_PNG_COMPRESSION. | |
Returns: | |
process_info (str): Process information displayed in progress bar. | |
""" | |
crop_size = opt["crop_size"] | |
step = opt["step"] | |
thresh_size = opt["thresh_size"] | |
img_name, extension = osp.splitext(osp.basename(path)) | |
# remove the x2, x3, x4 and x8 in the filename for DIV2K | |
img_name = ( | |
img_name.replace("x2", "").replace("x3", "").replace("x4", "").replace("x8", "") | |
) | |
img = cv2.imread(path, cv2.IMREAD_UNCHANGED) | |
h, w = img.shape[0:2] | |
h_space = np.arange(0, h - crop_size + 1, step) | |
if h - (h_space[-1] + crop_size) > thresh_size: | |
h_space = np.append(h_space, h - crop_size) | |
w_space = np.arange(0, w - crop_size + 1, step) | |
if w - (w_space[-1] + crop_size) > thresh_size: | |
w_space = np.append(w_space, w - crop_size) | |
index = 0 | |
for x in h_space: | |
for y in w_space: | |
index += 1 | |
cropped_img = img[x : x + crop_size, y : y + crop_size, ...] | |
cropped_img = np.ascontiguousarray(cropped_img) | |
cv2.imwrite( | |
osp.join(opt["save_folder"], f"{img_name}_s{index:03d}{extension}"), | |
cropped_img, | |
[cv2.IMWRITE_PNG_COMPRESSION, opt["compression_level"]], | |
) | |
process_info = f"Processing {img_name} ..." | |
return process_info | |
if __name__ == "__main__": | |
parser = argparse.ArgumentParser() | |
parser.add_argument( | |
"--input", type=str, default="datasets/DF2K/DF2K_HR", help="Input folder" | |
) | |
parser.add_argument( | |
"--output", type=str, default="datasets/DF2K/DF2K_HR_sub", help="Output folder" | |
) | |
parser.add_argument("--crop_size", type=int, default=480, help="Crop size") | |
parser.add_argument( | |
"--step", type=int, default=240, help="Step for overlapped sliding window" | |
) | |
parser.add_argument( | |
"--thresh_size", | |
type=int, | |
default=0, | |
help="Threshold size. Patches whose size is lower than thresh_size will be dropped.", | |
) | |
parser.add_argument("--n_thread", type=int, default=20, help="Thread number.") | |
parser.add_argument( | |
"--compression_level", type=int, default=3, help="Compression level" | |
) | |
args = parser.parse_args() | |
main(args) | |