gap-text2sql
/
gap-text2sql-main
/mrat-sql-gap
/seq2struct
/models
/nl2code
/infer_tree_traversal.py
import ast | |
import collections | |
import collections.abc | |
import enum | |
import itertools | |
import json | |
import os | |
import operator | |
import re | |
import copy | |
import random | |
import asdl | |
import attr | |
import pyrsistent | |
import entmax | |
import torch | |
import torch.nn.functional as F | |
from seq2struct.utils import vocab | |
from seq2struct.models.nl2code.tree_traversal import TreeTraversal | |
class InferenceTreeTraversal(TreeTraversal): | |
class TreeAction: | |
pass | |
class SetParentField(TreeAction): | |
parent_field_name = attr.ib() | |
node_type = attr.ib() | |
node_value = attr.ib(default=None) | |
class CreateParentFieldList(TreeAction): | |
parent_field_name = attr.ib() | |
class AppendTerminalToken(TreeAction): | |
parent_field_name = attr.ib() | |
value = attr.ib() | |
class FinalizeTerminal(TreeAction): | |
parent_field_name = attr.ib() | |
terminal_type = attr.ib() | |
class NodeFinished(TreeAction): | |
pass | |
SIMPLE_TERMINAL_TYPES = { | |
'str': str, | |
'int': int, | |
'float': float, | |
'bool': lambda n: {'True': True, 'False': False}.get(n, False), | |
} | |
SIMPLE_TERMINAL_TYPES_DEFAULT = { | |
'str': '', | |
'int': 0, | |
'float': 0, | |
'bool': True, | |
} | |
def __init__(self, model, desc_enc, example=None): | |
super().__init__(model, desc_enc) | |
self.example = example | |
self.actions = pyrsistent.pvector() | |
def clone(self): | |
super_clone = super().clone() | |
super_clone.actions = self.actions | |
super_clone.example = self.example | |
return super_clone | |
def rule_choice(self, node_type, rule_logits): | |
return self.model.rule_infer(node_type, rule_logits) | |
def token_choice(self, output, gen_logodds): | |
return self.model.token_infer(output, gen_logodds, self.desc_enc) | |
def pointer_choice(self, node_type, logits, attention_logits): | |
# Group them based on pointer map | |
pointer_logprobs = self.model.pointer_infer(node_type, logits) | |
pointer_map = self.desc_enc.pointer_maps.get(node_type) | |
if not pointer_map: | |
return pointer_logprobs | |
pointer_logprobs = dict(pointer_logprobs) | |
return [ | |
(orig_index, torch.logsumexp( | |
torch.stack( | |
tuple(pointer_logprobs[i] for i in mapped_indices), | |
dim=0), | |
dim=0)) | |
for orig_index, mapped_indices in pointer_map.items() | |
] | |
def update_using_last_choice(self, last_choice, extra_choice_info, attention_offset): | |
super().update_using_last_choice(last_choice, extra_choice_info, attention_offset) | |
# Record actions | |
# CHILDREN_INQUIRE | |
if self.cur_item.state == TreeTraversal.State.CHILDREN_INQUIRE: | |
self.actions = self.actions.append( | |
self.SetParentField( | |
self.cur_item.parent_field_name, self.cur_item.node_type)) | |
type_info = self.model.ast_wrapper.singular_types[self.cur_item.node_type] | |
if not type_info.fields: | |
self.actions = self.actions.append(self.NodeFinished()) | |
# LIST_LENGTH_APPLY | |
elif self.cur_item.state == TreeTraversal.State.LIST_LENGTH_APPLY: | |
self.actions = self.actions.append(self.CreateParentFieldList(self.cur_item.parent_field_name)) | |
# GEN_TOKEN | |
elif self.cur_item.state == TreeTraversal.State.GEN_TOKEN: | |
if last_choice == vocab.EOS: | |
self.actions = self.actions.append(self.FinalizeTerminal( | |
self.cur_item.parent_field_name, | |
self.cur_item.node_type)) | |
elif last_choice is not None: | |
self.actions = self.actions.append(self.AppendTerminalToken( | |
self.cur_item.parent_field_name, | |
last_choice)) | |
elif self.cur_item.state == TreeTraversal.State.POINTER_APPLY: | |
self.actions = self.actions.append(self.SetParentField( | |
self.cur_item.parent_field_name, | |
node_type=None, | |
node_value=last_choice)) | |
# NODE_FINISHED | |
elif self.cur_item.state == TreeTraversal.State.NODE_FINISHED: | |
self.actions = self.actions.append(self.NodeFinished()) | |
def finalize(self): | |
root = current = None | |
stack = [] | |
for i, action in enumerate(self.actions): | |
if isinstance(action, self.SetParentField): | |
if action.node_value is None: | |
new_node = {'_type': action.node_type} | |
else: | |
new_node = action.node_value | |
if action.parent_field_name is None: | |
# Initial node in tree. | |
assert root is None | |
root = current = new_node | |
stack.append(root) | |
continue | |
existing_list = current.get(action.parent_field_name) | |
if existing_list is None: | |
current[action.parent_field_name] = new_node | |
else: | |
assert isinstance(existing_list, list) | |
current[action.parent_field_name].append(new_node) | |
if action.node_value is None: | |
stack.append(current) | |
current = new_node | |
elif isinstance(action, self.CreateParentFieldList): | |
current[action.parent_field_name] = [] | |
elif isinstance(action, self.AppendTerminalToken): | |
tokens = current.get(action.parent_field_name) | |
if tokens is None: | |
tokens = current[action.parent_field_name] = [] | |
tokens.append(action.value) | |
elif isinstance(action, self.FinalizeTerminal): | |
terminal = ''.join(current.get(action.parent_field_name, [])) | |
constructor = self.SIMPLE_TERMINAL_TYPES.get(action.terminal_type) | |
if constructor: | |
try: | |
value = constructor(terminal) | |
except ValueError: | |
value = self.SIMPLE_TERMINAL_TYPES_DEFAULT[action.terminal_type] | |
elif action.terminal_type == 'bytes': | |
value = terminal.decode('latin1') | |
elif action.terminal_type == 'NoneType': | |
value = None | |
else: | |
raise ValueError('Unknown terminal type: {}'.format(action.terminal_type)) | |
current[action.parent_field_name] = value | |
elif isinstance(action, self.NodeFinished): | |
current = stack.pop() | |
else: | |
raise ValueError(action) | |
assert not stack | |
return root, self.model.preproc.grammar.unparse(root, self.example) | |