import ast import collections import collections.abc import enum import itertools import json import os import operator import re import copy import random import asdl import attr import pyrsistent import entmax import torch import torch.nn.functional as F from seq2struct.utils import vocab from seq2struct.models.nl2code import decoder @attr.s class TreeState: node = attr.ib() parent_field_type = attr.ib() class TreeTraversal: class Handler: handlers = {} @classmethod def register_handler(cls, func_type): if func_type in cls.handlers: raise RuntimeError(f"{func_type} handler is already registered") def inner_func(func): cls.handlers[func_type] = func.__name__ return func return inner_func @attr.s(frozen=True) class QueueItem: item_id = attr.ib() state = attr.ib() node_type = attr.ib() parent_action_emb = attr.ib() parent_h = attr.ib() parent_field_name = attr.ib() def to_str(self): return "".format( self.state, self.node_type, self.parent_field_name ) class State(enum.Enum): SUM_TYPE_INQUIRE = 0 SUM_TYPE_APPLY = 1 CHILDREN_INQUIRE = 2 CHILDREN_APPLY = 3 LIST_LENGTH_INQUIRE = 4 LIST_LENGTH_APPLY = 5 GEN_TOKEN = 6 POINTER_INQUIRE = 7 POINTER_APPLY = 8 NODE_FINISHED = 9 def __init__(self, model, desc_enc): if model is None: return self.model = model self.desc_enc = desc_enc model.state_update.set_dropout_masks(batch_size=1) self.recurrent_state = decoder.lstm_init( model._device, None, self.model.recurrent_size, 1 ) self.prev_action_emb = model.zero_rule_emb root_type = model.preproc.grammar.root_type if root_type in model.preproc.ast_wrapper.sum_types: initial_state = TreeTraversal.State.SUM_TYPE_INQUIRE else: initial_state = TreeTraversal.State.CHILDREN_INQUIRE self.queue = pyrsistent.pvector() self.cur_item = TreeTraversal.QueueItem( item_id=0, state=initial_state, node_type=root_type, parent_action_emb=self.model.zero_rule_emb, parent_h=self.model.zero_recurrent_emb, parent_field_name=None, ) self.next_item_id = 1 self.update_prev_action_emb = TreeTraversal._update_prev_action_emb_apply_rule def clone(self): other = self.__class__(None, None) other.model = self.model other.desc_enc = self.desc_enc other.recurrent_state = self.recurrent_state other.prev_action_emb = self.prev_action_emb other.queue = self.queue other.cur_item = self.cur_item other.next_item_id = self.next_item_id other.actions = self.actions other.update_prev_action_emb = self.update_prev_action_emb return other def step(self, last_choice, extra_choice_info=None, attention_offset=None): while True: self.update_using_last_choice( last_choice, extra_choice_info, attention_offset ) handler_name = TreeTraversal.Handler.handlers[self.cur_item.state] handler = getattr(self, handler_name) choices, continued = handler(last_choice) if continued: last_choice = choices continue else: return choices def update_using_last_choice( self, last_choice, extra_choice_info, attention_offset ): if last_choice is None: return if self.model.visualize_flag: print("cur_item.state", self.cur_item.state) if ( self.cur_item.state == TreeTraversal.State.SUM_TYPE_APPLY or self.cur_item.state == TreeTraversal.State.CHILDREN_APPLY or self.cur_item.state == TreeTraversal.State.LIST_LENGTH_APPLY ): print("last choice", self.model.preproc.all_rules[last_choice]) else: print("last choice", last_choice) self.update_prev_action_emb(self, last_choice, extra_choice_info) @classmethod def _update_prev_action_emb_apply_rule(cls, self, last_choice, extra_choice_info): rule_idx = self.model._tensor([last_choice]) self.prev_action_emb = self.model.rule_embedding(rule_idx) @classmethod def _update_prev_action_emb_gen_token(cls, self, last_choice, extra_choice_info): # token_idx shape: batch (=1), LongTensor token_idx = self.model._index(self.model.terminal_vocab, last_choice) # action_emb shape: batch (=1) x emb_size self.prev_action_emb = self.model.terminal_embedding(token_idx) @classmethod def _update_prev_action_emb_pointer(cls, self, last_choice, extra_choice_info): # TODO batching self.prev_action_emb = self.model.pointer_action_emb_proj[ self.cur_item.node_type ](self.desc_enc.pointer_memories[self.cur_item.node_type][:, last_choice]) def pop(self): if self.queue: self.cur_item = self.queue[-1] self.queue = self.queue.delete(-1) return True return False @Handler.register_handler(State.SUM_TYPE_INQUIRE) def process_sum_inquire(self, last_choice): # 1. ApplyRule, like expr -> Call # a. Ask which one to choose output, self.recurrent_state, rule_logits = self.model.apply_rule( self.cur_item.node_type, self.recurrent_state, self.prev_action_emb, self.cur_item.parent_h, self.cur_item.parent_action_emb, self.desc_enc, ) self.cur_item = attr.evolve( self.cur_item, state=TreeTraversal.State.SUM_TYPE_APPLY, parent_h=output ) self.update_prev_action_emb = ( TreeTraversal._update_prev_action_emb_apply_rule ) choices = self.rule_choice(self.cur_item.node_type, rule_logits) return choices, False @Handler.register_handler(State.SUM_TYPE_APPLY) def process_sum_apply(self, last_choice): # b. Add action, prepare for #2 sum_type, singular_type = self.model.preproc.all_rules[last_choice] assert sum_type == self.cur_item.node_type self.cur_item = attr.evolve( self.cur_item, node_type=singular_type, parent_action_emb=self.prev_action_emb, state=TreeTraversal.State.CHILDREN_INQUIRE, ) return None, True @Handler.register_handler(State.CHILDREN_INQUIRE) def process_children_inquire(self, last_choice): # 2. ApplyRule, like Call -> expr[func] expr*[args] keyword*[keywords] # Check if we have no children type_info = self.model.ast_wrapper.singular_types[ self.cur_item.node_type ] if not type_info.fields: if self.pop(): last_choice = None return last_choice, True else: return None, False # a. Ask about presence output, self.recurrent_state, rule_logits = self.model.apply_rule( self.cur_item.node_type, self.recurrent_state, self.prev_action_emb, self.cur_item.parent_h, self.cur_item.parent_action_emb, self.desc_enc, ) self.cur_item = attr.evolve( self.cur_item, state=TreeTraversal.State.CHILDREN_APPLY, parent_h=output ) self.update_prev_action_emb = ( TreeTraversal._update_prev_action_emb_apply_rule ) choices = self.rule_choice(self.cur_item.node_type, rule_logits) return choices, False @Handler.register_handler(State.CHILDREN_APPLY) def process_children_apply(self, last_choice): # b. Create the children node_type, children_presence = self.model.preproc.all_rules[last_choice] assert node_type == self.cur_item.node_type self.queue = self.queue.append( TreeTraversal.QueueItem( item_id=self.cur_item.item_id, state=TreeTraversal.State.NODE_FINISHED, node_type=None, parent_action_emb=None, parent_h=None, parent_field_name=None, ) ) for field_info, present in reversed( list( zip( self.model.ast_wrapper.singular_types[node_type].fields, children_presence, ) ) ): if not present: continue # seq field: LIST_LENGTH_INQUIRE x # sum type: SUM_TYPE_INQUIRE x # product type: # no children: not possible # children: CHILDREN_INQUIRE # constructor type: not possible x # builtin type: GEN_TOKEN x child_type = field_type = field_info.type if field_info.seq: child_state = TreeTraversal.State.LIST_LENGTH_INQUIRE elif field_type in self.model.ast_wrapper.sum_types: child_state = TreeTraversal.State.SUM_TYPE_INQUIRE elif field_type in self.model.ast_wrapper.product_types: assert self.model.ast_wrapper.product_types[field_type].fields child_state = TreeTraversal.State.CHILDREN_INQUIRE elif field_type in self.model.preproc.grammar.pointers: child_state = TreeTraversal.State.POINTER_INQUIRE elif field_type in self.model.ast_wrapper.primitive_types: child_state = TreeTraversal.State.GEN_TOKEN child_type = present else: raise ValueError( "Unable to handle field type {}".format(field_type) ) self.queue = self.queue.append( TreeTraversal.QueueItem( item_id=self.next_item_id, state=child_state, node_type=child_type, parent_action_emb=self.prev_action_emb, parent_h=self.cur_item.parent_h, parent_field_name=field_info.name, ) ) self.next_item_id += 1 advanced = self.pop() assert advanced last_choice = None return last_choice, True @Handler.register_handler(State.LIST_LENGTH_INQUIRE) def process_list_length_inquire(self, last_choice): list_type = self.cur_item.node_type + "*" output, self.recurrent_state, rule_logits = self.model.apply_rule( list_type, self.recurrent_state, self.prev_action_emb, self.cur_item.parent_h, self.cur_item.parent_action_emb, self.desc_enc, ) self.cur_item = attr.evolve( self.cur_item, state=TreeTraversal.State.LIST_LENGTH_APPLY, parent_h=output ) self.update_prev_action_emb = ( TreeTraversal._update_prev_action_emb_apply_rule ) choices = self.rule_choice(list_type, rule_logits) return choices, False @Handler.register_handler(State.LIST_LENGTH_APPLY) def process_list_length_apply(self, last_choice): list_type, num_children = self.model.preproc.all_rules[last_choice] elem_type = self.cur_item.node_type assert list_type == elem_type + "*" child_node_type = elem_type if elem_type in self.model.ast_wrapper.sum_types: child_state = TreeTraversal.State.SUM_TYPE_INQUIRE if self.model.preproc.use_seq_elem_rules: child_node_type = elem_type + "_seq_elem" elif elem_type in self.model.ast_wrapper.product_types: child_state = TreeTraversal.State.CHILDREN_INQUIRE elif elem_type == "identifier": child_state = TreeTraversal.State.GEN_TOKEN child_node_type = "str" elif elem_type in self.model.ast_wrapper.primitive_types: # TODO: Fix this raise ValueError("sequential builtin types not supported") else: raise ValueError( "Unable to handle seq field type {}".format(elem_type) ) for i in range(num_children): self.queue = self.queue.append( TreeTraversal.QueueItem( item_id=self.next_item_id, state=child_state, node_type=child_node_type, parent_action_emb=self.prev_action_emb, parent_h=self.cur_item.parent_h, parent_field_name=self.cur_item.parent_field_name, ) ) self.next_item_id += 1 advanced = self.pop() assert advanced last_choice = None return last_choice, True @Handler.register_handler(State.GEN_TOKEN) def process_gen_token(self, last_choice): if last_choice == vocab.EOS: if self.pop(): last_choice = None return last_choice, True else: return None, False self.recurrent_state, output, gen_logodds = self.model.gen_token( self.cur_item.node_type, self.recurrent_state, self.prev_action_emb, self.cur_item.parent_h, self.cur_item.parent_action_emb, self.desc_enc, ) self.update_prev_action_emb = ( TreeTraversal._update_prev_action_emb_gen_token ) choices = self.token_choice(output, gen_logodds) return choices, False @Handler.register_handler(State.POINTER_INQUIRE) def process_pointer_inquire(self, last_choice): # a. Ask which one to choose output, self.recurrent_state, logits, attention_logits = self.model.compute_pointer_with_align( self.cur_item.node_type, self.recurrent_state, self.prev_action_emb, self.cur_item.parent_h, self.cur_item.parent_action_emb, self.desc_enc, ) self.cur_item = attr.evolve( self.cur_item, state=TreeTraversal.State.POINTER_APPLY, parent_h=output ) self.update_prev_action_emb = ( TreeTraversal._update_prev_action_emb_pointer ) choices = self.pointer_choice( self.cur_item.node_type, logits, attention_logits ) return choices, False @Handler.register_handler(State.POINTER_APPLY) def process_pointer_apply(self, last_choice): if self.pop(): last_choice = None return last_choice, True else: return None, False @Handler.register_handler(State.NODE_FINISHED) def process_node_finished(self, last_choice): if self.pop(): last_choice = None return last_choice, True else: return None, False def rule_choice(self, node_type, rule_logits): raise NotImplementedError def token_choice(self, output, gen_logodds): raise NotImplementedError def pointer_choice(self, node_type, logits, attention_logits): raise NotImplementedError