BOPBTL / Face_Enhancement /data /custom_dataset.py
manhkhanhUIT's picture
Add code
7fab858
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from data.pix2pix_dataset import Pix2pixDataset
from data.image_folder import make_dataset
class CustomDataset(Pix2pixDataset):
""" Dataset that loads images from directories
Use option --label_dir, --image_dir, --instance_dir to specify the directories.
The images in the directories are sorted in alphabetical order and paired in order.
"""
@staticmethod
def modify_commandline_options(parser, is_train):
parser = Pix2pixDataset.modify_commandline_options(parser, is_train)
parser.set_defaults(preprocess_mode="resize_and_crop")
load_size = 286 if is_train else 256
parser.set_defaults(load_size=load_size)
parser.set_defaults(crop_size=256)
parser.set_defaults(display_winsize=256)
parser.set_defaults(label_nc=13)
parser.set_defaults(contain_dontcare_label=False)
parser.add_argument(
"--label_dir", type=str, required=True, help="path to the directory that contains label images"
)
parser.add_argument(
"--image_dir", type=str, required=True, help="path to the directory that contains photo images"
)
parser.add_argument(
"--instance_dir",
type=str,
default="",
help="path to the directory that contains instance maps. Leave black if not exists",
)
return parser
def get_paths(self, opt):
label_dir = opt.label_dir
label_paths = make_dataset(label_dir, recursive=False, read_cache=True)
image_dir = opt.image_dir
image_paths = make_dataset(image_dir, recursive=False, read_cache=True)
if len(opt.instance_dir) > 0:
instance_dir = opt.instance_dir
instance_paths = make_dataset(instance_dir, recursive=False, read_cache=True)
else:
instance_paths = []
assert len(label_paths) == len(
image_paths
), "The #images in %s and %s do not match. Is there something wrong?"
return label_paths, image_paths, instance_paths