|
import yaml |
|
import torch |
|
from .model import Seq2SeqTransformer |
|
from transformers import AutoTokenizer |
|
from transformers import PreTrainedTokenizerFast |
|
from tokenizers.processors import TemplateProcessing |
|
|
|
|
|
def addPreprocessing(tokenizer): |
|
tokenizer._tokenizer.post_processor = TemplateProcessing( |
|
single=tokenizer.bos_token + " $A " + tokenizer.eos_token, |
|
special_tokens=[(tokenizer.eos_token, tokenizer.eos_token_id), (tokenizer.bos_token, tokenizer.bos_token_id)]) |
|
|
|
def load_model(model_checkpoint_dir='model.pt',config_dir='config.yaml'): |
|
|
|
with open(config_dir, 'r') as yaml_file: |
|
loaded_model_params = yaml.safe_load(yaml_file) |
|
|
|
|
|
model = Seq2SeqTransformer( |
|
loaded_model_params["num_encoder_layers"], |
|
loaded_model_params["num_decoder_layers"], |
|
loaded_model_params["emb_size"], |
|
loaded_model_params["nhead"], |
|
loaded_model_params["source_vocab_size"], |
|
loaded_model_params["target_vocab_size"], |
|
loaded_model_params["ffn_hid_dim"] |
|
) |
|
|
|
checkpoint = torch.load(model_checkpoint_dir) if torch.cuda.is_available() else torch.load(model_checkpoint_dir,map_location=torch.device('cpu')) |
|
model.load_state_dict(checkpoint) |
|
|
|
return model |
|
|
|
|
|
def greedy_decode(model, src, src_mask, max_len, start_symbol): |
|
|
|
src = src.to(device) |
|
src_mask = src_mask.to(device) |
|
|
|
|
|
memory = model.encode(src, src_mask) |
|
|
|
|
|
ys = torch.tensor([[start_symbol]]).type(torch.long).to(device) |
|
|
|
for i in range(max_len - 1): |
|
memory = memory.to(device) |
|
|
|
tgt_mask = torch.tril(torch.full((ys.size(1), ys.size(1)), float('-inf'), device=device), diagonal=-1).transpose(0, 1).to(device) |
|
|
|
out = model.decode(ys, memory, tgt_mask) |
|
|
|
prob = model.generator(out[:, -1]) |
|
|
|
|
|
_, next_word = torch.max(prob, dim=1) |
|
next_word = next_word.item() |
|
|
|
|
|
ys = torch.cat([ys, |
|
torch.ones(1, 1).type_as(src.data).fill_(next_word)], dim=1) |
|
|
|
|
|
if next_word == target_tokenizer.eos_token_id: |
|
break |
|
|
|
return ys |
|
|
|
|
|
def beam_search_decode(model, src, src_mask, max_len, start_symbol, beam_size ,length_penalty): |
|
|
|
src = src.to(device) |
|
src_mask = src_mask.to(device) |
|
|
|
|
|
memory = model.encode(src, src_mask) |
|
|
|
|
|
beams = [(torch.tensor([[start_symbol]]).type(torch.long).to(device), 0)] |
|
|
|
for i in range(max_len - 1): |
|
new_beams = [] |
|
complete_beams = [] |
|
cbl = [] |
|
|
|
for ys, score in beams: |
|
|
|
|
|
tgt_mask = torch.tril(torch.full((ys.size(1), ys.size(1)), float('-inf'), device=device), diagonal=-1).transpose(0, 1).to(device) |
|
|
|
out = model.decode(ys, memory, tgt_mask) |
|
|
|
|
|
prob = model.generator(out[:, -1]) |
|
|
|
|
|
|
|
_, top_indices = torch.topk(prob, beam_size, dim=1) |
|
|
|
for j,next_word in enumerate(top_indices[0]): |
|
|
|
next_word = next_word.item() |
|
|
|
|
|
new_ys = torch.cat([ys, torch.full((1, 1), fill_value=next_word, dtype=src.dtype).to(device)], dim=1) |
|
|
|
length_factor = (5 + j / 6) ** length_penalty |
|
new_score = (score + prob[0][next_word].item()) / length_factor |
|
|
|
if next_word == target_tokenizer.eos_token_id: |
|
complete_beams.append((new_ys, new_score)) |
|
else: |
|
new_beams.append((new_ys, new_score)) |
|
|
|
|
|
|
|
new_beams.sort(key=lambda x: x[1], reverse=True) |
|
try: |
|
beams = new_beams[:beam_size] |
|
except: |
|
beams = new_beams |
|
|
|
beams = new_beams + complete_beams |
|
beams.sort(key=lambda x: x[1], reverse=True) |
|
|
|
best_beam = beams[0][0] |
|
return best_beam |
|
|
|
def translate(model: torch.nn.Module, strategy:str = 'greedy' , src_sentence: str, lenght_extend :int = 5, beam_size: int = 5, length_penalty:float = 0.6): |
|
assert strategy in ['greedy','beam search'], 'the strategy for decoding has to be either greedy or beam search' |
|
|
|
src = source_tokenizer(src_sentence, **token_config)['input_ids'] |
|
num_tokens = src.shape[1] |
|
|
|
src_mask = (torch.zeros(num_tokens, num_tokens)).type(torch.bool) |
|
if strategy == 'greedy': |
|
tgt_tokens = greedy_decode(model, src, src_mask, max_len=num_tokens + lenght_extend, start_symbol=target_tokenizer.bos_token_id).flatten() |
|
|
|
else: |
|
tgt_tokens = beam_search_decode(model, src, src_mask, max_len=num_tokens + lenght_extend, start_symbol=target_tokenizer.bos_token_id, beam_size=beam_size,length_penalty=length_penalty).flatten() |
|
|
|
return target_tokenizer.decode(tgt_tokens, clean_up_tokenization_spaces=True, skip_special_tokens=True) |
|
|
|
special_tokens = {'unk_token':"[UNK]", |
|
'cls_token':"[CLS]", |
|
'eos_token': '[EOS]', |
|
'sep_token':"[SEP]", |
|
'pad_token':"[PAD]", |
|
'mask_token':"[MASK]", |
|
'bos_token':"[BOS]"} |
|
|
|
source_tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased", **special_tokens) |
|
target_tokenizer = PreTrainedTokenizerFast.from_pretrained('Sifal/E2KT') |
|
|
|
addPreprocessing(source_tokenizer) |
|
addPreprocessing(target_tokenizer) |
|
|
|
token_config = { |
|
"add_special_tokens": True, |
|
"return_tensors": True, |
|
} |
|
|
|
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") |
|
|
|
model = load_model() |
|
model.to(device) |
|
model.eval() |