gap-text2sql
/
gap-text2sql-main
/mrat-sql-gap
/seq2struct
/models
/nl2code
/train_tree_traversal.py
import ast | |
import collections | |
import collections.abc | |
import enum | |
import itertools | |
import json | |
import os | |
import operator | |
import re | |
import copy | |
import random | |
import asdl | |
import attr | |
import pyrsistent | |
import entmax | |
import torch | |
import torch.nn.functional as F | |
from seq2struct.models.nl2code.tree_traversal import TreeTraversal | |
class ChoiceHistoryEntry: | |
rule_left = attr.ib() | |
choices = attr.ib() | |
probs = attr.ib() | |
valid_choices = attr.ib() | |
class TrainTreeTraversal(TreeTraversal): | |
class XentChoicePoint: | |
logits = attr.ib() | |
def compute_loss(self, outer, idx, extra_indices): | |
if extra_indices: | |
logprobs = torch.nn.functional.log_softmax(self.logits, dim=1) | |
valid_logprobs = logprobs[:, [idx] + extra_indices] | |
return outer.model.multi_loss_reduction(valid_logprobs) | |
else: | |
# idx shape: batch (=1) | |
idx = outer.model._tensor([idx]) | |
# loss_piece shape: batch (=1) | |
return outer.model.xent_loss(self.logits, idx) | |
class TokenChoicePoint: | |
lstm_output = attr.ib() | |
gen_logodds = attr.ib() | |
def compute_loss(self, outer, token, extra_tokens): | |
return outer.model.gen_token_loss( | |
self.lstm_output, | |
self.gen_logodds, | |
token, | |
outer.desc_enc) | |
def __init__(self, model, desc_enc, debug=False): | |
super().__init__(model, desc_enc) | |
self.choice_point = None | |
self.loss = pyrsistent.pvector() | |
self.debug = debug | |
self.history = pyrsistent.pvector() | |
def clone(self): | |
super_clone = super().clone() | |
super_clone.choice_point = self.choice_point | |
super_clone.loss = self.loss | |
super_clone.debug = self.debug | |
super_clone.history = self.history | |
return super_clone | |
def rule_choice(self, node_type, rule_logits): | |
self.choice_point = self.XentChoicePoint(rule_logits) | |
if self.debug: | |
choices = [] | |
probs = [] | |
for rule_idx, logprob in sorted( | |
self.model.rule_infer(node_type, rule_logits), | |
key=operator.itemgetter(1), | |
reverse=True): | |
_, rule = self.model.preproc.all_rules[rule_idx] | |
choices.append(rule) | |
probs.append(logprob.exp().item()) | |
self.history = self.history.append( | |
ChoiceHistoryEntry(node_type, choices, probs, None)) | |
def token_choice(self, output, gen_logodds): | |
self.choice_point = self.TokenChoicePoint(output, gen_logodds) | |
def pointer_choice(self, node_type, logits, attention_logits): | |
self.choice_point = self.XentChoicePoint(logits) | |
self.attention_choice = self.XentChoicePoint(attention_logits) | |
def update_using_last_choice(self, last_choice, extra_choice_info, attention_offset): | |
super().update_using_last_choice(last_choice, extra_choice_info, attention_offset) | |
if last_choice is None: | |
return | |
if self.debug and isinstance(self.choice_point, self.XentChoicePoint): | |
valid_choice_indices = [last_choice] + ([] if extra_choice_info is None | |
else extra_choice_info) | |
self.history[-1].valid_choices = [ | |
self.model.preproc.all_rules[rule_idx][1] | |
for rule_idx in valid_choice_indices] | |
self.loss = self.loss.append( | |
self.choice_point.compute_loss(self, last_choice, extra_choice_info)) | |
if attention_offset is not None and self.attention_choice is not None: | |
self.loss = self.loss.append( | |
self.attention_choice.compute_loss(self, attention_offset, extra_indices=None)) | |
self.choice_point = None | |
self.attention_choice = None | |