Spaces:
Runtime error
Runtime error
import gradio as gr | |
from transformers import GPT2LMHeadModel, GPT2Tokenizer | |
# Load pre-trained model and tokenizer | |
model_name = "gpt2" # You can use other models like gpt-2-large or gpt-3 for better performance | |
model = GPT2LMHeadModel.from_pretrained(model_name) | |
tokenizer = GPT2Tokenizer.from_pretrained(model_name) | |
# Function to generate keywords based on a prompt | |
def generate_keywords(prompt): | |
# Encode input prompt with a more direct instruction for only keywords | |
prompt_with_instruction = prompt + " Only provide a list of keywords, no additional text." | |
inputs = tokenizer.encode(prompt_with_instruction, return_tensors="pt") | |
# Generate output from model | |
outputs = model.generate(inputs, max_length=50, num_return_sequences=1, no_repeat_ngram_size=2, top_k=50, top_p=0.95) | |
# Decode generated tokens | |
generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True) | |
# Clean up the text to remove unnecessary parts | |
# Remove anything after 'Only provide a list of keywords' | |
clean_text = generated_text.split("Only provide a list of keywords")[0].strip() | |
# Return the keywords only | |
return clean_text | |
# Gradio interface | |
iface = gr.Interface(fn=generate_keywords, | |
inputs=gr.Textbox(label="Enter Ad Prompt", placeholder="E.g., Generate ad keywords for wireless headphones"), | |
outputs=gr.Textbox(label="Generated Keywords"), | |
live=True) | |
# Launch interface | |
iface.launch() | |