NMT-LaVi / modules /inference /beam_search1.py
hieungo1410's picture
'add'
8cb4f3b
import numpy as np
import torch
import math, time, operator
import torch.nn.functional as functional
import torch.nn as nn
from torch.autograd import Variable
from torch.nn.utils.rnn import pad_sequence
from modules.inference.decode_strategy import DecodeStrategy
from utils.misc import no_peeking_mask
class BeamSearch1(DecodeStrategy):
def __init__(self, model, max_len, device, beam_size=5, use_synonym_fn=False, replace_unk=None):
"""
Args:
model: the used model
max_len: the maximum timestep to be used
device: the device to perform calculation
beam_size: the size of the beam itself
use_synonym_fn: if set, use the get_synonym fn from wordnet to try replace <unk>
replace_unk: a tuple of [layer, head] designation, to replace the unknown word by chosen attention
"""
super(BeamSearch1, self).__init__(model, max_len, device)
self.beam_size = beam_size
self._use_synonym = use_synonym_fn
self._replace_unk = replace_unk
# print("Init BeamSearch ----------------")
def trg_init_vars(self, src, batch_size, trg_init_token, trg_eos_token, single_src_mask):
"""
Calculate the required matrices during translation after the model is finished
Input:
:param src: The batch of sentences
Output: Initialize the first character includes outputs, e_outputs, log_scores
"""
# Initialize target sequence (start with '<sos>' token) [batch_size x k x max_len]
trg = torch.zeros(batch_size, self.beam_size, self.max_len, device=self.device).long()
trg[:, :, 0] = trg_init_token
# Precalc output from model's encoder
e_out = self.model.encoder(src, single_src_mask) # [batch_size x S x d_model]
# Output model prob
trg_mask = no_peeking_mask(1, device=self.device)
# [batch_size x 1]
inp_decoder = trg[:, 0, 0].view(batch_size, 1)
# [batch_size x 1 x vocab_size]
prob = self.model.out(self.model.decoder(inp_decoder, e_out, single_src_mask, trg_mask))
prob = functional.softmax(prob, dim=-1)
# [batch_size x 1 x k]
k_prob, k_index = torch.topk(prob, self.beam_size, dim=-1)
trg[:, :, 1] = k_index.view(batch_size, self.beam_size)
# Init log scores from k beams [batch_size x k x 1]
log_scores = torch.log(k_prob.view(batch_size, self.beam_size, 1))
# Repeat encoder's output k times for searching [(k * batch_size) x S x d_model]
e_outs = torch.repeat_interleave(e_out, self.beam_size, dim=0)
src_mask = torch.repeat_interleave(single_src_mask, self.beam_size, dim=0)
# Create mask for checking eos
sent_eos = torch.tensor([trg_eos_token for _ in range(self.beam_size)], device=self.device).view(1, self.beam_size)
return sent_eos, log_scores, e_outs, e_out, src_mask, trg
def compute_k_best(self, outputs, out, log_scores, i, debug=False):
"""
Compute k words with the highest conditional probability
Args:
outputs: Array has k previous candidate output sequences. [batch_size*beam_size, max_len]
i: the current timestep to execute. Int
out: current output of the model at timestep. [batch_size*beam_size, vocab_size]
log_scores: Conditional probability of past candidates (in outputs) [batch_size * beam_size]
Returns:
new outputs has k best candidate output sequences
log_scores for each of those candidate
"""
row_b = len(out);
batch_size = row_b // self.beam_size
eos_id = self.TRG.vocab.stoi['<eos>']
probs, ix = out[:, -1].data.topk(self.beam_size)
probs_rep = torch.Tensor([[1] + [1e-100] * (self.beam_size-1)]*row_b).view(row_b, self.beam_size).to(self.device)
ix_rep = torch.LongTensor([[eos_id] + [-1]*(self.beam_size-1)]*row_b).view(row_b, self.beam_size).to(self.device)
check_eos = torch.repeat_interleave((outputs[:, i-1] == eos_id).view(row_b, 1), self.beam_size, 1)
probs = torch.where(check_eos, probs_rep, probs)
ix = torch.where(check_eos, ix_rep, ix)
log_probs = torch.log(probs).to(self.device) + log_scores.to(self.device) # CPU
k_probs, k_ix = log_probs.view(batch_size, -1).topk(self.beam_size)
if(debug):
print("kprobs_after_select: ", log_probs, k_probs, k_ix)
# Use cpu
k_probs, k_ix = torch.Tensor(k_probs.cpu().data.numpy()), torch.LongTensor(k_ix.cpu().data.numpy())
row = k_ix // self.beam_size + torch.LongTensor([[v*self.beam_size] for v in range(batch_size)])
col = k_ix % self.beam_size
if(debug):
print("kprobs row/col", row, col, ix[row.view(-1), col.view(-1)])
assert False
outputs[:, :i] = outputs[row.view(-1), :i]
outputs[:, i] = ix[row.view(-1), col.view(-1)]
log_scores = k_probs.view(-1, 1)
return outputs, log_scores
def replace_unknown(self, outputs, sentences, attn, selector_tuple, unknown_token="<unk>"):
"""Replace the unknown words in the outputs with the highest valued attentionized words.
Args:
outputs: the output from decoding. [batchbeam] of list of str, with maximum values being
sentences: the original wordings of the sentences. [batch_size, src_len] of str
attn: the attention received, in the form of list: [layers units of (self-attention, attention) with shapes of [batchbeam, heads, tgt_len, tgt_len] & [batchbeam, heads, tgt_len, src_len] respectively]
selector_tuple: (layer, head) used to select the attention
unknown_token: token used for
Returns:
the replaced version, in the same shape as outputs
"""
layer_used, head_used = selector_tuple
# used_attention = attn[layer_used][-1][:, head_used] # it should be [batchbeam, tgt_len, src_len], as we are using the attention to source
inx = torch.arange(start=0,end=len(attn)-1, step=self.beam_size)
used_attention = attn[inx]
select_id_src = torch.argmax(used_attention, dim=-1).cpu().numpy() # [batchbeam, tgt_len] of best indices. Also convert to numpy version (remove sos not needed as it is attention of outputs)
# print(select_id_src, len(select_id_src))
beam_size = select_id_src.shape[0] // len(sentences) # used custom-calculated beam_size as we might not output the entirety of beams. See beam_search fn for details
# print("beam: ", beam_size)
# select per batchbeam. source batch id is found by dividing batchbeam id per beam; we are selecting [tgt_len] indices from [src_len] tokens; then concat at the first dimensions to retrieve [batch_beam, tgt_len] of replacement tokens
# need itemgetter / map to retrieve from list
# print([ operator.itemgetter(*src_idx)(sentences[bidx // beam_size]) for bidx, src_idx in enumerate(select_id_src)])
# print([print(sentences[bidx // beam_size], src_idx) for bidx, src_idx in enumerate(select_id_src)])
# replace_tokens = [ operator.itemgetter(*src_idx)(sentences[bidx // beam_size]) for bidx, src_idx in enumerate(select_id_src)]
for i in range(len(outputs)):
for j in range(len(outputs[i])):
if outputs[i][j] == unknown_token:
outputs[i][j] = sentences[i][select_id_src[i][j]]
# print(sentences[0][0], outputs[0][0])
# print(i)
# zip together with sentences; then output { the token if not unk / the replacement if is }. Note that this will trim the orig version down to repl size.
# replaced = [ [tok if tok != unknown_token else rpl for rpl, tok in zip(repl, orig)] for orig, repl in zipped ]
# return replaced
return outputs
# def beam_search(self, src, max_len, device, k=4):
def beam_search(self, src, src_tokens=None, n_best=1, debug=False):
"""
Beam search for a single sentence
Args:
model : a Transformer instance
src : a batch (tokenized + numerized) sentence [batch_size x S]
Returns:
trg : a batch (tokenized + numerized) sentence [batch_size x T]
"""
src = src.to(self.device)
trg_init_token = self.TRG.vocab.stoi["<sos>"]
trg_eos_token = self.TRG.vocab.stoi["<eos>"]
single_src_mask = (src != self.SRC.vocab.stoi['<pad>']).unsqueeze(1).to(self.device)
batch_size = src.size(0)
sent_eos, log_scores, e_outs, e_out, src_mask, trg = self.trg_init_vars(src, batch_size, trg_init_token, trg_eos_token, single_src_mask)
# The batch indexes
batch_index = torch.arange(batch_size)
finished_batches = torch.zeros(batch_size, device=self.device).long()
log_attn = torch.zeros([self.beam_size*batch_size, self.max_len, len(src[0])])
# Iteratively searching
for i in range(2, self.max_len):
trg_mask = no_peeking_mask(i, self.device)
# Flatten trg tensor for feeding into model [(k * batch_size) x i]
inp_decoder = trg[batch_index, :, :i].view(self.beam_size * len(batch_index), i)
# Output model prob [(k * batch_size) x i x vocab_size]
current_decode, attn = self.model.decoder(inp_decoder, e_outs, src_mask, trg_mask, output_attention=True)
# print(len(attn[0]))
prob = self.model.out(current_decode)
prob = functional.softmax(prob, dim=-1)
# Only care the last prob i-th
# [(k * batch_size) x 1 x vocab_size]
prob = prob[:, i-1, :].view(self.beam_size * len(batch_index), 1, -1)
# Truncate prob to top k [(k * batch_size) x 1 x k]
k_prob, k_index = prob.data.topk(self.beam_size, dim=-1)
# Deflatten k_prob & k_index
k_prob = k_prob.view(len(batch_index), self.beam_size, 1, self.beam_size)
k_index = k_index.view(len(batch_index), self.beam_size, 1, self.beam_size)
# Preserve eos beams
# [batch_size x k] -> view -> [batch_size x k x 1 x 1] (broadcastable)
eos_mask = (trg[batch_index, :, i-1] == trg_eos_token).view(len(batch_index), self.beam_size, 1, 1)
k_prob.masked_fill_(eos_mask, 1.0)
k_index.masked_fill_(eos_mask, trg_eos_token)
# Find the best k cases
# Compute log score at i-th timestep
# [batch_size x k x 1 x 1] + [batch_size x k x 1 x k] = [batch_size x k x 1 x k]
combine_probs = log_scores[batch_index].unsqueeze(-1) + torch.log(k_prob)
# [batch_size x k x 1]
log_scores[batch_index], positions = torch.topk(combine_probs.view(len(batch_index), self.beam_size * self.beam_size, 1), self.beam_size, dim=1)
# The rows selected from top k
rows = positions.view(len(batch_index), self.beam_size) // self.beam_size
# The indexes in vocab respected to these rows
cols = positions.view(len(batch_index), self.beam_size) % self.beam_size
batch_sim = torch.arange(len(batch_index)).view(-1, 1)
trg[batch_index, :, :] = trg[batch_index.view(-1, 1), rows, :]
trg[batch_index, :, i] = k_index[batch_sim, rows, :, cols].view(len(batch_index), self.beam_size)
# Update attn
inx = torch.repeat_interleave(finished_batches, self.beam_size, dim=0)
batch_attn = torch.nonzero(inx == 0).view(-1)
# import copy
# x = copy.deepcopy(attn[0][-1][:, 0].to("cpu"))
# log_attn[batch_attn, :i, :] = x
# if i == 7:
# print(log_attn[batch_attn, :i, :].shape, attn[0][-1][:, 0].shape)
# print(log_attn[batch_attn, :i, :])
# Update which sentences finished all its beams
mask = (trg[:, :, i] == sent_eos).all(1).view(-1).to(self.device)
finished_batches.masked_fill_(mask, value=1)
cnt = torch.sum(finished_batches).item()
if cnt == batch_size:
break
# # Continue with remaining batches (if any)
batch_index = torch.nonzero(finished_batches == 0).view(-1)
e_outs = torch.repeat_interleave(e_out[batch_index], self.beam_size, dim=0)
src_mask = torch.repeat_interleave(single_src_mask[batch_index], self.beam_size, dim=0)
# End loop
# Get the best beam
log_scores = log_scores.view(batch_size, self.beam_size)
results = []
for t, j in enumerate(torch.argmax(log_scores, dim=-1)):
sent = []
for i in range(self.max_len):
token_id = trg[t, j.item(), i].item()
if token_id == trg_init_token:
continue
if token_id == trg_eos_token:
break
sent.append(self.TRG.vocab.itos[token_id])
results.append(sent)
# if(self._replace_unk and src_tokens is not None):
# # replace unknown words per translated sentences.
# # NOTE: lacking a src_tokens does not raise any warning. Add that in when logging module is available, to support error catching
# # print("Replace unk -----------------------")
# results = self.replace_unknown(results, src_tokens, log_attn, self._replace_unk)
return results
def translate_single_sentence(self, src, **kwargs):
"""Translate a single sentence. Currently unused."""
raise NotImplementedError
return self.translate_batch_sentence([src], **kwargs)
def translate_batch_sentence(self, src, field_processed=False, src_size_limit=None, output_tokens=False, debug=False):
"""Translate a batch of sentences together. Currently disabling the synonym func.
Args:
src: the batch of sentences to be translated
field_processed: bool, if the sentences had been already processed (i.e part of batched validation data)
src_size_limit: if set, trim the input if it cross this value. Added due to current positional encoding support only <=200 tokens
output_tokens: the output format. False will give a batch of sentences (str), while True will give batch of tokens (list of str)
debug: enable to print external values
Return:
the result of translation, with format dictated by output_tokens
"""
# start = time.time()
self.model.eval()
# create the indiced batch.
processed_batch = self.preprocess_batch(src, field_processed=field_processed, src_size_limit=src_size_limit, output_tokens=True, debug=debug)
# print("Time preprocess_batch: ", time.time()-start)
sent_ids, sent_tokens = (processed_batch, None) if(field_processed) else processed_batch
assert isinstance(sent_ids, torch.Tensor), "sent_ids is instead {}".format(type(sent_ids))
translated_sentences = self.beam_search(sent_ids, src_tokens=sent_tokens, debug=debug)
# print("Time for one batch: ", time.time()-batch_start)
# if time.time()-batch_start > 2:
# [print("len src >2 : ++++++", len(i.split())) for i in src]
# [print("len translate >2: ++++++", len(i)) for i in translated_sentences]
# else:
# [print("len src : ====", len(i.split())) for i in src]
# [print("len translate : ====", len(i)) for i in translated_sentences]
# print("=====================================")
# time.sleep(4)
if(debug):
print("Time performed for batch {}: {:.2f}s".format(sent_ids.shape))
if(not output_tokens):
translated_sentences = [' '.join(tokens) for tokens in translated_sentences]
return translated_sentences
def preprocess_batch(self, sentences, field_processed=False, pad_token="<pad>", src_size_limit=None, output_tokens=False, debug=True):
"""Adding
src_size_limit: int, option to limit the length of src.
field_processed: bool: if the sentences had been already processed (i.e part of batched validation data)
output_tokens: if set, output a token version aside the id version, in [batch of [src_len]] str. Note that it won't work with field_processed
"""
if(field_processed):
# do nothing, as it had already performed tokenizing/stoi
return sentences
processed_sent = map(self.SRC.preprocess, sentences)
if(src_size_limit):
processed_sent = map(lambda x: x[:src_size_limit], processed_sent)
processed_sent = list(processed_sent)
tokenized_sent = [torch.LongTensor([self._token_to_index(t) for t in s]) for s in processed_sent] # convert to tensors, in indices format
sentences = Variable(pad_sequence(tokenized_sent, True, padding_value=self.SRC.vocab.stoi[pad_token])) # padding sentences
if(debug):
print("Input batch after process: ", sentences.shape, sentences)
if(output_tokens):
return sentences, processed_sent
else:
return sentences
def translate_batch(self, sentences, **kwargs):
return self.translate_batch_sentence(sentences, **kwargs)
def _token_to_index(self, tok):
"""Override to select, depending on the self._use_synonym param"""
if(self._use_synonym):
return super(BeamSearch1, self)._token_to_index(tok)
else:
return self.SRC.vocab.stoi[tok]