nullHawk's picture
Added model and data
5e27466 verified
raw
history blame
768 Bytes
import gradio as gr
import torch
import model
from utils import line_to_tensor, load_data, N_LETTERS
def category_from_output(output):
category_index = torch.argmax(output).item()
return all_categories[category_index]
def run(name):
with torch.no_grad():
input_tensor = line_to_tensor(name)
hidden_tensor = rnn.init_hidden()
for i in range(input_tensor.size()[0]):
output, hidden_tensor = rnn(input_tensor[i], hidden_tensor)
return f"It is an {category_from_output(output)} name"
category_lines, all_categories = load_data()
rnn = model.RNN(N_LETTERS, 128, len(all_categories))
rnn.load_state_dict(torch.load('rnn.pth'))
rnn.eval
demo = gr.Interface(fn=run, inputs="text", outputs="text")
demo.launch()