File size: 4,249 Bytes
25ddf29
f42689b
25ddf29
de3eaf5
25ddf29
 
 
9657faa
 
de3eaf5
 
 
 
 
 
 
 
 
 
 
 
 
f42689b
25ddf29
 
de3eaf5
25ddf29
 
 
 
 
f42689b
25ddf29
 
 
 
 
 
 
 
 
 
de3eaf5
 
 
 
 
 
 
 
 
 
2bdeffd
 
 
 
 
 
 
 
 
f42689b
 
 
 
8d65776
 
 
 
 
 
f42689b
 
 
e358e48
25ddf29
 
f42689b
25ddf29
 
 
f42689b
 
e358e48
f42689b
 
 
e358e48
f42689b
 
2bdeffd
67fee7f
 
f42689b
 
de3eaf5
12009aa
de3eaf5
 
 
25ddf29
f42689b
25ddf29
de3eaf5
 
f42689b
 
 
 
25ddf29
f42689b
 
de3eaf5
f42689b
9657faa
25ddf29
f42689b
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
import gradio as gr
from huggingface_hub import InferenceClient, list_models
import os
import json

# Initialize Hugging Face Inference Client
api_key = os.getenv("HF_TOKEN")
client = InferenceClient(api_key=api_key)

# Load or initialize system prompts
PROMPTS_FILE = "system_prompts.json"
if os.path.exists(PROMPTS_FILE):
    with open(PROMPTS_FILE, "r") as file:
        system_prompts = json.load(file)
else:
    system_prompts = {"default": "You are a good image generation prompt engineer for diffuser image generation models"}

def save_prompts():
    """Save the current system prompts to a JSON file."""
    with open(PROMPTS_FILE, "w") as file:
        json.dump(system_prompts, file, indent=4)

def chat_with_model(user_input, system_prompt, selected_model):
    """Send user input to the model and return its response."""
    messages = [
        {"role": "system", "content": system_prompt},
        {"role": "user", "content": user_input}
    ]

    try:
        result = client.chat.completions.create(
            model=selected_model,
            messages=messages,
            temperature=0.5,
            max_tokens=2048,
            top_p=0.7,
            stream=False  # Stream disabled for simplicity
        )
        return result["choices"][0]["message"]["content"]
    except Exception as e:
        return f"Error: {str(e)}"

def update_prompt(name, content):
    """Update or add a new system prompt."""
    system_prompts[name] = content
    save_prompts()
    return f"System prompt '{name}' saved."

def get_prompt(name):
    """Retrieve a system prompt by name."""
    return system_prompts.get(name, "")

    
def update_dropdown_choices(models):
    """Update dropdown choices dynamically."""
    if not models:
        return [], None  # Return empty choices and no default value
    return models, models[0]



def fetch_models(task):
    """Fetch models for a specific task from Hugging Face Hub."""
    try:
        models = list_models(filter=f"pipeline_tags:{task}")
        models += [
            "HuggingFaceH4/zephyr-7b-beta",
            "HuggingFaceH4/zephyr-7b-alpha",
            "HuggingFaceH4/zephyr-6b"
        ]

        return [model.modelId for model in models]
    except Exception as e:
        return [f"Error fetching models: {str(e)}"]
task_choices=["text-generation", "image-classification", "text-classification", "translation"]
# Gradio Interface
with gr.Blocks() as demo:
    gr.Markdown("## Hugging Face Chatbot with Dynamic Model Selection")

    with gr.Row():
        with gr.Column():
            # Task selection
            task_selector = gr.Dropdown(
                choices=task_choices,
                label="Select Task",
                value="text-generation"
            )
            models=fetch_models(task_choices[0])
            # Model selector
            model_selector = gr.Dropdown(choices=[], label="Select Model")
            # Example of handling model updates
            model_selector.choices = models
            model_selector.value = models[0]

            # System prompt and input
            system_prompt_name = gr.Dropdown(choices=list(system_prompts.keys()), label="Select System Prompt")
            system_prompt_content = gr.TextArea(label="System Prompt", value=get_prompt("default"), lines=4)
            save_prompt_button = gr.Button("Save System Prompt")

            user_input = gr.TextArea(label="Enter your prompt", placeholder="Describe the character or request a detailed description...", lines=4)
            submit_button = gr.Button("Generate")

        with gr.Column():
            output = gr.TextArea(label="Model Response", interactive=False, lines=10)

    # Update model list when task changes
    def update_model_list(task):
        models = fetch_models(task)
        return gr.Dropdown.update(choices=models, value=models[0] if models else None)

    # Event bindings
    task_selector.change(update_model_list, inputs=[task_selector], outputs=[model_selector])
    save_prompt_button.click(update_prompt, inputs=[system_prompt_name, system_prompt_content], outputs=[])
    submit_button.click(chat_with_model, inputs=[user_input, system_prompt_content, model_selector], outputs=[output])

# Run the app
demo.launch()