Alfred828 commited on
Commit
422ae9b
·
verified ·
1 Parent(s): 2ddb4c1

Create tools/encyclopedia.py

Browse files
Files changed (1) hide show
  1. tools/encyclopedia.py +113 -0
tools/encyclopedia.py ADDED
@@ -0,0 +1,113 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ import datasets
4
+ from langchain.docstore.document import Document
5
+ from langchain_community.retrievers import BM25Retriever
6
+ from pydantic import BaseModel, Field
7
+
8
+
9
+ class EncyclopediaDataSets:
10
+ @staticmethod
11
+ def gaia(base_path: str) -> list[Document]:
12
+ datasets_dir = os.path.join(base_path, "tools/.datasets/gaia")
13
+
14
+ try:
15
+ gaia_dataset: (
16
+ datasets.DatasetDict
17
+ | datasets.Dataset
18
+ | datasets.IterableDatasetDict
19
+ | datasets.IterableDataset
20
+ ) = datasets.load_from_disk(datasets_dir)
21
+ # print("load local")
22
+ except Exception as e:
23
+ # print(f"{e}load online")
24
+ gaia_dataset: (
25
+ datasets.DatasetDict
26
+ | datasets.Dataset
27
+ | datasets.IterableDatasetDict
28
+ | datasets.IterableDataset
29
+ ) = datasets.load_dataset(
30
+ "gaia-benchmark/GAIA",
31
+ "2023_all",
32
+ )
33
+
34
+ gaia_dataset.save_to_disk(datasets_dir)
35
+
36
+ # dict_keys(['task_id', 'Question', 'Level', 'Final answer', 'file_name', 'file_path', 'Annotator Metadata'])
37
+ gaia_dataset_list = (
38
+ gaia_dataset["test"].to_list() + gaia_dataset["validation"].to_list()
39
+ )
40
+
41
+ # Convert dataset entries into Document objects
42
+ docs: list[Document] = [
43
+ Document(
44
+ page_content="\n".join(
45
+ [
46
+ f"task_id: {gdl['task_id']}",
47
+ f"Question: {gdl['Question']}",
48
+ f"Final answer: {gdl['Final answer']}",
49
+ ]
50
+ ),
51
+ metadata={"Question": gdl["Question"]},
52
+ )
53
+ for gdl in gaia_dataset_list
54
+ ]
55
+
56
+ return docs
57
+
58
+
59
+ class EncyclopediaRetrieveInput(BaseModel):
60
+ question: str = Field(description="使用者欲搜尋的完整問題。")
61
+
62
+
63
+ class EncyclopediaRetriever:
64
+ def __init__(self, needed_doc_names: list[str], base_path: str):
65
+ self.bm25_retriever = BM25Retriever.from_documents(
66
+ self.prepare_docs(needed_doc_names, base_path)
67
+ )
68
+
69
+ def prepare_docs(self, needed_doc_names: list[str], base_path: str):
70
+ """
71
+ 準備所需的 Document 文件列表。
72
+
73
+ Args:
74
+ needed_doc_names (list[str]): 需要載入的百科資料集合名稱列表。
75
+ base_path (str): 存放本地資料集的基礎路徑。
76
+
77
+ Returns:
78
+ list[Document]: 經由所有指定來源整合而成的 Document 物件列表。
79
+
80
+ 說明:
81
+ 根據傳入的資料集名稱逐一載入相關文件,支援多來源文檔的彙整。
82
+ 目前僅支援 "gaia" 資料集,其它來源可根據需求擴充。
83
+ """
84
+
85
+ docs = []
86
+
87
+ for ndn in needed_doc_names:
88
+ if ndn == "gaia":
89
+ docs.extend(EncyclopediaDataSets.gaia(base_path))
90
+
91
+ return docs
92
+
93
+ def get_related_question(self, question: str) -> str:
94
+ """
95
+ 依據輸入問題檢索相關百科內容。
96
+
97
+ Args:
98
+ question (str): 使用者欲搜尋的完整問題。
99
+
100
+ Returns:
101
+ str: 與問題最相關的百科內容(文本格式),如無符合則傳回提示訊息。
102
+
103
+ 說明:
104
+ 本方法會使用 BM25 向量檢索器對 Document 集合進行檢索,回傳結果內容合併為字串輸出。
105
+ """
106
+
107
+ results: list[Document] = self.bm25_retriever.invoke(question)
108
+
109
+ results_in_str: str = "No matching guest information found."
110
+ if results:
111
+ results_in_str: str = "\n\n".join([doc.page_content for doc in results])
112
+
113
+ return results_in_str