chat-ui / app.py
isitcoding's picture
Update app.py
3950cb3 verified
raw
history blame
2.53 kB
import os
import gradio as gr
from transformers import pipeline
hf_token = os.getenv("hf_token")
# Initialize the text generation pipeline
generator = pipeline("text-generation", model="isitcoding/gpt2_120_finetuned", use_auth_token=hf_token)
# Define the response function with additional options for customization
def text_generation(
prompt: str,
details: bool = False,
stream: bool = False,
model: str = None,
best_of: int = None,
decoder_input_details: bool = None,
do_sample: bool = False,
frequency_penalty: float = None,
grammar: None = None,
max_new_tokens: int = None,
repetition_penalty: float = None
):
# Setup the configuration for the model generation
gen_params = {
"max_length": 518, # Default, you can tweak it or set from parameters
"num_return_sequences": 1,
"do_sample": do_sample,
"temperature": 0.7, # Controls randomness
"top_k": 50, # You can adjust for more control over sampling
"top_p": 0.9, # Same as above, for sampling
}
if max_new_tokens:
gen_params["max_length"] = max_new_tokens + len(prompt.split())
if frequency_penalty:
gen_params["frequency_penalty"] = frequency_penalty
if repetition_penalty:
gen_params["repetition_penalty"] = repetition_penalty
# Generate the text based on the input prompt and parameters
generated_text = generator(prompt, **gen_params)[0]["generated_text"]
if details:
# Return additional details for debugging if needed
return {
"generated_text": generated_text,
"params_used": gen_params
}
else:
return generated_text
# Create Gradio interface
iface = gr.Interface(
fn=text_generation, # The function we defined
inputs=[
gr.Textbox(label="Input Prompt"), # User input prompt
gr.Checkbox(label="Show Details", default=False), # Option for additional details
gr.Checkbox(label="Stream Mode", default=False), # Streaming checkbox (not used in this example)
gr.Textbox(label="Model (optional)", default=None), # Optional model name
gr.Slider(minimum=1, maximum=5, label="Best of (Optional)", default=None),
gr.Slider(minimum=0.0, maximum=2.0, label="Frequency Penalty (Optional)", default=None),
gr.Slider(minimum=0.0, maximum=2.0, label="Repetition Penalty (Optional)", default=None),
],
outputs="text" # Output is plain text
)
# Launch the interface
iface.launch()