antonlabate
ver 1.3
d758c99
import ast
import collections
import collections.abc
import enum
import itertools
import json
import os
import operator
import re
import copy
import random
import asdl
import attr
import pyrsistent
import entmax
import torch
import torch.nn.functional as F
from seq2struct import ast_util
from seq2struct import grammars
from seq2struct.models import abstract_preproc
from seq2struct.models import attention
from seq2struct.models import variational_lstm
from seq2struct.utils import registry
from seq2struct.utils import vocab
from seq2struct.utils import serialization
from seq2struct.models.nl2code.tree_traversal import TreeTraversal
from seq2struct.models.nl2code.train_tree_traversal import TrainTreeTraversal
from seq2struct.models.nl2code.infer_tree_traversal import InferenceTreeTraversal
def lstm_init(device, num_layers, hidden_size, *batch_sizes):
init_size = batch_sizes + (hidden_size, )
if num_layers is not None:
init_size = (num_layers, ) + init_size
init = torch.zeros(*init_size, device=device)
return (init, init)
def maybe_stack(items, dim=None):
to_stack = [item for item in items if item is not None]
if not to_stack:
return None
elif len(to_stack) == 1:
return to_stack[0].unsqueeze(dim)
else:
return torch.stack(to_stack, dim)
def accumulate_logprobs(d, keys_and_logprobs):
for key, logprob in keys_and_logprobs:
existing = d.get(key)
if existing is None:
d[key] = logprob
else:
d[key] = torch.logsumexp(
torch.stack((logprob, existing), dim=0),
dim=0)
def get_field_presence_info(ast_wrapper, node, field_infos):
present = []
for field_info in field_infos:
field_value = node.get(field_info.name)
is_present = field_value is not None and field_value != []
maybe_missing = field_info.opt or field_info.seq
is_builtin_type = field_info.type in ast_wrapper.primitive_types
if maybe_missing and is_builtin_type:
# TODO: make it posible to deal with "singleton?"
present.append(is_present and type(field_value).__name__)
elif maybe_missing and not is_builtin_type:
present.append(is_present)
elif not maybe_missing and is_builtin_type:
present.append(type(field_value).__name__)
elif not maybe_missing and not is_builtin_type:
assert is_present
present.append(True)
return tuple(present)
@attr.s
class NL2CodeDecoderPreprocItem:
tree = attr.ib()
orig_code = attr.ib()
class NL2CodeDecoderPreproc(abstract_preproc.AbstractPreproc):
def __init__(
self,
grammar,
save_path,
min_freq=3,
max_count=5000,
use_seq_elem_rules=False):
self.grammar = registry.construct('grammar', grammar)
self.ast_wrapper = self.grammar.ast_wrapper
self.vocab_path = os.path.join(save_path, 'dec_vocab.json')
self.observed_productions_path = os.path.join(save_path, 'observed_productions.json')
self.grammar_rules_path = os.path.join(save_path, 'grammar_rules.json')
self.data_dir = os.path.join(save_path, 'dec')
self.vocab_builder = vocab.VocabBuilder(min_freq, max_count)
self.use_seq_elem_rules = use_seq_elem_rules
self.items = collections.defaultdict(list)
self.sum_type_constructors = collections.defaultdict(set)
self.field_presence_infos = collections.defaultdict(set)
self.seq_lengths = collections.defaultdict(set)
self.primitive_types = set()
self.vocab = None
self.all_rules = None
self.rules_mask = None
def validate_item(self, item, section):
parsed = self.grammar.parse(item.code, section)
if parsed:
self.ast_wrapper.verify_ast(parsed)
return True, parsed
return section != 'train', None
def add_item(self, item, section, validation_info):
root = validation_info
if section == 'train':
for token in self._all_tokens(root):
self.vocab_builder.add_word(token)
self._record_productions(root)
self.items[section].append(
NL2CodeDecoderPreprocItem(
tree=root,
orig_code=item.code))
def clear_items(self):
self.items = collections.defaultdict(list)
def save(self):
os.makedirs(self.data_dir, exist_ok=True)
self.vocab = self.vocab_builder.finish()
self.vocab.save(self.vocab_path)
for section, items in self.items.items():
with open(os.path.join(self.data_dir, section + '.jsonl'), 'w', encoding='utf8') as f:
for item in items:
f.write(json.dumps(attr.asdict(item), ensure_ascii=False) + '\n')
# observed_productions
self.sum_type_constructors = serialization.to_dict_with_sorted_values(
self.sum_type_constructors)
self.field_presence_infos = serialization.to_dict_with_sorted_values(
self.field_presence_infos, key=str)
self.seq_lengths = serialization.to_dict_with_sorted_values(
self.seq_lengths)
self.primitive_types = sorted(self.primitive_types)
with open(self.observed_productions_path, 'w', encoding='utf8') as f:
json.dump({
'sum_type_constructors': self.sum_type_constructors,
'field_presence_infos': self.field_presence_infos,
'seq_lengths': self.seq_lengths,
'primitive_types': self.primitive_types,
}, f, indent=2, sort_keys=True, ensure_ascii=False)
# grammar
self.all_rules, self.rules_mask = self._calculate_rules()
with open(self.grammar_rules_path, 'w', encoding='utf8') as f:
json.dump({
'all_rules': self.all_rules,
'rules_mask': self.rules_mask,
}, f, indent=2, sort_keys=True, ensure_ascii=False)
def load(self):
self.vocab = vocab.Vocab.load(self.vocab_path)
observed_productions = json.load(open(self.observed_productions_path, encoding='utf8'))
self.sum_type_constructors = observed_productions['sum_type_constructors']
self.field_presence_infos = observed_productions['field_presence_infos']
self.seq_lengths = observed_productions['seq_lengths']
self.primitive_types = observed_productions['primitive_types']
grammar = json.load(open(self.grammar_rules_path, encoding='utf8'))
self.all_rules = serialization.tuplify(grammar['all_rules'])
self.rules_mask = grammar['rules_mask']
def dataset(self, section):
return [
NL2CodeDecoderPreprocItem(**json.loads(line))
for line in open(os.path.join(self.data_dir, section + '.jsonl'), encoding='utf8')]
def _record_productions(self, tree):
queue = [(tree, False)]
while queue:
node, is_seq_elem = queue.pop()
node_type = node['_type']
# Rules of the form:
# expr -> Attribute | Await | BinOp | BoolOp | ...
# expr_seq_elem -> Attribute | Await | ... | Template1 | Template2 | ...
for type_name in [node_type] + node.get('_extra_types', []):
if type_name in self.ast_wrapper.constructors:
sum_type_name = self.ast_wrapper.constructor_to_sum_type[type_name]
if is_seq_elem and self.use_seq_elem_rules:
self.sum_type_constructors[sum_type_name + '_seq_elem'].add(type_name)
else:
self.sum_type_constructors[sum_type_name].add(type_name)
# Rules of the form:
# FunctionDef
# -> identifier name, arguments args
# | identifier name, arguments args, stmt* body
# | identifier name, arguments args, expr* decorator_list
# | identifier name, arguments args, expr? returns
# ...
# | identifier name, arguments args, stmt* body, expr* decorator_list, expr returns
assert node_type in self.ast_wrapper.singular_types
field_presence_info = get_field_presence_info(
self.ast_wrapper,
node,
self.ast_wrapper.singular_types[node_type].fields)
self.field_presence_infos[node_type].add(field_presence_info)
for field_info in self.ast_wrapper.singular_types[node_type].fields:
field_value = node.get(field_info.name, [] if field_info.seq else None)
to_enqueue = []
if field_info.seq:
# Rules of the form:
# stmt* -> stmt
# | stmt stmt
# | stmt stmt stmt
self.seq_lengths[field_info.type + '*'].add(len(field_value))
to_enqueue = field_value
else:
to_enqueue = [field_value]
for child in to_enqueue:
if isinstance(child, collections.abc.Mapping) and '_type' in child:
queue.append((child, field_info.seq))
else:
self.primitive_types.add(type(child).__name__)
def _calculate_rules(self):
offset = 0
all_rules = []
rules_mask = {}
# Rules of the form:
# expr -> Attribute | Await | BinOp | BoolOp | ...
# expr_seq_elem -> Attribute | Await | ... | Template1 | Template2 | ...
for parent, children in sorted(self.sum_type_constructors.items()):
assert not isinstance(children, set)
rules_mask[parent] = (offset, offset + len(children))
offset += len(children)
all_rules += [(parent, child) for child in children]
# Rules of the form:
# FunctionDef
# -> identifier name, arguments args
# | identifier name, arguments args, stmt* body
# | identifier name, arguments args, expr* decorator_list
# | identifier name, arguments args, expr? returns
# ...
# | identifier name, arguments args, stmt* body, expr* decorator_list, expr returns
for name, field_presence_infos in sorted(self.field_presence_infos.items()):
assert not isinstance(field_presence_infos, set)
rules_mask[name] = (offset, offset + len(field_presence_infos))
offset += len(field_presence_infos)
all_rules += [(name, presence) for presence in field_presence_infos]
# Rules of the form:
# stmt* -> stmt
# | stmt stmt
# | stmt stmt stmt
for seq_type_name, lengths in sorted(self.seq_lengths.items()):
assert not isinstance(lengths, set)
rules_mask[seq_type_name] = (offset, offset + len(lengths))
offset += len(lengths)
all_rules += [(seq_type_name, i) for i in lengths]
return tuple(all_rules), rules_mask
def _all_tokens(self, root):
queue = [root]
while queue:
node = queue.pop()
type_info = self.ast_wrapper.singular_types[node['_type']]
for field_info in reversed(type_info.fields):
field_value = node.get(field_info.name)
if field_info.type in self.grammar.pointers:
pass
elif field_info.type in self.ast_wrapper.primitive_types:
for token in self.grammar.tokenize_field_value(field_value):
yield token
elif isinstance(field_value, (list, tuple)):
queue.extend(field_value)
elif field_value is not None:
queue.append(field_value)
@attr.s
class TreeState:
node = attr.ib()
parent_field_type = attr.ib()
@registry.register('decoder', 'NL2Code')
class NL2CodeDecoder(torch.nn.Module):
Preproc = NL2CodeDecoderPreproc
def __init__(
self,
device,
preproc,
#
rule_emb_size=128,
node_embed_size=64,
# TODO: This should be automatically inferred from encoder
enc_recurrent_size=256,
recurrent_size=256,
dropout=0.,
desc_attn='bahdanau',
copy_pointer=None,
multi_loss_type='logsumexp',
sup_att=None,
use_align_mat=False,
use_align_loss=False,
enumerate_order=False,
loss_type="softmax"):
super().__init__()
self._device = device
self.preproc = preproc
self.ast_wrapper = preproc.ast_wrapper
self.terminal_vocab = preproc.vocab
self.rule_emb_size = rule_emb_size
self.node_emb_size = node_embed_size
self.enc_recurrent_size = enc_recurrent_size
self.recurrent_size = recurrent_size
self.rules_index = {v: idx for idx, v in enumerate(self.preproc.all_rules)}
self.use_align_mat = use_align_mat
self.use_align_loss = use_align_loss
self.enumerate_order = enumerate_order
if use_align_mat:
from seq2struct.models.spider import spider_dec_func
self.compute_align_loss = lambda *args: \
spider_dec_func.compute_align_loss(self, *args)
self.compute_pointer_with_align = lambda *args: \
spider_dec_func.compute_pointer_with_align(self, *args)
if self.preproc.use_seq_elem_rules:
self.node_type_vocab = vocab.Vocab(
sorted(self.preproc.primitive_types) +
sorted(self.ast_wrapper.custom_primitive_types) +
sorted(self.preproc.sum_type_constructors.keys()) +
sorted(self.preproc.field_presence_infos.keys()) +
sorted(self.preproc.seq_lengths.keys()),
special_elems=())
else:
self.node_type_vocab = vocab.Vocab(
sorted(self.preproc.primitive_types) +
sorted(self.ast_wrapper.custom_primitive_types) +
sorted(self.ast_wrapper.sum_types.keys()) +
sorted(self.ast_wrapper.singular_types.keys()) +
sorted(self.preproc.seq_lengths.keys()),
special_elems=())
self.state_update = variational_lstm.RecurrentDropoutLSTMCell(
input_size=self.rule_emb_size * 2 + self.enc_recurrent_size + self.recurrent_size + self.node_emb_size,
hidden_size=self.recurrent_size,
dropout=dropout)
self.attn_type = desc_attn
if desc_attn == 'bahdanau':
self.desc_attn = attention.BahdanauAttention(
query_size=self.recurrent_size,
value_size=self.enc_recurrent_size,
proj_size=50)
elif desc_attn == 'mha':
self.desc_attn = attention.MultiHeadedAttention(
h=8,
query_size=self.recurrent_size,
value_size=self.enc_recurrent_size)
elif desc_attn == 'mha-1h':
self.desc_attn = attention.MultiHeadedAttention(
h=1,
query_size=self.recurrent_size,
value_size=self.enc_recurrent_size)
elif desc_attn == 'sep':
self.question_attn = attention.MultiHeadedAttention(
h=1,
query_size=self.recurrent_size,
value_size=self.enc_recurrent_size)
self.schema_attn = attention.MultiHeadedAttention(
h=1,
query_size=self.recurrent_size,
value_size=self.enc_recurrent_size)
else:
# TODO: Figure out how to get right sizes (query, value) to module
self.desc_attn = desc_attn
self.sup_att = sup_att
self.rule_logits = torch.nn.Sequential(
torch.nn.Linear(self.recurrent_size, self.rule_emb_size),
torch.nn.Tanh(),
torch.nn.Linear(self.rule_emb_size, len(self.rules_index)))
self.rule_embedding = torch.nn.Embedding(
num_embeddings=len(self.rules_index),
embedding_dim=self.rule_emb_size)
self.gen_logodds = torch.nn.Linear(self.recurrent_size, 1)
self.terminal_logits = torch.nn.Sequential(
torch.nn.Linear(self.recurrent_size, self.rule_emb_size),
torch.nn.Tanh(),
torch.nn.Linear(self.rule_emb_size, len(self.terminal_vocab)))
self.terminal_embedding = torch.nn.Embedding(
num_embeddings=len(self.terminal_vocab),
embedding_dim=self.rule_emb_size)
if copy_pointer is None:
self.copy_pointer = attention.BahdanauPointer(
query_size=self.recurrent_size,
key_size=self.enc_recurrent_size,
proj_size=50)
else:
# TODO: Figure out how to get right sizes (query, key) to module
self.copy_pointer = copy_pointer
if multi_loss_type == 'logsumexp':
self.multi_loss_reduction = lambda logprobs: -torch.logsumexp(logprobs, dim=1)
elif multi_loss_type == 'mean':
self.multi_loss_reduction = lambda logprobs: -torch.mean(logprobs, dim=1)
self.pointers = torch.nn.ModuleDict()
self.pointer_action_emb_proj = torch.nn.ModuleDict()
for pointer_type in self.preproc.grammar.pointers:
self.pointers[pointer_type] = attention.ScaledDotProductPointer(
query_size=self.recurrent_size,
key_size=self.enc_recurrent_size)
self.pointer_action_emb_proj[pointer_type] = torch.nn.Linear(
self.enc_recurrent_size, self.rule_emb_size)
self.node_type_embedding = torch.nn.Embedding(
num_embeddings=len(self.node_type_vocab),
embedding_dim=self.node_emb_size)
# TODO batching
self.zero_rule_emb = torch.zeros(1, self.rule_emb_size, device=self._device)
self.zero_recurrent_emb = torch.zeros(1, self.recurrent_size, device=self._device)
if loss_type == "softmax":
self.xent_loss = torch.nn.CrossEntropyLoss(reduction='none')
elif loss_type == "entmax":
self.xent_loss = entmax.entmax15_loss
elif loss_type == "sparsemax":
self.xent_loss = entmax.sparsemax_loss
elif loss_type == "label_smooth":
self.xent_loss = self.label_smooth_loss
def label_smooth_loss(self, X, target, smooth_value=0.1):
if self.training:
logits = torch.log_softmax(X, dim=1)
size = X.size()[1]
one_hot = torch.full(X.size(), smooth_value / (size - 1)).to(X.device)
one_hot.scatter_(1, target.unsqueeze(0), 1 - smooth_value)
loss = F.kl_div(logits, one_hot, reduction="batchmean")
return loss.unsqueeze(0)
else:
return torch.nn.functional.cross_entropy(X, target, reduction="none")
@classmethod
def _calculate_rules(cls, preproc):
offset = 0
all_rules = []
rules_mask = {}
# Rules of the form:
# expr -> Attribute | Await | BinOp | BoolOp | ...
# expr_seq_elem -> Attribute | Await | ... | Template1 | Template2 | ...
for parent, children in sorted(preproc.sum_type_constructors.items()):
assert parent not in rules_mask
rules_mask[parent] = (offset, offset + len(children))
offset += len(children)
all_rules += [(parent, child) for child in children]
# Rules of the form:
# FunctionDef
# -> identifier name, arguments args
# | identifier name, arguments args, stmt* body
# | identifier name, arguments args, expr* decorator_list
# | identifier name, arguments args, expr? returns
# ...
# | identifier name, arguments args, stmt* body, expr* decorator_list, expr returns
for name, field_presence_infos in sorted(preproc.field_presence_infos.items()):
assert name not in rules_mask
rules_mask[name] = (offset, offset + len(field_presence_infos))
offset += len(field_presence_infos)
all_rules += [(name, presence) for presence in field_presence_infos]
# Rules of the form:
# stmt* -> stmt
# | stmt stmt
# | stmt stmt stmt
for seq_type_name, lengths in sorted(preproc.seq_lengths.items()):
assert seq_type_name not in rules_mask
rules_mask[seq_type_name] = (offset, offset + len(lengths))
offset += len(lengths)
all_rules += [(seq_type_name, i) for i in lengths]
return all_rules, rules_mask
def compute_loss(self, enc_input, example, desc_enc, debug):
if not self.enumerate_order or not self.training:
mle_loss = self.compute_mle_loss(enc_input, example, desc_enc, debug)
else:
mle_loss = self.compute_loss_from_all_ordering(enc_input, example, desc_enc, debug)
if self.use_align_loss:
align_loss = self.compute_align_loss(desc_enc, example)
return mle_loss + align_loss
return mle_loss
def compute_loss_from_all_ordering(self, enc_input, example, desc_enc, debug):
def get_permutations(node):
def traverse_tree(node):
nonlocal permutations
if isinstance(node, (list, tuple)):
p = itertools.permutations(range(len(node)))
permutations.append(list(p))
for child in node:
traverse_tree(child)
elif isinstance(node, dict):
for node_name in node:
traverse_tree(node[node_name])
permutations = []
traverse_tree(node)
return permutations
def get_perturbed_tree(node, permutation):
def traverse_tree(node, parent_type, parent_node):
if isinstance(node, (list, tuple)):
nonlocal permutation
p_node = [node[i] for i in permutation[0]]
parent_node[parent_type] = p_node
permutation = permutation[1:]
for child in node:
traverse_tree(child, None, None)
elif isinstance(node, dict):
for node_name in node:
traverse_tree(node[node_name], node_name, node)
node = copy.deepcopy(node)
traverse_tree(node, None, None)
return node
orig_tree = example.tree
permutations = get_permutations(orig_tree)
products = itertools.product(*permutations)
loss_list = []
for product in products:
tree = get_perturbed_tree(orig_tree, product)
example.tree = tree
loss = self.compute_mle_loss(enc_input, example, desc_enc)
loss_list.append(loss)
example.tree = orig_tree
loss_v = torch.stack(loss_list, 0)
return torch.logsumexp(loss_v, 0)
def compute_mle_loss(self, enc_input, example, desc_enc, debug=False):
traversal = TrainTreeTraversal(self, desc_enc, debug)
traversal.step(None)
queue = [
TreeState(
node=example.tree,
parent_field_type=self.preproc.grammar.root_type,
)
]
while queue:
item = queue.pop()
node = item.node
parent_field_type = item.parent_field_type
if isinstance(node, (list, tuple)):
node_type = parent_field_type + '*'
rule = (node_type, len(node))
rule_idx = self.rules_index[rule]
assert traversal.cur_item.state == TreeTraversal.State.LIST_LENGTH_APPLY
traversal.step(rule_idx)
if self.preproc.use_seq_elem_rules and parent_field_type in self.ast_wrapper.sum_types:
parent_field_type += '_seq_elem'
for i, elem in reversed(list(enumerate(node))):
queue.append(
TreeState(
node=elem,
parent_field_type=parent_field_type,
))
continue
if parent_field_type in self.preproc.grammar.pointers:
assert isinstance(node, int)
assert traversal.cur_item.state == TreeTraversal.State.POINTER_APPLY
pointer_map = desc_enc.pointer_maps.get(parent_field_type)
if pointer_map:
values = pointer_map[node]
if self.sup_att == '1h':
if len(pointer_map) == len(enc_input['columns']):
if self.attn_type != 'sep':
traversal.step(values[0], values[1:], node + len(enc_input['question']))
else:
traversal.step(values[0], values[1:], node)
else:
if self.attn_type != 'sep':
traversal.step(values[0], values[1:], node + len(enc_input['question']) + len(enc_input['columns']))
else:
traversal.step(values[0], values[1:], node + len(enc_input['columns']))
else:
traversal.step(values[0], values[1:])
else:
traversal.step(node)
continue
if parent_field_type in self.ast_wrapper.primitive_types:
# identifier, int, string, bytes, object, singleton
# - could be bytes, str, int, float, bool, NoneType
# - terminal tokens vocabulary is created by turning everything into a string (with `str`)
# - at decoding time, cast back to str/int/float/bool
field_type = type(node).__name__
field_value_split = self.preproc.grammar.tokenize_field_value(node) + [
vocab.EOS]
for token in field_value_split:
assert traversal.cur_item.state == TreeTraversal.State.GEN_TOKEN
traversal.step(token)
continue
type_info = self.ast_wrapper.singular_types[node['_type']]
if parent_field_type in self.preproc.sum_type_constructors:
# ApplyRule, like expr -> Call
rule = (parent_field_type, type_info.name)
rule_idx = self.rules_index[rule]
assert traversal.cur_item.state == TreeTraversal.State.SUM_TYPE_APPLY
extra_rules = [
self.rules_index[parent_field_type, extra_type]
for extra_type in node.get('_extra_types', [])]
traversal.step(rule_idx, extra_rules)
if type_info.fields:
# ApplyRule, like Call -> expr[func] expr*[args] keyword*[keywords]
# Figure out which rule needs to be applied
present = get_field_presence_info(self.ast_wrapper, node, type_info.fields)
rule = (node['_type'], tuple(present))
rule_idx = self.rules_index[rule]
assert traversal.cur_item.state == TreeTraversal.State.CHILDREN_APPLY
traversal.step(rule_idx)
# reversed so that we perform a DFS in left-to-right order
for field_info in reversed(type_info.fields):
if field_info.name not in node:
continue
queue.append(
TreeState(
node=node[field_info.name],
parent_field_type=field_info.type,
))
loss = torch.sum(torch.stack(tuple(traversal.loss), dim=0), dim=0)
if debug:
return loss, [attr.asdict(entry) for entry in traversal.history]
else:
return loss
def begin_inference(self, desc_enc, example):
traversal = InferenceTreeTraversal(self, desc_enc, example)
choices = traversal.step(None)
return traversal, choices
def _desc_attention(self, prev_state, desc_enc):
# prev_state shape:
# - h_n: batch (=1) x emb_size
# - c_n: batch (=1) x emb_size
query = prev_state[0]
if self.attn_type != 'sep':
return self.desc_attn(query, desc_enc.memory, attn_mask=None)
else:
question_context, question_attention_logits = self.question_attn(query, desc_enc.question_memory)
schema_context, schema_attention_logits = self.schema_attn(query, desc_enc.schema_memory)
return question_context + schema_context, schema_attention_logits
def _tensor(self, data, dtype=None):
return torch.tensor(data, dtype=dtype, device=self._device)
def _index(self, vocab, word):
return self._tensor([vocab.index(word)])
def _update_state(
self,
node_type,
prev_state,
prev_action_emb,
parent_h,
parent_action_emb,
desc_enc):
# desc_context shape: batch (=1) x emb_size
desc_context, attention_logits = self._desc_attention(prev_state, desc_enc)
if self.visualize_flag:
attention_weights = F.softmax(attention_logits, dim = -1)
print(attention_weights)
# node_type_emb shape: batch (=1) x emb_size
node_type_emb = self.node_type_embedding(
self._index(self.node_type_vocab, node_type))
state_input = torch.cat(
(
prev_action_emb, # a_{t-1}: rule_emb_size
desc_context, # c_t: enc_recurrent_size
parent_h, # s_{p_t}: recurrent_size
parent_action_emb, # a_{p_t}: rule_emb_size
node_type_emb, # n_{f-t}: node_emb_size
),
dim=-1)
new_state = self.state_update(
# state_input shape: batch (=1) x (emb_size * 5)
state_input, prev_state)
return new_state, attention_logits
def apply_rule(
self,
node_type,
prev_state,
prev_action_emb,
parent_h,
parent_action_emb,
desc_enc):
new_state, attention_logits = self._update_state(
node_type, prev_state, prev_action_emb, parent_h, parent_action_emb, desc_enc)
# output shape: batch (=1) x emb_size
output = new_state[0]
# rule_logits shape: batch (=1) x num choices
rule_logits = self.rule_logits(output)
return output, new_state, rule_logits
def rule_infer(self, node_type, rule_logits):
rule_logprobs = torch.nn.functional.log_softmax(rule_logits, dim=-1)
rules_start, rules_end = self.preproc.rules_mask[node_type]
# TODO: Mask other probabilities first?
return list(zip(
range(rules_start, rules_end),
rule_logprobs[0, rules_start:rules_end]))
def gen_token(
self,
node_type,
prev_state,
prev_action_emb,
parent_h,
parent_action_emb,
desc_enc):
new_state, attention_logits = self._update_state(
node_type, prev_state, prev_action_emb, parent_h, parent_action_emb, desc_enc)
# output shape: batch (=1) x emb_size
output = new_state[0]
# gen_logodds shape: batch (=1)
gen_logodds = self.gen_logodds(output).squeeze(1)
return new_state, output, gen_logodds
def gen_token_loss(
self,
output,
gen_logodds,
token,
desc_enc):
# token_idx shape: batch (=1), LongTensor
token_idx = self._index(self.terminal_vocab, token)
# action_emb shape: batch (=1) x emb_size
action_emb = self.terminal_embedding(token_idx)
# +unk, +in desc: copy
# +unk, -in desc: gen (an unk token)
# -unk, +in desc: copy, gen
# -unk, -in desc: gen
# gen_logodds shape: batch (=1)
desc_locs = desc_enc.find_word_occurrences(token)
if desc_locs:
# copy: if the token appears in the description at least once
# copy_loc_logits shape: batch (=1) x desc length
copy_loc_logits = self.copy_pointer(output, desc_enc.memory)
copy_logprob = (
# log p(copy | output)
# shape: batch (=1)
torch.nn.functional.logsigmoid(-gen_logodds) -
# xent_loss: -log p(location | output)
# TODO: sum the probability of all occurrences
# shape: batch (=1)
self.xent_loss(copy_loc_logits, self._tensor(desc_locs[0:1])))
else:
copy_logprob = None
# gen: ~(unk & in desc), equivalent to ~unk | ~in desc
if token in self.terminal_vocab or copy_logprob is None:
token_logits = self.terminal_logits(output)
# shape:
gen_logprob = (
# log p(gen | output)
# shape: batch (=1)
torch.nn.functional.logsigmoid(gen_logodds) -
# xent_loss: -log p(token | output)
# shape: batch (=1)
self.xent_loss(token_logits, token_idx))
else:
gen_logprob = None
# loss should be -log p(...), so negate
loss_piece = -torch.logsumexp(
maybe_stack([copy_logprob, gen_logprob], dim=1),
dim=1)
return loss_piece
def token_infer(self, output, gen_logodds, desc_enc):
# Copy tokens
# log p(copy | output)
# shape: batch (=1)
copy_logprob = torch.nn.functional.logsigmoid(-gen_logodds)
copy_loc_logits = self.copy_pointer(output, desc_enc.memory)
# log p(loc_i | copy, output)
# shape: batch (=1) x seq length
copy_loc_logprobs = torch.nn.functional.log_softmax(copy_loc_logits, dim=-1)
# log p(loc_i, copy | output)
copy_loc_logprobs += copy_logprob
log_prob_by_word = {}
# accumulate_logprobs is needed because the same word may appear
# multiple times in desc_enc.words.
accumulate_logprobs(
log_prob_by_word,
zip(desc_enc.words, copy_loc_logprobs.squeeze(0)))
# Generate tokens
# log p(~copy | output)
# shape: batch (=1)
gen_logprob = torch.nn.functional.logsigmoid(gen_logodds)
token_logits = self.terminal_logits(output)
# log p(v | ~copy, output)
# shape: batch (=1) x vocab size
token_logprobs = torch.nn.functional.log_softmax(token_logits, dim=-1)
# log p(v, ~copy| output)
# shape: batch (=1) x vocab size
token_logprobs += gen_logprob
accumulate_logprobs(
log_prob_by_word,
((self.terminal_vocab[idx], token_logprobs[0, idx]) for idx in range(token_logprobs.shape[1])))
return list(log_prob_by_word.items())
def compute_pointer(
self,
node_type,
prev_state,
prev_action_emb,
parent_h,
parent_action_emb,
desc_enc):
new_state, attention_logits = self._update_state(
node_type, prev_state, prev_action_emb, parent_h, parent_action_emb, desc_enc)
# output shape: batch (=1) x emb_size
output = new_state[0]
# pointer_logits shape: batch (=1) x num choices
pointer_logits = self.pointers[node_type](
output, desc_enc.pointer_memories[node_type])
return output, new_state, pointer_logits, attention_logits
def pointer_infer(self, node_type, logits):
logprobs = torch.nn.functional.log_softmax(logits, dim=-1)
return list(zip(
# TODO batching
range(logits.shape[1]),
logprobs[0]))