Spaces:
Runtime error
Runtime error
File size: 902 Bytes
21b6296 270b967 21b6296 29074a9 b78dd53 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 |
from transformers import GPT2Tokenizer
import torch
import streamlit as st
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
tokenizer.pad_token = tokenizer.eos_token
model = torch.load("poem_model.pt")
def infer(inp):
inp = tokenizer(inp,return_tensors="pt")
X = inp["input_ids"] #.to(device)
a = inp["attention_mask"] #.to(device)
output = model.generate(X,
attention_mask=a,
max_length=100,
min_length=10,
early_stopping=True,
num_beams=5,
no_repeat_ngram_size=2)
output = tokenizer.decode(output[0])
return output
st.title("WaltWhitman-GPT By Ilyas")
text = st.text_area("Enter Prompt")
if st.button("Generate Poem"):
if text:
output = infer(text)
st.write(output) |