Corvius commited on
Commit
1647f17
Β·
verified Β·
1 Parent(s): 0de7b75

unqueue the queue + random shit fixes idk

Browse files
Files changed (1) hide show
  1. app.py +60 -31
app.py CHANGED
@@ -4,9 +4,10 @@ import json
4
  import threading
5
  import os
6
  import datetime
 
 
7
  from requests.exceptions import RequestException
8
 
9
- stop_generation = threading.Event()
10
  API_URL = os.environ.get('API_URL')
11
  API_KEY = os.environ.get('API_KEY')
12
 
@@ -15,8 +16,6 @@ headers = {
15
  "Content-Type": "application/json"
16
  }
17
 
18
- session = requests.Session()
19
-
20
  DEFAULT_PARAMS = {
21
  "temperature": 0.8,
22
  "top_p": 0.95,
@@ -27,12 +26,20 @@ DEFAULT_PARAMS = {
27
  "max_tokens": 512
28
  }
29
 
 
 
 
 
 
 
 
 
30
  def get_timestamp():
31
  return datetime.datetime.now().strftime("%H:%M:%S")
32
 
33
  def predict(message, history, system_prompt, temperature, top_p, top_k, frequency_penalty, presence_penalty, repetition_penalty, max_tokens):
34
- global stop_generation, session
35
- stop_generation.clear()
36
 
37
  history_format = [{"role": "system", "content": system_prompt}]
38
  for human, assistant in history:
@@ -56,7 +63,7 @@ def predict(message, history, system_prompt, temperature, top_p, top_k, frequenc
56
  }
57
 
58
  non_default_params = {k: v for k, v in current_params.items() if v != DEFAULT_PARAMS[k]}
59
-
60
  if non_default_params and not message.startswith(('*', '"')):
61
  for param, value in non_default_params.items():
62
  print(f"{param}={value}")
@@ -75,12 +82,14 @@ def predict(message, history, system_prompt, temperature, top_p, top_k, frequenc
75
  }
76
 
77
  try:
78
- with session.post(API_URL, headers=headers, data=json.dumps(data), stream=True) as response:
 
 
79
  partial_message = ""
80
  for line in response.iter_lines():
81
- if stop_generation.is_set():
82
- response.close()
83
- break
84
  if line:
85
  line = line.decode('utf-8')
86
  if line.startswith("data: "):
@@ -100,8 +109,16 @@ def predict(message, history, system_prompt, temperature, top_p, top_k, frequenc
100
  yield partial_message
101
 
102
  except RequestException as e:
103
- print(f"Request error: {e}")
104
- yield f"An error occurred: {str(e)}"
 
 
 
 
 
 
 
 
105
 
106
  def import_chat(custom_format_string):
107
  try:
@@ -138,12 +155,18 @@ def export_chat(history, system_prompt):
138
  return export_data
139
 
140
  def stop_generation_func():
141
- global stop_generation, session
142
- stop_generation.set()
143
- session.close()
144
- session = requests.Session()
145
-
146
- with gr.Blocks(theme='gradio/monochrome') as demo:
 
 
 
 
 
 
147
  with gr.Row():
148
  with gr.Column(scale=2):
149
  chatbot = gr.Chatbot(value=[])
@@ -174,20 +197,22 @@ with gr.Blocks(theme='gradio/monochrome') as demo:
174
  return "", history + [[user_message, None]]
175
 
176
  def bot(history, system_prompt, temperature, top_p, top_k, frequency_penalty, presence_penalty, repetition_penalty, max_tokens):
177
- global stop_generation
178
  history = history or []
179
  if not history:
180
  return history
181
  user_message = history[-1][0]
182
  bot_message = predict(user_message, history[:-1], system_prompt, temperature, top_p, top_k, frequency_penalty, presence_penalty, repetition_penalty, max_tokens)
183
  history[-1][1] = ""
184
- for chunk in bot_message:
185
- if stop_generation.is_set():
186
- history[-1][1] += " [Generation stopped]"
187
- break
188
- history[-1][1] = chunk
 
 
 
 
189
  yield history
190
- stop_generation.clear()
191
 
192
  def regenerate_response(history, system_prompt, temperature, top_p, top_k, frequency_penalty, presence_penalty, repetition_penalty, max_tokens):
193
  if history and len(history) > 0:
@@ -202,18 +227,24 @@ with gr.Blocks(theme='gradio/monochrome') as demo:
202
  imported_history, imported_system_prompt = import_chat(custom_format_string)
203
  return imported_history, imported_system_prompt
204
 
205
- msg.submit(user, [msg, chatbot], [msg, chatbot], queue=False).then(
206
  bot, [chatbot, system_prompt, temperature, top_p, top_k, frequency_penalty, presence_penalty, repetition_penalty, max_tokens], chatbot
207
  )
208
 
209
- clear.click(lambda: None, None, chatbot, queue=False)
210
 
211
- regenerate.click(
212
  regenerate_response,
213
  [chatbot, system_prompt, temperature, top_p, top_k, frequency_penalty, presence_penalty, repetition_penalty, max_tokens],
214
  chatbot
215
  )
216
 
 
 
 
 
 
 
217
  import_button.click(import_chat_wrapper, inputs=[import_textbox], outputs=[chatbot, system_prompt])
218
 
219
  export_button.click(
@@ -222,7 +253,5 @@ with gr.Blocks(theme='gradio/monochrome') as demo:
222
  outputs=[import_textbox]
223
  )
224
 
225
- stop_btn.click(stop_generation_func, inputs=[], outputs=[])
226
-
227
  if __name__ == "__main__":
228
- demo.launch(debug=True)
 
4
  import threading
5
  import os
6
  import datetime
7
+ import queue
8
+ import time
9
  from requests.exceptions import RequestException
10
 
 
11
  API_URL = os.environ.get('API_URL')
12
  API_KEY = os.environ.get('API_KEY')
13
 
 
16
  "Content-Type": "application/json"
17
  }
18
 
 
 
19
  DEFAULT_PARAMS = {
20
  "temperature": 0.8,
21
  "top_p": 0.95,
 
26
  "max_tokens": 512
27
  }
28
 
29
+ class ThreadLocalStorage:
30
+ def __init__(self):
31
+ self.stop_generation = False
32
+ self.active_requests = set()
33
+ self.lock = threading.Lock()
34
+
35
+ thread_local = ThreadLocalStorage()
36
+
37
  def get_timestamp():
38
  return datetime.datetime.now().strftime("%H:%M:%S")
39
 
40
  def predict(message, history, system_prompt, temperature, top_p, top_k, frequency_penalty, presence_penalty, repetition_penalty, max_tokens):
41
+ with thread_local.lock:
42
+ thread_local.stop_generation = False
43
 
44
  history_format = [{"role": "system", "content": system_prompt}]
45
  for human, assistant in history:
 
63
  }
64
 
65
  non_default_params = {k: v for k, v in current_params.items() if v != DEFAULT_PARAMS[k]}
66
+
67
  if non_default_params and not message.startswith(('*', '"')):
68
  for param, value in non_default_params.items():
69
  print(f"{param}={value}")
 
82
  }
83
 
84
  try:
85
+ with requests.post(API_URL, headers=headers, data=json.dumps(data), stream=True) as response:
86
+ with thread_local.lock:
87
+ thread_local.active_requests.add(response)
88
  partial_message = ""
89
  for line in response.iter_lines():
90
+ with thread_local.lock:
91
+ if thread_local.stop_generation:
92
+ return partial_message
93
  if line:
94
  line = line.decode('utf-8')
95
  if line.startswith("data: "):
 
109
  yield partial_message
110
 
111
  except RequestException as e:
112
+ error_message = f"Request error: {str(e)}"
113
+ print(error_message)
114
+ yield error_message
115
+ except Exception as e:
116
+ error_message = f"Unexpected error: {str(e)}"
117
+ print(error_message)
118
+ yield error_message
119
+ finally:
120
+ with thread_local.lock:
121
+ thread_local.active_requests.discard(response)
122
 
123
  def import_chat(custom_format_string):
124
  try:
 
155
  return export_data
156
 
157
  def stop_generation_func():
158
+ with thread_local.lock:
159
+ thread_local.stop_generation = True
160
+ for request in thread_local.active_requests:
161
+ try:
162
+ request.close()
163
+ except Exception as e:
164
+ print(f"Error closing request: {str(e)}")
165
+ thread_local.active_requests.clear()
166
+ time.sleep(0.1)
167
+ return gr.update(), gr.update()
168
+
169
+ with gr.Blocks(theme=gr.themes.Monochrome()) as demo:
170
  with gr.Row():
171
  with gr.Column(scale=2):
172
  chatbot = gr.Chatbot(value=[])
 
197
  return "", history + [[user_message, None]]
198
 
199
  def bot(history, system_prompt, temperature, top_p, top_k, frequency_penalty, presence_penalty, repetition_penalty, max_tokens):
 
200
  history = history or []
201
  if not history:
202
  return history
203
  user_message = history[-1][0]
204
  bot_message = predict(user_message, history[:-1], system_prompt, temperature, top_p, top_k, frequency_penalty, presence_penalty, repetition_penalty, max_tokens)
205
  history[-1][1] = ""
206
+ try:
207
+ for chunk in bot_message:
208
+ if thread_local.stop_generation:
209
+ break
210
+ history[-1][1] = chunk
211
+ yield history
212
+ except Exception as e:
213
+ print(f"Error in bot function: {str(e)}")
214
+ history[-1][1] = "An error occurred while generating the response."
215
  yield history
 
216
 
217
  def regenerate_response(history, system_prompt, temperature, top_p, top_k, frequency_penalty, presence_penalty, repetition_penalty, max_tokens):
218
  if history and len(history) > 0:
 
227
  imported_history, imported_system_prompt = import_chat(custom_format_string)
228
  return imported_history, imported_system_prompt
229
 
230
+ submit_event = msg.submit(user, [msg, chatbot], [msg, chatbot]).then(
231
  bot, [chatbot, system_prompt, temperature, top_p, top_k, frequency_penalty, presence_penalty, repetition_penalty, max_tokens], chatbot
232
  )
233
 
234
+ clear.click(lambda: None, None, chatbot)
235
 
236
+ regenerate_event = regenerate.click(
237
  regenerate_response,
238
  [chatbot, system_prompt, temperature, top_p, top_k, frequency_penalty, presence_penalty, repetition_penalty, max_tokens],
239
  chatbot
240
  )
241
 
242
+ stop_btn.click(
243
+ stop_generation_func,
244
+ inputs=[],
245
+ outputs=[chatbot, msg]
246
+ )
247
+
248
  import_button.click(import_chat_wrapper, inputs=[import_textbox], outputs=[chatbot, system_prompt])
249
 
250
  export_button.click(
 
253
  outputs=[import_textbox]
254
  )
255
 
 
 
256
  if __name__ == "__main__":
257
+ demo.launch(debug=True, server_name="0.0.0.0", server_port=7860, share=True)