Spaces:
Runtime error
Runtime error
import streamlit as st | |
import torch | |
from transformers import GPT2Tokenizer, GPT2LMHeadModel | |
# Load the GPT-2 tokenizer and model | |
tokenizer = GPT2Tokenizer.from_pretrained("gpt2") | |
model = GPT2LMHeadModel.from_pretrained("gpt2") | |
# Set the maximum length of generated text | |
max_length = 200 | |
# Define a function to generate text | |
def generate_text(prompt): | |
# Encode the prompt | |
input_ids = tokenizer.encode(prompt, return_tensors="pt") | |
# Generate text | |
output = model.generate( | |
input_ids=input_ids, | |
max_length=max_length, | |
num_beams=5, | |
no_repeat_ngram_size=2, | |
early_stopping=True | |
) | |
# Decode the generated text | |
text = tokenizer.decode(output[0], skip_special_tokens=True) | |
return text | |
# Set up the Streamlit app | |
st.title("GPT-2 Text Generator") | |
# Add a text input widget for the user to enter a prompt | |
prompt = st.text_input("Enter a prompt:") | |
# When the user clicks the "Generate" button, generate text | |
if st.button("Generate"): | |
with st.spinner("Generating text..."): | |
text = generate_text(prompt) | |
st.write(text) | |