Corvius commited on
Commit
2c15d27
Β·
verified Β·
1 Parent(s): d542f92

????????wtf

Browse files
Files changed (1) hide show
  1. app.py +86 -64
app.py CHANGED
@@ -4,7 +4,7 @@ import os
4
  import datetime
5
  import asyncio
6
  import aiohttp
7
- from aiohttp import ClientSession
8
 
9
  API_URL = os.environ.get('API_URL')
10
  API_KEY = os.environ.get('API_KEY')
@@ -24,8 +24,6 @@ DEFAULT_PARAMS = {
24
  "max_tokens": 512
25
  }
26
 
27
- active_tasks = {}
28
-
29
  def get_timestamp():
30
  return datetime.datetime.now().strftime("%H:%M:%S")
31
 
@@ -71,11 +69,13 @@ async def predict(message, history, system_prompt, temperature, top_p, top_k, fr
71
  }
72
 
73
  try:
74
- async with ClientSession() as session:
 
75
  async with session.post(API_URL, headers=headers, json=data) as response:
76
  partial_message = ""
77
  async for line in response.content:
78
  if asyncio.current_task().cancelled():
 
79
  break
80
  if line:
81
  line = line.decode('utf-8')
@@ -95,6 +95,9 @@ async def predict(message, history, system_prompt, temperature, top_p, top_k, fr
95
  if partial_message:
96
  yield partial_message
97
 
 
 
 
98
  except Exception as e:
99
  print(f"Request error: {e}")
100
  yield f"An error occurred: {str(e)}"
@@ -135,9 +138,70 @@ def export_chat(history, system_prompt):
135
 
136
  def sanitize_chatbot_history(history):
137
  """Ensure each entry in the chatbot history is a tuple of two items."""
138
- return [tuple(entry[:2]) for entry in history]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
139
 
140
  with gr.Blocks(theme='gradio/monochrome') as demo:
 
 
141
  with gr.Row():
142
  with gr.Column(scale=2):
143
  chatbot = gr.Chatbot(value=[])
@@ -163,85 +227,43 @@ with gr.Blocks(theme='gradio/monochrome') as demo:
163
  repetition_penalty = gr.Slider(0.01, 5, value=1.1, step=0.01, label="Repetition Penalty")
164
  max_tokens = gr.Slider(1, 4096, value=512, step=1, label="Max Output (max_tokens)")
165
 
166
- async def user(user_message, history):
167
- history = sanitize_chatbot_history(history or [])
168
- return "", history + [(user_message, None)]
169
-
170
- async def bot(history, system_prompt, temperature, top_p, top_k, frequency_penalty, presence_penalty, repetition_penalty, max_tokens):
171
- history = sanitize_chatbot_history(history or [])
172
- if not history:
173
- yield history
174
- return
175
- user_message = history[-1][0]
176
- bot_message = predict(user_message, history[:-1], system_prompt, temperature, top_p, top_k, frequency_penalty, presence_penalty, repetition_penalty, max_tokens)
177
- history[-1] = (history[-1][0], "") # Ensure it's a tuple
178
- task_id = id(asyncio.current_task())
179
- active_tasks[task_id] = asyncio.current_task()
180
- try:
181
- async for chunk in bot_message:
182
- if task_id not in active_tasks:
183
- break
184
- history[-1] = (history[-1][0], chunk) # Update as a tuple
185
- yield history
186
- except asyncio.CancelledError:
187
- pass
188
- finally:
189
- if task_id in active_tasks:
190
- del active_tasks[task_id]
191
- if history[-1][1] == "":
192
- history[-1] = (history[-1][0], " [Generation stopped]")
193
- yield history
194
-
195
- async def regenerate_response(history, system_prompt, temperature, top_p, top_k, frequency_penalty, presence_penalty, repetition_penalty, max_tokens):
196
- # Cancel any ongoing generation
197
- for task in list(active_tasks.values()):
198
- task.cancel()
199
-
200
- # Wait for a short time to ensure cancellation is processed
201
- await asyncio.sleep(0.1)
202
-
203
- history = sanitize_chatbot_history(history or [])
204
- if history:
205
- history[-1] = (history[-1][0], None) # Reset last response
206
- async for new_history in bot(history, system_prompt, temperature, top_p, top_k, frequency_penalty, presence_penalty, repetition_penalty, max_tokens):
207
- yield new_history
208
- else:
209
- yield []
210
-
211
- def import_chat_wrapper(custom_format_string):
212
- imported_history, imported_system_prompt = import_chat(custom_format_string)
213
- return sanitize_chatbot_history(imported_history), imported_system_prompt
214
-
215
  submit_event = msg.submit(user, [msg, chatbot], [msg, chatbot], queue=False).then(
216
- bot, [chatbot, system_prompt, temperature, top_p, top_k, frequency_penalty, presence_penalty, repetition_penalty, max_tokens], chatbot,
217
- concurrency_limit=5
218
  )
219
 
220
  clear.click(lambda: [], None, chatbot, queue=False)
221
 
222
  regenerate_event = regenerate.click(
223
  regenerate_response,
224
- [chatbot, system_prompt, temperature, top_p, top_k, frequency_penalty, presence_penalty, repetition_penalty, max_tokens],
225
  chatbot,
226
- concurrency_limit=5
227
  )
228
 
229
- import_button.click(import_chat_wrapper, inputs=[import_textbox], outputs=[chatbot, system_prompt], concurrency_limit=5)
230
 
231
  export_button.click(
232
  export_chat,
233
  inputs=[chatbot, system_prompt],
234
  outputs=[import_textbox],
235
- concurrency_limit=5
236
  )
237
 
 
 
 
 
 
 
 
238
  stop_btn.click(
239
- lambda: [task.cancel() for task in list(active_tasks.values())],
240
- None,
241
- None,
242
  cancels=[submit_event, regenerate_event],
243
  queue=False
244
  )
245
 
246
  if __name__ == "__main__":
247
- demo.launch(debug=True, max_threads=20)
 
4
  import datetime
5
  import asyncio
6
  import aiohttp
7
+ from aiohttp import ClientSession, ClientTimeout
8
 
9
  API_URL = os.environ.get('API_URL')
10
  API_KEY = os.environ.get('API_KEY')
 
24
  "max_tokens": 512
25
  }
26
 
 
 
27
  def get_timestamp():
28
  return datetime.datetime.now().strftime("%H:%M:%S")
29
 
 
69
  }
70
 
71
  try:
72
+ timeout = ClientTimeout(total=60) # Set a 60-second timeout
73
+ async with ClientSession(timeout=timeout) as session:
74
  async with session.post(API_URL, headers=headers, json=data) as response:
75
  partial_message = ""
76
  async for line in response.content:
77
  if asyncio.current_task().cancelled():
78
+ print("Task cancelled during API request")
79
  break
80
  if line:
81
  line = line.decode('utf-8')
 
95
  if partial_message:
96
  yield partial_message
97
 
98
+ except asyncio.TimeoutError:
99
+ print("Request timed out")
100
+ yield "Request timed out. Please try again."
101
  except Exception as e:
102
  print(f"Request error: {e}")
103
  yield f"An error occurred: {str(e)}"
 
138
 
139
  def sanitize_chatbot_history(history):
140
  """Ensure each entry in the chatbot history is a tuple of two items."""
141
+ return [tuple(entry[:2]) if isinstance(entry, (list, tuple)) else (str(entry), None) for entry in history]
142
+
143
+ async def user(user_message, history):
144
+ history = sanitize_chatbot_history(history or [])
145
+ return "", history + [(user_message, None)]
146
+
147
+ async def bot(history, system_prompt, temperature, top_p, top_k, frequency_penalty, presence_penalty, repetition_penalty, max_tokens, task_info):
148
+ history = sanitize_chatbot_history(history or [])
149
+ if not history:
150
+ yield history
151
+ return
152
+ user_message = history[-1][0]
153
+ bot_message = predict(user_message, history[:-1], system_prompt, temperature, top_p, top_k, frequency_penalty, presence_penalty, repetition_penalty, max_tokens)
154
+ history[-1] = (history[-1][0], "")
155
+ task = asyncio.current_task()
156
+ task_info['task'] = task
157
+ task_info['stop_requested'] = False
158
+ try:
159
+ async for chunk in bot_message:
160
+ if task_info.get('stop_requested', False):
161
+ print("Stop requested, breaking the loop")
162
+ break
163
+ history[-1] = (history[-1][0], chunk)
164
+ yield history
165
+ except asyncio.CancelledError:
166
+ print("Bot generation cancelled")
167
+ except GeneratorExit:
168
+ print("Generator exited")
169
+ except Exception as e:
170
+ print(f"Error in bot generation: {e}")
171
+ finally:
172
+ if history[-1][1] == "":
173
+ history[-1] = (history[-1][0], " [Generation stopped]")
174
+ task_info['task'] = None
175
+ task_info['stop_requested'] = False
176
+ yield history
177
+
178
+ async def regenerate_response(history, system_prompt, temperature, top_p, top_k, frequency_penalty, presence_penalty, repetition_penalty, max_tokens, task_info):
179
+ if 'task' in task_info and task_info['task']:
180
+ print("Cancelling previous task")
181
+ task_info['stop_requested'] = True
182
+ task_info['task'].cancel()
183
+
184
+ await asyncio.sleep(0.1)
185
+
186
+ history = sanitize_chatbot_history(history or [])
187
+ if history:
188
+ history[-1] = (history[-1][0], None)
189
+ try:
190
+ async for new_history in bot(history, system_prompt, temperature, top_p, top_k, frequency_penalty, presence_penalty, repetition_penalty, max_tokens, task_info):
191
+ yield sanitize_chatbot_history(new_history)
192
+ except Exception as e:
193
+ print(f"Error in regenerate_response: {e}")
194
+ yield history
195
+ else:
196
+ yield []
197
+
198
+ def import_chat_wrapper(custom_format_string):
199
+ imported_history, imported_system_prompt = import_chat(custom_format_string)
200
+ return sanitize_chatbot_history(imported_history), imported_system_prompt
201
 
202
  with gr.Blocks(theme='gradio/monochrome') as demo:
203
+ task_info = gr.State({'task': None, 'stop_requested': False})
204
+
205
  with gr.Row():
206
  with gr.Column(scale=2):
207
  chatbot = gr.Chatbot(value=[])
 
227
  repetition_penalty = gr.Slider(0.01, 5, value=1.1, step=0.01, label="Repetition Penalty")
228
  max_tokens = gr.Slider(1, 4096, value=512, step=1, label="Max Output (max_tokens)")
229
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
230
  submit_event = msg.submit(user, [msg, chatbot], [msg, chatbot], queue=False).then(
231
+ bot, [chatbot, system_prompt, temperature, top_p, top_k, frequency_penalty, presence_penalty, repetition_penalty, max_tokens, task_info], chatbot,
232
+ concurrency_limit=10
233
  )
234
 
235
  clear.click(lambda: [], None, chatbot, queue=False)
236
 
237
  regenerate_event = regenerate.click(
238
  regenerate_response,
239
+ [chatbot, system_prompt, temperature, top_p, top_k, frequency_penalty, presence_penalty, repetition_penalty, max_tokens, task_info],
240
  chatbot,
241
+ concurrency_limit=10
242
  )
243
 
244
+ import_button.click(import_chat_wrapper, inputs=[import_textbox], outputs=[chatbot, system_prompt], concurrency_limit=10)
245
 
246
  export_button.click(
247
  export_chat,
248
  inputs=[chatbot, system_prompt],
249
  outputs=[import_textbox],
250
+ concurrency_limit=10
251
  )
252
 
253
+ def stop_generation(task_info):
254
+ if 'task' in task_info and task_info['task']:
255
+ print("Stop requested")
256
+ task_info['stop_requested'] = True
257
+ task_info['task'].cancel()
258
+ return task_info
259
+
260
  stop_btn.click(
261
+ stop_generation,
262
+ inputs=[task_info],
263
+ outputs=[task_info],
264
  cancels=[submit_event, regenerate_event],
265
  queue=False
266
  )
267
 
268
  if __name__ == "__main__":
269
+ demo.launch(debug=True, max_threads=40)