|
import functools |
|
from utils.paper_retriever import RetrieverFactory |
|
from utils.paper_client import PaperClient |
|
from utils.llms_api import APIHelper |
|
from utils.header import ConfigReader |
|
from omegaconf import OmegaConf |
|
import click |
|
import json |
|
from loguru import logger |
|
import warnings |
|
import time |
|
import os |
|
from utils.hash import check_env, check_embedding |
|
import threading |
|
|
|
warnings.filterwarnings("ignore") |
|
|
|
|
|
def extract_problem(problem, background): |
|
start_keyword = "**Research Problem**" |
|
end_keyword = "**Rationales**" |
|
start_index = problem.find(start_keyword) |
|
end_index = problem.find(end_keyword) |
|
if start_index != -1 and end_index != -1: |
|
research_problem = problem[start_index:end_index].strip() |
|
else: |
|
research_problem = background |
|
return research_problem |
|
|
|
|
|
def extract_ideas(idea_str): |
|
if idea_str is None: |
|
return "" |
|
ideas = [] |
|
for i in range(1, 100): |
|
start_word = f"**Idea {i}" |
|
end_word = f"**Idea {i+1}" |
|
start_index = idea_str.find(start_word) |
|
end_index = idea_str.find(end_word) |
|
if start_index != -1 and end_index != -1: |
|
ideas.append(idea_str[start_index:end_index].strip()) |
|
|
|
elif start_index != -1: |
|
ideas.append(idea_str[start_index:].strip()) |
|
break |
|
else: |
|
break |
|
return ideas if ideas else [idea_str] |
|
|
|
|
|
class IdeaGenerator: |
|
def __init__( |
|
self, |
|
config, |
|
paper_list: list[dict] = [], |
|
brainstorm: str = None, |
|
) -> None: |
|
self.api_helper = APIHelper(config) |
|
self.paper_list = paper_list |
|
self.brainstorm = brainstorm |
|
|
|
def generate_without_cue_words(self, background: str): |
|
"""Generate ideas without cue words and brainstorm |
|
""" |
|
problem, message_input = self.api_helper.generate_problem( |
|
background, self.paper_list |
|
) |
|
idea = self.api_helper.generate_idea(problem, self.paper_list) |
|
idea_filtered = self.api_helper.filter_idea(idea, background) |
|
return message_input, problem, idea, idea_filtered |
|
|
|
def generate_without_cue_words_bs(self, background: str): |
|
"""Generate ideas without cue words, but brainstorm |
|
""" |
|
problem, message_input = self.api_helper.generate_problem( |
|
background, self.paper_list |
|
) |
|
idea = self.api_helper.generate_idea(problem, self.paper_list) |
|
idea_filtered = self.api_helper.integrate_idea( |
|
background, self.brainstorm, idea |
|
) |
|
return message_input, problem, idea, idea_filtered |
|
|
|
def generate_without_cue_words_ins(self, background: str): |
|
"""Generate ideas without cue words and brainstorm, but inspiration |
|
""" |
|
problem, message_input = self.api_helper.generate_problem( |
|
background, self.paper_list |
|
) |
|
research_problem = extract_problem(problem, background) |
|
inspirations = [] |
|
for paper in self.paper_list: |
|
inspiration = self.api_helper.generate_inspiration(research_problem, paper) |
|
inspirations.append(inspiration) |
|
idea = self.api_helper.generate_idea_by_inspiration(problem, inspirations) |
|
idea_filtered = self.api_helper.filter_idea(idea, background) |
|
return message_input, problem, inspirations, idea, idea_filtered |
|
|
|
def generate_without_cue_words_ins_bs(self, background: str): |
|
"""Generate ideas without cue words, but inspiration and brainstorm |
|
""" |
|
problem, message_input = self.api_helper.generate_problem( |
|
background, self.paper_list |
|
) |
|
research_problem = extract_problem(problem, background) |
|
inspirations = [] |
|
for paper in self.paper_list: |
|
inspiration = self.api_helper.generate_inspiration(research_problem, paper) |
|
inspirations.append(inspiration) |
|
idea = self.api_helper.generate_idea_by_inspiration(problem, inspirations) |
|
idea_filtered = self.api_helper.integrate_idea( |
|
background, self.brainstorm, idea |
|
) |
|
return message_input, problem, inspirations, idea, idea_filtered |
|
|
|
def generate_ins_bs(self, detail_background: str): |
|
"""Generate ideas with inspiration and brainstorm |
|
""" |
|
inspirations = [] |
|
|
|
|
|
processes = [] |
|
def generate_inspiration(paper, i): |
|
detail_method = self.api_helper.generate_concise_method(paper["methodology"]) |
|
inspiration = self.api_helper.generate_inspiration_with_detail_method(detail_background, detail_method) |
|
logger.info(f"Generate inspiration for related paper {i} succeed") |
|
if not(inspiration.startswith("None") or (len(inspiration) < 100 and "None" in inspiration)): |
|
inspirations.append(inspiration) |
|
|
|
for i, paper in enumerate(self.paper_list): |
|
p = threading.Thread(target=generate_inspiration, args=(paper, i)) |
|
processes.append(p) |
|
p.start() |
|
for p in processes: |
|
p.join(120) |
|
|
|
|
|
logger.info("Generate inspirations for all related papers succeed") |
|
idea = self.api_helper.generate_idea_by_inspiration(detail_background, inspirations) |
|
initial_ideas = extract_ideas(idea) |
|
logger.info("Generate ideas from inspirations succeed") |
|
idea_filtered = self.api_helper.integrate_idea(detail_background, self.brainstorm, idea) |
|
logger.info("Idea integration succeed") |
|
|
|
|
|
ideas_filtered = extract_ideas(idea_filtered) |
|
final_ideas = ["None"] * len(ideas_filtered) |
|
def expand_idea(detail_background: str, idea: str, i): |
|
final_ideas[i] = self.api_helper.expand_idea(detail_background, idea) |
|
logger.info(f"Expand the {i}th idea succeed") |
|
processes = [] |
|
for i, idea in enumerate(ideas_filtered): |
|
p = threading.Thread(target=expand_idea, args=(detail_background, idea, i)) |
|
processes.append(p) |
|
p.start() |
|
for p in processes: |
|
p.join(120) |
|
|
|
|
|
return None, None, inspirations, initial_ideas, ideas_filtered, final_ideas |
|
|
|
def generate( |
|
self, |
|
background: str, |
|
mode: str, |
|
bs_mode: str = None, |
|
use_cue_words: bool = False, |
|
): |
|
mode_name = None |
|
if mode == "backtracking": |
|
mode_name = "Backtrack" |
|
elif mode == "new_idea": |
|
mode_name = "Generate new idea" |
|
if bs_mode == "mode_a": |
|
logger.info( |
|
"{} using brainstorm_mode_a without cue words.".format(mode_name) |
|
) |
|
(message_input, problem, idea, idea_filtered) = ( |
|
self.generate_without_cue_words(background) |
|
) |
|
elif bs_mode == "mode_b" or bs_mode == "mode_c": |
|
logger.info( |
|
"{} using brainstorm_{} without cue words.".format( |
|
mode_name, bs_mode |
|
) |
|
) |
|
(message_input, problem, idea, idea_filtered) = ( |
|
self.generate_without_cue_words_bs(background) |
|
) |
|
|
|
idea_modified = self.api_helper.modify_idea(background, idea_filtered) |
|
median = { |
|
"problem": problem, |
|
"initial_idea": idea, |
|
"filtered_idea": idea_filtered, |
|
} |
|
return message_input, idea_modified, median |
|
|
|
def generate_by_inspiration( |
|
self, |
|
background: str, |
|
mode: str, |
|
bs_mode: str = None, |
|
use_cue_words: bool = False, |
|
): |
|
mode_name = None |
|
if mode == "backtracking": |
|
mode_name = "Backtrack" |
|
elif mode == "new_idea": |
|
mode_name = "Generate new idea" |
|
if bs_mode == "mode_a": |
|
logger.info( |
|
"{} using brainstorm_mode_a without cue words.".format(mode_name) |
|
) |
|
(message_input, problem, inspirations, idea, idea_filtered) = ( |
|
self.generate_without_cue_words_ins(background) |
|
) |
|
elif bs_mode == "mode_b" or bs_mode == "mode_c": |
|
logger.info( |
|
"{} using brainstorm_{} without cue words.".format( |
|
mode_name, bs_mode |
|
) |
|
) |
|
(message_input, problem, inspirations, idea, idea_filtered) = ( |
|
self.generate_without_cue_words_ins_bs(background) |
|
) |
|
|
|
idea_modified = self.api_helper.modify_idea(background, idea_filtered) |
|
median = { |
|
"problem": problem, |
|
"inspirations": inspirations, |
|
"initial_idea": idea, |
|
"filtered_idea": idea_filtered, |
|
} |
|
return message_input, idea_modified, median |
|
|
|
|
|
@click.group() |
|
@click.pass_context |
|
def main(ctx): |
|
""" |
|
Training and evaluation |
|
""" |
|
print("Mode:", ctx.invoked_subcommand) |
|
|
|
|
|
@main.command() |
|
@click.option( |
|
"-c", |
|
"--config-path", |
|
default="./configs/datasets.yaml", |
|
type=click.File(), |
|
required=True, |
|
help="Dataset configuration file in YAML", |
|
) |
|
@click.option( |
|
"--ids-path", |
|
default="./assets/data/test_background.json", |
|
type=click.File(), |
|
required=True, |
|
help="Dataset configuration file in YAML", |
|
) |
|
@click.option( |
|
"--out-path", |
|
default="./assets/output_idea/", |
|
type=str, |
|
required=True, |
|
help="Dataset configuration file in YAML", |
|
) |
|
@click.option( |
|
"--out-file", |
|
default="out-file.json", |
|
type=str, |
|
required=True, |
|
help="Dataset configuration file in YAML", |
|
) |
|
@click.option( |
|
"-r", |
|
"--retriever-name", |
|
default="SNKG", |
|
type=str, |
|
required=True, |
|
help="Retrieve method", |
|
) |
|
@click.option( |
|
"--brainstorm-mode", |
|
default="mode_c", |
|
type=str, |
|
required=True, |
|
help="Choose your brainstorm mode (mode_a: no brainstorm, mode_b: brainstorm for idea generation, mode_c: brainstorm for idea generation and retrival)", |
|
) |
|
@click.option( |
|
"--use-inspiration", |
|
default=False, |
|
type=bool, |
|
required=True, |
|
help="Use inspiration in generation", |
|
) |
|
@click.option( |
|
"--expand-intermediate", |
|
default=False, |
|
type=bool, |
|
help="The number of data you want to process", |
|
) |
|
@click.option( |
|
"--num", |
|
default=100, |
|
type=int, |
|
required=False, |
|
help="The number of data you want to process", |
|
) |
|
def new_idea( |
|
config_path, |
|
ids_path, |
|
out_path, |
|
out_file, |
|
retriever_name, |
|
brainstorm_mode, |
|
use_inspiration, |
|
expand_intermediate, |
|
num, |
|
**kwargs, |
|
): |
|
check_env() |
|
logger.add( |
|
"log/generate_{}_{}.log".format(time.time(), retriever_name), level="DEBUG" |
|
) |
|
logger.info("Retrieve name: {}".format(retriever_name)) |
|
|
|
config = ConfigReader.load(config_path, **kwargs) |
|
api_helper = APIHelper(config) |
|
paper_client = PaperClient() |
|
check_embedding(config.DEFAULT.embedding) |
|
eval_data = [] |
|
cur_num = 0 |
|
data_num = 0 |
|
batch_size = 1 |
|
bg_ids = set() |
|
os.makedirs(out_path, exist_ok=True) |
|
output_file = os.path.join( |
|
out_path, out_file |
|
) |
|
if os.path.exists(output_file): |
|
with open(output_file, "r", encoding="utf-8") as f: |
|
try: |
|
eval_data = json.load(f) |
|
bg_ids = {data["background"] for data in eval_data} |
|
cur_num = len(eval_data) |
|
except json.JSONDecodeError: |
|
eval_data = [] |
|
logger.debug(f"{cur_num} datas have been processed.") |
|
all_input = json.load(ids_path) |
|
for line in all_input: |
|
|
|
|
|
data = line |
|
|
|
if "background" in data.keys(): |
|
bg = data["background"] |
|
else: |
|
data_num += 1 |
|
print(f"This data doesn't have background...") |
|
continue |
|
if bg in bg_ids: |
|
data_num += 1 |
|
print(f"Skipping already processed data_{data_num}.") |
|
continue |
|
|
|
|
|
entities = api_helper.generate_entity_list(bg) |
|
|
|
|
|
keywords_str = functools.reduce(lambda x, y: f"{x}, {y}", entities) |
|
expanded_background = api_helper.expand_background(bg, keywords_str) |
|
|
|
|
|
if brainstorm_mode == "mode_b" or brainstorm_mode == "mode_c": |
|
brainstorm = api_helper.generate_brainstorm(expanded_background) |
|
seperate_brainstorm = extract_ideas(brainstorm) |
|
|
|
expanded_brainstorms = [] |
|
if expand_intermediate: |
|
for i, sb in enumerate(seperate_brainstorm): |
|
expanded_brainstorms.append(api_helper.expand_idea(expanded_background, sb)) |
|
logger.info(f"Expand the {i}th brainstorm succeed") |
|
else: |
|
brainstorm = None |
|
|
|
|
|
logger.debug("Original entities from background: {}".format(entities)) |
|
if brainstorm_mode == "mode_c": |
|
entities_bs = api_helper.generate_entity_list(brainstorm, 10) |
|
logger.debug("Original entities from brainstorm: {}".format(entities_bs)) |
|
entities_all = list(set(entities) | set(entities_bs)) |
|
else: |
|
entities_bs = None |
|
entities_all = entities |
|
|
|
|
|
rt = RetrieverFactory.get_retriever_factory().create_retriever( |
|
retriever_name, config |
|
) |
|
result = rt.retrieve( |
|
expanded_background, entities_all, need_evaluate=False, target_paper_id_list=[] |
|
) |
|
related_paper = result["related_paper"] |
|
logger.info("Find {} related papers...".format(len(related_paper))) |
|
entities_rt = result["entities"] |
|
for paper in related_paper: |
|
if not ("detail_method" in paper): |
|
paper["detail_method"] = api_helper.generate_concise_method(paper["methodology"]) |
|
if isinstance(paper["detail_method"], str): |
|
paper_client.insert_new_field(paper["hash_id"], "detail_method", paper["detail_method"]) |
|
logger.info(f"Add new field detail method to paper: {paper['hash_id']} succeed") |
|
logger.info("Generate detail methods for all related papers succeed") |
|
|
|
|
|
idea_generator = IdeaGenerator(config, related_paper, brainstorm) |
|
_, _, inspirations, initial_ideas, idea_filtered, final_ideas = idea_generator.generate_ins_bs(expanded_background) |
|
expanded_initial_ideas = [] |
|
if expand_intermediate: |
|
for i, initial_idea in enumerate(initial_ideas): |
|
expanded_initial_ideas.append(api_helper.expand_idea(expanded_background, initial_idea)) |
|
logger.info(f"Expand the {i}th initial idea succeed") |
|
eval_data.append( |
|
{ |
|
"background": bg, |
|
"expanded_background": expanded_background, |
|
"entities_bg": entities, |
|
"brainstorm": brainstorm, |
|
"seperate_brainstorm": seperate_brainstorm, |
|
"entities_bs": entities_bs, |
|
"entities_rt": entities_rt, |
|
"related_paper": [p["title"] for p in related_paper], |
|
"inspirations": inspirations, |
|
"initial_ideas": initial_ideas, |
|
"filtered_ideas": idea_filtered, |
|
"expanded_final_ideas": final_ideas, |
|
"expanded_brainstorms": expanded_brainstorms, |
|
"expanded_initial_ideas": expanded_initial_ideas, |
|
} |
|
) |
|
cur_num += 1 |
|
if cur_num % batch_size == 0: |
|
with open(output_file, "w", encoding="utf-8") as f: |
|
json.dump(eval_data, f, ensure_ascii=False, indent=4) |
|
if cur_num >= num: |
|
break |
|
logger.info("=== Finish ===") |
|
with open(output_file, "w", encoding="utf-8") as f: |
|
json.dump(eval_data, f, ensure_ascii=False, indent=4) |
|
|
|
|
|
if __name__ == "__main__": |
|
main() |
|
|