Spaces:
Running
Running
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from src.helper import get_optimizer, tensorsFromPair, get_languages, tensorFromWord, get_data | |
from src.language import SOS_token, EOS_token | |
from src.encoder import Encoder | |
from src.decoder import Decoder | |
import random | |
import time | |
import numpy as np | |
PRINT_EVERY = 5000 | |
PLOT_EVERY = 100 | |
class Translator: | |
def __init__(self, lang: str, params: dict, device: str): | |
self.lang = lang | |
self.input_lang, self.output_lang, self.pairs = get_languages(self.lang) | |
self.input_size = self.input_lang.n_chars | |
self.output_size = self.output_lang.n_chars | |
self.device = device | |
self.training_pairs = [tensorsFromPair(self.input_lang, self.output_lang, pair, self.device) for pair in self.pairs] | |
self.encoder = Encoder(in_sz = self.input_size, | |
embed_sz = params["embed_size"], | |
hidden_sz = params["hidden_size"], | |
cell_type = params["cell_type"], | |
n_layers = params["num_layers"], | |
dropout = params["dropout"], | |
device=self.device).to(self.device) | |
self.decoder = Decoder(out_sz = self.output_size, | |
embed_sz = params["embed_size"], | |
hidden_sz = params["hidden_size"], | |
cell_type = params["cell_type"], | |
n_layers = params["num_layers"], | |
dropout = params["dropout"], | |
device=self.device).to(self.device) | |
self.encoder_optimizer = get_optimizer(params["optimizer"])(self.encoder.parameters(), lr=params["learning_rate"]) | |
self.decoder_optimizer = get_optimizer(params["optimizer"])(self.decoder.parameters(), lr=params["learning_rate"]) | |
self.criterion = nn.NLLLoss() | |
self.teacher_forcing_ratio = params["teacher_forcing_ratio"] | |
self.max_length = params["max_length"] | |
def train_single(self, input_tensor, target_tensor): | |
encoder_hidden = self.encoder.initHidden() | |
encoder_cell = self.encoder.initHidden() | |
self.encoder_optimizer.zero_grad() | |
self.decoder_optimizer.zero_grad() | |
input_length = input_tensor.size(0) | |
target_length = target_tensor.size(0) | |
encoder_outputs = torch.zeros(self.max_length, self.encoder.hidden_sz, device=self.device) | |
loss = 0 | |
for ei in range(input_length): | |
encoder_output, encoder_hidden, encoder_cell = self.encoder(input_tensor[ei], encoder_hidden, encoder_cell) | |
encoder_outputs[ei] = encoder_output[0, 0] | |
decoder_input = torch.tensor([[SOS_token]], device=self.device) | |
decoder_hidden, decoder_cell = encoder_hidden, encoder_cell | |
use_teacher_forcing = True if random.random() < self.teacher_forcing_ratio else False | |
if use_teacher_forcing: | |
for di in range(target_length): | |
decoder_output, decoder_hidden, decoder_cell = self.decoder(decoder_input, decoder_hidden, decoder_cell) | |
loss += self.criterion(decoder_output, target_tensor[di]) | |
decoder_input = target_tensor[di] | |
else: | |
for di in range(target_length): | |
decoder_output, decoder_hidden, decoder_cell = self.decoder(decoder_input, decoder_hidden, decoder_cell) | |
loss += self.criterion(decoder_output, target_tensor[di]) | |
topv, topi = decoder_output.topk(1) | |
decoder_input = topi.squeeze().detach() | |
if decoder_input.item() == EOS_token: | |
break | |
loss.backward() | |
self.encoder_optimizer.step() | |
self.decoder_optimizer.step() | |
return loss.item() / target_length | |
def train(self, iters=-1): | |
start_time = time.time() | |
plot_losses = [] | |
print_loss_total = 0 | |
plot_loss_total = 0 | |
random.shuffle(self.training_pairs) | |
iters = len(self.training_pairs) if iters == -1 else iters | |
for iter in range(1, iters+1): | |
training_pair = self.training_pairs[iter - 1] | |
input_tensor = training_pair[0] | |
target_tensor = training_pair[1] | |
loss = self.train_single(input_tensor, target_tensor) | |
print_loss_total += loss | |
plot_loss_total += loss | |
if iter % PRINT_EVERY == 0: | |
print_loss_avg = print_loss_total / PRINT_EVERY | |
print_loss_total = 0 | |
current_time = time.time() | |
print("Loss: {:.4f} | Iterations: {} | Time: {:.3f}".format(print_loss_avg, iter, current_time - start_time)) | |
if iter % PLOT_EVERY == 0: | |
plot_loss_avg = plot_loss_total / PLOT_EVERY | |
plot_losses.append(plot_loss_avg) | |
plot_loss_total = 0 | |
return plot_losses | |
def evaluate(self, word): | |
with torch.no_grad(): | |
input_tensor = tensorFromWord(self.input_lang, word, self.device) | |
input_length = input_tensor.size()[0] | |
encoder_hidden = self.encoder.initHidden() | |
encoder_cell = self.encoder.initHidden() | |
encoder_outputs = torch.zeros(self.max_length, self.encoder.hidden_sz, device=self.device) | |
for ei in range(input_length): | |
encoder_output, encoder_hidden, encoder_cell = self.encoder(input_tensor[ei], encoder_hidden, encoder_cell) | |
encoder_outputs[ei] += encoder_output[0, 0] | |
decoder_input = torch.tensor([[SOS_token]], device=self.device) | |
decoder_hidden, decoder_cell = encoder_hidden, encoder_cell | |
decoded_chars = "" | |
for di in range(self.max_length): | |
decoder_output, decoder_hidden, decoder_cell = self.decoder(decoder_input, decoder_hidden, decoder_cell) | |
topv, topi = decoder_output.topk(1) | |
if topi.item() == EOS_token: | |
break | |
else: | |
decoded_chars += self.output_lang.index2word[topi.item()] | |
decoder_input = topi.squeeze().detach() | |
return decoded_chars | |
def test_validate(self, type:str): | |
pairs = get_data(self.lang, type) | |
accuracy = np.sum([self.evaluate(pair[0]) == pair[1] for pair in pairs]) | |
return accuracy / len(pairs) |