Spaces:
Build error
Build error
File size: 9,169 Bytes
d444fe9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 |
import argparse
import os
class BaseOptions():
def __init__(self):
self.initialized = False
argparse
def initialize(self, parser):
# Datasets related
g_data = parser.add_argument_group('Data')
g_data.add_argument('--dataroot', type=str, default='./data',
help='path to images (data folder)')
g_data.add_argument('--loadSize', type=int, default=512, help='load size of input image')
# Experiment related
g_exp = parser.add_argument_group('Experiment')
g_exp.add_argument('--name', type=str, default='example',
help='name of the experiment. It decides where to store samples and models')
g_exp.add_argument('--debug', action='store_true', help='debug mode or not')
g_exp.add_argument('--num_views', type=int, default=1, help='How many views to use for multiview network.')
g_exp.add_argument('--random_multiview', action='store_true', help='Select random multiview combination.')
# Training related
g_train = parser.add_argument_group('Training')
g_train.add_argument('--gpu_id', type=int, default=0, help='gpu id for cuda')
g_train.add_argument('--gpu_ids', type=str, default='0', help='gpu ids: e.g. 0 0,1,2, 0,2, -1 for CPU mode')
g_train.add_argument('--num_threads', default=1, type=int, help='# sthreads for loading data')
g_train.add_argument('--serial_batches', action='store_true',
help='if true, takes images in order to make batches, otherwise takes them randomly')
g_train.add_argument('--pin_memory', action='store_true', help='pin_memory')
g_train.add_argument('--batch_size', type=int, default=2, help='input batch size')
g_train.add_argument('--learning_rate', type=float, default=1e-3, help='adam learning rate')
g_train.add_argument('--learning_rateC', type=float, default=1e-3, help='adam learning rate')
g_train.add_argument('--num_epoch', type=int, default=100, help='num epoch to train')
g_train.add_argument('--freq_plot', type=int, default=10, help='freqency of the error plot')
g_train.add_argument('--freq_save', type=int, default=50, help='freqency of the save_checkpoints')
g_train.add_argument('--freq_save_ply', type=int, default=100, help='freqency of the save ply')
g_train.add_argument('--no_gen_mesh', action='store_true')
g_train.add_argument('--no_num_eval', action='store_true')
g_train.add_argument('--resume_epoch', type=int, default=-1, help='epoch resuming the training')
g_train.add_argument('--continue_train', action='store_true', help='continue training: load the latest model')
# Testing related
g_test = parser.add_argument_group('Testing')
g_test.add_argument('--resolution', type=int, default=256, help='# of grid in mesh reconstruction')
g_test.add_argument('--test_folder_path', type=str, default=None, help='the folder of test image')
# Sampling related
g_sample = parser.add_argument_group('Sampling')
g_sample.add_argument('--sigma', type=float, default=5.0, help='perturbation standard deviation for positions')
g_sample.add_argument('--num_sample_inout', type=int, default=5000, help='# of sampling points')
g_sample.add_argument('--num_sample_color', type=int, default=0, help='# of sampling points')
g_sample.add_argument('--z_size', type=float, default=200.0, help='z normalization factor')
# Model related
g_model = parser.add_argument_group('Model')
# General
g_model.add_argument('--norm', type=str, default='group',
help='instance normalization or batch normalization or group normalization')
g_model.add_argument('--norm_color', type=str, default='instance',
help='instance normalization or batch normalization or group normalization')
# hg filter specify
g_model.add_argument('--num_stack', type=int, default=4, help='# of hourglass')
g_model.add_argument('--num_hourglass', type=int, default=2, help='# of stacked layer of hourglass')
g_model.add_argument('--skip_hourglass', action='store_true', help='skip connection in hourglass')
g_model.add_argument('--hg_down', type=str, default='ave_pool', help='ave pool || conv64 || conv128')
g_model.add_argument('--hourglass_dim', type=int, default='256', help='256 | 512')
# Classification General
g_model.add_argument('--mlp_dim', nargs='+', default=[257, 1024, 512, 256, 128, 1], type=int,
help='# of dimensions of mlp')
g_model.add_argument('--mlp_dim_color', nargs='+', default=[513, 1024, 512, 256, 128, 3],
type=int, help='# of dimensions of color mlp')
g_model.add_argument('--use_tanh', action='store_true',
help='using tanh after last conv of image_filter network')
# for train
parser.add_argument('--random_flip', action='store_true', help='if random flip')
parser.add_argument('--random_trans', action='store_true', help='if random flip')
parser.add_argument('--random_scale', action='store_true', help='if random flip')
parser.add_argument('--no_residual', action='store_true', help='no skip connection in mlp')
parser.add_argument('--schedule', type=int, nargs='+', default=[60, 80],
help='Decrease learning rate at these epochs.')
parser.add_argument('--gamma', type=float, default=0.1, help='LR is multiplied by gamma on schedule.')
parser.add_argument('--color_loss_type', type=str, default='l1', help='mse | l1')
# for eval
parser.add_argument('--val_test_error', action='store_true', help='validate errors of test data')
parser.add_argument('--val_train_error', action='store_true', help='validate errors of train data')
parser.add_argument('--gen_test_mesh', action='store_true', help='generate test mesh')
parser.add_argument('--gen_train_mesh', action='store_true', help='generate train mesh')
parser.add_argument('--all_mesh', action='store_true', help='generate meshs from all hourglass output')
parser.add_argument('--num_gen_mesh_test', type=int, default=1,
help='how many meshes to generate during testing')
# path
parser.add_argument('--checkpoints_path', type=str, default='./checkpoints', help='path to save checkpoints')
parser.add_argument('--load_netG_checkpoint_path', type=str, default=None, help='path to save checkpoints')
parser.add_argument('--load_netC_checkpoint_path', type=str, default=None, help='path to save checkpoints')
parser.add_argument('--results_path', type=str, default='./results', help='path to save results ply')
parser.add_argument('--load_checkpoint_path', type=str, help='path to save results ply')
parser.add_argument('--single', type=str, default='', help='single data for training')
# for single image reconstruction
parser.add_argument('--mask_path', type=str, help='path for input mask')
parser.add_argument('--img_path', type=str, help='path for input image')
# aug
group_aug = parser.add_argument_group('aug')
group_aug.add_argument('--aug_alstd', type=float, default=0.0, help='augmentation pca lighting alpha std')
group_aug.add_argument('--aug_bri', type=float, default=0.0, help='augmentation brightness')
group_aug.add_argument('--aug_con', type=float, default=0.0, help='augmentation contrast')
group_aug.add_argument('--aug_sat', type=float, default=0.0, help='augmentation saturation')
group_aug.add_argument('--aug_hue', type=float, default=0.0, help='augmentation hue')
group_aug.add_argument('--aug_blur', type=float, default=0.0, help='augmentation blur')
# special tasks
self.initialized = True
return parser
def gather_options(self):
# initialize parser with basic options
if not self.initialized:
parser = argparse.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser = self.initialize(parser)
self.parser = parser
return parser.parse_args()
def print_options(self, opt):
message = ''
message += '----------------- Options ---------------\n'
for k, v in sorted(vars(opt).items()):
comment = ''
default = self.parser.get_default(k)
if v != default:
comment = '\t[default: %s]' % str(default)
message += '{:>25}: {:<30}{}\n'.format(str(k), str(v), comment)
message += '----------------- End -------------------'
print(message)
def parse(self):
opt = self.gather_options()
return opt
def parse_to_dict(self):
opt = self.gather_options()
return opt.__dict__ |