File size: 3,978 Bytes
d758c99 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 |
import ast
import collections
import collections.abc
import enum
import itertools
import json
import os
import operator
import re
import copy
import random
import asdl
import attr
import pyrsistent
import entmax
import torch
import torch.nn.functional as F
from seq2struct.models.nl2code.tree_traversal import TreeTraversal
@attr.s
class ChoiceHistoryEntry:
rule_left = attr.ib()
choices = attr.ib()
probs = attr.ib()
valid_choices = attr.ib()
class TrainTreeTraversal(TreeTraversal):
@attr.s(frozen=True)
class XentChoicePoint:
logits = attr.ib()
def compute_loss(self, outer, idx, extra_indices):
if extra_indices:
logprobs = torch.nn.functional.log_softmax(self.logits, dim=1)
valid_logprobs = logprobs[:, [idx] + extra_indices]
return outer.model.multi_loss_reduction(valid_logprobs)
else:
# idx shape: batch (=1)
idx = outer.model._tensor([idx])
# loss_piece shape: batch (=1)
return outer.model.xent_loss(self.logits, idx)
@attr.s(frozen=True)
class TokenChoicePoint:
lstm_output = attr.ib()
gen_logodds = attr.ib()
def compute_loss(self, outer, token, extra_tokens):
return outer.model.gen_token_loss(
self.lstm_output,
self.gen_logodds,
token,
outer.desc_enc)
def __init__(self, model, desc_enc, debug=False):
super().__init__(model, desc_enc)
self.choice_point = None
self.loss = pyrsistent.pvector()
self.debug = debug
self.history = pyrsistent.pvector()
def clone(self):
super_clone = super().clone()
super_clone.choice_point = self.choice_point
super_clone.loss = self.loss
super_clone.debug = self.debug
super_clone.history = self.history
return super_clone
def rule_choice(self, node_type, rule_logits):
self.choice_point = self.XentChoicePoint(rule_logits)
if self.debug:
choices = []
probs = []
for rule_idx, logprob in sorted(
self.model.rule_infer(node_type, rule_logits),
key=operator.itemgetter(1),
reverse=True):
_, rule = self.model.preproc.all_rules[rule_idx]
choices.append(rule)
probs.append(logprob.exp().item())
self.history = self.history.append(
ChoiceHistoryEntry(node_type, choices, probs, None))
def token_choice(self, output, gen_logodds):
self.choice_point = self.TokenChoicePoint(output, gen_logodds)
def pointer_choice(self, node_type, logits, attention_logits):
self.choice_point = self.XentChoicePoint(logits)
self.attention_choice = self.XentChoicePoint(attention_logits)
def update_using_last_choice(self, last_choice, extra_choice_info, attention_offset):
super().update_using_last_choice(last_choice, extra_choice_info, attention_offset)
if last_choice is None:
return
if self.debug and isinstance(self.choice_point, self.XentChoicePoint):
valid_choice_indices = [last_choice] + ([] if extra_choice_info is None
else extra_choice_info)
self.history[-1].valid_choices = [
self.model.preproc.all_rules[rule_idx][1]
for rule_idx in valid_choice_indices]
self.loss = self.loss.append(
self.choice_point.compute_loss(self, last_choice, extra_choice_info))
if attention_offset is not None and self.attention_choice is not None:
self.loss = self.loss.append(
self.attention_choice.compute_loss(self, attention_offset, extra_indices=None))
self.choice_point = None
self.attention_choice = None
|