Spaces:
Runtime error
Runtime error
Merge pull request #9 from CCCBora/dev
Browse files- api_wrapper.py +13 -5
- app.py +5 -10
- auto_backgrounds.py +34 -16
- references_generator.py +10 -1
- section_generator.py +21 -17
- utils/gpt_interaction.py +30 -85
- utils/prompts.py +31 -18
- utils/references.py +9 -2
- worker.py +172 -0
api_wrapper.py
CHANGED
@@ -12,18 +12,26 @@ todo:
|
|
12 |
If `generator_wrapper` returns nothing or Timeout, or raise any error:
|
13 |
Change Task status from Running to Failed.
|
14 |
'''
|
|
|
15 |
|
16 |
from auto_backgrounds import generate_draft
|
17 |
-
import json
|
|
|
18 |
|
19 |
|
20 |
-
GENERATOR_MAPPING = {"draft": generate_draft}
|
|
|
21 |
|
22 |
def generator_wrapper(path_to_config_json):
|
23 |
# Read configuration file and call corresponding function
|
24 |
with open(path_to_config_json, "r", encoding='utf-8') as f:
|
25 |
config = json.load(f)
|
26 |
-
|
27 |
-
generator = GENERATOR_MAPPING.get(config["generator"])
|
|
|
28 |
if generator is None:
|
29 |
-
|
|
|
|
|
|
|
|
|
|
12 |
If `generator_wrapper` returns nothing or Timeout, or raise any error:
|
13 |
Change Task status from Running to Failed.
|
14 |
'''
|
15 |
+
import os.path
|
16 |
|
17 |
from auto_backgrounds import generate_draft
|
18 |
+
import json, time
|
19 |
+
from utils.file_operations import make_archive
|
20 |
|
21 |
|
22 |
+
# GENERATOR_MAPPING = {"draft": generate_draft}
|
23 |
+
GENERATOR_MAPPING = {"draft": None}
|
24 |
|
25 |
def generator_wrapper(path_to_config_json):
|
26 |
# Read configuration file and call corresponding function
|
27 |
with open(path_to_config_json, "r", encoding='utf-8') as f:
|
28 |
config = json.load(f)
|
29 |
+
print("Configuration:", config)
|
30 |
+
# generator = GENERATOR_MAPPING.get(config["generator"])
|
31 |
+
generator = None
|
32 |
if generator is None:
|
33 |
+
# generate a fake ZIP file and upload
|
34 |
+
time.sleep(150)
|
35 |
+
zip_path = os.path.splitext(path_to_config_json)[0]+".zip"
|
36 |
+
return make_archive(path_to_config_json, zip_path)
|
37 |
+
|
app.py
CHANGED
@@ -5,8 +5,8 @@ from auto_backgrounds import generate_backgrounds, generate_draft
|
|
5 |
from utils.file_operations import hash_name
|
6 |
from references_generator import generate_top_k_references
|
7 |
|
8 |
-
# note: App白屏bug:允许第三方cookie
|
9 |
# todo:
|
|
|
10 |
# 6. get logs when the procedure is not completed. *
|
11 |
# 7. 自己的文件库; 更多的prompts
|
12 |
# 8. Decide on how to generate the main part of a paper * (Langchain/AutoGPT
|
@@ -54,7 +54,7 @@ def clear_inputs_refs(*args):
|
|
54 |
|
55 |
|
56 |
def wrapped_generator(paper_title, paper_description, openai_api_key=None,
|
57 |
-
paper_template="ICLR2022", tldr=True,
|
58 |
cache_mode=IS_CACHE_AVAILABLE):
|
59 |
# if `cache_mode` is True, then follow the following steps:
|
60 |
# check if "title"+"description" have been generated before
|
@@ -82,16 +82,14 @@ def wrapped_generator(paper_title, paper_description, openai_api_key=None,
|
|
82 |
# generate the result.
|
83 |
# output = fake_generate_backgrounds(title, description, openai_key)
|
84 |
output = generate_draft(paper_title, paper_description, template=paper_template,
|
85 |
-
tldr=tldr,
|
86 |
-
sections=selected_sections, bib_refs=bib_refs, model=model)
|
87 |
# output = generate_draft(paper_title, paper_description, template, "gpt-4")
|
88 |
upload_file(output)
|
89 |
return output
|
90 |
else:
|
91 |
# output = fake_generate_backgrounds(title, description, openai_key)
|
92 |
output = generate_draft(paper_title, paper_description, template=paper_template,
|
93 |
-
tldr=tldr,
|
94 |
-
sections=selected_sections, bib_refs=bib_refs, model=model)
|
95 |
return output
|
96 |
|
97 |
|
@@ -170,9 +168,6 @@ with gr.Blocks(theme=theme) as demo:
|
|
170 |
|
171 |
title = gr.Textbox(value="Playing Atari with Deep Reinforcement Learning", lines=1, max_lines=1,
|
172 |
label="Title", info="论文标题")
|
173 |
-
|
174 |
-
slider = gr.Slider(minimum=1, maximum=100, value=20, step=1,
|
175 |
-
interactive=True, visible=False, label="最大参考文献数目")
|
176 |
with gr.Accordion("高级设置", open=False):
|
177 |
with gr.Row():
|
178 |
description_pp = gr.Textbox(lines=5, label="Description (Optional)", visible=True,
|
@@ -264,7 +259,7 @@ with gr.Blocks(theme=theme) as demo:
|
|
264 |
|
265 |
clear_button_pp.click(fn=clear_inputs, inputs=[title, description_pp], outputs=[title, description_pp])
|
266 |
submit_button_pp.click(fn=wrapped_generator,
|
267 |
-
inputs=[title, description_pp, key, template, tldr_checkbox,
|
268 |
model_selection], outputs=file_output)
|
269 |
|
270 |
clear_button_refs.click(fn=clear_inputs_refs, inputs=[title_refs, slider_refs], outputs=[title_refs, slider_refs])
|
|
|
5 |
from utils.file_operations import hash_name
|
6 |
from references_generator import generate_top_k_references
|
7 |
|
|
|
8 |
# todo:
|
9 |
+
# generation.log sometimes disappears
|
10 |
# 6. get logs when the procedure is not completed. *
|
11 |
# 7. 自己的文件库; 更多的prompts
|
12 |
# 8. Decide on how to generate the main part of a paper * (Langchain/AutoGPT
|
|
|
54 |
|
55 |
|
56 |
def wrapped_generator(paper_title, paper_description, openai_api_key=None,
|
57 |
+
paper_template="ICLR2022", tldr=True, selected_sections=None, bib_refs=None, model="gpt-4",
|
58 |
cache_mode=IS_CACHE_AVAILABLE):
|
59 |
# if `cache_mode` is True, then follow the following steps:
|
60 |
# check if "title"+"description" have been generated before
|
|
|
82 |
# generate the result.
|
83 |
# output = fake_generate_backgrounds(title, description, openai_key)
|
84 |
output = generate_draft(paper_title, paper_description, template=paper_template,
|
85 |
+
tldr=tldr, sections=selected_sections, bib_refs=bib_refs, model=model)
|
|
|
86 |
# output = generate_draft(paper_title, paper_description, template, "gpt-4")
|
87 |
upload_file(output)
|
88 |
return output
|
89 |
else:
|
90 |
# output = fake_generate_backgrounds(title, description, openai_key)
|
91 |
output = generate_draft(paper_title, paper_description, template=paper_template,
|
92 |
+
tldr=tldr, sections=selected_sections, bib_refs=bib_refs, model=model)
|
|
|
93 |
return output
|
94 |
|
95 |
|
|
|
168 |
|
169 |
title = gr.Textbox(value="Playing Atari with Deep Reinforcement Learning", lines=1, max_lines=1,
|
170 |
label="Title", info="论文标题")
|
|
|
|
|
|
|
171 |
with gr.Accordion("高级设置", open=False):
|
172 |
with gr.Row():
|
173 |
description_pp = gr.Textbox(lines=5, label="Description (Optional)", visible=True,
|
|
|
259 |
|
260 |
clear_button_pp.click(fn=clear_inputs, inputs=[title, description_pp], outputs=[title, description_pp])
|
261 |
submit_button_pp.click(fn=wrapped_generator,
|
262 |
+
inputs=[title, description_pp, key, template, tldr_checkbox, sections, bibtex_file,
|
263 |
model_selection], outputs=file_output)
|
264 |
|
265 |
clear_button_refs.click(fn=clear_inputs_refs, inputs=[title_refs, slider_refs], outputs=[title_refs, slider_refs])
|
auto_backgrounds.py
CHANGED
@@ -2,8 +2,7 @@ import os.path
|
|
2 |
from utils.references import References
|
3 |
from utils.file_operations import hash_name, make_archive, copy_templates
|
4 |
from utils.tex_processing import create_copies
|
5 |
-
from section_generator import
|
6 |
-
from references_generator import generate_top_k_references
|
7 |
import logging
|
8 |
import time
|
9 |
|
@@ -26,14 +25,16 @@ def log_usage(usage, generating_target, print_out=True):
|
|
26 |
TOTAL_PROMPTS_TOKENS += prompts_tokens
|
27 |
TOTAL_COMPLETION_TOKENS += completion_tokens
|
28 |
|
29 |
-
message = f"For generating {generating_target}, {total_tokens} tokens have been used
|
|
|
30 |
f"{TOTAL_TOKENS} tokens have been used in total.\n\n"
|
31 |
if print_out:
|
32 |
print(message)
|
33 |
logging.info(message)
|
34 |
|
|
|
35 |
def _generation_setup(title, description="", template="ICLR2022", tldr=False,
|
36 |
-
max_kw_refs=10,
|
37 |
"""
|
38 |
This function handles the setup process for paper generation; it contains three folds
|
39 |
1. Copy the template to the outputs folder. Create the log file `generation.log`
|
@@ -44,9 +45,12 @@ def _generation_setup(title, description="", template="ICLR2022", tldr=False,
|
|
44 |
title (str): The title of the paper.
|
45 |
description (str, optional): A short description or abstract for the paper. Defaults to an empty string.
|
46 |
template (str, optional): The template to be used for paper generation. Defaults to "ICLR2022".
|
47 |
-
tldr (bool, optional): A flag indicating whether a TL;DR (Too Long; Didn't Read) summary should be
|
48 |
-
|
49 |
-
|
|
|
|
|
|
|
50 |
bib_refs (list, optional): A list of pre-existing references in BibTeX format. Defaults to None.
|
51 |
|
52 |
Returns:
|
@@ -69,9 +73,8 @@ def _generation_setup(title, description="", template="ICLR2022", tldr=False,
|
|
69 |
keywords, usage = keywords_generation(input_dict)
|
70 |
log_usage(usage, "keywords")
|
71 |
|
72 |
-
# generate keywords dictionary
|
73 |
keywords = {keyword:max_kw_refs for keyword in keywords}
|
74 |
-
print(f"keywords: {keywords}\n\n")
|
75 |
|
76 |
ref = References(title, bib_refs)
|
77 |
ref.collect_papers(keywords, tldr=tldr)
|
@@ -109,23 +112,38 @@ def generate_backgrounds(title, description="", template="ICLR2022", model="gpt-
|
|
109 |
|
110 |
|
111 |
def generate_draft(title, description="", template="ICLR2022",
|
112 |
-
tldr=True, max_kw_refs=10,
|
|
|
|
|
|
|
|
|
|
|
113 |
# pre-processing `sections` parameter;
|
|
|
|
|
|
|
114 |
print("================PRE-PROCESSING================")
|
115 |
if sections is None:
|
116 |
sections = ["introduction", "related works", "backgrounds", "methodology", "experiments", "conclusion", "abstract"]
|
|
|
|
|
117 |
|
118 |
-
|
119 |
-
|
|
|
|
|
|
|
120 |
|
121 |
# main components
|
|
|
122 |
for section in sections:
|
123 |
-
print(f"
|
124 |
max_attempts = 4
|
125 |
attempts_count = 0
|
126 |
while attempts_count < max_attempts:
|
127 |
try:
|
128 |
usage = section_generation(paper, section, destination_folder, model=model)
|
|
|
129 |
log_usage(usage, section)
|
130 |
break
|
131 |
except Exception as e:
|
@@ -153,7 +171,7 @@ if __name__ == "__main__":
|
|
153 |
import openai
|
154 |
openai.api_key = os.getenv("OPENAI_API_KEY")
|
155 |
|
156 |
-
|
157 |
-
|
158 |
-
output = generate_draft(
|
159 |
print(output)
|
|
|
2 |
from utils.references import References
|
3 |
from utils.file_operations import hash_name, make_archive, copy_templates
|
4 |
from utils.tex_processing import create_copies
|
5 |
+
from section_generator import keywords_generation, section_generation # figures_generation, section_generation_bg,
|
|
|
6 |
import logging
|
7 |
import time
|
8 |
|
|
|
25 |
TOTAL_PROMPTS_TOKENS += prompts_tokens
|
26 |
TOTAL_COMPLETION_TOKENS += completion_tokens
|
27 |
|
28 |
+
message = f"For generating {generating_target}, {total_tokens} tokens have been used " \
|
29 |
+
f"({prompts_tokens} for prompts; {completion_tokens} for completion). " \
|
30 |
f"{TOTAL_TOKENS} tokens have been used in total.\n\n"
|
31 |
if print_out:
|
32 |
print(message)
|
33 |
logging.info(message)
|
34 |
|
35 |
+
|
36 |
def _generation_setup(title, description="", template="ICLR2022", tldr=False,
|
37 |
+
max_kw_refs=10, bib_refs=None, max_tokens=2048):
|
38 |
"""
|
39 |
This function handles the setup process for paper generation; it contains three folds
|
40 |
1. Copy the template to the outputs folder. Create the log file `generation.log`
|
|
|
45 |
title (str): The title of the paper.
|
46 |
description (str, optional): A short description or abstract for the paper. Defaults to an empty string.
|
47 |
template (str, optional): The template to be used for paper generation. Defaults to "ICLR2022".
|
48 |
+
tldr (bool, optional): A flag indicating whether a TL;DR (Too Long; Didn't Read) summary should be used
|
49 |
+
for the collected papers. Defaults to False.
|
50 |
+
max_kw_refs (int, optional): The maximum number of references that can be associated with each keyword.
|
51 |
+
Defaults to 10.
|
52 |
+
max_num_refs (int, optional): The maximum number of references that can be included in the paper.
|
53 |
+
Defaults to 50.
|
54 |
bib_refs (list, optional): A list of pre-existing references in BibTeX format. Defaults to None.
|
55 |
|
56 |
Returns:
|
|
|
73 |
keywords, usage = keywords_generation(input_dict)
|
74 |
log_usage(usage, "keywords")
|
75 |
|
76 |
+
# generate keywords dictionary # todo: in some rare situations, collected papers will be an empty list.
|
77 |
keywords = {keyword:max_kw_refs for keyword in keywords}
|
|
|
78 |
|
79 |
ref = References(title, bib_refs)
|
80 |
ref.collect_papers(keywords, tldr=tldr)
|
|
|
112 |
|
113 |
|
114 |
def generate_draft(title, description="", template="ICLR2022",
|
115 |
+
tldr=True, max_kw_refs=10, sections=None, bib_refs=None, model="gpt-4"):
|
116 |
+
|
117 |
+
def _filter_sections(sections):
|
118 |
+
ordered_sections = ["introduction", "related works", "backgrounds", "methodology", "experiments", "conclusion",
|
119 |
+
"abstract"]
|
120 |
+
return [section for section in ordered_sections if section in sections]
|
121 |
# pre-processing `sections` parameter;
|
122 |
+
print("================START================")
|
123 |
+
print(f"Generating the paper '{title}'.")
|
124 |
+
print("\n") # todo: use a configuration file to define parameters
|
125 |
print("================PRE-PROCESSING================")
|
126 |
if sections is None:
|
127 |
sections = ["introduction", "related works", "backgrounds", "methodology", "experiments", "conclusion", "abstract"]
|
128 |
+
else:
|
129 |
+
sections = _filter_sections(sections)
|
130 |
|
131 |
+
if model == "gpt-4":
|
132 |
+
max_tokens = 4096
|
133 |
+
else:
|
134 |
+
max_tokens = 2048
|
135 |
+
paper, destination_folder, _ = _generation_setup(title, description, template, tldr, max_kw_refs, bib_refs, max_tokens=max_tokens)
|
136 |
|
137 |
# main components
|
138 |
+
print(f"================PROCESSING================")
|
139 |
for section in sections:
|
140 |
+
print(f"Generate {section} part...")
|
141 |
max_attempts = 4
|
142 |
attempts_count = 0
|
143 |
while attempts_count < max_attempts:
|
144 |
try:
|
145 |
usage = section_generation(paper, section, destination_folder, model=model)
|
146 |
+
print(f"{section} part has been generated. ")
|
147 |
log_usage(usage, section)
|
148 |
break
|
149 |
except Exception as e:
|
|
|
171 |
import openai
|
172 |
openai.api_key = os.getenv("OPENAI_API_KEY")
|
173 |
|
174 |
+
target_title = "Using interpretable boosting algorithms for modeling environmental and agricultural data"
|
175 |
+
target_description = ""
|
176 |
+
output = generate_draft(target_title, target_description, tldr=True, max_kw_refs=10)
|
177 |
print(output)
|
references_generator.py
CHANGED
@@ -1,7 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
import os.path
|
2 |
import json
|
3 |
from utils.references import References
|
4 |
-
from section_generator import section_generation_bg,
|
5 |
import itertools
|
6 |
from gradio_client import Client
|
7 |
|
|
|
1 |
+
'''
|
2 |
+
This script is used to generate the most relevant papers of a given title.
|
3 |
+
- Search for as many as possible references. For 10~15 keywords, 10 references each.
|
4 |
+
- Sort the results from most relevant to least relevant.
|
5 |
+
- Return the most relevant using token size.
|
6 |
+
|
7 |
+
Note: we do not use this function in auto-draft function. It has been integrated in that.
|
8 |
+
'''
|
9 |
+
|
10 |
import os.path
|
11 |
import json
|
12 |
from utils.references import References
|
13 |
+
from section_generator import keywords_generation # section_generation_bg, #, figures_generation, section_generation
|
14 |
import itertools
|
15 |
from gradio_client import Client
|
16 |
|
section_generator.py
CHANGED
@@ -1,5 +1,6 @@
|
|
1 |
-
from utils.prompts import generate_paper_prompts, generate_keywords_prompts, generate_experiments_prompts, generate_bg_summary_prompts
|
2 |
-
from utils.
|
|
|
3 |
from utils.figures import generate_random_figures
|
4 |
import time
|
5 |
import os
|
@@ -15,8 +16,10 @@ import json
|
|
15 |
|
16 |
MAX_ATTEMPTS = 6
|
17 |
|
|
|
18 |
def section_generation_bg(paper, section, save_to_path, model):
|
19 |
"""
|
|
|
20 |
The main pipeline of generating a section.
|
21 |
1. Generate prompts.
|
22 |
2. Get responses from AI assistant.
|
@@ -26,8 +29,9 @@ def section_generation_bg(paper, section, save_to_path, model):
|
|
26 |
"""
|
27 |
print(f"Generating {section}...")
|
28 |
prompts = generate_bg_summary_prompts(paper, section)
|
29 |
-
gpt_response, usage = get_responses(prompts, model)
|
30 |
-
|
|
|
31 |
paper["body"][section] = output
|
32 |
tex_file = os.path.join(save_to_path, f"{section}.tex")
|
33 |
# tex_file = save_to_path + f"/{section}.tex"
|
@@ -58,8 +62,8 @@ def section_generation(paper, section, save_to_path, model, research_field="mach
|
|
58 |
:return usage
|
59 |
"""
|
60 |
prompts = generate_paper_prompts(paper, section)
|
61 |
-
output, usage= get_gpt_responses(SECTION_GENERATION_SYSTEM.format(research_field=research_field), prompts,
|
62 |
-
|
63 |
paper["body"][section] = output
|
64 |
tex_file = os.path.join(save_to_path, f"{section}.tex")
|
65 |
with open(tex_file, "w") as f:
|
@@ -69,7 +73,7 @@ def section_generation(paper, section, save_to_path, model, research_field="mach
|
|
69 |
|
70 |
|
71 |
def keywords_generation(input_dict, default_keywords=None):
|
72 |
-
|
73 |
Input:
|
74 |
input_dict: a dictionary containing the title of a paper.
|
75 |
default_keywords: if anything went wrong, return this keywords.
|
@@ -79,13 +83,13 @@ def keywords_generation(input_dict, default_keywords=None):
|
|
79 |
|
80 |
Input example: {"title": "The title of a Machine Learning Paper"}
|
81 |
Output Example: {"machine learning": 5, "reinforcement learning": 2}
|
82 |
-
|
83 |
title = input_dict.get("title")
|
84 |
attempts_count = 0
|
85 |
while (attempts_count < MAX_ATTEMPTS) and (title is not None):
|
86 |
try:
|
87 |
-
keywords, usage= get_gpt_responses(KEYWORDS_SYSTEM.format(min_refs_num=1, max_refs_num=10), title,
|
88 |
-
|
89 |
print(keywords)
|
90 |
output = json.loads(keywords)
|
91 |
return output.keys(), usage
|
@@ -99,10 +103,10 @@ def keywords_generation(input_dict, default_keywords=None):
|
|
99 |
else:
|
100 |
return default_keywords
|
101 |
|
102 |
-
def figures_generation(paper, save_to_path, model):
|
103 |
-
|
104 |
-
|
105 |
-
|
106 |
-
|
107 |
-
|
108 |
-
|
|
|
1 |
+
# from utils.prompts import generate_paper_prompts, generate_keywords_prompts, generate_experiments_prompts, generate_bg_summary_prompts
|
2 |
+
from utils.prompts import generate_paper_prompts, generate_bg_summary_prompts
|
3 |
+
# from utils.gpt_interaction import get_responses #, extract_responses, extract_keywords, extract_json
|
4 |
from utils.figures import generate_random_figures
|
5 |
import time
|
6 |
import os
|
|
|
16 |
|
17 |
MAX_ATTEMPTS = 6
|
18 |
|
19 |
+
|
20 |
def section_generation_bg(paper, section, save_to_path, model):
|
21 |
"""
|
22 |
+
todo: this part should be revised
|
23 |
The main pipeline of generating a section.
|
24 |
1. Generate prompts.
|
25 |
2. Get responses from AI assistant.
|
|
|
29 |
"""
|
30 |
print(f"Generating {section}...")
|
31 |
prompts = generate_bg_summary_prompts(paper, section)
|
32 |
+
# gpt_response, usage = get_responses(prompts, model)
|
33 |
+
gpt_response, usage = get_gpt_responses(prompts, model)
|
34 |
+
output = gpt_response # extract_responses(gpt_response)
|
35 |
paper["body"][section] = output
|
36 |
tex_file = os.path.join(save_to_path, f"{section}.tex")
|
37 |
# tex_file = save_to_path + f"/{section}.tex"
|
|
|
62 |
:return usage
|
63 |
"""
|
64 |
prompts = generate_paper_prompts(paper, section)
|
65 |
+
output, usage = get_gpt_responses(SECTION_GENERATION_SYSTEM.format(research_field=research_field), prompts,
|
66 |
+
model=model, temperature=0.4)
|
67 |
paper["body"][section] = output
|
68 |
tex_file = os.path.join(save_to_path, f"{section}.tex")
|
69 |
with open(tex_file, "w") as f:
|
|
|
73 |
|
74 |
|
75 |
def keywords_generation(input_dict, default_keywords=None):
|
76 |
+
"""
|
77 |
Input:
|
78 |
input_dict: a dictionary containing the title of a paper.
|
79 |
default_keywords: if anything went wrong, return this keywords.
|
|
|
83 |
|
84 |
Input example: {"title": "The title of a Machine Learning Paper"}
|
85 |
Output Example: {"machine learning": 5, "reinforcement learning": 2}
|
86 |
+
"""
|
87 |
title = input_dict.get("title")
|
88 |
attempts_count = 0
|
89 |
while (attempts_count < MAX_ATTEMPTS) and (title is not None):
|
90 |
try:
|
91 |
+
keywords, usage = get_gpt_responses(KEYWORDS_SYSTEM.format(min_refs_num=1, max_refs_num=10), title,
|
92 |
+
model="gpt-3.5-turbo", temperature=0.4)
|
93 |
print(keywords)
|
94 |
output = json.loads(keywords)
|
95 |
return output.keys(), usage
|
|
|
103 |
else:
|
104 |
return default_keywords
|
105 |
|
106 |
+
# def figures_generation(paper, save_to_path, model):
|
107 |
+
# # todo: this function is not complete.
|
108 |
+
# prompts = generate_experiments_prompts(paper)
|
109 |
+
# gpt_response, usage = get_responses(prompts, model)
|
110 |
+
# list_of_methods = list(extract_json(gpt_response))
|
111 |
+
# generate_random_figures(list_of_methods, os.path.join(save_to_path, "comparison.png"))
|
112 |
+
# return usage
|
utils/gpt_interaction.py
CHANGED
@@ -1,80 +1,11 @@
|
|
|
|
1 |
import openai
|
2 |
-
import re
|
3 |
-
import json
|
4 |
import logging
|
|
|
5 |
|
6 |
-
log = logging.getLogger(__name__)
|
7 |
-
|
8 |
-
|
9 |
-
def extract_responses(assistant_message):
|
10 |
-
# pattern = re.compile(r"f\.write\(r'{1,3}(.*?)'{0,3}\){0,1}$", re.DOTALL)
|
11 |
-
pattern = re.compile(r"f\.write\(r['\"]{1,3}(.*?)['\"]{0,3}\){0,1}$", re.DOTALL)
|
12 |
-
match = re.search(pattern, assistant_message)
|
13 |
-
if match:
|
14 |
-
return match.group(1)
|
15 |
-
else:
|
16 |
-
log.info("Responses are not put in Python codes. Directly return assistant_message.\n")
|
17 |
-
log.info(f"assistant_message: {assistant_message}")
|
18 |
-
return assistant_message
|
19 |
-
|
20 |
-
|
21 |
-
def extract_keywords(assistant_message, default_keywords=None):
|
22 |
-
if default_keywords is None:
|
23 |
-
default_keywords = {"machine learning": 5}
|
24 |
-
|
25 |
-
try:
|
26 |
-
keywords = json.loads(assistant_message)
|
27 |
-
except ValueError:
|
28 |
-
log.info("Responses are not in json format. Return the default dictionary.\n ")
|
29 |
-
log.info(f"assistant_message: {assistant_message}")
|
30 |
-
return default_keywords
|
31 |
-
return keywords
|
32 |
-
|
33 |
-
|
34 |
-
def extract_section_name(assistant_message, default_section_name=""):
|
35 |
-
try:
|
36 |
-
keywords = json.loads(assistant_message)
|
37 |
-
except ValueError:
|
38 |
-
log.info("Responses are not in json format. Return None.\n ")
|
39 |
-
log.info(f"assistant_message: {assistant_message}")
|
40 |
-
return default_section_name
|
41 |
-
return keywords
|
42 |
|
|
|
43 |
|
44 |
-
def extract_json(assistant_message, default_output=None):
|
45 |
-
if default_output is None:
|
46 |
-
default_keys = ["Method 1", "Method 2"]
|
47 |
-
else:
|
48 |
-
default_keys = default_output
|
49 |
-
try:
|
50 |
-
dict = json.loads(assistant_message)
|
51 |
-
except:
|
52 |
-
log.info("Responses are not in json format. Return the default keys.\n ")
|
53 |
-
log.info(f"assistant_message: {assistant_message}")
|
54 |
-
return default_keys
|
55 |
-
return dict.keys()
|
56 |
-
|
57 |
-
|
58 |
-
def get_responses(user_message, model="gpt-4", temperature=0.4, openai_key=None):
|
59 |
-
if openai.api_key is None and openai_key is None:
|
60 |
-
raise ValueError("OpenAI API key must be provided.")
|
61 |
-
if openai_key is not None:
|
62 |
-
openai.api_key = openai_key
|
63 |
-
|
64 |
-
conversation_history = [
|
65 |
-
{"role": "system", "content": "You are an assistant in writing machine learning papers."}
|
66 |
-
]
|
67 |
-
conversation_history.append({"role": "user", "content": user_message})
|
68 |
-
response = openai.ChatCompletion.create(
|
69 |
-
model=model,
|
70 |
-
messages=conversation_history,
|
71 |
-
n=1, # Number of responses you want to generate
|
72 |
-
temperature=temperature, # Controls the creativity of the generated response
|
73 |
-
)
|
74 |
-
assistant_message = response['choices'][0]["message"]["content"]
|
75 |
-
usage = response['usage']
|
76 |
-
log.info(assistant_message)
|
77 |
-
return assistant_message, usage
|
78 |
|
79 |
def get_gpt_responses(systems, prompts, model="gpt-4", temperature=0.4):
|
80 |
conversation_history = [
|
@@ -93,17 +24,31 @@ def get_gpt_responses(systems, prompts, model="gpt-4", temperature=0.4):
|
|
93 |
return assistant_message, usage
|
94 |
|
95 |
|
96 |
-
|
97 |
-
|
98 |
-
|
99 |
-
|
100 |
-
|
101 |
-
|
102 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
103 |
|
104 |
-
|
105 |
-
|
106 |
-
|
107 |
-
|
108 |
-
|
109 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
import openai
|
|
|
|
|
3 |
import logging
|
4 |
+
import requests
|
5 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
6 |
|
7 |
+
log = logging.getLogger(__name__)
|
8 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
9 |
|
10 |
def get_gpt_responses(systems, prompts, model="gpt-4", temperature=0.4):
|
11 |
conversation_history = [
|
|
|
24 |
return assistant_message, usage
|
25 |
|
26 |
|
27 |
+
def get_gpt_responses_test(systems, prompts, model="gpt-4", temperature=0.4, base_url=None, key=None):
|
28 |
+
end_point = r"/v1/chat/completions"
|
29 |
+
if base_url is None:
|
30 |
+
base_url = r"https://api.openai.com" + end_point
|
31 |
+
if key is None:
|
32 |
+
key = os.getenv("OPENAI_API_KEY")
|
33 |
+
|
34 |
+
url = base_url + end_point
|
35 |
+
|
36 |
+
headers = {
|
37 |
+
'Content-Type': 'application/json',
|
38 |
+
'Authorization': f'Bearer {key}' # <-- 把 fkxxxxx 替换成你自己的 Forward Key,注意前面的 Bearer 要保留,并且和 Key 中间有一个空格。
|
39 |
+
}
|
40 |
|
41 |
+
message = [{"role": "system", "content": systems},
|
42 |
+
{"role": "user", "content": prompts}]
|
43 |
+
data = {
|
44 |
+
"model": model,
|
45 |
+
"message": message,
|
46 |
+
"temperature": temperature
|
47 |
+
}
|
48 |
+
response = requests.post(url, headers=headers, json=data)
|
49 |
+
response = response.json()
|
50 |
+
return response['choices'][0]["message"]["content"]
|
51 |
+
|
52 |
+
|
53 |
+
if __name__ == "__main__":
|
54 |
+
pass
|
utils/prompts.py
CHANGED
@@ -9,23 +9,23 @@ log = logging.getLogger(__name__)
|
|
9 |
######################################################################################################################
|
10 |
# Some basic functions
|
11 |
######################################################################################################################
|
12 |
-
def generate_keywords_prompts(title, description="", num_refs=5):
|
13 |
-
|
14 |
-
|
15 |
-
|
16 |
-
|
17 |
-
|
18 |
-
def generate_rename_prompts(paper_info, section):
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
def generate_experiments_prompts(paper_info):
|
25 |
-
|
26 |
-
|
27 |
-
|
28 |
-
|
29 |
|
30 |
|
31 |
|
@@ -68,6 +68,10 @@ SECTION_GENERATION_SYSTEM = PromptTemplate(input_variables=["research_field"],
|
|
68 |
# Academic Paper
|
69 |
######################################################################################################################
|
70 |
|
|
|
|
|
|
|
|
|
71 |
INSTRUCTIONS = {"introduction":
|
72 |
"- Include five paragraph: Establishing the motivation for the research. Explaining its importance and relevance to the AI community. Clearly state the problem you're addressing, your proposed solution, and the specific research questions or objectives. Briefly mention key related works for context and explain the main differences from this work. List three novel contributions of this paper.",
|
73 |
"results":
|
@@ -213,4 +217,13 @@ def generate_bg_summary_prompts(paper_info, section):
|
|
213 |
raise NotImplementedError
|
214 |
|
215 |
log.info(f"Generated prompts for {section}: {prompts}")
|
216 |
-
return prompts
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
9 |
######################################################################################################################
|
10 |
# Some basic functions
|
11 |
######################################################################################################################
|
12 |
+
# def generate_keywords_prompts(title, description="", num_refs=5):
|
13 |
+
# prompts = f"I am writing a machine learning paper with the title '{title}'. {description}\n" \
|
14 |
+
# f"Generate three to five keywords. For each keyword, rate it from 1 to {num_refs}; the larger number means more important." \
|
15 |
+
# r"Your response must be in JSON format like {\"keyword1\":1, \"keyword2\":3}."
|
16 |
+
# return prompts
|
17 |
+
#
|
18 |
+
# def generate_rename_prompts(paper_info, section):
|
19 |
+
# prompts = f"Please read the {section} section of the paper {paper_info['title']}: {paper_info['body'][section]}. \n" \
|
20 |
+
# f"You need to rename this section to make it more specific to the context. " \
|
21 |
+
# r"Response in a dictionary format like {\"option_1\": \"new_section_name_1\", \"option_2\": \"new_section_name_2\", ...}."
|
22 |
+
# return prompts
|
23 |
+
#
|
24 |
+
# def generate_experiments_prompts(paper_info):
|
25 |
+
# prompts = f"I am writing a machine learning paper with the title {paper_info['title']}\n" \
|
26 |
+
# f"Please list two to four methods that I should compare my methods with and assign them with scores (5 means most related, 1 means least related). " \
|
27 |
+
# r"Response in a dictionary format like {\"method_name_1\": 2, \"method_name_2\": 5, ...}. Use abbreviation to make their names have 5 characters or less."
|
28 |
+
# return prompts
|
29 |
|
30 |
|
31 |
|
|
|
68 |
# Academic Paper
|
69 |
######################################################################################################################
|
70 |
|
71 |
+
# When generating Academic Paper. Load instructions.
|
72 |
+
# with open("../prompts/instructions.json", "r") as f:
|
73 |
+
# INSTRUCTIONS = json.load(f)
|
74 |
+
#
|
75 |
INSTRUCTIONS = {"introduction":
|
76 |
"- Include five paragraph: Establishing the motivation for the research. Explaining its importance and relevance to the AI community. Clearly state the problem you're addressing, your proposed solution, and the specific research questions or objectives. Briefly mention key related works for context and explain the main differences from this work. List three novel contributions of this paper.",
|
77 |
"results":
|
|
|
217 |
raise NotImplementedError
|
218 |
|
219 |
log.info(f"Generated prompts for {section}: {prompts}")
|
220 |
+
return prompts
|
221 |
+
|
222 |
+
if __name__ == "__main__":
|
223 |
+
# import json
|
224 |
+
# with open("../prompts/instructions.json", "w") as f:
|
225 |
+
# json.dump(INSTRUCTIONS, f)
|
226 |
+
import json
|
227 |
+
with open("../prompts/instructions.json", "r") as f:
|
228 |
+
ins = json.load(f)
|
229 |
+
print(ins == INSTRUCTIONS)
|
utils/references.py
CHANGED
@@ -27,6 +27,7 @@ from scholarly import ProxyGenerator
|
|
27 |
import tiktoken
|
28 |
import itertools, uuid, json
|
29 |
from gradio_client import Client
|
|
|
30 |
|
31 |
|
32 |
######################################################################################################################
|
@@ -251,8 +252,10 @@ class References:
|
|
251 |
comb_keywords = list(itertools.combinations(keywords, 2))
|
252 |
for comb_keyword in comb_keywords:
|
253 |
keywords.append(" ".join(comb_keyword))
|
|
|
254 |
for key in keywords:
|
255 |
self.papers[key] = _collect_papers_ss(key, 10, tldr)
|
|
|
256 |
# for key, counts in keywords_dict.items():
|
257 |
# self.papers[key] = _collect_papers_ss(key, counts, tldr)
|
258 |
|
@@ -334,8 +337,12 @@ class References:
|
|
334 |
prompts = {}
|
335 |
tokens = 0
|
336 |
for paper in result:
|
337 |
-
|
338 |
-
|
|
|
|
|
|
|
|
|
339 |
if tokens >= max_tokens:
|
340 |
break
|
341 |
return prompts
|
|
|
27 |
import tiktoken
|
28 |
import itertools, uuid, json
|
29 |
from gradio_client import Client
|
30 |
+
import time
|
31 |
|
32 |
|
33 |
######################################################################################################################
|
|
|
252 |
comb_keywords = list(itertools.combinations(keywords, 2))
|
253 |
for comb_keyword in comb_keywords:
|
254 |
keywords.append(" ".join(comb_keyword))
|
255 |
+
print("Keywords: ", keywords)
|
256 |
for key in keywords:
|
257 |
self.papers[key] = _collect_papers_ss(key, 10, tldr)
|
258 |
+
print("Collected papers: ", papers)
|
259 |
# for key, counts in keywords_dict.items():
|
260 |
# self.papers[key] = _collect_papers_ss(key, counts, tldr)
|
261 |
|
|
|
337 |
prompts = {}
|
338 |
tokens = 0
|
339 |
for paper in result:
|
340 |
+
abstract = paper.get("abstract")
|
341 |
+
if abstract is not None and isinstance(abstract, str):
|
342 |
+
prompts[paper["paper_id"]] = paper["abstract"]
|
343 |
+
tokens += tiktoken_len(paper["abstract"])
|
344 |
+
else:
|
345 |
+
prompts[paper["paper_id"]] = " "
|
346 |
if tokens >= max_tokens:
|
347 |
break
|
348 |
return prompts
|
worker.py
ADDED
@@ -0,0 +1,172 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
'''
|
2 |
+
This script is only used for service-side host.
|
3 |
+
'''
|
4 |
+
import boto3
|
5 |
+
import os, time
|
6 |
+
from api_wrapper import generator_wrapper
|
7 |
+
from sqlalchemy import create_engine, Table, MetaData, update, select
|
8 |
+
from sqlalchemy.orm import sessionmaker
|
9 |
+
from sqlalchemy import inspect
|
10 |
+
|
11 |
+
QUEUE_URL = os.getenv('QUEUE_URL')
|
12 |
+
AWS_ACCESS_KEY_ID = os.getenv('AWS_ACCESS_KEY_ID')
|
13 |
+
AWS_SECRET_ACCESS_KEY = os.getenv('AWS_SECRET_ACCESS_KEY')
|
14 |
+
BUCKET_NAME = os.getenv('BUCKET_NAME')
|
15 |
+
DB_STRING = os.getenv('DATABASE_STRING')
|
16 |
+
|
17 |
+
# Create engine
|
18 |
+
ENGINE = create_engine(DB_STRING)
|
19 |
+
SESSION = sessionmaker(bind=ENGINE)
|
20 |
+
|
21 |
+
|
22 |
+
#######################################################################################################################
|
23 |
+
# Amazon SQS Handler
|
24 |
+
#######################################################################################################################
|
25 |
+
def get_sqs_client():
|
26 |
+
sqs = boto3.client('sqs', region_name="us-east-2",
|
27 |
+
aws_access_key_id=AWS_ACCESS_KEY_ID,
|
28 |
+
aws_secret_access_key=AWS_SECRET_ACCESS_KEY)
|
29 |
+
return sqs
|
30 |
+
|
31 |
+
|
32 |
+
def receive_message():
|
33 |
+
sqs = get_sqs_client()
|
34 |
+
message = sqs.receive_message(QueueUrl=QUEUE_URL)
|
35 |
+
if message.get('Messages') is not None:
|
36 |
+
receipt_handle = message['Messages'][0]['ReceiptHandle']
|
37 |
+
else:
|
38 |
+
receipt_handle = None
|
39 |
+
return message, receipt_handle
|
40 |
+
|
41 |
+
|
42 |
+
def delete_message(receipt_handle):
|
43 |
+
sqs = get_sqs_client()
|
44 |
+
response = sqs.delete_message(QueueUrl=QUEUE_URL, ReceiptHandle=receipt_handle)
|
45 |
+
return response
|
46 |
+
|
47 |
+
|
48 |
+
#######################################################################################################################
|
49 |
+
# AWS S3 Handler
|
50 |
+
#######################################################################################################################
|
51 |
+
def get_s3_client():
|
52 |
+
access_key_id = os.getenv('AWS_ACCESS_KEY_ID')
|
53 |
+
secret_access_key = os.getenv('AWS_SECRET_ACCESS_KEY')
|
54 |
+
session = boto3.Session(
|
55 |
+
aws_access_key_id=access_key_id,
|
56 |
+
aws_secret_access_key=secret_access_key,
|
57 |
+
)
|
58 |
+
s3 = session.resource('s3')
|
59 |
+
bucket = s3.Bucket(BUCKET_NAME)
|
60 |
+
return s3, bucket
|
61 |
+
|
62 |
+
|
63 |
+
def upload_file(file_name, target_name=None):
|
64 |
+
s3, _ = get_s3_client()
|
65 |
+
|
66 |
+
if target_name is None:
|
67 |
+
target_name = file_name
|
68 |
+
s3.meta.client.upload_file(Filename=file_name, Bucket=BUCKET_NAME, Key=target_name)
|
69 |
+
print(f"The file {file_name} has been uploaded!")
|
70 |
+
|
71 |
+
|
72 |
+
def download_file(file_name):
|
73 |
+
""" Download `file_name` from the bucket.
|
74 |
+
Bucket (str) – The name of the bucket to download from.
|
75 |
+
Key (str) – The name of the key to download from.
|
76 |
+
Filename (str) – The path to the file to download to.
|
77 |
+
"""
|
78 |
+
s3, _ = get_s3_client()
|
79 |
+
s3.meta.client.download_file(Bucket=BUCKET_NAME, Key=file_name, Filename=os.path.basename(file_name))
|
80 |
+
print(f"The file {file_name} has been downloaded!")
|
81 |
+
|
82 |
+
|
83 |
+
#######################################################################################################################
|
84 |
+
# AWS SQL Handler
|
85 |
+
#######################################################################################################################
|
86 |
+
def modify_status(task_id, new_status):
|
87 |
+
session = SESSION()
|
88 |
+
metadata = MetaData()
|
89 |
+
task_to_update = task_id
|
90 |
+
task_table = Table('task', metadata, autoload_with=ENGINE)
|
91 |
+
stmt = select(task_table).where(task_table.c.task_id == task_to_update)
|
92 |
+
# Execute the statement
|
93 |
+
with ENGINE.connect() as connection:
|
94 |
+
result = connection.execute(stmt)
|
95 |
+
|
96 |
+
# Fetch the first result (if exists)
|
97 |
+
task_data = result.fetchone()
|
98 |
+
|
99 |
+
# If user_data is not None, the user exists and we can update the password
|
100 |
+
if task_data:
|
101 |
+
# Update statement
|
102 |
+
stmt = (
|
103 |
+
update(task_table).
|
104 |
+
where(task_table.c.task_id == task_to_update).
|
105 |
+
values(status=new_status)
|
106 |
+
)
|
107 |
+
# Execute the statement and commit
|
108 |
+
result = connection.execute(stmt)
|
109 |
+
connection.commit()
|
110 |
+
# Close the session
|
111 |
+
session.close()
|
112 |
+
|
113 |
+
#######################################################################################################################
|
114 |
+
# Pipline
|
115 |
+
#######################################################################################################################
|
116 |
+
def pipeline(message_count=0, query_interval=10):
|
117 |
+
# status: 0 - pending (default), 1 - running, 2 - completed, 3 - failed
|
118 |
+
|
119 |
+
# Query a message from SQS
|
120 |
+
msg, handle = receive_message()
|
121 |
+
if handle is None:
|
122 |
+
print("No message in SQS. ")
|
123 |
+
time.sleep(query_interval)
|
124 |
+
else:
|
125 |
+
print("===============================================================================================")
|
126 |
+
print(f"MESSAGE COUNT: {message_count}")
|
127 |
+
print("===============================================================================================")
|
128 |
+
config_s3_path = msg['Messages'][0]['Body']
|
129 |
+
config_s3_dir = os.path.dirname(config_s3_path)
|
130 |
+
config_local_path = os.path.basename(config_s3_path)
|
131 |
+
task_id, _ = os.path.splitext(config_local_path)
|
132 |
+
|
133 |
+
print("Initializing ...")
|
134 |
+
print("Configuration file on S3: ", config_s3_path)
|
135 |
+
print("Configuration file on S3 (Directory): ", config_s3_dir)
|
136 |
+
print("Local file path: ", config_local_path)
|
137 |
+
print("Task id: ", task_id)
|
138 |
+
|
139 |
+
print(f"Success in receiving message: {msg}")
|
140 |
+
print(f"Configuration file path: {config_s3_path}")
|
141 |
+
|
142 |
+
# Process the downloaded configuration file
|
143 |
+
download_file(config_s3_path)
|
144 |
+
modify_status(task_id, 1) # status: 0 - pending (default), 1 - running, 2 - completed, 3 - failed
|
145 |
+
delete_message(handle)
|
146 |
+
print(f"Success in the initialization. Message deleted.")
|
147 |
+
|
148 |
+
print("Running ...")
|
149 |
+
# try:
|
150 |
+
zip_path = generator_wrapper(config_local_path)
|
151 |
+
# Upload the generated file to S3
|
152 |
+
upload_to = os.path.join(config_s3_dir, zip_path).replace("\\", "/")
|
153 |
+
|
154 |
+
print("Local file path (ZIP): ", zip_path)
|
155 |
+
print("Upload to S3: ", upload_to)
|
156 |
+
upload_file(zip_path, upload_to)
|
157 |
+
modify_status(task_id, 2) # status: 0 - pending (default), 1 - running, 2 - completed, 3 - failed, 4 - deleted
|
158 |
+
print(f"Success in generating the paper.")
|
159 |
+
|
160 |
+
# Complete.
|
161 |
+
print("Task completed.")
|
162 |
+
|
163 |
+
|
164 |
+
def initialize_everything():
|
165 |
+
# Clear S3
|
166 |
+
|
167 |
+
# Clear SQS
|
168 |
+
pass
|
169 |
+
|
170 |
+
|
171 |
+
if __name__ == "__main__":
|
172 |
+
pipeline()
|