added avatars, dynamic bubble-size, upvote buttons
Browse files
app.py
CHANGED
@@ -54,9 +54,9 @@ def predict(message, chatbot, system_prompt="", temperature=0.9, max_new_tokens=
|
|
54 |
"inputs": input_prompt,
|
55 |
"parameters": {
|
56 |
"max_new_tokens":max_new_tokens,
|
57 |
-
"temperature"
|
58 |
-
"top_p"
|
59 |
-
"repetition_penalty"
|
60 |
"do_sample":True,
|
61 |
},
|
62 |
}
|
@@ -114,9 +114,9 @@ def predict_batch(message, chatbot, system_prompt="", temperature=0.9, max_new_t
|
|
114 |
"inputs": input_prompt,
|
115 |
"parameters": {
|
116 |
"max_new_tokens":max_new_tokens,
|
117 |
-
"temperature"
|
118 |
-
"top_p"
|
119 |
-
"repetition_penalty"
|
120 |
"do_sample":True,
|
121 |
},
|
122 |
}
|
@@ -138,6 +138,12 @@ def predict_batch(message, chatbot, system_prompt="", temperature=0.9, max_new_t
|
|
138 |
print(f"Request failed with status code {response.status_code}")
|
139 |
|
140 |
|
|
|
|
|
|
|
|
|
|
|
|
|
141 |
|
142 |
additional_inputs=[
|
143 |
gr.Textbox("", label="Optional system prompt"),
|
@@ -179,14 +185,36 @@ additional_inputs=[
|
|
179 |
)
|
180 |
]
|
181 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
182 |
|
183 |
# Gradio Demo
|
184 |
with gr.Blocks() as demo:
|
185 |
|
186 |
with gr.Tab("Streaming"):
|
187 |
-
gr.ChatInterface(predict, title=title, description=description, css=css, examples=examples, cache_examples=True, additional_inputs=additional_inputs,)
|
188 |
-
|
189 |
-
|
190 |
-
gr.ChatInterface(predict_batch, title=title, description=description, css=css, examples=examples, cache_examples=True, additional_inputs=additional_inputs,)
|
191 |
|
|
|
|
|
|
|
|
|
|
|
192 |
demo.queue(concurrency_count=75, max_size=100).launch(debug=True)
|
|
|
54 |
"inputs": input_prompt,
|
55 |
"parameters": {
|
56 |
"max_new_tokens":max_new_tokens,
|
57 |
+
"temperature":temperature,
|
58 |
+
"top_p":top_p,
|
59 |
+
"repetition_penalty":repetition_penalty,
|
60 |
"do_sample":True,
|
61 |
},
|
62 |
}
|
|
|
114 |
"inputs": input_prompt,
|
115 |
"parameters": {
|
116 |
"max_new_tokens":max_new_tokens,
|
117 |
+
"temperature":temperature,
|
118 |
+
"top_p":top_p,
|
119 |
+
"repetition_penalty":repetition_penalty,
|
120 |
"do_sample":True,
|
121 |
},
|
122 |
}
|
|
|
138 |
print(f"Request failed with status code {response.status_code}")
|
139 |
|
140 |
|
141 |
+
def vote(data: gr.LikeData):
|
142 |
+
if data.liked:
|
143 |
+
print("You upvoted this response: " + data.value)
|
144 |
+
else:
|
145 |
+
print("You downvoted this response: " + data.value)
|
146 |
+
|
147 |
|
148 |
additional_inputs=[
|
149 |
gr.Textbox("", label="Optional system prompt"),
|
|
|
185 |
)
|
186 |
]
|
187 |
|
188 |
+
chatbot_stream = gr.Chatbot(avatar_images=('user.png', 'bot2.png'),bubble_full_width = False)
|
189 |
+
chatbot_batch = gr.Chatbot(avatar_images=('user1.png', 'bot1.png'),bubble_full_width = False)
|
190 |
+
chat_interface_stream = gr.ChatInterface(predict,
|
191 |
+
title=title,
|
192 |
+
description=description,
|
193 |
+
chatbot=chatbot,
|
194 |
+
css=css,
|
195 |
+
examples=examples,
|
196 |
+
cache_examples=True,
|
197 |
+
additional_inputs=additional_inputs,)
|
198 |
+
chat_interface_batch = gr.ChatInterface(predict_batch,
|
199 |
+
title=title,
|
200 |
+
description=description,
|
201 |
+
chatbot=chatbot,
|
202 |
+
css=css,
|
203 |
+
examples=examples,
|
204 |
+
cache_examples=True,
|
205 |
+
additional_inputs=additional_inputs,)
|
206 |
|
207 |
# Gradio Demo
|
208 |
with gr.Blocks() as demo:
|
209 |
|
210 |
with gr.Tab("Streaming"):
|
211 |
+
#gr.ChatInterface(predict, title=title, description=description, css=css, examples=examples, cache_examples=True, additional_inputs=additional_inputs,)
|
212 |
+
chatbot_stream.like(vote, None, None)
|
213 |
+
chat_interface_stream.render()
|
|
|
214 |
|
215 |
+
with gr.Tab("Batch"):
|
216 |
+
#gr.ChatInterface(predict_batch, title=title, description=description, css=css, examples=examples, cache_examples=True, additional_inputs=additional_inputs,)
|
217 |
+
chatbot_batch.like(vote, None, None)
|
218 |
+
chat_interface_batch.render()
|
219 |
+
|
220 |
demo.queue(concurrency_count=75, max_size=100).launch(debug=True)
|