nullHawk's picture
Added model and data
5e27466 verified
raw
history blame
1.85 kB
import unicodedata
import glob
import torch
import random
ALL_LETTERS = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz .,;'-"
N_LETTERS = len(ALL_LETTERS)
def load_data():
category_lines = {}
all_categories = []
def find_files(path):
return glob.glob(path)
def read_lines(filename):
lines = open(filename, encoding='utf-8').read().strip().split('\n')
return [unicode_to_ascii(line) for line in lines]
for filename in find_files('data/names/*.txt'):
category = filename.split('/')[-1].split('.')[0]
all_categories.append(category)
lines = read_lines(filename)
category_lines[category] = lines
return category_lines, all_categories
def letter_to_index(letter):
return ALL_LETTERS.find(letter)
def letter_to_tensor(letter):
tensor = torch.zeros(1,N_LETTERS)
tensor[0][letter_to_index(letter)] = 1
return tensor
def line_to_tensor(line):
tensor = torch.zeros(len(line), 1, N_LETTERS)
for i, letter in enumerate(line):
tensor[i][0][letter_to_index(letter)] = 1
return tensor
def random_training_example(category_lines, all_categories):
def random_choice(a):
random_idx = random.randint(0, len(a) - 1)
return a[random_idx]
category = random_choice(all_categories)
line = random_choice(category_lines[category])
category_tensor = torch.tensor([all_categories.index(category)], dtype=torch.long)
line_tensor = line_to_tensor(line)
return category, line, category_tensor, line_tensor
# Turn unicode string to plain ASCII
def unicode_to_ascii(s):
return ''.join(
c for c in unicodedata.normalize('NFD', s)
if unicodedata.category(c) != 'Mn'
and c in ALL_LETTERS
)
if __name__ == '__main__':
print(unicode_to_ascii("O'Néàl"))