Spaces:
Running
Running
import gradio as gr | |
from transformers import pipeline, set_seed | |
# Initialize the text-generation pipeline with DistilGPT2 on CPU. | |
generator = pipeline("text-generation", model="distilgpt2", device=-1) | |
set_seed(42) | |
def generate_text(prompt, stop_choice): | |
# Map the dropdown choice to an actual stop character. | |
mapping = {"Period": ".", "Space": " "} | |
stop_token = mapping.get(stop_choice, "") | |
# Generate text β we set max_length relative to the prompt length. | |
output = generator(prompt, max_length=len(prompt.split()) + 50, num_return_sequences=1)[0]["generated_text"] | |
# Look for the stop token in the generated output (after the prompt). | |
idx = output.find(stop_token, len(prompt)) | |
if idx != -1: | |
# If found, return the output up to and including the stop token. | |
return output[: idx + len(stop_token)] | |
else: | |
# Otherwise, return the full generated text. | |
return output | |
# Define a simple Gradio interface: | |
iface = gr.Interface( | |
fn=generate_text, | |
inputs=[ | |
gr.Textbox(lines=4, placeholder="Enter your prompt here...", label="Prompt"), | |
gr.Dropdown(choices=["Period", "Space"], label="Stop Token") | |
], | |
outputs="text", | |
title="DistilGPT2 Text Generation", | |
description="Enter a prompt and select a stop token (Period or Space) to halt generation." | |
) | |
if __name__ == "__main__": | |
iface.launch() | |