Spaces:
Sleeping
Sleeping
#!/usr/bin/env python | |
# coding: utf-8 | |
import re | |
import time | |
from random import shuffle, sample | |
from langgraph.checkpoint.sqlite import SqliteSaver | |
from langchain_groq import ChatGroq | |
from langchain_openai import ChatOpenAI | |
from langchain_core.messages import HumanMessage | |
from langchain_community.graphs import Neo4jGraph | |
from langchain_community.chains.graph_qa.cypher_utils import CypherQueryCorrector, Schema | |
from langchain_core.output_parsers import StrOutputParser | |
from langchain_core.prompts import ChatPromptTemplate | |
from langchain_core.pydantic_v1 import Field | |
from pydantic import BaseModel | |
from langchain_groq import ChatGroq | |
from langgraph.graph import StateGraph | |
from llmlingua import PromptCompressor | |
from ki_gen.prompts import ( | |
CYPHER_GENERATION_PROMPT, | |
CONCEPT_SELECTION_PROMPT, | |
BINARY_GRADER_PROMPT, | |
SCORE_GRADER_PROMPT, | |
RELEVANT_CONCEPTS_PROMPT, | |
) | |
from ki_gen.utils import ConfigSchema, DocRetrieverState, get_model, format_doc | |
def extract_cypher(text: str) -> str: | |
"""Extract Cypher code from a text. | |
Args: | |
text: Text to extract Cypher code from. | |
Returns: | |
Cypher code extracted from the text. | |
""" | |
# The pattern to find Cypher code enclosed in triple backticks | |
pattern_1 = r"```cypher\n(.*?)```" | |
pattern_2 = r"```\n(.*?)```" | |
# Find all matches in the input text | |
matches_1 = re.findall(pattern_1, text, re.DOTALL) | |
matches_2 = re.findall(pattern_2, text, re.DOTALL) | |
return [ | |
matches_1[0] if matches_1 else text, | |
matches_2[0] if matches_2 else text, | |
text | |
] | |
def get_cypher_gen_chain(model: str = "deepseek-r1-distill-llama-70b"): | |
""" | |
Returns cypher gen chain using specified model for generation | |
This is used when the 'auto' cypher generation method has been configured | |
""" | |
if model=="openai": | |
llm_cypher_gen = ChatOpenAI(model='gpt-4o', base_url="https://llm.synapse.thalescloud.io/") | |
else: | |
llm_cypher_gen = ChatGroq(model = "deepseek-r1-distill-llama-70b") | |
cypher_gen_chain = CYPHER_GENERATION_PROMPT | llm_cypher_gen | StrOutputParser() | extract_cypher | |
return cypher_gen_chain | |
def get_concept_selection_chain(model: str = "deepseek-r1-distill-llama-70b"): | |
""" | |
Returns a chain to select the most relevant topic using specified model for generation. | |
This is used when the 'guided' cypher generation method has been configured | |
""" | |
if model == "openai": | |
llm_topic_selection = ChatOpenAI(model='gpt-4o', base_url="https://llm.synapse.thalescloud.io/") | |
else: | |
llm_topic_selection = ChatGroq(model="deepseek-r1-distill-llama-70b") | |
print(f"FOUND LLM TOPIC SELECTION FOR THE CONCEPT SELECTION PROMPT : {llm_topic_selection}") | |
topic_selection_chain = CONCEPT_SELECTION_PROMPT | llm_topic_selection | StrOutputParser() | |
return topic_selection_chain | |
def get_concepts(graph: Neo4jGraph): | |
concept_cypher = "MATCH (c:Concept) return c" | |
if isinstance(graph, Neo4jGraph): | |
concepts = graph.query(concept_cypher) | |
else: | |
user_input = input("Topics : ") | |
concepts = eval(user_input) | |
concepts_name = [concept['c']['name'] for concept in concepts] | |
return concepts_name | |
def get_related_concepts(graph: Neo4jGraph, question: str): | |
concepts = get_concepts(graph) | |
llm = get_model() | |
print(f"this is the llm variable : {llm}") | |
def parse_answer(llm_answer : str): | |
try: | |
print(f"This the llm_answer : {llm_answer}") | |
return re.split("\n(?:\d)+\.\s", llm_answer.split("Concepts:")[1])[1:] | |
except: | |
return "No concept" | |
related_concepts_chain = RELEVANT_CONCEPTS_PROMPT | llm | StrOutputParser() | parse_answer | |
print(f"This is the question of the user : {question}") | |
print(f"This is the concepts of the user : {concepts}") | |
#groq.APIStatusError: Error code: 413 - {'error': {'message': 'Request too large for model `deepseek-r1-distill-llama-70b` in organization `org_01j6xywkndffv96m3wgh81jm49` on tokens per minute | |
# (TPM): Limit 5000, Requested 17099, please reduce your message size and try again. Visit https://console.groq.com/docs/rate-limits for more information.', | |
# 'type': 'tokens', 'code': 'rate_limit_exceeded'}} | |
try: | |
related_concepts_raw = related_concepts_chain.invoke({"user_query" : question, "concepts" : '\n'.join(concepts)}) | |
print(f"related_concepts_raw : {related_concepts_raw}") | |
except Exception as e: | |
if e.status_code == 413: | |
msg = e.body["error"]["message"] | |
print(f"question is : {question}") | |
print(type(question)) | |
error_question = ["user_query", question] | |
related_concepts_raw = error_concept_groq(msg,concepts,related_concepts_chain,error_question) | |
pass | |
# We clean up the list we received from the LLM in case there were some hallucinations | |
related_concepts_cleaned = [] | |
for related_concept in related_concepts_raw: | |
# If the concept returned from the LLM is in the list we keep it | |
if related_concept in concepts: | |
related_concepts_cleaned.append(related_concept) | |
else: | |
# The LLM sometimes only forgets a few words from the concept name | |
# We check if the generated concept is a substring of an existing one and if it is the case add it to the list | |
for concept in concepts: | |
if related_concept in concept: | |
related_concepts_cleaned.append(concept) | |
break | |
# TODO : Add concepts found via similarity search | |
return related_concepts_cleaned | |
def build_concept_string(graph: Neo4jGraph, concept_list: list[str]): | |
concept_string = "" | |
for concept in concept_list: | |
concept_description_query = f""" | |
MATCH (c:Concept {{name: "{concept}" }}) RETURN c.description | |
""" | |
concept_description = graph.query(concept_description_query)[0]['c.description'] | |
concept_string += f"name: {concept}\ndescription: {concept_description}\n\n" | |
return concept_string | |
def get_global_concepts(graph: Neo4jGraph): | |
concept_cypher = "MATCH (gc:GlobalConcept) return gc" | |
if isinstance(graph, Neo4jGraph): | |
concepts = graph.query(concept_cypher) | |
else: | |
user_input = input("Topics : ") | |
concepts = eval(user_input) | |
concepts_name = [concept['gc']['name'] for concept in concepts] | |
return concepts_name | |
def generate_cypher(state: DocRetrieverState, config: ConfigSchema): | |
""" | |
The node where the cypher is generated | |
""" | |
graph = config["configurable"].get("graph") | |
question = state['query'] | |
related_concepts = get_related_concepts(graph, question) | |
cyphers = [] | |
if config["configurable"].get("cypher_gen_method") == 'auto': | |
cypher_gen_chain = get_cypher_gen_chain() | |
cyphers = cypher_gen_chain.invoke({ | |
"schema": graph.schema, | |
"question": question, | |
"concepts": related_concepts | |
}) | |
try : | |
if config["configurable"].get("cypher_gen_method") == 'guided': | |
concept_selection_chain = get_concept_selection_chain() | |
print(f"Concept selection chain is : {concept_selection_chain}") | |
selected_topic = concept_selection_chain.invoke({"question" : question, "concepts": get_concepts(graph)}) | |
print(f"Selected topic are : {selected_topic}") | |
except Exception as e: | |
error_question = ["question", question] | |
selected_topic = error_concept_groq(e.body["error"]["message"],get_concepts(graph),concept_selection_chain,error_question) | |
pass | |
if config["configurable"].get("cypher_gen_method") == 'guided': | |
cyphers = [generate_cypher_from_topic(selected_topic, state['current_plan_step'])] | |
print(f"Cyphers are : {cyphers}") | |
if config["configurable"].get("validate_cypher"): | |
corrector_schema = [Schema(el["start"], el["type"], el["end"]) for el in graph.structured_schema.get("relationships")] | |
cypher_corrector = CypherQueryCorrector(corrector_schema) | |
cyphers = [cypher_corrector(cypher) for cypher in cyphers] | |
return {"cyphers" : cyphers} | |
def generate_cypher_from_topic(selected_concept: str, plan_step: int): | |
""" | |
Helper function used when the 'guided' cypher generation method has been configured | |
""" | |
print(f"L.176 PLAN STEP : {plan_step}") | |
cypher_el = "(n) return n.title, n.description" | |
match plan_step: | |
case 0: | |
cypher_el = "(ts:TechnicalSpecification) RETURN ts.title, ts.scope, ts.description" | |
case 1: | |
cypher_el = "(rp:ResearchPaper) RETURN rp.title, rp.abstract" | |
case 2: | |
cypher_el = "(ki:KeyIssue) RETURN ki.description" | |
return f"MATCH (c:Concept {{name:'{selected_concept}'}})-[:RELATED_TO]-{cypher_el}" | |
def get_docs(state:DocRetrieverState, config:ConfigSchema): | |
""" | |
This node retrieves docs from the graph using the generated cypher | |
""" | |
graph = config["configurable"].get("graph") | |
output = [] | |
if graph is not None: | |
for cypher in state["cyphers"]: | |
try: | |
output = graph.query(cypher) | |
break | |
except Exception as e: | |
print("Failed to retrieve docs : {e}") | |
# Clean up the docs we received as there may be duplicates depending on the cypher query | |
all_docs = [] | |
for doc in output: | |
unwinded_doc = {} | |
for key in doc: | |
if isinstance(doc[key], dict): | |
all_docs.append(doc[key]) | |
else: | |
unwinded_doc.update({key: doc[key]}) | |
if unwinded_doc: | |
all_docs.append(unwinded_doc) | |
filtered_docs = [] | |
for doc in all_docs: | |
if doc not in filtered_docs: | |
filtered_docs.append(doc) | |
return {"docs": filtered_docs} | |
# Data model | |
class GradeDocumentsBinary(BaseModel): | |
"""Binary score for relevance check on retrieved documents.""" | |
binary_score: str = Field( | |
description="Documents are relevant to the question, 'yes' or 'no'" | |
) | |
# LLM with function call | |
# llm_grader_binary = ChatGroq(model="deepseek-r1-distill-llama-70b", temperature=0) | |
def get_binary_grader(model="deepseek-r1-distill-llama-70b"): | |
""" | |
Returns a binary grader to evaluate relevance of documents using specified model for generation | |
This is used when the 'binary' evaluation method has been configured | |
""" | |
if model == "gpt-4o": | |
llm_grader_binary = ChatOpenAI(model='gpt-4o', base_url="https://llm.synapse.thalescloud.io/", temperature=0) | |
else: | |
llm_grader_binary = ChatGroq(model="deepseek-r1-distill-llama-70b", temperature=0) | |
structured_llm_grader_binary = llm_grader_binary.with_structured_output(GradeDocumentsBinary) | |
retrieval_grader_binary = BINARY_GRADER_PROMPT | structured_llm_grader_binary | |
return retrieval_grader_binary | |
class GradeDocumentsScore(BaseModel): | |
"""Score for relevance check on retrieved documents.""" | |
score: float = Field( | |
description="Documents are relevant to the question, score between 0 (completely irrelevant) and 1 (perfectly relevant)" | |
) | |
def get_score_grader(model="deepseek-r1-distill-llama-70b"): | |
""" | |
Returns a score grader to evaluate relevance of documents using specified model for generation | |
This is used when the 'score' evaluation method has been configured | |
""" | |
if model == "gpt-4o": | |
llm_grader_score = ChatOpenAI(model='gpt-4o', base_url="https://llm.synapse.thalescloud.io/", temperature=0) | |
else: | |
llm_grader_score = ChatGroq(model="deepseek-r1-distill-llama-70b", temperature = 0) | |
structured_llm_grader_score = llm_grader_score.with_structured_output(GradeDocumentsScore) | |
retrieval_grader_score = SCORE_GRADER_PROMPT | structured_llm_grader_score | |
return retrieval_grader_score | |
def eval_doc(doc, query, method="binary", threshold=0.7, eval_model="deepseek-r1-distill-llama-70b"): | |
''' | |
doc : the document to evaluate | |
query : the query to which to doc shoud be relevant | |
method : "binary" or "score" | |
threshold : for "score" method, score above which a doc is considered relevant | |
''' | |
if method == "binary": | |
retrieval_grader_binary = get_binary_grader(model=eval_model) | |
return 1 if (retrieval_grader_binary.invoke({"question": query, "document":doc}).binary_score == 'yes') else 0 | |
elif method == "score": | |
retrieval_grader_score = get_score_grader(model=eval_model) | |
score = retrieval_grader_score.invoke({"query": query, "document":doc}).score or None | |
if score is not None: | |
return score if score >= threshold else 0 | |
else: | |
# Couldn't parse score, marking document as relevant by default | |
return 1 | |
else: | |
raise ValueError("Invalid method") | |
def eval_docs(state: DocRetrieverState, config: ConfigSchema): | |
""" | |
This node performs evaluation of the retrieved docs and | |
""" | |
eval_method = config["configurable"].get("eval_method") or "binary" | |
MAX_DOCS = config["configurable"].get("max_docs") or 15 | |
valid_doc_scores = [] | |
for doc in sample(state["docs"], min(25, len(state["docs"]))): | |
score = eval_doc( | |
doc=format_doc(doc), | |
query=state["query"], | |
method=eval_method, | |
threshold=config["configurable"].get("eval_threshold") or 0.7, | |
eval_model = config["configurable"].get("eval_model") or "deepseek-r1-distill-llama-70b" | |
) | |
if score: | |
valid_doc_scores.append((doc, score)) | |
if eval_method == 'score': | |
# Get at most MAX_DOCS items with the highest score if score method was used | |
valid_docs = sorted(valid_doc_scores, key=lambda x: x[1]) | |
valid_docs = [valid_doc[0] for valid_doc in valid_docs[:MAX_DOCS]] | |
else: | |
# Get at mots MAX_DOCS items at random if binary method was used | |
shuffle(valid_doc_scores) | |
valid_docs = [valid_doc[0] for valid_doc in valid_doc_scores[:MAX_DOCS]] | |
return {"valid_docs": valid_docs + (state["valid_docs"] or [])} | |
def build_data_retriever_graph(memory): | |
""" | |
Builds the data_retriever graph | |
""" | |
#with SqliteSaver.from_conn_string(":memory:") as memory : | |
graph_builder_doc_retriever = StateGraph(DocRetrieverState) | |
graph_builder_doc_retriever.add_node("generate_cypher", generate_cypher) | |
graph_builder_doc_retriever.add_node("get_docs", get_docs) | |
graph_builder_doc_retriever.add_node("eval_docs", eval_docs) | |
graph_builder_doc_retriever.add_edge("__start__", "generate_cypher") | |
graph_builder_doc_retriever.add_edge("generate_cypher", "get_docs") | |
graph_builder_doc_retriever.add_edge("get_docs", "eval_docs") | |
graph_builder_doc_retriever.add_edge("eval_docs", "__end__") | |
graph_doc_retriever = graph_builder_doc_retriever.compile(checkpointer=memory) | |
return graph_doc_retriever | |
def error_concept_groq(msg,concepts,groq,question): | |
try: | |
start = msg.find("Requested") + len("Requested ") | |
end = msg.find(",", start) | |
rate_limit = int(msg[start:end]) | |
related_concepts = [] | |
i = 0 | |
start = 0 | |
end = len(concepts) // (rate_limit // 5000 + (1 if rate_limit%4500 != 0 else 0)) | |
while (i < rate_limit // 5000): | |
smaller_concepts = concepts[start:end] | |
start = end | |
end = end + len(concepts) // (rate_limit//5000 + (1 if rate_limit%4500 != 0 else 0)) | |
res = groq.invoke({question[0] : question[1], "concepts" : '\n'.join(smaller_concepts)}) | |
for r in res: | |
related_concepts.append(r) | |
i+=1 | |
return related_concepts | |
except Exception as e: | |
if e.status_code == 419: | |
time.sleep(65) | |
error_concept_groq(msg,concepts,groq,question) |