diff --git a/.gitattributes b/.gitattributes index 0eb8c2e05739c1d905f2c2a19356d46372635988..e436d7c5d104668b7969e23dc6688b3dafbe9c99 100644 --- a/.gitattributes +++ b/.gitattributes @@ -44,3 +44,4 @@ documents/climate_gpt_v2_only_giec.faiss filter=lfs diff=lfs merge=lfs -text documents/climate_gpt_v2.faiss filter=lfs diff=lfs merge=lfs -text climateqa_v3.db filter=lfs diff=lfs merge=lfs -text climateqa_v3.faiss filter=lfs diff=lfs merge=lfs -text +data/drias/drias.db filter=lfs diff=lfs merge=lfs -text diff --git a/.gitignore b/.gitignore index 810e6d2a5f4099116c3b903da22346a544afee54..8288a2228a648af2e94d03ef1375299785bfe0c8 100644 --- a/.gitignore +++ b/.gitignore @@ -5,3 +5,16 @@ __pycache__/utils.cpython-38.pyc notebooks/ *.pyc + +**/.ipynb_checkpoints/ +**/.flashrank_cache/ + +data/ +sandbox/ + +climateqa/talk_to_data/database/ +*.db + +data_ingestion/ +.vscode +*old/ diff --git a/README.md b/README.md index b1f4ec3bf80bc3b00e8a839e1d0052789970b96a..4bc553e88e65fd5201809ec9ebd12312f96a9816 100644 --- a/README.md +++ b/README.md @@ -4,7 +4,7 @@ emoji: ๐ colorFrom: blue colorTo: red sdk: gradio -sdk_version: 4.19.1 +sdk_version: 5.0.2 app_file: app.py fullWidth: true pinned: false diff --git a/app.py b/app.py index ab849993528f591307fdd3a6b5be50730fc147f4..3f227ea2dcf837b2dd5d80078fa8bed67d95aa7a 100644 --- a/app.py +++ b/app.py @@ -1,44 +1,32 @@ -from climateqa.engine.embeddings import get_embeddings_function -embeddings_function = get_embeddings_function() - -from climateqa.papers.openalex import OpenAlex -from sentence_transformers import CrossEncoder - -reranker = CrossEncoder("mixedbread-ai/mxbai-rerank-xsmall-v1") -oa = OpenAlex() - -import gradio as gr -import pandas as pd -import numpy as np +# Import necessary libraries import os -import time -import re -import json +import gradio as gr -# from gradio_modal import Modal +from azure.storage.fileshare import ShareServiceClient -from io import BytesIO -import base64 +# Import custom modules +from climateqa.engine.embeddings import get_embeddings_function +from climateqa.engine.llm import get_llm +from climateqa.engine.vectorstore import get_pinecone_vectorstore +from climateqa.engine.reranker import get_reranker +from climateqa.engine.graph import make_graph_agent,make_graph_agent_poc +from climateqa.engine.chains.retrieve_papers import find_papers +from climateqa.chat import start_chat, chat_stream, finish_chat +from climateqa.engine.talk_to_data.main import ask_vanna + +from front.tabs import (create_config_modal, create_examples_tab, create_papers_tab, create_figures_tab, create_chat_interface, create_about_tab) +from front.utils import process_figures +from gradio_modal import Modal -from datetime import datetime -from azure.storage.fileshare import ShareServiceClient from utils import create_user_id +import logging +logging.basicConfig(level=logging.WARNING) +os.environ['TF_CPP_MIN_LOG_LEVEL'] = '2' # Suppresses INFO and WARNING logs +logging.getLogger().setLevel(logging.WARNING) -# ClimateQ&A imports -from climateqa.engine.llm import get_llm -from climateqa.engine.rag import make_rag_chain -from climateqa.engine.vectorstore import get_pinecone_vectorstore -from climateqa.engine.retriever import ClimateQARetriever -from climateqa.engine.embeddings import get_embeddings_function -from climateqa.engine.prompts import audience_prompts -from climateqa.sample_questions import QUESTIONS -from climateqa.constants import POSSIBLE_REPORTS -from climateqa.utils import get_image_from_azure_blob_storage -from climateqa.engine.keywords import make_keywords_chain -from climateqa.engine.rag import make_rag_papers_chain # Load environment variables in local mode try: @@ -47,6 +35,7 @@ try: except Exception as e: pass + # Set up Gradio Theme theme = gr.themes.Base( primary_hue="blue", @@ -54,15 +43,7 @@ theme = gr.themes.Base( font=[gr.themes.GoogleFont("Poppins"), "ui-sans-serif", "system-ui", "sans-serif"], ) - - -init_prompt = "" - -system_template = { - "role": "system", - "content": init_prompt, -} - +# Azure Blob Storage credentials account_key = os.environ["BLOB_ACCOUNT_KEY"] if len(account_key) == 86: account_key += "==" @@ -81,597 +62,273 @@ user_id = create_user_id() -def parse_output_llm_with_sources(output): - # Split the content into a list of text and "[Doc X]" references - content_parts = re.split(r'\[(Doc\s?\d+(?:,\s?Doc\s?\d+)*)\]', output) - parts = [] - for part in content_parts: - if part.startswith("Doc"): - subparts = part.split(",") - subparts = [subpart.lower().replace("doc","").strip() for subpart in subparts] - subparts = [f"""{subpart}""" for subpart in subparts] - parts.append("".join(subparts)) - else: - parts.append(part) - content_parts = "".join(parts) - return content_parts - - # Create vectorstore and retriever -vectorstore = get_pinecone_vectorstore(embeddings_function) -llm = get_llm(provider="openai",max_tokens = 1024,temperature = 0.0) - - -def make_pairs(lst): - """from a list of even lenght, make tupple pairs""" - return [(lst[i], lst[i + 1]) for i in range(0, len(lst), 2)] - - -def serialize_docs(docs): - new_docs = [] - for doc in docs: - new_doc = {} - new_doc["page_content"] = doc.page_content - new_doc["metadata"] = doc.metadata - new_docs.append(new_doc) - return new_docs - - - -async def chat(query,history,audience,sources,reports): - """taking a query and a message history, use a pipeline (reformulation, retriever, answering) to yield a tuple of: - (messages in gradio format, messages in langchain format, source documents)""" - - print(f">> NEW QUESTION : {query}") +embeddings_function = get_embeddings_function() +vectorstore = get_pinecone_vectorstore(embeddings_function, index_name=os.getenv("PINECONE_API_INDEX")) +vectorstore_graphs = get_pinecone_vectorstore(embeddings_function, index_name=os.getenv("PINECONE_API_INDEX_OWID"), text_key="description") +vectorstore_region = get_pinecone_vectorstore(embeddings_function, index_name=os.getenv("PINECONE_API_INDEX_LOCAL_V2")) - if audience == "Children": - audience_prompt = audience_prompts["children"] - elif audience == "General public": - audience_prompt = audience_prompts["general"] - elif audience == "Experts": - audience_prompt = audience_prompts["experts"] - else: - audience_prompt = audience_prompts["experts"] +llm = get_llm(provider="openai",max_tokens = 1024,temperature = 0.0) +if os.environ["GRADIO_ENV"] == "local": + reranker = get_reranker("nano") +else : + reranker = get_reranker("large") - # Prepare default values - if len(sources) == 0: - sources = ["IPCC"] +agent = make_graph_agent(llm=llm, vectorstore_ipcc=vectorstore, vectorstore_graphs=vectorstore_graphs, vectorstore_region = vectorstore_region, reranker=reranker, threshold_docs=0.2) +agent_poc = make_graph_agent_poc(llm=llm, vectorstore_ipcc=vectorstore, vectorstore_graphs=vectorstore_graphs, vectorstore_region = vectorstore_region, reranker=reranker, threshold_docs=0, version="v4")#TODO put back default 0.2 - if len(reports) == 0: - reports = [] - retriever = ClimateQARetriever(vectorstore=vectorstore,sources = sources,min_size = 200,reports = reports,k_summary = 3,k_total = 15,threshold=0.5) - rag_chain = make_rag_chain(retriever,llm) - - inputs = {"query": query,"audience": audience_prompt} - result = rag_chain.astream_log(inputs) #{"callbacks":[MyCustomAsyncHandler()]}) - # result = rag_chain.stream(inputs) - - path_reformulation = "/logs/reformulation/final_output" - path_keywords = "/logs/keywords/final_output" - path_retriever = "/logs/find_documents/final_output" - path_answer = "/logs/answer/streamed_output_str/-" - - docs_html = "" - output_query = "" - output_language = "" - output_keywords = "" - gallery = [] - - try: - async for op in result: - - op = op.ops[0] - - if op['path'] == path_reformulation: # reforulated question - try: - output_language = op['value']["language"] # str - output_query = op["value"]["question"] - except Exception as e: - raise gr.Error(f"ClimateQ&A Error: {e} - The error has been noted, try another question and if the error remains, you can contact us :)") - - if op["path"] == path_keywords: - try: - output_keywords = op['value']["keywords"] # str - output_keywords = " AND ".join(output_keywords) - except Exception as e: - pass - - - elif op['path'] == path_retriever: # documents - try: - docs = op['value']['docs'] # List[Document] - docs_html = [] - for i, d in enumerate(docs, 1): - docs_html.append(make_html_source(d, i)) - docs_html = "".join(docs_html) - except TypeError: - print("No documents found") - print("op: ",op) - continue - - elif op['path'] == path_answer: # final answer - new_token = op['value'] # str - # time.sleep(0.01) - previous_answer = history[-1][1] - previous_answer = previous_answer if previous_answer is not None else "" - answer_yet = previous_answer + new_token - answer_yet = parse_output_llm_with_sources(answer_yet) - history[-1] = (query,answer_yet) - - - - else: - continue - - history = [tuple(x) for x in history] - yield history,docs_html,output_query,output_language,gallery,output_query,output_keywords - - except Exception as e: - raise gr.Error(f"{e}") - - - try: - # Log answer on Azure Blob Storage - if os.getenv("GRADIO_ENV") != "local": - timestamp = str(datetime.now().timestamp()) - file = timestamp + ".json" - prompt = history[-1][0] - logs = { - "user_id": str(user_id), - "prompt": prompt, - "query": prompt, - "question":output_query, - "sources":sources, - "docs":serialize_docs(docs), - "answer": history[-1][1], - "time": timestamp, - } - log_on_azure(file, logs, share_client) - except Exception as e: - print(f"Error logging on Azure Blob Storage: {e}") - raise gr.Error(f"ClimateQ&A Error: {str(e)[:100]} - The error has been noted, try another question and if the error remains, you can contact us :)") - - image_dict = {} - for i,doc in enumerate(docs): +async def chat(query, history, audience, sources, reports, relevant_content_sources_selection, search_only): + print("chat cqa - message received") + async for event in chat_stream(agent, query, history, audience, sources, reports, relevant_content_sources_selection, search_only, share_client, user_id): + yield event - if doc.metadata["chunk_type"] == "image": - try: - key = f"Image {i+1}" - image_path = doc.metadata["image_path"].split("documents/")[1] - img = get_image_from_azure_blob_storage(image_path) - - # Convert the image to a byte buffer - buffered = BytesIO() - img.save(buffered, format="PNG") - img_str = base64.b64encode(buffered.getvalue()).decode() - - # Embedding the base64 string in Markdown - markdown_image = f"" - image_dict[key] = {"img":img,"md":markdown_image,"caption":doc.page_content,"key":key,"figure_code":doc.metadata["figure_code"]} - except Exception as e: - print(f"Skipped adding image {i} because of {e}") - - if len(image_dict) > 0: - - gallery = [x["img"] for x in list(image_dict.values())] - img = list(image_dict.values())[0] - img_md = img["md"] - img_caption = img["caption"] - img_code = img["figure_code"] - if img_code != "N/A": - img_name = f"{img['key']} - {img['figure_code']}" - else: - img_name = f"{img['key']}" - - answer_yet = history[-1][1] + f"\n\n{img_md}\n
{img_name} - {img_caption}
" - history[-1] = (history[-1][0],answer_yet) - history = [tuple(x) for x in history] - - # gallery = [x.metadata["image_path"] for x in docs if (len(x.metadata["image_path"]) > 0 and "IAS" in x.metadata["image_path"])] - # if len(gallery) > 0: - # gallery = list(set("|".join(gallery).split("|"))) - # gallery = [get_image_from_azure_blob_storage(x) for x in gallery] - - yield history,docs_html,output_query,output_language,gallery,output_query,output_keywords - - -def make_html_source(source,i): - meta = source.metadata - # content = source.page_content.split(":",1)[1].strip() - content = source.page_content.strip() - - toc_levels = [] - for j in range(2): - level = meta[f"toc_level{j}"] - if level != "N/A": - toc_levels.append(level) - else: - break - toc_levels = " > ".join(toc_levels) - - if len(toc_levels) > 0: - name = f"{toc_levels}{content}
-{content}
-AI-generated description
-