Spaces:
Running
Running
File size: 6,542 Bytes
9e582c5 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 |
import torch
import torch.nn as nn
import torch.nn.functional as F
from src.helper import get_optimizer, tensorsFromPair, get_languages, tensorFromWord, get_data
from src.language import SOS_token, EOS_token
from src.encoder import Encoder
from src.decoder import Decoder
import random
import time
import numpy as np
PRINT_EVERY = 5000
PLOT_EVERY = 100
class Translator:
def __init__(self, lang: str, params: dict, device: str):
self.lang = lang
self.input_lang, self.output_lang, self.pairs = get_languages(self.lang)
self.input_size = self.input_lang.n_chars
self.output_size = self.output_lang.n_chars
self.device = device
self.training_pairs = [tensorsFromPair(self.input_lang, self.output_lang, pair, self.device) for pair in self.pairs]
self.encoder = Encoder(in_sz = self.input_size,
embed_sz = params["embed_size"],
hidden_sz = params["hidden_size"],
cell_type = params["cell_type"],
n_layers = params["num_layers"],
dropout = params["dropout"],
device=self.device).to(self.device)
self.decoder = Decoder(out_sz = self.output_size,
embed_sz = params["embed_size"],
hidden_sz = params["hidden_size"],
cell_type = params["cell_type"],
n_layers = params["num_layers"],
dropout = params["dropout"],
device=self.device).to(self.device)
self.encoder_optimizer = get_optimizer(params["optimizer"])(self.encoder.parameters(), lr=params["learning_rate"])
self.decoder_optimizer = get_optimizer(params["optimizer"])(self.decoder.parameters(), lr=params["learning_rate"])
self.criterion = nn.NLLLoss()
self.teacher_forcing_ratio = params["teacher_forcing_ratio"]
self.max_length = params["max_length"]
def train_single(self, input_tensor, target_tensor):
encoder_hidden = self.encoder.initHidden()
encoder_cell = self.encoder.initHidden()
self.encoder_optimizer.zero_grad()
self.decoder_optimizer.zero_grad()
input_length = input_tensor.size(0)
target_length = target_tensor.size(0)
encoder_outputs = torch.zeros(self.max_length, self.encoder.hidden_sz, device=self.device)
loss = 0
for ei in range(input_length):
encoder_output, encoder_hidden, encoder_cell = self.encoder(input_tensor[ei], encoder_hidden, encoder_cell)
encoder_outputs[ei] = encoder_output[0, 0]
decoder_input = torch.tensor([[SOS_token]], device=self.device)
decoder_hidden, decoder_cell = encoder_hidden, encoder_cell
use_teacher_forcing = True if random.random() < self.teacher_forcing_ratio else False
if use_teacher_forcing:
for di in range(target_length):
decoder_output, decoder_hidden, decoder_cell = self.decoder(decoder_input, decoder_hidden, decoder_cell)
loss += self.criterion(decoder_output, target_tensor[di])
decoder_input = target_tensor[di]
else:
for di in range(target_length):
decoder_output, decoder_hidden, decoder_cell = self.decoder(decoder_input, decoder_hidden, decoder_cell)
loss += self.criterion(decoder_output, target_tensor[di])
topv, topi = decoder_output.topk(1)
decoder_input = topi.squeeze().detach()
if decoder_input.item() == EOS_token:
break
loss.backward()
self.encoder_optimizer.step()
self.decoder_optimizer.step()
return loss.item() / target_length
def train(self, iters=-1):
start_time = time.time()
plot_losses = []
print_loss_total = 0
plot_loss_total = 0
random.shuffle(self.training_pairs)
iters = len(self.training_pairs) if iters == -1 else iters
for iter in range(1, iters+1):
training_pair = self.training_pairs[iter - 1]
input_tensor = training_pair[0]
target_tensor = training_pair[1]
loss = self.train_single(input_tensor, target_tensor)
print_loss_total += loss
plot_loss_total += loss
if iter % PRINT_EVERY == 0:
print_loss_avg = print_loss_total / PRINT_EVERY
print_loss_total = 0
current_time = time.time()
print("Loss: {:.4f} | Iterations: {} | Time: {:.3f}".format(print_loss_avg, iter, current_time - start_time))
if iter % PLOT_EVERY == 0:
plot_loss_avg = plot_loss_total / PLOT_EVERY
plot_losses.append(plot_loss_avg)
plot_loss_total = 0
return plot_losses
def evaluate(self, word):
with torch.no_grad():
input_tensor = tensorFromWord(self.input_lang, word, self.device)
input_length = input_tensor.size()[0]
encoder_hidden = self.encoder.initHidden()
encoder_cell = self.encoder.initHidden()
encoder_outputs = torch.zeros(self.max_length, self.encoder.hidden_sz, device=self.device)
for ei in range(input_length):
encoder_output, encoder_hidden, encoder_cell = self.encoder(input_tensor[ei], encoder_hidden, encoder_cell)
encoder_outputs[ei] += encoder_output[0, 0]
decoder_input = torch.tensor([[SOS_token]], device=self.device)
decoder_hidden, decoder_cell = encoder_hidden, encoder_cell
decoded_chars = ""
for di in range(self.max_length):
decoder_output, decoder_hidden, decoder_cell = self.decoder(decoder_input, decoder_hidden, decoder_cell)
topv, topi = decoder_output.topk(1)
if topi.item() == EOS_token:
break
else:
decoded_chars += self.output_lang.index2word[topi.item()]
decoder_input = topi.squeeze().detach()
return decoded_chars
def test_validate(self, type:str):
pairs = get_data(self.lang, type)
accuracy = np.sum([self.evaluate(pair[0]) == pair[1] for pair in pairs])
return accuracy / len(pairs) |