transliteration / train_attention.py
Pankaj Singh Rawat
Initial commit
9e582c5
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import random
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import time
import argparse
random.seed()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Language Model
SOS_token = 0
EOS_token = 1
class Language:
def __init__(self, name):
self.name = name
self.word2index = {}
self.word2count = {}
self.index2word = {SOS_token: "<", EOS_token: ">"}
self.n_chars = 2 # Count SOS and EOS
def addWord(self, word):
for char in word:
self.addChar(char)
def addChar(self, char):
if char not in self.word2index:
self.word2index[char] = self.n_chars
self.word2count[char] = 1
self.index2word[self.n_chars] = char
self.n_chars += 1
else:
self.word2count[char] += 1
def get_data(lang: str, type: str) -> list[list[str]]:
"""
Returns: 'pairs': list of [input_word, target_word] pairs
"""
path = "./aksharantar_sampled/{}/{}_{}.csv".format(lang, lang, type)
df = pd.read_csv(path, header=None)
pairs = df.values.tolist()
return pairs
def get_languages(lang: str):
"""
Returns
1. input_lang: input language - English
2. output_lang: output language - Given language
3. pairs: list of [input_word, target_word] pairs
"""
input_lang = Language('eng')
output_lang = Language(lang)
pairs = get_data(lang, "train")
for pair in pairs:
input_lang.addWord(pair[0])
output_lang.addWord(pair[1])
return input_lang, output_lang, pairs
def get_cell(cell_type: str):
if cell_type == "LSTM":
return nn.LSTM
elif cell_type == "GRU":
return nn.GRU
elif cell_type == "RNN":
return nn.RNN
else:
raise Exception("Invalid cell type")
def get_optimizer(optimizer: str):
if optimizer == "SGD":
return optim.SGD
elif optimizer == "ADAM":
return optim.Adam
else:
raise Exception("Invalid optimizer")
class Encoder(nn.Module):
def __init__(self,
in_sz: int,
embed_sz: int,
hidden_sz: int,
cell_type: str,
n_layers: int,
dropout: float):
super(Encoder, self).__init__()
self.hidden_sz = hidden_sz
self.n_layers = n_layers
self.dropout = dropout
self.cell_type = cell_type
self.embedding = nn.Embedding(in_sz, embed_sz)
self.rnn = get_cell(cell_type)(input_size = embed_sz,
hidden_size = hidden_sz,
num_layers = n_layers,
dropout = dropout)
def forward(self, input, hidden, cell):
embedded = self.embedding(input).view(1, 1, -1)
if(self.cell_type == "LSTM"):
output, (hidden, cell) = self.rnn(embedded, (hidden, cell))
else:
output, hidden = self.rnn(embedded, hidden)
return output, hidden, cell
def initHidden(self):
return torch.zeros(self.n_layers, 1, self.hidden_sz, device=device)
class AttentionDecoder(nn.Module):
def __init__(self,
out_sz: int,
embed_sz: int,
hidden_sz: int,
cell_type: str,
n_layers: int,
dropout: float):
super(AttentionDecoder, self).__init__()
self.hidden_sz = hidden_sz
self.n_layers = n_layers
self.dropout = dropout
self.cell_type = cell_type
self.embedding = nn.Embedding(out_sz, embed_sz)
self.attn = nn.Linear(hidden_sz + embed_sz, 50)
self.attn_combine = nn.Linear(hidden_sz + embed_sz, hidden_sz)
self.rnn = get_cell(cell_type)(input_size = hidden_sz,
hidden_size = hidden_sz,
num_layers = n_layers,
dropout = dropout)
self.out = nn.Linear(hidden_sz, out_sz)
self.softmax = nn.LogSoftmax(dim=1)
def forward(self, input, hidden, cell, encoder_outputs):
embedding = self.embedding(input).view(1, 1, -1)
attn_weights = F.softmax(self.attn(torch.cat((embedding[0], hidden[0]), 1)), dim=1)
attn_applied = torch.bmm(attn_weights.unsqueeze(0), encoder_outputs.unsqueeze(0))
output = torch.cat((embedding[0], attn_applied[0]), 1)
output = self.attn_combine(output).unsqueeze(0)
if(self.cell_type == "LSTM"):
output, (hidden, cell) = self.rnn(output, (hidden, cell))
else:
output, hidden = self.rnn(output, hidden)
output = self.softmax(self.out(output[0]))
return output, hidden, cell, attn_weights
def initHidden(self):
return torch.zeros(self.n_layers, 1, self.hidden_sz, device=device)
def indexesFromWord(lang:Language, word:str):
return [lang.word2index[char] for char in word]
def tensorFromWord(lang:Language, word:str):
indexes = indexesFromWord(lang, word)
indexes.append(EOS_token)
return torch.tensor(indexes, dtype=torch.long, device=device).view(-1, 1)
def tensorsFromPair(input_lang:Language, output_lang:Language, pair:list[str]):
input_tensor = tensorFromWord(input_lang, pair[0])
target_tensor = tensorFromWord(output_lang, pair[1])
return (input_tensor, target_tensor)
def params_definition():
"""
params:
embed_size : size of embedding (input and output) (8, 16, 32, 64)
hidden_size : size of hidden layer (64, 128, 256, 512)
cell_type : type of cell (LSTM, GRU, RNN)
num_layers : number of layers in encoder (1, 2, 3)
dropout : dropout probability
learning_rate : learning rate
teacher_forcing_ratio : teacher forcing ratio (0.5 fixed for now)
optimizer : optimizer (SGD, Adam)
max_length : maximum length of input word (50 fixed for now)
"""
pass
PRINT_EVERY = 5000
PLOT_EVERY = 100
class Translator:
def __init__(self, lang: str, params: dict):
self.lang = lang
self.input_lang, self.output_lang, self.pairs = get_languages(self.lang)
self.input_size = self.input_lang.n_chars
self.output_size = self.output_lang.n_chars
self.training_pairs = [tensorsFromPair(self.input_lang, self.output_lang, pair) for pair in self.pairs]
self.encoder = Encoder(in_sz = self.input_size,
embed_sz = params["embed_size"],
hidden_sz = params["hidden_size"],
cell_type = params["cell_type"],
n_layers = params["num_layers"],
dropout = params["dropout"]).to(device)
self.decoder = AttentionDecoder(out_sz = self.output_size,
embed_sz = params["embed_size"],
hidden_sz = params["hidden_size"],
cell_type = params["cell_type"],
n_layers = params["num_layers"],
dropout = params["dropout"]).to(device)
self.encoder_optimizer = get_optimizer(params["optimizer"])(self.encoder.parameters(), lr=params["learning_rate"], weight_decay=params["weight_decay"])
self.decoder_optimizer = get_optimizer(params["optimizer"])(self.decoder.parameters(), lr=params["learning_rate"], weight_decay=params["weight_decay"])
self.criterion = nn.NLLLoss()
self.teacher_forcing_ratio = params["teacher_forcing_ratio"]
self.max_length = params["max_length"]
def train_single(self, input_tensor, target_tensor):
encoder_hidden = self.encoder.initHidden()
encoder_cell = self.encoder.initHidden()
self.encoder_optimizer.zero_grad()
self.decoder_optimizer.zero_grad()
input_length = input_tensor.size(0)
target_length = target_tensor.size(0)
encoder_outputs = torch.zeros(self.max_length, self.encoder.hidden_sz, device=device)
loss = 0
for ei in range(input_length):
encoder_output, encoder_hidden, encoder_cell = self.encoder(input_tensor[ei], encoder_hidden, encoder_cell)
encoder_outputs[ei] = encoder_output[0, 0]
decoder_input = torch.tensor([[SOS_token]], device=device)
decoder_hidden, decoder_cell = encoder_hidden, encoder_cell
use_teacher_forcing = True if random.random() < self.teacher_forcing_ratio else False
if use_teacher_forcing:
for di in range(target_length):
decoder_output, decoder_hidden, decoder_cell, decoder_attention = self.decoder(decoder_input, decoder_hidden, decoder_cell, encoder_outputs)
loss += self.criterion(decoder_output, target_tensor[di])
decoder_input = target_tensor[di]
else:
for di in range(target_length):
decoder_output, decoder_hidden, decoder_cell, decoder_attention = self.decoder(decoder_input, decoder_hidden, decoder_cell, encoder_outputs)
loss += self.criterion(decoder_output, target_tensor[di])
topv, topi = decoder_output.topk(1)
decoder_input = topi.squeeze().detach()
if decoder_input.item() == EOS_token:
break
loss.backward()
self.encoder_optimizer.step()
self.decoder_optimizer.step()
return loss.item() / target_length
def train(self, iters=-1):
start_time = time.time()
plot_losses = []
print_loss_total = 0
plot_loss_total = 0
random.shuffle(self.training_pairs)
iters = len(self.training_pairs) if iters == -1 else iters
for iter in range(1, iters):
training_pair = self.training_pairs[iter - 1]
input_tensor = training_pair[0]
target_tensor = training_pair[1]
loss = self.train_single(input_tensor, target_tensor)
print_loss_total += loss
plot_loss_total += loss
if iter % PRINT_EVERY == 0:
print_loss_avg = print_loss_total / PRINT_EVERY
print_loss_total = 0
current_time = time.time()
print("Loss: {:.4f} | Iterations: {} | Time: {:.3f}".format(print_loss_avg, iter, current_time - start_time))
if iter % PLOT_EVERY == 0:
plot_loss_avg = plot_loss_total / PLOT_EVERY
plot_losses.append(plot_loss_avg)
plot_loss_total = 0
return plot_losses
def evaluate(self, word):
with torch.no_grad():
input_tensor = tensorFromWord(self.input_lang, word)
input_length = input_tensor.size()[0]
encoder_hidden = self.encoder.initHidden()
encoder_cell = self.encoder.initHidden()
encoder_outputs = torch.zeros(self.max_length, self.encoder.hidden_sz, device=device)
for ei in range(input_length):
encoder_output, encoder_hidden, encoder_cell = self.encoder(input_tensor[ei], encoder_hidden, encoder_cell)
encoder_outputs[ei] += encoder_output[0, 0]
decoder_input = torch.tensor([[SOS_token]], device=device)
decoder_hidden, decoder_cell = encoder_hidden, encoder_cell
decoded_chars = ""
decoder_attentions = torch.zeros(self.max_length, self.max_length)
for di in range(self.max_length):
decoder_output, decoder_hidden, decoder_cell, decoder_attention = self.decoder(decoder_input, decoder_hidden, decoder_cell, encoder_outputs)
decoder_attentions[di] = decoder_attention.data
topv, topi = decoder_output.topk(1)
if topi.item() == EOS_token:
break
else:
decoded_chars += self.output_lang.index2word[topi.item()]
decoder_input = topi.squeeze().detach()
return decoded_chars, decoder_attentions[:di + 1]
def test_validate(self, type:str):
pairs = get_data(self.lang, type)
accuracy = 0
for pair in pairs:
output, _ = self.evaluate(pair[0])
if output == pair[1]:
accuracy += 1
return accuracy / len(pairs)
params = {
"embed_size": 32,
"hidden_size": 256,
"cell_type": "RNN",
"num_layers": 2,
"dropout": 0,
"learning_rate": 0.001,
"optimizer": "SGD",
"teacher_forcing_ratio": 0.5,
"max_length": 50,
"weight_decay": 0.001
}
language = "tam"
parser = argparse.ArgumentParser(description="Transliteration Model with Attention")
parser.add_argument('-es', '--embed_size', type=int, default=32, help='Embedding size')
parser.add_argument('-hs', '--hidden_size', type=int, default=256, help='Hidden size')
parser.add_argument('-ct', '--cell_type', type=str, default='RNN', help='Cell type')
parser.add_argument('-nl', '--num_layers', type=int, default=2, help='Number of layers')
parser.add_argument('-dr', '--dropout', type=float, default=0, help='Dropout')
parser.add_argument('-lr', '--learning_rate', type=float, default=0.001, help='Learning rate')
parser.add_argument('-op', '--optimizer', type=str, default='SGD', help='Optimizer')
parser.add_argument('-wd', '--weight_decay', type=float, default=0.001, help='Weight decay')
parser.add_argument('-l', '--lang', type=str, default='tam', help='Language')
args = parser.parse_args()
for arg in vars(args):
params[arg] = getattr(args, arg)
language = args.lang
print("Language: {}".format(language))
print("Embedding size: {}".format(params['embed_size']))
print("Hidden size: {}".format(params['hidden_size']))
print("Cell type: {}".format(params['cell_type']))
print("Number of layers: {}".format(params['num_layers']))
print("Dropout: {}".format(params['dropout']))
print("Learning rate: {}".format(params['learning_rate']))
print("Optimizer: {}".format(params['optimizer']))
print("Weight decay: {}".format(params['weight_decay']))
print("Teacher forcing ratio: {}".format(params['teacher_forcing_ratio']))
print("Max length: {}".format(params['max_length']))
model = Translator(language, params)
epochs = 10
for epoch in range(epochs):
print("Epoch: {}".format(epoch + 1))
model.train()
train_accuracy = model.test_validate('train')
print("Training Accuracy: {:.4f}".format(train_accuracy))
validation_accuracy = model.test_validate('valid')
print("Validation Accuracy: {:.4f}".format(validation_accuracy))
test_accuracy = model.test_validate('test')
print("Test Accuracy: {:.4f}".format(test_accuracy))