Spaces:
Runtime error
Runtime error
from transformers import GPT2Tokenizer | |
import torch | |
import streamlit as st | |
tokenizer = GPT2Tokenizer.from_pretrained('gpt2') | |
tokenizer.pad_token = tokenizer.eos_token | |
model = torch.load("poem_model.pt") | |
def infer(inp): | |
inp = tokenizer(inp,return_tensors="pt") | |
X = inp["input_ids"] #.to(device) | |
a = inp["attention_mask"] #.to(device) | |
output = model.generate(X, | |
attention_mask=a, | |
max_length=100, | |
min_length=10, | |
early_stopping=True, | |
num_beams=5, | |
no_repeat_ngram_size=2) | |
output = tokenizer.decode(output[0]) | |
return output | |
st.title("WaltWhitman-GPT By Ilyas") | |
text = st.text_area("Enter Prompt") | |
if st.button("Generate Poem"): | |
if text: | |
output = infer(text) | |
st.write(output) |