Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
@@ -1,11 +1,10 @@
|
|
1 |
import gradio as gr
|
2 |
-
import aiohttp
|
3 |
-
import asyncio
|
4 |
import json
|
5 |
import os
|
6 |
import datetime
|
7 |
-
import
|
8 |
-
|
|
|
9 |
|
10 |
API_URL = os.environ.get('API_URL')
|
11 |
API_KEY = os.environ.get('API_KEY')
|
@@ -25,15 +24,12 @@ DEFAULT_PARAMS = {
|
|
25 |
"max_tokens": 512
|
26 |
}
|
27 |
|
28 |
-
|
29 |
|
30 |
def get_timestamp():
|
31 |
return datetime.datetime.now().strftime("%H:%M:%S")
|
32 |
|
33 |
-
should_stop = False
|
34 |
-
|
35 |
async def predict(message, history, system_prompt, temperature, top_p, top_k, frequency_penalty, presence_penalty, repetition_penalty, max_tokens):
|
36 |
-
global should_stop
|
37 |
history_format = [{"role": "system", "content": system_prompt}]
|
38 |
for human, assistant in history:
|
39 |
history_format.append({"role": "user", "content": human})
|
@@ -56,7 +52,7 @@ async def predict(message, history, system_prompt, temperature, top_p, top_k, fr
|
|
56 |
}
|
57 |
|
58 |
non_default_params = {k: v for k, v in current_params.items() if v != DEFAULT_PARAMS[k]}
|
59 |
-
|
60 |
if non_default_params and not message.startswith(('*', '"')):
|
61 |
for param, value in non_default_params.items():
|
62 |
print(f"{param}={value}")
|
@@ -74,28 +70,34 @@ async def predict(message, history, system_prompt, temperature, top_p, top_k, fr
|
|
74 |
"max_tokens": max_tokens
|
75 |
}
|
76 |
|
77 |
-
|
78 |
-
async with
|
79 |
-
|
80 |
-
|
81 |
-
|
82 |
-
|
83 |
-
line = line.decode('utf-8')
|
84 |
-
if line.startswith("data: "):
|
85 |
-
if line.strip() == "data: [DONE]":
|
86 |
break
|
87 |
-
|
88 |
-
|
89 |
-
if
|
90 |
-
|
91 |
-
|
92 |
-
|
93 |
-
|
94 |
-
|
95 |
-
|
96 |
-
|
97 |
-
|
98 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
99 |
|
100 |
def import_chat(custom_format_string):
|
101 |
try:
|
@@ -131,12 +133,11 @@ def export_chat(history, system_prompt):
|
|
131 |
export_data += f"<|assistant|> {assistant_msg}\n\n"
|
132 |
return export_data
|
133 |
|
134 |
-
def
|
135 |
-
|
136 |
-
|
137 |
-
return gr.update(interactive=True), gr.update(interactive=True)
|
138 |
|
139 |
-
with gr.Blocks(theme=
|
140 |
with gr.Row():
|
141 |
with gr.Column(scale=2):
|
142 |
chatbot = gr.Chatbot(value=[])
|
@@ -162,81 +163,85 @@ with gr.Blocks(theme=gr.themes.Monochrome()) as demo:
|
|
162 |
repetition_penalty = gr.Slider(0.01, 5, value=1.1, step=0.01, label="Repetition Penalty")
|
163 |
max_tokens = gr.Slider(1, 4096, value=512, step=1, label="Max Output (max_tokens)")
|
164 |
|
165 |
-
def user(user_message, history):
|
166 |
-
history = history or []
|
167 |
-
return "", history + [
|
168 |
|
169 |
async def bot(history, system_prompt, temperature, top_p, top_k, frequency_penalty, presence_penalty, repetition_penalty, max_tokens):
|
170 |
-
|
171 |
-
should_stop = False
|
172 |
-
history = history or []
|
173 |
if not history:
|
174 |
yield history
|
175 |
return
|
176 |
user_message = history[-1][0]
|
177 |
bot_message = predict(user_message, history[:-1], system_prompt, temperature, top_p, top_k, frequency_penalty, presence_penalty, repetition_penalty, max_tokens)
|
178 |
-
history[-1][1]
|
|
|
|
|
179 |
try:
|
180 |
async for chunk in bot_message:
|
181 |
-
if
|
182 |
break
|
183 |
-
history[-1][1]
|
184 |
yield history
|
185 |
-
except
|
186 |
-
|
187 |
-
history[-1][1] = "An error occurred while generating the response."
|
188 |
-
yield history
|
189 |
finally:
|
190 |
-
|
|
|
|
|
|
|
|
|
191 |
|
192 |
async def regenerate_response(history, system_prompt, temperature, top_p, top_k, frequency_penalty, presence_penalty, repetition_penalty, max_tokens):
|
193 |
-
|
194 |
-
|
195 |
-
|
196 |
-
|
197 |
-
|
|
|
|
|
|
|
|
|
|
|
198 |
async for new_history in bot(history, system_prompt, temperature, top_p, top_k, frequency_penalty, presence_penalty, repetition_penalty, max_tokens):
|
199 |
-
if should_stop:
|
200 |
-
break
|
201 |
yield new_history
|
202 |
else:
|
203 |
yield []
|
204 |
-
should_stop = False
|
205 |
|
206 |
def import_chat_wrapper(custom_format_string):
|
207 |
imported_history, imported_system_prompt = import_chat(custom_format_string)
|
208 |
-
return imported_history, imported_system_prompt
|
209 |
|
210 |
submit_event = msg.submit(user, [msg, chatbot], [msg, chatbot], queue=False).then(
|
211 |
bot, [chatbot, system_prompt, temperature, top_p, top_k, frequency_penalty, presence_penalty, repetition_penalty, max_tokens], chatbot,
|
212 |
-
concurrency_limit=
|
213 |
)
|
214 |
|
215 |
-
clear.click(lambda:
|
216 |
|
217 |
regenerate_event = regenerate.click(
|
218 |
regenerate_response,
|
219 |
[chatbot, system_prompt, temperature, top_p, top_k, frequency_penalty, presence_penalty, repetition_penalty, max_tokens],
|
220 |
chatbot,
|
221 |
-
concurrency_limit=
|
222 |
)
|
223 |
|
224 |
-
|
225 |
-
stop_generation,
|
226 |
-
inputs=[],
|
227 |
-
outputs=[msg, regenerate],
|
228 |
-
cancels=[submit_event, regenerate_event],
|
229 |
-
queue=False
|
230 |
-
)
|
231 |
-
|
232 |
-
import_button.click(import_chat_wrapper, inputs=[import_textbox], outputs=[chatbot, system_prompt], queue=False)
|
233 |
|
234 |
export_button.click(
|
235 |
export_chat,
|
236 |
inputs=[chatbot, system_prompt],
|
237 |
outputs=[import_textbox],
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
238 |
queue=False
|
239 |
)
|
240 |
|
241 |
if __name__ == "__main__":
|
242 |
-
demo.launch(debug=True,
|
|
|
1 |
import gradio as gr
|
|
|
|
|
2 |
import json
|
3 |
import os
|
4 |
import datetime
|
5 |
+
import asyncio
|
6 |
+
import aiohttp
|
7 |
+
from aiohttp import ClientSession
|
8 |
|
9 |
API_URL = os.environ.get('API_URL')
|
10 |
API_KEY = os.environ.get('API_KEY')
|
|
|
24 |
"max_tokens": 512
|
25 |
}
|
26 |
|
27 |
+
active_tasks = {}
|
28 |
|
29 |
def get_timestamp():
|
30 |
return datetime.datetime.now().strftime("%H:%M:%S")
|
31 |
|
|
|
|
|
32 |
async def predict(message, history, system_prompt, temperature, top_p, top_k, frequency_penalty, presence_penalty, repetition_penalty, max_tokens):
|
|
|
33 |
history_format = [{"role": "system", "content": system_prompt}]
|
34 |
for human, assistant in history:
|
35 |
history_format.append({"role": "user", "content": human})
|
|
|
52 |
}
|
53 |
|
54 |
non_default_params = {k: v for k, v in current_params.items() if v != DEFAULT_PARAMS[k]}
|
55 |
+
|
56 |
if non_default_params and not message.startswith(('*', '"')):
|
57 |
for param, value in non_default_params.items():
|
58 |
print(f"{param}={value}")
|
|
|
70 |
"max_tokens": max_tokens
|
71 |
}
|
72 |
|
73 |
+
try:
|
74 |
+
async with ClientSession() as session:
|
75 |
+
async with session.post(API_URL, headers=headers, json=data) as response:
|
76 |
+
partial_message = ""
|
77 |
+
async for line in response.content:
|
78 |
+
if asyncio.current_task().cancelled():
|
|
|
|
|
|
|
79 |
break
|
80 |
+
if line:
|
81 |
+
line = line.decode('utf-8')
|
82 |
+
if line.startswith("data: "):
|
83 |
+
if line.strip() == "data: [DONE]":
|
84 |
+
break
|
85 |
+
try:
|
86 |
+
json_data = json.loads(line[6:])
|
87 |
+
if 'choices' in json_data and json_data['choices']:
|
88 |
+
content = json_data['choices'][0]['delta'].get('content', '')
|
89 |
+
if content:
|
90 |
+
partial_message += content
|
91 |
+
yield partial_message
|
92 |
+
except json.JSONDecodeError:
|
93 |
+
continue
|
94 |
+
|
95 |
+
if partial_message:
|
96 |
+
yield partial_message
|
97 |
+
|
98 |
+
except Exception as e:
|
99 |
+
print(f"Request error: {e}")
|
100 |
+
yield f"An error occurred: {str(e)}"
|
101 |
|
102 |
def import_chat(custom_format_string):
|
103 |
try:
|
|
|
133 |
export_data += f"<|assistant|> {assistant_msg}\n\n"
|
134 |
return export_data
|
135 |
|
136 |
+
def sanitize_chatbot_history(history):
|
137 |
+
"""Ensure each entry in the chatbot history is a tuple of two items."""
|
138 |
+
return [tuple(entry[:2]) for entry in history]
|
|
|
139 |
|
140 |
+
with gr.Blocks(theme='gradio/monochrome') as demo:
|
141 |
with gr.Row():
|
142 |
with gr.Column(scale=2):
|
143 |
chatbot = gr.Chatbot(value=[])
|
|
|
163 |
repetition_penalty = gr.Slider(0.01, 5, value=1.1, step=0.01, label="Repetition Penalty")
|
164 |
max_tokens = gr.Slider(1, 4096, value=512, step=1, label="Max Output (max_tokens)")
|
165 |
|
166 |
+
async def user(user_message, history):
|
167 |
+
history = sanitize_chatbot_history(history or [])
|
168 |
+
return "", history + [(user_message, None)]
|
169 |
|
170 |
async def bot(history, system_prompt, temperature, top_p, top_k, frequency_penalty, presence_penalty, repetition_penalty, max_tokens):
|
171 |
+
history = sanitize_chatbot_history(history or [])
|
|
|
|
|
172 |
if not history:
|
173 |
yield history
|
174 |
return
|
175 |
user_message = history[-1][0]
|
176 |
bot_message = predict(user_message, history[:-1], system_prompt, temperature, top_p, top_k, frequency_penalty, presence_penalty, repetition_penalty, max_tokens)
|
177 |
+
history[-1] = (history[-1][0], "") # Ensure it's a tuple
|
178 |
+
task_id = id(asyncio.current_task())
|
179 |
+
active_tasks[task_id] = asyncio.current_task()
|
180 |
try:
|
181 |
async for chunk in bot_message:
|
182 |
+
if task_id not in active_tasks:
|
183 |
break
|
184 |
+
history[-1] = (history[-1][0], chunk) # Update as a tuple
|
185 |
yield history
|
186 |
+
except asyncio.CancelledError:
|
187 |
+
pass
|
|
|
|
|
188 |
finally:
|
189 |
+
if task_id in active_tasks:
|
190 |
+
del active_tasks[task_id]
|
191 |
+
if history[-1][1] == "":
|
192 |
+
history[-1] = (history[-1][0], " [Generation stopped]")
|
193 |
+
yield history
|
194 |
|
195 |
async def regenerate_response(history, system_prompt, temperature, top_p, top_k, frequency_penalty, presence_penalty, repetition_penalty, max_tokens):
|
196 |
+
# Cancel any ongoing generation
|
197 |
+
for task in list(active_tasks.values()):
|
198 |
+
task.cancel()
|
199 |
+
|
200 |
+
# Wait for a short time to ensure cancellation is processed
|
201 |
+
await asyncio.sleep(0.1)
|
202 |
+
|
203 |
+
history = sanitize_chatbot_history(history or [])
|
204 |
+
if history:
|
205 |
+
history[-1] = (history[-1][0], None) # Reset last response
|
206 |
async for new_history in bot(history, system_prompt, temperature, top_p, top_k, frequency_penalty, presence_penalty, repetition_penalty, max_tokens):
|
|
|
|
|
207 |
yield new_history
|
208 |
else:
|
209 |
yield []
|
|
|
210 |
|
211 |
def import_chat_wrapper(custom_format_string):
|
212 |
imported_history, imported_system_prompt = import_chat(custom_format_string)
|
213 |
+
return sanitize_chatbot_history(imported_history), imported_system_prompt
|
214 |
|
215 |
submit_event = msg.submit(user, [msg, chatbot], [msg, chatbot], queue=False).then(
|
216 |
bot, [chatbot, system_prompt, temperature, top_p, top_k, frequency_penalty, presence_penalty, repetition_penalty, max_tokens], chatbot,
|
217 |
+
concurrency_limit=5
|
218 |
)
|
219 |
|
220 |
+
clear.click(lambda: [], None, chatbot, queue=False)
|
221 |
|
222 |
regenerate_event = regenerate.click(
|
223 |
regenerate_response,
|
224 |
[chatbot, system_prompt, temperature, top_p, top_k, frequency_penalty, presence_penalty, repetition_penalty, max_tokens],
|
225 |
chatbot,
|
226 |
+
concurrency_limit=5
|
227 |
)
|
228 |
|
229 |
+
import_button.click(import_chat_wrapper, inputs=[import_textbox], outputs=[chatbot, system_prompt], concurrency_limit=5)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
230 |
|
231 |
export_button.click(
|
232 |
export_chat,
|
233 |
inputs=[chatbot, system_prompt],
|
234 |
outputs=[import_textbox],
|
235 |
+
concurrency_limit=5
|
236 |
+
)
|
237 |
+
|
238 |
+
stop_btn.click(
|
239 |
+
lambda: [task.cancel() for task in list(active_tasks.values())],
|
240 |
+
None,
|
241 |
+
None,
|
242 |
+
cancels=[submit_event, regenerate_event],
|
243 |
queue=False
|
244 |
)
|
245 |
|
246 |
if __name__ == "__main__":
|
247 |
+
demo.launch(debug=True, max_threads=20)
|