|
|
|
|
|
from langchain_core.pydantic_v1 import BaseModel, Field |
|
from typing import List |
|
from typing import Literal |
|
from langchain.prompts import ChatPromptTemplate |
|
from langchain_core.utils.function_calling import convert_to_openai_function |
|
from langchain.output_parsers.openai_functions import JsonOutputFunctionsParser |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
ROUTING_INDEX = { |
|
"IPx":["IPCC", "IPBES", "IPOS"], |
|
"POC": ["AcclimaTerra", "PCAET","Biodiv"], |
|
"OpenAlex":["OpenAlex"], |
|
} |
|
|
|
POSSIBLE_SOURCES = [y for values in ROUTING_INDEX.values() for y in values] |
|
|
|
|
|
|
|
class QueryDecomposition(BaseModel): |
|
""" |
|
Decompose the user query into smaller parts to think step by step to answer this question |
|
Act as a simple planning agent |
|
""" |
|
|
|
questions: List[str] = Field( |
|
description=""" |
|
Think step by step to answer this question, and provide one or several search engine questions in the provided language for knowledge that you need. |
|
Suppose that the user is looking for information about climate change, energy, biodiversity, nature, and everything we can find the IPCC reports and scientific literature |
|
- If it's already a standalone and explicit question, just return the reformulated question for the search engine |
|
- If you need to decompose the question, output a list of maximum 2 to 3 questions |
|
""" |
|
) |
|
|
|
|
|
class Location(BaseModel): |
|
country:str = Field(...,description="The country if directly mentioned or inferred from the location (cities, regions, adresses), ex: France, USA, ...") |
|
location:str = Field(...,description="The specific place if mentioned (cities, regions, addresses), ex: Marseille, New York, Wisconsin, ...") |
|
|
|
class QueryTranslation(BaseModel): |
|
"""Translate the query into a given language""" |
|
|
|
question : str = Field( |
|
description=""" |
|
Translate the questions into the given language |
|
If the question is alrealdy in the given language, just return the same question |
|
""", |
|
) |
|
|
|
|
|
class QueryAnalysis(BaseModel): |
|
""" |
|
Analyze the user query to extract the relevant sources |
|
|
|
Deprecated: |
|
Analyzing the user query to extract topics, sources and date |
|
Also do query expansion to get alternative search queries |
|
Also provide simple keywords to feed a search engine |
|
""" |
|
|
|
sources: List[Literal["IPCC", "IPBES", "IPOS", "AcclimaTerra", "PCAET","Biodiv"]] = Field( |
|
..., |
|
description=""" |
|
Given a user question choose which documents would be most relevant for answering their question, |
|
- IPCC is for questions about climate change, energy, impacts, and everything we can find the IPCC reports |
|
- IPBES is for questions about biodiversity and nature |
|
- IPOS is for questions about the ocean and deep sea mining |
|
- AcclimaTerra is for questions about any specific place in, or close to, the french region "Nouvelle-Aquitaine" |
|
- PCAET is the Plan Climat Eneregie Territorial for the city of Paris |
|
- Biodiv is the Biodiversity plan for the city of Paris |
|
""", |
|
) |
|
|
|
|
|
|
|
def make_query_decomposition_chain(llm): |
|
"""Chain to decompose a query into smaller parts to think step by step to answer this question |
|
|
|
Args: |
|
llm (_type_): _description_ |
|
|
|
Returns: |
|
_type_: _description_ |
|
""" |
|
|
|
openai_functions = [convert_to_openai_function(QueryDecomposition)] |
|
llm_with_functions = llm.bind(functions = openai_functions,function_call={"name":"QueryDecomposition"}) |
|
|
|
prompt = ChatPromptTemplate.from_messages([ |
|
("system", "You are a helpful assistant, you will analyze, translate and reformulate the user input message using the function provided"), |
|
("user", "input: {input}") |
|
]) |
|
|
|
chain = prompt | llm_with_functions | JsonOutputFunctionsParser() |
|
return chain |
|
|
|
|
|
def make_query_analysis_chain(llm): |
|
"""Analyze the user query to extract the relevant sources""" |
|
|
|
openai_functions = [convert_to_openai_function(QueryAnalysis)] |
|
llm_with_functions = llm.bind(functions = openai_functions,function_call={"name":"QueryAnalysis"}) |
|
|
|
|
|
|
|
prompt = ChatPromptTemplate.from_messages([ |
|
("system", "You are a helpful assistant, you will analyze the user input message using the function provided"), |
|
("user", "input: {input}") |
|
]) |
|
|
|
|
|
chain = prompt | llm_with_functions | JsonOutputFunctionsParser() |
|
return chain |
|
|
|
|
|
def make_query_translation_chain(llm): |
|
"""Analyze the user query to extract the relevant sources""" |
|
|
|
openai_functions = [convert_to_openai_function(QueryTranslation)] |
|
llm_with_functions = llm.bind(functions = openai_functions,function_call={"name":"QueryTranslation"}) |
|
|
|
|
|
|
|
prompt = ChatPromptTemplate.from_messages([ |
|
("system", "You are a helpful assistant, translate the question into {language}"), |
|
("user", "input: {input}") |
|
]) |
|
|
|
|
|
chain = prompt | llm_with_functions | JsonOutputFunctionsParser() |
|
return chain |
|
|
|
def group_by_sources_types(sources): |
|
sources_types = {} |
|
IPx_sources = ["IPCC", "IPBES", "IPOS"] |
|
local_sources = ["AcclimaTerra", "PCAET","Biodiv"] |
|
if any(source in IPx_sources for source in sources): |
|
sources_types["IPx"] = list(set(sources).intersection(IPx_sources)) |
|
if any(source in local_sources for source in sources): |
|
sources_types["POC"] = list(set(sources).intersection(local_sources)) |
|
return sources_types |
|
|
|
|
|
def make_query_transform_node(llm,k_final=15): |
|
""" |
|
Creates a query transformation node that processes and transforms a given query state. |
|
Args: |
|
llm: The language model to be used for query decomposition and rewriting. |
|
k_final (int, optional): The final number of questions to be generated. Defaults to 15. |
|
Returns: |
|
function: A function that takes a query state and returns a transformed state. |
|
The returned function performs the following steps: |
|
1. Checks if the query should be processed in auto mode based on the state. |
|
2. Retrieves the input sources from the state or defaults to a predefined routing index. |
|
3. Decomposes the query using the decomposition chain. |
|
4. Analyzes each decomposed question using the rewriter chain. |
|
5. Ensures that the sources returned by the language model are valid. |
|
6. Explodes the questions into multiple questions with different sources based on the mode. |
|
7. Constructs a new state with the transformed questions and their respective sources. |
|
""" |
|
|
|
|
|
decomposition_chain = make_query_decomposition_chain(llm) |
|
query_analysis_chain = make_query_analysis_chain(llm) |
|
query_translation_chain = make_query_translation_chain(llm) |
|
|
|
def transform_query(state): |
|
print("---- Transform query ----") |
|
|
|
auto_mode = state.get("sources_auto", True) |
|
sources_input = state.get("sources_input", ROUTING_INDEX["IPx"]) |
|
|
|
|
|
new_state = {} |
|
|
|
|
|
decomposition_output = decomposition_chain.invoke({"input":state["query"]}) |
|
new_state.update(decomposition_output) |
|
|
|
|
|
|
|
questions = [] |
|
for question in new_state["questions"]: |
|
question_state = {"question":question} |
|
query_analysis_output = query_analysis_chain.invoke({"input":question}) |
|
|
|
|
|
|
|
if not query_analysis_output["sources"] or not all(source in ["IPCC", "IPBS", "IPOS","AcclimaTerra", "PCAET","Biodiv"] for source in query_analysis_output["sources"]): |
|
query_analysis_output["sources"] = ["IPCC", "IPBES", "IPOS"] |
|
|
|
sources_types = group_by_sources_types(query_analysis_output["sources"]) |
|
for source_type,sources in sources_types.items(): |
|
question_state = { |
|
"question":question, |
|
"sources":sources, |
|
"source_type":source_type |
|
} |
|
|
|
questions.append(question_state) |
|
|
|
|
|
for q in questions: |
|
if q["source_type"]=="IPx": |
|
translation_output = query_translation_chain.invoke({"input":q["question"],"language":"English"}) |
|
q["question"] = translation_output["question"] |
|
elif q["source_type"]=="POC": |
|
translation_output = query_translation_chain.invoke({"input":q["question"],"language":"French"}) |
|
q["question"] = translation_output["question"] |
|
|
|
|
|
new_questions = [] |
|
for q in questions: |
|
question,sources,source_type = q["question"],q["sources"], q["source_type"] |
|
|
|
|
|
if not auto_mode: |
|
sources = sources_input |
|
|
|
for index,index_sources in ROUTING_INDEX.items(): |
|
selected_sources = list(set(sources).intersection(index_sources)) |
|
if len(selected_sources) > 0: |
|
new_questions.append({"question":question,"sources":selected_sources,"index":index, "source_type":source_type}) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
n_questions = { |
|
"total":len(new_questions), |
|
"IPx":len([q for q in new_questions if q["index"] == "IPx"]), |
|
"POC":len([q for q in new_questions if q["index"] == "POC"]), |
|
} |
|
|
|
new_state = { |
|
"questions_list":new_questions, |
|
"n_questions":n_questions, |
|
"handled_questions_index":[], |
|
} |
|
print("New questions") |
|
print(new_questions) |
|
return new_state |
|
|
|
return transform_query |