Spaces:
Runtime error
Runtime error
import argparse | |
import torch | |
import os | |
from datetime import datetime | |
import time | |
import torch | |
import random | |
import numpy as np | |
import sys | |
class Options(object): | |
"""docstring for Options""" | |
def __init__(self): | |
super(Options, self).__init__() | |
def initialize(self): | |
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) | |
parser.add_argument('--mode', type=str, default='train', help='Mode of code. [train|test]') | |
parser.add_argument('--model', type=str, default='ganimation', help='[ganimation|stargan], see model.__init__ from more details.') | |
parser.add_argument('--lucky_seed', type=int, default=0, help='seed for random initialize, 0 to use current time.') | |
parser.add_argument('--visdom_env', type=str, default="main", help='visdom env.') | |
parser.add_argument('--visdom_port', type=int, default=8097, help='visdom port.') | |
parser.add_argument('--visdom_display_id', type=int, default=1, help='set value larger than 0 to display with visdom.') | |
parser.add_argument('--results', type=str, default="results", help='save test results to this path.') | |
parser.add_argument('--interpolate_len', type=int, default=5, help='interpolate length for test.') | |
parser.add_argument('--no_test_eval', action='store_true', help='do not use eval mode during test time.') | |
parser.add_argument('--save_test_gif', action='store_true', help='save gif images instead of the concatenation of static images.') | |
parser.add_argument('--data_root', required=False, help='paths to data set.') | |
parser.add_argument('--imgs_dir', type=str, default="imgs", help='path to image') | |
parser.add_argument('--aus_pkl', type=str, default="aus_openface.pkl", help='AUs pickle dictionary.') | |
parser.add_argument('--train_csv', type=str, default="train_ids.csv", help='train images paths') | |
parser.add_argument('--test_csv', type=str, default="test_ids.csv", help='test images paths') | |
parser.add_argument('--batch_size', type=int, default=25, help='input batch size.') | |
parser.add_argument('--serial_batches', action='store_true', help='if specified, input images in order.') | |
parser.add_argument('--n_threads', type=int, default=6, help='number of workers to load data.') | |
parser.add_argument('--max_dataset_size', type=int, default=float("inf"), help='maximum number of samples.') | |
parser.add_argument('--resize_or_crop', type=str, default='none', help='Preprocessing image, [resize_and_crop|crop|none]') | |
parser.add_argument('--load_size', type=int, default=148, help='scale image to this size.') | |
parser.add_argument('--final_size', type=int, default=128, help='crop image to this size.') | |
parser.add_argument('--no_flip', action='store_true', help='if specified, do not flip image.') | |
parser.add_argument('--no_aus_noise', action='store_true', help='if specified, add noise to target AUs.') | |
parser.add_argument('--gpu_ids', type=str, default='0', help='gpu ids, eg. 0,1,2; -1 for cpu.') | |
parser.add_argument('--ckpt_dir', type=str, default='./ckpts', help='directory to save check points.') | |
parser.add_argument('--load_epoch', type=int, default=0, help='load epoch; 0: do not load') | |
parser.add_argument('--log_file', type=str, default="logs.txt", help='log loss') | |
parser.add_argument('--opt_file', type=str, default="opt.txt", help='options file') | |
# train options | |
parser.add_argument('--img_nc', type=int, default=3, help='image number of channel') | |
parser.add_argument('--aus_nc', type=int, default=17, help='aus number of channel') | |
parser.add_argument('--ngf', type=int, default=64, help='ngf') | |
parser.add_argument('--ndf', type=int, default=64, help='ndf') | |
parser.add_argument('--use_dropout', action='store_true', help='if specified, use dropout.') | |
parser.add_argument('--gan_type', type=str, default='wgan-gp', help='GAN loss [wgan-gp|lsgan|gan]') | |
parser.add_argument('--init_type', type=str, default='normal', help='network initialization [normal|xavier|kaiming|orthogonal]') | |
parser.add_argument('--init_gain', type=float, default=0.02, help='scaling factor for normal, xavier and orthogonal.') | |
parser.add_argument('--norm', type=str, default='instance', help='instance normalization or batch normalization [batch|instance|none]') | |
parser.add_argument('--beta1', type=float, default=0.5, help='momentum term of adam') | |
parser.add_argument('--lr', type=float, default=0.0001, help='initial learning rate for adam') | |
parser.add_argument('--lr_policy', type=str, default='lambda', help='learning rate policy: lambda|step|plateau|cosine') | |
parser.add_argument('--lr_decay_iters', type=int, default=50, help='multiply by a gamma every lr_decay_iters iterations') | |
parser.add_argument('--epoch_count', type=int, default=1, help='the starting epoch count, we save the model by <epoch_count>, <epoch_count>+<save_latest_freq>, ...') | |
parser.add_argument('--niter', type=int, default=20, help='# of iter at starting learning rate') | |
parser.add_argument('--niter_decay', type=int, default=10, help='# of iter to linearly decay learning rate to zero') | |
# loss options | |
parser.add_argument('--lambda_dis', type=float, default=1.0, help='discriminator weight in loss') | |
parser.add_argument('--lambda_aus', type=float, default=160.0, help='AUs weight in loss') | |
parser.add_argument('--lambda_rec', type=float, default=10.0, help='reconstruct loss weight') | |
parser.add_argument('--lambda_mask', type=float, default=0, help='mse loss weight') | |
parser.add_argument('--lambda_tv', type=float, default=0, help='total variation loss weight') | |
parser.add_argument('--lambda_wgan_gp', type=float, default=10., help='wgan gradient penalty weight') | |
# frequency options | |
parser.add_argument('--train_gen_iter', type=int, default=5, help='train G every n iterations.') | |
parser.add_argument('--print_losses_freq', type=int, default=100, help='print log every print_freq step.') | |
parser.add_argument('--plot_losses_freq', type=int, default=20000, help='plot log every plot_freq step.') | |
parser.add_argument('--sample_img_freq', type=int, default=2000, help='draw image every sample_img_freq step.') | |
parser.add_argument('--save_epoch_freq', type=int, default=2, help='save checkpoint every save_epoch_freq epoch.') | |
return parser | |
def parse(self): | |
parser = self.initialize() | |
parser.set_defaults(name=datetime.now().strftime("%y%m%d_%H%M%S")) | |
opt = parser.parse_args() | |
dataset_name = os.path.basename(opt.data_root.strip('/')) | |
# update checkpoint dir | |
if opt.mode == 'train' and opt.load_epoch == 0: | |
opt.ckpt_dir = os.path.join(opt.ckpt_dir, dataset_name, opt.model, opt.name) | |
if not os.path.exists(opt.ckpt_dir): | |
os.makedirs(opt.ckpt_dir) | |
# if test, disable visdom, update results path | |
if opt.mode == "test": | |
opt.visdom_display_id = 0 | |
opt.results = os.path.join(opt.results, "%s_%s_%s" % (dataset_name, opt.model, opt.load_epoch)) | |
if not os.path.exists(opt.results): | |
os.makedirs(opt.results) | |
# set gpu device | |
str_ids = opt.gpu_ids.split(',') | |
opt.gpu_ids = [] | |
for str_id in str_ids: | |
cur_id = int(str_id) | |
if cur_id >= 0: | |
opt.gpu_ids.append(cur_id) | |
if len(opt.gpu_ids) > 0: | |
torch.cuda.set_device(opt.gpu_ids[0]) | |
# set seed | |
if opt.lucky_seed == 0: | |
opt.lucky_seed = int(time.time()) | |
random.seed(a=opt.lucky_seed) | |
np.random.seed(seed=opt.lucky_seed) | |
torch.manual_seed(opt.lucky_seed) | |
if len(opt.gpu_ids) > 0: | |
torch.backends.cudnn.deterministic = True | |
torch.backends.cudnn.benchmark = False | |
torch.cuda.manual_seed(opt.lucky_seed) | |
torch.cuda.manual_seed_all(opt.lucky_seed) | |
# write command to file | |
script_dir = opt.ckpt_dir | |
with open(os.path.join(os.path.join(script_dir, "run_script.sh")), 'a+') as f: | |
f.write("[%5s][%s]python %s\n" % (opt.mode, opt.name, ' '.join(sys.argv))) | |
# print and write options file | |
msg = '' | |
msg += '------------------- [%5s][%s]Options --------------------\n' % (opt.mode, opt.name) | |
for k, v in sorted(vars(opt).items()): | |
comment = '' | |
default_v = parser.get_default(k) | |
if v != default_v: | |
comment = '\t[default: %s]' % str(default_v) | |
msg += '{:>25}: {:<30}{}\n'.format(str(k), str(v), comment) | |
msg += '--------------------- [%5s][%s]End ----------------------\n' % (opt.mode, opt.name) | |
print(msg) | |
with open(os.path.join(os.path.join(script_dir, "opt.txt")), 'a+') as f: | |
f.write(msg + '\n\n') | |
return opt | |