fix figures retrieval
Browse files- app.py +38 -16
- climateqa/engine/chains/retrieve_documents.py +6 -3
- climateqa/engine/graph.py +2 -2
- front/utils.py +20 -16
app.py
CHANGED
@@ -113,7 +113,7 @@ vectorstore = get_pinecone_vectorstore(embeddings_function, index_name = os.gete
|
|
113 |
vectorstore_graphs = get_pinecone_vectorstore(embeddings_function, index_name = os.getenv("PINECONE_API_INDEX_OWID"), text_key="description")
|
114 |
|
115 |
llm = get_llm(provider="openai",max_tokens = 1024,temperature = 0.0)
|
116 |
-
reranker = get_reranker("
|
117 |
|
118 |
agent = make_graph_agent(llm=llm, vectorstore_ipcc=vectorstore, vectorstore_graphs=vectorstore_graphs, reranker=reranker)
|
119 |
|
@@ -142,7 +142,6 @@ async def chat(query, history, audience, sources, reports, relevant_content_sour
|
|
142 |
|
143 |
|
144 |
docs = []
|
145 |
-
used_figures=[]
|
146 |
related_contents = []
|
147 |
docs_html = ""
|
148 |
output_query = ""
|
@@ -165,7 +164,7 @@ async def chat(query, history, audience, sources, reports, relevant_content_sour
|
|
165 |
if "langgraph_node" in event["metadata"]:
|
166 |
node = event["metadata"]["langgraph_node"]
|
167 |
|
168 |
-
if event["event"] == "on_chain_end" and event["name"] == "retrieve_documents" :# when documents are retrieved
|
169 |
docs, docs_html, history, used_documents, related_contents = handle_retrieved_documents(event, history, used_documents)
|
170 |
|
171 |
elif event["event"] == "on_chain_end" and node == "categorize_intent" and event["name"] == "_write": # when the query is transformed
|
@@ -321,10 +320,19 @@ with gr.Blocks(title="Climate Q&A", css_paths=os.getcwd()+ "/style.css", theme=t
|
|
321 |
|
322 |
|
323 |
with gr.Row(elem_id = "input-message"):
|
324 |
-
textbox=gr.Textbox(
|
325 |
-
|
326 |
-
|
327 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
328 |
|
329 |
|
330 |
|
@@ -417,7 +425,9 @@ with gr.Blocks(title="Climate Q&A", css_paths=os.getcwd()+ "/style.css", theme=t
|
|
417 |
with gr.Tabs(elem_id = "group-subtabs") as tabs_recommended_content:
|
418 |
|
419 |
with gr.Tab("Figures",elem_id = "tab-figures",id = 3) as tab_figures:
|
420 |
-
sources_raw = gr.State()
|
|
|
|
|
421 |
|
422 |
with Modal(visible=False, elem_id="modal_figure_galery") as figure_modal:
|
423 |
gallery_component = gr.Gallery(object_fit='scale-down',elem_id="gallery-component", height="80vh")
|
@@ -475,9 +485,9 @@ with gr.Blocks(title="Climate Q&A", css_paths=os.getcwd()+ "/style.css", theme=t
|
|
475 |
)
|
476 |
|
477 |
dropdown_external_sources = gr.CheckboxGroup(
|
478 |
-
["IPCC
|
479 |
label="Select database to search for relevant content",
|
480 |
-
value=["IPCC
|
481 |
interactive=True,
|
482 |
)
|
483 |
|
@@ -633,15 +643,25 @@ with gr.Blocks(title="Climate Q&A", css_paths=os.getcwd()+ "/style.css", theme=t
|
|
633 |
return gr.update(label = recommended_content_notif_label), gr.update(label = sources_notif_label), gr.update(label = figures_notif_label), gr.update(label = graphs_notif_label), gr.update(label = papers_notif_label)
|
634 |
|
635 |
(textbox
|
636 |
-
.submit(start_chat, [textbox,chatbot, search_only],
|
637 |
-
|
638 |
-
|
639 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
640 |
)
|
641 |
|
|
|
|
|
642 |
(examples_hidden
|
643 |
.change(start_chat, [examples_hidden,chatbot, search_only], [textbox,tabs,chatbot],queue = False,api_name = "start_chat_examples")
|
644 |
-
.then(chat, [examples_hidden,chatbot,dropdown_audience, dropdown_sources,dropdown_reports, dropdown_external_sources, search_only] ,[chatbot,sources_textbox,output_query,output_language,
|
645 |
.then(finish_chat, None, [textbox],api_name = "finish_chat_examples")
|
646 |
# .then(update_sources_number_display, [sources_textbox, figures_cards, current_graphs,papers_html],[tab_sources, tab_figures, tab_graphs, tab_papers] )
|
647 |
)
|
@@ -654,7 +674,7 @@ with gr.Blocks(title="Climate Q&A", css_paths=os.getcwd()+ "/style.css", theme=t
|
|
654 |
return [gr.update(visible=visible_bools[i]) for i in range(len(samples))]
|
655 |
|
656 |
|
657 |
-
|
658 |
|
659 |
# update sources numbers
|
660 |
sources_textbox.change(update_sources_number_display, [sources_textbox, figures_cards, current_graphs,papers_html],[tab_recommended_content, tab_sources, tab_figures, tab_graphs, tab_papers])
|
@@ -674,4 +694,6 @@ with gr.Blocks(title="Climate Q&A", css_paths=os.getcwd()+ "/style.css", theme=t
|
|
674 |
|
675 |
demo.queue()
|
676 |
|
|
|
|
|
677 |
demo.launch(ssr_mode=False)
|
|
|
113 |
vectorstore_graphs = get_pinecone_vectorstore(embeddings_function, index_name = os.getenv("PINECONE_API_INDEX_OWID"), text_key="description")
|
114 |
|
115 |
llm = get_llm(provider="openai",max_tokens = 1024,temperature = 0.0)
|
116 |
+
reranker = get_reranker("large")
|
117 |
|
118 |
agent = make_graph_agent(llm=llm, vectorstore_ipcc=vectorstore, vectorstore_graphs=vectorstore_graphs, reranker=reranker)
|
119 |
|
|
|
142 |
|
143 |
|
144 |
docs = []
|
|
|
145 |
related_contents = []
|
146 |
docs_html = ""
|
147 |
output_query = ""
|
|
|
164 |
if "langgraph_node" in event["metadata"]:
|
165 |
node = event["metadata"]["langgraph_node"]
|
166 |
|
167 |
+
if event["event"] == "on_chain_end" and event["name"] == "retrieve_documents" and event["data"]["output"] != None:# when documents are retrieved
|
168 |
docs, docs_html, history, used_documents, related_contents = handle_retrieved_documents(event, history, used_documents)
|
169 |
|
170 |
elif event["event"] == "on_chain_end" and node == "categorize_intent" and event["name"] == "_write": # when the query is transformed
|
|
|
320 |
|
321 |
|
322 |
with gr.Row(elem_id = "input-message"):
|
323 |
+
textbox = gr.Textbox(
|
324 |
+
placeholder="Ask me anything here!",
|
325 |
+
show_label=False,
|
326 |
+
scale=12,
|
327 |
+
lines=1,
|
328 |
+
interactive=True,
|
329 |
+
elem_id="input-textbox"
|
330 |
+
)
|
331 |
+
|
332 |
+
config_button = gr.Button(
|
333 |
+
"",
|
334 |
+
elem_id="config-button"
|
335 |
+
)
|
336 |
|
337 |
|
338 |
|
|
|
425 |
with gr.Tabs(elem_id = "group-subtabs") as tabs_recommended_content:
|
426 |
|
427 |
with gr.Tab("Figures",elem_id = "tab-figures",id = 3) as tab_figures:
|
428 |
+
sources_raw = gr.State([])
|
429 |
+
new_figures = gr.State([])
|
430 |
+
used_figures = gr.State([])
|
431 |
|
432 |
with Modal(visible=False, elem_id="modal_figure_galery") as figure_modal:
|
433 |
gallery_component = gr.Gallery(object_fit='scale-down',elem_id="gallery-component", height="80vh")
|
|
|
485 |
)
|
486 |
|
487 |
dropdown_external_sources = gr.CheckboxGroup(
|
488 |
+
["Figures (IPCC/IPBES)","Papers (OpenAlex)", "Graphs (OurWorldInData)"],
|
489 |
label="Select database to search for relevant content",
|
490 |
+
value=["Figures (IPCC/IPBES)"],
|
491 |
interactive=True,
|
492 |
)
|
493 |
|
|
|
643 |
return gr.update(label = recommended_content_notif_label), gr.update(label = sources_notif_label), gr.update(label = figures_notif_label), gr.update(label = graphs_notif_label), gr.update(label = papers_notif_label)
|
644 |
|
645 |
(textbox
|
646 |
+
.submit(start_chat, [textbox, chatbot, search_only],
|
647 |
+
[textbox, tabs, chatbot],
|
648 |
+
queue=False,
|
649 |
+
api_name="start_chat_textbox")
|
650 |
+
.then(chat, [textbox, chatbot, dropdown_audience, dropdown_sources,
|
651 |
+
dropdown_reports, dropdown_external_sources, search_only],
|
652 |
+
[chatbot, sources_textbox, output_query, output_language,
|
653 |
+
new_figures, current_graphs],
|
654 |
+
concurrency_limit=8,
|
655 |
+
api_name="chat_textbox")
|
656 |
+
.then(finish_chat, None, [textbox],
|
657 |
+
api_name="finish_chat_textbox")
|
658 |
)
|
659 |
|
660 |
+
|
661 |
+
|
662 |
(examples_hidden
|
663 |
.change(start_chat, [examples_hidden,chatbot, search_only], [textbox,tabs,chatbot],queue = False,api_name = "start_chat_examples")
|
664 |
+
.then(chat, [examples_hidden,chatbot,dropdown_audience, dropdown_sources,dropdown_reports, dropdown_external_sources, search_only] ,[chatbot,sources_textbox,output_query,output_language, new_figures, current_graphs],concurrency_limit = 8,api_name = "chat_textbox")
|
665 |
.then(finish_chat, None, [textbox],api_name = "finish_chat_examples")
|
666 |
# .then(update_sources_number_display, [sources_textbox, figures_cards, current_graphs,papers_html],[tab_sources, tab_figures, tab_graphs, tab_papers] )
|
667 |
)
|
|
|
674 |
return [gr.update(visible=visible_bools[i]) for i in range(len(samples))]
|
675 |
|
676 |
|
677 |
+
new_figures.change(process_figures, inputs=[sources_raw, new_figures], outputs=[sources_raw, figures_cards, gallery_component])
|
678 |
|
679 |
# update sources numbers
|
680 |
sources_textbox.change(update_sources_number_display, [sources_textbox, figures_cards, current_graphs,papers_html],[tab_recommended_content, tab_sources, tab_figures, tab_graphs, tab_papers])
|
|
|
694 |
|
695 |
demo.queue()
|
696 |
|
697 |
+
|
698 |
+
|
699 |
demo.launch(ssr_mode=False)
|
climateqa/engine/chains/retrieve_documents.py
CHANGED
@@ -87,7 +87,7 @@ def _get_k_images_by_question(n_questions):
|
|
87 |
elif n_questions == 2:
|
88 |
return 5
|
89 |
elif n_questions == 3:
|
90 |
-
return
|
91 |
else:
|
92 |
return 1
|
93 |
|
@@ -98,7 +98,10 @@ def _add_metadata_and_score(docs: List) -> Document:
|
|
98 |
doc.page_content = doc.page_content.replace("\r\n"," ")
|
99 |
doc.metadata["similarity_score"] = score
|
100 |
doc.metadata["content"] = doc.page_content
|
101 |
-
doc.metadata["page_number"]
|
|
|
|
|
|
|
102 |
# doc.page_content = f"""Doc {i+1} - {doc.metadata['short_name']}: {doc.page_content}"""
|
103 |
docs_with_metadata.append(doc)
|
104 |
return docs_with_metadata
|
@@ -222,7 +225,7 @@ async def retrieve_documents(state,config, vectorstore,reranker,llm,rerank_by_qu
|
|
222 |
else:
|
223 |
related_content = []
|
224 |
|
225 |
-
search_figures = "IPCC
|
226 |
search_only = state["search_only"]
|
227 |
|
228 |
# Get the current question
|
|
|
87 |
elif n_questions == 2:
|
88 |
return 5
|
89 |
elif n_questions == 3:
|
90 |
+
return 3
|
91 |
else:
|
92 |
return 1
|
93 |
|
|
|
98 |
doc.page_content = doc.page_content.replace("\r\n"," ")
|
99 |
doc.metadata["similarity_score"] = score
|
100 |
doc.metadata["content"] = doc.page_content
|
101 |
+
if doc.metadata["page_number"] != "N/A":
|
102 |
+
doc.metadata["page_number"] = int(doc.metadata["page_number"]) + 1
|
103 |
+
else:
|
104 |
+
doc.metadata["page_number"] = 1
|
105 |
# doc.page_content = f"""Doc {i+1} - {doc.metadata['short_name']}: {doc.page_content}"""
|
106 |
docs_with_metadata.append(doc)
|
107 |
return docs_with_metadata
|
|
|
225 |
else:
|
226 |
related_content = []
|
227 |
|
228 |
+
search_figures = "Figures (IPCC/IPBES)" in state["relevant_content_sources"]
|
229 |
search_only = state["search_only"]
|
230 |
|
231 |
# Get the current question
|
climateqa/engine/graph.py
CHANGED
@@ -36,7 +36,7 @@ class GraphState(TypedDict):
|
|
36 |
answer: str
|
37 |
audience: str = "experts"
|
38 |
sources_input: List[str] = ["IPCC","IPBES"]
|
39 |
-
relevant_content_sources: List[str] = ["IPCC
|
40 |
sources_auto: bool = True
|
41 |
min_year: int = 1960
|
42 |
max_year: int = None
|
@@ -82,7 +82,7 @@ def route_based_on_relevant_docs(state,threshold_docs=0.2):
|
|
82 |
return "answer_rag_no_docs"
|
83 |
|
84 |
def route_retrieve_documents(state):
|
85 |
-
if state["search_only"] :
|
86 |
return END
|
87 |
elif len(state["remaining_questions"]) > 0:
|
88 |
return "retrieve_documents"
|
|
|
36 |
answer: str
|
37 |
audience: str = "experts"
|
38 |
sources_input: List[str] = ["IPCC","IPBES"]
|
39 |
+
relevant_content_sources: List[str] = ["Figures (IPCC/IPBES)"]
|
40 |
sources_auto: bool = True
|
41 |
min_year: int = 1960
|
42 |
max_year: int = None
|
|
|
82 |
return "answer_rag_no_docs"
|
83 |
|
84 |
def route_retrieve_documents(state):
|
85 |
+
if len(state["remaining_questions"]) == 0 and state["search_only"] :
|
86 |
return END
|
87 |
elif len(state["remaining_questions"]) > 0:
|
88 |
return "retrieve_documents"
|
front/utils.py
CHANGED
@@ -39,25 +39,29 @@ def parse_output_llm_with_sources(output:str)->str:
|
|
39 |
content_parts = "".join(parts)
|
40 |
return content_parts
|
41 |
|
42 |
-
def process_figures(docs:list)->tuple:
|
43 |
-
|
44 |
-
|
45 |
figures = '<div class="figures-container"><p></p> </div>'
|
|
|
|
|
|
|
46 |
if docs == []:
|
47 |
-
return figures, gallery
|
|
|
|
|
48 |
docs_figures = [d for d in docs if d.metadata["chunk_type"] == "image"]
|
49 |
-
for
|
50 |
-
if doc.metadata["chunk_type"] == "image":
|
51 |
-
|
52 |
-
title = f"{doc.metadata['figure_code']} - {doc.metadata['short_name']}"
|
53 |
-
else:
|
54 |
-
title = f"{doc.metadata['short_name']}"
|
55 |
|
56 |
|
57 |
-
if
|
58 |
-
used_figures.append(
|
|
|
|
|
59 |
try:
|
60 |
-
key = f"Image {
|
61 |
|
62 |
image_path = doc.metadata["image_path"].split("documents/")[1]
|
63 |
img = get_image_from_azure_blob_storage(image_path)
|
@@ -70,12 +74,12 @@ def process_figures(docs:list)->tuple:
|
|
70 |
|
71 |
img_str = base64.b64encode(buffered.getvalue()).decode()
|
72 |
|
73 |
-
figures = figures + make_html_figure_sources(doc,
|
74 |
gallery.append(img)
|
75 |
except Exception as e:
|
76 |
-
print(f"Skipped adding image {
|
77 |
|
78 |
-
return figures, gallery
|
79 |
|
80 |
|
81 |
def generate_html_graphs(graphs:list)->str:
|
|
|
39 |
content_parts = "".join(parts)
|
40 |
return content_parts
|
41 |
|
42 |
+
def process_figures(docs:list, new_figures:list)->tuple:
|
43 |
+
docs = docs + new_figures
|
44 |
+
|
45 |
figures = '<div class="figures-container"><p></p> </div>'
|
46 |
+
gallery = []
|
47 |
+
used_figures = []
|
48 |
+
|
49 |
if docs == []:
|
50 |
+
return figures, gallery, used_figures
|
51 |
+
|
52 |
+
|
53 |
docs_figures = [d for d in docs if d.metadata["chunk_type"] == "image"]
|
54 |
+
for i_doc, doc in enumerate(docs_figures):
|
55 |
+
if doc.metadata["chunk_type"] == "image":
|
56 |
+
path = doc.metadata["image_path"]
|
|
|
|
|
|
|
57 |
|
58 |
|
59 |
+
if path not in used_figures:
|
60 |
+
used_figures.append(path)
|
61 |
+
figure_number = len(used_figures)
|
62 |
+
|
63 |
try:
|
64 |
+
key = f"Image {figure_number}"
|
65 |
|
66 |
image_path = doc.metadata["image_path"].split("documents/")[1]
|
67 |
img = get_image_from_azure_blob_storage(image_path)
|
|
|
74 |
|
75 |
img_str = base64.b64encode(buffered.getvalue()).decode()
|
76 |
|
77 |
+
figures = figures + make_html_figure_sources(doc, figure_number, img_str)
|
78 |
gallery.append(img)
|
79 |
except Exception as e:
|
80 |
+
print(f"Skipped adding image {figure_number} because of {e}")
|
81 |
|
82 |
+
return docs, figures, gallery
|
83 |
|
84 |
|
85 |
def generate_html_graphs(graphs:list)->str:
|