Spaces:
Runtime error
Runtime error
unqueue the queue + random shit fixes idk
Browse files
app.py
CHANGED
@@ -4,9 +4,10 @@ import json
|
|
4 |
import threading
|
5 |
import os
|
6 |
import datetime
|
|
|
|
|
7 |
from requests.exceptions import RequestException
|
8 |
|
9 |
-
stop_generation = threading.Event()
|
10 |
API_URL = os.environ.get('API_URL')
|
11 |
API_KEY = os.environ.get('API_KEY')
|
12 |
|
@@ -15,8 +16,6 @@ headers = {
|
|
15 |
"Content-Type": "application/json"
|
16 |
}
|
17 |
|
18 |
-
session = requests.Session()
|
19 |
-
|
20 |
DEFAULT_PARAMS = {
|
21 |
"temperature": 0.8,
|
22 |
"top_p": 0.95,
|
@@ -27,12 +26,20 @@ DEFAULT_PARAMS = {
|
|
27 |
"max_tokens": 512
|
28 |
}
|
29 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
30 |
def get_timestamp():
|
31 |
return datetime.datetime.now().strftime("%H:%M:%S")
|
32 |
|
33 |
def predict(message, history, system_prompt, temperature, top_p, top_k, frequency_penalty, presence_penalty, repetition_penalty, max_tokens):
|
34 |
-
|
35 |
-
|
36 |
|
37 |
history_format = [{"role": "system", "content": system_prompt}]
|
38 |
for human, assistant in history:
|
@@ -56,7 +63,7 @@ def predict(message, history, system_prompt, temperature, top_p, top_k, frequenc
|
|
56 |
}
|
57 |
|
58 |
non_default_params = {k: v for k, v in current_params.items() if v != DEFAULT_PARAMS[k]}
|
59 |
-
|
60 |
if non_default_params and not message.startswith(('*', '"')):
|
61 |
for param, value in non_default_params.items():
|
62 |
print(f"{param}={value}")
|
@@ -75,12 +82,14 @@ def predict(message, history, system_prompt, temperature, top_p, top_k, frequenc
|
|
75 |
}
|
76 |
|
77 |
try:
|
78 |
-
with
|
|
|
|
|
79 |
partial_message = ""
|
80 |
for line in response.iter_lines():
|
81 |
-
|
82 |
-
|
83 |
-
|
84 |
if line:
|
85 |
line = line.decode('utf-8')
|
86 |
if line.startswith("data: "):
|
@@ -100,8 +109,16 @@ def predict(message, history, system_prompt, temperature, top_p, top_k, frequenc
|
|
100 |
yield partial_message
|
101 |
|
102 |
except RequestException as e:
|
103 |
-
|
104 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
105 |
|
106 |
def import_chat(custom_format_string):
|
107 |
try:
|
@@ -138,12 +155,18 @@ def export_chat(history, system_prompt):
|
|
138 |
return export_data
|
139 |
|
140 |
def stop_generation_func():
|
141 |
-
|
142 |
-
|
143 |
-
|
144 |
-
|
145 |
-
|
146 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
147 |
with gr.Row():
|
148 |
with gr.Column(scale=2):
|
149 |
chatbot = gr.Chatbot(value=[])
|
@@ -174,20 +197,22 @@ with gr.Blocks(theme='gradio/monochrome') as demo:
|
|
174 |
return "", history + [[user_message, None]]
|
175 |
|
176 |
def bot(history, system_prompt, temperature, top_p, top_k, frequency_penalty, presence_penalty, repetition_penalty, max_tokens):
|
177 |
-
global stop_generation
|
178 |
history = history or []
|
179 |
if not history:
|
180 |
return history
|
181 |
user_message = history[-1][0]
|
182 |
bot_message = predict(user_message, history[:-1], system_prompt, temperature, top_p, top_k, frequency_penalty, presence_penalty, repetition_penalty, max_tokens)
|
183 |
history[-1][1] = ""
|
184 |
-
|
185 |
-
|
186 |
-
|
187 |
-
|
188 |
-
|
|
|
|
|
|
|
|
|
189 |
yield history
|
190 |
-
stop_generation.clear()
|
191 |
|
192 |
def regenerate_response(history, system_prompt, temperature, top_p, top_k, frequency_penalty, presence_penalty, repetition_penalty, max_tokens):
|
193 |
if history and len(history) > 0:
|
@@ -202,18 +227,24 @@ with gr.Blocks(theme='gradio/monochrome') as demo:
|
|
202 |
imported_history, imported_system_prompt = import_chat(custom_format_string)
|
203 |
return imported_history, imported_system_prompt
|
204 |
|
205 |
-
msg.submit(user, [msg, chatbot], [msg, chatbot]
|
206 |
bot, [chatbot, system_prompt, temperature, top_p, top_k, frequency_penalty, presence_penalty, repetition_penalty, max_tokens], chatbot
|
207 |
)
|
208 |
|
209 |
-
clear.click(lambda: None, None, chatbot
|
210 |
|
211 |
-
regenerate.click(
|
212 |
regenerate_response,
|
213 |
[chatbot, system_prompt, temperature, top_p, top_k, frequency_penalty, presence_penalty, repetition_penalty, max_tokens],
|
214 |
chatbot
|
215 |
)
|
216 |
|
|
|
|
|
|
|
|
|
|
|
|
|
217 |
import_button.click(import_chat_wrapper, inputs=[import_textbox], outputs=[chatbot, system_prompt])
|
218 |
|
219 |
export_button.click(
|
@@ -222,7 +253,5 @@ with gr.Blocks(theme='gradio/monochrome') as demo:
|
|
222 |
outputs=[import_textbox]
|
223 |
)
|
224 |
|
225 |
-
stop_btn.click(stop_generation_func, inputs=[], outputs=[])
|
226 |
-
|
227 |
if __name__ == "__main__":
|
228 |
-
demo.launch(debug=True)
|
|
|
4 |
import threading
|
5 |
import os
|
6 |
import datetime
|
7 |
+
import queue
|
8 |
+
import time
|
9 |
from requests.exceptions import RequestException
|
10 |
|
|
|
11 |
API_URL = os.environ.get('API_URL')
|
12 |
API_KEY = os.environ.get('API_KEY')
|
13 |
|
|
|
16 |
"Content-Type": "application/json"
|
17 |
}
|
18 |
|
|
|
|
|
19 |
DEFAULT_PARAMS = {
|
20 |
"temperature": 0.8,
|
21 |
"top_p": 0.95,
|
|
|
26 |
"max_tokens": 512
|
27 |
}
|
28 |
|
29 |
+
class ThreadLocalStorage:
|
30 |
+
def __init__(self):
|
31 |
+
self.stop_generation = False
|
32 |
+
self.active_requests = set()
|
33 |
+
self.lock = threading.Lock()
|
34 |
+
|
35 |
+
thread_local = ThreadLocalStorage()
|
36 |
+
|
37 |
def get_timestamp():
|
38 |
return datetime.datetime.now().strftime("%H:%M:%S")
|
39 |
|
40 |
def predict(message, history, system_prompt, temperature, top_p, top_k, frequency_penalty, presence_penalty, repetition_penalty, max_tokens):
|
41 |
+
with thread_local.lock:
|
42 |
+
thread_local.stop_generation = False
|
43 |
|
44 |
history_format = [{"role": "system", "content": system_prompt}]
|
45 |
for human, assistant in history:
|
|
|
63 |
}
|
64 |
|
65 |
non_default_params = {k: v for k, v in current_params.items() if v != DEFAULT_PARAMS[k]}
|
66 |
+
|
67 |
if non_default_params and not message.startswith(('*', '"')):
|
68 |
for param, value in non_default_params.items():
|
69 |
print(f"{param}={value}")
|
|
|
82 |
}
|
83 |
|
84 |
try:
|
85 |
+
with requests.post(API_URL, headers=headers, data=json.dumps(data), stream=True) as response:
|
86 |
+
with thread_local.lock:
|
87 |
+
thread_local.active_requests.add(response)
|
88 |
partial_message = ""
|
89 |
for line in response.iter_lines():
|
90 |
+
with thread_local.lock:
|
91 |
+
if thread_local.stop_generation:
|
92 |
+
return partial_message
|
93 |
if line:
|
94 |
line = line.decode('utf-8')
|
95 |
if line.startswith("data: "):
|
|
|
109 |
yield partial_message
|
110 |
|
111 |
except RequestException as e:
|
112 |
+
error_message = f"Request error: {str(e)}"
|
113 |
+
print(error_message)
|
114 |
+
yield error_message
|
115 |
+
except Exception as e:
|
116 |
+
error_message = f"Unexpected error: {str(e)}"
|
117 |
+
print(error_message)
|
118 |
+
yield error_message
|
119 |
+
finally:
|
120 |
+
with thread_local.lock:
|
121 |
+
thread_local.active_requests.discard(response)
|
122 |
|
123 |
def import_chat(custom_format_string):
|
124 |
try:
|
|
|
155 |
return export_data
|
156 |
|
157 |
def stop_generation_func():
|
158 |
+
with thread_local.lock:
|
159 |
+
thread_local.stop_generation = True
|
160 |
+
for request in thread_local.active_requests:
|
161 |
+
try:
|
162 |
+
request.close()
|
163 |
+
except Exception as e:
|
164 |
+
print(f"Error closing request: {str(e)}")
|
165 |
+
thread_local.active_requests.clear()
|
166 |
+
time.sleep(0.1)
|
167 |
+
return gr.update(), gr.update()
|
168 |
+
|
169 |
+
with gr.Blocks(theme=gr.themes.Monochrome()) as demo:
|
170 |
with gr.Row():
|
171 |
with gr.Column(scale=2):
|
172 |
chatbot = gr.Chatbot(value=[])
|
|
|
197 |
return "", history + [[user_message, None]]
|
198 |
|
199 |
def bot(history, system_prompt, temperature, top_p, top_k, frequency_penalty, presence_penalty, repetition_penalty, max_tokens):
|
|
|
200 |
history = history or []
|
201 |
if not history:
|
202 |
return history
|
203 |
user_message = history[-1][0]
|
204 |
bot_message = predict(user_message, history[:-1], system_prompt, temperature, top_p, top_k, frequency_penalty, presence_penalty, repetition_penalty, max_tokens)
|
205 |
history[-1][1] = ""
|
206 |
+
try:
|
207 |
+
for chunk in bot_message:
|
208 |
+
if thread_local.stop_generation:
|
209 |
+
break
|
210 |
+
history[-1][1] = chunk
|
211 |
+
yield history
|
212 |
+
except Exception as e:
|
213 |
+
print(f"Error in bot function: {str(e)}")
|
214 |
+
history[-1][1] = "An error occurred while generating the response."
|
215 |
yield history
|
|
|
216 |
|
217 |
def regenerate_response(history, system_prompt, temperature, top_p, top_k, frequency_penalty, presence_penalty, repetition_penalty, max_tokens):
|
218 |
if history and len(history) > 0:
|
|
|
227 |
imported_history, imported_system_prompt = import_chat(custom_format_string)
|
228 |
return imported_history, imported_system_prompt
|
229 |
|
230 |
+
submit_event = msg.submit(user, [msg, chatbot], [msg, chatbot]).then(
|
231 |
bot, [chatbot, system_prompt, temperature, top_p, top_k, frequency_penalty, presence_penalty, repetition_penalty, max_tokens], chatbot
|
232 |
)
|
233 |
|
234 |
+
clear.click(lambda: None, None, chatbot)
|
235 |
|
236 |
+
regenerate_event = regenerate.click(
|
237 |
regenerate_response,
|
238 |
[chatbot, system_prompt, temperature, top_p, top_k, frequency_penalty, presence_penalty, repetition_penalty, max_tokens],
|
239 |
chatbot
|
240 |
)
|
241 |
|
242 |
+
stop_btn.click(
|
243 |
+
stop_generation_func,
|
244 |
+
inputs=[],
|
245 |
+
outputs=[chatbot, msg]
|
246 |
+
)
|
247 |
+
|
248 |
import_button.click(import_chat_wrapper, inputs=[import_textbox], outputs=[chatbot, system_prompt])
|
249 |
|
250 |
export_button.click(
|
|
|
253 |
outputs=[import_textbox]
|
254 |
)
|
255 |
|
|
|
|
|
256 |
if __name__ == "__main__":
|
257 |
+
demo.launch(debug=True, server_name="0.0.0.0", server_port=7860, share=True)
|