Corvius commited on
Commit
c58b4ed
Β·
verified Β·
1 Parent(s): c7a2372

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +76 -71
app.py CHANGED
@@ -1,11 +1,10 @@
1
  import gradio as gr
2
- import aiohttp
3
- import asyncio
4
  import json
5
  import os
6
  import datetime
7
- import time
8
- from concurrent.futures import ThreadPoolExecutor
 
9
 
10
  API_URL = os.environ.get('API_URL')
11
  API_KEY = os.environ.get('API_KEY')
@@ -25,15 +24,12 @@ DEFAULT_PARAMS = {
25
  "max_tokens": 512
26
  }
27
 
28
- thread_pool = ThreadPoolExecutor(max_workers=10)
29
 
30
  def get_timestamp():
31
  return datetime.datetime.now().strftime("%H:%M:%S")
32
 
33
- should_stop = False
34
-
35
  async def predict(message, history, system_prompt, temperature, top_p, top_k, frequency_penalty, presence_penalty, repetition_penalty, max_tokens):
36
- global should_stop
37
  history_format = [{"role": "system", "content": system_prompt}]
38
  for human, assistant in history:
39
  history_format.append({"role": "user", "content": human})
@@ -56,7 +52,7 @@ async def predict(message, history, system_prompt, temperature, top_p, top_k, fr
56
  }
57
 
58
  non_default_params = {k: v for k, v in current_params.items() if v != DEFAULT_PARAMS[k]}
59
-
60
  if non_default_params and not message.startswith(('*', '"')):
61
  for param, value in non_default_params.items():
62
  print(f"{param}={value}")
@@ -74,28 +70,34 @@ async def predict(message, history, system_prompt, temperature, top_p, top_k, fr
74
  "max_tokens": max_tokens
75
  }
76
 
77
- async with aiohttp.ClientSession() as session:
78
- async with session.post(API_URL, headers=headers, json=data) as response:
79
- partial_message = ""
80
- async for line in response.content:
81
- if should_stop:
82
- break
83
- line = line.decode('utf-8')
84
- if line.startswith("data: "):
85
- if line.strip() == "data: [DONE]":
86
  break
87
- try:
88
- json_data = json.loads(line[6:])
89
- if 'choices' in json_data and json_data['choices']:
90
- content = json_data['choices'][0]['delta'].get('content', '')
91
- if content:
92
- partial_message += content
93
- yield partial_message
94
- except json.JSONDecodeError:
95
- continue
96
-
97
- if partial_message:
98
- yield partial_message
 
 
 
 
 
 
 
 
 
99
 
100
  def import_chat(custom_format_string):
101
  try:
@@ -131,12 +133,11 @@ def export_chat(history, system_prompt):
131
  export_data += f"<|assistant|> {assistant_msg}\n\n"
132
  return export_data
133
 
134
- def stop_generation():
135
- global should_stop
136
- should_stop = True
137
- return gr.update(interactive=True), gr.update(interactive=True)
138
 
139
- with gr.Blocks(theme=gr.themes.Monochrome()) as demo:
140
  with gr.Row():
141
  with gr.Column(scale=2):
142
  chatbot = gr.Chatbot(value=[])
@@ -162,81 +163,85 @@ with gr.Blocks(theme=gr.themes.Monochrome()) as demo:
162
  repetition_penalty = gr.Slider(0.01, 5, value=1.1, step=0.01, label="Repetition Penalty")
163
  max_tokens = gr.Slider(1, 4096, value=512, step=1, label="Max Output (max_tokens)")
164
 
165
- def user(user_message, history):
166
- history = history or []
167
- return "", history + [[user_message, None]]
168
 
169
  async def bot(history, system_prompt, temperature, top_p, top_k, frequency_penalty, presence_penalty, repetition_penalty, max_tokens):
170
- global should_stop
171
- should_stop = False
172
- history = history or []
173
  if not history:
174
  yield history
175
  return
176
  user_message = history[-1][0]
177
  bot_message = predict(user_message, history[:-1], system_prompt, temperature, top_p, top_k, frequency_penalty, presence_penalty, repetition_penalty, max_tokens)
178
- history[-1][1] = ""
 
 
179
  try:
180
  async for chunk in bot_message:
181
- if should_stop:
182
  break
183
- history[-1][1] = chunk
184
  yield history
185
- except Exception as e:
186
- print(f"Error in bot function: {str(e)}")
187
- history[-1][1] = "An error occurred while generating the response."
188
- yield history
189
  finally:
190
- should_stop = False
 
 
 
 
191
 
192
  async def regenerate_response(history, system_prompt, temperature, top_p, top_k, frequency_penalty, presence_penalty, repetition_penalty, max_tokens):
193
- global should_stop
194
- should_stop = False
195
- if history and len(history) > 0:
196
- last_user_message = history[-1][0]
197
- history[-1][1] = None
 
 
 
 
 
198
  async for new_history in bot(history, system_prompt, temperature, top_p, top_k, frequency_penalty, presence_penalty, repetition_penalty, max_tokens):
199
- if should_stop:
200
- break
201
  yield new_history
202
  else:
203
  yield []
204
- should_stop = False
205
 
206
  def import_chat_wrapper(custom_format_string):
207
  imported_history, imported_system_prompt = import_chat(custom_format_string)
208
- return imported_history, imported_system_prompt
209
 
210
  submit_event = msg.submit(user, [msg, chatbot], [msg, chatbot], queue=False).then(
211
  bot, [chatbot, system_prompt, temperature, top_p, top_k, frequency_penalty, presence_penalty, repetition_penalty, max_tokens], chatbot,
212
- concurrency_limit=10
213
  )
214
 
215
- clear.click(lambda: None, None, chatbot, queue=False)
216
 
217
  regenerate_event = regenerate.click(
218
  regenerate_response,
219
  [chatbot, system_prompt, temperature, top_p, top_k, frequency_penalty, presence_penalty, repetition_penalty, max_tokens],
220
  chatbot,
221
- concurrency_limit=10
222
  )
223
 
224
- stop_btn.click(
225
- stop_generation,
226
- inputs=[],
227
- outputs=[msg, regenerate],
228
- cancels=[submit_event, regenerate_event],
229
- queue=False
230
- )
231
-
232
- import_button.click(import_chat_wrapper, inputs=[import_textbox], outputs=[chatbot, system_prompt], queue=False)
233
 
234
  export_button.click(
235
  export_chat,
236
  inputs=[chatbot, system_prompt],
237
  outputs=[import_textbox],
 
 
 
 
 
 
 
 
238
  queue=False
239
  )
240
 
241
  if __name__ == "__main__":
242
- demo.launch(debug=True, server_name="0.0.0.0", server_port=7860, share=True, max_threads=40)
 
1
  import gradio as gr
 
 
2
  import json
3
  import os
4
  import datetime
5
+ import asyncio
6
+ import aiohttp
7
+ from aiohttp import ClientSession
8
 
9
  API_URL = os.environ.get('API_URL')
10
  API_KEY = os.environ.get('API_KEY')
 
24
  "max_tokens": 512
25
  }
26
 
27
+ active_tasks = {}
28
 
29
  def get_timestamp():
30
  return datetime.datetime.now().strftime("%H:%M:%S")
31
 
 
 
32
  async def predict(message, history, system_prompt, temperature, top_p, top_k, frequency_penalty, presence_penalty, repetition_penalty, max_tokens):
 
33
  history_format = [{"role": "system", "content": system_prompt}]
34
  for human, assistant in history:
35
  history_format.append({"role": "user", "content": human})
 
52
  }
53
 
54
  non_default_params = {k: v for k, v in current_params.items() if v != DEFAULT_PARAMS[k]}
55
+
56
  if non_default_params and not message.startswith(('*', '"')):
57
  for param, value in non_default_params.items():
58
  print(f"{param}={value}")
 
70
  "max_tokens": max_tokens
71
  }
72
 
73
+ try:
74
+ async with ClientSession() as session:
75
+ async with session.post(API_URL, headers=headers, json=data) as response:
76
+ partial_message = ""
77
+ async for line in response.content:
78
+ if asyncio.current_task().cancelled():
 
 
 
79
  break
80
+ if line:
81
+ line = line.decode('utf-8')
82
+ if line.startswith("data: "):
83
+ if line.strip() == "data: [DONE]":
84
+ break
85
+ try:
86
+ json_data = json.loads(line[6:])
87
+ if 'choices' in json_data and json_data['choices']:
88
+ content = json_data['choices'][0]['delta'].get('content', '')
89
+ if content:
90
+ partial_message += content
91
+ yield partial_message
92
+ except json.JSONDecodeError:
93
+ continue
94
+
95
+ if partial_message:
96
+ yield partial_message
97
+
98
+ except Exception as e:
99
+ print(f"Request error: {e}")
100
+ yield f"An error occurred: {str(e)}"
101
 
102
  def import_chat(custom_format_string):
103
  try:
 
133
  export_data += f"<|assistant|> {assistant_msg}\n\n"
134
  return export_data
135
 
136
+ def sanitize_chatbot_history(history):
137
+ """Ensure each entry in the chatbot history is a tuple of two items."""
138
+ return [tuple(entry[:2]) for entry in history]
 
139
 
140
+ with gr.Blocks(theme='gradio/monochrome') as demo:
141
  with gr.Row():
142
  with gr.Column(scale=2):
143
  chatbot = gr.Chatbot(value=[])
 
163
  repetition_penalty = gr.Slider(0.01, 5, value=1.1, step=0.01, label="Repetition Penalty")
164
  max_tokens = gr.Slider(1, 4096, value=512, step=1, label="Max Output (max_tokens)")
165
 
166
+ async def user(user_message, history):
167
+ history = sanitize_chatbot_history(history or [])
168
+ return "", history + [(user_message, None)]
169
 
170
  async def bot(history, system_prompt, temperature, top_p, top_k, frequency_penalty, presence_penalty, repetition_penalty, max_tokens):
171
+ history = sanitize_chatbot_history(history or [])
 
 
172
  if not history:
173
  yield history
174
  return
175
  user_message = history[-1][0]
176
  bot_message = predict(user_message, history[:-1], system_prompt, temperature, top_p, top_k, frequency_penalty, presence_penalty, repetition_penalty, max_tokens)
177
+ history[-1] = (history[-1][0], "") # Ensure it's a tuple
178
+ task_id = id(asyncio.current_task())
179
+ active_tasks[task_id] = asyncio.current_task()
180
  try:
181
  async for chunk in bot_message:
182
+ if task_id not in active_tasks:
183
  break
184
+ history[-1] = (history[-1][0], chunk) # Update as a tuple
185
  yield history
186
+ except asyncio.CancelledError:
187
+ pass
 
 
188
  finally:
189
+ if task_id in active_tasks:
190
+ del active_tasks[task_id]
191
+ if history[-1][1] == "":
192
+ history[-1] = (history[-1][0], " [Generation stopped]")
193
+ yield history
194
 
195
  async def regenerate_response(history, system_prompt, temperature, top_p, top_k, frequency_penalty, presence_penalty, repetition_penalty, max_tokens):
196
+ # Cancel any ongoing generation
197
+ for task in list(active_tasks.values()):
198
+ task.cancel()
199
+
200
+ # Wait for a short time to ensure cancellation is processed
201
+ await asyncio.sleep(0.1)
202
+
203
+ history = sanitize_chatbot_history(history or [])
204
+ if history:
205
+ history[-1] = (history[-1][0], None) # Reset last response
206
  async for new_history in bot(history, system_prompt, temperature, top_p, top_k, frequency_penalty, presence_penalty, repetition_penalty, max_tokens):
 
 
207
  yield new_history
208
  else:
209
  yield []
 
210
 
211
  def import_chat_wrapper(custom_format_string):
212
  imported_history, imported_system_prompt = import_chat(custom_format_string)
213
+ return sanitize_chatbot_history(imported_history), imported_system_prompt
214
 
215
  submit_event = msg.submit(user, [msg, chatbot], [msg, chatbot], queue=False).then(
216
  bot, [chatbot, system_prompt, temperature, top_p, top_k, frequency_penalty, presence_penalty, repetition_penalty, max_tokens], chatbot,
217
+ concurrency_limit=5
218
  )
219
 
220
+ clear.click(lambda: [], None, chatbot, queue=False)
221
 
222
  regenerate_event = regenerate.click(
223
  regenerate_response,
224
  [chatbot, system_prompt, temperature, top_p, top_k, frequency_penalty, presence_penalty, repetition_penalty, max_tokens],
225
  chatbot,
226
+ concurrency_limit=5
227
  )
228
 
229
+ import_button.click(import_chat_wrapper, inputs=[import_textbox], outputs=[chatbot, system_prompt], concurrency_limit=5)
 
 
 
 
 
 
 
 
230
 
231
  export_button.click(
232
  export_chat,
233
  inputs=[chatbot, system_prompt],
234
  outputs=[import_textbox],
235
+ concurrency_limit=5
236
+ )
237
+
238
+ stop_btn.click(
239
+ lambda: [task.cancel() for task in list(active_tasks.values())],
240
+ None,
241
+ None,
242
+ cancels=[submit_event, regenerate_event],
243
  queue=False
244
  )
245
 
246
  if __name__ == "__main__":
247
+ demo.launch(debug=True, max_threads=20)