File size: 13,402 Bytes
97b6013
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
# Copyright 2016 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Functions to create a DSN model and add the different losses to it.

Specifically, in this file we define the:
  - Shared Encoding Similarity Loss Module, with:
    - The MMD Similarity method
    - The Correlation Similarity method
    - The Gradient Reversal (Domain-Adversarial) method
  - Difference Loss Module
  - Reconstruction Loss Module
  - Task Loss Module
"""
from functools import partial

import tensorflow as tf

import losses
import models
import utils

slim = tf.contrib.slim


################################################################################
# HELPER FUNCTIONS
################################################################################
def dsn_loss_coefficient(params):
  """The global_step-dependent weight that specifies when to kick in DSN losses.

  Args:
    params: A dictionary of parameters. Expecting 'domain_separation_startpoint'

  Returns:
    A weight to that effectively enables or disables the DSN-related losses,
    i.e. similarity, difference, and reconstruction losses.
  """
  return tf.where(
      tf.less(slim.get_or_create_global_step(),
              params['domain_separation_startpoint']), 1e-10, 1.0)


################################################################################
# MODEL CREATION
################################################################################
def create_model(source_images, source_labels, domain_selection_mask,
                 target_images, target_labels, similarity_loss, params,
                 basic_tower_name):
  """Creates a DSN model.

  Args:
    source_images: images from the source domain, a tensor of size
      [batch_size, height, width, channels]
    source_labels: a dictionary with the name, tensor pairs. 'classes' is one-
      hot for the number of classes.
    domain_selection_mask: a boolean tensor of size [batch_size, ] which denotes
      the labeled images that belong to the source domain.
    target_images: images from the target domain, a tensor of size
      [batch_size, height width, channels].
    target_labels: a dictionary with the name, tensor pairs.
    similarity_loss: The type of method to use for encouraging
      the codes from the shared encoder to be similar.
    params: A dictionary of parameters. Expecting 'weight_decay',
      'layers_to_regularize', 'use_separation', 'domain_separation_startpoint',
      'alpha_weight', 'beta_weight', 'gamma_weight', 'recon_loss_name',
      'decoder_name', 'encoder_name'
    basic_tower_name: the name of the tower to use for the shared encoder.

  Raises:
    ValueError: if the arch is not one of the available architectures.
  """
  network = getattr(models, basic_tower_name)
  num_classes = source_labels['classes'].get_shape().as_list()[1]

  # Make sure we are using the appropriate number of classes.
  network = partial(network, num_classes=num_classes)

  # Add the classification/pose estimation loss to the source domain.
  source_endpoints = add_task_loss(source_images, source_labels, network,
                                   params)

  if similarity_loss == 'none':
    # No domain adaptation, we can stop here.
    return

  with tf.variable_scope('towers', reuse=True):
    target_logits, target_endpoints = network(
        target_images, weight_decay=params['weight_decay'], prefix='target')

  # Plot target accuracy of the train set.
  target_accuracy = utils.accuracy(
      tf.argmax(target_logits, 1), tf.argmax(target_labels['classes'], 1))

  if 'quaternions' in target_labels:
    target_quaternion_loss = losses.log_quaternion_loss(
        target_labels['quaternions'], target_endpoints['quaternion_pred'],
        params)
    tf.summary.scalar('eval/Target quaternions', target_quaternion_loss)

  tf.summary.scalar('eval/Target accuracy', target_accuracy)

  source_shared = source_endpoints[params['layers_to_regularize']]
  target_shared = target_endpoints[params['layers_to_regularize']]

  # When using the semisupervised model we include labeled target data in the
  # source classifier. We do not want to include these target domain when
  # we use the similarity loss.
  indices = tf.range(0, source_shared.get_shape().as_list()[0])
  indices = tf.boolean_mask(indices, domain_selection_mask)
  add_similarity_loss(similarity_loss,
                      tf.gather(source_shared, indices),
                      tf.gather(target_shared, indices), params)

  if params['use_separation']:
    add_autoencoders(
        source_images,
        source_shared,
        target_images,
        target_shared,
        params=params,)


def add_similarity_loss(method_name,
                        source_samples,
                        target_samples,
                        params,
                        scope=None):
  """Adds a loss encouraging the shared encoding from each domain to be similar.

  Args:
    method_name: the name of the encoding similarity method to use. Valid
      options include `dann_loss', `mmd_loss' or `correlation_loss'.
    source_samples: a tensor of shape [num_samples, num_features].
    target_samples: a tensor of shape [num_samples, num_features].
    params: a dictionary of parameters. Expecting 'gamma_weight'.
    scope: optional name scope for summary tags.
  Raises:
    ValueError: if `method_name` is not recognized.
  """
  weight = dsn_loss_coefficient(params) * params['gamma_weight']
  method = getattr(losses, method_name)
  method(source_samples, target_samples, weight, scope)


def add_reconstruction_loss(recon_loss_name, images, recons, weight, domain):
  """Adds a reconstruction loss.

  Args:
    recon_loss_name: The name of the reconstruction loss.
    images: A `Tensor` of size [batch_size, height, width, 3].
    recons: A `Tensor` whose size matches `images`.
    weight: A scalar coefficient for the loss.
    domain: The name of the domain being reconstructed.

  Raises:
    ValueError: If `recon_loss_name` is not recognized.
  """
  if recon_loss_name == 'sum_of_pairwise_squares':
    loss_fn = tf.contrib.losses.mean_pairwise_squared_error
  elif recon_loss_name == 'sum_of_squares':
    loss_fn = tf.contrib.losses.mean_squared_error
  else:
    raise ValueError('recon_loss_name value [%s] not recognized.' %
                     recon_loss_name)

  loss = loss_fn(recons, images, weight)
  assert_op = tf.Assert(tf.is_finite(loss), [loss])
  with tf.control_dependencies([assert_op]):
    tf.summary.scalar('losses/%s Recon Loss' % domain, loss)


def add_autoencoders(source_data, source_shared, target_data, target_shared,
                     params):
  """Adds the encoders/decoders for our domain separation model w/ incoherence.

  Args:
    source_data: images from the source domain, a tensor of size
      [batch_size, height, width, channels]
    source_shared: a tensor with first dimension batch_size
    target_data: images from the target domain, a tensor of size
      [batch_size, height, width, channels]
    target_shared: a tensor with first dimension batch_size
    params: A dictionary of parameters. Expecting 'layers_to_regularize',
      'beta_weight', 'alpha_weight', 'recon_loss_name', 'decoder_name',
      'encoder_name', 'weight_decay'
  """

  def normalize_images(images):
    images -= tf.reduce_min(images)
    return images / tf.reduce_max(images)

  def concat_operation(shared_repr, private_repr):
    return shared_repr + private_repr

  mu = dsn_loss_coefficient(params)

  # The layer to concatenate the networks at.
  concat_layer = params['layers_to_regularize']

  # The coefficient for modulating the private/shared difference loss.
  difference_loss_weight = params['beta_weight'] * mu

  # The reconstruction weight.
  recon_loss_weight = params['alpha_weight'] * mu

  # The reconstruction loss to use.
  recon_loss_name = params['recon_loss_name']

  # The decoder/encoder to use.
  decoder_name = params['decoder_name']
  encoder_name = params['encoder_name']

  _, height, width, _ = source_data.get_shape().as_list()
  code_size = source_shared.get_shape().as_list()[-1]
  weight_decay = params['weight_decay']

  encoder_fn = getattr(models, encoder_name)
  # Target Auto-encoding.
  with tf.variable_scope('source_encoder'):
    source_endpoints = encoder_fn(
        source_data, code_size, weight_decay=weight_decay)

  with tf.variable_scope('target_encoder'):
    target_endpoints = encoder_fn(
        target_data, code_size, weight_decay=weight_decay)

  decoder_fn = getattr(models, decoder_name)

  decoder = partial(
      decoder_fn,
      height=height,
      width=width,
      channels=source_data.get_shape().as_list()[-1],
      weight_decay=weight_decay)

  # Source Auto-encoding.
  source_private = source_endpoints[concat_layer]
  target_private = target_endpoints[concat_layer]
  with tf.variable_scope('decoder'):
    source_recons = decoder(concat_operation(source_shared, source_private))

  with tf.variable_scope('decoder', reuse=True):
    source_private_recons = decoder(
        concat_operation(tf.zeros_like(source_private), source_private))
    source_shared_recons = decoder(
        concat_operation(source_shared, tf.zeros_like(source_shared)))

  with tf.variable_scope('decoder', reuse=True):
    target_recons = decoder(concat_operation(target_shared, target_private))
    target_shared_recons = decoder(
        concat_operation(target_shared, tf.zeros_like(target_shared)))
    target_private_recons = decoder(
        concat_operation(tf.zeros_like(target_private), target_private))

  losses.difference_loss(
      source_private,
      source_shared,
      weight=difference_loss_weight,
      name='Source')
  losses.difference_loss(
      target_private,
      target_shared,
      weight=difference_loss_weight,
      name='Target')

  add_reconstruction_loss(recon_loss_name, source_data, source_recons,
                          recon_loss_weight, 'source')
  add_reconstruction_loss(recon_loss_name, target_data, target_recons,
                          recon_loss_weight, 'target')

  # Add summaries
  source_reconstructions = tf.concat(
      axis=2,
      values=map(normalize_images, [
          source_data, source_recons, source_shared_recons,
          source_private_recons
      ]))
  target_reconstructions = tf.concat(
      axis=2,
      values=map(normalize_images, [
          target_data, target_recons, target_shared_recons,
          target_private_recons
      ]))
  tf.summary.image(
      'Source Images:Recons:RGB',
      source_reconstructions[:, :, :, :3],
      max_outputs=10)
  tf.summary.image(
      'Target Images:Recons:RGB',
      target_reconstructions[:, :, :, :3],
      max_outputs=10)

  if source_reconstructions.get_shape().as_list()[3] == 4:
    tf.summary.image(
        'Source Images:Recons:Depth',
        source_reconstructions[:, :, :, 3:4],
        max_outputs=10)
    tf.summary.image(
        'Target Images:Recons:Depth',
        target_reconstructions[:, :, :, 3:4],
        max_outputs=10)


def add_task_loss(source_images, source_labels, basic_tower, params):
  """Adds a classification and/or pose estimation loss to the model.

  Args:
    source_images: images from the source domain, a tensor of size
      [batch_size, height, width, channels]
    source_labels: labels from the source domain, a tensor of size [batch_size].
      or a tuple of (quaternions, class_labels)
    basic_tower: a function that creates the single tower of the model.
    params: A dictionary of parameters. Expecting 'weight_decay', 'pose_weight'.
  Returns:
    The source endpoints.

  Raises:
    RuntimeError: if basic tower does not support pose estimation.
  """
  with tf.variable_scope('towers'):
    source_logits, source_endpoints = basic_tower(
        source_images, weight_decay=params['weight_decay'], prefix='Source')

  if 'quaternions' in source_labels:  # We have pose estimation as well
    if 'quaternion_pred' not in source_endpoints:
      raise RuntimeError('Please use a model for estimation e.g. pose_mini')

    loss = losses.log_quaternion_loss(source_labels['quaternions'],
                                      source_endpoints['quaternion_pred'],
                                      params)

    assert_op = tf.Assert(tf.is_finite(loss), [loss])
    with tf.control_dependencies([assert_op]):
      quaternion_loss = loss
      tf.summary.histogram('log_quaternion_loss_hist', quaternion_loss)
    slim.losses.add_loss(quaternion_loss * params['pose_weight'])
    tf.summary.scalar('losses/quaternion_loss', quaternion_loss)

  classification_loss = tf.losses.softmax_cross_entropy(
      source_labels['classes'], source_logits)

  tf.summary.scalar('losses/classification_loss', classification_loss)
  return source_endpoints