Merged in feature/improve_parsing_and_retrieval (pull request #3)
Browse files- app.py +2 -2
- climateqa/engine/chains/prompts.py +51 -1
- climateqa/engine/chains/retrieve_documents.py +226 -16
- climateqa/engine/graph.py +19 -3
app.py
CHANGED
@@ -66,7 +66,7 @@ user_id = create_user_id()
|
|
66 |
embeddings_function = get_embeddings_function()
|
67 |
vectorstore = get_pinecone_vectorstore(embeddings_function, index_name=os.getenv("PINECONE_API_INDEX"))
|
68 |
vectorstore_graphs = get_pinecone_vectorstore(embeddings_function, index_name=os.getenv("PINECONE_API_INDEX_OWID"), text_key="description")
|
69 |
-
vectorstore_region = get_pinecone_vectorstore(embeddings_function, index_name=os.getenv("
|
70 |
|
71 |
llm = get_llm(provider="openai",max_tokens = 1024,temperature = 0.0)
|
72 |
if os.environ["GRADIO_ENV"] == "local":
|
@@ -75,7 +75,7 @@ else :
|
|
75 |
reranker = get_reranker("large")
|
76 |
|
77 |
agent = make_graph_agent(llm=llm, vectorstore_ipcc=vectorstore, vectorstore_graphs=vectorstore_graphs, vectorstore_region = vectorstore_region, reranker=reranker, threshold_docs=0.2)
|
78 |
-
agent_poc = make_graph_agent_poc(llm=llm, vectorstore_ipcc=vectorstore, vectorstore_graphs=vectorstore_graphs, vectorstore_region = vectorstore_region, reranker=reranker, threshold_docs=0)#TODO put back default 0.2
|
79 |
|
80 |
|
81 |
async def chat(query, history, audience, sources, reports, relevant_content_sources_selection, search_only):
|
|
|
66 |
embeddings_function = get_embeddings_function()
|
67 |
vectorstore = get_pinecone_vectorstore(embeddings_function, index_name=os.getenv("PINECONE_API_INDEX"))
|
68 |
vectorstore_graphs = get_pinecone_vectorstore(embeddings_function, index_name=os.getenv("PINECONE_API_INDEX_OWID"), text_key="description")
|
69 |
+
vectorstore_region = get_pinecone_vectorstore(embeddings_function, index_name=os.getenv("PINECONE_API_INDEX_LOCAL_V2"))
|
70 |
|
71 |
llm = get_llm(provider="openai",max_tokens = 1024,temperature = 0.0)
|
72 |
if os.environ["GRADIO_ENV"] == "local":
|
|
|
75 |
reranker = get_reranker("large")
|
76 |
|
77 |
agent = make_graph_agent(llm=llm, vectorstore_ipcc=vectorstore, vectorstore_graphs=vectorstore_graphs, vectorstore_region = vectorstore_region, reranker=reranker, threshold_docs=0.2)
|
78 |
+
agent_poc = make_graph_agent_poc(llm=llm, vectorstore_ipcc=vectorstore, vectorstore_graphs=vectorstore_graphs, vectorstore_region = vectorstore_region, reranker=reranker, threshold_docs=0, version="v4")#TODO put back default 0.2
|
79 |
|
80 |
|
81 |
async def chat(query, history, audience, sources, reports, relevant_content_sources_selection, search_only):
|
climateqa/engine/chains/prompts.py
CHANGED
@@ -198,4 +198,54 @@ Graphs and their HTML embedding:
|
|
198 |
{format_instructions}
|
199 |
|
200 |
Output the result as json with a key "graphs" containing a list of dictionaries of the relevant graphs with keys 'embedding', 'category', and 'source'. Do not modify the graph HTML embedding, the category or the source. Do not put any message or text before or after the JSON output.
|
201 |
-
"""
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
198 |
{format_instructions}
|
199 |
|
200 |
Output the result as json with a key "graphs" containing a list of dictionaries of the relevant graphs with keys 'embedding', 'category', and 'source'. Do not modify the graph HTML embedding, the category or the source. Do not put any message or text before or after the JSON output.
|
201 |
+
"""
|
202 |
+
|
203 |
+
retrieve_chapter_prompt_template = """Given the user question and a list of documents with their table of contents, retrieve the 5 most relevant level 0 chapters which could help to answer to the question while taking account their sub-chapters.
|
204 |
+
|
205 |
+
The table of contents is structured like that :
|
206 |
+
{{
|
207 |
+
"level": 0,
|
208 |
+
"Chapter 1": {{}},
|
209 |
+
"Chapter 2" : {{
|
210 |
+
"level": 1,
|
211 |
+
"Chapter 2.1": {{
|
212 |
+
...
|
213 |
+
}}
|
214 |
+
}},
|
215 |
+
}}
|
216 |
+
|
217 |
+
Here level is the level of the chapter. For example, Chapter 1 and Chapter 2 are at level 0, and Chapter 2.1 is at level 1.
|
218 |
+
|
219 |
+
### Guidelines ###
|
220 |
+
- Keep all the list of documents that is given to you
|
221 |
+
- Each chapter must keep **EXACTLY** its assigned level in the table of contents. **DO NOT MODIFY THE LEVELS. **
|
222 |
+
- Check systematically the level of a chapter before including it in the answer.
|
223 |
+
- Return **valid JSON** result.
|
224 |
+
|
225 |
+
--------------------
|
226 |
+
User question :
|
227 |
+
{query}
|
228 |
+
|
229 |
+
List of documents with their table of contents :
|
230 |
+
{doc_list}
|
231 |
+
|
232 |
+
--------------------
|
233 |
+
|
234 |
+
Return a JSON result with a list of relevant chapters with the following keys **WITHOUT** the json markdown indicator ```json at the beginning:
|
235 |
+
- "document" : the document in which we can find the chapter
|
236 |
+
- "chapter" : the title of the chapter
|
237 |
+
|
238 |
+
**IMPORTANT : Make sure that the levels of the answer are exactly the same as the ones in the table of contents**
|
239 |
+
|
240 |
+
Example of a JSON response:
|
241 |
+
[
|
242 |
+
{{
|
243 |
+
"document": "Document A",
|
244 |
+
"chapter": "Chapter 1",
|
245 |
+
}},
|
246 |
+
{{
|
247 |
+
"document": "Document B",
|
248 |
+
"chapter": "Chapter 5",
|
249 |
+
}}
|
250 |
+
]
|
251 |
+
"""
|
climateqa/engine/chains/retrieve_documents.py
CHANGED
@@ -15,6 +15,14 @@ from ..utils import log_event
|
|
15 |
from langchain_core.vectorstores import VectorStore
|
16 |
from typing import List
|
17 |
from langchain_core.documents.base import Document
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
18 |
import asyncio
|
19 |
|
20 |
from typing import Any, Dict, List, Tuple
|
@@ -119,6 +127,21 @@ def remove_duplicates_chunks(docs):
|
|
119 |
result.append(doc)
|
120 |
return result
|
121 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
122 |
async def get_POC_relevant_documents(
|
123 |
query: str,
|
124 |
vectorstore:VectorStore,
|
@@ -169,6 +192,86 @@ async def get_POC_relevant_documents(
|
|
169 |
"docs_question" : docs_question,
|
170 |
"docs_images" : docs_images
|
171 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
172 |
|
173 |
|
174 |
async def get_IPCC_relevant_documents(
|
@@ -271,6 +374,7 @@ def concatenate_documents(index, source_type, docs_question_dict, k_by_question,
|
|
271 |
return docs_question, images_question
|
272 |
|
273 |
|
|
|
274 |
# The chain callback is not necessary, but it propagates the langchain callbacks to the astream_events logger to display intermediate results
|
275 |
# @chain
|
276 |
async def retrieve_documents(
|
@@ -279,6 +383,7 @@ async def retrieve_documents(
|
|
279 |
source_type: str,
|
280 |
vectorstore: VectorStore,
|
281 |
reranker: Any,
|
|
|
282 |
search_figures: bool = False,
|
283 |
search_only: bool = False,
|
284 |
reports: list = [],
|
@@ -286,7 +391,9 @@ async def retrieve_documents(
|
|
286 |
k_images_by_question: int = 5,
|
287 |
k_before_reranking: int = 100,
|
288 |
k_by_question: int = 5,
|
289 |
-
k_summary_by_question: int = 3
|
|
|
|
|
290 |
) -> Tuple[List[Document], List[Document]]:
|
291 |
"""
|
292 |
Unpack the first question of the remaining questions, and retrieve and rerank corresponding documents, based on the question and selected_sources
|
@@ -316,6 +423,7 @@ async def retrieve_documents(
|
|
316 |
|
317 |
print(f"""---- Retrieve documents from {current_question["source_type"]}----""")
|
318 |
|
|
|
319 |
if source_type == "IPx":
|
320 |
docs_question_dict = await get_IPCC_relevant_documents(
|
321 |
query = question,
|
@@ -331,19 +439,36 @@ async def retrieve_documents(
|
|
331 |
reports = reports,
|
332 |
)
|
333 |
|
334 |
-
if source_type ==
|
335 |
-
|
336 |
-
|
337 |
-
|
338 |
-
|
339 |
-
|
340 |
-
|
341 |
-
|
342 |
-
|
343 |
-
|
344 |
-
|
345 |
-
|
346 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
347 |
|
348 |
# Rerank
|
349 |
if reranker is not None and rerank_by_question:
|
@@ -382,7 +507,10 @@ async def retrieve_documents_for_all_questions(
|
|
382 |
reranker,
|
383 |
rerank_by_question=True,
|
384 |
k_final=15,
|
385 |
-
k_before_reranking=100
|
|
|
|
|
|
|
386 |
):
|
387 |
"""
|
388 |
Retrieve documents in parallel for all questions.
|
@@ -403,6 +531,7 @@ async def retrieve_documents_for_all_questions(
|
|
403 |
k_images_by_question = _get_k_images_by_question(n_questions)
|
404 |
k_before_reranking=100
|
405 |
|
|
|
406 |
tasks = [
|
407 |
retrieve_documents(
|
408 |
current_question=question,
|
@@ -417,7 +546,10 @@ async def retrieve_documents_for_all_questions(
|
|
417 |
k_images_by_question=k_images_by_question,
|
418 |
k_before_reranking=k_before_reranking,
|
419 |
k_by_question=k_by_question,
|
420 |
-
k_summary_by_question=k_summary_by_question
|
|
|
|
|
|
|
421 |
)
|
422 |
for i, question in enumerate(questions_list) if i in to_handle_questions_index
|
423 |
]
|
@@ -429,6 +561,32 @@ async def retrieve_documents_for_all_questions(
|
|
429 |
new_state["related_contents"].extend(images_question)
|
430 |
return new_state
|
431 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
432 |
def make_IPx_retriever_node(vectorstore,reranker,llm,rerank_by_question=True, k_final=15, k_before_reranking=100, k_summary=5):
|
433 |
|
434 |
async def retrieve_IPx_docs(state, config):
|
@@ -496,4 +654,56 @@ def make_POC_retriever_node(vectorstore,reranker,llm,rerank_by_question=True, k_
|
|
496 |
return retrieve_POC_docs_node
|
497 |
|
498 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
499 |
|
|
|
15 |
from langchain_core.vectorstores import VectorStore
|
16 |
from typing import List
|
17 |
from langchain_core.documents.base import Document
|
18 |
+
from ..llm import get_llm
|
19 |
+
from .prompts import retrieve_chapter_prompt_template
|
20 |
+
from langchain_core.prompts import ChatPromptTemplate
|
21 |
+
from langchain_core.output_parsers import StrOutputParser
|
22 |
+
from ..vectorstore import get_pinecone_vectorstore
|
23 |
+
from ..embeddings import get_embeddings_function
|
24 |
+
|
25 |
+
|
26 |
import asyncio
|
27 |
|
28 |
from typing import Any, Dict, List, Tuple
|
|
|
127 |
result.append(doc)
|
128 |
return result
|
129 |
|
130 |
+
def get_ToCs(version: str) :
|
131 |
+
|
132 |
+
filters_text = {
|
133 |
+
"chunk_type":"toc",
|
134 |
+
"version": version
|
135 |
+
}
|
136 |
+
embeddings_function = get_embeddings_function()
|
137 |
+
vectorstore = get_pinecone_vectorstore(embeddings_function, index_name="climateqa-v2")
|
138 |
+
tocs = vectorstore.similarity_search_with_score(query="",filter = filters_text)
|
139 |
+
|
140 |
+
# remove duplicates or almost duplicates
|
141 |
+
tocs = remove_duplicates_chunks(tocs)
|
142 |
+
|
143 |
+
return tocs
|
144 |
+
|
145 |
async def get_POC_relevant_documents(
|
146 |
query: str,
|
147 |
vectorstore:VectorStore,
|
|
|
192 |
"docs_question" : docs_question,
|
193 |
"docs_images" : docs_images
|
194 |
}
|
195 |
+
|
196 |
+
async def get_POC_documents_by_ToC_relevant_documents(
|
197 |
+
query: str,
|
198 |
+
tocs: list,
|
199 |
+
vectorstore:VectorStore,
|
200 |
+
version: str,
|
201 |
+
sources:list = ["Acclimaterra","PCAET","Plan Biodiversite"],
|
202 |
+
search_figures:bool = False,
|
203 |
+
search_only:bool = False,
|
204 |
+
k_documents:int = 10,
|
205 |
+
threshold:float = 0.6,
|
206 |
+
k_images: int = 5,
|
207 |
+
reports:list = [],
|
208 |
+
min_size:int = 200,
|
209 |
+
proportion: float = 0.5,
|
210 |
+
) :
|
211 |
+
"""
|
212 |
+
Args:
|
213 |
+
- tocs : list with the table of contents of each document
|
214 |
+
- version : version of the parsed documents (e.g. "v4")
|
215 |
+
- proportion : share of documents retrieved using ToCs
|
216 |
+
"""
|
217 |
+
# Prepare base search kwargs
|
218 |
+
filters = {}
|
219 |
+
docs_question = []
|
220 |
+
docs_images = []
|
221 |
+
|
222 |
+
# TODO add source selection
|
223 |
+
# if len(reports) > 0:
|
224 |
+
# filters["short_name"] = {"$in":reports}
|
225 |
+
# else:
|
226 |
+
# filters["source"] = { "$in": sources}
|
227 |
+
|
228 |
+
k_documents_toc = round(k_documents * proportion)
|
229 |
+
|
230 |
+
relevant_tocs = await get_relevant_toc_level_for_query(query, tocs)
|
231 |
+
|
232 |
+
print(f"Relevant ToCs : {relevant_tocs}")
|
233 |
+
# Transform the ToC dict {"document": str, "chapter": str} into a list of string
|
234 |
+
toc_filters = [toc['chapter'] for toc in relevant_tocs]
|
235 |
+
|
236 |
+
filters_text_toc = {
|
237 |
+
**filters,
|
238 |
+
"chunk_type":"text",
|
239 |
+
"toc_level0": {"$in": toc_filters},
|
240 |
+
"version": version
|
241 |
+
# "report_type": {}, # TODO to be completed to choose the right documents / chapters according to the analysis of the question
|
242 |
+
}
|
243 |
+
|
244 |
+
docs_question = vectorstore.similarity_search_with_score(query=query,filter = filters_text_toc,k = k_documents_toc)
|
245 |
+
|
246 |
+
filters_text = {
|
247 |
+
**filters,
|
248 |
+
"chunk_type":"text",
|
249 |
+
"version": version
|
250 |
+
# "report_type": {}, # TODO to be completed to choose the right documents / chapters according to the analysis of the question
|
251 |
+
}
|
252 |
+
|
253 |
+
docs_question += vectorstore.similarity_search_with_score(query=query,filter = filters_text,k = k_documents - k_documents_toc)
|
254 |
+
|
255 |
+
# remove duplicates or almost duplicates
|
256 |
+
docs_question = remove_duplicates_chunks(docs_question)
|
257 |
+
docs_question = [x for x in docs_question if x[1] > threshold]
|
258 |
+
|
259 |
+
if search_figures:
|
260 |
+
# Images
|
261 |
+
filters_image = {
|
262 |
+
**filters,
|
263 |
+
"chunk_type":"image"
|
264 |
+
}
|
265 |
+
docs_images = vectorstore.similarity_search_with_score(query=query,filter = filters_image,k = k_images)
|
266 |
+
|
267 |
+
docs_question, docs_images = _add_metadata_and_score(docs_question), _add_metadata_and_score(docs_images)
|
268 |
+
|
269 |
+
docs_question = [x for x in docs_question if len(x.page_content) > min_size]
|
270 |
+
|
271 |
+
return {
|
272 |
+
"docs_question" : docs_question,
|
273 |
+
"docs_images" : docs_images
|
274 |
+
}
|
275 |
|
276 |
|
277 |
async def get_IPCC_relevant_documents(
|
|
|
374 |
return docs_question, images_question
|
375 |
|
376 |
|
377 |
+
|
378 |
# The chain callback is not necessary, but it propagates the langchain callbacks to the astream_events logger to display intermediate results
|
379 |
# @chain
|
380 |
async def retrieve_documents(
|
|
|
383 |
source_type: str,
|
384 |
vectorstore: VectorStore,
|
385 |
reranker: Any,
|
386 |
+
version: str = "",
|
387 |
search_figures: bool = False,
|
388 |
search_only: bool = False,
|
389 |
reports: list = [],
|
|
|
391 |
k_images_by_question: int = 5,
|
392 |
k_before_reranking: int = 100,
|
393 |
k_by_question: int = 5,
|
394 |
+
k_summary_by_question: int = 3,
|
395 |
+
tocs: list = [],
|
396 |
+
by_toc=False
|
397 |
) -> Tuple[List[Document], List[Document]]:
|
398 |
"""
|
399 |
Unpack the first question of the remaining questions, and retrieve and rerank corresponding documents, based on the question and selected_sources
|
|
|
423 |
|
424 |
print(f"""---- Retrieve documents from {current_question["source_type"]}----""")
|
425 |
|
426 |
+
|
427 |
if source_type == "IPx":
|
428 |
docs_question_dict = await get_IPCC_relevant_documents(
|
429 |
query = question,
|
|
|
439 |
reports = reports,
|
440 |
)
|
441 |
|
442 |
+
if source_type == 'POC':
|
443 |
+
if by_toc == True:
|
444 |
+
print("---- Retrieve documents by ToC----")
|
445 |
+
docs_question_dict = await get_POC_documents_by_ToC_relevant_documents(
|
446 |
+
query=question,
|
447 |
+
tocs = tocs,
|
448 |
+
vectorstore=vectorstore,
|
449 |
+
version=version,
|
450 |
+
search_figures = search_figures,
|
451 |
+
sources = sources,
|
452 |
+
threshold = 0.5,
|
453 |
+
search_only = search_only,
|
454 |
+
reports = reports,
|
455 |
+
min_size= 200,
|
456 |
+
k_documents= k_before_reranking,
|
457 |
+
k_images= k_by_question
|
458 |
+
)
|
459 |
+
else :
|
460 |
+
docs_question_dict = await get_POC_relevant_documents(
|
461 |
+
query = question,
|
462 |
+
vectorstore=vectorstore,
|
463 |
+
search_figures = search_figures,
|
464 |
+
sources = sources,
|
465 |
+
threshold = 0.5,
|
466 |
+
search_only = search_only,
|
467 |
+
reports = reports,
|
468 |
+
min_size= 200,
|
469 |
+
k_documents= k_before_reranking,
|
470 |
+
k_images= k_by_question
|
471 |
+
)
|
472 |
|
473 |
# Rerank
|
474 |
if reranker is not None and rerank_by_question:
|
|
|
507 |
reranker,
|
508 |
rerank_by_question=True,
|
509 |
k_final=15,
|
510 |
+
k_before_reranking=100,
|
511 |
+
version: str = "",
|
512 |
+
tocs: list[dict] = [],
|
513 |
+
by_toc: bool = False
|
514 |
):
|
515 |
"""
|
516 |
Retrieve documents in parallel for all questions.
|
|
|
531 |
k_images_by_question = _get_k_images_by_question(n_questions)
|
532 |
k_before_reranking=100
|
533 |
|
534 |
+
print(f"Source type here is {source_type}")
|
535 |
tasks = [
|
536 |
retrieve_documents(
|
537 |
current_question=question,
|
|
|
546 |
k_images_by_question=k_images_by_question,
|
547 |
k_before_reranking=k_before_reranking,
|
548 |
k_by_question=k_by_question,
|
549 |
+
k_summary_by_question=k_summary_by_question,
|
550 |
+
tocs=tocs,
|
551 |
+
version=version,
|
552 |
+
by_toc=by_toc
|
553 |
)
|
554 |
for i, question in enumerate(questions_list) if i in to_handle_questions_index
|
555 |
]
|
|
|
561 |
new_state["related_contents"].extend(images_question)
|
562 |
return new_state
|
563 |
|
564 |
+
# ToC Retriever
|
565 |
+
async def get_relevant_toc_level_for_query(
|
566 |
+
query: str,
|
567 |
+
tocs: list[Document],
|
568 |
+
) -> list[dict] :
|
569 |
+
|
570 |
+
doc_list = []
|
571 |
+
for doc in tocs:
|
572 |
+
doc_name = doc[0].metadata['name']
|
573 |
+
toc = doc[0].page_content
|
574 |
+
doc_list.append({'document': doc_name, 'toc': toc})
|
575 |
+
|
576 |
+
llm = get_llm(provider="openai",max_tokens = 1024,temperature = 0.0)
|
577 |
+
|
578 |
+
prompt = ChatPromptTemplate.from_template(retrieve_chapter_prompt_template)
|
579 |
+
chain = prompt | llm | StrOutputParser()
|
580 |
+
response = chain.invoke({"query": query, "doc_list": doc_list})
|
581 |
+
|
582 |
+
try:
|
583 |
+
relevant_tocs = eval(response)
|
584 |
+
except Exception as e:
|
585 |
+
print(f" Failed to parse the result because of : {e}")
|
586 |
+
|
587 |
+
return relevant_tocs
|
588 |
+
|
589 |
+
|
590 |
def make_IPx_retriever_node(vectorstore,reranker,llm,rerank_by_question=True, k_final=15, k_before_reranking=100, k_summary=5):
|
591 |
|
592 |
async def retrieve_IPx_docs(state, config):
|
|
|
654 |
return retrieve_POC_docs_node
|
655 |
|
656 |
|
657 |
+
def make_POC_by_ToC_retriever_node(
|
658 |
+
vectorstore: VectorStore,
|
659 |
+
reranker,
|
660 |
+
llm,
|
661 |
+
version: str = "",
|
662 |
+
rerank_by_question=True,
|
663 |
+
k_final=15,
|
664 |
+
k_before_reranking=100,
|
665 |
+
k_summary=5,
|
666 |
+
):
|
667 |
+
|
668 |
+
async def retrieve_POC_docs_node(state, config):
|
669 |
+
if "POC region" not in state["relevant_content_sources_selection"] :
|
670 |
+
return {}
|
671 |
+
|
672 |
+
search_figures = "Figures (IPCC/IPBES)" in state["relevant_content_sources_selection"]
|
673 |
+
search_only = state["search_only"]
|
674 |
+
search_only = state["search_only"]
|
675 |
+
reports = state["reports"]
|
676 |
+
questions_list = state["questions_list"]
|
677 |
+
n_questions=state["n_questions"]["total"]
|
678 |
+
|
679 |
+
tocs = get_ToCs(version=version)
|
680 |
+
|
681 |
+
source_type = "POC"
|
682 |
+
POC_questions_index = [i for i, x in enumerate(state["questions_list"]) if x["source_type"] == "POC"]
|
683 |
+
|
684 |
+
state = await retrieve_documents_for_all_questions(
|
685 |
+
search_figures=search_figures,
|
686 |
+
search_only=search_only,
|
687 |
+
config=config,
|
688 |
+
reports=reports,
|
689 |
+
questions_list=questions_list,
|
690 |
+
n_questions=n_questions,
|
691 |
+
source_type=source_type,
|
692 |
+
to_handle_questions_index=POC_questions_index,
|
693 |
+
vectorstore=vectorstore,
|
694 |
+
reranker=reranker,
|
695 |
+
rerank_by_question=rerank_by_question,
|
696 |
+
k_final=k_final,
|
697 |
+
k_before_reranking=k_before_reranking,
|
698 |
+
tocs=tocs,
|
699 |
+
version=version,
|
700 |
+
by_toc=True
|
701 |
+
)
|
702 |
+
return state
|
703 |
+
|
704 |
+
return retrieve_POC_docs_node
|
705 |
+
|
706 |
+
|
707 |
+
|
708 |
+
|
709 |
|
climateqa/engine/graph.py
CHANGED
@@ -19,7 +19,7 @@ from .chains.answer_ai_impact import make_ai_impact_node
|
|
19 |
from .chains.query_transformation import make_query_transform_node
|
20 |
from .chains.translation import make_translation_node
|
21 |
from .chains.intent_categorization import make_intent_categorization_node
|
22 |
-
from .chains.retrieve_documents import make_IPx_retriever_node, make_POC_retriever_node
|
23 |
from .chains.answer_rag import make_rag_node
|
24 |
from .chains.graph_retriever import make_graph_retriever_node
|
25 |
from .chains.chitchat_categorization import make_chitchat_intent_categorization_node
|
@@ -211,8 +211,23 @@ def make_graph_agent(llm, vectorstore_ipcc, vectorstore_graphs, vectorstore_regi
|
|
211 |
app = workflow.compile()
|
212 |
return app
|
213 |
|
214 |
-
def make_graph_agent_poc(llm, vectorstore_ipcc, vectorstore_graphs, vectorstore_region, reranker, threshold_docs=0.2):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
215 |
|
|
|
216 |
workflow = StateGraph(GraphState)
|
217 |
|
218 |
# Define the node functions
|
@@ -223,7 +238,8 @@ def make_graph_agent_poc(llm, vectorstore_ipcc, vectorstore_graphs, vectorstore_
|
|
223 |
answer_ai_impact = make_ai_impact_node(llm)
|
224 |
retrieve_documents = make_IPx_retriever_node(vectorstore_ipcc, reranker, llm)
|
225 |
retrieve_graphs = make_graph_retriever_node(vectorstore_graphs, reranker)
|
226 |
-
retrieve_local_data = make_POC_retriever_node(vectorstore_region, reranker, llm)
|
|
|
227 |
answer_rag = make_rag_node(llm, with_docs=True)
|
228 |
answer_rag_no_docs = make_rag_node(llm, with_docs=False)
|
229 |
chitchat_categorize_intent = make_chitchat_intent_categorization_node(llm)
|
|
|
19 |
from .chains.query_transformation import make_query_transform_node
|
20 |
from .chains.translation import make_translation_node
|
21 |
from .chains.intent_categorization import make_intent_categorization_node
|
22 |
+
from .chains.retrieve_documents import make_IPx_retriever_node, make_POC_retriever_node, make_POC_by_ToC_retriever_node
|
23 |
from .chains.answer_rag import make_rag_node
|
24 |
from .chains.graph_retriever import make_graph_retriever_node
|
25 |
from .chains.chitchat_categorization import make_chitchat_intent_categorization_node
|
|
|
211 |
app = workflow.compile()
|
212 |
return app
|
213 |
|
214 |
+
def make_graph_agent_poc(llm, vectorstore_ipcc, vectorstore_graphs, vectorstore_region, reranker, version:str, threshold_docs=0.2):
|
215 |
+
"""_summary_
|
216 |
+
|
217 |
+
Args:
|
218 |
+
llm (_type_): _description_
|
219 |
+
vectorstore_ipcc (_type_): _description_
|
220 |
+
vectorstore_graphs (_type_): _description_
|
221 |
+
vectorstore_region (_type_): _description_
|
222 |
+
reranker (_type_): _description_
|
223 |
+
version (str): version of the parsed documents (e.g "v4")
|
224 |
+
threshold_docs (float, optional): _description_. Defaults to 0.2.
|
225 |
+
|
226 |
+
Returns:
|
227 |
+
_type_: _description_
|
228 |
+
"""
|
229 |
|
230 |
+
|
231 |
workflow = StateGraph(GraphState)
|
232 |
|
233 |
# Define the node functions
|
|
|
238 |
answer_ai_impact = make_ai_impact_node(llm)
|
239 |
retrieve_documents = make_IPx_retriever_node(vectorstore_ipcc, reranker, llm)
|
240 |
retrieve_graphs = make_graph_retriever_node(vectorstore_graphs, reranker)
|
241 |
+
# retrieve_local_data = make_POC_retriever_node(vectorstore_region, reranker, llm)
|
242 |
+
retrieve_local_data = make_POC_by_ToC_retriever_node(vectorstore_region, reranker, llm, version=version)
|
243 |
answer_rag = make_rag_node(llm, with_docs=True)
|
244 |
answer_rag_no_docs = make_rag_node(llm, with_docs=False)
|
245 |
chitchat_categorize_intent = make_chitchat_intent_categorization_node(llm)
|