File size: 8,456 Bytes
ec0fdfd
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
import argparse
import os
import torch
import model
from util import util


class BaseOptions():
    def __init__(self):
        self.parser = argparse.ArgumentParser()
        self.initialized = False

    def initialize(self, parser):
        # base define
        parser.add_argument('--name', type=str, default='experiment_name', help='name of the experiment.')
        parser.add_argument('--model', type=str, default='tc', help='name of the model type. [pluralistic]')
        parser.add_argument('--checkpoints_dir', type=str, default='./checkpoints', help='models are save here')
        parser.add_argument('--which_iter', type=int, default='0', help='which iterations to load')
        parser.add_argument('--epoch', type=str, default='latest', help='which epoch to load')
        parser.add_argument('--gpu_ids', type=str, default='0', help='gpu ids: e.g. 0, 1, 2 use -1 for CPU')
        # data define
        parser.add_argument('--mask_type', type=int, default=[0,1,3], help='0:center,1:regular,2:irregular,3:external')
        parser.add_argument('--img_file', type=str, default='/data/dataset/train', help='training and testing dataset')
        parser.add_argument('--mask_file', type=str, default='none', help='load test mask')
        parser.add_argument('--img_nc', type=int, default=3, help='# of image channels')
        parser.add_argument('--preprocess', type=str, default='resize_and_crop', help='preprocessing image at load time')
        parser.add_argument('--load_size', type=int, default=542, help='scale examples to this size')
        parser.add_argument('--fine_size', type=int, default=512, help='then crop to this size')
        parser.add_argument('--fixed_size', type=int, default=256, help='fixed the image size in S1 with transformer')
        parser.add_argument('--no_flip', action='store_true', help='if specified, do not flip the image')
        parser.add_argument('--data_powers', type=int, default=5, help='# times of the scale to 2 times')
        parser.add_argument('--reverse_mask', action='store_true', help='if specified, random reverse the mask region')
        parser.add_argument('--batch_size', type=int, default=8, help='input batch size')
        parser.add_argument('--nThreads', type=int, default=8, help='# threads for loading data')
        parser.add_argument('--no_shuffle', action='store_true', help='if true, takes examples serial')
        # display parameter define
        parser.add_argument('--display_winsize', type=int, default=256, help='display window size')
        parser.add_argument('--display_id', type=int, default=None, help='display id of the web')
        parser.add_argument('--display_server', type=str, default="http://localhost", help='server of the web display')
        parser.add_argument('--display_env', type=str, default='main', help='display name (default is "main")')
        parser.add_argument('--display_port', type=int, default=8092, help='port of the web display')
        parser.add_argument('--display_single_pane_ncols', type=int, default=0, help='if positive, display all examples in a single visidom web panel')
        # Encoder-Decoder define
        parser.add_argument('--ngf', type=int, default=32, help='# of gen filters in the last conv layer')
        parser.add_argument('--ndf', type=int, default=32, help='# of dis filters in the first conv layer')
        parser.add_argument('--num_res_blocks', type=int, default=2, help='# of residual block in the encoder and decoder layer')
        parser.add_argument('--netD', type=str, default='style', help='specify discriminator architecture ')
        parser.add_argument('--netG', type=str, default='diff', help='specify decoder architecture')
        parser.add_argument('--netE', type=str, default='diff', help='specify encoder architecture')
        parser.add_argument('--kernel_G', type=int, default=3, help='kernel size for the decoder')
        parser.add_argument('--kernel_E', type=int, default=1, help='kernel size for the encoder')
        parser.add_argument('--add_noise', action='store_true', help='if true, add noise to the decoder')
        parser.add_argument('--attn_E', action='store_true', help='if true, use attention in the encoder')
        parser.add_argument('--attn_G', action='store_true', help='if true, use attention in the decoder')
        parser.add_argument('--attn_D', action='store_true', help='if true, use attention in the decoder')
        parser.add_argument('--n_layers_D', type=int, default=3, help='only used if netD==n_layers')
        parser.add_argument('--n_layers_G', type=int, default=4, help='# of down sample layers in the Encoder and Decoder')
        parser.add_argument('--norm', type=str, default='pixel', help='instance normalization or batch normalization [instance | batch | pixel | none]')
        parser.add_argument('--activation', type=str, default='leakyrelu', help='activation layer [relu | gelu | leakyrelu | none]')
        parser.add_argument('--init_type', type=str, default='kaiming', help='network initialization [normal | xavier | kaiming | orthogonal]')
        parser.add_argument('--init_gain', type=float, default=0.02, help='scaling factor for normal, xavier and orthogonal.')
        parser.add_argument('--lipip_path', type=str, default='./model/lpips/vgg.pth', help='the pretrained LIPPS model')
        # Transformer define
        parser.add_argument('--netT', type=str, default='original', help='specify transformer architecture')
        parser.add_argument('--embed_dim', type=int, default=512, help='the numbers of embedding dimension')
        parser.add_argument('--dropout', type=float, default=0., help='the dropout probability in transformer')
        parser.add_argument('--kernel_T', type=int, default=1, help='kernel size for the transformer projection')
        parser.add_argument('--n_encoders', type=int, default=12, help='the numbers of encoder in transformer')
        parser.add_argument('--n_decoders', type=int, default=0, help='the numbers of decoder in transformer')
        parser.add_argument('--embed_type', type=str, default='learned', choices=['learned', 'sine'])
        parser.add_argument('--top_k', type=int, default=10, help='sample the results on top k value')
        # VQ define
        parser.add_argument('--num_embeds', type=int, default=1024, help='the numbers of words for image')
        parser.add_argument('--use_pos_G', action='store_true', help='if true, position embedding in G')
        parser.add_argument('--word_size', type=int, default=16, help='the numbers of word for each image')
        self.initialized = True
        return parser

    def gather_options(self):
        """Add additional model-specific options"""
        if not self.initialized:
            parser = self.initialize(self.parser)

        # get basic options
        opt, _ = parser.parse_known_args()

        # modify the options for different models
        model_option_set = model.get_option_setter(opt.model)
        parser = model_option_set(parser, self.isTrain)
        opt = parser.parse_args()

        return opt

    def parse(self):
        """Parse the options"""
        opt = self.gather_options()
        opt.isTrain = self.isTrain

        self.print_options(opt)

        # set gpu ids
        str_ids = opt.gpu_ids.split(',')
        opt.gpu_ids = []
        for str_id in str_ids:
            id = int(str_id)
            if id >= 0:
                opt.gpu_ids.append(id)
        if len(opt.gpu_ids):
            torch.cuda.set_device(opt.gpu_ids[0])

        self.opt = opt

        return self.opt

    @staticmethod
    def print_options(opt):
        """print and save options"""
        print('--------------Options--------------')
        for k, v in sorted(vars(opt).items()):
            print('%s: %s' % (str(k), str(v)))
        print('----------------End----------------')

        # save to the disk
        expr_dir = os.path.join(opt.checkpoints_dir, opt.name)
        util.mkdirs(expr_dir)
        if opt.isTrain:
            file_name = os.path.join(expr_dir, 'train_opt.txt')
        else:
            file_name = os.path.join(expr_dir, 'test_opt.txt')
        with open(file_name, 'wt') as opt_file:
            opt_file.write('--------------Options--------------\n')
            for k, v in sorted(vars(opt).items()):
                opt_file.write('%s: %s\n' % (str(k), str(v)))
            opt_file.write('----------------End----------------\n')