|
import torch |
|
import torch.nn as nn |
|
from torch.nn import Transformer |
|
from loguru import logger |
|
import os |
|
|
|
|
|
DEVICE = torch.device("cuda:6" if torch.cuda.is_available() else "cpu") |
|
|
|
def get_sinusoid_encoding_table(max_len, d_model): |
|
pos_encoding = torch.zeros(max_len, d_model) |
|
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1) |
|
div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-torch.log(torch.tensor(10000.0)) / d_model)) |
|
pos_encoding[:, 0::2] = torch.sin(position * div_term) |
|
pos_encoding[:, 1::2] = torch.cos(position * div_term) |
|
pos_encoding = pos_encoding.unsqueeze(0) |
|
return nn.Parameter(pos_encoding, requires_grad=False) |
|
|
|
|
|
class TransformerModel(nn.Module): |
|
def __init__(self, config, src_tokenizer, tgt_tokenizer): |
|
super(TransformerModel, self).__init__() |
|
|
|
global DEVICE |
|
config = config.config |
|
DEVICE = torch.device(config['DEVICE']) |
|
self.batch_size = config.get('BATCH_SIZE') |
|
self.epochs = config.get('EPOCHS') |
|
self.learning_rate = config.get('LEARNING_RATE') |
|
self.max_seq_len = config.get('MAX_SEQ_LEN') |
|
self.d_model = config.get('D_MODEL') |
|
self.n_head = config.get('N_HEAD') |
|
self.num_layers = config.get('NUM_LAYERS') |
|
self.dim_feedforward = config.get('DIM_FEEDFORWARD') |
|
self.dropout = config.get('DROPOUT') |
|
self.src_tokenizer = src_tokenizer |
|
self.tgt_tokenizer = tgt_tokenizer |
|
|
|
self.transformer = Transformer(d_model=self.d_model, |
|
nhead=self.n_head, |
|
num_encoder_layers=self.num_layers, |
|
num_decoder_layers=self.num_layers, |
|
dim_feedforward=self.dim_feedforward, |
|
dropout = self.dropout |
|
) |
|
|
|
|
|
|
|
src_vocab_size = len(src_tokenizer.word2int) |
|
tgt_vocab_size = len(tgt_tokenizer.word2int) |
|
self.src_embedding = nn.Embedding(src_vocab_size, self.d_model) |
|
self.tgt_embedding = nn.Embedding(tgt_vocab_size, self.d_model) |
|
self.fc = nn.Linear(self.d_model, tgt_vocab_size) |
|
|
|
self.pos_encoding = get_sinusoid_encoding_table(self.max_seq_len, self.d_model) |
|
self.to(DEVICE) |
|
|
|
def forward(self, src, tgt, src_mask=None, tgt_mask=None, src_padding_mask=None, tgt_padding_mask=None): |
|
|
|
src_emb = self.src_embedding(src) * (self.d_model ** 0.5) |
|
tgt_emb = self.tgt_embedding(tgt) * (self.d_model ** 0.5) |
|
|
|
|
|
src_emb = src_emb+ self.pos_encoding[0, :src_emb.size(1)] |
|
tgt_emb = tgt_emb + self.pos_encoding[0, :tgt_emb.size(1)] |
|
|
|
|
|
output = self.transformer( |
|
src_emb.permute(1, 0, 2), |
|
tgt_emb.permute(1, 0, 2), |
|
src_mask=src_mask, |
|
tgt_mask=tgt_mask, |
|
src_key_padding_mask=src_padding_mask, |
|
tgt_key_padding_mask=tgt_padding_mask, |
|
) |
|
output = self.fc(output.permute(1, 0, 2)) |
|
return output |
|
|
|
def decode_sentence(self, ids): |
|
int2word = self.tgt_tokenizer.int2word |
|
word2int = self.tgt_tokenizer.word2int |
|
tokens = [int2word[id] for id in ids if id not in {word2int["<PAD>"], word2int["<BOS>"], word2int["<EOS>"]}] |
|
return tokens, " ".join(tokens) |
|
|
|
def encode_sentence(self, sentence): |
|
tokens = sentence.split() |
|
|
|
word2int = self.src_tokenizer.word2int |
|
max_len = self.max_seq_len |
|
ids = [word2int.get(token, word2int["<UNK>"]) for token in tokens] |
|
ids = [word2int["<BOS>"]] + ids[:max_len - 2] + [word2int["<EOS>"]] |
|
ids += [word2int["<PAD>"]] * (max_len - len(ids)) |
|
|
|
return torch.tensor(ids, dtype=torch.long).unsqueeze(0).to(DEVICE) |
|
|
|
def translate(self, sentence): |
|
|
|
|
|
src_tensor = self.encode_sentence(sentence).to(DEVICE) |
|
|
|
tgt_tensor = tgt_tensor = torch.tensor([self.tgt_tokenizer.word2int["<BOS>"]], dtype=torch.long).unsqueeze(0).to(DEVICE) |
|
|
|
for _ in range(self.max_seq_len): |
|
|
|
tgt_mask = nn.Transformer.generate_square_subsequent_mask(tgt_tensor.size(1)).to(DEVICE) |
|
|
|
|
|
|
|
output = self.forward(src_tensor, tgt_tensor, tgt_mask=tgt_mask) |
|
|
|
|
|
next_token = output[:, -1, :].argmax(dim=-1).item() |
|
|
|
|
|
tgt_tensor = torch.cat([tgt_tensor, torch.tensor([[next_token]], dtype=torch.long).to(DEVICE)], dim=1) |
|
|
|
|
|
if next_token == self.tgt_tokenizer.word2int["<EOS>"]: |
|
break |
|
return self.decode_sentence(tgt_tensor.squeeze(0).tolist()) |
|
|
|
|
|
def translate_no_unk(self, sentence): |
|
src_tensor = self.encode_sentence(sentence) |
|
tgt_tensor = torch.tensor([self.tgt_tokenizer.word2int["<BOS>"]], dtype=torch.long).unsqueeze(0).to(DEVICE) |
|
|
|
for _ in range(self.max_seq_len): |
|
|
|
tgt_mask = nn.Transformer.generate_square_subsequent_mask(tgt_tensor.size(1)).to(DEVICE) |
|
|
|
|
|
output = self.forward(src_tensor, tgt_tensor, tgt_mask=tgt_mask) |
|
|
|
|
|
logits = output[:, -1, :] |
|
sorted_indices = torch.argsort(logits, dim=-1, descending=True) |
|
next_token = sorted_indices[0, 0].item() |
|
unk_id = self.tgt_tokenizer.word2int["<UNK>"] |
|
|
|
|
|
if next_token == unk_id: |
|
next_token = sorted_indices[0, 1].item() |
|
|
|
|
|
tgt_tensor = torch.cat([tgt_tensor, torch.tensor([[next_token]], dtype=torch.long).to(DEVICE)], dim=1) |
|
|
|
|
|
if next_token == self.tgt_tokenizer.word2int["<EOS>"]: |
|
break |
|
|
|
return self.decode_sentence(tgt_tensor.squeeze(0).tolist()) |
|
|
|
def translate_beam_search(self, sentence, beam_size=3): |
|
src_tensor = self.encode_sentence(sentence) |
|
bos_id = self.tgt_tokenizer.word2int["<BOS>"] |
|
eos_id = self.tgt_tokenizer.word2int["<EOS>"] |
|
unk_id = self.tgt_tokenizer.word2int["<UNK>"] |
|
|
|
|
|
beams = [([bos_id], 0)] |
|
|
|
for _ in range(self.max_seq_len): |
|
new_beams = [] |
|
for seq, score in beams: |
|
tgt_tensor = torch.tensor(seq, dtype=torch.long).unsqueeze(0).to(DEVICE) |
|
tgt_mask = nn.Transformer.generate_square_subsequent_mask(tgt_tensor.size(1)).to(DEVICE) |
|
|
|
output = self.forward(src_tensor, tgt_tensor, tgt_mask=tgt_mask) |
|
logits = output[:, -1, :] |
|
probs = torch.softmax(logits, dim=-1).squeeze(0) |
|
|
|
sorted_indices = torch.argsort(probs, dim=-1, descending=True) |
|
top_indices = sorted_indices[:beam_size] |
|
|
|
for idx in top_indices: |
|
next_token = idx.item() |
|
if next_token == unk_id: |
|
continue |
|
new_seq = seq + [next_token] |
|
new_score = score + torch.log(probs[next_token]).item() |
|
new_beams.append((new_seq, new_score)) |
|
|
|
|
|
new_beams.sort(key=lambda x: x[1], reverse=True) |
|
beams = new_beams[:beam_size] |
|
|
|
|
|
ended_beams = [beam for beam in beams if beam[0][-1] == eos_id] |
|
if ended_beams: |
|
best_beam = max(ended_beams, key=lambda x: x[1]) |
|
return self.decode_sentence(best_beam[0]) |
|
|
|
|
|
best_beam = max(beams, key=lambda x: x[1]) |
|
return self.decode_sentence(best_beam[0]) |
|
|
|
def translate_no_unk_beam_search(self, sentence, beam_size=3): |
|
src_tensor = self.encode_sentence(sentence) |
|
bos_id = self.tgt_tokenizer.word2int["<BOS>"] |
|
eos_id = self.tgt_tokenizer.word2int["<EOS>"] |
|
unk_id = self.tgt_tokenizer.word2int["<UNK>"] |
|
|
|
|
|
beams = [([bos_id], 0)] |
|
|
|
for _ in range(self.max_seq_len): |
|
new_beams = [] |
|
for seq, score in beams: |
|
tgt_tensor = torch.tensor(seq, dtype=torch.long).unsqueeze(0).to(DEVICE) |
|
tgt_mask = nn.Transformer.generate_square_subsequent_mask(tgt_tensor.size(1)).to(DEVICE) |
|
|
|
output = self.forward(src_tensor, tgt_tensor, tgt_mask=tgt_mask) |
|
logits = output[:, -1, :] |
|
probs = torch.softmax(logits, dim=-1).squeeze(0) |
|
|
|
sorted_indices = torch.argsort(probs, dim=-1, descending=True) |
|
top_indices = sorted_indices[:beam_size + 1] |
|
|
|
valid_count = 0 |
|
for idx in top_indices: |
|
next_token = idx.item() |
|
if next_token == unk_id: |
|
continue |
|
new_seq = seq + [next_token] |
|
new_score = score + torch.log(probs[next_token]).item() |
|
new_beams.append((new_seq, new_score)) |
|
valid_count += 1 |
|
if valid_count >= beam_size: |
|
break |
|
|
|
|
|
new_beams.sort(key=lambda x: x[1], reverse=True) |
|
beams = new_beams[:beam_size] |
|
|
|
|
|
ended_beams = [beam for beam in beams if beam[0][-1] == eos_id] |
|
if ended_beams: |
|
best_beam = max(ended_beams, key=lambda x: x[1]) |
|
return self.decode_sentence(best_beam[0]) |
|
|
|
|
|
best_beam = max(beams, key=lambda x: x[1]) |
|
return self.decode_sentence(best_beam[0]) |
|
|
|
|
|
|
|
|