harpreetsahota commited on
Commit
eda2dbf
·
verified ·
1 Parent(s): 00ce2db

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +139 -43
app.py CHANGED
@@ -1,67 +1,163 @@
1
  import gradio as gr
2
  from huggingface_hub import InferenceClient
3
-
4
  from prompt_template import PromptTemplate, PromptLoader
5
  from assistant import AIAssistant
 
 
 
 
6
 
7
- """
8
- For more information on `huggingface_hub` Inference API support, please check the docs: https://huggingface.co/docs/huggingface_hub/v0.22.2/en/guides/inference
9
- """
10
- client = InferenceClient("HuggingFaceH4/zephyr-7b-beta")
 
 
 
 
 
 
 
 
 
 
 
11
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
 
13
  def respond(
14
  message,
15
  history: list[tuple[str, str]],
 
 
16
  system_message,
 
17
  max_tokens,
18
  temperature,
19
  top_p,
20
  ):
21
- messages = [{"role": "system", "content": system_message}]
22
-
23
- for val in history:
24
- if val[0]:
25
- messages.append({"role": "user", "content": val[0]})
26
- if val[1]:
27
- messages.append({"role": "assistant", "content": val[1]})
28
-
 
 
 
 
 
 
 
29
  messages.append({"role": "user", "content": message})
30
 
31
- response = ""
 
 
 
 
 
32
 
33
- for message in client.chat_completion(
34
- messages,
35
- max_tokens=max_tokens,
 
36
  stream=True,
37
- temperature=temperature,
38
- top_p=top_p,
39
  ):
40
- token = message.choices[0].delta.content
41
-
42
- response += token
43
  yield response
44
 
45
-
46
- """
47
- For information on how to customize the ChatInterface, peruse the gradio docs: https://www.gradio.app/docs/chatinterface
48
- """
49
- demo = gr.ChatInterface(
50
- respond,
51
- additional_inputs=[
52
- gr.Textbox(value="You are a friendly Chatbot.", label="System message"),
53
- gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
54
- gr.Slider(minimum=0.1, maximum=4.0, value=0.7, step=0.1, label="Temperature"),
55
- gr.Slider(
56
- minimum=0.1,
57
- maximum=1.0,
58
- value=0.95,
59
- step=0.05,
60
- label="Top-p (nucleus sampling)",
61
- ),
62
- ],
63
- )
64
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
65
 
66
  if __name__ == "__main__":
67
- demo.launch()
 
1
  import gradio as gr
2
  from huggingface_hub import InferenceClient
3
+ from openai import OpenAI
4
  from prompt_template import PromptTemplate, PromptLoader
5
  from assistant import AIAssistant
6
+ from pathlib import Path
7
+
8
+ # Load prompts from YAML
9
+ prompts = PromptLoader.load_prompts("prompts.yaml")
10
 
11
+ # Available models and their configurations
12
+ MODELS = {
13
+ "Zephyr 7B Beta": {
14
+ "name": "HuggingFaceH4/zephyr-7b-beta",
15
+ "provider": "huggingface"
16
+ },
17
+ "Mistral 7B": {
18
+ "name": "mistralai/Mistral-7B-v0.1",
19
+ "provider": "huggingface"
20
+ },
21
+ "GPT-3.5 Turbo": {
22
+ "name": "gpt-3.5-turbo",
23
+ "provider": "openai"
24
+ }
25
+ }
26
 
27
+ # Available prompt strategies
28
+ PROMPT_STRATEGIES = {
29
+ "Default": "system_context",
30
+ "Chain of Thought": "cot_prompt",
31
+ "Knowledge-based": "knowledge_prompt",
32
+ "Few-shot Learning": "few_shot_prompt",
33
+ "Meta-prompting": "meta_prompt"
34
+ }
35
+
36
+ def create_assistant(model_name):
37
+ model_info = MODELS[model_name]
38
+ if model_info["provider"] == "huggingface":
39
+ client = InferenceClient(model_info["name"])
40
+ else: # OpenAI
41
+ client = OpenAI()
42
+
43
+ return AIAssistant(
44
+ client=client,
45
+ model=model_info["name"]
46
+ )
47
 
48
  def respond(
49
  message,
50
  history: list[tuple[str, str]],
51
+ model_name,
52
+ prompt_strategy,
53
  system_message,
54
+ override_params: bool,
55
  max_tokens,
56
  temperature,
57
  top_p,
58
  ):
59
+ assistant = create_assistant(model_name)
60
+
61
+ # Get prompt template
62
+ prompt_template: PromptTemplate = prompts[PROMPT_STRATEGIES[prompt_strategy]]
63
+
64
+ # Generate system message using prompt template
65
+ formatted_system_message = prompt_template.format(prompt_strategy=system_message)
66
+
67
+ # Prepare messages
68
+ messages = [{"role": "system", "content": formatted_system_message}]
69
+ for user_msg, assistant_msg in history:
70
+ if user_msg:
71
+ messages.append({"role": "user", "content": user_msg})
72
+ if assistant_msg:
73
+ messages.append({"role": "assistant", "content": assistant_msg})
74
  messages.append({"role": "user", "content": message})
75
 
76
+ # Get generation parameters
77
+ generation_params = prompt_template.parameters if not override_params else {
78
+ "max_tokens": max_tokens,
79
+ "temperature": temperature,
80
+ "top_p": top_p
81
+ }
82
 
83
+ # Generate response using the assistant
84
+ for response in assistant.generate_response(
85
+ prompt_template=prompt_template,
86
+ generation_params=generation_params,
87
  stream=True,
88
+ messages=messages
 
89
  ):
 
 
 
90
  yield response
91
 
92
+ with gr.Blocks() as demo:
93
+ with gr.Row():
94
+ with gr.Column():
95
+ model_dropdown = gr.Dropdown(
96
+ choices=list(MODELS.keys()),
97
+ value=list(MODELS.keys())[0],
98
+ label="Select Model"
99
+ )
100
+ prompt_strategy_dropdown = gr.Dropdown(
101
+ choices=list(PROMPT_STRATEGIES.keys()),
102
+ value=list(PROMPT_STRATEGIES.keys())[0],
103
+ label="Select Prompt Strategy"
104
+ )
105
+ system_message = gr.Textbox(
106
+ value="You are a friendly and helpful AI assistant.",
107
+ label="System Message"
108
+ )
109
+
110
+ with gr.Row():
111
+ override_params = gr.Checkbox(
112
+ label="Override Template Parameters",
113
+ value=False
114
+ )
115
+
116
+ with gr.Row():
117
+ with gr.Column(visible=False) as param_controls:
118
+ max_tokens = gr.Slider(
119
+ minimum=1,
120
+ maximum=2048,
121
+ value=512,
122
+ step=1,
123
+ label="Max new tokens"
124
+ )
125
+ temperature = gr.Slider(
126
+ minimum=0.1,
127
+ maximum=4.0,
128
+ value=0.7,
129
+ step=0.1,
130
+ label="Temperature"
131
+ )
132
+ top_p = gr.Slider(
133
+ minimum=0.1,
134
+ maximum=1.0,
135
+ value=0.95,
136
+ step=0.05,
137
+ label="Top-p (nucleus sampling)"
138
+ )
139
+
140
+ chatbot = gr.ChatInterface(
141
+ fn=respond,
142
+ additional_inputs=[
143
+ model_dropdown,
144
+ prompt_strategy_dropdown,
145
+ system_message,
146
+ override_params,
147
+ max_tokens,
148
+ temperature,
149
+ top_p,
150
+ ]
151
+ )
152
+
153
+ def toggle_param_controls(override):
154
+ return gr.Column(visible=override)
155
+
156
+ override_params.change(
157
+ toggle_param_controls,
158
+ inputs=[override_params],
159
+ outputs=[param_controls]
160
+ )
161
 
162
  if __name__ == "__main__":
163
+ demo.launch()