# Copyright 2017 The TensorFlow Authors All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== import numpy as np # the number of attention input to each module _module_input_num = { '_key_find': 0, '_key_filter': 1, '_val_desc': 1} _module_output_type = { '_key_find': 'att', '_key_filter': 'att', '_val_desc': 'ans' } INVALID_EXPR = 'INVALID_EXPR' class Assembler: def __init__(self, config): # read the module list, and record the index of each module and self.module_names = config.module_names # find the index of for n_s in range(len(self.module_names)): if self.module_names[n_s] == '': self.EOS_idx = n_s break # build a dictionary from module name to token index self.name2idx_dict = { name: n_s for n_s, name in enumerate(self.module_names) } def module_list2tokens(self, module_list, max_len=None): layout_tokens = [self.name2idx_dict[name] for name in module_list] if max_len is not None: if len(module_list) >= max_len: raise ValueError('Not enough time steps to add ') layout_tokens += [self.EOS_idx] * (max_len - len(module_list)) return layout_tokens def _layout_tokens2str(self, layout_tokens): return ' '.join([self.module_names[idx] for idx in layout_tokens]) def _invalid_expr(self, layout_tokens, error_str): return { 'module': INVALID_EXPR, 'expr_str': self._layout_tokens2str(layout_tokens), 'error': error_str } def _assemble_layout_tokens(self, layout_tokens, batch_idx): # Every module takes a time_idx as the index from LSTM hidden states # (even if it doesn't need it, like _and), and different arity of # attention inputs. The output type can be either attention or answer # # The final assembled expression for each instance is as follows: # expr_type := # {'module': '_find', 'output_type': 'att', 'time_idx': idx} # | {'module': '_relocate', 'output_type': 'att', 'time_idx': idx, # 'inputs_0': } # | {'module': '_and', 'output_type': 'att', 'time_idx': idx, # 'inputs_0': , 'inputs_1': )} # | {'module': '_describe', 'output_type': 'ans', 'time_idx': idx, # 'inputs_0': } # | {'module': INVALID_EXPR, 'expr_str': '...', 'error': '...', # 'assembly_loss': } (for invalid expressions) # # A valid layout must contain . Assembly fails if it doesn't. if not np.any(layout_tokens == self.EOS_idx): return self._invalid_expr(layout_tokens, 'cannot find ') # Decoding Reverse Polish Notation with a stack decoding_stack = [] for t in range(len(layout_tokens)): # decode a module/operation module_idx = layout_tokens[t] if module_idx == self.EOS_idx: break module_name = self.module_names[module_idx] expr = { 'module': module_name, 'output_type': _module_output_type[module_name], 'time_idx': t, 'batch_idx': batch_idx } input_num = _module_input_num[module_name] # Check if there are enough input in the stack if len(decoding_stack) < input_num: # Invalid expression. Not enough input. return self._invalid_expr(layout_tokens, 'not enough input for ' + module_name) # Get the input from stack for n_input in range(input_num - 1, -1, -1): stack_top = decoding_stack.pop() if stack_top['output_type'] != 'att': # Invalid expression. Input must be attention return self._invalid_expr(layout_tokens, 'input incompatible for ' + module_name) expr['input_%d' % n_input] = stack_top decoding_stack.append(expr) # After decoding the reverse polish expression, there should be exactly # one expression in the stack if len(decoding_stack) != 1: return self._invalid_expr( layout_tokens, 'final stack size not equal to 1 (%d remains)' % len(decoding_stack)) result = decoding_stack[0] # The result type should be answer, not attention if result['output_type'] != 'ans': return self._invalid_expr(layout_tokens, 'result type must be ans, not att') return result def assemble(self, layout_tokens_batch): # layout_tokens_batch is a numpy array with shape [max_dec_len, batch_size], # containing module tokens and , in Reverse Polish Notation. _, batch_size = layout_tokens_batch.shape expr_list = [ self._assemble_layout_tokens(layout_tokens_batch[:, batch_i], batch_i) for batch_i in range(batch_size) ] expr_validity = np.array( [expr['module'] != INVALID_EXPR for expr in expr_list], np.bool) return expr_list, expr_validity