File size: 2,861 Bytes
25ddf29
 
 
de3eaf5
25ddf29
 
 
9657faa
 
de3eaf5
 
 
 
 
 
 
 
 
 
 
 
 
 
25ddf29
 
de3eaf5
25ddf29
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
de3eaf5
 
 
 
 
 
 
 
 
 
25ddf29
 
 
 
 
 
de3eaf5
12009aa
de3eaf5
 
 
25ddf29
 
 
de3eaf5
 
 
 
25ddf29
12009aa
 
 
 
 
de3eaf5
 
9657faa
25ddf29
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
import gradio as gr
from huggingface_hub import InferenceClient
import os
import json

# Initialize Hugging Face Inference Client
api_key = os.getenv("HF_TOKEN")
client = InferenceClient(api_key=api_key)

# Load or initialize system prompts
PROMPTS_FILE = "system_prompts.json"
if os.path.exists(PROMPTS_FILE):
    with open(PROMPTS_FILE, "r") as file:
        system_prompts = json.load(file)
else:
    system_prompts = {"default": "You are a good image generation prompt engineer for diffuser image generation models"}

def save_prompts():
    """Save the current system prompts to a JSON file."""
    with open(PROMPTS_FILE, "w") as file:
        json.dump(system_prompts, file, indent=4)

def chat_with_model(user_input, system_prompt):
    """Send user input to the model and return its response."""
    messages = [
        {"role": "system", "content": system_prompt},
        {"role": "user", "content": user_input}
    ]

    try:
        result = client.chat.completions.create(
            model="HuggingFaceH4/zephyr-7b-beta",
            messages=messages,
            temperature=0.5,
            max_tokens=2048,
            top_p=0.7,
            stream=False  # Stream disabled for simplicity
        )
        return result["choices"][0]["message"]["content"]
    except Exception as e:
        return f"Error: {str(e)}"

def update_prompt(name, content):
    """Update or add a new system prompt."""
    system_prompts[name] = content
    save_prompts()
    return f"System prompt '{name}' saved."

def get_prompt(name):
    """Retrieve a system prompt by name."""
    return system_prompts.get(name, "")

# Gradio Interface
with gr.Blocks() as demo:
    gr.Markdown("## Hugging Face Chatbot with Gradio")

    with gr.Row():
        with gr.Column():
            system_prompt_name = gr.Dropdown(choices=list(system_prompts.keys()), label="Select System Prompt")
            system_prompt_content = gr.TextArea(label="System Prompt", value=get_prompt("default"), lines=4)
            save_prompt_button = gr.Button("Save System Prompt")

            user_input = gr.TextArea(label="Enter your prompt", placeholder="Describe the character or request a detailed description...", lines=4)
            submit_button = gr.Button("Generate")
        
        with gr.Column():
            output = gr.TextArea(label="Model Response", interactive=False, lines=10)

    def load_prompt(name):
        return get_prompt(name)

    system_prompt_name.change(
        lambda name: (name, get_prompt(name)), 
        inputs=[system_prompt_name], 
        outputs=[system_prompt_name, system_prompt_content]
    )
    save_prompt_button.click(update_prompt, inputs=[system_prompt_name, system_prompt_content], outputs=[])
    submit_button.click(chat_with_model, inputs=[user_input, system_prompt_content], outputs=[output])

# Run the app
demo.launch()