# Copyright 2017 Google Inc. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Define model HParams.""" import tensorflow as tf def create_hparams(hparam_string=None): """Create model hyperparameters. Parse nondefault from given string.""" hparams = tf.contrib.training.HParams( # The name of the architecture to use. arch='resnet', lrelu_leakiness=0.2, batch_norm_decay=0.9, weight_decay=1e-5, normal_init_std=0.02, generator_kernel_size=3, discriminator_kernel_size=3, # Stop training after this many examples are processed # If none, train indefinitely num_training_examples=0, # Apply data augmentation to datasets # Applies only in training job augment_source_images=False, augment_target_images=False, # Discriminator # Number of filters in first layer of discriminator num_discriminator_filters=64, discriminator_conv_block_size=1, # How many convs to have at each size discriminator_filter_factor=2.0, # Multiply # filters by this each layer # Add gaussian noise with this stddev to every hidden layer of D discriminator_noise_stddev=0.2, # lmetz: Start seeing results at >= 0.1 # If true, add this gaussian noise to input images to D as well discriminator_image_noise=False, discriminator_first_stride=1, # Stride in first conv of discriminator discriminator_do_pooling=False, # If true, replace stride 2 with avg pool discriminator_dropout_keep_prob=0.9, # keep probability for dropout # DCGAN Generator # Number of filters in generator decoder last layer (repeatedly halved # from 1st layer) num_decoder_filters=64, # Number of filters in generator encoder 1st layer (repeatedly doubled # after 1st layer) num_encoder_filters=64, # This is the shape to which the noise vector is projected (if we're # transferring from noise). # Write this way instead of [4, 4, 64] for hparam search flexibility projection_shape_size=4, projection_shape_channels=64, # Indicates the method by which we enlarge the spatial representation # of an image. Possible values include: # - resize_conv: Performs a nearest neighbor resize followed by a conv. # - conv2d_transpose: Performs a conv2d_transpose. upsample_method='resize_conv', # Visualization summary_steps=500, # Output image summary every N steps ################################### # Task Classifier Hyperparameters # ################################### # Which task-specific prediction tower to use. Possible choices are: # none: No task tower. # doubling_pose_estimator: classifier + quaternion regressor. # [conv + pool]* + FC # Classifiers used in DSN paper: # gtsrb: Classifier used for GTSRB # svhn: Classifier used for SVHN # mnist: Classifier used for MNIST # pose_mini: Classifier + regressor used for pose_mini task_tower='doubling_pose_estimator', weight_decay_task_classifier=1e-5, source_task_loss_weight=1.0, transferred_task_loss_weight=1.0, # Number of private layers in doubling_pose_estimator task tower num_private_layers=2, # The weight for the log quaternion loss we use for source and transferred # samples of the cropped_linemod dataset. # In the DSN work, 1/8 of the classifier weight worked well for our log # quaternion loss source_pose_weight=0.125 * 2.0, transferred_pose_weight=0.125 * 1.0, # If set to True, the style transfer network also attempts to change its # weights to maximize the performance of the task tower. If set to False, # then the style transfer network only attempts to change its weights to # make the transferred images more likely according to the domain # classifier. task_tower_in_g_step=True, task_loss_in_g_weight=1.0, # Weight of task loss in G ######################################### # 'simple` generator arch model hparams # ######################################### simple_num_conv_layers=1, simple_conv_filters=8, ######################### # Resnet Hyperparameters# ######################### resnet_blocks=6, # Number of resnet blocks resnet_filters=64, # Number of filters per conv in resnet blocks # If true, add original input back to result of convolutions inside the # resnet arch. If false, it turns into a simple stack of conv/relu/BN # layers. resnet_residuals=True, ####################################### # The residual / interpretable model. # ####################################### res_int_blocks=2, # The number of residual blocks. res_int_convs=2, # The number of conv calls inside each block. res_int_filters=64, # The number of filters used by each convolution. #################### # Latent variables # #################### # if true, then generate random noise and project to input for generator noise_channel=True, # The number of dimensions in the input noise vector. noise_dims=10, # If true, then one hot encode source image class and project as an # additional channel for the input to generator. This gives the generator # access to the class, which may help generation performance. condition_on_source_class=False, ######################## # Loss Hyperparameters # ######################## domain_loss_weight=1.0, style_transfer_loss_weight=1.0, ######################################################################## # Encourages the transferred images to be similar to the source images # # using a configurable metric. # ######################################################################## # The weight of the loss function encouraging the source and transferred # images to be similar. If set to 0, then the loss function is not used. transferred_similarity_loss_weight=0.0, # The type of loss used to encourage transferred and source image # similarity. Valid values include: # mpse: Mean Pairwise Squared Error # mse: Mean Squared Error # hinged_mse: Computes the mean squared error using squared differences # greater than hparams.transferred_similarity_max_diff # hinged_mae: Computes the mean absolute error using absolute # differences greater than hparams.transferred_similarity_max_diff. transferred_similarity_loss='mpse', # The maximum allowable difference between the source and target images. # This value is used, in effect, to produce a hinge loss. Note that the # range of values should be between 0 and 1. transferred_similarity_max_diff=0.4, ################################ # Optimization Hyperparameters # ################################ learning_rate=0.001, batch_size=32, lr_decay_steps=20000, lr_decay_rate=0.95, # Recomendation from the DCGAN paper: adam_beta1=0.5, clip_gradient_norm=5.0, # The number of times we run the discriminator train_op in a row. discriminator_steps=1, # The number of times we run the generator train_op in a row. generator_steps=1) if hparam_string: tf.logging.info('Parsing command line hparams: %s', hparam_string) hparams.parse(hparam_string) tf.logging.info('Final parsed hparams: %s', hparams.values()) return hparams