K00B404 commited on
Commit
c49e8f0
·
verified ·
1 Parent(s): 2c59fff

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +17 -28
app.py CHANGED
@@ -1,5 +1,5 @@
1
  import gradio as gr
2
- from huggingface_hub import InferenceClient, list_models
3
  import os
4
  import json
5
 
@@ -50,33 +50,22 @@ def get_prompt(name):
50
  """Retrieve a system prompt by name."""
51
  return system_prompts.get(name, "")
52
 
53
- def fetch_models(task):
54
- """Fetch models for a specific task from Hugging Face Hub."""
55
- try:
56
- all_models=list_models()
57
- print(all_models)
58
- models = list_models(filter=f"pipeline_tags:{task}")
59
- return [model.modelId for model in models]
60
- except Exception as e:
61
- return [f"Error fetching models: {str(e)}"]
62
 
63
  # Gradio Interface
64
  with gr.Blocks() as demo:
65
- gr.Markdown("## Hugging Face Chatbot with Dynamic Model Selection")
66
 
67
  with gr.Row():
68
  with gr.Column():
69
- # Task selection
70
- task_selector = gr.Dropdown(
71
- choices=["text-generation", "image-classification", "text-classification", "translation"],
72
- label="Select Task",
73
- value="text-generation"
74
- )
75
-
76
- # Model selector
77
- model_selector = gr.Dropdown(choices=[], label="Select Model")
78
 
79
- # System prompt and input
80
  system_prompt_name = gr.Dropdown(choices=list(system_prompts.keys()), label="Select System Prompt")
81
  system_prompt_content = gr.TextArea(label="System Prompt", value=get_prompt("default"), lines=4)
82
  save_prompt_button = gr.Button("Save System Prompt")
@@ -87,14 +76,14 @@ with gr.Blocks() as demo:
87
  with gr.Column():
88
  output = gr.TextArea(label="Model Response", interactive=False, lines=10)
89
 
90
- # Update model list when task changes
91
- def update_model_list(task):
92
- models = fetch_models(task)
93
- print(f"Models:{models}")
94
- return gr.Dropdown.update(choices=models, value=models[0] if models else None)
95
 
96
- # Event bindings
97
- task_selector.change(update_model_list, inputs=[task_selector], outputs=[model_selector])
 
 
 
98
  save_prompt_button.click(update_prompt, inputs=[system_prompt_name, system_prompt_content], outputs=[])
99
  submit_button.click(chat_with_model, inputs=[user_input, system_prompt_content, model_selector], outputs=[output])
100
 
 
1
  import gradio as gr
2
+ from huggingface_hub import InferenceClient
3
  import os
4
  import json
5
 
 
50
  """Retrieve a system prompt by name."""
51
  return system_prompts.get(name, "")
52
 
53
+ # List of available models
54
+ available_models = [
55
+ "gpt-3.5-turbo",
56
+ "gpt-4",
57
+ "HuggingFaceH4/zephyr-7b-beta",
58
+ "HuggingFaceH4/zephyr-7b-alpha"
59
+ ]
 
 
60
 
61
  # Gradio Interface
62
  with gr.Blocks() as demo:
63
+ gr.Markdown("## Hugging Face Chatbot with Gradio")
64
 
65
  with gr.Row():
66
  with gr.Column():
67
+ model_selector = gr.Dropdown(choices=available_models, label="Select Model", value=available_models[0])
 
 
 
 
 
 
 
 
68
 
 
69
  system_prompt_name = gr.Dropdown(choices=list(system_prompts.keys()), label="Select System Prompt")
70
  system_prompt_content = gr.TextArea(label="System Prompt", value=get_prompt("default"), lines=4)
71
  save_prompt_button = gr.Button("Save System Prompt")
 
76
  with gr.Column():
77
  output = gr.TextArea(label="Model Response", interactive=False, lines=10)
78
 
79
+ def load_prompt(name):
80
+ return get_prompt(name)
 
 
 
81
 
82
+ system_prompt_name.change(
83
+ lambda name: (name, get_prompt(name)),
84
+ inputs=[system_prompt_name],
85
+ outputs=[system_prompt_name, system_prompt_content]
86
+ )
87
  save_prompt_button.click(update_prompt, inputs=[system_prompt_name, system_prompt_content], outputs=[])
88
  submit_button.click(chat_with_model, inputs=[user_input, system_prompt_content, model_selector], outputs=[output])
89