|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""Custom RNN decoder.""" |
|
|
|
import tensorflow.compat.v1 as tf |
|
import lstm_object_detection.lstm.utils as lstm_utils |
|
|
|
|
|
class _NoVariableScope(object): |
|
|
|
def __enter__(self): |
|
return |
|
|
|
def __exit__(self, exc_type, exc_value, traceback): |
|
return False |
|
|
|
|
|
def rnn_decoder(decoder_inputs, |
|
initial_state, |
|
cell, |
|
loop_function=None, |
|
scope=None): |
|
"""RNN decoder for the LSTM-SSD model. |
|
|
|
This decoder returns a list of all states, rather than only the final state. |
|
Args: |
|
decoder_inputs: A list of 4D Tensors with shape [batch_size x input_size]. |
|
initial_state: 2D Tensor with shape [batch_size x cell.state_size]. |
|
cell: rnn_cell.RNNCell defining the cell function and size. |
|
loop_function: If not None, this function will be applied to the i-th output |
|
in order to generate the i+1-st input, and decoder_inputs will be ignored, |
|
except for the first element ("GO" symbol). This can be used for decoding, |
|
but also for training to emulate http://arxiv.org/abs/1506.03099. |
|
Signature -- loop_function(prev, i) = next |
|
* prev is a 2D Tensor of shape [batch_size x output_size], |
|
* i is an integer, the step number (when advanced control is needed), |
|
* next is a 2D Tensor of shape [batch_size x input_size]. |
|
scope: optional VariableScope for the created subgraph. |
|
Returns: |
|
A tuple of the form (outputs, state), where: |
|
outputs: A list of the same length as decoder_inputs of 4D Tensors with |
|
shape [batch_size x output_size] containing generated outputs. |
|
states: A list of the same length as decoder_inputs of the state of each |
|
cell at each time-step. It is a 2D Tensor of shape |
|
[batch_size x cell.state_size]. |
|
""" |
|
with tf.variable_scope(scope) if scope else _NoVariableScope(): |
|
state_tuple = initial_state |
|
outputs = [] |
|
states = [] |
|
prev = None |
|
for local_step, decoder_input in enumerate(decoder_inputs): |
|
if loop_function is not None and prev is not None: |
|
with tf.variable_scope('loop_function', reuse=True): |
|
decoder_input = loop_function(prev, local_step) |
|
output, state_tuple = cell(decoder_input, state_tuple) |
|
outputs.append(output) |
|
states.append(state_tuple) |
|
if loop_function is not None: |
|
prev = output |
|
return outputs, states |
|
|
|
def multi_input_rnn_decoder(decoder_inputs, |
|
initial_state, |
|
cell, |
|
sequence_step, |
|
selection_strategy='RANDOM', |
|
is_training=None, |
|
is_quantized=False, |
|
preprocess_fn_list=None, |
|
pre_bottleneck=False, |
|
flatten_state=False, |
|
scope=None): |
|
"""RNN decoder for the Interleaved LSTM-SSD model. |
|
|
|
This decoder takes multiple sequences of inputs and selects the input to feed |
|
to the rnn at each timestep using its selection_strategy, which can be random, |
|
learned, or deterministic. |
|
This decoder returns a list of all states, rather than only the final state. |
|
Args: |
|
decoder_inputs: A list of lists of 2D Tensors [batch_size x input_size]. |
|
initial_state: 2D Tensor with shape [batch_size x cell.state_size]. |
|
cell: rnn_cell.RNNCell defining the cell function and size. |
|
sequence_step: Tensor [batch_size] of the step number of the first elements |
|
in the sequence. |
|
selection_strategy: Method for picking the decoder_input to use at each |
|
timestep. Must be 'RANDOM', 'SKIPX' for integer X, where X is the number |
|
of times to use the second input before using the first. |
|
is_training: boolean, whether the network is training. When using learned |
|
selection, attempts exploration if training. |
|
is_quantized: flag to enable/disable quantization mode. |
|
preprocess_fn_list: List of functions accepting two tensor arguments: one |
|
timestep of decoder_inputs and the lstm state. If not None, |
|
decoder_inputs[i] will be updated with preprocess_fn[i] at the start of |
|
each timestep. |
|
pre_bottleneck: if True, use separate bottleneck weights for each sequence. |
|
Useful when input sequences have differing numbers of channels. Final |
|
bottlenecks will have the same dimension. |
|
flatten_state: Whether the LSTM state is flattened. |
|
scope: optional VariableScope for the created subgraph. |
|
Returns: |
|
A tuple of the form (outputs, state), where: |
|
outputs: A list of the same length as decoder_inputs of 2D Tensors with |
|
shape [batch_size x output_size] containing generated outputs. |
|
states: A list of the same length as decoder_inputs of the state of each |
|
cell at each time-step. It is a 2D Tensor of shape |
|
[batch_size x cell.state_size]. |
|
Raises: |
|
ValueError: If selection_strategy is not recognized or unexpected unroll |
|
length. |
|
""" |
|
if flatten_state and len(decoder_inputs[0]) > 1: |
|
raise ValueError('In export mode, unroll length should not be more than 1') |
|
with tf.variable_scope(scope) if scope else _NoVariableScope(): |
|
state_tuple = initial_state |
|
outputs = [] |
|
states = [] |
|
batch_size = decoder_inputs[0][0].shape[0].value |
|
num_sequences = len(decoder_inputs) |
|
sequence_length = len(decoder_inputs[0]) |
|
|
|
for local_step in range(sequence_length): |
|
for sequence_index in range(num_sequences): |
|
if preprocess_fn_list is not None: |
|
decoder_inputs[sequence_index][local_step] = ( |
|
preprocess_fn_list[sequence_index]( |
|
decoder_inputs[sequence_index][local_step], state_tuple[0])) |
|
if pre_bottleneck: |
|
decoder_inputs[sequence_index][local_step] = cell.pre_bottleneck( |
|
inputs=decoder_inputs[sequence_index][local_step], |
|
state=state_tuple[1], |
|
input_index=sequence_index) |
|
|
|
action = generate_action(selection_strategy, local_step, sequence_step, |
|
[batch_size, 1, 1, 1]) |
|
inputs, _ = ( |
|
select_inputs(decoder_inputs, action, local_step, is_training, |
|
is_quantized)) |
|
|
|
with tf.name_scope(None): |
|
inputs = tf.identity(inputs, 'raw_inputs/base_endpoint') |
|
output, state_tuple_out = cell(inputs, state_tuple) |
|
state_tuple = select_state(state_tuple, state_tuple_out, action) |
|
|
|
outputs.append(output) |
|
states.append(state_tuple) |
|
return outputs, states |
|
|
|
|
|
def generate_action(selection_strategy, local_step, sequence_step, |
|
action_shape): |
|
"""Generate current (binary) action based on selection strategy. |
|
|
|
Args: |
|
selection_strategy: Method for picking the decoder_input to use at each |
|
timestep. Must be 'RANDOM', 'SKIPX' for integer X, where X is the number |
|
of times to use the second input before using the first. |
|
local_step: Tensor [batch_size] of the step number within the current |
|
unrolled batch. |
|
sequence_step: Tensor [batch_size] of the step number of the first elements |
|
in the sequence. |
|
action_shape: The shape of action tensor to be generated. |
|
|
|
Returns: |
|
A tensor of shape action_shape, each element is an individual action. |
|
|
|
Raises: |
|
ValueError: if selection_strategy is not supported or if 'SKIP' is not |
|
followed by numerics. |
|
""" |
|
if selection_strategy.startswith('RANDOM'): |
|
action = tf.random.uniform(action_shape, maxval=2, dtype=tf.int32) |
|
action = tf.minimum(action, 1) |
|
|
|
|
|
if local_step == 0 and sequence_step is not None: |
|
action *= tf.minimum( |
|
tf.reshape(tf.cast(sequence_step, tf.int32), action_shape), 1) |
|
elif selection_strategy.startswith('SKIP'): |
|
inter_count = int(selection_strategy[4:]) |
|
if local_step % (inter_count + 1) == 0: |
|
action = tf.zeros(action_shape) |
|
else: |
|
action = tf.ones(action_shape) |
|
else: |
|
raise ValueError('Selection strategy %s not recognized' % |
|
selection_strategy) |
|
return tf.cast(action, tf.int32) |
|
|
|
|
|
def select_inputs(decoder_inputs, action, local_step, is_training, is_quantized, |
|
get_alt_inputs=False): |
|
"""Selects sequence from decoder_inputs based on 1D actions. |
|
|
|
Given multiple input batches, creates a single output batch by |
|
selecting from the action[i]-ith input for the i-th batch element. |
|
|
|
Args: |
|
decoder_inputs: A 2-D list of tensor inputs. |
|
action: A tensor of shape [batch_size]. Each element corresponds to an index |
|
of decoder_inputs to choose. |
|
local_step: The current timestep. |
|
is_training: boolean, whether the network is training. When using learned |
|
selection, attempts exploration if training. |
|
is_quantized: flag to enable/disable quantization mode. |
|
get_alt_inputs: Whether the non-chosen inputs should also be returned. |
|
|
|
Returns: |
|
The constructed output. Also outputs the elements that were not chosen |
|
if get_alt_inputs is True, otherwise None. |
|
|
|
Raises: |
|
ValueError: if the decoder inputs contains other than two sequences. |
|
""" |
|
num_seqs = len(decoder_inputs) |
|
if not num_seqs == 2: |
|
raise ValueError('Currently only supports two sets of inputs.') |
|
stacked_inputs = tf.stack( |
|
[decoder_inputs[seq_index][local_step] for seq_index in range(num_seqs)], |
|
axis=-1) |
|
action_index = tf.one_hot(action, num_seqs) |
|
selected_inputs = ( |
|
lstm_utils.quantize_op(stacked_inputs * action_index, is_training, |
|
is_quantized, scope='quant_selected_inputs')) |
|
inputs = tf.reduce_sum(selected_inputs, axis=-1) |
|
inputs_alt = None |
|
|
|
if get_alt_inputs: |
|
|
|
action_index_alt = tf.one_hot(action, num_seqs, on_value=0.0, off_value=1.0) |
|
selected_inputs = ( |
|
lstm_utils.quantize_op(stacked_inputs * action_index_alt, is_training, |
|
is_quantized, scope='quant_selected_inputs_alt')) |
|
inputs_alt = tf.reduce_sum(selected_inputs, axis=-1) |
|
return inputs, inputs_alt |
|
|
|
def select_state(previous_state, new_state, action): |
|
"""Select state given action. |
|
|
|
Currently only supports binary action. If action is 0, it means the state is |
|
generated from the large model, and thus we will update the state. Otherwise, |
|
if the action is 1, it means the state is generated from the small model, and |
|
in interleaved model, we skip this state update. |
|
|
|
Args: |
|
previous_state: A state tuple representing state from previous step. |
|
new_state: A state tuple representing newly computed state. |
|
action: A tensor the same shape as state. |
|
|
|
Returns: |
|
A state tuple selected based on the given action. |
|
""" |
|
action = tf.cast(action, tf.float32) |
|
state_c = previous_state[0] * action + new_state[0] * (1 - action) |
|
state_h = previous_state[1] * action + new_state[1] * (1 - action) |
|
return (state_c, state_h) |
|
|