Spaces:
Runtime error
Runtime error
File size: 8,456 Bytes
ec0fdfd |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 |
import argparse
import os
import torch
import model
from util import util
class BaseOptions():
def __init__(self):
self.parser = argparse.ArgumentParser()
self.initialized = False
def initialize(self, parser):
# base define
parser.add_argument('--name', type=str, default='experiment_name', help='name of the experiment.')
parser.add_argument('--model', type=str, default='tc', help='name of the model type. [pluralistic]')
parser.add_argument('--checkpoints_dir', type=str, default='./checkpoints', help='models are save here')
parser.add_argument('--which_iter', type=int, default='0', help='which iterations to load')
parser.add_argument('--epoch', type=str, default='latest', help='which epoch to load')
parser.add_argument('--gpu_ids', type=str, default='0', help='gpu ids: e.g. 0, 1, 2 use -1 for CPU')
# data define
parser.add_argument('--mask_type', type=int, default=[0,1,3], help='0:center,1:regular,2:irregular,3:external')
parser.add_argument('--img_file', type=str, default='/data/dataset/train', help='training and testing dataset')
parser.add_argument('--mask_file', type=str, default='none', help='load test mask')
parser.add_argument('--img_nc', type=int, default=3, help='# of image channels')
parser.add_argument('--preprocess', type=str, default='resize_and_crop', help='preprocessing image at load time')
parser.add_argument('--load_size', type=int, default=542, help='scale examples to this size')
parser.add_argument('--fine_size', type=int, default=512, help='then crop to this size')
parser.add_argument('--fixed_size', type=int, default=256, help='fixed the image size in S1 with transformer')
parser.add_argument('--no_flip', action='store_true', help='if specified, do not flip the image')
parser.add_argument('--data_powers', type=int, default=5, help='# times of the scale to 2 times')
parser.add_argument('--reverse_mask', action='store_true', help='if specified, random reverse the mask region')
parser.add_argument('--batch_size', type=int, default=8, help='input batch size')
parser.add_argument('--nThreads', type=int, default=8, help='# threads for loading data')
parser.add_argument('--no_shuffle', action='store_true', help='if true, takes examples serial')
# display parameter define
parser.add_argument('--display_winsize', type=int, default=256, help='display window size')
parser.add_argument('--display_id', type=int, default=None, help='display id of the web')
parser.add_argument('--display_server', type=str, default="http://localhost", help='server of the web display')
parser.add_argument('--display_env', type=str, default='main', help='display name (default is "main")')
parser.add_argument('--display_port', type=int, default=8092, help='port of the web display')
parser.add_argument('--display_single_pane_ncols', type=int, default=0, help='if positive, display all examples in a single visidom web panel')
# Encoder-Decoder define
parser.add_argument('--ngf', type=int, default=32, help='# of gen filters in the last conv layer')
parser.add_argument('--ndf', type=int, default=32, help='# of dis filters in the first conv layer')
parser.add_argument('--num_res_blocks', type=int, default=2, help='# of residual block in the encoder and decoder layer')
parser.add_argument('--netD', type=str, default='style', help='specify discriminator architecture ')
parser.add_argument('--netG', type=str, default='diff', help='specify decoder architecture')
parser.add_argument('--netE', type=str, default='diff', help='specify encoder architecture')
parser.add_argument('--kernel_G', type=int, default=3, help='kernel size for the decoder')
parser.add_argument('--kernel_E', type=int, default=1, help='kernel size for the encoder')
parser.add_argument('--add_noise', action='store_true', help='if true, add noise to the decoder')
parser.add_argument('--attn_E', action='store_true', help='if true, use attention in the encoder')
parser.add_argument('--attn_G', action='store_true', help='if true, use attention in the decoder')
parser.add_argument('--attn_D', action='store_true', help='if true, use attention in the decoder')
parser.add_argument('--n_layers_D', type=int, default=3, help='only used if netD==n_layers')
parser.add_argument('--n_layers_G', type=int, default=4, help='# of down sample layers in the Encoder and Decoder')
parser.add_argument('--norm', type=str, default='pixel', help='instance normalization or batch normalization [instance | batch | pixel | none]')
parser.add_argument('--activation', type=str, default='leakyrelu', help='activation layer [relu | gelu | leakyrelu | none]')
parser.add_argument('--init_type', type=str, default='kaiming', help='network initialization [normal | xavier | kaiming | orthogonal]')
parser.add_argument('--init_gain', type=float, default=0.02, help='scaling factor for normal, xavier and orthogonal.')
parser.add_argument('--lipip_path', type=str, default='./model/lpips/vgg.pth', help='the pretrained LIPPS model')
# Transformer define
parser.add_argument('--netT', type=str, default='original', help='specify transformer architecture')
parser.add_argument('--embed_dim', type=int, default=512, help='the numbers of embedding dimension')
parser.add_argument('--dropout', type=float, default=0., help='the dropout probability in transformer')
parser.add_argument('--kernel_T', type=int, default=1, help='kernel size for the transformer projection')
parser.add_argument('--n_encoders', type=int, default=12, help='the numbers of encoder in transformer')
parser.add_argument('--n_decoders', type=int, default=0, help='the numbers of decoder in transformer')
parser.add_argument('--embed_type', type=str, default='learned', choices=['learned', 'sine'])
parser.add_argument('--top_k', type=int, default=10, help='sample the results on top k value')
# VQ define
parser.add_argument('--num_embeds', type=int, default=1024, help='the numbers of words for image')
parser.add_argument('--use_pos_G', action='store_true', help='if true, position embedding in G')
parser.add_argument('--word_size', type=int, default=16, help='the numbers of word for each image')
self.initialized = True
return parser
def gather_options(self):
"""Add additional model-specific options"""
if not self.initialized:
parser = self.initialize(self.parser)
# get basic options
opt, _ = parser.parse_known_args()
# modify the options for different models
model_option_set = model.get_option_setter(opt.model)
parser = model_option_set(parser, self.isTrain)
opt = parser.parse_args()
return opt
def parse(self):
"""Parse the options"""
opt = self.gather_options()
opt.isTrain = self.isTrain
self.print_options(opt)
# set gpu ids
str_ids = opt.gpu_ids.split(',')
opt.gpu_ids = []
for str_id in str_ids:
id = int(str_id)
if id >= 0:
opt.gpu_ids.append(id)
if len(opt.gpu_ids):
torch.cuda.set_device(opt.gpu_ids[0])
self.opt = opt
return self.opt
@staticmethod
def print_options(opt):
"""print and save options"""
print('--------------Options--------------')
for k, v in sorted(vars(opt).items()):
print('%s: %s' % (str(k), str(v)))
print('----------------End----------------')
# save to the disk
expr_dir = os.path.join(opt.checkpoints_dir, opt.name)
util.mkdirs(expr_dir)
if opt.isTrain:
file_name = os.path.join(expr_dir, 'train_opt.txt')
else:
file_name = os.path.join(expr_dir, 'test_opt.txt')
with open(file_name, 'wt') as opt_file:
opt_file.write('--------------Options--------------\n')
for k, v in sorted(vars(opt).items()):
opt_file.write('%s: %s\n' % (str(k), str(v)))
opt_file.write('----------------End----------------\n') |