File size: 1,396 Bytes
ff936e0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1f5aa86
 
ff936e0
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
import gradio as gr
from transformers import pipeline, set_seed

# Initialize the text-generation pipeline with DistilGPT2 on CPU.
generator = pipeline("text-generation", model="distilgpt2", device=-1)
set_seed(42)

def generate_text(prompt, stop_choice):
    # Map the dropdown choice to an actual stop character.
    mapping = {"Period": ".", "Space": " "}
    stop_token = mapping.get(stop_choice, "")
    # Generate text – we set max_length relative to the prompt length.
    output = generator(prompt, max_length=len(prompt.split()) + 50, num_return_sequences=1)[0]["generated_text"]
    # Look for the stop token in the generated output (after the prompt).
    idx = output.find(stop_token, len(prompt))
    if idx != -1:
        # If found, return the output up to and including the stop token.
        return output[: idx + len(stop_token)]
    else:
        # Otherwise, return the full generated text.
        return output

# Define a simple Gradio interface:
iface = gr.Interface(
    fn=generate_text,
    inputs=[
        gr.Textbox(lines=4, placeholder="Enter your prompt here...", label="Prompt"),
        gr.Dropdown(choices=["Period", "Space"], label="Stop Token")
    ],
    outputs="text",
    title="DistilGPT2 Text Generation",
    description="Enter a prompt and select a stop token (Period or Space) to halt generation."
)

if __name__ == "__main__":
    iface.launch()