File size: 768 Bytes
5e27466
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
import gradio as gr
import torch
import model
from utils import line_to_tensor, load_data, N_LETTERS

def category_from_output(output):
    category_index = torch.argmax(output).item()
    return all_categories[category_index]

def run(name):
    
    with torch.no_grad():
        input_tensor = line_to_tensor(name)
        hidden_tensor = rnn.init_hidden()
        for i in range(input_tensor.size()[0]):
            output, hidden_tensor = rnn(input_tensor[i], hidden_tensor)
        return f"It is an {category_from_output(output)} name"

category_lines, all_categories = load_data()
rnn = model.RNN(N_LETTERS, 128, len(all_categories))
rnn.load_state_dict(torch.load('rnn.pth'))
rnn.eval


demo = gr.Interface(fn=run, inputs="text", outputs="text")
demo.launch()