Spaces:
Runtime error
Runtime error
????????wtf
Browse files
app.py
CHANGED
@@ -4,7 +4,7 @@ import os
|
|
4 |
import datetime
|
5 |
import asyncio
|
6 |
import aiohttp
|
7 |
-
from aiohttp import ClientSession
|
8 |
|
9 |
API_URL = os.environ.get('API_URL')
|
10 |
API_KEY = os.environ.get('API_KEY')
|
@@ -24,8 +24,6 @@ DEFAULT_PARAMS = {
|
|
24 |
"max_tokens": 512
|
25 |
}
|
26 |
|
27 |
-
active_tasks = {}
|
28 |
-
|
29 |
def get_timestamp():
|
30 |
return datetime.datetime.now().strftime("%H:%M:%S")
|
31 |
|
@@ -71,11 +69,13 @@ async def predict(message, history, system_prompt, temperature, top_p, top_k, fr
|
|
71 |
}
|
72 |
|
73 |
try:
|
74 |
-
|
|
|
75 |
async with session.post(API_URL, headers=headers, json=data) as response:
|
76 |
partial_message = ""
|
77 |
async for line in response.content:
|
78 |
if asyncio.current_task().cancelled():
|
|
|
79 |
break
|
80 |
if line:
|
81 |
line = line.decode('utf-8')
|
@@ -95,6 +95,9 @@ async def predict(message, history, system_prompt, temperature, top_p, top_k, fr
|
|
95 |
if partial_message:
|
96 |
yield partial_message
|
97 |
|
|
|
|
|
|
|
98 |
except Exception as e:
|
99 |
print(f"Request error: {e}")
|
100 |
yield f"An error occurred: {str(e)}"
|
@@ -135,9 +138,70 @@ def export_chat(history, system_prompt):
|
|
135 |
|
136 |
def sanitize_chatbot_history(history):
|
137 |
"""Ensure each entry in the chatbot history is a tuple of two items."""
|
138 |
-
return [tuple(entry[:2]) for entry in history]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
139 |
|
140 |
with gr.Blocks(theme='gradio/monochrome') as demo:
|
|
|
|
|
141 |
with gr.Row():
|
142 |
with gr.Column(scale=2):
|
143 |
chatbot = gr.Chatbot(value=[])
|
@@ -163,85 +227,43 @@ with gr.Blocks(theme='gradio/monochrome') as demo:
|
|
163 |
repetition_penalty = gr.Slider(0.01, 5, value=1.1, step=0.01, label="Repetition Penalty")
|
164 |
max_tokens = gr.Slider(1, 4096, value=512, step=1, label="Max Output (max_tokens)")
|
165 |
|
166 |
-
async def user(user_message, history):
|
167 |
-
history = sanitize_chatbot_history(history or [])
|
168 |
-
return "", history + [(user_message, None)]
|
169 |
-
|
170 |
-
async def bot(history, system_prompt, temperature, top_p, top_k, frequency_penalty, presence_penalty, repetition_penalty, max_tokens):
|
171 |
-
history = sanitize_chatbot_history(history or [])
|
172 |
-
if not history:
|
173 |
-
yield history
|
174 |
-
return
|
175 |
-
user_message = history[-1][0]
|
176 |
-
bot_message = predict(user_message, history[:-1], system_prompt, temperature, top_p, top_k, frequency_penalty, presence_penalty, repetition_penalty, max_tokens)
|
177 |
-
history[-1] = (history[-1][0], "") # Ensure it's a tuple
|
178 |
-
task_id = id(asyncio.current_task())
|
179 |
-
active_tasks[task_id] = asyncio.current_task()
|
180 |
-
try:
|
181 |
-
async for chunk in bot_message:
|
182 |
-
if task_id not in active_tasks:
|
183 |
-
break
|
184 |
-
history[-1] = (history[-1][0], chunk) # Update as a tuple
|
185 |
-
yield history
|
186 |
-
except asyncio.CancelledError:
|
187 |
-
pass
|
188 |
-
finally:
|
189 |
-
if task_id in active_tasks:
|
190 |
-
del active_tasks[task_id]
|
191 |
-
if history[-1][1] == "":
|
192 |
-
history[-1] = (history[-1][0], " [Generation stopped]")
|
193 |
-
yield history
|
194 |
-
|
195 |
-
async def regenerate_response(history, system_prompt, temperature, top_p, top_k, frequency_penalty, presence_penalty, repetition_penalty, max_tokens):
|
196 |
-
# Cancel any ongoing generation
|
197 |
-
for task in list(active_tasks.values()):
|
198 |
-
task.cancel()
|
199 |
-
|
200 |
-
# Wait for a short time to ensure cancellation is processed
|
201 |
-
await asyncio.sleep(0.1)
|
202 |
-
|
203 |
-
history = sanitize_chatbot_history(history or [])
|
204 |
-
if history:
|
205 |
-
history[-1] = (history[-1][0], None) # Reset last response
|
206 |
-
async for new_history in bot(history, system_prompt, temperature, top_p, top_k, frequency_penalty, presence_penalty, repetition_penalty, max_tokens):
|
207 |
-
yield new_history
|
208 |
-
else:
|
209 |
-
yield []
|
210 |
-
|
211 |
-
def import_chat_wrapper(custom_format_string):
|
212 |
-
imported_history, imported_system_prompt = import_chat(custom_format_string)
|
213 |
-
return sanitize_chatbot_history(imported_history), imported_system_prompt
|
214 |
-
|
215 |
submit_event = msg.submit(user, [msg, chatbot], [msg, chatbot], queue=False).then(
|
216 |
-
bot, [chatbot, system_prompt, temperature, top_p, top_k, frequency_penalty, presence_penalty, repetition_penalty, max_tokens], chatbot,
|
217 |
-
concurrency_limit=
|
218 |
)
|
219 |
|
220 |
clear.click(lambda: [], None, chatbot, queue=False)
|
221 |
|
222 |
regenerate_event = regenerate.click(
|
223 |
regenerate_response,
|
224 |
-
[chatbot, system_prompt, temperature, top_p, top_k, frequency_penalty, presence_penalty, repetition_penalty, max_tokens],
|
225 |
chatbot,
|
226 |
-
concurrency_limit=
|
227 |
)
|
228 |
|
229 |
-
import_button.click(import_chat_wrapper, inputs=[import_textbox], outputs=[chatbot, system_prompt], concurrency_limit=
|
230 |
|
231 |
export_button.click(
|
232 |
export_chat,
|
233 |
inputs=[chatbot, system_prompt],
|
234 |
outputs=[import_textbox],
|
235 |
-
concurrency_limit=
|
236 |
)
|
237 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
238 |
stop_btn.click(
|
239 |
-
|
240 |
-
|
241 |
-
|
242 |
cancels=[submit_event, regenerate_event],
|
243 |
queue=False
|
244 |
)
|
245 |
|
246 |
if __name__ == "__main__":
|
247 |
-
demo.launch(debug=True, max_threads=
|
|
|
4 |
import datetime
|
5 |
import asyncio
|
6 |
import aiohttp
|
7 |
+
from aiohttp import ClientSession, ClientTimeout
|
8 |
|
9 |
API_URL = os.environ.get('API_URL')
|
10 |
API_KEY = os.environ.get('API_KEY')
|
|
|
24 |
"max_tokens": 512
|
25 |
}
|
26 |
|
|
|
|
|
27 |
def get_timestamp():
|
28 |
return datetime.datetime.now().strftime("%H:%M:%S")
|
29 |
|
|
|
69 |
}
|
70 |
|
71 |
try:
|
72 |
+
timeout = ClientTimeout(total=60) # Set a 60-second timeout
|
73 |
+
async with ClientSession(timeout=timeout) as session:
|
74 |
async with session.post(API_URL, headers=headers, json=data) as response:
|
75 |
partial_message = ""
|
76 |
async for line in response.content:
|
77 |
if asyncio.current_task().cancelled():
|
78 |
+
print("Task cancelled during API request")
|
79 |
break
|
80 |
if line:
|
81 |
line = line.decode('utf-8')
|
|
|
95 |
if partial_message:
|
96 |
yield partial_message
|
97 |
|
98 |
+
except asyncio.TimeoutError:
|
99 |
+
print("Request timed out")
|
100 |
+
yield "Request timed out. Please try again."
|
101 |
except Exception as e:
|
102 |
print(f"Request error: {e}")
|
103 |
yield f"An error occurred: {str(e)}"
|
|
|
138 |
|
139 |
def sanitize_chatbot_history(history):
|
140 |
"""Ensure each entry in the chatbot history is a tuple of two items."""
|
141 |
+
return [tuple(entry[:2]) if isinstance(entry, (list, tuple)) else (str(entry), None) for entry in history]
|
142 |
+
|
143 |
+
async def user(user_message, history):
|
144 |
+
history = sanitize_chatbot_history(history or [])
|
145 |
+
return "", history + [(user_message, None)]
|
146 |
+
|
147 |
+
async def bot(history, system_prompt, temperature, top_p, top_k, frequency_penalty, presence_penalty, repetition_penalty, max_tokens, task_info):
|
148 |
+
history = sanitize_chatbot_history(history or [])
|
149 |
+
if not history:
|
150 |
+
yield history
|
151 |
+
return
|
152 |
+
user_message = history[-1][0]
|
153 |
+
bot_message = predict(user_message, history[:-1], system_prompt, temperature, top_p, top_k, frequency_penalty, presence_penalty, repetition_penalty, max_tokens)
|
154 |
+
history[-1] = (history[-1][0], "")
|
155 |
+
task = asyncio.current_task()
|
156 |
+
task_info['task'] = task
|
157 |
+
task_info['stop_requested'] = False
|
158 |
+
try:
|
159 |
+
async for chunk in bot_message:
|
160 |
+
if task_info.get('stop_requested', False):
|
161 |
+
print("Stop requested, breaking the loop")
|
162 |
+
break
|
163 |
+
history[-1] = (history[-1][0], chunk)
|
164 |
+
yield history
|
165 |
+
except asyncio.CancelledError:
|
166 |
+
print("Bot generation cancelled")
|
167 |
+
except GeneratorExit:
|
168 |
+
print("Generator exited")
|
169 |
+
except Exception as e:
|
170 |
+
print(f"Error in bot generation: {e}")
|
171 |
+
finally:
|
172 |
+
if history[-1][1] == "":
|
173 |
+
history[-1] = (history[-1][0], " [Generation stopped]")
|
174 |
+
task_info['task'] = None
|
175 |
+
task_info['stop_requested'] = False
|
176 |
+
yield history
|
177 |
+
|
178 |
+
async def regenerate_response(history, system_prompt, temperature, top_p, top_k, frequency_penalty, presence_penalty, repetition_penalty, max_tokens, task_info):
|
179 |
+
if 'task' in task_info and task_info['task']:
|
180 |
+
print("Cancelling previous task")
|
181 |
+
task_info['stop_requested'] = True
|
182 |
+
task_info['task'].cancel()
|
183 |
+
|
184 |
+
await asyncio.sleep(0.1)
|
185 |
+
|
186 |
+
history = sanitize_chatbot_history(history or [])
|
187 |
+
if history:
|
188 |
+
history[-1] = (history[-1][0], None)
|
189 |
+
try:
|
190 |
+
async for new_history in bot(history, system_prompt, temperature, top_p, top_k, frequency_penalty, presence_penalty, repetition_penalty, max_tokens, task_info):
|
191 |
+
yield sanitize_chatbot_history(new_history)
|
192 |
+
except Exception as e:
|
193 |
+
print(f"Error in regenerate_response: {e}")
|
194 |
+
yield history
|
195 |
+
else:
|
196 |
+
yield []
|
197 |
+
|
198 |
+
def import_chat_wrapper(custom_format_string):
|
199 |
+
imported_history, imported_system_prompt = import_chat(custom_format_string)
|
200 |
+
return sanitize_chatbot_history(imported_history), imported_system_prompt
|
201 |
|
202 |
with gr.Blocks(theme='gradio/monochrome') as demo:
|
203 |
+
task_info = gr.State({'task': None, 'stop_requested': False})
|
204 |
+
|
205 |
with gr.Row():
|
206 |
with gr.Column(scale=2):
|
207 |
chatbot = gr.Chatbot(value=[])
|
|
|
227 |
repetition_penalty = gr.Slider(0.01, 5, value=1.1, step=0.01, label="Repetition Penalty")
|
228 |
max_tokens = gr.Slider(1, 4096, value=512, step=1, label="Max Output (max_tokens)")
|
229 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
230 |
submit_event = msg.submit(user, [msg, chatbot], [msg, chatbot], queue=False).then(
|
231 |
+
bot, [chatbot, system_prompt, temperature, top_p, top_k, frequency_penalty, presence_penalty, repetition_penalty, max_tokens, task_info], chatbot,
|
232 |
+
concurrency_limit=10
|
233 |
)
|
234 |
|
235 |
clear.click(lambda: [], None, chatbot, queue=False)
|
236 |
|
237 |
regenerate_event = regenerate.click(
|
238 |
regenerate_response,
|
239 |
+
[chatbot, system_prompt, temperature, top_p, top_k, frequency_penalty, presence_penalty, repetition_penalty, max_tokens, task_info],
|
240 |
chatbot,
|
241 |
+
concurrency_limit=10
|
242 |
)
|
243 |
|
244 |
+
import_button.click(import_chat_wrapper, inputs=[import_textbox], outputs=[chatbot, system_prompt], concurrency_limit=10)
|
245 |
|
246 |
export_button.click(
|
247 |
export_chat,
|
248 |
inputs=[chatbot, system_prompt],
|
249 |
outputs=[import_textbox],
|
250 |
+
concurrency_limit=10
|
251 |
)
|
252 |
|
253 |
+
def stop_generation(task_info):
|
254 |
+
if 'task' in task_info and task_info['task']:
|
255 |
+
print("Stop requested")
|
256 |
+
task_info['stop_requested'] = True
|
257 |
+
task_info['task'].cancel()
|
258 |
+
return task_info
|
259 |
+
|
260 |
stop_btn.click(
|
261 |
+
stop_generation,
|
262 |
+
inputs=[task_info],
|
263 |
+
outputs=[task_info],
|
264 |
cancels=[submit_event, regenerate_event],
|
265 |
queue=False
|
266 |
)
|
267 |
|
268 |
if __name__ == "__main__":
|
269 |
+
demo.launch(debug=True, max_threads=40)
|