|
|
|
|
|
|
|
import re |
|
from random import shuffle, sample |
|
|
|
from langchain_groq import ChatGroq |
|
from langchain_openai import ChatOpenAI |
|
from langchain_core.messages import HumanMessage |
|
from langchain_community.graphs import Neo4jGraph |
|
from langchain_community.chains.graph_qa.cypher_utils import CypherQueryCorrector, Schema |
|
from langchain_core.output_parsers import StrOutputParser |
|
from langchain_core.prompts import ChatPromptTemplate |
|
from langchain_core.pydantic_v1 import BaseModel, Field |
|
from langchain_groq import ChatGroq |
|
|
|
from langgraph.graph import StateGraph |
|
|
|
from llmlingua import PromptCompressor |
|
|
|
from ki_gen.prompts import ( |
|
CYPHER_GENERATION_PROMPT, |
|
CONCEPT_SELECTION_PROMPT, |
|
BINARY_GRADER_PROMPT, |
|
SCORE_GRADER_PROMPT, |
|
RELEVANT_CONCEPTS_PROMPT, |
|
) |
|
from ki_gen.utils import ConfigSchema, DocRetrieverState, get_model, format_doc |
|
|
|
|
|
|
|
|
|
def extract_cypher(text: str) -> str: |
|
"""Extract Cypher code from a text. |
|
|
|
Args: |
|
text: Text to extract Cypher code from. |
|
|
|
Returns: |
|
Cypher code extracted from the text. |
|
""" |
|
|
|
pattern_1 = r"```cypher\n(.*?)```" |
|
pattern_2 = r"```\n(.*?)```" |
|
|
|
|
|
matches_1 = re.findall(pattern_1, text, re.DOTALL) |
|
matches_2 = re.findall(pattern_2, text, re.DOTALL) |
|
return [ |
|
matches_1[0] if matches_1 else text, |
|
matches_2[0] if matches_2 else text, |
|
text |
|
] |
|
|
|
def get_cypher_gen_chain(model: str = "openai"): |
|
""" |
|
Returns cypher gen chain using specified model for generation |
|
This is used when the 'auto' cypher generation method has been configured |
|
""" |
|
|
|
if model=="openai": |
|
llm_cypher_gen = ChatOpenAI(model='gpt-4o', base_url="https://llm.synapse.thalescloud.io/") |
|
else: |
|
llm_cypher_gen = ChatGroq(model = "mixtral-8x7b-32768") |
|
cypher_gen_chain = CYPHER_GENERATION_PROMPT | llm_cypher_gen | StrOutputParser() | extract_cypher |
|
return cypher_gen_chain |
|
|
|
def get_concept_selection_chain(model: str = "openai"): |
|
""" |
|
Returns a chain to select the most relevant topic using specified model for generation. |
|
This is used when the 'guided' cypher generation method has been configured |
|
""" |
|
|
|
if model == "openai": |
|
llm_topic_selection = ChatOpenAI(model='gpt-4o', base_url="https://llm.synapse.thalescloud.io/") |
|
else: |
|
llm_topic_selection = ChatGroq(model="llama3-70b-8192", max_tokens=8192) |
|
print(f"FOUND LLM TOPIC SELECTION FOR THE CONCEPT SELECTION PROMPT : {llm_topic_selection}") |
|
topic_selection_chain = CONCEPT_SELECTION_PROMPT | llm_topic_selection | StrOutputParser() |
|
return topic_selection_chain |
|
|
|
def get_concepts(graph: Neo4jGraph): |
|
concept_cypher = "MATCH (c:Concept) return c" |
|
if isinstance(graph, Neo4jGraph): |
|
concepts = graph.query(concept_cypher) |
|
else: |
|
user_input = input("Topics : ") |
|
concepts = eval(user_input) |
|
|
|
concepts_name = [concept['c']['name'] for concept in concepts] |
|
return concepts_name |
|
|
|
def get_related_concepts(graph: Neo4jGraph, question: str): |
|
concepts = get_concepts(graph) |
|
llm = get_model(model='gpt-4o') |
|
print(f"this is the llm variable : {llm}") |
|
def parse_answer(llm_answer : str): |
|
print(f"This the llm_answer : {llm_answer}") |
|
return re.split("\n(?:\d)+\.\s", llm_answer.split("Concepts:")[1])[1:] |
|
related_concepts_chain = RELEVANT_CONCEPTS_PROMPT | llm | StrOutputParser() | parse_answer |
|
|
|
related_concepts_raw = related_concepts_chain.invoke({"user_query" : question, "concepts" : '\n'.join(concepts)}) |
|
|
|
|
|
related_concepts_cleaned = [] |
|
for related_concept in related_concepts_raw: |
|
|
|
if related_concept in concepts: |
|
related_concepts_cleaned.append(related_concept) |
|
else: |
|
|
|
|
|
for concept in concepts: |
|
if related_concept in concept: |
|
related_concepts_cleaned.append(concept) |
|
break |
|
|
|
|
|
return related_concepts_cleaned |
|
|
|
def build_concept_string(graph: Neo4jGraph, concept_list: list[str]): |
|
concept_string = "" |
|
for concept in concept_list: |
|
concept_description_query = f""" |
|
MATCH (c:Concept {{name: "{concept}" }}) RETURN c.description |
|
""" |
|
concept_description = graph.query(concept_description_query)[0]['c.description'] |
|
concept_string += f"name: {concept}\ndescription: {concept_description}\n\n" |
|
return concept_string |
|
|
|
def get_global_concepts(graph: Neo4jGraph): |
|
concept_cypher = "MATCH (gc:GlobalConcept) return gc" |
|
if isinstance(graph, Neo4jGraph): |
|
concepts = graph.query(concept_cypher) |
|
else: |
|
user_input = input("Topics : ") |
|
concepts = eval(user_input) |
|
|
|
concepts_name = [concept['gc']['name'] for concept in concepts] |
|
return concepts_name |
|
|
|
def generate_cypher(state: DocRetrieverState, config: ConfigSchema): |
|
""" |
|
The node where the cypher is generated |
|
""" |
|
|
|
graph = config["configurable"].get("graph") |
|
question = state['query'] |
|
related_concepts = get_related_concepts(graph, question) |
|
cyphers = [] |
|
|
|
if config["configurable"].get("cypher_gen_method") == 'auto': |
|
cypher_gen_chain = get_cypher_gen_chain() |
|
cyphers = cypher_gen_chain.invoke({ |
|
"schema": graph.schema, |
|
"question": question, |
|
"concepts": related_concepts |
|
}) |
|
|
|
|
|
if config["configurable"].get("cypher_gen_method") == 'guided': |
|
concept_selection_chain = get_concept_selection_chain() |
|
print(f"Concept selection chain is : {concept_selection_chain}") |
|
selected_topic = concept_selection_chain.invoke({"question" : question, "concepts": get_concepts(graph)}) |
|
print(f"Selected topic are : {selected_topic}") |
|
cyphers = [generate_cypher_from_topic(selected_topic, state['current_plan_step'])] |
|
print(f"Cyphers are : {cyphers}") |
|
|
|
|
|
if config["configurable"].get("validate_cypher"): |
|
corrector_schema = [Schema(el["start"], el["type"], el["end"]) for el in graph.structured_schema.get("relationships")] |
|
cypher_corrector = CypherQueryCorrector(corrector_schema) |
|
cyphers = [cypher_corrector(cypher) for cypher in cyphers] |
|
|
|
return {"cyphers" : cyphers} |
|
|
|
def generate_cypher_from_topic(selected_concept: str, plan_step: int): |
|
""" |
|
Helper function used when the 'guided' cypher generation method has been configured |
|
""" |
|
|
|
print(f"L.176 PLAN STEP : {plan_step}") |
|
cypher_el = "(n) return n.title, n.description" |
|
match plan_step: |
|
case 0: |
|
cypher_el = "(ts:TechnicalSpecification) RETURN ts.title, ts.scope, ts.description" |
|
case 1: |
|
cypher_el = "(rp:ResearchPaper) RETURN rp.title, rp.abstract" |
|
case 2: |
|
cypher_el = "(ki:KeyIssue) RETURN ki.description" |
|
return f"MATCH (c:Concept {{name:'{selected_concept}'}})-[:RELATED_TO]-{cypher_el}" |
|
|
|
def get_docs(state:DocRetrieverState, config:ConfigSchema): |
|
""" |
|
This node retrieves docs from the graph using the generated cypher |
|
""" |
|
graph = config["configurable"].get("graph") |
|
output = [] |
|
if graph is not None: |
|
for cypher in state["cyphers"]: |
|
try: |
|
output = graph.query(cypher) |
|
break |
|
except Exception as e: |
|
print("Failed to retrieve docs : {e}") |
|
|
|
|
|
all_docs = [] |
|
for doc in output: |
|
unwinded_doc = {} |
|
for key in doc: |
|
if isinstance(doc[key], dict): |
|
all_docs.append(doc[key]) |
|
else: |
|
unwinded_doc.update({key: doc[key]}) |
|
if unwinded_doc: |
|
all_docs.append(unwinded_doc) |
|
|
|
|
|
filtered_docs = [] |
|
for doc in all_docs: |
|
if doc not in filtered_docs: |
|
filtered_docs.append(doc) |
|
|
|
return {"docs": filtered_docs} |
|
|
|
|
|
|
|
|
|
|
|
|
|
class GradeDocumentsBinary(BaseModel): |
|
"""Binary score for relevance check on retrieved documents.""" |
|
|
|
binary_score: str = Field( |
|
description="Documents are relevant to the question, 'yes' or 'no'" |
|
) |
|
|
|
|
|
|
|
|
|
def get_binary_grader(model="mixtral-8x7b-32768"): |
|
""" |
|
Returns a binary grader to evaluate relevance of documents using specified model for generation |
|
This is used when the 'binary' evaluation method has been configured |
|
""" |
|
|
|
|
|
if model == "gpt-4o": |
|
llm_grader_binary = ChatOpenAI(model='gpt-4o', base_url="https://llm.synapse.thalescloud.io/", temperature=0) |
|
else: |
|
llm_grader_binary = ChatGroq(model="mixtral-8x7b-32768", temperature=0) |
|
structured_llm_grader_binary = llm_grader_binary.with_structured_output(GradeDocumentsBinary) |
|
retrieval_grader_binary = BINARY_GRADER_PROMPT | structured_llm_grader_binary |
|
return retrieval_grader_binary |
|
|
|
|
|
class GradeDocumentsScore(BaseModel): |
|
"""Score for relevance check on retrieved documents.""" |
|
|
|
score: float = Field( |
|
description="Documents are relevant to the question, score between 0 (completely irrelevant) and 1 (perfectly relevant)" |
|
) |
|
|
|
def get_score_grader(model="mixtral-8x7b-32768"): |
|
""" |
|
Returns a score grader to evaluate relevance of documents using specified model for generation |
|
This is used when the 'score' evaluation method has been configured |
|
""" |
|
if model == "gpt-4o": |
|
llm_grader_score = ChatOpenAI(model='gpt-4o', base_url="https://llm.synapse.thalescloud.io/", temperature=0) |
|
else: |
|
llm_grader_score = ChatGroq(model="mixtral-8x7b-32768", temperature = 0) |
|
structured_llm_grader_score = llm_grader_score.with_structured_output(GradeDocumentsScore) |
|
retrieval_grader_score = SCORE_GRADER_PROMPT | structured_llm_grader_score |
|
return retrieval_grader_score |
|
|
|
|
|
def eval_doc(doc, query, method="binary", threshold=0.7, eval_model="mixtral-8x7b-32768"): |
|
''' |
|
doc : the document to evaluate |
|
query : the query to which to doc shoud be relevant |
|
method : "binary" or "score" |
|
threshold : for "score" method, score above which a doc is considered relevant |
|
''' |
|
if method == "binary": |
|
retrieval_grader_binary = get_binary_grader(model=eval_model) |
|
return 1 if (retrieval_grader_binary.invoke({"question": query, "document":doc}).binary_score == 'yes') else 0 |
|
elif method == "score": |
|
retrieval_grader_score = get_score_grader(model=eval_model) |
|
score = retrieval_grader_score.invoke({"query": query, "document":doc}).score or None |
|
if score is not None: |
|
return score if score >= threshold else 0 |
|
else: |
|
|
|
return 1 |
|
else: |
|
raise ValueError("Invalid method") |
|
|
|
def eval_docs(state: DocRetrieverState, config: ConfigSchema): |
|
""" |
|
This node performs evaluation of the retrieved docs and |
|
""" |
|
|
|
eval_method = config["configurable"].get("eval_method") or "binary" |
|
MAX_DOCS = config["configurable"].get("max_docs") or 15 |
|
valid_doc_scores = [] |
|
|
|
for doc in sample(state["docs"], min(25, len(state["docs"]))): |
|
score = eval_doc( |
|
doc=format_doc(doc), |
|
query=state["query"], |
|
method=eval_method, |
|
threshold=config["configurable"].get("eval_threshold") or 0.7, |
|
eval_model = config["configurable"].get("eval_model") or "mixtral-8x7b-32768" |
|
) |
|
if score: |
|
valid_doc_scores.append((doc, score)) |
|
|
|
if eval_method == 'score': |
|
|
|
valid_docs = sorted(valid_doc_scores, key=lambda x: x[1]) |
|
valid_docs = [valid_doc[0] for valid_doc in valid_docs[:MAX_DOCS]] |
|
else: |
|
|
|
shuffle(valid_doc_scores) |
|
valid_docs = [valid_doc[0] for valid_doc in valid_doc_scores[:MAX_DOCS]] |
|
|
|
return {"valid_docs": valid_docs + (state["valid_docs"] or [])} |
|
|
|
|
|
|
|
def build_data_retriever_graph(memory): |
|
""" |
|
Builds the data_retriever graph |
|
""" |
|
graph_builder_doc_retriever = StateGraph(DocRetrieverState) |
|
|
|
graph_builder_doc_retriever.add_node("generate_cypher", generate_cypher) |
|
graph_builder_doc_retriever.add_node("get_docs", get_docs) |
|
graph_builder_doc_retriever.add_node("eval_docs", eval_docs) |
|
|
|
|
|
graph_builder_doc_retriever.add_edge("__start__", "generate_cypher") |
|
graph_builder_doc_retriever.add_edge("generate_cypher", "get_docs") |
|
graph_builder_doc_retriever.add_edge("get_docs", "eval_docs") |
|
graph_builder_doc_retriever.add_edge("eval_docs", "__end__") |
|
|
|
graph_doc_retriever = graph_builder_doc_retriever.compile(checkpointer=memory) |
|
|
|
return graph_doc_retriever |
|
|