DistilGPT2 / app.py
TotoB12's picture
Update app.py
1f5aa86 verified
import gradio as gr
from transformers import pipeline, set_seed
# Initialize the text-generation pipeline with DistilGPT2 on CPU.
generator = pipeline("text-generation", model="distilgpt2", device=-1)
set_seed(42)
def generate_text(prompt, stop_choice):
# Map the dropdown choice to an actual stop character.
mapping = {"Period": ".", "Space": " "}
stop_token = mapping.get(stop_choice, "")
# Generate text – we set max_length relative to the prompt length.
output = generator(prompt, max_length=len(prompt.split()) + 50, num_return_sequences=1)[0]["generated_text"]
# Look for the stop token in the generated output (after the prompt).
idx = output.find(stop_token, len(prompt))
if idx != -1:
# If found, return the output up to and including the stop token.
return output[: idx + len(stop_token)]
else:
# Otherwise, return the full generated text.
return output
# Define a simple Gradio interface:
iface = gr.Interface(
fn=generate_text,
inputs=[
gr.Textbox(lines=4, placeholder="Enter your prompt here...", label="Prompt"),
gr.Dropdown(choices=["Period", "Space"], label="Stop Token")
],
outputs="text",
title="DistilGPT2 Text Generation",
description="Enter a prompt and select a stop token (Period or Space) to halt generation."
)
if __name__ == "__main__":
iface.launch()