venkyyuvy's picture
batch dim
7cfe36b
import pickle
import torch
import gradio as gr
from gpt import GPTLanguageModel
with open('stoi_itos.pkl', 'rb') as file:
stoi, itos = pickle.load(file)
encode = lambda s: [stoi[c] for c in s] # encoder: take a string, output a list of integers
decode = lambda l: ''.join([itos[i] for i in l]) # decoder: take a list of integers, output a string
lm = GPTLanguageModel()
lm.load_state_dict(torch.load('shakespeare_lm.pt', map_location='cpu'))
lm.eval()
def inference(prompt: str):
encoded_prompt = torch.tensor(encode(prompt), dtype=torch.long).unsqueeze(0)
output = decode(lm.generate(encoded_prompt, max_new_tokens=500)[0].tolist())
return output
gr_interface = gr.Interface(
inference,
inputs=[
gr.Textbox("man walking on the streets", label="Prompt"),
],
outputs=[
gr.Textbox(
label="Generated story",
height="auto",
)
],
title="Stories generated by a language model trained on Shakespeare's work",
examples=[
["Sunrise rising"],
["A big blast sound"]
]
)
gr_interface.launch(debug=True)