Spaces:
Build error
Build error
# Copyright (c) Microsoft Corporation. | |
# Licensed under the MIT license. | |
import copy | |
import json | |
import math | |
import re | |
import collections | |
import numpy as np | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from torch.autograd import Variable | |
from torch.nn.parameter import Parameter | |
def gelu(x): | |
return ( | |
0.5 | |
* x | |
* (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3)))) | |
) | |
def swish(x): | |
return x * torch.sigmoid(x) | |
class LayerNorm(nn.Module): | |
"Construct a layernorm module in the OpenAI style (epsilon inside the square root)." | |
def __init__(self, n_state, e=1e-5): | |
super(LayerNorm, self).__init__() | |
self.g = nn.Parameter(torch.ones(n_state)) | |
self.b = nn.Parameter(torch.zeros(n_state)) | |
self.e = e | |
""" | |
Input: | |
x: n_state-dim | |
Output: | |
o: n_state-dim | |
""" | |
def forward(self, x): | |
u = x.mean(-1, keepdim=True) | |
s = (x - u).pow(2).mean(-1, keepdim=True) | |
x = (x - u) / torch.sqrt(s + self.e) | |
return self.g * x + self.b | |
""" | |
Convolution | |
nx is the last input dim | |
nf is the last output dim | |
""" | |
class Conv1D(nn.Module): | |
def __init__(self, nf, nx): | |
super(Conv1D, self).__init__() | |
self.nf = nf | |
w = torch.empty(nx, nf) | |
nn.init.normal_(w, std=0.02) | |
self.w = Parameter(w) | |
self.b = Parameter(torch.zeros(nf)) | |
""" | |
Input: | |
x: batch x len x nx | |
Output: | |
x: batch x len x nf | |
""" | |
def forward(self, x): | |
size_out = x.size()[:-1] + (self.nf,) | |
x = torch.addmm(self.b, x.view(-1, x.size(-1)), self.w) | |
x = x.view(*size_out) | |
return x | |
class PositionalEmbedding(nn.Module): | |
def __init__(self, opt, demb): | |
super(PositionalEmbedding, self).__init__() | |
self.demb = demb | |
inv_freq = 1 / (10000 ** (torch.arange(0.0, demb, 2.0) / demb)) | |
self.pos_discount = float(opt["TRANSFORMER_POS_DISCOUNT"]) | |
self.register_buffer("inv_freq", inv_freq) | |
""" | |
Input: | |
pos_seq: len | |
Output: | |
pos_emb: len x demb | |
""" | |
def forward(self, pos_seq): | |
sinusoid_inp = torch.ger(pos_seq, self.inv_freq) | |
pos_emb = ( | |
torch.cat([sinusoid_inp.sin(), sinusoid_inp.cos()], dim=-1) | |
/ self.pos_discount | |
) | |
return pos_emb | |
""" | |
Splitter | |
""" | |
class Splitter(nn.Module): | |
def __init__(self, nx): | |
super(Splitter, self).__init__() | |
self.nx = nx | |
self.augmenter = Conv1D(nx * 3, nx) | |
""" | |
Input: | |
x: batch x len x nx | |
Output: | |
query,key,value: batch x len x nx | |
""" | |
def forward(self, x): | |
x = self.augmenter(x) | |
# x: batch x len x (3 x nx) | |
query, key, value = x.split(self.nx, dim=2) | |
# query,key,value: batch x len x nx | |
return query, key, value | |
""" | |
Multi-head Attention | |
""" | |
class Attention(nn.Module): | |
""" | |
nx: input dimension | |
""" | |
def __init__(self, nx, opt): | |
super(Attention, self).__init__() | |
n_state = nx # in Attention: n_state=768 (nx=n_embd) | |
# [switch nx => n_state from Block to Attention to keep identical to TF implem] | |
n_head = int(opt["TRANSFORMER_HEAD"]) | |
resid_pdrop = opt["TRANSFORMER_RESIDUAL_DROPOUT"] | |
attn_pdrop = opt["TRANSFORMER_ATTENTION_DROPOUT"] | |
use_cuda = opt["cuda"] | |
assert n_state % n_head == 0 | |
# if mask is needed, uncomment this | |
self.maxlen = 2048 # beyond this scale | |
self.mask = ( | |
Variable( | |
torch.tril(torch.ones(self.maxlen, self.maxlen)).view( | |
1, 1, self.maxlen, self.maxlen | |
), | |
requires_grad=False, | |
).cuda() | |
if use_cuda | |
else Variable( | |
torch.tril(torch.ones(self.maxlen, self.maxlen)).view( | |
1, 1, self.maxlen, self.maxlen | |
), | |
requires_grad=False, | |
) | |
) | |
self.n_head = n_head | |
self.c_proj = Conv1D(n_state, nx) | |
self.attn_dropout = nn.Dropout(attn_pdrop) | |
self.resid_dropout = nn.Dropout(resid_pdrop) | |
self.use_cuda = use_cuda | |
""" | |
Input: | |
q: batch x n_head x len x dim | |
k: batch x n_head x dim x kv_len | |
v: batch x n_head x kv_len x dim | |
x_mask: batch x kv_len # key and value's mask (if not None, used for encoder's self-attention and decoder's src-tgt attention) | |
one_dir_visible: only sees previous history (used for decoder's self-attention) | |
return_attn_weight: if true, also return the attention weights | |
Output: | |
a: batch x n_head x len x n_state x dim | |
attn_weight (if return_attn_weight): attn_weight: batch x n_head x len x kv_len | |
""" | |
def _attn(self, q, k, v, x_mask, one_dir_visible, return_attn_weight): | |
w = torch.matmul(q, k) | |
# batch x n_head x len x kv_len | |
w = w / math.sqrt(v.size(-1)) | |
mask = None | |
if one_dir_visible: # mask "seeing the future" | |
if w.size(-2) <= self.maxlen and w.size(-1) <= self.maxlen: | |
mask = ( | |
self.mask[:, :, : w.size(-2), : w.size(-1)].cuda() | |
if self.use_cuda | |
else self.mask[:, :, : w.size(-2), : w.size(-1)] | |
) | |
else: | |
mask = ( | |
Variable( | |
torch.tril(torch.ones(w.size(-2), w.size(-1))).view( | |
1, 1, w.size(-2), w.size(-1) | |
), | |
requires_grad=False, | |
).cuda() | |
if self.use_cuda | |
else Variable( | |
torch.tril(torch.ones(w.size(-2), w.size(-1))).view( | |
1, 1, w.size(-2), w.size(-1) | |
), | |
requires_grad=False, | |
) | |
) | |
if x_mask is not None: | |
mask = x_mask.unsqueeze(1).unsqueeze(1).expand_as(w).float() | |
# batch x n_head x len x kv_len | |
if mask is not None: | |
w = w * mask + -1e9 * (1 - mask) | |
w_prob = nn.Softmax(dim=-1)(w) | |
w_prob = self.attn_dropout(w_prob) | |
if return_attn_weight: | |
return torch.matmul(w_prob, v), w | |
else: | |
return torch.matmul(w_prob, v) | |
def merge_heads(self, x): | |
x = x.permute(0, 2, 1, 3).contiguous() | |
new_x_shape = x.size()[:-2] + (x.size(-2) * x.size(-1),) | |
return x.view(*new_x_shape) # in Tensorflow implem: fct merge_states | |
""" | |
Input: | |
x: batch x len x dim | |
Output: | |
not k: batch x n_head x (dim/n_head) x len | |
k: batch x n_head x len x (dim/n_head) | |
""" | |
def split_heads(self, x, k=False): | |
new_x_shape = x.size()[:-1] + (self.n_head, x.size(-1) // self.n_head) | |
x = x.view(*new_x_shape) # in Tensorflow implem: fct split_states | |
if k: | |
return x.permute(0, 2, 3, 1) | |
else: | |
return x.permute(0, 2, 1, 3) | |
""" | |
Input: | |
query: batch x len x n_state | |
key, value: batch x kv_len x n_state | |
x_mask: batch x kv_len # key and value's mask (if not None, used for encoder's self-attention and decoder's src-tgt attention) | |
one_dir_visible: only sees previous history (used for decoder's self-attention) | |
return_attn_weight: if true, also return the attention weights | |
Output: | |
a: batch x len x n_state | |
attn_weight (if return_attn_weight): batch x len x kv_len | |
""" | |
def forward( | |
self, query, key, value, x_mask, one_dir_visible=False, return_attn_weight=False | |
): | |
query = self.split_heads(query) | |
# batch x n_head x len x (n_state/n_head) | |
key = self.split_heads(key, k=True) | |
# batch x n_head x (n_state/n_head) x kv_len | |
value = self.split_heads(value) | |
# batch x n_head x kv_len x (n_state/n_head) | |
out = self._attn(query, key, value, x_mask, one_dir_visible, return_attn_weight) | |
if return_attn_weight: | |
a, attn_weight = out | |
# a: batch x n_head x len x (n_state/n_head) | |
# attn_weight: batch x n_head x len x kv_len | |
attn_weight = attn_weight.permute(0, 2, 3, 1).contiguous() | |
# batch x len x kv_len x n_head | |
attn_weight = torch.sum(attn_weight, dim=3) | |
# batch x len x kv_len | |
else: | |
a = out | |
# batch x n_head x len x (n_state/n_head) | |
a = self.merge_heads(a) | |
# batch x len x n_state | |
a = self.c_proj(a) | |
# batch x len x n_state | |
a = self.resid_dropout(a) | |
# batch x len x n_state | |
if return_attn_weight: | |
return a, attn_weight | |
else: | |
return a | |
""" | |
Two-layer network | |
""" | |
class MLP(nn.Module): | |
""" | |
Input: | |
n_state: intermediate dim | |
""" | |
def __init__(self, n_state, opt): # in MLP: n_state=3072 (4 * n_embd) | |
super(MLP, self).__init__() | |
nx = int(opt["transformer_embed_dim"]) | |
resid_pdrop = opt["TRANSFORMER_RESIDUAL_DROPOUT"] | |
self.c_fc = Conv1D(n_state, nx) | |
self.c_proj = Conv1D(nx, n_state) | |
self.dropout = nn.Dropout(resid_pdrop) | |
""" | |
Input: | |
x: batch x len x nx | |
Output: batch x len x nx | |
""" | |
def forward(self, x): | |
h = F.relu(self.c_fc(x)) | |
h2 = self.c_proj(h) | |
return self.dropout(h2) | |
""" | |
One encoder block of transformer | |
""" | |
class EncoderBlock(nn.Module): | |
def __init__(self, opt): | |
super(EncoderBlock, self).__init__() | |
nx = int(opt["transformer_embed_dim"]) | |
self.one_dir_visible = False | |
if "transformer_encoder_one_dir_visible" in opt: | |
self.one_dir_visible = opt["transformer_encoder_one_dir_visible"] | |
self.splitter = Splitter(nx) | |
self.attn = Attention(nx, opt) | |
self.ln_1 = LayerNorm(nx) | |
self.mlp = MLP(4 * nx, opt) | |
self.ln_2 = LayerNorm(nx) | |
""" | |
Input: | |
x: batch x len x n_state | |
x_mask: batch x len (1 means there's something) | |
Output: | |
h: batch x len x n_state | |
""" | |
def forward(self, x, x_mask): | |
query, key, value = self.splitter(x) | |
if self.one_dir_visible: | |
# in this case, use triangle masking, as it's one_direction | |
a = self.attn(query, key, value, None, one_dir_visible=True) | |
else: | |
# in this case, use x_mask for attention masking | |
a = self.attn(query, key, value, x_mask, one_dir_visible=False) | |
n = self.ln_1(x + a) # residual | |
m = self.mlp(n) | |
h = self.ln_2(n + m) | |
return h | |
""" | |
One encoder block of transformer | |
""" | |
class DecoderBlock(nn.Module): | |
def __init__(self, opt): | |
super(DecoderBlock, self).__init__() | |
nx = int(opt["transformer_embed_dim"]) | |
self.decoder_splitter = Splitter(nx) | |
self.self_attn = Attention(nx, opt) | |
self.cross_attn = Attention(nx, opt) | |
self.ln_1 = LayerNorm(nx) | |
self.ln_2 = LayerNorm(nx) | |
self.mlp = MLP(4 * nx, opt) | |
self.ln_3 = LayerNorm(nx) | |
""" | |
Input: | |
x_mask: batch x len, mask for encoder's input | |
y: batch x len x n_state (decoder part) | |
enc_key: batch x encoder_len x n_state | |
enc_value: batch x encoder_len x n_state | |
lang_model: whether it's for language model training (no encoder part is used) | |
Output: | |
h: batch x len x n_state | |
""" | |
def forward(self, x_mask, y, enc_key, enc_value, lang_model=False): | |
query, key, value = self.decoder_splitter(y) | |
# batch x len x n_state | |
# self-attention | |
a = self.self_attn(query, key, value, None, one_dir_visible=True) | |
# batch x len x n_state | |
n = self.ln_1(y + a) # residual | |
# seq2seq | |
if not lang_model: | |
# src-tgt attention | |
o = self.cross_attn(n, enc_key, enc_value, x_mask) | |
p = self.ln_2(n + o) # residual | |
# batch x len x n_state | |
else: # language model | |
p = n | |
m = self.mlp(p) | |
h = self.ln_3(p + m) | |
return h | |
""" | |
Embedder | |
""" | |
class Embedder(nn.Module): | |
""" | |
Input: | |
vocab: size of vocabulary | |
""" | |
def __init__(self, opt, embed=None): | |
super(Embedder, self).__init__() | |
n_state = int(opt["transformer_embed_dim"]) # n_state | |
embed_dropout_rate = opt["TRANSFORMER_EMBED_DROPOUT"] | |
if embed is None: | |
self.embed = nn.Embedding(opt["vocab_size"], n_state) | |
nn.init.normal_(self.embed.weight, std=0.02) | |
else: | |
self.embed = embed | |
self.drop = nn.Dropout(embed_dropout_rate) | |
self.pos_emb = PositionalEmbedding(opt, n_state) | |
self.use_cuda = opt["cuda"] | |
""" | |
Input: | |
x: batch x len (word_id) | |
Output: | |
h: batch x len x n_state | |
""" | |
def forward(self, x): | |
x_emb = self.embed(x) | |
batch_size = x.shape[0] | |
x_len = x.shape[1] | |
x_pos = self.pos_emb( | |
torch.arange(x_len).type( | |
torch.cuda.FloatTensor if self.use_cuda else torch.FloatTensor | |
) | |
) # len x n_state | |
x_pos = ( | |
Variable( | |
x_pos.unsqueeze(0).repeat(batch_size, 1, 1), requires_grad=False | |
).cuda() | |
if self.use_cuda | |
else Variable( | |
x_pos.unsqueeze(0).repeat(batch_size, 1, 1), requires_grad=False | |
) | |
) | |
x_input = x_emb + x_pos | |
h = self.drop(x_input) | |
return h | |
""" | |
Transformer encoder | |
""" | |
class TransformerEncoder(nn.Module): | |
""" | |
Input: | |
embed: (if not None) pre-computed vocab embeddings | |
""" | |
def __init__(self, opt, embed=None): | |
super(TransformerEncoder, self).__init__() | |
vocab = int(opt["vocab_size"]) | |
n_state = int(opt["transformer_embed_dim"]) | |
n_layer = int(opt["TRANSFORMER_LAYER"]) | |
if "vae_z_scale_factor" in opt: | |
self.vae_z_scale_factor = float(opt["vae_z_scale_factor"]) | |
self.embedder = Embedder(opt, embed) | |
block = EncoderBlock(opt) | |
self.blocks = nn.ModuleList([copy.deepcopy(block) for _ in range(n_layer)]) | |
self.use_cuda = opt["cuda"] | |
""" | |
Input: | |
x: batch x len (word_id) | |
z (optional): batch x len x n_state (for VAE) | |
Output: | |
h: batch x len x n_state (word_id) | |
""" | |
def forward(self, x, z=None): | |
x_mask = ~x.eq(0) # 1 is PAD_id | |
x_mask = x_mask.type( | |
torch.cuda.FloatTensor if self.use_cuda else torch.FloatTensor | |
) | |
h = self.embedder(x) | |
if z is not None: | |
z *= self.vae_z_scale_factor | |
h += z | |
for block in self.blocks: | |
h = block(h, x_mask) | |
return h | |
""" | |
Transformer decoder | |
""" | |
class TransformerDecoder(nn.Module): | |
""" | |
Input: | |
embed: (if not None) pre-computed vocab embeddings | |
""" | |
def __init__(self, opt, embed=None): | |
super(TransformerDecoder, self).__init__() | |
self.opt = opt | |
vocab_size = int(opt["vocab_size"]) | |
n_state = int(opt["transformer_embed_dim"]) # n_state | |
n_layer = int(opt["TRANSFORMER_LAYER"]) | |
self.embedder = Embedder(opt, embed) | |
self.encoder_splitter = Splitter(n_state) | |
block = DecoderBlock(opt) | |
self.blocks = nn.ModuleList([copy.deepcopy(block) for _ in range(n_layer)]) | |
if embed is None: | |
self.linear = Conv1D(vocab_size, n_state) | |
else: | |
self.linear = nn.Linear(n_state, vocab_size, bias=False) | |
if ( | |
"FINETUNE_RETRAIN_SOFTMAX" not in opt | |
): # if FINETUNE_RETRAIN_SOFTMAX, linear needs to be seperately trained | |
self.linear.weight = embed.weight # share weight | |
self.use_coda = opt["cuda"] | |
""" | |
Input: | |
x: batch x encoder_len (word id) | |
x_out: batch x encoder_len x n_state | |
y: batch x len (word_id) (decoder part) | |
lang_model: whether it's for language model training (no encoder part is used) | |
Output: | |
prob: batch x len x vocab_size (probabilities after softmax) | |
""" | |
def forward(self, x, x_out, y, lang_model=False): | |
# seq2seq | |
if not lang_model: | |
_, enc_key, enc_value = self.encoder_splitter(x_out) | |
# enc_key: batch x encoder_len x n_state | |
# enc_value: batch x encoder_len x n_state | |
x_mask = ~x.eq(0) # 1 is PAD_id | |
x_mask = x_mask.type( | |
torch.cuda.FloatTensor if self.use_cuda else torch.FloatTensor | |
) | |
else: | |
enc_key = None | |
enc_value = None | |
x_mask = None | |
h = self.embedder(y) | |
for block in self.blocks: | |
h = block(x_mask, h, enc_key, enc_value, lang_model) | |
prob = F.softmax(self.linear(h), dim=-1) | |
return prob | |
class TransformerBeam: | |
""" | |
Input: | |
encoder: TransformerEncoder class | |
decoder: TransformerDecoder class | |
begin_id: word id of '<BEGIN>' | |
vocab: list of words | |
""" | |
def __init__(self, opt, encoder, decoder, begin_id, vocab): | |
self.encoder = encoder | |
self.decoder = decoder | |
self.opt = opt | |
self.max_sent_len = int(opt["max_sent_len"]) | |
self.begin_id = begin_id | |
self.vocab = vocab | |
self.beam_width = int(opt["beam_width"]) | |
self.use_cuda = opt["cuda"] | |
# each candidate is (idx, prob, 0/1, position/wordid) | |
def merge_candidates(self, cand_A, cand_B): | |
C = [] | |
pA, lA, pB, lB = 0, len(cand_A), 0, len(cand_B) | |
lC = 0 | |
while (pA < lA or pB < lB) and (lC < self.beam_width): | |
if pA < lA and (pB >= lB or cand_A[pA][1] > cand_B[pB][1]): | |
C.append(cand_A[pA]) | |
pA += 1 | |
else: | |
C.append(cand_B[pB]) | |
pB += 1 | |
lC += 1 | |
return C | |
""" | |
Input: | |
x = batch * encoder_len (word_ids) encoder's input | |
k: top-k sampling | |
Output: | |
sents: list of words, with batch items, each one with up to beam_width (sentence, log_prob), each sentence with up to max_sent_len_word words | |
""" | |
def topk(self, x, k): | |
batch_size = x.shape[0] | |
x_len = x.shape[1] | |
x_out = self.encoder(x) | |
# x_out: batch x encoder_len x n_state | |
# sent_ids is the words for each of the batch_size sentences | |
sent_ids = [] | |
for i in range(batch_size): | |
sent_ids.append([self.begin_id]) | |
topk = 1 | |
MIN_GEN_LENGTH = 45 | |
if "MIN_GEN_LENGTH" in self.opt: | |
MIN_GEN_LENGTH = int(self.opt["MIN_GEN_LENGTH"]) | |
for l in range(self.max_sent_len): | |
y = ( | |
Variable(torch.LongTensor(sent_ids)).cuda() | |
if self.use_cuda | |
else Variable(torch.LongTensor(sent_ids)) | |
) # batch_size x l | |
decoder_outputs = self.decoder(x, x_out, y) | |
probs = decoder_outputs[ | |
:, -1, : | |
] # batch_size x vocab_size (only take the last output) | |
for i in range(batch_size): | |
topk_probs, _ = torch.topk(probs[i], k) | |
threshold = float(topk_probs[-1]) | |
probs[i][probs[i] < threshold] = 0.0 | |
samples = torch.multinomial( | |
probs, 2 | |
) # sample 2 since the first one may be <END> | |
for i in range(batch_size): | |
if l < MIN_GEN_LENGTH and self.vocab[int(samples[i, 0])] == "<END>": | |
sent_ids[i].append(int(samples[i, 1])) | |
else: | |
sent_ids[i].append(int(samples[i, 0])) | |
sents = [] | |
for i in range(batch_size): | |
utt = [] | |
for j in range(len(sent_ids[i])): | |
w = self.vocab[sent_ids[i][j]] | |
if w == "<BEGIN>": | |
continue | |
if w == "<END>": | |
break | |
utt.append(w) | |
sents.append([(utt, 0)]) | |
return sents | |
""" | |
Input: | |
x = batch * encoder_len (word_ids) encoder's input | |
Output: | |
sents: list of words, with batch items, each one with up to beam_width (sentence, log_prob), each sentence with up to max_sent_len_word words | |
""" | |
def beam_search(self, x): | |
batch_size = x.shape[0] | |
x_len = x.shape[1] | |
x_out = self.encoder(x) | |
# x_out: batch x encoder_len x n_state | |
sents = [] | |
topk = 1 | |
history_nodes = [{}] | |
end_nodes = {} | |
for idx in range(batch_size): | |
start_node = BeamSearchNode([self.begin_id], 0, 1) | |
history_nodes[0][idx] = [start_node] | |
end_nodes[idx] = [] | |
for l in range(self.max_sent_len): | |
last_nodes = history_nodes[-1] | |
if sum([len(l) for i, l in last_nodes.items()]) == 0: # no nodes left | |
break | |
ys = [] | |
x_outs = [] | |
xs = [] | |
for idx in range(batch_size): | |
ys.extend([node.word_ids for node in last_nodes[idx]]) | |
x_outs.extend( | |
[x_out[idx, :, :].unsqueeze(0) for node in last_nodes[idx]] | |
) | |
xs.extend([x[idx, :].unsqueeze(0) for node in last_nodes[idx]]) | |
ys = ( | |
Variable(torch.LongTensor(ys)).cuda() | |
if self.use_cuda | |
else Variable(torch.LongTensor(ys)) | |
) # N x l | |
x_outs = torch.cat(x_outs, dim=0) # N x x_len x n_state | |
xs = torch.cat(xs, dim=0) # N x x_len | |
probs = self.decoder(xs, x_outs, ys) | |
log_probs = torch.log( | |
probs[:, -1, :] + 1e-15 | |
) # N x vocab_size (only take the last output) | |
history_nodes.append({}) | |
p = 0 | |
for idx in range(batch_size): | |
history_nodes[-1][idx] = [] | |
N = len(last_nodes[idx]) | |
if N == 0: | |
continue | |
log_prob = log_probs[p : p + N] | |
p += N | |
# log_prob = N x extended_vocab_size | |
# generate | |
candidates = [] | |
for k in range(N): | |
logprobs, ids = torch.topk(log_prob[k], self.beam_width) | |
candidates = self.merge_candidates( | |
candidates, [(k, p, d) for p, d in zip(logprobs, ids)] | |
) | |
candidates = candidates[: self.beam_width] | |
extended_nodes_in_last_nodes = set() | |
for k in range(len(candidates)): | |
h, logp, next_word_id = candidates[ | |
k | |
] # h means "the h-th node in last_nodes" | |
logp = float(logp) | |
next_word_id = int(next_word_id) | |
prev_node = last_nodes[idx][h] | |
next_wordids = prev_node.word_ids + [next_word_id] | |
next_word = self.vocab[next_word_id] | |
next_node = BeamSearchNode( | |
next_wordids, prev_node.log_prob + logp, prev_node.length + 1 | |
) | |
if next_node.duplicate == False: # no duplicate trigram generated | |
extended_nodes_in_last_nodes.add(h) | |
if next_word == "<END>" or l == self.max_sent_len - 1: | |
end_nodes[idx].append((next_node.eval(), next_node)) | |
else: | |
history_nodes[-1][idx].append(next_node) | |
special_words = ["<PAD>", "<UNK>", "<s>", "</s>", "<BEGIN>", "<END>"] | |
for k in range(N): | |
if k not in extended_nodes_in_last_nodes: | |
node = last_nodes[idx][k] | |
effective_word_count = sum( | |
[ | |
1 | |
for x in node.word_ids | |
if self.vocab[x] not in special_words | |
] | |
) | |
if effective_word_count >= 5: | |
end_nodes[idx].append((node.eval(), node)) | |
MIN_GEN_LENGTH = 45 | |
if "MIN_GEN_LENGTH" in self.opt: | |
MIN_GEN_LENGTH = int(self.opt["MIN_GEN_LENGTH"]) | |
for idx in range(batch_size): | |
t = len([w for w in end_nodes[idx] if w[1].length > MIN_GEN_LENGTH]) | |
if t > 0: | |
end_nodes[idx] = [ | |
w for w in end_nodes[idx] if w[1].length > MIN_GEN_LENGTH | |
] | |
end_nodes[idx].sort(key=lambda tup: tup[0], reverse=True) | |
candidates = [] | |
for score, node in end_nodes[idx][:topk]: | |
utt = [self.vocab[x] for x in node.word_ids] | |
utt = [x for x in utt if x not in ["<BEGIN>", "<END>"]] | |
candidates.append((utt, score)) | |
if len(candidates) == 0: | |
candidates.append(("", 0)) | |
sents.append(candidates) | |
return sents | |
class BeamSearchNode(object): | |
def __init__(self, word_ids, log_prob, length): | |
self.word_ids = word_ids | |
self.log_prob = log_prob | |
self.length = length | |
trigram_set = set() | |
self.duplicate = False | |
for i in range(2, len(word_ids)): | |
trigram = ( | |
str(word_ids[i - 2]) | |
+ " " | |
+ str(word_ids[i - 1]) | |
+ " " | |
+ str(word_ids[i]) | |
) | |
if trigram in trigram_set: | |
self.duplicate = True | |
break | |
trigram_set.add(trigram) | |
def eval(self): | |
return self.log_prob / float(self.length - 1.0 + 1e-6) | |
def __lt__(self, other): | |
return self.length < other.length | |