File size: 13,402 Bytes
97b6013 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 |
# Copyright 2016 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Functions to create a DSN model and add the different losses to it.
Specifically, in this file we define the:
- Shared Encoding Similarity Loss Module, with:
- The MMD Similarity method
- The Correlation Similarity method
- The Gradient Reversal (Domain-Adversarial) method
- Difference Loss Module
- Reconstruction Loss Module
- Task Loss Module
"""
from functools import partial
import tensorflow as tf
import losses
import models
import utils
slim = tf.contrib.slim
################################################################################
# HELPER FUNCTIONS
################################################################################
def dsn_loss_coefficient(params):
"""The global_step-dependent weight that specifies when to kick in DSN losses.
Args:
params: A dictionary of parameters. Expecting 'domain_separation_startpoint'
Returns:
A weight to that effectively enables or disables the DSN-related losses,
i.e. similarity, difference, and reconstruction losses.
"""
return tf.where(
tf.less(slim.get_or_create_global_step(),
params['domain_separation_startpoint']), 1e-10, 1.0)
################################################################################
# MODEL CREATION
################################################################################
def create_model(source_images, source_labels, domain_selection_mask,
target_images, target_labels, similarity_loss, params,
basic_tower_name):
"""Creates a DSN model.
Args:
source_images: images from the source domain, a tensor of size
[batch_size, height, width, channels]
source_labels: a dictionary with the name, tensor pairs. 'classes' is one-
hot for the number of classes.
domain_selection_mask: a boolean tensor of size [batch_size, ] which denotes
the labeled images that belong to the source domain.
target_images: images from the target domain, a tensor of size
[batch_size, height width, channels].
target_labels: a dictionary with the name, tensor pairs.
similarity_loss: The type of method to use for encouraging
the codes from the shared encoder to be similar.
params: A dictionary of parameters. Expecting 'weight_decay',
'layers_to_regularize', 'use_separation', 'domain_separation_startpoint',
'alpha_weight', 'beta_weight', 'gamma_weight', 'recon_loss_name',
'decoder_name', 'encoder_name'
basic_tower_name: the name of the tower to use for the shared encoder.
Raises:
ValueError: if the arch is not one of the available architectures.
"""
network = getattr(models, basic_tower_name)
num_classes = source_labels['classes'].get_shape().as_list()[1]
# Make sure we are using the appropriate number of classes.
network = partial(network, num_classes=num_classes)
# Add the classification/pose estimation loss to the source domain.
source_endpoints = add_task_loss(source_images, source_labels, network,
params)
if similarity_loss == 'none':
# No domain adaptation, we can stop here.
return
with tf.variable_scope('towers', reuse=True):
target_logits, target_endpoints = network(
target_images, weight_decay=params['weight_decay'], prefix='target')
# Plot target accuracy of the train set.
target_accuracy = utils.accuracy(
tf.argmax(target_logits, 1), tf.argmax(target_labels['classes'], 1))
if 'quaternions' in target_labels:
target_quaternion_loss = losses.log_quaternion_loss(
target_labels['quaternions'], target_endpoints['quaternion_pred'],
params)
tf.summary.scalar('eval/Target quaternions', target_quaternion_loss)
tf.summary.scalar('eval/Target accuracy', target_accuracy)
source_shared = source_endpoints[params['layers_to_regularize']]
target_shared = target_endpoints[params['layers_to_regularize']]
# When using the semisupervised model we include labeled target data in the
# source classifier. We do not want to include these target domain when
# we use the similarity loss.
indices = tf.range(0, source_shared.get_shape().as_list()[0])
indices = tf.boolean_mask(indices, domain_selection_mask)
add_similarity_loss(similarity_loss,
tf.gather(source_shared, indices),
tf.gather(target_shared, indices), params)
if params['use_separation']:
add_autoencoders(
source_images,
source_shared,
target_images,
target_shared,
params=params,)
def add_similarity_loss(method_name,
source_samples,
target_samples,
params,
scope=None):
"""Adds a loss encouraging the shared encoding from each domain to be similar.
Args:
method_name: the name of the encoding similarity method to use. Valid
options include `dann_loss', `mmd_loss' or `correlation_loss'.
source_samples: a tensor of shape [num_samples, num_features].
target_samples: a tensor of shape [num_samples, num_features].
params: a dictionary of parameters. Expecting 'gamma_weight'.
scope: optional name scope for summary tags.
Raises:
ValueError: if `method_name` is not recognized.
"""
weight = dsn_loss_coefficient(params) * params['gamma_weight']
method = getattr(losses, method_name)
method(source_samples, target_samples, weight, scope)
def add_reconstruction_loss(recon_loss_name, images, recons, weight, domain):
"""Adds a reconstruction loss.
Args:
recon_loss_name: The name of the reconstruction loss.
images: A `Tensor` of size [batch_size, height, width, 3].
recons: A `Tensor` whose size matches `images`.
weight: A scalar coefficient for the loss.
domain: The name of the domain being reconstructed.
Raises:
ValueError: If `recon_loss_name` is not recognized.
"""
if recon_loss_name == 'sum_of_pairwise_squares':
loss_fn = tf.contrib.losses.mean_pairwise_squared_error
elif recon_loss_name == 'sum_of_squares':
loss_fn = tf.contrib.losses.mean_squared_error
else:
raise ValueError('recon_loss_name value [%s] not recognized.' %
recon_loss_name)
loss = loss_fn(recons, images, weight)
assert_op = tf.Assert(tf.is_finite(loss), [loss])
with tf.control_dependencies([assert_op]):
tf.summary.scalar('losses/%s Recon Loss' % domain, loss)
def add_autoencoders(source_data, source_shared, target_data, target_shared,
params):
"""Adds the encoders/decoders for our domain separation model w/ incoherence.
Args:
source_data: images from the source domain, a tensor of size
[batch_size, height, width, channels]
source_shared: a tensor with first dimension batch_size
target_data: images from the target domain, a tensor of size
[batch_size, height, width, channels]
target_shared: a tensor with first dimension batch_size
params: A dictionary of parameters. Expecting 'layers_to_regularize',
'beta_weight', 'alpha_weight', 'recon_loss_name', 'decoder_name',
'encoder_name', 'weight_decay'
"""
def normalize_images(images):
images -= tf.reduce_min(images)
return images / tf.reduce_max(images)
def concat_operation(shared_repr, private_repr):
return shared_repr + private_repr
mu = dsn_loss_coefficient(params)
# The layer to concatenate the networks at.
concat_layer = params['layers_to_regularize']
# The coefficient for modulating the private/shared difference loss.
difference_loss_weight = params['beta_weight'] * mu
# The reconstruction weight.
recon_loss_weight = params['alpha_weight'] * mu
# The reconstruction loss to use.
recon_loss_name = params['recon_loss_name']
# The decoder/encoder to use.
decoder_name = params['decoder_name']
encoder_name = params['encoder_name']
_, height, width, _ = source_data.get_shape().as_list()
code_size = source_shared.get_shape().as_list()[-1]
weight_decay = params['weight_decay']
encoder_fn = getattr(models, encoder_name)
# Target Auto-encoding.
with tf.variable_scope('source_encoder'):
source_endpoints = encoder_fn(
source_data, code_size, weight_decay=weight_decay)
with tf.variable_scope('target_encoder'):
target_endpoints = encoder_fn(
target_data, code_size, weight_decay=weight_decay)
decoder_fn = getattr(models, decoder_name)
decoder = partial(
decoder_fn,
height=height,
width=width,
channels=source_data.get_shape().as_list()[-1],
weight_decay=weight_decay)
# Source Auto-encoding.
source_private = source_endpoints[concat_layer]
target_private = target_endpoints[concat_layer]
with tf.variable_scope('decoder'):
source_recons = decoder(concat_operation(source_shared, source_private))
with tf.variable_scope('decoder', reuse=True):
source_private_recons = decoder(
concat_operation(tf.zeros_like(source_private), source_private))
source_shared_recons = decoder(
concat_operation(source_shared, tf.zeros_like(source_shared)))
with tf.variable_scope('decoder', reuse=True):
target_recons = decoder(concat_operation(target_shared, target_private))
target_shared_recons = decoder(
concat_operation(target_shared, tf.zeros_like(target_shared)))
target_private_recons = decoder(
concat_operation(tf.zeros_like(target_private), target_private))
losses.difference_loss(
source_private,
source_shared,
weight=difference_loss_weight,
name='Source')
losses.difference_loss(
target_private,
target_shared,
weight=difference_loss_weight,
name='Target')
add_reconstruction_loss(recon_loss_name, source_data, source_recons,
recon_loss_weight, 'source')
add_reconstruction_loss(recon_loss_name, target_data, target_recons,
recon_loss_weight, 'target')
# Add summaries
source_reconstructions = tf.concat(
axis=2,
values=map(normalize_images, [
source_data, source_recons, source_shared_recons,
source_private_recons
]))
target_reconstructions = tf.concat(
axis=2,
values=map(normalize_images, [
target_data, target_recons, target_shared_recons,
target_private_recons
]))
tf.summary.image(
'Source Images:Recons:RGB',
source_reconstructions[:, :, :, :3],
max_outputs=10)
tf.summary.image(
'Target Images:Recons:RGB',
target_reconstructions[:, :, :, :3],
max_outputs=10)
if source_reconstructions.get_shape().as_list()[3] == 4:
tf.summary.image(
'Source Images:Recons:Depth',
source_reconstructions[:, :, :, 3:4],
max_outputs=10)
tf.summary.image(
'Target Images:Recons:Depth',
target_reconstructions[:, :, :, 3:4],
max_outputs=10)
def add_task_loss(source_images, source_labels, basic_tower, params):
"""Adds a classification and/or pose estimation loss to the model.
Args:
source_images: images from the source domain, a tensor of size
[batch_size, height, width, channels]
source_labels: labels from the source domain, a tensor of size [batch_size].
or a tuple of (quaternions, class_labels)
basic_tower: a function that creates the single tower of the model.
params: A dictionary of parameters. Expecting 'weight_decay', 'pose_weight'.
Returns:
The source endpoints.
Raises:
RuntimeError: if basic tower does not support pose estimation.
"""
with tf.variable_scope('towers'):
source_logits, source_endpoints = basic_tower(
source_images, weight_decay=params['weight_decay'], prefix='Source')
if 'quaternions' in source_labels: # We have pose estimation as well
if 'quaternion_pred' not in source_endpoints:
raise RuntimeError('Please use a model for estimation e.g. pose_mini')
loss = losses.log_quaternion_loss(source_labels['quaternions'],
source_endpoints['quaternion_pred'],
params)
assert_op = tf.Assert(tf.is_finite(loss), [loss])
with tf.control_dependencies([assert_op]):
quaternion_loss = loss
tf.summary.histogram('log_quaternion_loss_hist', quaternion_loss)
slim.losses.add_loss(quaternion_loss * params['pose_weight'])
tf.summary.scalar('losses/quaternion_loss', quaternion_loss)
classification_loss = tf.losses.softmax_cross_entropy(
source_labels['classes'], source_logits)
tf.summary.scalar('losses/classification_loss', classification_loss)
return source_endpoints
|