Corvius commited on
Commit
c7a2372
Β·
verified Β·
1 Parent(s): 1647f17

+2 405B instances ez clap + stochastic gradio spergasm prevention system

Browse files
Files changed (1) hide show
  1. app.py +63 -78
app.py CHANGED
@@ -1,12 +1,11 @@
1
  import gradio as gr
2
- import requests
 
3
  import json
4
- import threading
5
  import os
6
  import datetime
7
- import queue
8
  import time
9
- from requests.exceptions import RequestException
10
 
11
  API_URL = os.environ.get('API_URL')
12
  API_KEY = os.environ.get('API_KEY')
@@ -26,21 +25,15 @@ DEFAULT_PARAMS = {
26
  "max_tokens": 512
27
  }
28
 
29
- class ThreadLocalStorage:
30
- def __init__(self):
31
- self.stop_generation = False
32
- self.active_requests = set()
33
- self.lock = threading.Lock()
34
-
35
- thread_local = ThreadLocalStorage()
36
 
37
  def get_timestamp():
38
  return datetime.datetime.now().strftime("%H:%M:%S")
39
 
40
- def predict(message, history, system_prompt, temperature, top_p, top_k, frequency_penalty, presence_penalty, repetition_penalty, max_tokens):
41
- with thread_local.lock:
42
- thread_local.stop_generation = False
43
 
 
 
44
  history_format = [{"role": "system", "content": system_prompt}]
45
  for human, assistant in history:
46
  history_format.append({"role": "user", "content": human})
@@ -81,44 +74,28 @@ def predict(message, history, system_prompt, temperature, top_p, top_k, frequenc
81
  "max_tokens": max_tokens
82
  }
83
 
84
- try:
85
- with requests.post(API_URL, headers=headers, data=json.dumps(data), stream=True) as response:
86
- with thread_local.lock:
87
- thread_local.active_requests.add(response)
88
  partial_message = ""
89
- for line in response.iter_lines():
90
- with thread_local.lock:
91
- if thread_local.stop_generation:
92
- return partial_message
93
- if line:
94
- line = line.decode('utf-8')
95
- if line.startswith("data: "):
96
- if line.strip() == "data: [DONE]":
97
- break
98
- try:
99
- json_data = json.loads(line[6:])
100
- if 'choices' in json_data and json_data['choices']:
101
- content = json_data['choices'][0]['delta'].get('content', '')
102
- if content:
103
- partial_message += content
104
- yield partial_message
105
- except json.JSONDecodeError:
106
- continue
107
-
108
- if partial_message:
109
- yield partial_message
110
-
111
- except RequestException as e:
112
- error_message = f"Request error: {str(e)}"
113
- print(error_message)
114
- yield error_message
115
- except Exception as e:
116
- error_message = f"Unexpected error: {str(e)}"
117
- print(error_message)
118
- yield error_message
119
- finally:
120
- with thread_local.lock:
121
- thread_local.active_requests.discard(response)
122
 
123
  def import_chat(custom_format_string):
124
  try:
@@ -154,17 +131,10 @@ def export_chat(history, system_prompt):
154
  export_data += f"<|assistant|> {assistant_msg}\n\n"
155
  return export_data
156
 
157
- def stop_generation_func():
158
- with thread_local.lock:
159
- thread_local.stop_generation = True
160
- for request in thread_local.active_requests:
161
- try:
162
- request.close()
163
- except Exception as e:
164
- print(f"Error closing request: {str(e)}")
165
- thread_local.active_requests.clear()
166
- time.sleep(0.1)
167
- return gr.update(), gr.update()
168
 
169
  with gr.Blocks(theme=gr.themes.Monochrome()) as demo:
170
  with gr.Row():
@@ -196,16 +166,19 @@ with gr.Blocks(theme=gr.themes.Monochrome()) as demo:
196
  history = history or []
197
  return "", history + [[user_message, None]]
198
 
199
- def bot(history, system_prompt, temperature, top_p, top_k, frequency_penalty, presence_penalty, repetition_penalty, max_tokens):
 
 
200
  history = history or []
201
  if not history:
202
- return history
 
203
  user_message = history[-1][0]
204
  bot_message = predict(user_message, history[:-1], system_prompt, temperature, top_p, top_k, frequency_penalty, presence_penalty, repetition_penalty, max_tokens)
205
  history[-1][1] = ""
206
  try:
207
- for chunk in bot_message:
208
- if thread_local.stop_generation:
209
  break
210
  history[-1][1] = chunk
211
  yield history
@@ -213,45 +186,57 @@ with gr.Blocks(theme=gr.themes.Monochrome()) as demo:
213
  print(f"Error in bot function: {str(e)}")
214
  history[-1][1] = "An error occurred while generating the response."
215
  yield history
 
 
216
 
217
- def regenerate_response(history, system_prompt, temperature, top_p, top_k, frequency_penalty, presence_penalty, repetition_penalty, max_tokens):
 
 
218
  if history and len(history) > 0:
219
  last_user_message = history[-1][0]
220
- history[-1][1] = None
221
- for new_history in bot(history, system_prompt, temperature, top_p, top_k, frequency_penalty, presence_penalty, repetition_penalty, max_tokens):
 
 
222
  yield new_history
223
  else:
224
  yield []
 
225
 
226
  def import_chat_wrapper(custom_format_string):
227
  imported_history, imported_system_prompt = import_chat(custom_format_string)
228
  return imported_history, imported_system_prompt
229
 
230
- submit_event = msg.submit(user, [msg, chatbot], [msg, chatbot]).then(
231
- bot, [chatbot, system_prompt, temperature, top_p, top_k, frequency_penalty, presence_penalty, repetition_penalty, max_tokens], chatbot
 
232
  )
233
 
234
- clear.click(lambda: None, None, chatbot)
235
 
236
  regenerate_event = regenerate.click(
237
  regenerate_response,
238
  [chatbot, system_prompt, temperature, top_p, top_k, frequency_penalty, presence_penalty, repetition_penalty, max_tokens],
239
- chatbot
 
240
  )
241
 
242
  stop_btn.click(
243
- stop_generation_func,
244
  inputs=[],
245
- outputs=[chatbot, msg]
 
 
246
  )
247
 
248
- import_button.click(import_chat_wrapper, inputs=[import_textbox], outputs=[chatbot, system_prompt])
249
 
250
  export_button.click(
251
  export_chat,
252
  inputs=[chatbot, system_prompt],
253
- outputs=[import_textbox]
 
254
  )
255
 
256
  if __name__ == "__main__":
257
- demo.launch(debug=True, server_name="0.0.0.0", server_port=7860, share=True)
 
1
  import gradio as gr
2
+ import aiohttp
3
+ import asyncio
4
  import json
 
5
  import os
6
  import datetime
 
7
  import time
8
+ from concurrent.futures import ThreadPoolExecutor
9
 
10
  API_URL = os.environ.get('API_URL')
11
  API_KEY = os.environ.get('API_KEY')
 
25
  "max_tokens": 512
26
  }
27
 
28
+ thread_pool = ThreadPoolExecutor(max_workers=10)
 
 
 
 
 
 
29
 
30
  def get_timestamp():
31
  return datetime.datetime.now().strftime("%H:%M:%S")
32
 
33
+ should_stop = False
 
 
34
 
35
+ async def predict(message, history, system_prompt, temperature, top_p, top_k, frequency_penalty, presence_penalty, repetition_penalty, max_tokens):
36
+ global should_stop
37
  history_format = [{"role": "system", "content": system_prompt}]
38
  for human, assistant in history:
39
  history_format.append({"role": "user", "content": human})
 
74
  "max_tokens": max_tokens
75
  }
76
 
77
+ async with aiohttp.ClientSession() as session:
78
+ async with session.post(API_URL, headers=headers, json=data) as response:
 
 
79
  partial_message = ""
80
+ async for line in response.content:
81
+ if should_stop:
82
+ break
83
+ line = line.decode('utf-8')
84
+ if line.startswith("data: "):
85
+ if line.strip() == "data: [DONE]":
86
+ break
87
+ try:
88
+ json_data = json.loads(line[6:])
89
+ if 'choices' in json_data and json_data['choices']:
90
+ content = json_data['choices'][0]['delta'].get('content', '')
91
+ if content:
92
+ partial_message += content
93
+ yield partial_message
94
+ except json.JSONDecodeError:
95
+ continue
96
+
97
+ if partial_message:
98
+ yield partial_message
 
 
 
 
 
 
 
 
 
 
 
 
 
 
99
 
100
  def import_chat(custom_format_string):
101
  try:
 
131
  export_data += f"<|assistant|> {assistant_msg}\n\n"
132
  return export_data
133
 
134
+ def stop_generation():
135
+ global should_stop
136
+ should_stop = True
137
+ return gr.update(interactive=True), gr.update(interactive=True)
 
 
 
 
 
 
 
138
 
139
  with gr.Blocks(theme=gr.themes.Monochrome()) as demo:
140
  with gr.Row():
 
166
  history = history or []
167
  return "", history + [[user_message, None]]
168
 
169
+ async def bot(history, system_prompt, temperature, top_p, top_k, frequency_penalty, presence_penalty, repetition_penalty, max_tokens):
170
+ global should_stop
171
+ should_stop = False
172
  history = history or []
173
  if not history:
174
+ yield history
175
+ return
176
  user_message = history[-1][0]
177
  bot_message = predict(user_message, history[:-1], system_prompt, temperature, top_p, top_k, frequency_penalty, presence_penalty, repetition_penalty, max_tokens)
178
  history[-1][1] = ""
179
  try:
180
+ async for chunk in bot_message:
181
+ if should_stop:
182
  break
183
  history[-1][1] = chunk
184
  yield history
 
186
  print(f"Error in bot function: {str(e)}")
187
  history[-1][1] = "An error occurred while generating the response."
188
  yield history
189
+ finally:
190
+ should_stop = False
191
 
192
+ async def regenerate_response(history, system_prompt, temperature, top_p, top_k, frequency_penalty, presence_penalty, repetition_penalty, max_tokens):
193
+ global should_stop
194
+ should_stop = False
195
  if history and len(history) > 0:
196
  last_user_message = history[-1][0]
197
+ history[-1][1] = None
198
+ async for new_history in bot(history, system_prompt, temperature, top_p, top_k, frequency_penalty, presence_penalty, repetition_penalty, max_tokens):
199
+ if should_stop:
200
+ break
201
  yield new_history
202
  else:
203
  yield []
204
+ should_stop = False
205
 
206
  def import_chat_wrapper(custom_format_string):
207
  imported_history, imported_system_prompt = import_chat(custom_format_string)
208
  return imported_history, imported_system_prompt
209
 
210
+ submit_event = msg.submit(user, [msg, chatbot], [msg, chatbot], queue=False).then(
211
+ bot, [chatbot, system_prompt, temperature, top_p, top_k, frequency_penalty, presence_penalty, repetition_penalty, max_tokens], chatbot,
212
+ concurrency_limit=10
213
  )
214
 
215
+ clear.click(lambda: None, None, chatbot, queue=False)
216
 
217
  regenerate_event = regenerate.click(
218
  regenerate_response,
219
  [chatbot, system_prompt, temperature, top_p, top_k, frequency_penalty, presence_penalty, repetition_penalty, max_tokens],
220
+ chatbot,
221
+ concurrency_limit=10
222
  )
223
 
224
  stop_btn.click(
225
+ stop_generation,
226
  inputs=[],
227
+ outputs=[msg, regenerate],
228
+ cancels=[submit_event, regenerate_event],
229
+ queue=False
230
  )
231
 
232
+ import_button.click(import_chat_wrapper, inputs=[import_textbox], outputs=[chatbot, system_prompt], queue=False)
233
 
234
  export_button.click(
235
  export_chat,
236
  inputs=[chatbot, system_prompt],
237
+ outputs=[import_textbox],
238
+ queue=False
239
  )
240
 
241
  if __name__ == "__main__":
242
+ demo.launch(debug=True, server_name="0.0.0.0", server_port=7860, share=True, max_threads=40)