File size: 15,086 Bytes
97b6013
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
# Copyright 2017 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

"""Various implementations of sequence layers for character prediction.

A 'sequence layer' is a part of a computation graph which is responsible of
producing a sequence of characters using extracted image features. There are
many reasonable ways to implement such layers. All of them are using RNNs.
This module provides implementations which uses 'attention' mechanism to
spatially 'pool' image features and also can use a previously predicted
character to predict the next (aka auto regression).

Usage:
  Select one of available classes, e.g. Attention or use a wrapper function to
  pick one based on your requirements:
  layer_class = sequence_layers.get_layer_class(use_attention=True,
                                                use_autoregression=True)
  layer = layer_class(net, labels_one_hot, model_params, method_params)
  char_logits = layer.create_logits()
"""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import collections
import abc
import logging
import numpy as np

import tensorflow as tf

from tensorflow.contrib import slim


def orthogonal_initializer(shape, dtype=tf.float32, *args, **kwargs):
  """Generates orthonormal matrices with random values.

  Orthonormal initialization is important for RNNs:
    http://arxiv.org/abs/1312.6120
    http://smerity.com/articles/2016/orthogonal_init.html

  For non-square shapes the returned matrix will be semi-orthonormal: if the
  number of columns exceeds the number of rows, then the rows are orthonormal
  vectors; but if the number of rows exceeds the number of columns, then the
  columns are orthonormal vectors.

  We use SVD decomposition to generate an orthonormal matrix with random
  values. The same way as it is done in the Lasagne library for Theano. Note
  that both u and v returned by the svd are orthogonal and random. We just need
  to pick one with the right shape.

  Args:
    shape: a shape of the tensor matrix to initialize.
    dtype: a dtype of the initialized tensor.
    *args: not used.
    **kwargs: not used.

  Returns:
    An initialized tensor.
  """
  del args
  del kwargs
  flat_shape = (shape[0], np.prod(shape[1:]))
  w = np.random.randn(*flat_shape)
  u, _, v = np.linalg.svd(w, full_matrices=False)
  w = u if u.shape == flat_shape else v
  return tf.constant(w.reshape(shape), dtype=dtype)


SequenceLayerParams = collections.namedtuple('SequenceLogitsParams', [
    'num_lstm_units', 'weight_decay', 'lstm_state_clip_value'
])


class SequenceLayerBase(object):
  """A base abstruct class for all sequence layers.

  A child class has to define following methods:
    get_train_input
    get_eval_input
    unroll_cell
  """
  __metaclass__ = abc.ABCMeta

  def __init__(self, net, labels_one_hot, model_params, method_params):
    """Stores argument in member variable for further use.

    Args:
      net: A tensor with shape [batch_size, num_features, feature_size] which
        contains some extracted image features.
      labels_one_hot: An optional (can be None) ground truth labels for the
        input features. Is a tensor with shape
        [batch_size, seq_length, num_char_classes]
      model_params: A namedtuple with model parameters (model.ModelParams).
      method_params: A SequenceLayerParams instance.
    """
    self._params = model_params
    self._mparams = method_params
    self._net = net
    self._labels_one_hot = labels_one_hot
    self._batch_size = net.get_shape().dims[0].value

    # Initialize parameters for char logits which will be computed on the fly
    # inside an LSTM decoder.
    self._char_logits = {}
    regularizer = slim.l2_regularizer(self._mparams.weight_decay)
    self._softmax_w = slim.model_variable(
        'softmax_w',
        [self._mparams.num_lstm_units, self._params.num_char_classes],
        initializer=orthogonal_initializer,
        regularizer=regularizer)
    self._softmax_b = slim.model_variable(
        'softmax_b', [self._params.num_char_classes],
        initializer=tf.zeros_initializer(),
        regularizer=regularizer)

  @abc.abstractmethod
  def get_train_input(self, prev, i):
    """Returns a sample to be used to predict a character during training.

    This function is used as a loop_function for an RNN decoder.

    Args:
      prev: output tensor from previous step of the RNN. A tensor with shape:
        [batch_size, num_char_classes].
      i: index of a character in the output sequence.

    Returns:
      A tensor with shape [batch_size, ?] - depth depends on implementation
      details.
    """
    pass

  @abc.abstractmethod
  def get_eval_input(self, prev, i):
    """Returns a sample to be used to predict a character during inference.

    This function is used as a loop_function for an RNN decoder.

    Args:
      prev: output tensor from previous step of the RNN. A tensor with shape:
        [batch_size, num_char_classes].
      i: index of a character in the output sequence.

    Returns:
      A tensor with shape [batch_size, ?] - depth depends on implementation
      details.
    """
    raise AssertionError('Not implemented')

  @abc.abstractmethod
  def unroll_cell(self, decoder_inputs, initial_state, loop_function, cell):
    """Unrolls an RNN cell for all inputs.

    This is a placeholder to call some RNN decoder. It has a similar to
    tf.seq2seq.rnn_decode interface.

    Args:
      decoder_inputs: A list of 2D Tensors* [batch_size x input_size]. In fact,
        most of existing decoders in presence of a loop_function use only the
        first element to determine batch_size and length of the list to
        determine number of steps.
      initial_state: 2D Tensor with shape [batch_size x cell.state_size].
      loop_function: function will be applied to the i-th output in order to
        generate the i+1-st input (see self.get_input).
      cell: rnn_cell.RNNCell defining the cell function and size.

    Returns:
      A tuple of the form (outputs, state), where:
        outputs: A list of character logits of the same length as
        decoder_inputs of 2D Tensors with shape [batch_size x num_characters].
        state: The state of each cell at the final time-step.
          It is a 2D Tensor of shape [batch_size x cell.state_size].
    """
    pass

  def is_training(self):
    """Returns True if the layer is created for training stage."""
    return self._labels_one_hot is not None

  def char_logit(self, inputs, char_index):
    """Creates logits for a character if required.

    Args:
      inputs: A tensor with shape [batch_size, ?] (depth is implementation
        dependent).
      char_index: A integer index of a character in the output sequence.

    Returns:
      A tensor with shape [batch_size, num_char_classes]
    """
    if char_index not in self._char_logits:
      self._char_logits[char_index] = tf.nn.xw_plus_b(inputs, self._softmax_w,
                                                      self._softmax_b)
    return self._char_logits[char_index]

  def char_one_hot(self, logit):
    """Creates one hot encoding for a logit of a character.

    Args:
      logit: A tensor with shape [batch_size, num_char_classes].

    Returns:
      A tensor with shape [batch_size, num_char_classes]
    """
    prediction = tf.argmax(logit, axis=1)
    return slim.one_hot_encoding(prediction, self._params.num_char_classes)

  def get_input(self, prev, i):
    """A wrapper for get_train_input and get_eval_input.

    Args:
      prev: output tensor from previous step of the RNN. A tensor with shape:
        [batch_size, num_char_classes].
      i: index of a character in the output sequence.

    Returns:
      A tensor with shape [batch_size, ?] - depth depends on implementation
      details.
    """
    if self.is_training():
      return self.get_train_input(prev, i)
    else:
      return self.get_eval_input(prev, i)

  def create_logits(self):
    """Creates character sequence logits for a net specified in the constructor.

    A "main" method for the sequence layer which glues together all pieces.

    Returns:
      A tensor with shape [batch_size, seq_length, num_char_classes].
    """
    with tf.variable_scope('LSTM'):
      first_label = self.get_input(prev=None, i=0)
      decoder_inputs = [first_label] + [None] * (self._params.seq_length - 1)
      lstm_cell = tf.contrib.rnn.LSTMCell(
          self._mparams.num_lstm_units,
          use_peepholes=False,
          cell_clip=self._mparams.lstm_state_clip_value,
          state_is_tuple=True,
          initializer=orthogonal_initializer)
      lstm_outputs, _ = self.unroll_cell(
          decoder_inputs=decoder_inputs,
          initial_state=lstm_cell.zero_state(self._batch_size, tf.float32),
          loop_function=self.get_input,
          cell=lstm_cell)

    with tf.variable_scope('logits'):
      logits_list = [
          tf.expand_dims(self.char_logit(logit, i), dim=1)
          for i, logit in enumerate(lstm_outputs)
      ]

    return tf.concat(logits_list, 1)


class NetSlice(SequenceLayerBase):
  """A layer which uses a subset of image features to predict each character.
  """

  def __init__(self, *args, **kwargs):
    super(NetSlice, self).__init__(*args, **kwargs)
    self._zero_label = tf.zeros(
        [self._batch_size, self._params.num_char_classes])

  def get_image_feature(self, char_index):
    """Returns a subset of image features for a character.

    Args:
      char_index: an index of a character.

    Returns:
      A tensor with shape [batch_size, ?]. The output depth depends on the
      depth of input net.
    """
    batch_size, features_num, _ = [d.value for d in self._net.get_shape()]
    slice_len = int(features_num / self._params.seq_length)
    # In case when features_num != seq_length, we just pick a subset of image
    # features, this choice is arbitrary and there is no intuitive geometrical
    # interpretation. If features_num is not dividable by seq_length there will
    # be unused image features.
    net_slice = self._net[:, char_index:char_index + slice_len, :]
    feature = tf.reshape(net_slice, [batch_size, -1])
    logging.debug('Image feature: %s', feature)
    return feature

  def get_eval_input(self, prev, i):
    """See SequenceLayerBase.get_eval_input for details."""
    del prev
    return self.get_image_feature(i)

  def get_train_input(self, prev, i):
    """See SequenceLayerBase.get_train_input for details."""
    return self.get_eval_input(prev, i)

  def unroll_cell(self, decoder_inputs, initial_state, loop_function, cell):
    """See SequenceLayerBase.unroll_cell for details."""
    return tf.contrib.legacy_seq2seq.rnn_decoder(
        decoder_inputs=decoder_inputs,
        initial_state=initial_state,
        cell=cell,
        loop_function=self.get_input)


class NetSliceWithAutoregression(NetSlice):
  """A layer similar to NetSlice, but it also uses auto regression.

  The "auto regression" means that we use network output for previous character
  as a part of input for the current character.
  """

  def __init__(self, *args, **kwargs):
    super(NetSliceWithAutoregression, self).__init__(*args, **kwargs)

  def get_eval_input(self, prev, i):
    """See SequenceLayerBase.get_eval_input for details."""
    if i == 0:
      prev = self._zero_label
    else:
      logit = self.char_logit(prev, char_index=i - 1)
      prev = self.char_one_hot(logit)
    image_feature = self.get_image_feature(char_index=i)
    return tf.concat([image_feature, prev], 1)

  def get_train_input(self, prev, i):
    """See SequenceLayerBase.get_train_input for details."""
    if i == 0:
      prev = self._zero_label
    else:
      prev = self._labels_one_hot[:, i - 1, :]
    image_feature = self.get_image_feature(i)
    return tf.concat([image_feature, prev], 1)


class Attention(SequenceLayerBase):
  """A layer which uses attention mechanism to select image features."""

  def __init__(self, *args, **kwargs):
    super(Attention, self).__init__(*args, **kwargs)
    self._zero_label = tf.zeros(
        [self._batch_size, self._params.num_char_classes])

  def get_eval_input(self, prev, i):
    """See SequenceLayerBase.get_eval_input for details."""
    del prev, i
    # The attention_decoder will fetch image features from the net, no need for
    # extra inputs.
    return self._zero_label

  def get_train_input(self, prev, i):
    """See SequenceLayerBase.get_train_input for details."""
    return self.get_eval_input(prev, i)

  def unroll_cell(self, decoder_inputs, initial_state, loop_function, cell):
    return tf.contrib.legacy_seq2seq.attention_decoder(
        decoder_inputs=decoder_inputs,
        initial_state=initial_state,
        attention_states=self._net,
        cell=cell,
        loop_function=self.get_input)


class AttentionWithAutoregression(Attention):
  """A layer which uses both attention and auto regression."""

  def __init__(self, *args, **kwargs):
    super(AttentionWithAutoregression, self).__init__(*args, **kwargs)

  def get_train_input(self, prev, i):
    """See SequenceLayerBase.get_train_input for details."""
    if i == 0:
      return self._zero_label
    else:
      # TODO(gorban): update to gradually introduce gt labels.
      return self._labels_one_hot[:, i - 1, :]

  def get_eval_input(self, prev, i):
    """See SequenceLayerBase.get_eval_input for details."""
    if i == 0:
      return self._zero_label
    else:
      logit = self.char_logit(prev, char_index=i - 1)
      return self.char_one_hot(logit)


def get_layer_class(use_attention, use_autoregression):
  """A convenience function to get a layer class based on requirements.

  Args:
    use_attention: if True a returned class will use attention.
    use_autoregression: if True a returned class will use auto regression.

  Returns:
    One of available sequence layers (child classes for SequenceLayerBase).
  """
  if use_attention and use_autoregression:
    layer_class = AttentionWithAutoregression
  elif use_attention and not use_autoregression:
    layer_class = Attention
  elif not use_attention and not use_autoregression:
    layer_class = NetSlice
  elif not use_attention and use_autoregression:
    layer_class = NetSliceWithAutoregression
  else:
    raise AssertionError('Unsupported sequence layer class')

  logging.debug('Use %s as a layer class', layer_class.__name__)
  return layer_class