Corvius commited on
Commit
1872449
Β·
verified Β·
1 Parent(s): 7324de2

un-retard the stop button

Browse files
Files changed (1) hide show
  1. app.py +34 -28
app.py CHANGED
@@ -3,6 +3,7 @@ import requests
3
  import json
4
  import threading
5
  import os
 
6
 
7
  stop_generation = threading.Event()
8
  API_URL = os.environ.get('API_URL')
@@ -13,8 +14,10 @@ headers = {
13
  "Content-Type": "application/json"
14
  }
15
 
 
 
16
  def predict(message, history, system_prompt, temperature, top_p, top_k, frequency_penalty, presence_penalty, repetition_penalty, max_tokens):
17
- global stop_generation
18
  stop_generation.clear()
19
 
20
  history_format = [{"role": "system", "content": system_prompt}]
@@ -37,32 +40,34 @@ def predict(message, history, system_prompt, temperature, top_p, top_k, frequenc
37
  "max_tokens": max_tokens
38
  }
39
 
40
- response = requests.post(API_URL, headers=headers, data=json.dumps(data), stream=True)
41
-
42
- print("Raw API Response:")
43
- print(response.text)
44
-
45
- partial_message = ""
46
- for line in response.iter_lines():
47
- if stop_generation.is_set():
48
- break
49
- if line:
50
- line = line.decode('utf-8')
51
- if line.startswith("data: "):
52
- if line.strip() == "data: [DONE]":
53
  break
54
- try:
55
- json_data = json.loads(line[6:])
56
- if 'choices' in json_data and json_data['choices']:
57
- content = json_data['choices'][0]['delta'].get('content', '')
58
- if content:
59
- partial_message += content
60
- yield partial_message
61
- except json.JSONDecodeError:
62
- continue
63
-
64
- if partial_message:
65
- yield partial_message
 
 
 
 
 
 
 
 
 
66
 
67
  def import_chat(custom_format_string):
68
  try:
@@ -98,9 +103,10 @@ def export_chat(history, system_prompt):
98
  return export_data
99
 
100
  def stop_generation_func():
101
- global stop_generation
102
  stop_generation.set()
103
-
 
104
 
105
  with gr.Blocks(theme='gradio/monochrome') as demo:
106
  with gr.Row():
 
3
  import json
4
  import threading
5
  import os
6
+ from requests.exceptions import RequestException
7
 
8
  stop_generation = threading.Event()
9
  API_URL = os.environ.get('API_URL')
 
14
  "Content-Type": "application/json"
15
  }
16
 
17
+ session = requests.Session()
18
+
19
  def predict(message, history, system_prompt, temperature, top_p, top_k, frequency_penalty, presence_penalty, repetition_penalty, max_tokens):
20
+ global stop_generation, session
21
  stop_generation.clear()
22
 
23
  history_format = [{"role": "system", "content": system_prompt}]
 
40
  "max_tokens": max_tokens
41
  }
42
 
43
+ try:
44
+ with session.post(API_URL, headers=headers, data=json.dumps(data), stream=True) as response:
45
+ partial_message = ""
46
+ for line in response.iter_lines():
47
+ if stop_generation.is_set():
48
+ response.close()
 
 
 
 
 
 
 
49
  break
50
+ if line:
51
+ line = line.decode('utf-8')
52
+ if line.startswith("data: "):
53
+ if line.strip() == "data: [DONE]":
54
+ break
55
+ try:
56
+ json_data = json.loads(line[6:])
57
+ if 'choices' in json_data and json_data['choices']:
58
+ content = json_data['choices'][0]['delta'].get('content', '')
59
+ if content:
60
+ partial_message += content
61
+ yield partial_message
62
+ except json.JSONDecodeError:
63
+ continue
64
+
65
+ if partial_message:
66
+ yield partial_message
67
+
68
+ except RequestException as e:
69
+ print(f"Request error: {e}")
70
+ yield f"An error occurred: {str(e)}"
71
 
72
  def import_chat(custom_format_string):
73
  try:
 
103
  return export_data
104
 
105
  def stop_generation_func():
106
+ global stop_generation, session
107
  stop_generation.set()
108
+ session.close()
109
+ session = requests.Session()
110
 
111
  with gr.Blocks(theme='gradio/monochrome') as demo:
112
  with gr.Row():