bjhk / ki_gen /data_retriever.py
heymenn's picture
Upload 15 files
6aaddef verified
#!/usr/bin/env python
# coding: utf-8
import re
from random import shuffle, sample
from langchain_groq import ChatGroq
from langchain_openai import ChatOpenAI
from langchain_core.messages import HumanMessage
from langchain_community.graphs import Neo4jGraph
from langchain_community.chains.graph_qa.cypher_utils import CypherQueryCorrector, Schema
from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.pydantic_v1 import BaseModel, Field
from langchain_groq import ChatGroq
from langgraph.graph import StateGraph
from llmlingua import PromptCompressor
from ki_gen.prompts import (
CYPHER_GENERATION_PROMPT,
CONCEPT_SELECTION_PROMPT,
BINARY_GRADER_PROMPT,
SCORE_GRADER_PROMPT,
RELEVANT_CONCEPTS_PROMPT,
)
from ki_gen.utils import ConfigSchema, DocRetrieverState, get_model, format_doc
def extract_cypher(text: str) -> str:
"""Extract Cypher code from a text.
Args:
text: Text to extract Cypher code from.
Returns:
Cypher code extracted from the text.
"""
# The pattern to find Cypher code enclosed in triple backticks
pattern_1 = r"```cypher\n(.*?)```"
pattern_2 = r"```\n(.*?)```"
# Find all matches in the input text
matches_1 = re.findall(pattern_1, text, re.DOTALL)
matches_2 = re.findall(pattern_2, text, re.DOTALL)
return [
matches_1[0] if matches_1 else text,
matches_2[0] if matches_2 else text,
text
]
def get_cypher_gen_chain(model: str = "openai"):
"""
Returns cypher gen chain using specified model for generation
This is used when the 'auto' cypher generation method has been configured
"""
if model=="openai":
llm_cypher_gen = ChatOpenAI(model='gpt-4o', base_url="https://llm.synapse.thalescloud.io/")
else:
llm_cypher_gen = ChatGroq(model = "mixtral-8x7b-32768")
cypher_gen_chain = CYPHER_GENERATION_PROMPT | llm_cypher_gen | StrOutputParser() | extract_cypher
return cypher_gen_chain
def get_concept_selection_chain(model: str = "openai"):
"""
Returns a chain to select the most relevant topic using specified model for generation.
This is used when the 'guided' cypher generation method has been configured
"""
if model == "openai":
llm_topic_selection = ChatOpenAI(model='gpt-4o', base_url="https://llm.synapse.thalescloud.io/")
else:
llm_topic_selection = ChatGroq(model="llama3-70b-8192", max_tokens=8192)
print(f"FOUND LLM TOPIC SELECTION FOR THE CONCEPT SELECTION PROMPT : {llm_topic_selection}")
topic_selection_chain = CONCEPT_SELECTION_PROMPT | llm_topic_selection | StrOutputParser()
return topic_selection_chain
def get_concepts(graph: Neo4jGraph):
concept_cypher = "MATCH (c:Concept) return c"
if isinstance(graph, Neo4jGraph):
concepts = graph.query(concept_cypher)
else:
user_input = input("Topics : ")
concepts = eval(user_input)
concepts_name = [concept['c']['name'] for concept in concepts]
return concepts_name
def get_related_concepts(graph: Neo4jGraph, question: str):
concepts = get_concepts(graph)
llm = get_model(model='gpt-4o')
print(f"this is the llm variable : {llm}")
def parse_answer(llm_answer : str):
print(f"This the llm_answer : {llm_answer}")
return re.split("\n(?:\d)+\.\s", llm_answer.split("Concepts:")[1])[1:]
related_concepts_chain = RELEVANT_CONCEPTS_PROMPT | llm | StrOutputParser() | parse_answer
related_concepts_raw = related_concepts_chain.invoke({"user_query" : question, "concepts" : '\n'.join(concepts)})
# We clean up the list we received from the LLM in case there were some hallucinations
related_concepts_cleaned = []
for related_concept in related_concepts_raw:
# If the concept returned from the LLM is in the list we keep it
if related_concept in concepts:
related_concepts_cleaned.append(related_concept)
else:
# The LLM sometimes only forgets a few words from the concept name
# We check if the generated concept is a substring of an existing one and if it is the case add it to the list
for concept in concepts:
if related_concept in concept:
related_concepts_cleaned.append(concept)
break
# TODO : Add concepts found via similarity search
return related_concepts_cleaned
def build_concept_string(graph: Neo4jGraph, concept_list: list[str]):
concept_string = ""
for concept in concept_list:
concept_description_query = f"""
MATCH (c:Concept {{name: "{concept}" }}) RETURN c.description
"""
concept_description = graph.query(concept_description_query)[0]['c.description']
concept_string += f"name: {concept}\ndescription: {concept_description}\n\n"
return concept_string
def get_global_concepts(graph: Neo4jGraph):
concept_cypher = "MATCH (gc:GlobalConcept) return gc"
if isinstance(graph, Neo4jGraph):
concepts = graph.query(concept_cypher)
else:
user_input = input("Topics : ")
concepts = eval(user_input)
concepts_name = [concept['gc']['name'] for concept in concepts]
return concepts_name
def generate_cypher(state: DocRetrieverState, config: ConfigSchema):
"""
The node where the cypher is generated
"""
graph = config["configurable"].get("graph")
question = state['query']
related_concepts = get_related_concepts(graph, question)
cyphers = []
if config["configurable"].get("cypher_gen_method") == 'auto':
cypher_gen_chain = get_cypher_gen_chain()
cyphers = cypher_gen_chain.invoke({
"schema": graph.schema,
"question": question,
"concepts": related_concepts
})
if config["configurable"].get("cypher_gen_method") == 'guided':
concept_selection_chain = get_concept_selection_chain()
print(f"Concept selection chain is : {concept_selection_chain}")
selected_topic = concept_selection_chain.invoke({"question" : question, "concepts": get_concepts(graph)})
print(f"Selected topic are : {selected_topic}")
cyphers = [generate_cypher_from_topic(selected_topic, state['current_plan_step'])]
print(f"Cyphers are : {cyphers}")
if config["configurable"].get("validate_cypher"):
corrector_schema = [Schema(el["start"], el["type"], el["end"]) for el in graph.structured_schema.get("relationships")]
cypher_corrector = CypherQueryCorrector(corrector_schema)
cyphers = [cypher_corrector(cypher) for cypher in cyphers]
return {"cyphers" : cyphers}
def generate_cypher_from_topic(selected_concept: str, plan_step: int):
"""
Helper function used when the 'guided' cypher generation method has been configured
"""
print(f"L.176 PLAN STEP : {plan_step}")
cypher_el = "(n) return n.title, n.description"
match plan_step:
case 0:
cypher_el = "(ts:TechnicalSpecification) RETURN ts.title, ts.scope, ts.description"
case 1:
cypher_el = "(rp:ResearchPaper) RETURN rp.title, rp.abstract"
case 2:
cypher_el = "(ki:KeyIssue) RETURN ki.description"
return f"MATCH (c:Concept {{name:'{selected_concept}'}})-[:RELATED_TO]-{cypher_el}"
def get_docs(state:DocRetrieverState, config:ConfigSchema):
"""
This node retrieves docs from the graph using the generated cypher
"""
graph = config["configurable"].get("graph")
output = []
if graph is not None:
for cypher in state["cyphers"]:
try:
output = graph.query(cypher)
break
except Exception as e:
print("Failed to retrieve docs : {e}")
# Clean up the docs we received as there may be duplicates depending on the cypher query
all_docs = []
for doc in output:
unwinded_doc = {}
for key in doc:
if isinstance(doc[key], dict):
all_docs.append(doc[key])
else:
unwinded_doc.update({key: doc[key]})
if unwinded_doc:
all_docs.append(unwinded_doc)
filtered_docs = []
for doc in all_docs:
if doc not in filtered_docs:
filtered_docs.append(doc)
return {"docs": filtered_docs}
# Data model
class GradeDocumentsBinary(BaseModel):
"""Binary score for relevance check on retrieved documents."""
binary_score: str = Field(
description="Documents are relevant to the question, 'yes' or 'no'"
)
# LLM with function call
# llm_grader_binary = ChatGroq(model="mixtral-8x7b-32768", temperature=0)
def get_binary_grader(model="mixtral-8x7b-32768"):
"""
Returns a binary grader to evaluate relevance of documents using specified model for generation
This is used when the 'binary' evaluation method has been configured
"""
if model == "gpt-4o":
llm_grader_binary = ChatOpenAI(model='gpt-4o', base_url="https://llm.synapse.thalescloud.io/", temperature=0)
else:
llm_grader_binary = ChatGroq(model="mixtral-8x7b-32768", temperature=0)
structured_llm_grader_binary = llm_grader_binary.with_structured_output(GradeDocumentsBinary)
retrieval_grader_binary = BINARY_GRADER_PROMPT | structured_llm_grader_binary
return retrieval_grader_binary
class GradeDocumentsScore(BaseModel):
"""Score for relevance check on retrieved documents."""
score: float = Field(
description="Documents are relevant to the question, score between 0 (completely irrelevant) and 1 (perfectly relevant)"
)
def get_score_grader(model="mixtral-8x7b-32768"):
"""
Returns a score grader to evaluate relevance of documents using specified model for generation
This is used when the 'score' evaluation method has been configured
"""
if model == "gpt-4o":
llm_grader_score = ChatOpenAI(model='gpt-4o', base_url="https://llm.synapse.thalescloud.io/", temperature=0)
else:
llm_grader_score = ChatGroq(model="mixtral-8x7b-32768", temperature = 0)
structured_llm_grader_score = llm_grader_score.with_structured_output(GradeDocumentsScore)
retrieval_grader_score = SCORE_GRADER_PROMPT | structured_llm_grader_score
return retrieval_grader_score
def eval_doc(doc, query, method="binary", threshold=0.7, eval_model="mixtral-8x7b-32768"):
'''
doc : the document to evaluate
query : the query to which to doc shoud be relevant
method : "binary" or "score"
threshold : for "score" method, score above which a doc is considered relevant
'''
if method == "binary":
retrieval_grader_binary = get_binary_grader(model=eval_model)
return 1 if (retrieval_grader_binary.invoke({"question": query, "document":doc}).binary_score == 'yes') else 0
elif method == "score":
retrieval_grader_score = get_score_grader(model=eval_model)
score = retrieval_grader_score.invoke({"query": query, "document":doc}).score or None
if score is not None:
return score if score >= threshold else 0
else:
# Couldn't parse score, marking document as relevant by default
return 1
else:
raise ValueError("Invalid method")
def eval_docs(state: DocRetrieverState, config: ConfigSchema):
"""
This node performs evaluation of the retrieved docs and
"""
eval_method = config["configurable"].get("eval_method") or "binary"
MAX_DOCS = config["configurable"].get("max_docs") or 15
valid_doc_scores = []
for doc in sample(state["docs"], min(25, len(state["docs"]))):
score = eval_doc(
doc=format_doc(doc),
query=state["query"],
method=eval_method,
threshold=config["configurable"].get("eval_threshold") or 0.7,
eval_model = config["configurable"].get("eval_model") or "mixtral-8x7b-32768"
)
if score:
valid_doc_scores.append((doc, score))
if eval_method == 'score':
# Get at most MAX_DOCS items with the highest score if score method was used
valid_docs = sorted(valid_doc_scores, key=lambda x: x[1])
valid_docs = [valid_doc[0] for valid_doc in valid_docs[:MAX_DOCS]]
else:
# Get at mots MAX_DOCS items at random if binary method was used
shuffle(valid_doc_scores)
valid_docs = [valid_doc[0] for valid_doc in valid_doc_scores[:MAX_DOCS]]
return {"valid_docs": valid_docs + (state["valid_docs"] or [])}
def build_data_retriever_graph(memory):
"""
Builds the data_retriever graph
"""
graph_builder_doc_retriever = StateGraph(DocRetrieverState)
graph_builder_doc_retriever.add_node("generate_cypher", generate_cypher)
graph_builder_doc_retriever.add_node("get_docs", get_docs)
graph_builder_doc_retriever.add_node("eval_docs", eval_docs)
graph_builder_doc_retriever.add_edge("__start__", "generate_cypher")
graph_builder_doc_retriever.add_edge("generate_cypher", "get_docs")
graph_builder_doc_retriever.add_edge("get_docs", "eval_docs")
graph_builder_doc_retriever.add_edge("eval_docs", "__end__")
graph_doc_retriever = graph_builder_doc_retriever.compile(checkpointer=memory)
return graph_doc_retriever