Spaces:
Running
Running
import pandas as pd | |
import torch | |
import torch.nn as nn | |
import torch.optim as optim | |
from src.language import Language, EOS_token | |
def get_data(lang: str, type: str) -> list[list[str]]: | |
""" | |
Returns: 'pairs': list of [input_word, target_word] pairs | |
""" | |
path = "./aksharantar_sampled/{}/{}_{}.csv".format(lang, lang, type) | |
df = pd.read_csv(path, header=None) | |
pairs = df.values.tolist() | |
return pairs | |
def get_languages(lang: str): | |
""" | |
Returns | |
1. input_lang: input language - English | |
2. output_lang: output language - Given language | |
3. pairs: list of [input_word, target_word] pairs | |
""" | |
input_lang = Language('eng') | |
output_lang = Language(lang) | |
pairs = get_data(lang, "train") | |
for pair in pairs: | |
input_lang.addWord(pair[0]) | |
output_lang.addWord(pair[1]) | |
return input_lang, output_lang, pairs | |
def get_cell(cell_type: str): | |
if cell_type == "LSTM": | |
return nn.LSTM | |
elif cell_type == "GRU": | |
return nn.GRU | |
elif cell_type == "RNN": | |
return nn.RNN | |
else: | |
raise Exception("Invalid cell type") | |
def get_optimizer(optimizer: str): | |
if optimizer == "SGD": | |
return optim.SGD | |
elif optimizer == "ADAM": | |
return optim.Adam | |
else: | |
raise Exception("Invalid optimizer") | |
def indexesFromWord(lang:Language, word:str): | |
return [lang.word2index[char] for char in word] | |
def tensorFromWord(lang:Language, word:str, device:str): | |
indexes = indexesFromWord(lang, word) | |
indexes.append(EOS_token) | |
return torch.tensor(indexes, dtype=torch.long, device=device).view(-1, 1) | |
def tensorsFromPair(input_lang:Language, output_lang:Language, pair:list[str], device:str): | |
input_tensor = tensorFromWord(input_lang, pair[0], device) | |
target_tensor = tensorFromWord(output_lang, pair[1], device) | |
return (input_tensor, target_tensor) |