Spaces:
Runtime error
Runtime error
from transformers import GPT2Tokenizer | |
import torch | |
import streamlit as st | |
tokenizer = GPT2Tokenizer.from_pretrained('gpt2') | |
tokenizer.pad_token = tokenizer.eos_token | |
model = torch.load("poem_model.pt") | |
def infer(inp): | |
inp = tokenizer(inp,return_tensors="pt") | |
X = inp["input_ids"] #.to(device) | |
a = inp["attention_mask"] #.to(device) | |
output = model.generate(X, | |
attention_mask=a, | |
max_length=100, | |
min_length=10, | |
early_stopping=True, | |
num_beams=5, | |
no_repeat_ngram_size=2) | |
output = tokenizer.decode(output[0]) | |
return output | |
output = infer(" I shall go") | |
text = st.text_area(output) | |
print(text) |