WaltWhitman-GPT / app.py
Isaoudata's picture
Update app.py
b78dd53
raw
history blame contribute delete
902 Bytes
from transformers import GPT2Tokenizer
import torch
import streamlit as st
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
tokenizer.pad_token = tokenizer.eos_token
model = torch.load("poem_model.pt")
def infer(inp):
inp = tokenizer(inp,return_tensors="pt")
X = inp["input_ids"] #.to(device)
a = inp["attention_mask"] #.to(device)
output = model.generate(X,
attention_mask=a,
max_length=100,
min_length=10,
early_stopping=True,
num_beams=5,
no_repeat_ngram_size=2)
output = tokenizer.decode(output[0])
return output
st.title("WaltWhitman-GPT By Ilyas")
text = st.text_area("Enter Prompt")
if st.button("Generate Poem"):
if text:
output = infer(text)
st.write(output)