wangrongsheng circlecircle commited on
Commit
12b83ac
·
1 Parent(s): e55d43b

API Pool and bugs fixing (#4)

Browse files

- fix bugs for valid api (54fb763315a30f14c9a6560762a263c575e1f43f)


Co-authored-by: cc <[email protected]>

Files changed (2) hide show
  1. app.py +332 -208
  2. optimizeOpenAI.py +5 -5
app.py CHANGED
@@ -15,9 +15,10 @@ import json
15
  import tiktoken
16
  import concurrent.futures
17
  from optimizeOpenAI import chatPaper
 
18
  def parse_text(text):
19
  lines = text.split("\n")
20
- for i,line in enumerate(lines):
21
  if "```" in line:
22
  items = line.split('`')
23
  if items[-1]:
@@ -25,12 +26,13 @@ def parse_text(text):
25
  else:
26
  lines[i] = f'</code></pre>'
27
  else:
28
- if i>0:
29
  line = line.replace("<", "&lt;")
30
  line = line.replace(">", "&gt;")
31
- lines[i] = '<br/>'+line.replace(" ", "&nbsp;")
32
  return "".join(lines)
33
 
 
34
  # def get_response(system, context, myKey, raw = False):
35
  # openai.api_key = myKey
36
  # response = openai.ChatCompletion.create(
@@ -47,6 +49,7 @@ def parse_text(text):
47
 
48
  valid_api_keys = []
49
 
 
50
  def api_key_check(api_key):
51
  try:
52
  chat = chatPaper([api_key])
@@ -57,54 +60,62 @@ def api_key_check(api_key):
57
  except:
58
  return None
59
 
 
60
  def valid_apikey(api_keys):
61
  api_keys = api_keys.replace(' ', '')
62
  api_key_list = api_keys.split(',')
63
  print(api_key_list)
64
  global valid_api_keys
65
  with concurrent.futures.ThreadPoolExecutor() as executor:
66
- future_results = {executor.submit(api_key_check, api_key): api_key for api_key in api_key_list}
 
 
 
67
  for future in concurrent.futures.as_completed(future_results):
68
  result = future.result()
69
  if result:
70
  valid_api_keys.append(result)
71
  if len(valid_api_keys) > 0:
72
- return "有效的api-key一共有{}个,分别是:{}, 现在可以提交你的paper".format(len(valid_api_keys), valid_api_keys)
 
73
  return "无效的api-key"
74
 
75
 
76
  class Paper:
 
77
  def __init__(self, path, title='', url='', abs='', authers=[], sl=[]):
78
- # 初始化函数,根据pdf路径初始化Paper对象
79
- self.url = url # 文章链接
80
- self.path = path # pdf路径
81
  self.sl = sl
82
- self.section_names = [] # 段落标题
83
- self.section_texts = {} # 段落内容
84
  self.abs = abs
85
  self.title_page = 0
86
  if title == '':
87
- self.pdf = fitz.open(self.path) # pdf文档
88
  self.title = self.get_title()
89
- self.parse_pdf()
90
  else:
91
  self.title = title
92
- self.authers = authers
93
- self.roman_num = ["I", "II", 'III', "IV", "V", "VI", "VII", "VIII", "IIX", "IX", "X"]
94
- self.digit_num = [str(d+1) for d in range(10)]
 
 
95
  self.first_image = ''
96
-
97
  def parse_pdf(self):
98
- self.pdf = fitz.open(self.path) # pdf文档
99
  self.text_list = [page.get_text() for page in self.pdf]
100
  self.all_text = ' '.join(self.text_list)
101
- self.section_page_dict = self._get_all_page_index() # 段落与页码的对应字典
102
  print("section_page_dict", self.section_page_dict)
103
- self.section_text_dict = self._get_all_page() # 段落与内容的对应字典
104
  self.section_text_dict.update({"title": self.title})
105
  self.section_text_dict.update({"paper_info": self.get_paper_info()})
106
- self.pdf.close()
107
-
108
  def get_paper_info(self):
109
  first_page_text = self.pdf[self.title_page].get_text()
110
  if "Abstract" in self.section_text_dict.keys():
@@ -112,9 +123,10 @@ class Paper:
112
  else:
113
  abstract_text = self.abs
114
  introduction_text = self.section_text_dict['Introduction']
115
- first_page_text = first_page_text.replace(abstract_text, "").replace(introduction_text, "")
 
116
  return first_page_text
117
-
118
  def get_image_path(self, image_path=''):
119
  """
120
  将PDF中的第一张���保存到image.png里面,存到本地目录,返回文件名称,供gitee读取
@@ -131,9 +143,10 @@ class Paper:
131
  # 查看独立页面
132
  page = my_pdf_file[page_number - 1]
133
  # 查看当前页所有图片
134
- images = page.get_images()
135
  # 遍历当前页面所有图片
136
- for image_number, image in enumerate(page.get_images(), start=1):
 
137
  # 访问图片xref
138
  xref_value = image[0]
139
  # 提取图片信息
@@ -148,32 +161,32 @@ class Paper:
148
  if image_size > max_size:
149
  max_size = image_size
150
  image_list.append(image)
151
- for image in image_list:
152
  image_size = image.size[0] * image.size[1]
153
- if image_size == max_size:
154
  image_name = f"image.{ext}"
155
  im_path = os.path.join(image_path, image_name)
156
  print("im_path:", im_path)
157
-
158
  max_pix = 480
159
  origin_min_pix = min(image.size[0], image.size[1])
160
-
161
  if image.size[0] > image.size[1]:
162
- min_pix = int(image.size[1] * (max_pix/image.size[0]))
163
  newsize = (max_pix, min_pix)
164
  else:
165
- min_pix = int(image.size[0] * (max_pix/image.size[1]))
166
  newsize = (min_pix, max_pix)
167
  image = image.resize(newsize)
168
-
169
  image.save(open(im_path, "wb"))
170
  return im_path, ext
171
  return None, None
172
-
173
  # 定义一个函数,根据字体的大小,识别每个章节名称,并返回一个列表
174
- def get_chapter_names(self,):
175
  # # 打开一个pdf文件
176
- doc = fitz.open(self.path) # pdf文档
177
  text_list = [page.get_text() for page in doc]
178
  all_text = ''
179
  for text in text_list:
@@ -186,52 +199,61 @@ class Paper:
186
  point_split_list = line.split('.')
187
  space_split_list = line.split(' ')
188
  if 1 < len(space_split_list) < 5:
189
- if 1 < len(point_split_list) < 5 and (point_split_list[0] in self.roman_num or point_split_list[0] in self.digit_num):
 
 
190
  print("line:", line)
191
- chapter_names.append(line)
192
-
193
  return chapter_names
194
-
195
  def get_title(self):
196
- doc = self.pdf # 打开pdf文件
197
- max_font_size = 0 # 初始化最大字体大小为0
198
- max_string = "" # 初始化最大字体大小对应的字符串为空
199
  max_font_sizes = [0]
200
- for page_index, page in enumerate(doc): # 遍历每一页
201
- text = page.get_text("dict") # 获取页面上的文本信息
202
- blocks = text["blocks"] # 获取文本块列表
203
- for block in blocks: # 遍历每个文本块
204
- if block["type"] == 0 and len(block['lines']): # 如果是文字类型
205
  if len(block["lines"][0]["spans"]):
206
- font_size = block["lines"][0]["spans"][0]["size"] # 获取第一行第一段文字的字体大小
 
207
  max_font_sizes.append(font_size)
208
- if font_size > max_font_size: # 如果字体大小大于当前最大值
209
- max_font_size = font_size # 更新最大值
210
- max_string = block["lines"][0]["spans"][0]["text"] # 更新最大值对应���字符串
211
- max_font_sizes.sort()
 
212
  print("max_font_sizes", max_font_sizes[-10:])
213
  cur_title = ''
214
- for page_index, page in enumerate(doc): # 遍历每一页
215
- text = page.get_text("dict") # 获取页面上的文本信息
216
- blocks = text["blocks"] # 获取文本块列表
217
- for block in blocks: # 遍历每个文本块
218
- if block["type"] == 0 and len(block['lines']): # 如果是文字类型
219
  if len(block["lines"][0]["spans"]):
220
- cur_string = block["lines"][0]["spans"][0]["text"] # 更新最大值对应的字符串
221
- font_flags = block["lines"][0]["spans"][0]["flags"] # 获取第一行第一段文字的字体特征
222
- font_size = block["lines"][0]["spans"][0]["size"] # 获取第一行第一段文字的字体大小
 
 
 
223
  # print(font_size)
224
- if abs(font_size - max_font_sizes[-1]) < 0.3 or abs(font_size - max_font_sizes[-2]) < 0.3:
225
- # print("The string is bold.", max_string, "font_size:", font_size, "font_flags:", font_flags)
226
- if len(cur_string) > 4 and "arXiv" not in cur_string:
227
- # print("The string is bold.", max_string, "font_size:", font_size, "font_flags:", font_flags)
228
- if cur_title == '' :
229
- cur_title += cur_string
 
 
230
  else:
231
- cur_title += ' ' + cur_string
232
  self.title_page = page_index
233
 
234
- title = cur_title.replace('\n', ' ')
235
  return title
236
 
237
  def _get_all_page_index(self):
@@ -269,7 +291,7 @@ class Paper:
269
  text = ''
270
  text_list = []
271
  section_dict = {}
272
-
273
  # 再处理其他章节:
274
  text_list = [page.get_text() for page in self.pdf]
275
  for sec_index, sec_name in enumerate(self.section_page_dict):
@@ -279,63 +301,90 @@ class Paper:
279
  else:
280
  # 直接考虑后面的内容:
281
  start_page = self.section_page_dict[sec_name]
282
- if sec_index < len(list(self.section_page_dict.keys()))-1:
283
- end_page = self.section_page_dict[list(self.section_page_dict.keys())[sec_index+1]]
 
284
  else:
285
  end_page = len(text_list)
286
  print("start_page, end_page:", start_page, end_page)
287
  cur_sec_text = ''
288
  if end_page - start_page == 0:
289
- if sec_index < len(list(self.section_page_dict.keys()))-1:
290
- next_sec = list(self.section_page_dict.keys())[sec_index+1]
 
 
291
  if text_list[start_page].find(sec_name) == -1:
292
- start_i = text_list[start_page].find(sec_name.upper())
 
293
  else:
294
  start_i = text_list[start_page].find(sec_name)
295
  if text_list[start_page].find(next_sec) == -1:
296
- end_i = text_list[start_page].find(next_sec.upper())
 
297
  else:
298
- end_i = text_list[start_page].find(next_sec)
299
  cur_sec_text += text_list[start_page][start_i:end_i]
300
  else:
301
- for page_i in range(start_page, end_page):
302
- # print("page_i:", page_i)
303
  if page_i == start_page:
304
  if text_list[start_page].find(sec_name) == -1:
305
- start_i = text_list[start_page].find(sec_name.upper())
 
306
  else:
307
  start_i = text_list[start_page].find(sec_name)
308
  cur_sec_text += text_list[page_i][start_i:]
309
  elif page_i < end_page:
310
  cur_sec_text += text_list[page_i]
311
  elif page_i == end_page:
312
- if sec_index < len(list(self.section_page_dict.keys()))-1:
313
- next_sec = list(self.section_page_dict.keys())[sec_index+1]
 
 
 
314
  if text_list[start_page].find(next_sec) == -1:
315
- end_i = text_list[start_page].find(next_sec.upper())
 
316
  else:
317
- end_i = text_list[start_page].find(next_sec)
 
318
  cur_sec_text += text_list[page_i][:end_i]
319
- section_dict[sec_name] = cur_sec_text.replace('-\n', '').replace('\n', ' ')
 
 
320
  return section_dict
321
 
 
322
  # 定义Reader类
323
  class Reader:
324
  # 初始化方法,设置属性
325
- def __init__(self, key_word='', query='', filter_keys='',
 
 
 
326
  root_path='./',
327
  gitee_key='',
328
- sort=arxiv.SortCriterion.SubmittedDate, user_name='defualt', language='cn', api_keys:list = [], model_name="gpt-3.5-turbo", p=1.0, temperature=1.0):
 
 
 
 
 
 
329
  self.api_keys = api_keys
330
- self.chatPaper = chatPaper( api_keys = self.api_keys, apiTimeInterval=10 , temperature=temperature,top_p=p,model_name=model_name) #openAI api封装
331
- self.user_name = user_name # 读者姓名
332
- self.key_word = key_word # 读者感兴趣的关键词
333
- self.query = query # 读者输入的搜索查询
334
- self.sort = sort # 读者选择的排序方式
335
- self.language = language # 读者选择的语言
336
- self.filter_keys = filter_keys # 用于在摘要中筛选的关键词
 
 
 
 
337
  self.root_path = root_path
338
- self.file_format = 'md' # or 'txt',如果为图片,则必须为'md'
339
  self.save_image = False
340
  if self.save_image:
341
  self.gitee_key = self.config.get('Gitee', 'api')
@@ -343,24 +392,25 @@ class Reader:
343
  self.gitee_key = ''
344
  self.max_token_num = 4096
345
  self.encoding = tiktoken.get_encoding("gpt2")
346
-
347
  def get_arxiv(self, max_results=30):
348
- search = arxiv.Search(query=self.query,
349
- max_results=max_results,
350
- sort_by=self.sort,
351
- sort_order=arxiv.SortOrder.Descending,
352
- )
 
353
  return search
354
-
355
  def filter_arxiv(self, max_results=30):
356
  search = self.get_arxiv(max_results=max_results)
357
  print("all search:")
358
  for index, result in enumerate(search.results()):
359
  print(index, result.title, result.updated)
360
-
361
- filter_results = []
362
  filter_keys = self.filter_keys
363
-
364
  print("filter_keys:", self.filter_keys)
365
  # 确保每个关键词都能在摘要中找到,才算是目标论文
366
  for index, result in enumerate(search.results()):
@@ -377,18 +427,20 @@ class Reader:
377
  for index, result in enumerate(filter_results):
378
  print(index, result.title, result.updated)
379
  return filter_results
380
-
381
  def validateTitle(self, title):
382
  # 将论文的乱七八糟的路径格式修正
383
- rstr = r"[\/\\\:\*\?\"\<\>\|]" # '/ \ : * ? " < > |'
384
- new_title = re.sub(rstr, "_", title) # 替换为下划线
385
  return new_title
386
 
387
  def download_pdf(self, filter_results):
388
  # 先创建文件夹
389
- date_str = str(datetime.datetime.now())[:13].replace(' ', '-')
390
- key_word = str(self.key_word.replace(':', ' '))
391
- path = self.root_path + 'pdf_files/' + self.query.replace('au: ', '').replace('title: ', '').replace('ti: ', '').replace(':', ' ')[:25] + '-' + date_str
 
 
392
  try:
393
  os.makedirs(path)
394
  except:
@@ -399,17 +451,18 @@ class Reader:
399
  for r_index, result in enumerate(filter_results):
400
  try:
401
  title_str = self.validateTitle(result.title)
402
- pdf_name = title_str+'.pdf'
403
  # result.download_pdf(path, filename=pdf_name)
404
  self.try_download_pdf(result, path, pdf_name)
405
  paper_path = os.path.join(path, pdf_name)
406
  print("paper_path:", paper_path)
407
- paper = Paper(path=paper_path,
408
- url=result.entry_id,
409
- title=result.title,
410
- abs=result.summary.replace('-\n', '-').replace('\n', ' '),
411
- authers=[str(aut) for aut in result.authors],
412
- )
 
413
  # 下载完毕,开始解析:
414
  paper.parse_pdf()
415
  paper_list.append(paper)
@@ -417,28 +470,31 @@ class Reader:
417
  print("download_error:", e)
418
  pass
419
  return paper_list
420
-
421
- @tenacity.retry(wait=tenacity.wait_exponential(multiplier=1, min=4, max=10),
 
422
  stop=tenacity.stop_after_attempt(5),
423
  reraise=True)
424
  def try_download_pdf(self, result, path, pdf_name):
425
  result.download_pdf(path, filename=pdf_name)
426
-
427
- @tenacity.retry(wait=tenacity.wait_exponential(multiplier=1, min=4, max=10),
 
428
  stop=tenacity.stop_after_attempt(5),
429
  reraise=True)
430
  def upload_gitee(self, image_path, image_name='', ext='png'):
431
  """
432
  上传到码云
433
  :return:
434
- """
435
  with open(image_path, 'rb') as f:
436
  base64_data = base64.b64encode(f.read())
437
  base64_content = base64_data.decode()
438
-
439
- date_str = str(datetime.datetime.now())[:19].replace(':', '-').replace(' ', '-') + '.' + ext
440
- path = image_name+ '-' +date_str
441
-
 
442
  payload = {
443
  "access_token": self.gitee_key,
444
  "owner": self.config.get('Gitee', 'owner'),
@@ -448,16 +504,23 @@ class Reader:
448
  "message": "upload image"
449
  }
450
  # 这里需要修改成你的gitee的账户和仓库名,以及文件夹的名字:
451
- url = f'https://gitee.com/api/v5/repos/'+self.config.get('Gitee', 'owner')+'/'+self.config.get('Gitee', 'repo')+'/contents/'+self.config.get('Gitee', 'path')+'/'+path
 
 
 
452
  rep = requests.post(url, json=payload).json()
453
  print("rep:", rep)
454
  if 'content' in rep.keys():
455
  image_url = rep['content']['download_url']
456
  else:
457
- image_url = r"https://gitee.com/api/v5/repos/"+self.config.get('Gitee', 'owner')+'/'+self.config.get('Gitee', 'repo')+'/contents/'+self.config.get('Gitee', 'path')+'/' + path
458
-
 
 
 
459
  return image_url
460
-
 
461
  def summary_with_chat(self, paper_list):
462
  htmls = []
463
  utoken = 0
@@ -474,48 +537,52 @@ class Reader:
474
  text += list(paper.section_text_dict.values())[0]
475
  #max_token = 2500 * 4
476
  #text = text[:max_token]
477
- chat_summary_text, utoken1, ctoken1, ttoken1 = self.chat_summary(text=text)
 
478
  htmls.append(chat_summary_text)
479
-
480
  # TODO 往md文档中插入论文里的像素最大的一张图片,这个方案可以弄的更加智能一些:
481
  method_key = ''
482
  for parse_key in paper.section_text_dict.keys():
483
- if 'method' in parse_key.lower() or 'approach' in parse_key.lower():
 
484
  method_key = parse_key
485
  break
486
-
487
  if method_key != '':
488
  text = ''
489
  method_text = ''
490
  summary_text = ''
491
  summary_text += "<summary>" + chat_summary_text
492
- # methods
493
- method_text += paper.section_text_dict[method_key]
494
- text = summary_text + "\n<Methods>:\n" + method_text
495
- chat_method_text, utoken2, ctoken2, ttoken2 = self.chat_method(text=text)
496
- htmls.append(chat_method_text)
497
  else:
498
  chat_method_text = ''
 
499
  htmls.append("\n")
500
-
501
  # 第三步总结全文,并打分:
502
  conclusion_key = ''
503
  for parse_key in paper.section_text_dict.keys():
504
  if 'conclu' in parse_key.lower():
505
  conclusion_key = parse_key
506
  break
507
-
508
  text = ''
509
  conclusion_text = ''
510
  summary_text = ''
511
- summary_text += "<summary>" + chat_summary_text + "\n <Method summary>:\n" + chat_method_text
512
  if conclusion_key != '':
513
- # conclusion
514
- conclusion_text += paper.section_text_dict[conclusion_key]
515
- text = summary_text + "\n <Conclusion>:\n" + conclusion_text
516
  else:
517
- text = summary_text
518
- chat_conclusion_text, utoken3, ctoken3, ttoken3 = self.chat_conclusion(text=text)
 
519
  htmls.append(chat_conclusion_text)
520
  htmls.append("\n")
521
  # token统计
@@ -524,26 +591,36 @@ class Reader:
524
  ttoken = ttoken + ttoken1 + ttoken2 + ttoken3
525
  cost = (ttoken / 1000) * 0.002
526
  pos_count = {
527
- "usage_token_used": str(utoken),
528
- "completion_token_used": str(ctoken),
529
- "total_token_used": str(ttoken),
530
- "cost": str(cost),
531
- }
532
  md_text = "\n".join(htmls)
533
  return markdown.markdown(md_text), pos_count
534
-
535
-
536
- @tenacity.retry(wait=tenacity.wait_exponential(multiplier=1, min=4, max=10),
537
  stop=tenacity.stop_after_attempt(5),
538
  reraise=True)
539
  def chat_conclusion(self, text):
540
- conclusion_prompt_token = 650
541
  text_token = len(self.encoding.encode(text))
542
- clip_text_index = int(len(text)*(self.max_token_num-conclusion_prompt_token)/text_token)
 
 
543
  clip_text = text[:clip_text_index]
544
- self.chatPaper.reset(convo_id="chatConclusion",system_prompt="You are a reviewer in the field of ["+self.key_word+"] and you need to critically review this article")
545
- self.chatPaper.add_to_conversation(convo_id="chatConclusion", role="assistant", message="This is the <summary> and <conclusion> part of an English literature, where <summary> you have already summarized, but <conclusion> part, I need your help to summarize the following questions:"+clip_text)# 背景知识,可以参考OpenReview的审稿流程
546
- content = """
 
 
 
 
 
 
 
 
547
  8. Make the following summary.Be sure to use Chinese answers (proper nouns need to be marked in English).
548
  - (1):What is the significance of this piece of work?
549
  - (2):Summarize the strengths and weaknesses of this article in three dimensions: innovation point, performance, and workload.
@@ -556,24 +633,37 @@ class Reader:
556
  Be sure to use Chinese answers (proper nouns need to be marked in English), statements as concise and academic as possible, do not repeat the content of the previous <summary>, the value of the use of the original numbers, be sure to strictly follow the format, the corresponding content output to xxx, in accordance with \n line feed, ....... means fill in according to the actual requirements, if not, you can not write.
557
  """
558
  result = self.chatPaper.ask(
559
- prompt = content,
560
  role="user",
561
  convo_id="chatConclusion",
562
  )
563
  print(result)
564
  return result[0], result[1], result[2], result[3]
565
-
566
- @tenacity.retry(wait=tenacity.wait_exponential(multiplier=1, min=4, max=10),
 
567
  stop=tenacity.stop_after_attempt(5),
568
  reraise=True)
569
  def chat_method(self, text):
570
- method_prompt_token = 650
571
  text_token = len(self.encoding.encode(text))
572
- clip_text_index = int(len(text)*(self.max_token_num-method_prompt_token)/text_token)
 
 
573
  clip_text = text[:clip_text_index]
574
- self.chatPaper.reset(convo_id="chatMethod",system_prompt="You are a researcher in the field of ["+self.key_word+"] who is good at summarizing papers using concise statements")# chatgpt 角色
575
- self.chatPaper.add_to_conversation(convo_id="chatMethod", role="assistant", message=str("This is the <summary> and <Method> part of an English document, where <summary> you have summarized, but the <Methods> part, I need your help to read and summarize the following questions."+clip_text))
576
- content= """
 
 
 
 
 
 
 
 
 
 
577
  7. Describe in detail the methodological idea of this article. Be sure to use Chinese answers (proper nouns need to be marked in English). For example, its steps are.
578
  - (1):...
579
  - (2):...
@@ -589,24 +679,36 @@ class Reader:
589
  Be sure to use Chinese answers (proper nouns need to be marked in English), statements as concise and academic as possible, do not repeat the content of the previous <summary>, the value of the use of the original numbers, be sure to strictly follow the format, the corresponding content output to xxx, in accordance with \n line feed, ....... means fill in according to the actual requirements, if not, you can not write.
590
  """
591
  result = self.chatPaper.ask(
592
- prompt = content,
593
  role="user",
594
  convo_id="chatMethod",
595
  )
596
  print(result)
597
  return result[0], result[1], result[2], result[3]
598
-
599
- @tenacity.retry(wait=tenacity.wait_exponential(multiplier=1, min=4, max=10),
 
600
  stop=tenacity.stop_after_attempt(5),
601
  reraise=True)
602
  def chat_summary(self, text):
603
- summary_prompt_token = 1000
604
  text_token = len(self.encoding.encode(text))
605
- clip_text_index = int(len(text)*(self.max_token_num-summary_prompt_token)/text_token)
 
 
606
  clip_text = text[:clip_text_index]
607
- self.chatPaper.reset(convo_id="chatSummary",system_prompt="You are a researcher in the field of ["+self.key_word+"] who is good at summarizing papers using concise statements")
608
- self.chatPaper.add_to_conversation(convo_id="chatSummary", role="assistant", message=str("This is the title, author, link, abstract and introduction of an English document. I need your help to read and summarize the following questions: "+clip_text))
609
- content= """
 
 
 
 
 
 
 
 
 
610
  1. Mark the title of the paper (with Chinese translation)
611
  2. list all the authors' names (use English)
612
  3. mark the first author's affiliation (output Chinese translation only)
@@ -630,39 +732,40 @@ class Reader:
630
  - (4):xxx.\n\n
631
 
632
  Be sure to use Chinese answers (proper nouns need to be marked in English), statements as concise and academic as possible, do not have too much repetitive information, numerical values using the original numbers, be sure to strictly follow the format, the corresponding content output to xxx, in accordance with \n line feed.
633
- """
634
  result = self.chatPaper.ask(
635
- prompt = content,
636
  role="user",
637
  convo_id="chatSummary",
638
  )
639
  print(result)
640
  return result[0], result[1], result[2], result[3]
641
-
642
  def export_to_markdown(self, text, file_name, mode='w'):
643
  # 使用markdown模块的convert方法,将文本转换为html格式
644
  # html = markdown.markdown(text)
645
  # 打开一个文件,以写入模式
646
  with open(file_name, mode, encoding="utf-8") as f:
647
  # 将html格式的内容写入文件
648
- f.write(text)
649
 
650
  # 定义一个方法,打印出读者信息
651
- def show_info(self):
652
  print(f"Key word: {self.key_word}")
653
  print(f"Query: {self.query}")
654
- print(f"Sort: {self.sort}")
 
655
 
656
  def upload_pdf(api_keys, text, model_name, p, temperature, file):
657
  # 检查两个输入都不为空
658
  api_key_list = None
659
  if api_keys:
660
  api_key_list = api_keys.split(',')
661
- elif not api_keys and valid_api_keys!=[]:
662
  api_key_list = valid_api_keys
663
  if not text or not file or not api_key_list:
664
  return "两个输入都不能为空,请输入字符并上传 PDF 文件!"
665
-
666
  # 判断PDF文件
667
  #if file and file.name.split(".")[-1].lower() != "pdf":
668
  # return '请勿上传非 PDF 文件!'
@@ -671,60 +774,81 @@ def upload_pdf(api_keys, text, model_name, p, temperature, file):
671
  paper_list = [Paper(path=file, sl=section_list)]
672
  # 创建一个Reader对象
673
  print(api_key_list)
674
- reader = Reader(api_keys=api_key_list, model_name=model_name, p=p, temperature=temperature)
675
- sum_info, cost = reader.summary_with_chat(paper_list=paper_list) # type: ignore
 
 
 
 
676
  return cost, sum_info
677
 
 
678
  api_title = "api-key可用验证"
679
  api_description = '''<div align='left'>
680
-
681
  <img src='https://visitor-badge.laobi.icu/badge?page_id=https://huggingface.co/spaces/wangrongsheng/ChatPaper'>
682
-
683
  <img align='right' src='https://i.328888.xyz/2023/03/12/vH9dU.png' width="150">
684
-
685
  Use ChatGPT to summary the papers.Star our Github [🌟ChatPaper](https://github.com/kaixindelele/ChatPaper) .
686
-
687
  💗如果您觉得我们的项目对您有帮助,还请您给我们一些鼓励!💗
688
-
689
  🔴请注意:千万不要用于严肃的学术场景,只能用于论文阅读前的初筛!
690
-
691
  </div>
692
  '''
693
 
694
  api_input = [
695
- gradio.inputs.Textbox(label="请输入你的API-key(必填, 多个API-key请用英文逗号隔开)", default="", type='password')
 
 
696
  ]
697
- api_gui = gradio.Interface(fn=valid_apikey, inputs=api_input, outputs="text", title=api_title, description=api_description)
 
 
 
 
698
 
699
  # 标题
700
  title = "ChatPaper"
701
  # 描述
702
  description = '''<div align='left'>
703
-
704
  <img src='https://visitor-badge.laobi.icu/badge?page_id=https://huggingface.co/spaces/wangrongsheng/ChatPaper'>
705
-
706
  <img align='right' src='https://i.328888.xyz/2023/03/12/vH9dU.png' width="150">
707
-
708
  Use ChatGPT to summary the papers.Star our Github [🌟ChatPaper](https://github.com/kaixindelele/ChatPaper) .
709
-
710
  💗如果您觉得我们的项目对您有帮助,还请您给我们一些鼓励!💗
711
-
712
  🔴请注意:千万不要用于严肃的学术场景,只能用于论文阅读前的初筛!
713
-
714
  </div>
715
  '''
716
  # 创建Gradio界面
717
  ip = [
718
- gradio.inputs.Textbox(label="请输入你的API-key(必填, 多个API-key请用英文逗号隔开),不需要空格", default="", type='password'),
719
- gradio.inputs.Textbox(label="请输入论文大标题索引(用英文逗号隔开,必填)", default="'Abstract,Introduction,Related Work,Background,Preliminary,Problem Formulation,Methods,Methodology,Method,Approach,Approaches,Materials and Methods,Experiment Settings,Experiment,Experimental Results,Evaluation,Experiments,Results,Findings,Data Analysis,Discussion,Results and Discussion,Conclusion,References'"),
720
- gradio.inputs.Radio(choices=["gpt-3.5-turbo", "gpt-3.5-turbo-0301"], default="gpt-3.5-turbo", label="Select model"),
721
- gradio.inputs.Slider(minimum=-0, maximum=1.0, default=1.0, step=0.05, label="Top-p (nucleus sampling)"),
722
- gradio.inputs.Slider(minimum=-0, maximum=5.0, default=0.5, step=0.5, label="Temperature"),
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
723
  gradio.inputs.File(label="请上传论文PDF(必填)")
724
  ]
725
 
726
- chatpaper_gui = gradio.Interface(fn=upload_pdf, inputs=ip, outputs=["json", "html"], title=title, description=description)
 
 
 
 
 
727
 
728
  # Start server
729
- gui = gradio.TabbedInterface(interface_list=[api_gui, chatpaper_gui], tab_names=["API-key", "ChatPaper"])
730
- gui.launch(quiet=True,show_api=False)
 
 
15
  import tiktoken
16
  import concurrent.futures
17
  from optimizeOpenAI import chatPaper
18
+
19
  def parse_text(text):
20
  lines = text.split("\n")
21
+ for i, line in enumerate(lines):
22
  if "```" in line:
23
  items = line.split('`')
24
  if items[-1]:
 
26
  else:
27
  lines[i] = f'</code></pre>'
28
  else:
29
+ if i > 0:
30
  line = line.replace("<", "&lt;")
31
  line = line.replace(">", "&gt;")
32
+ lines[i] = '<br/>' + line.replace(" ", "&nbsp;")
33
  return "".join(lines)
34
 
35
+
36
  # def get_response(system, context, myKey, raw = False):
37
  # openai.api_key = myKey
38
  # response = openai.ChatCompletion.create(
 
49
 
50
  valid_api_keys = []
51
 
52
+
53
  def api_key_check(api_key):
54
  try:
55
  chat = chatPaper([api_key])
 
60
  except:
61
  return None
62
 
63
+
64
  def valid_apikey(api_keys):
65
  api_keys = api_keys.replace(' ', '')
66
  api_key_list = api_keys.split(',')
67
  print(api_key_list)
68
  global valid_api_keys
69
  with concurrent.futures.ThreadPoolExecutor() as executor:
70
+ future_results = {
71
+ executor.submit(api_key_check, api_key): api_key
72
+ for api_key in api_key_list
73
+ }
74
  for future in concurrent.futures.as_completed(future_results):
75
  result = future.result()
76
  if result:
77
  valid_api_keys.append(result)
78
  if len(valid_api_keys) > 0:
79
+ return "有效的api-key一共有{}个,分别是:{}, 现在可以提交你的paper".format(
80
+ len(valid_api_keys), valid_api_keys)
81
  return "无效的api-key"
82
 
83
 
84
  class Paper:
85
+
86
  def __init__(self, path, title='', url='', abs='', authers=[], sl=[]):
87
+ # 初始化函数,根据pdf路径初始化Paper对象
88
+ self.url = url # 文章链接
89
+ self.path = path # pdf路径
90
  self.sl = sl
91
+ self.section_names = [] # 段落标题
92
+ self.section_texts = {} # 段落内容
93
  self.abs = abs
94
  self.title_page = 0
95
  if title == '':
96
+ self.pdf = fitz.open(self.path) # pdf文档
97
  self.title = self.get_title()
98
+ self.parse_pdf()
99
  else:
100
  self.title = title
101
+ self.authers = authers
102
+ self.roman_num = [
103
+ "I", "II", 'III', "IV", "V", "VI", "VII", "VIII", "IIX", "IX", "X"
104
+ ]
105
+ self.digit_num = [str(d + 1) for d in range(10)]
106
  self.first_image = ''
107
+
108
  def parse_pdf(self):
109
+ self.pdf = fitz.open(self.path) # pdf文档
110
  self.text_list = [page.get_text() for page in self.pdf]
111
  self.all_text = ' '.join(self.text_list)
112
+ self.section_page_dict = self._get_all_page_index() # 段落与页码的对应字典
113
  print("section_page_dict", self.section_page_dict)
114
+ self.section_text_dict = self._get_all_page() # 段落与内容的对应字典
115
  self.section_text_dict.update({"title": self.title})
116
  self.section_text_dict.update({"paper_info": self.get_paper_info()})
117
+ self.pdf.close()
118
+
119
  def get_paper_info(self):
120
  first_page_text = self.pdf[self.title_page].get_text()
121
  if "Abstract" in self.section_text_dict.keys():
 
123
  else:
124
  abstract_text = self.abs
125
  introduction_text = self.section_text_dict['Introduction']
126
+ first_page_text = first_page_text.replace(abstract_text, "").replace(
127
+ introduction_text, "")
128
  return first_page_text
129
+
130
  def get_image_path(self, image_path=''):
131
  """
132
  将PDF中的第一张���保存到image.png里面,存到本地目录,返回文件名称,供gitee读取
 
143
  # 查看独立页面
144
  page = my_pdf_file[page_number - 1]
145
  # 查看当前页所有图片
146
+ images = page.get_images()
147
  # 遍历当前页面所有图片
148
+ for image_number, image in enumerate(page.get_images(),
149
+ start=1):
150
  # 访问图片xref
151
  xref_value = image[0]
152
  # 提取图片信息
 
161
  if image_size > max_size:
162
  max_size = image_size
163
  image_list.append(image)
164
+ for image in image_list:
165
  image_size = image.size[0] * image.size[1]
166
+ if image_size == max_size:
167
  image_name = f"image.{ext}"
168
  im_path = os.path.join(image_path, image_name)
169
  print("im_path:", im_path)
170
+
171
  max_pix = 480
172
  origin_min_pix = min(image.size[0], image.size[1])
173
+
174
  if image.size[0] > image.size[1]:
175
+ min_pix = int(image.size[1] * (max_pix / image.size[0]))
176
  newsize = (max_pix, min_pix)
177
  else:
178
+ min_pix = int(image.size[0] * (max_pix / image.size[1]))
179
  newsize = (min_pix, max_pix)
180
  image = image.resize(newsize)
181
+
182
  image.save(open(im_path, "wb"))
183
  return im_path, ext
184
  return None, None
185
+
186
  # 定义一个函数,根据字体的大小,识别每个章节名称,并返回一个列表
187
+ def get_chapter_names(self, ):
188
  # # 打开一个pdf文件
189
+ doc = fitz.open(self.path) # pdf文档
190
  text_list = [page.get_text() for page in doc]
191
  all_text = ''
192
  for text in text_list:
 
199
  point_split_list = line.split('.')
200
  space_split_list = line.split(' ')
201
  if 1 < len(space_split_list) < 5:
202
+ if 1 < len(point_split_list) < 5 and (
203
+ point_split_list[0] in self.roman_num
204
+ or point_split_list[0] in self.digit_num):
205
  print("line:", line)
206
+ chapter_names.append(line)
207
+
208
  return chapter_names
209
+
210
  def get_title(self):
211
+ doc = self.pdf # 打开pdf文件
212
+ max_font_size = 0 # 初始化最大字体大小为0
213
+ max_string = "" # 初始化最大字体大小对应的字符串为空
214
  max_font_sizes = [0]
215
+ for page_index, page in enumerate(doc): # 遍历每一页
216
+ text = page.get_text("dict") # 获取页面上的文本信息
217
+ blocks = text["blocks"] # 获取文本块列表
218
+ for block in blocks: # 遍历每个文本块
219
+ if block["type"] == 0 and len(block['lines']): # 如果是文字类型
220
  if len(block["lines"][0]["spans"]):
221
+ font_size = block["lines"][0]["spans"][0][
222
+ "size"] # 获取第一行第一段文字的字体大小
223
  max_font_sizes.append(font_size)
224
+ if font_size > max_font_size: # 如果字体大小大于当前最大值
225
+ max_font_size = font_size # 更新最大值
226
+ max_string = block["lines"][0]["spans"][0][
227
+ "text"] # 更新最大值对应的字符串
228
+ max_font_sizes.sort()
229
  print("max_font_sizes", max_font_sizes[-10:])
230
  cur_title = ''
231
+ for page_index, page in enumerate(doc): # 遍历每一页
232
+ text = page.get_text("dict") # 获取页面上的文本信息
233
+ blocks = text["blocks"] # 获取文本块列表
234
+ for block in blocks: # 遍历每个文本块
235
+ if block["type"] == 0 and len(block['lines']): # 如果是文字类型
236
  if len(block["lines"][0]["spans"]):
237
+ cur_string = block["lines"][0]["spans"][0][
238
+ "text"] # 更新最大值对应的字符串
239
+ font_flags = block["lines"][0]["spans"][0][
240
+ "flags"] # 获取第一行第一段文字的字体特征
241
+ font_size = block["lines"][0]["spans"][0][
242
+ "size"] # 获取第一行第一段文字的字体大小
243
  # print(font_size)
244
+ if abs(font_size - max_font_sizes[-1]) < 0.3 or abs(
245
+ font_size - max_font_sizes[-2]) < 0.3:
246
+ # print("The string is bold.", max_string, "font_size:", font_size, "font_flags:", font_flags)
247
+ if len(cur_string
248
+ ) > 4 and "arXiv" not in cur_string:
249
+ # print("The string is bold.", max_string, "font_size:", font_size, "font_flags:", font_flags)
250
+ if cur_title == '':
251
+ cur_title += cur_string
252
  else:
253
+ cur_title += ' ' + cur_string
254
  self.title_page = page_index
255
 
256
+ title = cur_title.replace('\n', ' ')
257
  return title
258
 
259
  def _get_all_page_index(self):
 
291
  text = ''
292
  text_list = []
293
  section_dict = {}
294
+
295
  # 再处理其他章节:
296
  text_list = [page.get_text() for page in self.pdf]
297
  for sec_index, sec_name in enumerate(self.section_page_dict):
 
301
  else:
302
  # 直接考虑后面的内容:
303
  start_page = self.section_page_dict[sec_name]
304
+ if sec_index < len(list(self.section_page_dict.keys())) - 1:
305
+ end_page = self.section_page_dict[list(
306
+ self.section_page_dict.keys())[sec_index + 1]]
307
  else:
308
  end_page = len(text_list)
309
  print("start_page, end_page:", start_page, end_page)
310
  cur_sec_text = ''
311
  if end_page - start_page == 0:
312
+ if sec_index < len(list(
313
+ self.section_page_dict.keys())) - 1:
314
+ next_sec = list(
315
+ self.section_page_dict.keys())[sec_index + 1]
316
  if text_list[start_page].find(sec_name) == -1:
317
+ start_i = text_list[start_page].find(
318
+ sec_name.upper())
319
  else:
320
  start_i = text_list[start_page].find(sec_name)
321
  if text_list[start_page].find(next_sec) == -1:
322
+ end_i = text_list[start_page].find(
323
+ next_sec.upper())
324
  else:
325
+ end_i = text_list[start_page].find(next_sec)
326
  cur_sec_text += text_list[start_page][start_i:end_i]
327
  else:
328
+ for page_i in range(start_page, end_page):
329
+ # print("page_i:", page_i)
330
  if page_i == start_page:
331
  if text_list[start_page].find(sec_name) == -1:
332
+ start_i = text_list[start_page].find(
333
+ sec_name.upper())
334
  else:
335
  start_i = text_list[start_page].find(sec_name)
336
  cur_sec_text += text_list[page_i][start_i:]
337
  elif page_i < end_page:
338
  cur_sec_text += text_list[page_i]
339
  elif page_i == end_page:
340
+ if sec_index < len(
341
+ list(self.section_page_dict.keys())) - 1:
342
+ next_sec = list(
343
+ self.section_page_dict.keys())[sec_index +
344
+ 1]
345
  if text_list[start_page].find(next_sec) == -1:
346
+ end_i = text_list[start_page].find(
347
+ next_sec.upper())
348
  else:
349
+ end_i = text_list[start_page].find(
350
+ next_sec)
351
  cur_sec_text += text_list[page_i][:end_i]
352
+ section_dict[sec_name] = cur_sec_text.replace('-\n',
353
+ '').replace(
354
+ '\n', ' ')
355
  return section_dict
356
 
357
+
358
  # 定义Reader类
359
  class Reader:
360
  # 初始化方法,设置属性
361
+ def __init__(self,
362
+ key_word='',
363
+ query='',
364
+ filter_keys='',
365
  root_path='./',
366
  gitee_key='',
367
+ sort=arxiv.SortCriterion.SubmittedDate,
368
+ user_name='defualt',
369
+ language='cn',
370
+ api_keys: list = [],
371
+ model_name="gpt-3.5-turbo",
372
+ p=1.0,
373
+ temperature=1.0):
374
  self.api_keys = api_keys
375
+ self.chatPaper = chatPaper(api_keys=self.api_keys,
376
+ apiTimeInterval=10,
377
+ temperature=temperature,
378
+ top_p=p,
379
+ model_name=model_name) #openAI api封装
380
+ self.user_name = user_name # 读者姓名
381
+ self.key_word = key_word # 读者感兴趣的关键词
382
+ self.query = query # 读者输入的搜索查询
383
+ self.sort = sort # 读者选择的排序方式
384
+ self.language = language # 读者选择的语言
385
+ self.filter_keys = filter_keys # 用于在摘要中筛选的关键词
386
  self.root_path = root_path
387
+ self.file_format = 'md' # or 'txt',如果为图片,则必须为'md'
388
  self.save_image = False
389
  if self.save_image:
390
  self.gitee_key = self.config.get('Gitee', 'api')
 
392
  self.gitee_key = ''
393
  self.max_token_num = 4096
394
  self.encoding = tiktoken.get_encoding("gpt2")
395
+
396
  def get_arxiv(self, max_results=30):
397
+ search = arxiv.Search(
398
+ query=self.query,
399
+ max_results=max_results,
400
+ sort_by=self.sort,
401
+ sort_order=arxiv.SortOrder.Descending,
402
+ )
403
  return search
404
+
405
  def filter_arxiv(self, max_results=30):
406
  search = self.get_arxiv(max_results=max_results)
407
  print("all search:")
408
  for index, result in enumerate(search.results()):
409
  print(index, result.title, result.updated)
410
+
411
+ filter_results = []
412
  filter_keys = self.filter_keys
413
+
414
  print("filter_keys:", self.filter_keys)
415
  # 确保每个关键词都能在摘要中找到,才算是目标论文
416
  for index, result in enumerate(search.results()):
 
427
  for index, result in enumerate(filter_results):
428
  print(index, result.title, result.updated)
429
  return filter_results
430
+
431
  def validateTitle(self, title):
432
  # 将论文的乱七八糟的路径格式修正
433
+ rstr = r"[\/\\\:\*\?\"\<\>\|]" # '/ \ : * ? " < > |'
434
+ new_title = re.sub(rstr, "_", title) # 替换为下划线
435
  return new_title
436
 
437
  def download_pdf(self, filter_results):
438
  # 先创建文件夹
439
+ date_str = str(datetime.datetime.now())[:13].replace(' ', '-')
440
+ key_word = str(self.key_word.replace(':', ' '))
441
+ path = self.root_path + 'pdf_files/' + self.query.replace(
442
+ 'au: ', '').replace('title: ', '').replace('ti: ', '').replace(
443
+ ':', ' ')[:25] + '-' + date_str
444
  try:
445
  os.makedirs(path)
446
  except:
 
451
  for r_index, result in enumerate(filter_results):
452
  try:
453
  title_str = self.validateTitle(result.title)
454
+ pdf_name = title_str + '.pdf'
455
  # result.download_pdf(path, filename=pdf_name)
456
  self.try_download_pdf(result, path, pdf_name)
457
  paper_path = os.path.join(path, pdf_name)
458
  print("paper_path:", paper_path)
459
+ paper = Paper(
460
+ path=paper_path,
461
+ url=result.entry_id,
462
+ title=result.title,
463
+ abs=result.summary.replace('-\n', '-').replace('\n', ' '),
464
+ authers=[str(aut) for aut in result.authors],
465
+ )
466
  # 下载完毕,开始解析:
467
  paper.parse_pdf()
468
  paper_list.append(paper)
 
470
  print("download_error:", e)
471
  pass
472
  return paper_list
473
+
474
+ @tenacity.retry(wait=tenacity.wait_exponential(multiplier=1, min=4,
475
+ max=10),
476
  stop=tenacity.stop_after_attempt(5),
477
  reraise=True)
478
  def try_download_pdf(self, result, path, pdf_name):
479
  result.download_pdf(path, filename=pdf_name)
480
+
481
+ @tenacity.retry(wait=tenacity.wait_exponential(multiplier=1, min=4,
482
+ max=10),
483
  stop=tenacity.stop_after_attempt(5),
484
  reraise=True)
485
  def upload_gitee(self, image_path, image_name='', ext='png'):
486
  """
487
  上传到码云
488
  :return:
489
+ """
490
  with open(image_path, 'rb') as f:
491
  base64_data = base64.b64encode(f.read())
492
  base64_content = base64_data.decode()
493
+
494
+ date_str = str(datetime.datetime.now())[:19].replace(':', '-').replace(
495
+ ' ', '-') + '.' + ext
496
+ path = image_name + '-' + date_str
497
+
498
  payload = {
499
  "access_token": self.gitee_key,
500
  "owner": self.config.get('Gitee', 'owner'),
 
504
  "message": "upload image"
505
  }
506
  # 这里需要修改成你的gitee的账户和仓库名,以及文件夹的名字:
507
+ url = f'https://gitee.com/api/v5/repos/' + self.config.get(
508
+ 'Gitee', 'owner') + '/' + self.config.get(
509
+ 'Gitee', 'repo') + '/contents/' + self.config.get(
510
+ 'Gitee', 'path') + '/' + path
511
  rep = requests.post(url, json=payload).json()
512
  print("rep:", rep)
513
  if 'content' in rep.keys():
514
  image_url = rep['content']['download_url']
515
  else:
516
+ image_url = r"https://gitee.com/api/v5/repos/" + self.config.get(
517
+ 'Gitee', 'owner') + '/' + self.config.get(
518
+ 'Gitee', 'repo') + '/contents/' + self.config.get(
519
+ 'Gitee', 'path') + '/' + path
520
+
521
  return image_url
522
+
523
+
524
  def summary_with_chat(self, paper_list):
525
  htmls = []
526
  utoken = 0
 
537
  text += list(paper.section_text_dict.values())[0]
538
  #max_token = 2500 * 4
539
  #text = text[:max_token]
540
+ chat_summary_text, utoken1, ctoken1, ttoken1 = self.chat_summary(
541
+ text=text)
542
  htmls.append(chat_summary_text)
543
+
544
  # TODO 往md文档中插入论文里的像素最大的一张图片,这个方案可以弄的更加智能一些:
545
  method_key = ''
546
  for parse_key in paper.section_text_dict.keys():
547
+ if 'method' in parse_key.lower(
548
+ ) or 'approach' in parse_key.lower():
549
  method_key = parse_key
550
  break
551
+
552
  if method_key != '':
553
  text = ''
554
  method_text = ''
555
  summary_text = ''
556
  summary_text += "<summary>" + chat_summary_text
557
+ # methods
558
+ method_text += paper.section_text_dict[method_key]
559
+ text = summary_text + "\n<Methods>:\n" + method_text
560
+ chat_method_text, utoken2, ctoken2, ttoken2 = self.chat_method(
561
+ text=text)
562
  else:
563
  chat_method_text = ''
564
+ htmls.append(chat_method_text)
565
  htmls.append("\n")
566
+
567
  # 第三步总结全文,并打分:
568
  conclusion_key = ''
569
  for parse_key in paper.section_text_dict.keys():
570
  if 'conclu' in parse_key.lower():
571
  conclusion_key = parse_key
572
  break
573
+
574
  text = ''
575
  conclusion_text = ''
576
  summary_text = ''
577
+ summary_text += "<summary>" + chat_summary_text + "\n <Method summary>:\n" + chat_method_text
578
  if conclusion_key != '':
579
+ # conclusion
580
+ conclusion_text += paper.section_text_dict[conclusion_key]
581
+ text = summary_text + "\n <Conclusion>:\n" + conclusion_text
582
  else:
583
+ text = summary_text
584
+ chat_conclusion_text, utoken3, ctoken3, ttoken3 = self.chat_conclusion(
585
+ text=text)
586
  htmls.append(chat_conclusion_text)
587
  htmls.append("\n")
588
  # token统计
 
591
  ttoken = ttoken + ttoken1 + ttoken2 + ttoken3
592
  cost = (ttoken / 1000) * 0.002
593
  pos_count = {
594
+ "usage_token_used": str(utoken),
595
+ "completion_token_used": str(ctoken),
596
+ "total_token_used": str(ttoken),
597
+ "cost": str(cost),
598
+ }
599
  md_text = "\n".join(htmls)
600
  return markdown.markdown(md_text), pos_count
601
+
602
+ @tenacity.retry(wait=tenacity.wait_exponential(multiplier=1, min=4,
603
+ max=10),
604
  stop=tenacity.stop_after_attempt(5),
605
  reraise=True)
606
  def chat_conclusion(self, text):
607
+ conclusion_prompt_token = 650
608
  text_token = len(self.encoding.encode(text))
609
+ clip_text_index = int(
610
+ len(text) * (self.max_token_num - conclusion_prompt_token) /
611
+ text_token)
612
  clip_text = text[:clip_text_index]
613
+ self.chatPaper.reset(
614
+ convo_id="chatConclusion",
615
+ system_prompt="You are a reviewer in the field of [" +
616
+ self.key_word + "] and you need to critically review this article")
617
+ self.chatPaper.add_to_conversation(
618
+ convo_id="chatConclusion",
619
+ role="assistant",
620
+ message=
621
+ "This is the <summary> and <conclusion> part of an English literature, where <summary> you have already summarized, but <conclusion> part, I need your help to summarize the following questions:"
622
+ + clip_text) # 背景知识,可以参考OpenReview的审稿流程
623
+ content = """
624
  8. Make the following summary.Be sure to use Chinese answers (proper nouns need to be marked in English).
625
  - (1):What is the significance of this piece of work?
626
  - (2):Summarize the strengths and weaknesses of this article in three dimensions: innovation point, performance, and workload.
 
633
  Be sure to use Chinese answers (proper nouns need to be marked in English), statements as concise and academic as possible, do not repeat the content of the previous <summary>, the value of the use of the original numbers, be sure to strictly follow the format, the corresponding content output to xxx, in accordance with \n line feed, ....... means fill in according to the actual requirements, if not, you can not write.
634
  """
635
  result = self.chatPaper.ask(
636
+ prompt=content,
637
  role="user",
638
  convo_id="chatConclusion",
639
  )
640
  print(result)
641
  return result[0], result[1], result[2], result[3]
642
+
643
+ @tenacity.retry(wait=tenacity.wait_exponential(multiplier=1, min=4,
644
+ max=10),
645
  stop=tenacity.stop_after_attempt(5),
646
  reraise=True)
647
  def chat_method(self, text):
648
+ method_prompt_token = 650
649
  text_token = len(self.encoding.encode(text))
650
+ clip_text_index = int(
651
+ len(text) * (self.max_token_num - method_prompt_token) /
652
+ text_token)
653
  clip_text = text[:clip_text_index]
654
+ self.chatPaper.reset(
655
+ convo_id="chatMethod",
656
+ system_prompt="You are a researcher in the field of [" +
657
+ self.key_word +
658
+ "] who is good at summarizing papers using concise statements"
659
+ ) # chatgpt 角色
660
+ self.chatPaper.add_to_conversation(
661
+ convo_id="chatMethod",
662
+ role="assistant",
663
+ message=str(
664
+ "This is the <summary> and <Method> part of an English document, where <summary> you have summarized, but the <Methods> part, I need your help to read and summarize the following questions."
665
+ + clip_text))
666
+ content = """
667
  7. Describe in detail the methodological idea of this article. Be sure to use Chinese answers (proper nouns need to be marked in English). For example, its steps are.
668
  - (1):...
669
  - (2):...
 
679
  Be sure to use Chinese answers (proper nouns need to be marked in English), statements as concise and academic as possible, do not repeat the content of the previous <summary>, the value of the use of the original numbers, be sure to strictly follow the format, the corresponding content output to xxx, in accordance with \n line feed, ....... means fill in according to the actual requirements, if not, you can not write.
680
  """
681
  result = self.chatPaper.ask(
682
+ prompt=content,
683
  role="user",
684
  convo_id="chatMethod",
685
  )
686
  print(result)
687
  return result[0], result[1], result[2], result[3]
688
+
689
+ @tenacity.retry(wait=tenacity.wait_exponential(multiplier=1, min=4,
690
+ max=10),
691
  stop=tenacity.stop_after_attempt(5),
692
  reraise=True)
693
  def chat_summary(self, text):
694
+ summary_prompt_token = 1000
695
  text_token = len(self.encoding.encode(text))
696
+ clip_text_index = int(
697
+ len(text) * (self.max_token_num - summary_prompt_token) /
698
+ text_token)
699
  clip_text = text[:clip_text_index]
700
+ self.chatPaper.reset(
701
+ convo_id="chatSummary",
702
+ system_prompt="You are a researcher in the field of [" +
703
+ self.key_word +
704
+ "] who is good at summarizing papers using concise statements")
705
+ self.chatPaper.add_to_conversation(
706
+ convo_id="chatSummary",
707
+ role="assistant",
708
+ message=str(
709
+ "This is the title, author, link, abstract and introduction of an English document. I need your help to read and summarize the following questions: "
710
+ + clip_text))
711
+ content = """
712
  1. Mark the title of the paper (with Chinese translation)
713
  2. list all the authors' names (use English)
714
  3. mark the first author's affiliation (output Chinese translation only)
 
732
  - (4):xxx.\n\n
733
 
734
  Be sure to use Chinese answers (proper nouns need to be marked in English), statements as concise and academic as possible, do not have too much repetitive information, numerical values using the original numbers, be sure to strictly follow the format, the corresponding content output to xxx, in accordance with \n line feed.
735
+ """
736
  result = self.chatPaper.ask(
737
+ prompt=content,
738
  role="user",
739
  convo_id="chatSummary",
740
  )
741
  print(result)
742
  return result[0], result[1], result[2], result[3]
743
+
744
  def export_to_markdown(self, text, file_name, mode='w'):
745
  # 使用markdown模块的convert方法,将文本转换为html格式
746
  # html = markdown.markdown(text)
747
  # 打开一个文件,以写入模式
748
  with open(file_name, mode, encoding="utf-8") as f:
749
  # 将html格式的内容写入文件
750
+ f.write(text)
751
 
752
  # 定义一个方法,打印出读者信息
753
+ def show_info(self):
754
  print(f"Key word: {self.key_word}")
755
  print(f"Query: {self.query}")
756
+ print(f"Sort: {self.sort}")
757
+
758
 
759
  def upload_pdf(api_keys, text, model_name, p, temperature, file):
760
  # 检查两个输入都不为空
761
  api_key_list = None
762
  if api_keys:
763
  api_key_list = api_keys.split(',')
764
+ elif not api_keys and valid_api_keys != []:
765
  api_key_list = valid_api_keys
766
  if not text or not file or not api_key_list:
767
  return "两个输入都不能为空,请输入字符并上传 PDF 文件!"
768
+
769
  # 判断PDF文件
770
  #if file and file.name.split(".")[-1].lower() != "pdf":
771
  # return '请勿上传非 PDF 文件!'
 
774
  paper_list = [Paper(path=file, sl=section_list)]
775
  # 创建一个Reader对象
776
  print(api_key_list)
777
+ reader = Reader(api_keys=api_key_list,
778
+ model_name=model_name,
779
+ p=p,
780
+ temperature=temperature)
781
+ sum_info, cost = reader.summary_with_chat(
782
+ paper_list=paper_list) # type: ignore
783
  return cost, sum_info
784
 
785
+
786
  api_title = "api-key可用验证"
787
  api_description = '''<div align='left'>
 
788
  <img src='https://visitor-badge.laobi.icu/badge?page_id=https://huggingface.co/spaces/wangrongsheng/ChatPaper'>
 
789
  <img align='right' src='https://i.328888.xyz/2023/03/12/vH9dU.png' width="150">
 
790
  Use ChatGPT to summary the papers.Star our Github [🌟ChatPaper](https://github.com/kaixindelele/ChatPaper) .
 
791
  💗如果您觉得我们的项目对您有帮助,还请您给我们一些鼓励!💗
 
792
  🔴请注意:千万不要用于严肃的学术场景,只能用于论文阅读前的初筛!
 
793
  </div>
794
  '''
795
 
796
  api_input = [
797
+ gradio.inputs.Textbox(label="请输入你的API-key(必填, 多个API-key请用英文逗号隔开)",
798
+ default="",
799
+ type='password')
800
  ]
801
+ api_gui = gradio.Interface(fn=valid_apikey,
802
+ inputs=api_input,
803
+ outputs="text",
804
+ title=api_title,
805
+ description=api_description)
806
 
807
  # 标题
808
  title = "ChatPaper"
809
  # 描述
810
  description = '''<div align='left'>
 
811
  <img src='https://visitor-badge.laobi.icu/badge?page_id=https://huggingface.co/spaces/wangrongsheng/ChatPaper'>
 
812
  <img align='right' src='https://i.328888.xyz/2023/03/12/vH9dU.png' width="150">
 
813
  Use ChatGPT to summary the papers.Star our Github [🌟ChatPaper](https://github.com/kaixindelele/ChatPaper) .
 
814
  💗如果您觉得我们的项目对您有帮助,还请您给我们一些鼓励!💗
 
815
  🔴请注意:千万不要用于严肃的学术场景,只能用于论文阅读前的初筛!
 
816
  </div>
817
  '''
818
  # 创建Gradio界面
819
  ip = [
820
+ gradio.inputs.Textbox(label="请输入你的API-key(必填, 多个API-key请用英文逗号隔开),不需要空格",
821
+ default="",
822
+ type='password'),
823
+ gradio.inputs.Textbox(
824
+ label="请输入论文大标题索引(用英文逗号隔开,必填)",
825
+ default=
826
+ "'Abstract,Introduction,Related Work,Background,Preliminary,Problem Formulation,Methods,Methodology,Method,Approach,Approaches,Materials and Methods,Experiment Settings,Experiment,Experimental Results,Evaluation,Experiments,Results,Findings,Data Analysis,Discussion,Results and Discussion,Conclusion,References'"
827
+ ),
828
+ gradio.inputs.Radio(choices=["gpt-3.5-turbo", "gpt-3.5-turbo-0301"],
829
+ default="gpt-3.5-turbo",
830
+ label="Select model"),
831
+ gradio.inputs.Slider(minimum=-0,
832
+ maximum=1.0,
833
+ default=1.0,
834
+ step=0.05,
835
+ label="Top-p (nucleus sampling)"),
836
+ gradio.inputs.Slider(minimum=-0,
837
+ maximum=5.0,
838
+ default=0.5,
839
+ step=0.5,
840
+ label="Temperature"),
841
  gradio.inputs.File(label="请上传论文PDF(必填)")
842
  ]
843
 
844
+
845
+ chatpaper_gui = gradio.Interface(fn=upload_pdf,
846
+ inputs=ip,
847
+ outputs=["json", "html"],
848
+ title=title,
849
+ description=description)
850
 
851
  # Start server
852
+ gui = gradio.TabbedInterface(interface_list=[api_gui, chatpaper_gui],
853
+ tab_names=["API-key", "ChatPaper"])
854
+ gui.launch(quiet=True, show_api=False)
optimizeOpenAI.py CHANGED
@@ -74,7 +74,7 @@ class chatPaper:
74
  if(convo_id not in self.conversation):
75
  self.reset(convo_id)
76
  self.conversation[convo_id].append({"role": role, "content": message})
77
-
78
  def __truncate_conversation(self, convo_id: str = "default"):
79
  """
80
  Truncate the conversation
@@ -89,10 +89,10 @@ class chatPaper:
89
  full_conversation = "\n".join([str(x["content"]) for x in self.conversation[convo_id]],)
90
  if len(ENCODER.encode(full_conversation)) > self.max_tokens:
91
  self.conversation_summary(convo_id=convo_id)
 
 
 
92
  while True:
93
- full_conversation = ""
94
- for x in self.conversation[convo_id]:
95
- full_conversation = str(x["content"]) + "\n" + full_conversation
96
  if (len(ENCODER.encode(full_conversation+query)) > self.max_tokens):
97
  query = query[:self.decrease_step]
98
  else:
@@ -170,7 +170,7 @@ class chatPaper:
170
  "https://api.openai.com/v1/chat/completions",
171
  headers={"Authorization": f"Bearer {self.get_api_key()}"},
172
  json={
173
- "model": self.engine,
174
  "messages": [{"role": "system", "content": "You are a helpful assistant."},{"role": "user", "content": "print A"}],
175
  "stream": True,
176
  # kwargs
 
74
  if(convo_id not in self.conversation):
75
  self.reset(convo_id)
76
  self.conversation[convo_id].append({"role": role, "content": message})
77
+
78
  def __truncate_conversation(self, convo_id: str = "default"):
79
  """
80
  Truncate the conversation
 
89
  full_conversation = "\n".join([str(x["content"]) for x in self.conversation[convo_id]],)
90
  if len(ENCODER.encode(full_conversation)) > self.max_tokens:
91
  self.conversation_summary(convo_id=convo_id)
92
+ full_conversation = ""
93
+ for x in self.conversation[convo_id]:
94
+ full_conversation = str(x["content"]) + "\n" + full_conversation
95
  while True:
 
 
 
96
  if (len(ENCODER.encode(full_conversation+query)) > self.max_tokens):
97
  query = query[:self.decrease_step]
98
  else:
 
170
  "https://api.openai.com/v1/chat/completions",
171
  headers={"Authorization": f"Bearer {self.get_api_key()}"},
172
  json={
173
+ "model": self.model_name,
174
  "messages": [{"role": "system", "content": "You are a helpful assistant."},{"role": "user", "content": "print A"}],
175
  "stream": True,
176
  # kwargs