Spaces:
Runtime error
Runtime error
sc_ma
commited on
Commit
·
1012e47
1
Parent(s):
49d990e
bug fix.
Browse files- utils/gpt_interaction.py +74 -103
utils/gpt_interaction.py
CHANGED
@@ -1,18 +1,70 @@
|
|
1 |
-
import os
|
2 |
-
import time
|
3 |
-
|
4 |
import openai
|
5 |
-
import
|
6 |
-
import
|
7 |
import json
|
8 |
-
|
9 |
log = logging.getLogger(__name__)
|
10 |
|
11 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
12 |
conversation_history = [
|
13 |
-
{"role": "system", "content":
|
14 |
-
{"role": "user", "content": prompts}
|
15 |
]
|
|
|
16 |
response = openai.ChatCompletion.create(
|
17 |
model=model,
|
18 |
messages=conversation_history,
|
@@ -25,98 +77,17 @@ def get_gpt_responses(systems, prompts, model="gpt-4", temperature=0.4):
|
|
25 |
return assistant_message, usage
|
26 |
|
27 |
|
28 |
-
class GPTModel_API2D_SUPPORT:
|
29 |
-
def __init__(self, model="gpt-4", temperature=0, presence_penalty=0,
|
30 |
-
frequency_penalty=0, url=None, key=None, max_attempts=1, delay=20):
|
31 |
-
if url is None:
|
32 |
-
url = "https://api.openai.com/v1/chat/completions"
|
33 |
-
if key is None:
|
34 |
-
key = os.getenv("OPENAI_API_KEY")
|
35 |
-
|
36 |
-
self.model = model
|
37 |
-
self.temperature = temperature
|
38 |
-
self.url = url
|
39 |
-
self.key = key
|
40 |
-
self.presence_penalty = presence_penalty
|
41 |
-
self.frequency_penalty = frequency_penalty
|
42 |
-
self.max_attempts = max_attempts
|
43 |
-
self.delay = delay
|
44 |
-
|
45 |
-
def __call__(self, systems, prompts, return_json=False):
|
46 |
-
headers = {
|
47 |
-
"Content-Type": "application/json",
|
48 |
-
"Authorization": f"Bearer {self.key}",
|
49 |
-
}
|
50 |
-
|
51 |
-
data = {
|
52 |
-
"model": f"{self.model}",
|
53 |
-
"messages": [
|
54 |
-
{"role": "system", "content": systems},
|
55 |
-
{"role": "user", "content": prompts}],
|
56 |
-
"temperature": self.temperature,
|
57 |
-
"n": 1,
|
58 |
-
"stream": False,
|
59 |
-
"presence_penalty": self.presence_penalty,
|
60 |
-
"frequency_penalty": self.frequency_penalty
|
61 |
-
}
|
62 |
-
for _ in range(self.max_attempts):
|
63 |
-
try:
|
64 |
-
# todo: in some cases, UnicodeEncodeError is raised:
|
65 |
-
# 'gbk' codec can't encode character '\xdf' in position 1898: illegal multibyte sequence
|
66 |
-
response = requests.post(self.url, headers=headers, data=json.dumps(data))
|
67 |
-
response = response.json()
|
68 |
-
assistant_message = response['choices'][0]["message"]["content"]
|
69 |
-
usage = response['usage']
|
70 |
-
log.info(assistant_message)
|
71 |
-
if return_json:
|
72 |
-
assistant_message = json.loads(assistant_message)
|
73 |
-
return assistant_message, usage
|
74 |
-
except Exception as e:
|
75 |
-
print(f"Failed to get response. Error: {e}")
|
76 |
-
time.sleep(self.delay)
|
77 |
-
raise RuntimeError("Failed to get response from OpenAI.")
|
78 |
-
|
79 |
-
|
80 |
-
class GPTModel:
|
81 |
-
def __init__(self, model="gpt-4", temperature=0.9, presence_penalty=0,
|
82 |
-
frequency_penalty=0, max_attempts=1, delay=20):
|
83 |
-
self.model = model
|
84 |
-
self.temperature = temperature
|
85 |
-
self.presence_penalty = presence_penalty
|
86 |
-
self.frequency_penalty = frequency_penalty
|
87 |
-
self.max_attempts = max_attempts
|
88 |
-
self.delay = delay
|
89 |
-
|
90 |
-
def __call__(self, systems, prompts, return_json=False):
|
91 |
-
conversation_history = [
|
92 |
-
{"role": "system", "content": systems},
|
93 |
-
{"role": "user", "content": prompts}
|
94 |
-
]
|
95 |
-
for _ in range(self.max_attempts):
|
96 |
-
try:
|
97 |
-
response = openai.ChatCompletion.create(
|
98 |
-
model=self.model,
|
99 |
-
messages=conversation_history,
|
100 |
-
n=1,
|
101 |
-
temperature=self.temperature,
|
102 |
-
presence_penalty=self.presence_penalty,
|
103 |
-
frequency_penalty=self.frequency_penalty,
|
104 |
-
stream=False
|
105 |
-
)
|
106 |
-
assistant_message = response['choices'][0]["message"]["content"]
|
107 |
-
usage = response['usage']
|
108 |
-
log.info(assistant_message)
|
109 |
-
if return_json:
|
110 |
-
assistant_message = json.loads(assistant_message)
|
111 |
-
return assistant_message, usage
|
112 |
-
except Exception as e:
|
113 |
-
print(f"Failed to get response. Error: {e}")
|
114 |
-
time.sleep(self.delay)
|
115 |
-
raise RuntimeError("Failed to get response from OpenAI.")
|
116 |
-
|
117 |
-
|
118 |
-
|
119 |
if __name__ == "__main__":
|
120 |
-
|
121 |
-
|
122 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import openai
|
2 |
+
import re
|
3 |
+
import os
|
4 |
import json
|
5 |
+
import logging
|
6 |
log = logging.getLogger(__name__)
|
7 |
|
8 |
+
# todo: 将api_key通过函数传入; 需要改很多地方
|
9 |
+
# openai.api_key = os.environ['OPENAI_API_KEY']
|
10 |
+
|
11 |
+
def extract_responses(assistant_message):
|
12 |
+
# pattern = re.compile(r"f\.write\(r'{1,3}(.*?)'{0,3}\){0,1}$", re.DOTALL)
|
13 |
+
pattern = re.compile(r"f\.write\(r['\"]{1,3}(.*?)['\"]{0,3}\){0,1}$", re.DOTALL)
|
14 |
+
match = re.search(pattern, assistant_message)
|
15 |
+
if match:
|
16 |
+
return match.group(1)
|
17 |
+
else:
|
18 |
+
log.info("Responses are not put in Python codes. Directly return assistant_message.\n")
|
19 |
+
log.info(f"assistant_message: {assistant_message}")
|
20 |
+
return assistant_message
|
21 |
+
|
22 |
+
def extract_keywords(assistant_message, default_keywords=None):
|
23 |
+
if default_keywords is None:
|
24 |
+
default_keywords = {"machine learning":5}
|
25 |
+
|
26 |
+
try:
|
27 |
+
keywords = json.loads(assistant_message)
|
28 |
+
except ValueError:
|
29 |
+
log.info("Responses are not in json format. Return the default dictionary.\n ")
|
30 |
+
log.info(f"assistant_message: {assistant_message}")
|
31 |
+
return default_keywords
|
32 |
+
return keywords
|
33 |
+
|
34 |
+
def extract_section_name(assistant_message, default_section_name=""):
|
35 |
+
try:
|
36 |
+
keywords = json.loads(assistant_message)
|
37 |
+
except ValueError:
|
38 |
+
log.info("Responses are not in json format. Return None.\n ")
|
39 |
+
log.info(f"assistant_message: {assistant_message}")
|
40 |
+
return default_section_name
|
41 |
+
return keywords
|
42 |
+
|
43 |
+
|
44 |
+
def extract_json(assistant_message, default_output=None):
|
45 |
+
if default_output is None:
|
46 |
+
default_keys = ["Method 1", "Method 2"]
|
47 |
+
else:
|
48 |
+
default_keys = default_output
|
49 |
+
try:
|
50 |
+
dict = json.loads(assistant_message)
|
51 |
+
except:
|
52 |
+
log.info("Responses are not in json format. Return the default keys.\n ")
|
53 |
+
log.info(f"assistant_message: {assistant_message}")
|
54 |
+
return default_keys
|
55 |
+
return dict.keys()
|
56 |
+
|
57 |
+
|
58 |
+
def get_responses(user_message, model="gpt-4", temperature=0.4, openai_key = None):
|
59 |
+
if openai.api_key is None and openai_key is None:
|
60 |
+
raise ValueError("OpenAI API key must be provided.")
|
61 |
+
if openai_key is not None:
|
62 |
+
openai.api_key = openai_key
|
63 |
+
|
64 |
conversation_history = [
|
65 |
+
{"role": "system", "content": "You are an assistant in writing machine learning papers."}
|
|
|
66 |
]
|
67 |
+
conversation_history.append({"role": "user", "content": user_message})
|
68 |
response = openai.ChatCompletion.create(
|
69 |
model=model,
|
70 |
messages=conversation_history,
|
|
|
77 |
return assistant_message, usage
|
78 |
|
79 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
80 |
if __name__ == "__main__":
|
81 |
+
test_strings = [r"f.write(r'hello world')", r"f.write(r'''hello world''')", r"f.write(r'''hello world",
|
82 |
+
r"f.write(r'''hello world'", r'f.write(r"hello world")', r'f.write(r"""hello world""")',
|
83 |
+
r'f.write(r"""hello world"', r'f.write(r"""hello world']
|
84 |
+
for input_string in test_strings:
|
85 |
+
print("input_string: ", input_string)
|
86 |
+
pattern = re.compile(r"f\.write\(r['\"]{1,3}(.*?)['\"]{0,3}\){0,1}$", re.DOTALL)
|
87 |
+
|
88 |
+
match = re.search(pattern, input_string)
|
89 |
+
if match:
|
90 |
+
extracted_string = match.group(1)
|
91 |
+
print("Extracted string:", extracted_string)
|
92 |
+
else:
|
93 |
+
print("No match found")
|