Spaces:
Runtime error
Runtime error
import pickle | |
import torch | |
import gradio as gr | |
from gpt import GPTLanguageModel | |
with open('stoi_itos.pkl', 'rb') as file: | |
stoi, itos = pickle.load(file) | |
encode = lambda s: [stoi[c] for c in s] # encoder: take a string, output a list of integers | |
decode = lambda l: ''.join([itos[i] for i in l]) # decoder: take a list of integers, output a string | |
lm = GPTLanguageModel() | |
lm.load_state_dict(torch.load('shakespeare_lm.pt', map_location='cpu')) | |
lm.eval() | |
def inference(prompt: str): | |
encoded_prompt = torch.tensor(encode(prompt), dtype=torch.long).unsqueeze(0) | |
output = decode(lm.generate(encoded_prompt, max_new_tokens=500)[0].tolist()) | |
return output | |
gr_interface = gr.Interface( | |
inference, | |
inputs=[ | |
gr.Textbox("man walking on the streets", label="Prompt"), | |
], | |
outputs=[ | |
gr.Textbox( | |
label="Generated story", | |
height="auto", | |
) | |
], | |
title="Stories generated by a language model trained on Shakespeare's work", | |
examples=[ | |
["Sunrise rising"], | |
["A big blast sound"] | |
] | |
) | |
gr_interface.launch(debug=True) | |