#!/usr/bin/env python # coding: utf-8 import re from random import shuffle, sample from langchain_groq import ChatGroq from langchain_openai import ChatOpenAI from langchain_core.messages import HumanMessage from langchain_community.graphs import Neo4jGraph from langchain_community.chains.graph_qa.cypher_utils import CypherQueryCorrector, Schema from langchain_core.output_parsers import StrOutputParser from langchain_core.prompts import ChatPromptTemplate from langchain_core.pydantic_v1 import BaseModel, Field from langchain_groq import ChatGroq from langgraph.graph import StateGraph from llmlingua import PromptCompressor from ki_gen.prompts import ( CYPHER_GENERATION_PROMPT, CONCEPT_SELECTION_PROMPT, BINARY_GRADER_PROMPT, SCORE_GRADER_PROMPT, RELEVANT_CONCEPTS_PROMPT, ) from ki_gen.utils import ConfigSchema, DocRetrieverState, get_model, format_doc def extract_cypher(text: str) -> str: """Extract Cypher code from a text. Args: text: Text to extract Cypher code from. Returns: Cypher code extracted from the text. """ # The pattern to find Cypher code enclosed in triple backticks pattern_1 = r"```cypher\n(.*?)```" pattern_2 = r"```\n(.*?)```" # Find all matches in the input text matches_1 = re.findall(pattern_1, text, re.DOTALL) matches_2 = re.findall(pattern_2, text, re.DOTALL) return [ matches_1[0] if matches_1 else text, matches_2[0] if matches_2 else text, text ] def get_cypher_gen_chain(model: str = "openai"): """ Returns cypher gen chain using specified model for generation This is used when the 'auto' cypher generation method has been configured """ if model=="openai": llm_cypher_gen = ChatOpenAI(model='gpt-4o', base_url="https://llm.synapse.thalescloud.io/") else: llm_cypher_gen = ChatGroq(model = "mixtral-8x7b-32768") cypher_gen_chain = CYPHER_GENERATION_PROMPT | llm_cypher_gen | StrOutputParser() | extract_cypher return cypher_gen_chain def get_concept_selection_chain(model: str = "openai"): """ Returns a chain to select the most relevant topic using specified model for generation. This is used when the 'guided' cypher generation method has been configured """ if model == "openai": llm_topic_selection = ChatOpenAI(model='gpt-4o', base_url="https://llm.synapse.thalescloud.io/") else: llm_topic_selection = ChatGroq(model="llama3-70b-8192", max_tokens=8192) print(f"FOUND LLM TOPIC SELECTION FOR THE CONCEPT SELECTION PROMPT : {llm_topic_selection}") topic_selection_chain = CONCEPT_SELECTION_PROMPT | llm_topic_selection | StrOutputParser() return topic_selection_chain def get_concepts(graph: Neo4jGraph): concept_cypher = "MATCH (c:Concept) return c" if isinstance(graph, Neo4jGraph): concepts = graph.query(concept_cypher) else: user_input = input("Topics : ") concepts = eval(user_input) concepts_name = [concept['c']['name'] for concept in concepts] return concepts_name def get_related_concepts(graph: Neo4jGraph, question: str): concepts = get_concepts(graph) llm = get_model(model='gpt-4o') print(f"this is the llm variable : {llm}") def parse_answer(llm_answer : str): print(f"This the llm_answer : {llm_answer}") return re.split("\n(?:\d)+\.\s", llm_answer.split("Concepts:")[1])[1:] related_concepts_chain = RELEVANT_CONCEPTS_PROMPT | llm | StrOutputParser() | parse_answer related_concepts_raw = related_concepts_chain.invoke({"user_query" : question, "concepts" : '\n'.join(concepts)}) # We clean up the list we received from the LLM in case there were some hallucinations related_concepts_cleaned = [] for related_concept in related_concepts_raw: # If the concept returned from the LLM is in the list we keep it if related_concept in concepts: related_concepts_cleaned.append(related_concept) else: # The LLM sometimes only forgets a few words from the concept name # We check if the generated concept is a substring of an existing one and if it is the case add it to the list for concept in concepts: if related_concept in concept: related_concepts_cleaned.append(concept) break # TODO : Add concepts found via similarity search return related_concepts_cleaned def build_concept_string(graph: Neo4jGraph, concept_list: list[str]): concept_string = "" for concept in concept_list: concept_description_query = f""" MATCH (c:Concept {{name: "{concept}" }}) RETURN c.description """ concept_description = graph.query(concept_description_query)[0]['c.description'] concept_string += f"name: {concept}\ndescription: {concept_description}\n\n" return concept_string def get_global_concepts(graph: Neo4jGraph): concept_cypher = "MATCH (gc:GlobalConcept) return gc" if isinstance(graph, Neo4jGraph): concepts = graph.query(concept_cypher) else: user_input = input("Topics : ") concepts = eval(user_input) concepts_name = [concept['gc']['name'] for concept in concepts] return concepts_name def generate_cypher(state: DocRetrieverState, config: ConfigSchema): """ The node where the cypher is generated """ graph = config["configurable"].get("graph") question = state['query'] related_concepts = get_related_concepts(graph, question) cyphers = [] if config["configurable"].get("cypher_gen_method") == 'auto': cypher_gen_chain = get_cypher_gen_chain() cyphers = cypher_gen_chain.invoke({ "schema": graph.schema, "question": question, "concepts": related_concepts }) if config["configurable"].get("cypher_gen_method") == 'guided': concept_selection_chain = get_concept_selection_chain() print(f"Concept selection chain is : {concept_selection_chain}") selected_topic = concept_selection_chain.invoke({"question" : question, "concepts": get_concepts(graph)}) print(f"Selected topic are : {selected_topic}") cyphers = [generate_cypher_from_topic(selected_topic, state['current_plan_step'])] print(f"Cyphers are : {cyphers}") if config["configurable"].get("validate_cypher"): corrector_schema = [Schema(el["start"], el["type"], el["end"]) for el in graph.structured_schema.get("relationships")] cypher_corrector = CypherQueryCorrector(corrector_schema) cyphers = [cypher_corrector(cypher) for cypher in cyphers] return {"cyphers" : cyphers} def generate_cypher_from_topic(selected_concept: str, plan_step: int): """ Helper function used when the 'guided' cypher generation method has been configured """ print(f"L.176 PLAN STEP : {plan_step}") cypher_el = "(n) return n.title, n.description" match plan_step: case 0: cypher_el = "(ts:TechnicalSpecification) RETURN ts.title, ts.scope, ts.description" case 1: cypher_el = "(rp:ResearchPaper) RETURN rp.title, rp.abstract" case 2: cypher_el = "(ki:KeyIssue) RETURN ki.description" return f"MATCH (c:Concept {{name:'{selected_concept}'}})-[:RELATED_TO]-{cypher_el}" def get_docs(state:DocRetrieverState, config:ConfigSchema): """ This node retrieves docs from the graph using the generated cypher """ graph = config["configurable"].get("graph") output = [] if graph is not None: for cypher in state["cyphers"]: try: output = graph.query(cypher) break except Exception as e: print("Failed to retrieve docs : {e}") # Clean up the docs we received as there may be duplicates depending on the cypher query all_docs = [] for doc in output: unwinded_doc = {} for key in doc: if isinstance(doc[key], dict): all_docs.append(doc[key]) else: unwinded_doc.update({key: doc[key]}) if unwinded_doc: all_docs.append(unwinded_doc) filtered_docs = [] for doc in all_docs: if doc not in filtered_docs: filtered_docs.append(doc) return {"docs": filtered_docs} # Data model class GradeDocumentsBinary(BaseModel): """Binary score for relevance check on retrieved documents.""" binary_score: str = Field( description="Documents are relevant to the question, 'yes' or 'no'" ) # LLM with function call # llm_grader_binary = ChatGroq(model="mixtral-8x7b-32768", temperature=0) def get_binary_grader(model="mixtral-8x7b-32768"): """ Returns a binary grader to evaluate relevance of documents using specified model for generation This is used when the 'binary' evaluation method has been configured """ if model == "gpt-4o": llm_grader_binary = ChatOpenAI(model='gpt-4o', base_url="https://llm.synapse.thalescloud.io/", temperature=0) else: llm_grader_binary = ChatGroq(model="mixtral-8x7b-32768", temperature=0) structured_llm_grader_binary = llm_grader_binary.with_structured_output(GradeDocumentsBinary) retrieval_grader_binary = BINARY_GRADER_PROMPT | structured_llm_grader_binary return retrieval_grader_binary class GradeDocumentsScore(BaseModel): """Score for relevance check on retrieved documents.""" score: float = Field( description="Documents are relevant to the question, score between 0 (completely irrelevant) and 1 (perfectly relevant)" ) def get_score_grader(model="mixtral-8x7b-32768"): """ Returns a score grader to evaluate relevance of documents using specified model for generation This is used when the 'score' evaluation method has been configured """ if model == "gpt-4o": llm_grader_score = ChatOpenAI(model='gpt-4o', base_url="https://llm.synapse.thalescloud.io/", temperature=0) else: llm_grader_score = ChatGroq(model="mixtral-8x7b-32768", temperature = 0) structured_llm_grader_score = llm_grader_score.with_structured_output(GradeDocumentsScore) retrieval_grader_score = SCORE_GRADER_PROMPT | structured_llm_grader_score return retrieval_grader_score def eval_doc(doc, query, method="binary", threshold=0.7, eval_model="mixtral-8x7b-32768"): ''' doc : the document to evaluate query : the query to which to doc shoud be relevant method : "binary" or "score" threshold : for "score" method, score above which a doc is considered relevant ''' if method == "binary": retrieval_grader_binary = get_binary_grader(model=eval_model) return 1 if (retrieval_grader_binary.invoke({"question": query, "document":doc}).binary_score == 'yes') else 0 elif method == "score": retrieval_grader_score = get_score_grader(model=eval_model) score = retrieval_grader_score.invoke({"query": query, "document":doc}).score or None if score is not None: return score if score >= threshold else 0 else: # Couldn't parse score, marking document as relevant by default return 1 else: raise ValueError("Invalid method") def eval_docs(state: DocRetrieverState, config: ConfigSchema): """ This node performs evaluation of the retrieved docs and """ eval_method = config["configurable"].get("eval_method") or "binary" MAX_DOCS = config["configurable"].get("max_docs") or 15 valid_doc_scores = [] for doc in sample(state["docs"], min(25, len(state["docs"]))): score = eval_doc( doc=format_doc(doc), query=state["query"], method=eval_method, threshold=config["configurable"].get("eval_threshold") or 0.7, eval_model = config["configurable"].get("eval_model") or "mixtral-8x7b-32768" ) if score: valid_doc_scores.append((doc, score)) if eval_method == 'score': # Get at most MAX_DOCS items with the highest score if score method was used valid_docs = sorted(valid_doc_scores, key=lambda x: x[1]) valid_docs = [valid_doc[0] for valid_doc in valid_docs[:MAX_DOCS]] else: # Get at mots MAX_DOCS items at random if binary method was used shuffle(valid_doc_scores) valid_docs = [valid_doc[0] for valid_doc in valid_doc_scores[:MAX_DOCS]] return {"valid_docs": valid_docs + (state["valid_docs"] or [])} def build_data_retriever_graph(memory): """ Builds the data_retriever graph """ graph_builder_doc_retriever = StateGraph(DocRetrieverState) graph_builder_doc_retriever.add_node("generate_cypher", generate_cypher) graph_builder_doc_retriever.add_node("get_docs", get_docs) graph_builder_doc_retriever.add_node("eval_docs", eval_docs) graph_builder_doc_retriever.add_edge("__start__", "generate_cypher") graph_builder_doc_retriever.add_edge("generate_cypher", "get_docs") graph_builder_doc_retriever.add_edge("get_docs", "eval_docs") graph_builder_doc_retriever.add_edge("eval_docs", "__end__") graph_doc_retriever = graph_builder_doc_retriever.compile(checkpointer=memory) return graph_doc_retriever