ysharma HF staff commited on
Commit
da6bfd7
·
1 Parent(s): 0be6373

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +195 -0
app.py ADDED
@@ -0,0 +1,195 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio ad gr
2
+ import os
3
+ import json
4
+ import requests
5
+
6
+
7
+ HF_TOKEN = os.getenv('HF_TOKEN')
8
+ HEADERS = {"Authorization": HF_TOKEN}
9
+ zephyr_7b_beta = os.getenv('zephyr_7b_beta')
10
+ zephyr_7b_alpha = os.getenv('zephyr_7b_alpha')
11
+
12
+
13
+ def build_input_prompt(message, chatbot):
14
+ """
15
+ Constructs the input prompt string from the chatbot interactions and the current message.
16
+ """
17
+ input_prompt = "<|system|>\n</s>\n<|user|>\n"
18
+ for interaction in chatbot:
19
+ input_prompt = input_prompt + str(interaction[0]) + "</s>\n<|assistant|>\n" + str(interaction[1]) + "\n</s>\n<|user|>\n"
20
+
21
+ input_prompt = input_prompt + str(message) + "</s>\n<|assistant|>"
22
+ return input_prompt
23
+
24
+
25
+ def post_request_beta(payload):
26
+ """
27
+ Sends a POST request to the predefined Zephyr-7b-Beta URL and returns the JSON response.
28
+ """
29
+ response = requests.post(zephyr_7b_beta, headers=HEADERS, json=payload)
30
+ response.raise_for_status() # Will raise an HTTPError if the HTTP request returned an unsuccessful status code
31
+ return response.json()
32
+
33
+
34
+ def post_request_alpha(payload):
35
+ """
36
+ Sends a POST request to the predefined Zephyr-7b-Alpha URL and returns the JSON response.
37
+ """
38
+ response = requests.post(zephyr_7b_beta, headers=HEADERS, json=payload)
39
+ response.raise_for_status() # Will raise an HTTPError if the HTTP request returned an unsuccessful status code
40
+ return response.json()
41
+
42
+
43
+ def predict_beta(message, chatbot, temperature=0.9, max_new_tokens=256, top_p=0.6, repetition_penalty=1.0):
44
+ temperature = float(temperature)
45
+ top_p = float(top_p)
46
+
47
+ input_prompt = build_input_prompt(chatbot, message)
48
+
49
+ data = {
50
+ "inputs": input_prompt,
51
+ "parameters": {
52
+ "max_new_tokens": max_new_tokens,
53
+ "temperature": temperature,
54
+ "top_p": top_p,
55
+ "repetition_penalty": repetition_penalty,
56
+ "do_sample": True,
57
+ },
58
+ }
59
+
60
+ try:
61
+ response_data = post_request_beta(data)
62
+ json_obj = response_data[0]
63
+
64
+ if 'generated_text' in json_obj and len(json_obj['generated_text']) > 0:
65
+ bot_message = json_obj['generated_text']
66
+ chatbot.append((message, bot_message))
67
+ return "", chatbot
68
+ elif 'error' in json_obj:
69
+ raise gr.Error(json_obj['error'] + ' Please refresh and try again with smaller input prompt')
70
+ else:
71
+ warning_msg = f"Unexpected response: {json_obj}"
72
+ print(warning_msg)
73
+ raise ValueError(warning_msg)
74
+ except requests.HTTPError:
75
+ error_msg = f"Request failed with status code {response.status_code}"
76
+ print(error_msg)
77
+ raise gr.Error(error_msg)
78
+ except json.JSONDecodeError:
79
+ error_msg = f"Failed to decode response as JSON: {response.text}"
80
+ print(error_msg)
81
+ raise gr.Error(error_msg)
82
+
83
+
84
+ def predict_alpha(message, chatbot, temperature=0.9, max_new_tokens=256, top_p=0.6, repetition_penalty=1.0):
85
+ temperature = float(temperature)
86
+ top_p = float(top_p)
87
+
88
+ input_prompt = build_input_prompt(chatbot, message)
89
+
90
+ data = {
91
+ "inputs": input_prompt,
92
+ "parameters": {
93
+ "max_new_tokens": max_new_tokens,
94
+ "temperature": temperature,
95
+ "top_p": top_p,
96
+ "repetition_penalty": repetition_penalty,
97
+ "do_sample": True,
98
+ },
99
+ }
100
+
101
+ try:
102
+ response_data = post_request_beta(data)
103
+ json_obj = response_data[0]
104
+
105
+ if 'generated_text' in json_obj and len(json_obj['generated_text']) > 0:
106
+ bot_message = json_obj['generated_text']
107
+ chatbot.append((message, bot_message))
108
+ return "", chatbot
109
+ elif 'error' in json_obj:
110
+ raise gr.Error(json_obj['error'] + ' Please refresh and try again with smaller input prompt')
111
+ else:
112
+ warning_msg = f"Unexpected response: {json_obj}"
113
+ print(warning_msg)
114
+ raise ValueError(warning_msg)
115
+ except requests.HTTPError:
116
+ error_msg = f"Request failed with status code {response.status_code}"
117
+ print(error_msg)
118
+ raise gr.Error(error_msg)
119
+ except json.JSONDecodeError:
120
+ error_msg = f"Failed to decode response as JSON: {response.text}"
121
+ print(error_msg)
122
+ raise gr.Error(error_msg)
123
+
124
+
125
+ def retry_fun_beta(chat_history_beta ):
126
+ """
127
+ Retries the prediction for the last message in the chat history.
128
+ Removes the last interaction and gets a new prediction for the same message from Zephyr-7b-Beta
129
+ """
130
+ if not chat_history or len(chat_history) < 1:
131
+ raise gr.Error("Chat history is empty or invalid.")
132
+
133
+ message = chat_history_beta[-1][0]
134
+ chat_history_beta.pop()
135
+ _, updated_chat_history_beta = predict_beta(message, chat_history_beta)
136
+ return updated_chat_history_beta
137
+
138
+
139
+ def retry_fun_alpha(chat_history_alpha ):
140
+ """
141
+ Retries the prediction for the last message in the chat history.
142
+ Removes the last interaction and gets a new prediction for the same message from Zephyr-7b-Alpha
143
+ """
144
+ if not chat_history or len(chat_history) < 1:
145
+ raise gr.Error("Chat history is empty or invalid.")
146
+
147
+ message = chat_history_alpha[-1][0]
148
+ chat_history_alpha.pop()
149
+ _, updated_chat_history_alpha = predict_alpha(message, chat_history_alpha)
150
+ return updated_chat_history_alpha
151
+
152
+ # Create chatbot components
153
+ chat_beta = gr.Chatbot(label="zephyr-7b-beta")
154
+ chat_alpha = gr.Chatbot(label="zephyr-7b-alpha")
155
+
156
+ # Create input and button components
157
+ textbox = gr.Textbox(container=False,
158
+ placeholder='Enter text and click the Submit button or press Enter')
159
+ submit = gr.Button('Submit', variant='primary',)
160
+ retry = gr.Button('🔄Retry', variant='secondary')
161
+ undo = gr.Button('↩️Undo', variant='secondary')
162
+
163
+ # Layout the components using Gradio Blocks API
164
+ with gr.Blocks() as demo:
165
+ with gr.Row():
166
+ chat_beta.render()
167
+ chat_alpha.render()
168
+ with gr.Group():
169
+ with gr.Row(equal_height=True):
170
+ with gr.Column(scale=5):
171
+ textbox.render()
172
+ with gr.Column(scale=1):
173
+ submit.render()
174
+ with gr.Row():
175
+ retry.render()
176
+ undo.render()
177
+ clear = gr.ClearButton(value='🗑️Clear',
178
+ components=[textbox,
179
+ chat_beta,
180
+ chat_alpha])
181
+
182
+ # Assign events to components
183
+ textbox.submit(predict_beta, [textbox, chat_beta], [textbox, chat_beta])
184
+ textbox.submit(predict_alpha, [textbox, chat_alpha], [textbox, chat_alpha])
185
+ submit.click(predict_beta, [textbox, chat_beta], [textbox, chat_beta])
186
+ submit.click(predict_alpha, [textbox, chat_alpha], [textbox, chat_alpha])
187
+
188
+ undo.click(lambda x:x[:-1], [chat_beta], [chat_beta])
189
+ undo.click(lambda x:x[:-1], [chat_alpha], [chat_alpha])
190
+
191
+ retry.click(retry_fun_beta, [chat_beta], [chat_beta])
192
+ retry.click(retry_fun_alpha, [chat_alpha], [chat_alpha])
193
+
194
+ # Launch the demo
195
+ demo.launch( debug=True)