Upload 6 files
Browse files- assets/blank.pdf +0 -0
- assets/pic.png +0 -0
- src/optimizeOpenAI.py +233 -0
- src/paper.py +121 -0
- src/reader.py +109 -0
- src/utils.py +5 -0
assets/blank.pdf
ADDED
Binary file (41.2 kB). View file
|
|
assets/pic.png
ADDED
src/optimizeOpenAI.py
ADDED
@@ -0,0 +1,233 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
A simple wrapper for the official ChatGPT API
|
3 |
+
"""
|
4 |
+
import json
|
5 |
+
import os
|
6 |
+
import threading
|
7 |
+
import time
|
8 |
+
import requests
|
9 |
+
import tiktoken
|
10 |
+
from typing import Generator
|
11 |
+
from queue import PriorityQueue as PQ
|
12 |
+
import json
|
13 |
+
import os
|
14 |
+
import time
|
15 |
+
|
16 |
+
class chatPaper:
|
17 |
+
"""
|
18 |
+
Official ChatGPT API
|
19 |
+
"""
|
20 |
+
def __init__(
|
21 |
+
self,
|
22 |
+
api_keys: list,
|
23 |
+
proxy = None,
|
24 |
+
api_proxy = None,
|
25 |
+
max_tokens: int = 4000,
|
26 |
+
temperature: float = 0.5,
|
27 |
+
top_p: float = 1.0,
|
28 |
+
model_name: str = "gpt-3.5-turbo",
|
29 |
+
reply_count: int = 1,
|
30 |
+
system_prompt = "You are ChatArxiv, A paper reading bot",
|
31 |
+
lastAPICallTime = time.time()-100,
|
32 |
+
apiTimeInterval = 20,
|
33 |
+
) -> None:
|
34 |
+
self.model_name = model_name
|
35 |
+
self.system_prompt = system_prompt
|
36 |
+
self.apiTimeInterval = apiTimeInterval
|
37 |
+
self.session = requests.Session()
|
38 |
+
self.api_keys = PQ()
|
39 |
+
for key in api_keys:
|
40 |
+
self.api_keys.put((lastAPICallTime,key))
|
41 |
+
self.proxy = proxy
|
42 |
+
if self.proxy:
|
43 |
+
proxies = {
|
44 |
+
"http": self.proxy,
|
45 |
+
"https": self.proxy,
|
46 |
+
}
|
47 |
+
self.session.proxies = proxies
|
48 |
+
self.max_tokens = max_tokens
|
49 |
+
self.temperature = temperature
|
50 |
+
self.top_p = top_p
|
51 |
+
self.reply_count = reply_count
|
52 |
+
self.decrease_step = 250
|
53 |
+
self.conversation = {}
|
54 |
+
self.ENCODER = tiktoken.get_encoding("gpt2")
|
55 |
+
if self.token_str(self.system_prompt) > self.max_tokens:
|
56 |
+
raise Exception("System prompt is too long")
|
57 |
+
self.lock = threading.Lock()
|
58 |
+
|
59 |
+
def get_api_key(self):
|
60 |
+
with self.lock:
|
61 |
+
apiKey = self.api_keys.get()
|
62 |
+
delay = self._calculate_delay(apiKey)
|
63 |
+
time.sleep(delay)
|
64 |
+
self.api_keys.put((time.time(), apiKey[1]))
|
65 |
+
return apiKey[1]
|
66 |
+
|
67 |
+
def _calculate_delay(self, apiKey):
|
68 |
+
elapsed_time = time.time() - apiKey[0]
|
69 |
+
if elapsed_time < self.apiTimeInterval:
|
70 |
+
return self.apiTimeInterval - elapsed_time
|
71 |
+
else:
|
72 |
+
return 0
|
73 |
+
|
74 |
+
def add_to_conversation(self, message: str, role: str, convo_id: str = "default"):
|
75 |
+
if(convo_id not in self.conversation):
|
76 |
+
self.reset(convo_id)
|
77 |
+
self.conversation[convo_id].append({"role": role, "content": message})
|
78 |
+
|
79 |
+
def __truncate_conversation(self, convo_id: str = "default"):
|
80 |
+
"""
|
81 |
+
Truncate the conversation
|
82 |
+
"""
|
83 |
+
last_dialog = self.conversation[convo_id][-1]
|
84 |
+
query = str(last_dialog['content'])
|
85 |
+
if(len(self.ENCODER.encode(str(query)))>self.max_tokens):
|
86 |
+
query = query[:int(1.5*self.max_tokens)]
|
87 |
+
while(len(self.ENCODER.encode(str(query)))>self.max_tokens):
|
88 |
+
query = query[:self.decrease_step]
|
89 |
+
self.conversation[convo_id] = self.conversation[convo_id][:-1]
|
90 |
+
full_conversation = "\n".join([str(x["content"]) for x in self.conversation[convo_id]],)
|
91 |
+
if len(self.ENCODER.encode(full_conversation)) > self.max_tokens:
|
92 |
+
self.conversation_summary(convo_id=convo_id)
|
93 |
+
full_conversation = ""
|
94 |
+
for x in self.conversation[convo_id]:
|
95 |
+
full_conversation = str(x["content"]) + "\n" + full_conversation
|
96 |
+
while True:
|
97 |
+
if (len(self.ENCODER.encode(full_conversation+query)) > self.max_tokens):
|
98 |
+
query = query[:self.decrease_step]
|
99 |
+
else:
|
100 |
+
break
|
101 |
+
last_dialog['content'] = str(query)
|
102 |
+
self.conversation[convo_id].append(last_dialog)
|
103 |
+
|
104 |
+
def ask_stream(
|
105 |
+
self,
|
106 |
+
prompt: str,
|
107 |
+
role: str = "user",
|
108 |
+
convo_id: str = "default",
|
109 |
+
**kwargs,
|
110 |
+
) -> Generator:
|
111 |
+
if convo_id not in self.conversation.keys():
|
112 |
+
self.reset(convo_id=convo_id)
|
113 |
+
self.add_to_conversation(prompt, "user", convo_id=convo_id)
|
114 |
+
self.__truncate_conversation(convo_id=convo_id)
|
115 |
+
apiKey = self.get_api_key()
|
116 |
+
response = self.session.post(
|
117 |
+
"https://api.openai.com/v1/chat/completions",
|
118 |
+
headers={"Authorization": f"Bearer {kwargs.get('api_key', apiKey)}"},
|
119 |
+
json={
|
120 |
+
"model": self.model_name,
|
121 |
+
"messages": self.conversation[convo_id],
|
122 |
+
"stream": True,
|
123 |
+
# kwargs
|
124 |
+
"temperature": kwargs.get("temperature", self.temperature),
|
125 |
+
"top_p": kwargs.get("top_p", self.top_p),
|
126 |
+
"n": kwargs.get("n", self.reply_count),
|
127 |
+
"user": role,
|
128 |
+
},
|
129 |
+
stream=True,
|
130 |
+
)
|
131 |
+
if response.status_code != 200:
|
132 |
+
raise Exception(
|
133 |
+
f"Error: {response.status_code} {response.reason} {response.text}",
|
134 |
+
)
|
135 |
+
for line in response.iter_lines():
|
136 |
+
if not line:
|
137 |
+
continue
|
138 |
+
# Remove "data: "
|
139 |
+
line = line.decode("utf-8")[6:]
|
140 |
+
if line == "[DONE]":
|
141 |
+
break
|
142 |
+
resp: dict = json.loads(line)
|
143 |
+
choices = resp.get("choices")
|
144 |
+
if not choices:
|
145 |
+
continue
|
146 |
+
delta = choices[0].get("delta")
|
147 |
+
if not delta:
|
148 |
+
continue
|
149 |
+
if "content" in delta:
|
150 |
+
content = delta["content"]
|
151 |
+
yield content
|
152 |
+
|
153 |
+
def ask(self, prompt: str, role: str = "user", convo_id: str = "default", **kwargs):
|
154 |
+
"""
|
155 |
+
Non-streaming ask
|
156 |
+
"""
|
157 |
+
response = self.ask_stream(
|
158 |
+
prompt=prompt,
|
159 |
+
role=role,
|
160 |
+
convo_id=convo_id,
|
161 |
+
**kwargs,
|
162 |
+
)
|
163 |
+
full_response: str = "".join(response)
|
164 |
+
self.add_to_conversation(full_response, role, convo_id=convo_id)
|
165 |
+
usage_token = self.token_str(prompt)
|
166 |
+
com_token = self.token_str(full_response)
|
167 |
+
total_token = self.token_cost(convo_id=convo_id)
|
168 |
+
return full_response, usage_token, com_token, total_token
|
169 |
+
|
170 |
+
def check_api_available(self):
|
171 |
+
response = self.session.post(
|
172 |
+
"https://api.openai.com/v1/chat/completions",
|
173 |
+
headers={"Authorization": f"Bearer {self.get_api_key()}"},
|
174 |
+
json={
|
175 |
+
"model": self.engine,
|
176 |
+
"messages": [{"role": "system", "content": "You are a helpful assistant."},{"role": "user", "content": "print A"}],
|
177 |
+
"stream": True,
|
178 |
+
# kwargs
|
179 |
+
"temperature": self.temperature,
|
180 |
+
"top_p": self.top_p,
|
181 |
+
"n": self.reply_count,
|
182 |
+
"user": "user",
|
183 |
+
},
|
184 |
+
stream=True,
|
185 |
+
)
|
186 |
+
if response.status_code == 200:
|
187 |
+
return True
|
188 |
+
else:
|
189 |
+
return False
|
190 |
+
|
191 |
+
def reset(self, convo_id: str = "default", system_prompt = None):
|
192 |
+
"""
|
193 |
+
Reset the conversation
|
194 |
+
"""
|
195 |
+
self.conversation[convo_id] = [
|
196 |
+
{"role": "system", "content": str(system_prompt or self.system_prompt)},
|
197 |
+
]
|
198 |
+
|
199 |
+
def conversation_summary(self, convo_id: str = "default"):
|
200 |
+
input = ""
|
201 |
+
role = ""
|
202 |
+
for conv in self.conversation[convo_id]:
|
203 |
+
if (conv["role"]=='user'):
|
204 |
+
role = 'User'
|
205 |
+
else:
|
206 |
+
role = 'ChatGpt'
|
207 |
+
input+=role+' : '+conv['content']+'\n'
|
208 |
+
prompt = "Your goal is to summarize the provided conversation. Your summary should be concise and focus on the key information to facilitate better dialogue for the large language model.Ensure that you include all necessary details and relevant information while still reducing the length of the conversation as much as possible. Your summary should be clear and easily understandable for the ChatGpt model providing a comprehensive and concise summary of the conversation."
|
209 |
+
if(self.token_str(str(input)+prompt)>self.max_tokens):
|
210 |
+
input = input[self.token_str(str(input))-self.max_tokens:]
|
211 |
+
while self.token_str(str(input)+prompt)>self.max_tokens:
|
212 |
+
input = input[self.decrease_step:]
|
213 |
+
prompt = prompt.replace("{conversation}", input)
|
214 |
+
self.reset(convo_id='conversationSummary')
|
215 |
+
response = self.ask(prompt, convo_id='conversationSummary')
|
216 |
+
while self.token_str(str(response))>self.max_tokens:
|
217 |
+
response = response[:-self.decrease_step]
|
218 |
+
self.reset(convo_id='conversationSummary',system_prompt='Summariaze')
|
219 |
+
self.conversation[convo_id] = [
|
220 |
+
{"role": "system", "content": self.system_prompt},
|
221 |
+
{"role": "user", "content": "Summariaze"},
|
222 |
+
{"role": 'assistant', "content": response},
|
223 |
+
]
|
224 |
+
return self.conversation[convo_id]
|
225 |
+
|
226 |
+
def token_cost(self,convo_id: str = "default"):
|
227 |
+
return len(self.ENCODER.encode("\n".join([x["content"] for x in self.conversation[convo_id]])))
|
228 |
+
|
229 |
+
def token_str(self, content:str):
|
230 |
+
return len(self.ENCODER.encode(content))
|
231 |
+
|
232 |
+
def main():
|
233 |
+
return
|
src/paper.py
ADDED
@@ -0,0 +1,121 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import fitz
|
2 |
+
import os
|
3 |
+
import io
|
4 |
+
import arxiv
|
5 |
+
import tempfile
|
6 |
+
|
7 |
+
from PIL import Image
|
8 |
+
from urllib.parse import urlparse
|
9 |
+
|
10 |
+
class Paper:
|
11 |
+
def __init__(self, url=''):
|
12 |
+
self.url = url
|
13 |
+
self.parse_url()
|
14 |
+
self.get_pdf()
|
15 |
+
self.paper_instance = {
|
16 |
+
'title': self.paper_arxiv.title,
|
17 |
+
'authors': self.paper_arxiv.authors,
|
18 |
+
'arxiv_id': self.paper_id,
|
19 |
+
'abstract': self.paper_arxiv.summary,
|
20 |
+
'pdf_url': self.paper_arxiv.pdf_url,
|
21 |
+
'categories': self.paper_arxiv.categories,
|
22 |
+
'published': self.paper_arxiv.published,
|
23 |
+
'updated': self.paper_arxiv.updated,
|
24 |
+
'content': {}
|
25 |
+
}
|
26 |
+
self.parse_pdf()
|
27 |
+
|
28 |
+
def get_paper(self):
|
29 |
+
return self.paper_instance
|
30 |
+
|
31 |
+
def parse_url(self):
|
32 |
+
self.url = self.url.replace('.pdf', '')
|
33 |
+
parsed_url = urlparse(self.url)
|
34 |
+
paper_id = os.path.basename(parsed_url.path)
|
35 |
+
self.paper_id = paper_id
|
36 |
+
|
37 |
+
def get_pdf(self):
|
38 |
+
search = arxiv.Search(id_list=[self.paper_id], max_results=1)
|
39 |
+
results = search.results()
|
40 |
+
paper_arxiv = next(results)
|
41 |
+
if paper_arxiv:
|
42 |
+
# with tempfile.NamedTemporaryFile(delete=False, suffix=".pdf") as temp_pdf:
|
43 |
+
paper_path = f'{self.paper_id}.pdf'
|
44 |
+
dir_path = "./pdf"
|
45 |
+
os.makedirs(dir_path, exist_ok=True)
|
46 |
+
save_dir = os.path.join(dir_path, paper_path)
|
47 |
+
if not os.path.exists(save_dir):
|
48 |
+
paper_arxiv.download_pdf(dirpath=dir_path, filename=paper_path)
|
49 |
+
self.paper_arxiv = paper_arxiv
|
50 |
+
self.path = save_dir
|
51 |
+
else:
|
52 |
+
raise Exception("无法找到论文,请检查 URL 是否正确。")
|
53 |
+
|
54 |
+
def parse_pdf(self):
|
55 |
+
self.pdf = fitz.open(self.path)
|
56 |
+
self.text_list = [page.get_text() for page in self.pdf]
|
57 |
+
self.all_text = ' '.join(self.text_list)
|
58 |
+
|
59 |
+
self._parse_paper()
|
60 |
+
self.pdf.close()
|
61 |
+
|
62 |
+
def _get_sections(self):
|
63 |
+
sections = 'Abstract,Introduction,Related Work,Background,Preliminary,Problem Formulation,Methods,Methodology,Method,Approach,Approaches,Materials and Methods,Experiment Settings,Experiment,Experimental Results,Evaluation,Experiments,Results,Findings,Data Analysis,Discussion,Results and Discussion,Conclusion,References'
|
64 |
+
self.sections = sections.split(',')
|
65 |
+
|
66 |
+
def _get_all_page_index(self):
|
67 |
+
section_list = self.sections
|
68 |
+
section_page_dict = {}
|
69 |
+
|
70 |
+
for page_index, page in enumerate(self.pdf):
|
71 |
+
cur_text = page.get_text()
|
72 |
+
for section_name in section_list:
|
73 |
+
section_name_upper = section_name.upper()
|
74 |
+
if "Abstract" == section_name and section_name in cur_text:
|
75 |
+
section_page_dict[section_name] = page_index
|
76 |
+
continue
|
77 |
+
|
78 |
+
if section_name + '\n' in cur_text:
|
79 |
+
section_page_dict[section_name] = page_index
|
80 |
+
elif section_name_upper + '\n' in cur_text:
|
81 |
+
section_page_dict[section_name] = page_index
|
82 |
+
|
83 |
+
self.section_page_dict = section_page_dict
|
84 |
+
|
85 |
+
def _parse_paper(self):
|
86 |
+
"""
|
87 |
+
Return: dict { <Section Name>: <Content> }
|
88 |
+
"""
|
89 |
+
self._get_sections()
|
90 |
+
self._get_all_page_index()
|
91 |
+
|
92 |
+
text_list = [page.get_text() for page in self.pdf]
|
93 |
+
section_keys = list(self.section_page_dict.keys())
|
94 |
+
section_count = len(section_keys)
|
95 |
+
|
96 |
+
section_dict = {}
|
97 |
+
for sec_index, sec_name in enumerate(section_keys):
|
98 |
+
if sec_index == 0:
|
99 |
+
continue
|
100 |
+
|
101 |
+
start_page = self.section_page_dict[sec_name]
|
102 |
+
end_page = self.section_page_dict[section_keys[sec_index + 1]] if sec_index < section_count - 1 else len(text_list)
|
103 |
+
|
104 |
+
cur_sec_text = []
|
105 |
+
for page_i in range(start_page, end_page):
|
106 |
+
page_text = text_list[page_i]
|
107 |
+
|
108 |
+
if page_i == start_page:
|
109 |
+
start_i = page_text.find(sec_name) if sec_name in page_text else page_text.find(sec_name.upper())
|
110 |
+
page_text = page_text[start_i:]
|
111 |
+
|
112 |
+
if page_i == end_page - 1 and sec_index < section_count - 1:
|
113 |
+
next_sec = section_keys[sec_index + 1]
|
114 |
+
end_i = page_text.find(next_sec) if next_sec in page_text else page_text.find(next_sec.upper())
|
115 |
+
page_text = page_text[:end_i]
|
116 |
+
|
117 |
+
cur_sec_text.append(page_text)
|
118 |
+
|
119 |
+
section_dict[sec_name] = ''.join(cur_sec_text).replace('-\n', '').replace('\n', ' ')
|
120 |
+
|
121 |
+
self.paper_instance['content'] = section_dict
|
src/reader.py
ADDED
@@ -0,0 +1,109 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import re
|
3 |
+
import numpy as np
|
4 |
+
import tenacity
|
5 |
+
import arxiv
|
6 |
+
import markdown
|
7 |
+
|
8 |
+
from .paper import Paper
|
9 |
+
from .optimizeOpenAI import chatPaper
|
10 |
+
|
11 |
+
class Reader:
|
12 |
+
def __init__(self,
|
13 |
+
paper: Paper,
|
14 |
+
api_key='',
|
15 |
+
user_name='defualt',
|
16 |
+
language='English'):
|
17 |
+
self.user_name = user_name
|
18 |
+
self.language = language
|
19 |
+
self.paper_instance = paper.get_paper()
|
20 |
+
|
21 |
+
self.chat_api_list = [api_key]
|
22 |
+
self.chatPaper = chatPaper(api_keys=self.chat_api_list, apiTimeInterval=10)
|
23 |
+
self.chatPaper.add_to_conversation(message="You are a professional academic paper reviewer and mentor named Arxiv Bot. As a professional academic paper reviewer and helpful mentor, you possess exceptional logical and critical thinking skills, enabling you to provide concise and insightful responses.", role='assistant', convo_id="chat")
|
24 |
+
self.chatPaper.add_to_conversation(message="You are not allowed to discuss anything about politics, do not comment on anything about that.", role='assistant', convo_id="chat")
|
25 |
+
self.chatPaper.add_to_conversation(message="You will be asked to answer questions about the paper with deep knowledge about it, providing clear and concise explanations in a helpful, friendly manner, using the asker's language.", role='user', convo_id="chat")
|
26 |
+
|
27 |
+
# Read Basic Info of the Paper
|
28 |
+
self._read_basic()
|
29 |
+
|
30 |
+
def _get_intro_prompt(self, intro_content: str = ''):
|
31 |
+
if intro_content == '':
|
32 |
+
intro_key = [k for k in self.paper_instance['content'].keys()][0]
|
33 |
+
intro_content = self.paper_instance['content'][intro_key]
|
34 |
+
prompt = (f"This is an academic paper from {self.paper_instance['categories']} fields,\n\
|
35 |
+
Title of this paper are {self.paper_instance['title']}.\n\
|
36 |
+
Authors of this paper are {self.paper_instance['authors']}.\n\
|
37 |
+
Abstract of this paper is {self.paper_instance['abstract']}.\n\
|
38 |
+
Introduction of this paper is {intro_content}.")
|
39 |
+
return prompt
|
40 |
+
|
41 |
+
def _init_prompt(self, convo_id: str = 'default'):
|
42 |
+
intro_content = ''
|
43 |
+
max_tokens = self.chatPaper.max_tokens
|
44 |
+
|
45 |
+
prompt = self._get_intro_prompt(intro_content)
|
46 |
+
full_conversation_ = "\n".join([str(x["content"]) for x in self.chatPaper.conversation[convo_id]],)
|
47 |
+
full_conversation = str(full_conversation_ + prompt)
|
48 |
+
|
49 |
+
# Try to summarize the intro part
|
50 |
+
if(len(self.chatPaper.ENCODER.encode(str(full_conversation)))>max_tokens):
|
51 |
+
prompt = f'This is the introduction, please summarize it and reduct its length in {max_tokens} tokens: {prompt}'
|
52 |
+
intro_content = self._summarize_content(prompt)
|
53 |
+
prompt = self._get_intro_prompt(intro_content)
|
54 |
+
full_conversation = str(full_conversation_ + prompt)
|
55 |
+
|
56 |
+
# Failed, try to reduce the length of the prompt
|
57 |
+
while(len(self.chatPaper.ENCODER.encode(str(full_conversation)))>max_tokens):
|
58 |
+
prompt = prompt[:self.chatPaper.decrease_step]
|
59 |
+
full_conversation = str(full_conversation_ + prompt)
|
60 |
+
|
61 |
+
return prompt
|
62 |
+
|
63 |
+
def _summarize_content(self, content: str = ''):
|
64 |
+
sys_prompt = "Your goal is to summarize the provided content from an academic paper. Your summary should be concise and focus on the key information of the academic paper, do not miss any important point."
|
65 |
+
self.chatPaper.reset(convo_id='summary', system_prompt=sys_prompt)
|
66 |
+
response = self.chatPaper.ask(content, convo_id='summary')
|
67 |
+
res_txt = str(response[0])
|
68 |
+
return res_txt
|
69 |
+
|
70 |
+
def get_basic_info(self):
|
71 |
+
prompt = f'Introduce this paper (its not necessary to include the basic information like title and author name), comment on this paper based on its abstract and introduction from its 1. Novelty, 2. Improtance, 3. Potential Influence. Relpy in {self.language}'
|
72 |
+
basic_op = self.chatPaper.ask(prompt, convo_id='chat')[0]
|
73 |
+
return basic_op
|
74 |
+
|
75 |
+
def _read_basic(self, convo_id="chat"):
|
76 |
+
prompt = self._init_prompt(convo_id)
|
77 |
+
self.chatPaper.add_to_conversation(
|
78 |
+
convo_id=convo_id,
|
79 |
+
role="assistant",
|
80 |
+
message= prompt
|
81 |
+
)
|
82 |
+
|
83 |
+
def read_paper(self, chapter_list: list = [], convo_id="chat"):
|
84 |
+
for chap in chapter_list:
|
85 |
+
prompt = self.paper_instance['content'][chap]
|
86 |
+
sys_prompt = f'This is the {chap} section of this paper, please read carefully and answer the users questions professionally and friendly basic on the content.\n'
|
87 |
+
prompt = sys_prompt + prompt
|
88 |
+
self.chatPaper.add_to_conversation(
|
89 |
+
convo_id=convo_id,
|
90 |
+
role="assistant",
|
91 |
+
message= prompt
|
92 |
+
)
|
93 |
+
return "我读完了这些章节,让我们开始吧! 🤩"
|
94 |
+
|
95 |
+
|
96 |
+
@tenacity.retry(wait=tenacity.wait_exponential(multiplier=1, min=4, max=10),
|
97 |
+
stop=tenacity.stop_after_attempt(5),
|
98 |
+
reraise=True)
|
99 |
+
def chat_with_paper(self, prompt):
|
100 |
+
result = self.chatPaper.ask(
|
101 |
+
prompt = prompt,
|
102 |
+
role="user",
|
103 |
+
convo_id="chat",
|
104 |
+
)
|
105 |
+
reply = str(result[0])
|
106 |
+
return reply
|
107 |
+
|
108 |
+
|
109 |
+
|
src/utils.py
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
language_dict = {
|
2 |
+
'zh': '中文',
|
3 |
+
'en': 'English',
|
4 |
+
}
|
5 |
+
|