|
import torch |
|
import torch.nn as nn |
|
from torch import Tensor |
|
from torch.nn import Transformer |
|
|
|
|
|
class PositionalEncoding(nn.Module): |
|
def __init__(self, |
|
emb_size: int, |
|
dropout: float, |
|
maxlen: int = 5000): |
|
super(PositionalEncoding, self).__init__() |
|
den = torch.exp(- torch.arange(0, emb_size, 2)* torch.log(10000) / emb_size) |
|
pos = torch.arange(0, maxlen).reshape(maxlen, 1) |
|
pos_embedding = torch.zeros((maxlen, emb_size)) |
|
pos_embedding[:, 0::2] = torch.sin(pos * den) |
|
pos_embedding[:, 1::2] = torch.cos(pos * den) |
|
pos_embedding = pos_embedding.unsqueeze(-2) |
|
|
|
self.dropout = nn.Dropout(dropout) |
|
self.register_buffer('pos_embedding', pos_embedding) |
|
|
|
def forward(self, token_embedding: Tensor): |
|
return self.dropout(token_embedding + self.pos_embedding[:token_embedding.size(0), :]) |
|
|
|
|
|
class TokenEmbedding(nn.Module): |
|
def __init__(self, vocab_size: int, emb_size): |
|
super(TokenEmbedding, self).__init__() |
|
self.embedding = nn.Embedding(vocab_size, emb_size) |
|
self.emb_size = emb_size |
|
|
|
def forward(self, tokens: Tensor): |
|
return self.embedding(tokens.long()) * math.sqrt(self.emb_size) |
|
|
|
class Seq2SeqTransformer(nn.Module): |
|
def __init__(self, |
|
num_encoder_layers: int, |
|
num_decoder_layers: int, |
|
emb_size: int, |
|
nhead: int, |
|
src_vocab_size: int, |
|
tgt_vocab_size: int, |
|
dim_feedforward: int = 512, |
|
dropout: float = 0.1): |
|
super(Seq2SeqTransformer, self).__init__() |
|
self.transformer = Transformer(d_model=emb_size, |
|
nhead=nhead, |
|
num_encoder_layers=num_encoder_layers, |
|
num_decoder_layers=num_decoder_layers, |
|
dim_feedforward=dim_feedforward, |
|
dropout=dropout, |
|
batch_first=True) |
|
self.generator = nn.Linear(emb_size, tgt_vocab_size) |
|
self.src_tok_emb = TokenEmbedding(src_vocab_size, emb_size) |
|
self.tgt_tok_emb = TokenEmbedding(tgt_vocab_size, emb_size) |
|
self.positional_encoding = PositionalEncoding( |
|
emb_size, dropout=dropout) |
|
|
|
def forward(self, |
|
src: Tensor, |
|
trg: Tensor, |
|
src_mask: Tensor, |
|
tgt_mask: Tensor, |
|
src_padding_mask: Tensor, |
|
tgt_padding_mask: Tensor, |
|
memory_key_padding_mask: Tensor): |
|
src_emb = self.positional_encoding(self.src_tok_emb(src)) |
|
tgt_emb = self.positional_encoding(self.tgt_tok_emb(trg)) |
|
outs = self.transformer(src_emb, tgt_emb, src_mask, tgt_mask, None, |
|
src_padding_mask, tgt_padding_mask, memory_key_padding_mask) |
|
return self.generator(outs) |
|
|
|
def encode(self, src: Tensor, src_mask: Tensor): |
|
return self.transformer.encoder(self.positional_encoding( |
|
self.src_tok_emb(src)), src_mask) |
|
|
|
def decode(self, tgt: Tensor, memory: Tensor, tgt_mask: Tensor): |
|
return self.transformer.decoder(self.positional_encoding( |
|
self.tgt_tok_emb(tgt)), memory, |
|
tgt_mask) |
|
|
|
import yaml |
|
from transformers import AutoTokenizer |
|
from transformers import PreTrainedTokenizerFast |
|
from tokenizers.processors import TemplateProcessing |
|
|
|
|
|
def addPreprocessing(tokenizer): |
|
tokenizer._tokenizer.post_processor = TemplateProcessing( |
|
single=tokenizer.bos_token + " $A " + tokenizer.eos_token, |
|
special_tokens=[(tokenizer.eos_token, tokenizer.eos_token_id), (tokenizer.bos_token, tokenizer.bos_token_id)]) |
|
|
|
def load_model(model_checkpoint_dir='model.pt',config_dir='config.yaml'): |
|
|
|
with open(config_dir, 'r') as yaml_file: |
|
loaded_model_params = yaml.safe_load(yaml_file) |
|
|
|
|
|
model = Seq2SeqTransformer( |
|
loaded_model_params["num_encoder_layers"], |
|
loaded_model_params["num_decoder_layers"], |
|
loaded_model_params["emb_size"], |
|
loaded_model_params["nhead"], |
|
loaded_model_params["source_vocab_size"], |
|
loaded_model_params["target_vocab_size"], |
|
loaded_model_params["ffn_hid_dim"] |
|
) |
|
|
|
checkpoint = torch.load(model_checkpoint_dir) if torch.cuda.is_available() else torch.load(model_checkpoint_dir,map_location=torch.device('cpu')) |
|
model.load_state_dict(checkpoint) |
|
|
|
return model |
|
|
|
|
|
def greedy_decode(model, src, src_mask, max_len, start_symbol): |
|
|
|
src = src.to(device) |
|
src_mask = src_mask.to(device) |
|
|
|
|
|
memory = model.encode(src, src_mask) |
|
|
|
|
|
ys = torch.tensor([[start_symbol]]).type(torch.long).to(device) |
|
|
|
for i in range(max_len - 1): |
|
memory = memory.to(device) |
|
|
|
tgt_mask = torch.tril(torch.full((ys.size(1), ys.size(1)), float('-inf'), device=device), diagonal=-1).transpose(0, 1).to(device) |
|
|
|
out = model.decode(ys, memory, tgt_mask) |
|
|
|
prob = model.generator(out[:, -1]) |
|
|
|
|
|
_, next_word = torch.max(prob, dim=1) |
|
next_word = next_word.item() |
|
|
|
|
|
ys = torch.cat([ys, |
|
torch.ones(1, 1).type_as(src.data).fill_(next_word)], dim=1) |
|
|
|
|
|
if next_word == target_tokenizer.eos_token_id: |
|
break |
|
|
|
return ys |
|
|
|
|
|
def beam_search_decode(model, src, src_mask, max_len, start_symbol, beam_size ,length_penalty): |
|
|
|
src = src.to(device) |
|
src_mask = src_mask.to(device) |
|
|
|
|
|
memory = model.encode(src, src_mask) |
|
|
|
|
|
beams = [(torch.tensor([[start_symbol]]).type(torch.long).to(device), 0)] |
|
|
|
for i in range(max_len - 1): |
|
new_beams = [] |
|
complete_beams = [] |
|
cbl = [] |
|
|
|
for ys, score in beams: |
|
|
|
|
|
tgt_mask = torch.tril(torch.full((ys.size(1), ys.size(1)), float('-inf'), device=device), diagonal=-1).transpose(0, 1).to(device) |
|
|
|
out = model.decode(ys, memory, tgt_mask) |
|
|
|
|
|
prob = model.generator(out[:, -1]) |
|
|
|
|
|
|
|
_, top_indices = torch.topk(prob, beam_size, dim=1) |
|
|
|
for j,next_word in enumerate(top_indices[0]): |
|
|
|
next_word = next_word.item() |
|
|
|
|
|
new_ys = torch.cat([ys, torch.full((1, 1), fill_value=next_word, dtype=src.dtype).to(device)], dim=1) |
|
|
|
length_factor = (5 + j / 6) ** length_penalty |
|
new_score = (score + prob[0][next_word].item()) / length_factor |
|
|
|
if next_word == target_tokenizer.eos_token_id: |
|
complete_beams.append((new_ys, new_score)) |
|
else: |
|
new_beams.append((new_ys, new_score)) |
|
|
|
|
|
|
|
new_beams.sort(key=lambda x: x[1], reverse=True) |
|
try: |
|
beams = new_beams[:beam_size] |
|
except: |
|
beams = new_beams |
|
|
|
beams = new_beams + complete_beams |
|
beams.sort(key=lambda x: x[1], reverse=True) |
|
|
|
best_beam = beams[0][0] |
|
return best_beam |
|
|
|
def translate(model: torch.nn.Module, src_sentence: str, strategy:str = 'greedy' , lenght_extend :int = 5, beam_size: int = 5, length_penalty:float = 0.6): |
|
assert strategy in ['greedy','beam search'], 'the strategy for decoding has to be either greedy or beam search' |
|
|
|
src = source_tokenizer(src_sentence, **token_config)['input_ids'] |
|
num_tokens = src.shape[1] |
|
|
|
src_mask = (torch.zeros(num_tokens, num_tokens)).type(torch.bool) |
|
if strategy == 'greedy': |
|
tgt_tokens = greedy_decode(model, src, src_mask, max_len=num_tokens + lenght_extend, start_symbol=target_tokenizer.bos_token_id).flatten() |
|
|
|
else: |
|
tgt_tokens = beam_search_decode(model, src, src_mask, max_len=num_tokens + lenght_extend, start_symbol=target_tokenizer.bos_token_id, beam_size=beam_size,length_penalty=length_penalty).flatten() |
|
|
|
return target_tokenizer.decode(tgt_tokens, clean_up_tokenization_spaces=True, skip_special_tokens=True) |
|
|
|
special_tokens = {'unk_token':"[UNK]", |
|
'cls_token':"[CLS]", |
|
'eos_token': '[EOS]', |
|
'sep_token':"[SEP]", |
|
'pad_token':"[PAD]", |
|
'mask_token':"[MASK]", |
|
'bos_token':"[BOS]"} |
|
|
|
source_tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased", **special_tokens) |
|
target_tokenizer = PreTrainedTokenizerFast.from_pretrained('Sifal/E2KT') |
|
|
|
addPreprocessing(source_tokenizer) |
|
addPreprocessing(target_tokenizer) |
|
|
|
token_config = { |
|
"add_special_tokens": True, |
|
"return_tensors": True, |
|
} |
|
|
|
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") |
|
|
|
model = load_model() |
|
model.to(device) |
|
model.eval() |
|
|
|
import gradio as gr |
|
|
|
iface = gr.Interface( |
|
fn=translate, |
|
inputs=[ |
|
gr.inputs.Textbox("Enter a sentence to translate"), |
|
gr.inputs.Radio(['greedy', 'beam search'], label="Decoding Strategy"), |
|
gr.inputs.Number(label="Length Extend (for greedy)", default=5), |
|
gr.inputs.Number(label="Beam Size (for beam search)", default=5), |
|
gr.inputs.Number(label="Length Penalty (for beam search)", default=0.6) |
|
], |
|
outputs=gr.outputs.Textbox("Translation"), |
|
title="Translation Interface", |
|
description="Translate text using a pre-trained model.", |
|
) |
|
|
|
|
|
iface.launch() |