File size: 8,232 Bytes
18ddfe2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
# Copyright 2017 Google Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Define model HParams."""
import tensorflow as tf


def create_hparams(hparam_string=None):
  """Create model hyperparameters. Parse nondefault from given string."""
  hparams = tf.contrib.training.HParams(
      # The name of the architecture to use.
      arch='resnet',
      lrelu_leakiness=0.2,
      batch_norm_decay=0.9,
      weight_decay=1e-5,
      normal_init_std=0.02,
      generator_kernel_size=3,
      discriminator_kernel_size=3,

      # Stop training after this many examples are processed
      # If none, train indefinitely
      num_training_examples=0,

      # Apply data augmentation to datasets
      # Applies only in training job
      augment_source_images=False,
      augment_target_images=False,

      # Discriminator
      # Number of filters in first layer of discriminator
      num_discriminator_filters=64,
      discriminator_conv_block_size=1,  # How many convs to have at each size
      discriminator_filter_factor=2.0,  # Multiply # filters by this each layer
      # Add gaussian noise with this stddev to every hidden layer of D
      discriminator_noise_stddev=0.2,  # lmetz: Start seeing results at >= 0.1
      # If true, add this gaussian noise to input images to D as well
      discriminator_image_noise=False,
      discriminator_first_stride=1,  # Stride in first conv of discriminator
      discriminator_do_pooling=False,  # If true, replace stride 2 with avg pool
      discriminator_dropout_keep_prob=0.9,  # keep probability for dropout

      # DCGAN Generator
      # Number of filters in generator decoder last layer (repeatedly halved
      # from 1st layer)
      num_decoder_filters=64,
      # Number of filters in generator encoder 1st layer (repeatedly doubled
      # after 1st layer)
      num_encoder_filters=64,

      # This is the shape to which the noise vector is projected (if we're
      # transferring from noise).
      # Write this way instead of [4, 4, 64] for hparam search flexibility
      projection_shape_size=4,
      projection_shape_channels=64,

      # Indicates the method by which we enlarge the spatial representation
      # of an image. Possible values include:
      # - resize_conv: Performs a nearest neighbor resize followed by a conv.
      # - conv2d_transpose: Performs a conv2d_transpose.
      upsample_method='resize_conv',

      # Visualization
      summary_steps=500,  # Output image summary every N steps

      ###################################
      # Task Classifier Hyperparameters #
      ###################################

      # Which task-specific prediction tower to use. Possible choices are:
      #  none: No task tower.
      #  doubling_pose_estimator: classifier + quaternion regressor.
      #    [conv + pool]* + FC
      # Classifiers used in DSN paper:
      #  gtsrb: Classifier used for GTSRB
      #  svhn: Classifier used for SVHN
      #  mnist: Classifier used for MNIST
      #  pose_mini: Classifier + regressor used for pose_mini
      task_tower='doubling_pose_estimator',
      weight_decay_task_classifier=1e-5,
      source_task_loss_weight=1.0,
      transferred_task_loss_weight=1.0,

      # Number of private layers in doubling_pose_estimator task tower
      num_private_layers=2,

      # The weight for the log quaternion loss we use for source and transferred
      # samples of the cropped_linemod dataset.
      # In the DSN work, 1/8 of the classifier weight worked well for our log
      # quaternion loss
      source_pose_weight=0.125 * 2.0,
      transferred_pose_weight=0.125 * 1.0,

      # If set to True, the style transfer network also attempts to change its
      # weights to maximize the performance of the task tower. If set to False,
      # then the style transfer network only attempts to change its weights to
      # make the transferred images more likely according to the domain
      # classifier.
      task_tower_in_g_step=True,
      task_loss_in_g_weight=1.0,  # Weight of task loss in G

      #########################################
      # 'simple` generator arch model hparams #
      #########################################
      simple_num_conv_layers=1,
      simple_conv_filters=8,

      #########################
      # Resnet Hyperparameters#
      #########################
      resnet_blocks=6,  # Number of resnet blocks
      resnet_filters=64,  # Number of filters per conv in resnet blocks
      # If true, add original input back to result of convolutions inside the
      # resnet arch. If false, it turns into a simple stack of conv/relu/BN
      # layers.
      resnet_residuals=True,

      #######################################
      # The residual / interpretable model. #
      #######################################
      res_int_blocks=2,  # The number of residual blocks.
      res_int_convs=2,  # The number of conv calls inside each block.
      res_int_filters=64,  # The number of filters used by each convolution.

      ####################
      # Latent variables #
      ####################
      # if true, then generate random noise and project to input for generator
      noise_channel=True,
      # The number of dimensions in the input noise vector.
      noise_dims=10,

      # If true, then one hot encode source image class and project as an
      # additional channel for the input to generator. This gives the generator
      # access to the class, which may help generation performance.
      condition_on_source_class=False,

      ########################
      # Loss Hyperparameters #
      ########################
      domain_loss_weight=1.0,
      style_transfer_loss_weight=1.0,

      ########################################################################
      # Encourages the transferred images to be similar to the source images #
      # using a configurable metric.                                         #
      ########################################################################

      # The weight of the loss function encouraging the source and transferred
      # images to be similar. If set to 0, then the loss function is not used.
      transferred_similarity_loss_weight=0.0,

      # The type of loss used to encourage transferred and source image
      # similarity. Valid values include:
      #   mpse: Mean Pairwise Squared Error
      #   mse: Mean Squared Error
      #   hinged_mse: Computes the mean squared error using squared differences
      #     greater than hparams.transferred_similarity_max_diff
      #   hinged_mae: Computes the mean absolute error using absolute
      #     differences greater than hparams.transferred_similarity_max_diff.
      transferred_similarity_loss='mpse',

      # The maximum allowable difference between the source and target images.
      # This value is used, in effect, to produce a hinge loss. Note that the
      # range of values should be between 0 and 1.
      transferred_similarity_max_diff=0.4,

      ################################
      # Optimization Hyperparameters #
      ################################
      learning_rate=0.001,
      batch_size=32,
      lr_decay_steps=20000,
      lr_decay_rate=0.95,

      # Recomendation from the DCGAN paper:
      adam_beta1=0.5,
      clip_gradient_norm=5.0,

      # The number of times we run the discriminator train_op in a row.
      discriminator_steps=1,

      # The number of times we run the generator train_op in a row.
      generator_steps=1)

  if hparam_string:
    tf.logging.info('Parsing command line hparams: %s', hparam_string)
    hparams.parse(hparam_string)

  tf.logging.info('Final parsed hparams: %s', hparams.values())
  return hparams