File size: 3,574 Bytes
479f67b
 
 
 
02069d7
479f67b
c8709b2
479f67b
02069d7
479f67b
 
 
 
 
 
 
 
02069d7
 
479f67b
 
02069d7
 
 
 
479f67b
 
 
 
 
 
 
 
 
 
 
 
 
 
c8709b2
 
 
479f67b
c8709b2
 
 
 
479f67b
c8709b2
 
 
 
 
 
 
479f67b
 
 
 
 
 
 
c8709b2
 
 
 
 
 
 
88253fe
479f67b
02069d7
c8709b2
02069d7
c8709b2
 
 
 
 
 
02069d7
c8709b2
 
479f67b
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
import json
from utils.paper_retriever import RetrieverFactory
from utils.llms_api import APIHelper
from utils.header import ConfigReader
from utils.hash import check_env, check_embedding
from generator import IdeaGenerator
import functools


class Backend(object):
    def __init__(self) -> None:
        CONFIG_PATH = "./configs/datasets.yaml"
        EXAMPLE_PATH = "./assets/data/example.json"
        USE_INSPIRATION = True
        BRAINSTORM_MODE = "mode_c"

        self.config = ConfigReader.load(CONFIG_PATH)
        check_env()
        check_embedding(self.config.DEFAULT.embedding)
        RETRIEVER_NAME = self.config.RETRIEVE.retriever_name
        self.api_helper = APIHelper(self.config)
        self.retriever_factory = (
            RetrieverFactory.get_retriever_factory().create_retriever(
                RETRIEVER_NAME, self.config
            )
        )
        self.idea_generator = IdeaGenerator(self.config, None)
        self.use_inspiration = USE_INSPIRATION
        self.brainstorm_mode = BRAINSTORM_MODE
        self.examples = self.load_examples(EXAMPLE_PATH)

    def load_examples(self, path):
        try:
            with open(path, "r") as f:
                data = json.load(f)
            return data
        except (FileNotFoundError, json.JSONDecodeError) as e:
            print(f"Error loading examples from {path}: {e}")
            return []
        
    def background2entities_callback(self, background):
        return self.api_helper.generate_entity_list(background)

    def background2expandedbackground_callback(self, background, entities):
        keywords_str = functools.reduce(lambda x, y: f"{x}, {y}", entities)
        expanded_background = self.api_helper.expand_background(background, keywords_str)
        return expanded_background

    def background2brainstorm_callback(self, expanded_background):
        return self.api_helper.generate_brainstorm(expanded_background)

    def brainstorm2entities_callback(self, brainstorm, entities):
        entities_bs = self.api_helper.generate_entity_list(brainstorm, 10)
        entities_all = list(set(entities) | set(entities_bs))
        return entities_all

    def upload_json_callback(self, input):
        with open(input, "r") as json_file:
            contents = json_file.read()
            json_contents = json.loads(contents)
        return [json_contents["background"], contents]

    def entities2literature_callback(self, expanded_background, entities):
        result = self.retriever_factory.retrieve(
            expanded_background, entities, need_evaluate=False, target_paper_id_list=[]
        )
        res = []
        for i, p in enumerate(result["related_paper"]):
            res.append(f'{p["title"]}. {p["venue_name"].upper()} {p["year"]}.')
        return res, result["related_paper"]

    def literature2initial_ideas_callback(
        self, expanded_background, brainstorms, retrieved_literature
    ):
        self.idea_generator.paper_list = retrieved_literature
        self.idea_generator.brainstorm = brainstorms
        _, _, inspirations, initial_ideas, idea_filtered, final_ideas = (
            self.idea_generator.generate_ins_bs(expanded_background)
        )
        return idea_filtered, final_ideas

    def initial2final_callback(self, initial_ideas, final_ideas):
        return final_ideas

    def get_demo_i(self, i):
        if 0 <= i < len(self.examples):
            return self.examples[i].get("background", "Background not found.")
        else:
            return "Example not found. Please select a valid index."