import argparse import os import torch import model from util import util class BaseOptions(): def __init__(self): self.parser = argparse.ArgumentParser() self.initialized = False def initialize(self, parser): # base define parser.add_argument('--name', type=str, default='experiment_name', help='name of the experiment.') parser.add_argument('--model', type=str, default='tc', help='name of the model type. [pluralistic]') parser.add_argument('--checkpoints_dir', type=str, default='./checkpoints', help='models are save here') parser.add_argument('--which_iter', type=int, default='0', help='which iterations to load') parser.add_argument('--epoch', type=str, default='latest', help='which epoch to load') parser.add_argument('--gpu_ids', type=str, default='0', help='gpu ids: e.g. 0, 1, 2 use -1 for CPU') # data define parser.add_argument('--mask_type', type=int, default=[0,1,3], help='0:center,1:regular,2:irregular,3:external') parser.add_argument('--img_file', type=str, default='/data/dataset/train', help='training and testing dataset') parser.add_argument('--mask_file', type=str, default='none', help='load test mask') parser.add_argument('--img_nc', type=int, default=3, help='# of image channels') parser.add_argument('--preprocess', type=str, default='resize_and_crop', help='preprocessing image at load time') parser.add_argument('--load_size', type=int, default=542, help='scale examples to this size') parser.add_argument('--fine_size', type=int, default=512, help='then crop to this size') parser.add_argument('--fixed_size', type=int, default=256, help='fixed the image size in S1 with transformer') parser.add_argument('--no_flip', action='store_true', help='if specified, do not flip the image') parser.add_argument('--data_powers', type=int, default=5, help='# times of the scale to 2 times') parser.add_argument('--reverse_mask', action='store_true', help='if specified, random reverse the mask region') parser.add_argument('--batch_size', type=int, default=8, help='input batch size') parser.add_argument('--nThreads', type=int, default=8, help='# threads for loading data') parser.add_argument('--no_shuffle', action='store_true', help='if true, takes examples serial') # display parameter define parser.add_argument('--display_winsize', type=int, default=256, help='display window size') parser.add_argument('--display_id', type=int, default=None, help='display id of the web') parser.add_argument('--display_server', type=str, default="http://localhost", help='server of the web display') parser.add_argument('--display_env', type=str, default='main', help='display name (default is "main")') parser.add_argument('--display_port', type=int, default=8092, help='port of the web display') parser.add_argument('--display_single_pane_ncols', type=int, default=0, help='if positive, display all examples in a single visidom web panel') # Encoder-Decoder define parser.add_argument('--ngf', type=int, default=32, help='# of gen filters in the last conv layer') parser.add_argument('--ndf', type=int, default=32, help='# of dis filters in the first conv layer') parser.add_argument('--num_res_blocks', type=int, default=2, help='# of residual block in the encoder and decoder layer') parser.add_argument('--netD', type=str, default='style', help='specify discriminator architecture ') parser.add_argument('--netG', type=str, default='diff', help='specify decoder architecture') parser.add_argument('--netE', type=str, default='diff', help='specify encoder architecture') parser.add_argument('--kernel_G', type=int, default=3, help='kernel size for the decoder') parser.add_argument('--kernel_E', type=int, default=1, help='kernel size for the encoder') parser.add_argument('--add_noise', action='store_true', help='if true, add noise to the decoder') parser.add_argument('--attn_E', action='store_true', help='if true, use attention in the encoder') parser.add_argument('--attn_G', action='store_true', help='if true, use attention in the decoder') parser.add_argument('--attn_D', action='store_true', help='if true, use attention in the decoder') parser.add_argument('--n_layers_D', type=int, default=3, help='only used if netD==n_layers') parser.add_argument('--n_layers_G', type=int, default=4, help='# of down sample layers in the Encoder and Decoder') parser.add_argument('--norm', type=str, default='pixel', help='instance normalization or batch normalization [instance | batch | pixel | none]') parser.add_argument('--activation', type=str, default='leakyrelu', help='activation layer [relu | gelu | leakyrelu | none]') parser.add_argument('--init_type', type=str, default='kaiming', help='network initialization [normal | xavier | kaiming | orthogonal]') parser.add_argument('--init_gain', type=float, default=0.02, help='scaling factor for normal, xavier and orthogonal.') parser.add_argument('--lipip_path', type=str, default='./model/lpips/vgg.pth', help='the pretrained LIPPS model') # Transformer define parser.add_argument('--netT', type=str, default='original', help='specify transformer architecture') parser.add_argument('--embed_dim', type=int, default=512, help='the numbers of embedding dimension') parser.add_argument('--dropout', type=float, default=0., help='the dropout probability in transformer') parser.add_argument('--kernel_T', type=int, default=1, help='kernel size for the transformer projection') parser.add_argument('--n_encoders', type=int, default=12, help='the numbers of encoder in transformer') parser.add_argument('--n_decoders', type=int, default=0, help='the numbers of decoder in transformer') parser.add_argument('--embed_type', type=str, default='learned', choices=['learned', 'sine']) parser.add_argument('--top_k', type=int, default=10, help='sample the results on top k value') # VQ define parser.add_argument('--num_embeds', type=int, default=1024, help='the numbers of words for image') parser.add_argument('--use_pos_G', action='store_true', help='if true, position embedding in G') parser.add_argument('--word_size', type=int, default=16, help='the numbers of word for each image') self.initialized = True return parser def gather_options(self): """Add additional model-specific options""" if not self.initialized: parser = self.initialize(self.parser) # get basic options opt, _ = parser.parse_known_args() # modify the options for different models model_option_set = model.get_option_setter(opt.model) parser = model_option_set(parser, self.isTrain) opt = parser.parse_args() return opt def parse(self): """Parse the options""" opt = self.gather_options() opt.isTrain = self.isTrain self.print_options(opt) # set gpu ids str_ids = opt.gpu_ids.split(',') opt.gpu_ids = [] for str_id in str_ids: id = int(str_id) if id >= 0: opt.gpu_ids.append(id) if len(opt.gpu_ids): torch.cuda.set_device(opt.gpu_ids[0]) self.opt = opt return self.opt @staticmethod def print_options(opt): """print and save options""" print('--------------Options--------------') for k, v in sorted(vars(opt).items()): print('%s: %s' % (str(k), str(v))) print('----------------End----------------') # save to the disk expr_dir = os.path.join(opt.checkpoints_dir, opt.name) util.mkdirs(expr_dir) if opt.isTrain: file_name = os.path.join(expr_dir, 'train_opt.txt') else: file_name = os.path.join(expr_dir, 'test_opt.txt') with open(file_name, 'wt') as opt_file: opt_file.write('--------------Options--------------\n') for k, v in sorted(vars(opt).items()): opt_file.write('%s: %s\n' % (str(k), str(v))) opt_file.write('----------------End----------------\n')