Spaces:
Runtime error
Runtime error
File size: 2,149 Bytes
7fab858 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 |
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
from data.pix2pix_dataset import Pix2pixDataset
from data.image_folder import make_dataset
class CustomDataset(Pix2pixDataset):
""" Dataset that loads images from directories
Use option --label_dir, --image_dir, --instance_dir to specify the directories.
The images in the directories are sorted in alphabetical order and paired in order.
"""
@staticmethod
def modify_commandline_options(parser, is_train):
parser = Pix2pixDataset.modify_commandline_options(parser, is_train)
parser.set_defaults(preprocess_mode="resize_and_crop")
load_size = 286 if is_train else 256
parser.set_defaults(load_size=load_size)
parser.set_defaults(crop_size=256)
parser.set_defaults(display_winsize=256)
parser.set_defaults(label_nc=13)
parser.set_defaults(contain_dontcare_label=False)
parser.add_argument(
"--label_dir", type=str, required=True, help="path to the directory that contains label images"
)
parser.add_argument(
"--image_dir", type=str, required=True, help="path to the directory that contains photo images"
)
parser.add_argument(
"--instance_dir",
type=str,
default="",
help="path to the directory that contains instance maps. Leave black if not exists",
)
return parser
def get_paths(self, opt):
label_dir = opt.label_dir
label_paths = make_dataset(label_dir, recursive=False, read_cache=True)
image_dir = opt.image_dir
image_paths = make_dataset(image_dir, recursive=False, read_cache=True)
if len(opt.instance_dir) > 0:
instance_dir = opt.instance_dir
instance_paths = make_dataset(instance_dir, recursive=False, read_cache=True)
else:
instance_paths = []
assert len(label_paths) == len(
image_paths
), "The #images in %s and %s do not match. Is there something wrong?"
return label_paths, image_paths, instance_paths
|