File size: 9,184 Bytes
1a79cb6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
import argparse
import torch
import os
from datetime import datetime
import time
import torch 
import random
import numpy as np 
import sys



class Options(object):
    """docstring for Options"""
    def __init__(self):
        super(Options, self).__init__()
        
    def initialize(self):
        parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
        parser.add_argument('--mode', type=str, default='train', help='Mode of code. [train|test]')
        parser.add_argument('--model', type=str, default='ganimation', help='[ganimation|stargan], see model.__init__ from more details.')
        parser.add_argument('--lucky_seed', type=int, default=0, help='seed for random initialize, 0 to use current time.')
        parser.add_argument('--visdom_env', type=str, default="main", help='visdom env.')
        parser.add_argument('--visdom_port', type=int, default=8097, help='visdom port.')
        parser.add_argument('--visdom_display_id', type=int, default=1, help='set value larger than 0 to display with visdom.')
        
        parser.add_argument('--results', type=str, default="results", help='save test results to this path.')
        parser.add_argument('--interpolate_len', type=int, default=5, help='interpolate length for test.')
        parser.add_argument('--no_test_eval', action='store_true', help='do not use eval mode during test time.')
        parser.add_argument('--save_test_gif', action='store_true', help='save gif images instead of the concatenation of static images.')

        parser.add_argument('--data_root', required=False, help='paths to data set.')
        parser.add_argument('--imgs_dir', type=str, default="imgs", help='path to image')
        parser.add_argument('--aus_pkl', type=str, default="aus_openface.pkl", help='AUs pickle dictionary.')
        parser.add_argument('--train_csv', type=str, default="train_ids.csv", help='train images paths')
        parser.add_argument('--test_csv', type=str, default="test_ids.csv", help='test images paths')

        parser.add_argument('--batch_size', type=int, default=25, help='input batch size.')
        parser.add_argument('--serial_batches', action='store_true', help='if specified, input images in order.')
        parser.add_argument('--n_threads', type=int, default=6, help='number of workers to load data.')
        parser.add_argument('--max_dataset_size', type=int, default=float("inf"), help='maximum number of samples.')

        parser.add_argument('--resize_or_crop', type=str, default='none', help='Preprocessing image, [resize_and_crop|crop|none]')
        parser.add_argument('--load_size', type=int, default=148, help='scale image to this size.')
        parser.add_argument('--final_size', type=int, default=128, help='crop image to this size.')
        parser.add_argument('--no_flip', action='store_true', help='if specified, do not flip image.')
        parser.add_argument('--no_aus_noise', action='store_true', help='if specified, add noise to target AUs.')

        parser.add_argument('--gpu_ids', type=str, default='0', help='gpu ids, eg. 0,1,2; -1 for cpu.')
        parser.add_argument('--ckpt_dir', type=str, default='./ckpts', help='directory to save check points.')
        parser.add_argument('--load_epoch', type=int, default=0, help='load epoch; 0: do not load')
        parser.add_argument('--log_file', type=str, default="logs.txt", help='log loss')
        parser.add_argument('--opt_file', type=str, default="opt.txt", help='options file')

        # train options 
        parser.add_argument('--img_nc', type=int, default=3, help='image number of channel')
        parser.add_argument('--aus_nc', type=int, default=17, help='aus number of channel')
        parser.add_argument('--ngf', type=int, default=64, help='ngf')
        parser.add_argument('--ndf', type=int, default=64, help='ndf')
        parser.add_argument('--use_dropout', action='store_true', help='if specified, use dropout.')
        
        parser.add_argument('--gan_type', type=str, default='wgan-gp', help='GAN loss [wgan-gp|lsgan|gan]')
        parser.add_argument('--init_type', type=str, default='normal', help='network initialization [normal|xavier|kaiming|orthogonal]')
        parser.add_argument('--init_gain', type=float, default=0.02, help='scaling factor for normal, xavier and orthogonal.')
        parser.add_argument('--norm', type=str, default='instance', help='instance normalization or batch normalization [batch|instance|none]')
        parser.add_argument('--beta1', type=float, default=0.5, help='momentum term of adam')
        parser.add_argument('--lr', type=float, default=0.0001, help='initial learning rate for adam')
        parser.add_argument('--lr_policy', type=str, default='lambda', help='learning rate policy: lambda|step|plateau|cosine')
        parser.add_argument('--lr_decay_iters', type=int, default=50, help='multiply by a gamma every lr_decay_iters iterations')

        parser.add_argument('--epoch_count', type=int, default=1, help='the starting epoch count, we save the model by <epoch_count>, <epoch_count>+<save_latest_freq>, ...')
        parser.add_argument('--niter', type=int, default=20, help='# of iter at starting learning rate')
        parser.add_argument('--niter_decay', type=int, default=10, help='# of iter to linearly decay learning rate to zero')
        
        # loss options 
        parser.add_argument('--lambda_dis', type=float, default=1.0, help='discriminator weight in loss')
        parser.add_argument('--lambda_aus', type=float, default=160.0, help='AUs weight in loss')
        parser.add_argument('--lambda_rec', type=float, default=10.0, help='reconstruct loss weight')
        parser.add_argument('--lambda_mask', type=float, default=0, help='mse loss weight')
        parser.add_argument('--lambda_tv', type=float, default=0, help='total variation loss weight')
        parser.add_argument('--lambda_wgan_gp', type=float, default=10., help='wgan gradient penalty weight')

        # frequency options
        parser.add_argument('--train_gen_iter', type=int, default=5, help='train G every n iterations.')
        parser.add_argument('--print_losses_freq', type=int, default=100, help='print log every print_freq step.')
        parser.add_argument('--plot_losses_freq', type=int, default=20000, help='plot log every plot_freq step.')
        parser.add_argument('--sample_img_freq', type=int, default=2000, help='draw image every sample_img_freq step.')
        parser.add_argument('--save_epoch_freq', type=int, default=2, help='save checkpoint every save_epoch_freq epoch.')
        
        return parser

    def parse(self):
        parser = self.initialize()
        parser.set_defaults(name=datetime.now().strftime("%y%m%d_%H%M%S"))
        opt = parser.parse_args()

        dataset_name = os.path.basename(opt.data_root.strip('/'))
        # update checkpoint dir
        if opt.mode == 'train' and opt.load_epoch == 0:
            opt.ckpt_dir = os.path.join(opt.ckpt_dir, dataset_name, opt.model, opt.name)
            if not os.path.exists(opt.ckpt_dir):
                os.makedirs(opt.ckpt_dir)

        # if test, disable visdom, update results path
        if opt.mode == "test":
            opt.visdom_display_id = 0
            opt.results = os.path.join(opt.results, "%s_%s_%s" % (dataset_name, opt.model, opt.load_epoch))
            if not os.path.exists(opt.results):
                os.makedirs(opt.results)

        # set gpu device
        str_ids = opt.gpu_ids.split(',')
        opt.gpu_ids = []
        for str_id in str_ids:
            cur_id = int(str_id)
            if cur_id >= 0:
                opt.gpu_ids.append(cur_id)
        if len(opt.gpu_ids) > 0:
            torch.cuda.set_device(opt.gpu_ids[0])

        # set seed 
        if opt.lucky_seed == 0:
            opt.lucky_seed = int(time.time())
        random.seed(a=opt.lucky_seed)
        np.random.seed(seed=opt.lucky_seed)
        torch.manual_seed(opt.lucky_seed)
        if len(opt.gpu_ids) > 0:
            torch.backends.cudnn.deterministic = True
            torch.backends.cudnn.benchmark = False
            torch.cuda.manual_seed(opt.lucky_seed)
            torch.cuda.manual_seed_all(opt.lucky_seed)
            
        # write command to file
        script_dir = opt.ckpt_dir 
        with open(os.path.join(os.path.join(script_dir, "run_script.sh")), 'a+') as f:
            f.write("[%5s][%s]python %s\n" % (opt.mode, opt.name, ' '.join(sys.argv)))

        # print and write options file
        msg = ''
        msg += '------------------- [%5s][%s]Options --------------------\n' % (opt.mode, opt.name)
        for k, v in sorted(vars(opt).items()):
            comment = ''
            default_v = parser.get_default(k)
            if v != default_v:
                comment = '\t[default: %s]' % str(default_v)
            msg += '{:>25}: {:<30}{}\n'.format(str(k), str(v), comment)
        msg += '--------------------- [%5s][%s]End ----------------------\n' % (opt.mode, opt.name)
        print(msg)
        with open(os.path.join(os.path.join(script_dir, "opt.txt")), 'a+') as f:
            f.write(msg + '\n\n')

        return opt