Spaces:
Runtime error
Runtime error
import os | |
import gradio as gr | |
from transformers import pipeline | |
hf_token = os.getenv("hf_token") | |
# Initialize the text generation pipeline | |
generator = pipeline("text-generation", model="isitcoding/gpt2_120_finetuned", use_auth_token=hf_token) | |
# Define the response function with additional options for customization | |
'''def text_generation( | |
prompt: str, | |
details: bool = False, | |
stream: bool = False, | |
model: str = None, | |
best_of: int = None, | |
decoder_input_details: bool = None, | |
do_sample: bool = False, | |
frequency_penalty: float = None, | |
grammar: None = None, | |
max_new_tokens: int = None, | |
repetition_penalty: float = None | |
): | |
# Setup the configuration for the model generation | |
gen_params = { | |
"max_length": 518, # Default, you can tweak it or set from parameters | |
"num_return_sequences": 1, | |
"do_sample": do_sample, | |
"temperature": 0.7, # Controls randomness | |
"top_k": 50, # You can adjust for more control over sampling | |
"top_p": 0.9, # Same as above, for sampling | |
} | |
if max_new_tokens: | |
gen_params["max_length"] = max_new_tokens + len(prompt.split()) | |
if frequency_penalty: | |
gen_params["frequency_penalty"] = frequency_penalty | |
if repetition_penalty: | |
gen_params["repetition_penalty"] = repetition_penalty | |
# Generate the text based on the input prompt and parameters | |
generated_text = generator(prompt, **gen_params)[0]["generated_text"] | |
if details: | |
# Return additional details for debugging if needed | |
return { | |
"generated_text": generated_text, | |
"params_used": gen_params | |
} | |
else: | |
return generated_text | |
# Create Gradio interface | |
iface = gr.Interface( | |
fn=text_generation, # The function we defined | |
inputs=[ | |
gr.Textbox(label="Input Prompt"), # User input prompt | |
gr.Checkbox(label="Show Details", default=False), # Option for additional details | |
gr.Checkbox(label="Stream Mode", default=False), # Streaming checkbox (not used in this example) | |
gr.Textbox(label="Model (optional)", default=None), # Optional model name | |
gr.Slider(minimum=1, maximum=5, label="Best of (Optional)", default=None), | |
gr.Slider(minimum=0.0, maximum=2.0, label="Frequency Penalty (Optional)", default=None), | |
gr.Slider(minimum=0.0, maximum=2.0, label="Repetition Penalty (Optional)", default=None), | |
], | |
outputs="text" # Output is plain text | |
) | |
# Launch the interface | |
iface.launch() | |
''' | |
# Test model generation | |
def generate_response(prompt): | |
response = generator(prompt, max_length=50) | |
return response[0]["generated_text"] | |
# Gradio interface | |
import gradio as gr | |
gr.Interface(fn=generate_response, inputs="text", outputs="text").launch() |