Spaces:
Runtime error
Runtime error
# Copyright (c) Microsoft Corporation. | |
# Licensed under the MIT License. | |
from data.pix2pix_dataset import Pix2pixDataset | |
from data.image_folder import make_dataset | |
class CustomDataset(Pix2pixDataset): | |
""" Dataset that loads images from directories | |
Use option --label_dir, --image_dir, --instance_dir to specify the directories. | |
The images in the directories are sorted in alphabetical order and paired in order. | |
""" | |
def modify_commandline_options(parser, is_train): | |
parser = Pix2pixDataset.modify_commandline_options(parser, is_train) | |
parser.set_defaults(preprocess_mode="resize_and_crop") | |
load_size = 286 if is_train else 256 | |
parser.set_defaults(load_size=load_size) | |
parser.set_defaults(crop_size=256) | |
parser.set_defaults(display_winsize=256) | |
parser.set_defaults(label_nc=13) | |
parser.set_defaults(contain_dontcare_label=False) | |
parser.add_argument( | |
"--label_dir", type=str, required=True, help="path to the directory that contains label images" | |
) | |
parser.add_argument( | |
"--image_dir", type=str, required=True, help="path to the directory that contains photo images" | |
) | |
parser.add_argument( | |
"--instance_dir", | |
type=str, | |
default="", | |
help="path to the directory that contains instance maps. Leave black if not exists", | |
) | |
return parser | |
def get_paths(self, opt): | |
label_dir = opt.label_dir | |
label_paths = make_dataset(label_dir, recursive=False, read_cache=True) | |
image_dir = opt.image_dir | |
image_paths = make_dataset(image_dir, recursive=False, read_cache=True) | |
if len(opt.instance_dir) > 0: | |
instance_dir = opt.instance_dir | |
instance_paths = make_dataset(instance_dir, recursive=False, read_cache=True) | |
else: | |
instance_paths = [] | |
assert len(label_paths) == len( | |
image_paths | |
), "The #images in %s and %s do not match. Is there something wrong?" | |
return label_paths, image_paths, instance_paths | |