nullHawk commited on
Commit
41ce832
·
verified ·
1 Parent(s): 5ce9bed

fixed word embedding issue

Browse files
Files changed (1) hide show
  1. app.py +3 -6
app.py CHANGED
@@ -3,9 +3,7 @@ import torch
3
  import model
4
  from utils import line_to_tensor, load_data, N_LETTERS
5
 
6
- def category_from_output(output):
7
- category_index = torch.argmax(output).item()
8
- return all_categories[category_index]
9
 
10
  def run(name):
11
 
@@ -14,10 +12,9 @@ def run(name):
14
  hidden_tensor = rnn.init_hidden()
15
  for i in range(input_tensor.size()[0]):
16
  output, hidden_tensor = rnn(input_tensor[i], hidden_tensor)
17
- return f"It is an {category_from_output(output)} name"
18
 
19
- category_lines, all_categories = load_data()
20
- rnn = model.RNN(N_LETTERS, 128, len(all_categories))
21
  rnn.load_state_dict(torch.load('rnn.pth'))
22
  rnn.eval
23
 
 
3
  import model
4
  from utils import line_to_tensor, load_data, N_LETTERS
5
 
6
+ language_map = ['Polish', 'Irish', 'Italian', 'Korean', 'English', 'Czech', 'Chinese', 'Japanese', 'Portuguese', 'Indian', 'Greek', 'Vietnamese', 'French', 'German', 'Russian', 'Scottish', 'Arabic', 'Dutch', 'Spanish']
 
 
7
 
8
  def run(name):
9
 
 
12
  hidden_tensor = rnn.init_hidden()
13
  for i in range(input_tensor.size()[0]):
14
  output, hidden_tensor = rnn(input_tensor[i], hidden_tensor)
15
+ return f"It is an {language_map[torch.argmax(output).item()]} name"
16
 
17
+ rnn = model.RNN(N_LETTERS, 128, 19)
 
18
  rnn.load_state_dict(torch.load('rnn.pth'))
19
  rnn.eval
20