# Import necessary libraries import os import gradio as gr from azure.storage.fileshare import ShareServiceClient # Import custom modules from climateqa.engine.embeddings import get_embeddings_function from climateqa.engine.llm import get_llm from climateqa.engine.vectorstore import get_pinecone_vectorstore from climateqa.engine.reranker import get_reranker from climateqa.engine.graph import make_graph_agent, make_graph_agent_poc from climateqa.engine.chains.retrieve_papers import find_papers from climateqa.chat import start_chat, chat_stream, finish_chat from front.tabs import create_config_modal, cqa_tab, create_about_tab from front.tabs import MainTabPanel, ConfigPanel from front.tabs.tab_drias import create_drias_tab from front.utils import process_figures from gradio_modal import Modal from utils import create_user_id import logging logging.basicConfig(level=logging.WARNING) os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2" # Suppresses INFO and WARNING logs logging.getLogger().setLevel(logging.WARNING) # Load environment variables in local mode try: from dotenv import load_dotenv load_dotenv() except Exception as e: pass # Set up Gradio Theme theme = gr.themes.Base( primary_hue="blue", secondary_hue="red", font=[gr.themes.GoogleFont("Poppins"), "ui-sans-serif", "system-ui", "sans-serif"], ) # Azure Blob Storage credentials account_key = os.environ["BLOB_ACCOUNT_KEY"] if len(account_key) == 86: account_key += "==" credential = { "account_key": account_key, "account_name": os.environ["BLOB_ACCOUNT_NAME"], } account_url = os.environ["BLOB_ACCOUNT_URL"] file_share_name = "climateqa" service = ShareServiceClient(account_url=account_url, credential=credential) share_client = service.get_share_client(file_share_name) user_id = create_user_id() # Create vectorstore and retriever embeddings_function = get_embeddings_function() vectorstore = get_pinecone_vectorstore( embeddings_function, index_name=os.getenv("PINECONE_API_INDEX") ) vectorstore_graphs = get_pinecone_vectorstore( embeddings_function, index_name=os.getenv("PINECONE_API_INDEX_OWID"), text_key="description", ) vectorstore_region = get_pinecone_vectorstore( embeddings_function, index_name=os.getenv("PINECONE_API_INDEX_LOCAL_V2") ) llm = get_llm(provider="openai", max_tokens=1024, temperature=0.0) if os.environ["GRADIO_ENV"] == "local": reranker = get_reranker("nano") else: reranker = get_reranker("large") agent = make_graph_agent( llm=llm, vectorstore_ipcc=vectorstore, vectorstore_graphs=vectorstore_graphs, vectorstore_region=vectorstore_region, reranker=reranker, threshold_docs=0.2, ) agent_poc = make_graph_agent_poc( llm=llm, vectorstore_ipcc=vectorstore, vectorstore_graphs=vectorstore_graphs, vectorstore_region=vectorstore_region, reranker=reranker, threshold_docs=0, version="v4", ) # TODO put back default 0.2 async def chat( query, history, audience, sources, reports, relevant_content_sources_selection, search_only, ): print("chat cqa - message received") # Ensure default values if components are not set audience = audience or "Experts" sources = sources or ["IPCC", "IPBES"] reports = reports or [] relevant_content_sources_selection = relevant_content_sources_selection or ["Figures (IPCC/IPBES)"] search_only = bool(search_only) # Convert to boolean if None async for event in chat_stream( agent, query, history, audience, sources, reports, relevant_content_sources_selection, search_only, share_client, user_id, ): yield event async def chat_poc( query, history, audience, sources, reports, relevant_content_sources_selection, search_only, ): print("chat poc - message received") async for event in chat_stream( agent_poc, query, history, audience, sources, reports, relevant_content_sources_selection, search_only, share_client, user_id, ): yield event # -------------------------------------------------------------------- # Gradio # -------------------------------------------------------------------- # Function to update modal visibility def update_config_modal_visibility(config_open): print(config_open) new_config_visibility_status = not config_open return Modal(visible=new_config_visibility_status), new_config_visibility_status def update_sources_number_display( sources_textbox, figures_cards, current_graphs, papers_html ): sources_number = sources_textbox.count("

") figures_number = figures_cards.count("

") graphs_number = current_graphs.count("") sources_notif_label = f"Sources ({sources_number})" figures_notif_label = f"Figures ({figures_number})" graphs_notif_label = f"Graphs ({graphs_number})" papers_notif_label = f"Papers ({papers_number})" recommended_content_notif_label = ( f"Recommended content ({figures_number + graphs_number + papers_number})" ) return ( gr.update(label=recommended_content_notif_label), gr.update(label=sources_notif_label), gr.update(label=figures_notif_label), gr.update(label=graphs_notif_label), gr.update(label=papers_notif_label), ) def config_event_handling( main_tabs_components: list[MainTabPanel], config_componenets: ConfigPanel ): config_open = config_componenets.config_open config_modal = config_componenets.config_modal close_config_modal = config_componenets.close_config_modal_button for button in [close_config_modal] + [ main_tab_component.config_button for main_tab_component in main_tabs_components ]: button.click( fn=update_config_modal_visibility, inputs=[config_open], outputs=[config_modal, config_open], ) def event_handling( main_tab_components: MainTabPanel, config_components: ConfigPanel, tab_name="ClimateQ&A", ): chatbot = main_tab_components.chatbot textbox = main_tab_components.textbox tabs = main_tab_components.tabs sources_raw = main_tab_components.sources_raw new_figures = main_tab_components.new_figures current_graphs = main_tab_components.current_graphs examples_hidden = main_tab_components.examples_hidden sources_textbox = main_tab_components.sources_textbox figures_cards = main_tab_components.figures_cards gallery_component = main_tab_components.gallery_component papers_direct_search = main_tab_components.papers_direct_search papers_html = main_tab_components.papers_html citations_network = main_tab_components.citations_network papers_summary = main_tab_components.papers_summary tab_recommended_content = main_tab_components.tab_recommended_content tab_sources = main_tab_components.tab_sources tab_figures = main_tab_components.tab_figures tab_graphs = main_tab_components.tab_graphs tab_papers = main_tab_components.tab_papers graphs_container = main_tab_components.graph_container follow_up_examples = main_tab_components.follow_up_examples follow_up_examples_hidden = main_tab_components.follow_up_examples_hidden dropdown_sources = config_components.dropdown_sources dropdown_reports = config_components.dropdown_reports dropdown_external_sources = config_components.dropdown_external_sources search_only = config_components.search_only dropdown_audience = config_components.dropdown_audience after = config_components.after output_query = config_components.output_query output_language = config_components.output_language new_sources_hmtl = gr.State([]) ttd_data = gr.State([]) if tab_name == "ClimateQ&A": print("chat cqa - message sent") # Event for textbox ( textbox.submit( start_chat, [textbox, chatbot, search_only], [textbox, tabs, chatbot, sources_raw], queue=False, api_name=f"start_chat_{textbox.elem_id}", ) .then( chat, [ textbox, chatbot, dropdown_audience, dropdown_sources, dropdown_reports, dropdown_external_sources, search_only, ], [ chatbot, new_sources_hmtl, output_query, output_language, new_figures, current_graphs, follow_up_examples.dataset, ], concurrency_limit=8, api_name=f"chat_{textbox.elem_id}", ) .then( finish_chat, None, [textbox], api_name=f"finish_chat_{textbox.elem_id}" ) ) # Event for examples_hidden ( examples_hidden.change( start_chat, [examples_hidden, chatbot, search_only], [examples_hidden, tabs, chatbot, sources_raw], queue=False, api_name=f"start_chat_{examples_hidden.elem_id}", ) .then( chat, [ examples_hidden, chatbot, dropdown_audience, dropdown_sources, dropdown_reports, dropdown_external_sources, search_only, ], [ chatbot, new_sources_hmtl, output_query, output_language, new_figures, current_graphs, follow_up_examples.dataset, ], concurrency_limit=8, api_name=f"chat_{examples_hidden.elem_id}", ) .then( finish_chat, None, [textbox], api_name=f"finish_chat_{examples_hidden.elem_id}", ) ) ( follow_up_examples_hidden.change( start_chat, [follow_up_examples_hidden, chatbot, search_only], [follow_up_examples_hidden, tabs, chatbot, sources_raw], queue=False, api_name=f"start_chat_{examples_hidden.elem_id}", ) .then( chat, [ follow_up_examples_hidden, chatbot, dropdown_audience, dropdown_sources, dropdown_reports, dropdown_external_sources, search_only, ], [ chatbot, new_sources_hmtl, output_query, output_language, new_figures, current_graphs, follow_up_examples.dataset, ], concurrency_limit=8, api_name=f"chat_{examples_hidden.elem_id}", ) .then( finish_chat, None, [textbox], api_name=f"finish_chat_{follow_up_examples_hidden.elem_id}", ) ) elif tab_name == "France - Local Q&A": print("chat poc - message sent") # Event for textbox ( textbox.submit( start_chat, [textbox, chatbot, search_only], [textbox, tabs, chatbot, sources_raw], queue=False, api_name=f"start_chat_{textbox.elem_id}", ) .then( chat_poc, [ textbox, chatbot, dropdown_audience, dropdown_sources, dropdown_reports, dropdown_external_sources, search_only, ], [ chatbot, new_sources_hmtl, output_query, output_language, new_figures, current_graphs, follow_up_examples.dataset, ], concurrency_limit=8, api_name=f"chat_{textbox.elem_id}", ) .then( finish_chat, None, [textbox], api_name=f"finish_chat_{textbox.elem_id}" ) ) # Event for examples_hidden ( examples_hidden.change( start_chat, [examples_hidden, chatbot, search_only], [examples_hidden, tabs, chatbot, sources_raw], queue=False, api_name=f"start_chat_{examples_hidden.elem_id}", ) .then( chat_poc, [ examples_hidden, chatbot, dropdown_audience, dropdown_sources, dropdown_reports, dropdown_external_sources, search_only, ], [ chatbot, new_sources_hmtl, output_query, output_language, new_figures, current_graphs, follow_up_examples.dataset, ], concurrency_limit=8, api_name=f"chat_{examples_hidden.elem_id}", ) .then( finish_chat, None, [textbox], api_name=f"finish_chat_{examples_hidden.elem_id}", ) ) ( follow_up_examples_hidden.change( start_chat, [follow_up_examples_hidden, chatbot, search_only], [follow_up_examples_hidden, tabs, chatbot, sources_raw], queue=False, api_name=f"start_chat_{examples_hidden.elem_id}", ) .then( chat, [ follow_up_examples_hidden, chatbot, dropdown_audience, dropdown_sources, dropdown_reports, dropdown_external_sources, search_only, ], [ chatbot, new_sources_hmtl, output_query, output_language, new_figures, current_graphs, follow_up_examples.dataset, ], concurrency_limit=8, api_name=f"chat_{examples_hidden.elem_id}", ) .then( finish_chat, None, [textbox], api_name=f"finish_chat_{follow_up_examples_hidden.elem_id}", ) ) new_sources_hmtl.change( lambda x: x, inputs=[new_sources_hmtl], outputs=[sources_textbox] ) current_graphs.change( lambda x: x, inputs=[current_graphs], outputs=[graphs_container] ) new_figures.change( process_figures, inputs=[sources_raw, new_figures], outputs=[sources_raw, figures_cards, gallery_component], ) # Update sources numbers for component in [sources_textbox, figures_cards, current_graphs, papers_html]: component.change( update_sources_number_display, [sources_textbox, figures_cards, current_graphs, papers_html], [tab_recommended_content, tab_sources, tab_figures, tab_graphs, tab_papers], ) # Search for papers for component in [textbox, examples_hidden, papers_direct_search]: component.submit( find_papers, [component, after, dropdown_external_sources], [papers_html, citations_network, papers_summary], ) # if tab_name == "France - Local Q&A": # Not untill results are good enough # # Drias search # textbox.submit(ask_vanna, [textbox], [vanna_sql_query ,vanna_table, vanna_display]) def main_ui(): # config_open = gr.State(True) with gr.Blocks( title="Climate Q&A", css_paths=os.getcwd() + "/style.css", theme=theme, elem_id="main-component", ) as demo: config_components = create_config_modal() with gr.Tabs(): cqa_components = cqa_tab(tab_name="ClimateQ&A") local_cqa_components = cqa_tab(tab_name="France - Local Q&A") create_drias_tab(share_client=share_client, user_id=user_id) create_about_tab() event_handling(cqa_components, config_components, tab_name="ClimateQ&A") event_handling( local_cqa_components, config_components, tab_name="France - Local Q&A" ) config_event_handling([cqa_components, local_cqa_components], config_components) demo.queue() return demo demo = main_ui() demo.launch(ssr_mode=False)