|
import torch |
|
import pickle |
|
|
|
|
|
TOKEN = '.' |
|
|
|
|
|
words = open('data/names.txt','r').read().splitlines() |
|
|
|
|
|
vocab = sorted(list(set(''.join(words)) | {TOKEN})) |
|
|
|
|
|
n = len(vocab) |
|
N = torch.zeros((n,n), dtype = torch.int32) |
|
|
|
|
|
char_to_int = {char:i for i,char in enumerate(vocab)} |
|
int_to_char = {value:key for key,value in char_to_int.items()} |
|
|
|
|
|
for word in words: |
|
chars = [TOKEN] + list(word) + [TOKEN] |
|
for ch1,ch2 in zip(chars,chars[1:]): |
|
ix1 = char_to_int[ch1] |
|
ix2 = char_to_int[ch2] |
|
N[ix1,ix2] += 1 |
|
|
|
|
|
P = N.float() |
|
P /= P.sum(1, keepdim = True) |
|
|
|
|
|
with open('model/bigrams.pkl', 'wb') as file: |
|
pickle.dump([P,char_to_int,int_to_char], file) |