import gradio as gr import torch import model from utils import line_to_tensor, load_data, N_LETTERS def category_from_output(output): category_index = torch.argmax(output).item() return all_categories[category_index] def run(name): with torch.no_grad(): input_tensor = line_to_tensor(name) hidden_tensor = rnn.init_hidden() for i in range(input_tensor.size()[0]): output, hidden_tensor = rnn(input_tensor[i], hidden_tensor) return f"It is an {category_from_output(output)} name" category_lines, all_categories = load_data() rnn = model.RNN(N_LETTERS, 128, len(all_categories)) rnn.load_state_dict(torch.load('rnn.pth')) rnn.eval demo = gr.Interface(fn=run, inputs="text", outputs="text") demo.launch()