timeki commited on
Commit
96c684f
·
1 Parent(s): 7ad11f7

fix figures retrieval

Browse files
app.py CHANGED
@@ -113,7 +113,7 @@ vectorstore = get_pinecone_vectorstore(embeddings_function, index_name = os.gete
113
  vectorstore_graphs = get_pinecone_vectorstore(embeddings_function, index_name = os.getenv("PINECONE_API_INDEX_OWID"), text_key="description")
114
 
115
  llm = get_llm(provider="openai",max_tokens = 1024,temperature = 0.0)
116
- reranker = get_reranker("nano")
117
 
118
  agent = make_graph_agent(llm=llm, vectorstore_ipcc=vectorstore, vectorstore_graphs=vectorstore_graphs, reranker=reranker)
119
 
@@ -142,7 +142,6 @@ async def chat(query, history, audience, sources, reports, relevant_content_sour
142
 
143
 
144
  docs = []
145
- used_figures=[]
146
  related_contents = []
147
  docs_html = ""
148
  output_query = ""
@@ -165,7 +164,7 @@ async def chat(query, history, audience, sources, reports, relevant_content_sour
165
  if "langgraph_node" in event["metadata"]:
166
  node = event["metadata"]["langgraph_node"]
167
 
168
- if event["event"] == "on_chain_end" and event["name"] == "retrieve_documents" :# when documents are retrieved
169
  docs, docs_html, history, used_documents, related_contents = handle_retrieved_documents(event, history, used_documents)
170
 
171
  elif event["event"] == "on_chain_end" and node == "categorize_intent" and event["name"] == "_write": # when the query is transformed
@@ -321,10 +320,19 @@ with gr.Blocks(title="Climate Q&A", css_paths=os.getcwd()+ "/style.css", theme=t
321
 
322
 
323
  with gr.Row(elem_id = "input-message"):
324
- textbox=gr.Textbox(placeholder="Ask me anything here!",show_label=False,scale=7,lines = 1,interactive = True,elem_id="input-textbox")
325
-
326
- config_button = gr.Button("",elem_id="config-button")
327
- # config_checkbox_button = gr.Checkbox(label = '⚙️', value="show",visible=True, interactive=True, elem_id="checkbox-config")
 
 
 
 
 
 
 
 
 
328
 
329
 
330
 
@@ -417,7 +425,9 @@ with gr.Blocks(title="Climate Q&A", css_paths=os.getcwd()+ "/style.css", theme=t
417
  with gr.Tabs(elem_id = "group-subtabs") as tabs_recommended_content:
418
 
419
  with gr.Tab("Figures",elem_id = "tab-figures",id = 3) as tab_figures:
420
- sources_raw = gr.State()
 
 
421
 
422
  with Modal(visible=False, elem_id="modal_figure_galery") as figure_modal:
423
  gallery_component = gr.Gallery(object_fit='scale-down',elem_id="gallery-component", height="80vh")
@@ -475,9 +485,9 @@ with gr.Blocks(title="Climate Q&A", css_paths=os.getcwd()+ "/style.css", theme=t
475
  )
476
 
477
  dropdown_external_sources = gr.CheckboxGroup(
478
- ["IPCC figures","OpenAlex", "OurWorldInData"],
479
  label="Select database to search for relevant content",
480
- value=["IPCC figures"],
481
  interactive=True,
482
  )
483
 
@@ -633,15 +643,25 @@ with gr.Blocks(title="Climate Q&A", css_paths=os.getcwd()+ "/style.css", theme=t
633
  return gr.update(label = recommended_content_notif_label), gr.update(label = sources_notif_label), gr.update(label = figures_notif_label), gr.update(label = graphs_notif_label), gr.update(label = papers_notif_label)
634
 
635
  (textbox
636
- .submit(start_chat, [textbox,chatbot, search_only], [textbox,tabs,chatbot],queue = False,api_name = "start_chat_textbox")
637
- .then(chat, [textbox,chatbot,dropdown_audience, dropdown_sources,dropdown_reports, dropdown_external_sources, search_only] ,[chatbot,sources_textbox,output_query,output_language, sources_raw, current_graphs],concurrency_limit = 8,api_name = "chat_textbox")
638
- .then(finish_chat, None, [textbox],api_name = "finish_chat_textbox")
639
- # .then(update_sources_number_display, [sources_textbox, figures_cards, current_graphs,papers_html],[tab_sources, tab_figures, tab_graphs, tab_papers] )
 
 
 
 
 
 
 
 
640
  )
641
 
 
 
642
  (examples_hidden
643
  .change(start_chat, [examples_hidden,chatbot, search_only], [textbox,tabs,chatbot],queue = False,api_name = "start_chat_examples")
644
- .then(chat, [examples_hidden,chatbot,dropdown_audience, dropdown_sources,dropdown_reports, dropdown_external_sources, search_only] ,[chatbot,sources_textbox,output_query,output_language, sources_raw, current_graphs],concurrency_limit = 8,api_name = "chat_textbox")
645
  .then(finish_chat, None, [textbox],api_name = "finish_chat_examples")
646
  # .then(update_sources_number_display, [sources_textbox, figures_cards, current_graphs,papers_html],[tab_sources, tab_figures, tab_graphs, tab_papers] )
647
  )
@@ -654,7 +674,7 @@ with gr.Blocks(title="Climate Q&A", css_paths=os.getcwd()+ "/style.css", theme=t
654
  return [gr.update(visible=visible_bools[i]) for i in range(len(samples))]
655
 
656
 
657
- sources_raw.change(process_figures, inputs=[sources_raw], outputs=[figures_cards, gallery_component])
658
 
659
  # update sources numbers
660
  sources_textbox.change(update_sources_number_display, [sources_textbox, figures_cards, current_graphs,papers_html],[tab_recommended_content, tab_sources, tab_figures, tab_graphs, tab_papers])
@@ -674,4 +694,6 @@ with gr.Blocks(title="Climate Q&A", css_paths=os.getcwd()+ "/style.css", theme=t
674
 
675
  demo.queue()
676
 
 
 
677
  demo.launch(ssr_mode=False)
 
113
  vectorstore_graphs = get_pinecone_vectorstore(embeddings_function, index_name = os.getenv("PINECONE_API_INDEX_OWID"), text_key="description")
114
 
115
  llm = get_llm(provider="openai",max_tokens = 1024,temperature = 0.0)
116
+ reranker = get_reranker("large")
117
 
118
  agent = make_graph_agent(llm=llm, vectorstore_ipcc=vectorstore, vectorstore_graphs=vectorstore_graphs, reranker=reranker)
119
 
 
142
 
143
 
144
  docs = []
 
145
  related_contents = []
146
  docs_html = ""
147
  output_query = ""
 
164
  if "langgraph_node" in event["metadata"]:
165
  node = event["metadata"]["langgraph_node"]
166
 
167
+ if event["event"] == "on_chain_end" and event["name"] == "retrieve_documents" and event["data"]["output"] != None:# when documents are retrieved
168
  docs, docs_html, history, used_documents, related_contents = handle_retrieved_documents(event, history, used_documents)
169
 
170
  elif event["event"] == "on_chain_end" and node == "categorize_intent" and event["name"] == "_write": # when the query is transformed
 
320
 
321
 
322
  with gr.Row(elem_id = "input-message"):
323
+ textbox = gr.Textbox(
324
+ placeholder="Ask me anything here!",
325
+ show_label=False,
326
+ scale=12,
327
+ lines=1,
328
+ interactive=True,
329
+ elem_id="input-textbox"
330
+ )
331
+
332
+ config_button = gr.Button(
333
+ "",
334
+ elem_id="config-button"
335
+ )
336
 
337
 
338
 
 
425
  with gr.Tabs(elem_id = "group-subtabs") as tabs_recommended_content:
426
 
427
  with gr.Tab("Figures",elem_id = "tab-figures",id = 3) as tab_figures:
428
+ sources_raw = gr.State([])
429
+ new_figures = gr.State([])
430
+ used_figures = gr.State([])
431
 
432
  with Modal(visible=False, elem_id="modal_figure_galery") as figure_modal:
433
  gallery_component = gr.Gallery(object_fit='scale-down',elem_id="gallery-component", height="80vh")
 
485
  )
486
 
487
  dropdown_external_sources = gr.CheckboxGroup(
488
+ ["Figures (IPCC/IPBES)","Papers (OpenAlex)", "Graphs (OurWorldInData)"],
489
  label="Select database to search for relevant content",
490
+ value=["Figures (IPCC/IPBES)"],
491
  interactive=True,
492
  )
493
 
 
643
  return gr.update(label = recommended_content_notif_label), gr.update(label = sources_notif_label), gr.update(label = figures_notif_label), gr.update(label = graphs_notif_label), gr.update(label = papers_notif_label)
644
 
645
  (textbox
646
+ .submit(start_chat, [textbox, chatbot, search_only],
647
+ [textbox, tabs, chatbot],
648
+ queue=False,
649
+ api_name="start_chat_textbox")
650
+ .then(chat, [textbox, chatbot, dropdown_audience, dropdown_sources,
651
+ dropdown_reports, dropdown_external_sources, search_only],
652
+ [chatbot, sources_textbox, output_query, output_language,
653
+ new_figures, current_graphs],
654
+ concurrency_limit=8,
655
+ api_name="chat_textbox")
656
+ .then(finish_chat, None, [textbox],
657
+ api_name="finish_chat_textbox")
658
  )
659
 
660
+
661
+
662
  (examples_hidden
663
  .change(start_chat, [examples_hidden,chatbot, search_only], [textbox,tabs,chatbot],queue = False,api_name = "start_chat_examples")
664
+ .then(chat, [examples_hidden,chatbot,dropdown_audience, dropdown_sources,dropdown_reports, dropdown_external_sources, search_only] ,[chatbot,sources_textbox,output_query,output_language, new_figures, current_graphs],concurrency_limit = 8,api_name = "chat_textbox")
665
  .then(finish_chat, None, [textbox],api_name = "finish_chat_examples")
666
  # .then(update_sources_number_display, [sources_textbox, figures_cards, current_graphs,papers_html],[tab_sources, tab_figures, tab_graphs, tab_papers] )
667
  )
 
674
  return [gr.update(visible=visible_bools[i]) for i in range(len(samples))]
675
 
676
 
677
+ new_figures.change(process_figures, inputs=[sources_raw, new_figures], outputs=[sources_raw, figures_cards, gallery_component])
678
 
679
  # update sources numbers
680
  sources_textbox.change(update_sources_number_display, [sources_textbox, figures_cards, current_graphs,papers_html],[tab_recommended_content, tab_sources, tab_figures, tab_graphs, tab_papers])
 
694
 
695
  demo.queue()
696
 
697
+
698
+
699
  demo.launch(ssr_mode=False)
climateqa/engine/chains/retrieve_documents.py CHANGED
@@ -87,7 +87,7 @@ def _get_k_images_by_question(n_questions):
87
  elif n_questions == 2:
88
  return 5
89
  elif n_questions == 3:
90
- return 2
91
  else:
92
  return 1
93
 
@@ -98,7 +98,10 @@ def _add_metadata_and_score(docs: List) -> Document:
98
  doc.page_content = doc.page_content.replace("\r\n"," ")
99
  doc.metadata["similarity_score"] = score
100
  doc.metadata["content"] = doc.page_content
101
- doc.metadata["page_number"] = int(doc.metadata["page_number"]) + 1
 
 
 
102
  # doc.page_content = f"""Doc {i+1} - {doc.metadata['short_name']}: {doc.page_content}"""
103
  docs_with_metadata.append(doc)
104
  return docs_with_metadata
@@ -222,7 +225,7 @@ async def retrieve_documents(state,config, vectorstore,reranker,llm,rerank_by_qu
222
  else:
223
  related_content = []
224
 
225
- search_figures = "IPCC figures" in state["relevant_content_sources"]
226
  search_only = state["search_only"]
227
 
228
  # Get the current question
 
87
  elif n_questions == 2:
88
  return 5
89
  elif n_questions == 3:
90
+ return 3
91
  else:
92
  return 1
93
 
 
98
  doc.page_content = doc.page_content.replace("\r\n"," ")
99
  doc.metadata["similarity_score"] = score
100
  doc.metadata["content"] = doc.page_content
101
+ if doc.metadata["page_number"] != "N/A":
102
+ doc.metadata["page_number"] = int(doc.metadata["page_number"]) + 1
103
+ else:
104
+ doc.metadata["page_number"] = 1
105
  # doc.page_content = f"""Doc {i+1} - {doc.metadata['short_name']}: {doc.page_content}"""
106
  docs_with_metadata.append(doc)
107
  return docs_with_metadata
 
225
  else:
226
  related_content = []
227
 
228
+ search_figures = "Figures (IPCC/IPBES)" in state["relevant_content_sources"]
229
  search_only = state["search_only"]
230
 
231
  # Get the current question
climateqa/engine/graph.py CHANGED
@@ -36,7 +36,7 @@ class GraphState(TypedDict):
36
  answer: str
37
  audience: str = "experts"
38
  sources_input: List[str] = ["IPCC","IPBES"]
39
- relevant_content_sources: List[str] = ["IPCC figures"]
40
  sources_auto: bool = True
41
  min_year: int = 1960
42
  max_year: int = None
@@ -82,7 +82,7 @@ def route_based_on_relevant_docs(state,threshold_docs=0.2):
82
  return "answer_rag_no_docs"
83
 
84
  def route_retrieve_documents(state):
85
- if state["search_only"] :
86
  return END
87
  elif len(state["remaining_questions"]) > 0:
88
  return "retrieve_documents"
 
36
  answer: str
37
  audience: str = "experts"
38
  sources_input: List[str] = ["IPCC","IPBES"]
39
+ relevant_content_sources: List[str] = ["Figures (IPCC/IPBES)"]
40
  sources_auto: bool = True
41
  min_year: int = 1960
42
  max_year: int = None
 
82
  return "answer_rag_no_docs"
83
 
84
  def route_retrieve_documents(state):
85
+ if len(state["remaining_questions"]) == 0 and state["search_only"] :
86
  return END
87
  elif len(state["remaining_questions"]) > 0:
88
  return "retrieve_documents"
front/utils.py CHANGED
@@ -39,25 +39,29 @@ def parse_output_llm_with_sources(output:str)->str:
39
  content_parts = "".join(parts)
40
  return content_parts
41
 
42
- def process_figures(docs:list)->tuple:
43
- gallery=[]
44
- used_figures =[]
45
  figures = '<div class="figures-container"><p></p> </div>'
 
 
 
46
  if docs == []:
47
- return figures, gallery
 
 
48
  docs_figures = [d for d in docs if d.metadata["chunk_type"] == "image"]
49
- for i, doc in enumerate(docs_figures):
50
- if doc.metadata["chunk_type"] == "image":
51
- if doc.metadata["figure_code"] != "N/A":
52
- title = f"{doc.metadata['figure_code']} - {doc.metadata['short_name']}"
53
- else:
54
- title = f"{doc.metadata['short_name']}"
55
 
56
 
57
- if title not in used_figures:
58
- used_figures.append(title)
 
 
59
  try:
60
- key = f"Image {i+1}"
61
 
62
  image_path = doc.metadata["image_path"].split("documents/")[1]
63
  img = get_image_from_azure_blob_storage(image_path)
@@ -70,12 +74,12 @@ def process_figures(docs:list)->tuple:
70
 
71
  img_str = base64.b64encode(buffered.getvalue()).decode()
72
 
73
- figures = figures + make_html_figure_sources(doc, i, img_str)
74
  gallery.append(img)
75
  except Exception as e:
76
- print(f"Skipped adding image {i} because of {e}")
77
 
78
- return figures, gallery
79
 
80
 
81
  def generate_html_graphs(graphs:list)->str:
 
39
  content_parts = "".join(parts)
40
  return content_parts
41
 
42
+ def process_figures(docs:list, new_figures:list)->tuple:
43
+ docs = docs + new_figures
44
+
45
  figures = '<div class="figures-container"><p></p> </div>'
46
+ gallery = []
47
+ used_figures = []
48
+
49
  if docs == []:
50
+ return figures, gallery, used_figures
51
+
52
+
53
  docs_figures = [d for d in docs if d.metadata["chunk_type"] == "image"]
54
+ for i_doc, doc in enumerate(docs_figures):
55
+ if doc.metadata["chunk_type"] == "image":
56
+ path = doc.metadata["image_path"]
 
 
 
57
 
58
 
59
+ if path not in used_figures:
60
+ used_figures.append(path)
61
+ figure_number = len(used_figures)
62
+
63
  try:
64
+ key = f"Image {figure_number}"
65
 
66
  image_path = doc.metadata["image_path"].split("documents/")[1]
67
  img = get_image_from_azure_blob_storage(image_path)
 
74
 
75
  img_str = base64.b64encode(buffered.getvalue()).decode()
76
 
77
+ figures = figures + make_html_figure_sources(doc, figure_number, img_str)
78
  gallery.append(img)
79
  except Exception as e:
80
+ print(f"Skipped adding image {figure_number} because of {e}")
81
 
82
+ return docs, figures, gallery
83
 
84
 
85
  def generate_html_graphs(graphs:list)->str: