ZephyrChat / app.py
K00B404's picture
Update app.py
12009aa verified
raw
history blame
2.86 kB
import gradio as gr
from huggingface_hub import InferenceClient
import os
import json
# Initialize Hugging Face Inference Client
api_key = os.getenv("HF_TOKEN")
client = InferenceClient(api_key=api_key)
# Load or initialize system prompts
PROMPTS_FILE = "system_prompts.json"
if os.path.exists(PROMPTS_FILE):
with open(PROMPTS_FILE, "r") as file:
system_prompts = json.load(file)
else:
system_prompts = {"default": "You are a good image generation prompt engineer for diffuser image generation models"}
def save_prompts():
"""Save the current system prompts to a JSON file."""
with open(PROMPTS_FILE, "w") as file:
json.dump(system_prompts, file, indent=4)
def chat_with_model(user_input, system_prompt):
"""Send user input to the model and return its response."""
messages = [
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_input}
]
try:
result = client.chat.completions.create(
model="HuggingFaceH4/zephyr-7b-beta",
messages=messages,
temperature=0.5,
max_tokens=2048,
top_p=0.7,
stream=False # Stream disabled for simplicity
)
return result["choices"][0]["message"]["content"]
except Exception as e:
return f"Error: {str(e)}"
def update_prompt(name, content):
"""Update or add a new system prompt."""
system_prompts[name] = content
save_prompts()
return f"System prompt '{name}' saved."
def get_prompt(name):
"""Retrieve a system prompt by name."""
return system_prompts.get(name, "")
# Gradio Interface
with gr.Blocks() as demo:
gr.Markdown("## Hugging Face Chatbot with Gradio")
with gr.Row():
with gr.Column():
system_prompt_name = gr.Dropdown(choices=list(system_prompts.keys()), label="Select System Prompt")
system_prompt_content = gr.TextArea(label="System Prompt", value=get_prompt("default"), lines=4)
save_prompt_button = gr.Button("Save System Prompt")
user_input = gr.TextArea(label="Enter your prompt", placeholder="Describe the character or request a detailed description...", lines=4)
submit_button = gr.Button("Generate")
with gr.Column():
output = gr.TextArea(label="Model Response", interactive=False, lines=10)
def load_prompt(name):
return get_prompt(name)
system_prompt_name.change(
lambda name: (name, get_prompt(name)),
inputs=[system_prompt_name],
outputs=[system_prompt_name, system_prompt_content]
)
save_prompt_button.click(update_prompt, inputs=[system_prompt_name, system_prompt_content], outputs=[])
submit_button.click(chat_with_model, inputs=[user_input, system_prompt_content], outputs=[output])
# Run the app
demo.launch()