File size: 13,476 Bytes
6aaddef
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
#!/usr/bin/env python
# coding: utf-8

import re
from random import shuffle, sample

from langchain_groq import ChatGroq
from langchain_openai import ChatOpenAI
from langchain_core.messages import HumanMessage
from langchain_community.graphs import Neo4jGraph
from langchain_community.chains.graph_qa.cypher_utils import CypherQueryCorrector, Schema
from langchain_core.output_parsers import StrOutputParser
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.pydantic_v1 import BaseModel, Field
from langchain_groq import ChatGroq

from langgraph.graph import StateGraph

from llmlingua import PromptCompressor

from ki_gen.prompts import (
    CYPHER_GENERATION_PROMPT, 
    CONCEPT_SELECTION_PROMPT,
    BINARY_GRADER_PROMPT,
    SCORE_GRADER_PROMPT,
    RELEVANT_CONCEPTS_PROMPT,
)
from ki_gen.utils import ConfigSchema, DocRetrieverState, get_model, format_doc




def extract_cypher(text: str) -> str:
    """Extract Cypher code from a text.

    Args:
        text: Text to extract Cypher code from.

    Returns:
        Cypher code extracted from the text.
    """
    # The pattern to find Cypher code enclosed in triple backticks
    pattern_1 = r"```cypher\n(.*?)```"
    pattern_2 = r"```\n(.*?)```"

    # Find all matches in the input text
    matches_1 = re.findall(pattern_1, text, re.DOTALL)
    matches_2 = re.findall(pattern_2, text, re.DOTALL)
    return [
        matches_1[0] if matches_1 else text,
        matches_2[0] if matches_2 else text,
        text
    ]

def get_cypher_gen_chain(model: str = "openai"):
    """
    Returns cypher gen chain using specified model for generation
    This is used when the 'auto' cypher generation method has been configured
    """

    if model=="openai":
        llm_cypher_gen = ChatOpenAI(model='gpt-4o', base_url="https://llm.synapse.thalescloud.io/")
    else:
        llm_cypher_gen = ChatGroq(model = "mixtral-8x7b-32768")
    cypher_gen_chain = CYPHER_GENERATION_PROMPT | llm_cypher_gen | StrOutputParser() | extract_cypher
    return cypher_gen_chain

def get_concept_selection_chain(model: str = "openai"):
    """
    Returns a chain to select the most relevant topic using specified model for generation.
    This is used when the 'guided' cypher generation method has been configured
    """

    if model == "openai":
        llm_topic_selection = ChatOpenAI(model='gpt-4o', base_url="https://llm.synapse.thalescloud.io/")
    else:
        llm_topic_selection = ChatGroq(model="llama3-70b-8192", max_tokens=8192)
    print(f"FOUND LLM TOPIC SELECTION FOR THE CONCEPT SELECTION PROMPT : {llm_topic_selection}")
    topic_selection_chain = CONCEPT_SELECTION_PROMPT | llm_topic_selection | StrOutputParser()
    return topic_selection_chain

def get_concepts(graph: Neo4jGraph):
    concept_cypher = "MATCH (c:Concept) return c"
    if isinstance(graph, Neo4jGraph):
        concepts = graph.query(concept_cypher)
    else:
        user_input = input("Topics : ")
        concepts = eval(user_input)

    concepts_name = [concept['c']['name'] for concept in concepts]
    return concepts_name

def get_related_concepts(graph: Neo4jGraph, question: str):
    concepts = get_concepts(graph)
    llm = get_model(model='gpt-4o')
    print(f"this is the llm variable : {llm}")
    def parse_answer(llm_answer : str):
        print(f"This the llm_answer : {llm_answer}")
        return re.split("\n(?:\d)+\.\s", llm_answer.split("Concepts:")[1])[1:]
    related_concepts_chain = RELEVANT_CONCEPTS_PROMPT | llm | StrOutputParser() | parse_answer

    related_concepts_raw =  related_concepts_chain.invoke({"user_query" : question, "concepts" : '\n'.join(concepts)})
    
    # We clean up the list we received from the LLM in case there were some hallucinations
    related_concepts_cleaned = []
    for related_concept in related_concepts_raw:
        # If the concept returned from the LLM is in the list we keep it
        if related_concept in concepts:
            related_concepts_cleaned.append(related_concept)
        else:
            # The LLM sometimes only forgets a few words from the concept name
            # We check if the generated concept is a substring of an existing one and if it is the case add it to the list
            for concept in concepts:
                if related_concept in concept:
                    related_concepts_cleaned.append(concept)
                    break

    # TODO : Add concepts found via similarity search
    return related_concepts_cleaned

def build_concept_string(graph: Neo4jGraph, concept_list: list[str]):
    concept_string = ""
    for concept in concept_list:
        concept_description_query = f"""
MATCH (c:Concept {{name: "{concept}" }}) RETURN c.description
"""
        concept_description = graph.query(concept_description_query)[0]['c.description']
        concept_string += f"name: {concept}\ndescription: {concept_description}\n\n"   
    return concept_string

def get_global_concepts(graph: Neo4jGraph):
    concept_cypher = "MATCH (gc:GlobalConcept) return gc"
    if isinstance(graph, Neo4jGraph):
        concepts = graph.query(concept_cypher)
    else:
        user_input = input("Topics : ")
        concepts = eval(user_input)

    concepts_name = [concept['gc']['name'] for concept in concepts]
    return concepts_name

def generate_cypher(state: DocRetrieverState, config: ConfigSchema):
    """
    The node where the cypher is generated
    """

    graph = config["configurable"].get("graph")
    question = state['query']
    related_concepts = get_related_concepts(graph, question)
    cyphers = []

    if config["configurable"].get("cypher_gen_method") == 'auto':
        cypher_gen_chain = get_cypher_gen_chain()
        cyphers = cypher_gen_chain.invoke({
            "schema": graph.schema,
            "question": question,
            "concepts": related_concepts
        })
        
    
    if config["configurable"].get("cypher_gen_method") == 'guided':
        concept_selection_chain = get_concept_selection_chain()
        print(f"Concept selection chain is : {concept_selection_chain}")
        selected_topic = concept_selection_chain.invoke({"question" : question, "concepts": get_concepts(graph)})
        print(f"Selected topic are : {selected_topic}")
        cyphers = [generate_cypher_from_topic(selected_topic, state['current_plan_step'])]
        print(f"Cyphers are : {cyphers}")


    if config["configurable"].get("validate_cypher"):
        corrector_schema = [Schema(el["start"], el["type"], el["end"]) for el in graph.structured_schema.get("relationships")]
        cypher_corrector = CypherQueryCorrector(corrector_schema)
        cyphers = [cypher_corrector(cypher) for cypher in cyphers]

    return {"cyphers" : cyphers}

def generate_cypher_from_topic(selected_concept: str, plan_step: int):
    """
    Helper function used when the 'guided' cypher generation method has been configured
    """

    print(f"L.176 PLAN STEP : {plan_step}")
    cypher_el = "(n) return n.title, n.description"
    match plan_step:
        case 0:
            cypher_el = "(ts:TechnicalSpecification) RETURN ts.title, ts.scope, ts.description"
        case 1:
            cypher_el = "(rp:ResearchPaper) RETURN rp.title, rp.abstract"
        case 2:
            cypher_el = "(ki:KeyIssue) RETURN ki.description"
    return f"MATCH (c:Concept {{name:'{selected_concept}'}})-[:RELATED_TO]-{cypher_el}" 

def get_docs(state:DocRetrieverState, config:ConfigSchema):
    """
    This node retrieves docs from the graph using the generated cypher
    """
    graph = config["configurable"].get("graph")
    output = []
    if graph is not None:
        for cypher in state["cyphers"]:
            try:
                output = graph.query(cypher)
                break
            except Exception as e:
                print("Failed to retrieve docs : {e}")

    # Clean up the docs we received as there may be duplicates depending on the cypher query
    all_docs = []
    for doc in output:
        unwinded_doc = {}
        for key in doc:
            if isinstance(doc[key], dict):
                all_docs.append(doc[key])
            else:
                unwinded_doc.update({key: doc[key]})
        if unwinded_doc:
            all_docs.append(unwinded_doc)

    
    filtered_docs = []
    for doc in all_docs:
        if doc not in filtered_docs:
            filtered_docs.append(doc)

    return {"docs": filtered_docs}





# Data model
class GradeDocumentsBinary(BaseModel):
    """Binary score for relevance check on retrieved documents."""

    binary_score: str = Field(
        description="Documents are relevant to the question, 'yes' or 'no'"
    )

# LLM with function call
# llm_grader_binary = ChatGroq(model="mixtral-8x7b-32768", temperature=0)

def get_binary_grader(model="mixtral-8x7b-32768"):
    """
    Returns a binary grader to evaluate relevance of documents using specified model for generation
    This is used when the 'binary' evaluation method has been configured
    """


    if model == "gpt-4o":
        llm_grader_binary = ChatOpenAI(model='gpt-4o', base_url="https://llm.synapse.thalescloud.io/", temperature=0)
    else:
        llm_grader_binary = ChatGroq(model="mixtral-8x7b-32768", temperature=0)
    structured_llm_grader_binary = llm_grader_binary.with_structured_output(GradeDocumentsBinary)
    retrieval_grader_binary = BINARY_GRADER_PROMPT | structured_llm_grader_binary
    return retrieval_grader_binary


class GradeDocumentsScore(BaseModel):
    """Score for relevance check on retrieved documents."""

    score: float = Field(
        description="Documents are relevant to the question, score between 0 (completely irrelevant) and 1 (perfectly relevant)"
    )

def get_score_grader(model="mixtral-8x7b-32768"):
    """
    Returns a score grader to evaluate relevance of documents using specified model for generation
    This is used when the 'score' evaluation method has been configured
    """
    if model == "gpt-4o":
        llm_grader_score = ChatOpenAI(model='gpt-4o', base_url="https://llm.synapse.thalescloud.io/", temperature=0)
    else:
        llm_grader_score = ChatGroq(model="mixtral-8x7b-32768", temperature = 0)
    structured_llm_grader_score = llm_grader_score.with_structured_output(GradeDocumentsScore)
    retrieval_grader_score = SCORE_GRADER_PROMPT | structured_llm_grader_score
    return retrieval_grader_score


def eval_doc(doc, query, method="binary", threshold=0.7, eval_model="mixtral-8x7b-32768"):
    '''
    doc : the document to evaluate
    query : the query to which to doc shoud be relevant
    method : "binary" or "score"
    threshold : for "score" method, score above which a doc is considered relevant
    '''
    if method == "binary":
        retrieval_grader_binary = get_binary_grader(model=eval_model)
        return 1 if (retrieval_grader_binary.invoke({"question": query, "document":doc}).binary_score == 'yes') else 0
    elif method == "score":
        retrieval_grader_score = get_score_grader(model=eval_model)
        score = retrieval_grader_score.invoke({"query": query, "document":doc}).score or None
        if score is not None:
            return score if score >= threshold else 0
        else:
            # Couldn't parse score, marking document as relevant by default
            return 1    
    else:
        raise ValueError("Invalid method")

def eval_docs(state: DocRetrieverState, config: ConfigSchema):
    """
    This node performs evaluation of the retrieved docs and 
    """

    eval_method =  config["configurable"].get("eval_method") or "binary"
    MAX_DOCS = config["configurable"].get("max_docs") or 15
    valid_doc_scores = []

    for doc in sample(state["docs"], min(25, len(state["docs"]))):
        score = eval_doc(
                        doc=format_doc(doc),
                        query=state["query"],
                        method=eval_method,
                        threshold=config["configurable"].get("eval_threshold") or 0.7,
                        eval_model = config["configurable"].get("eval_model") or "mixtral-8x7b-32768"
                        )
        if score:
            valid_doc_scores.append((doc, score))
    
    if eval_method == 'score':
        # Get at most MAX_DOCS items with the highest score if score method was used
        valid_docs = sorted(valid_doc_scores, key=lambda x: x[1])
        valid_docs = [valid_doc[0] for valid_doc in valid_docs[:MAX_DOCS]]
    else:
        # Get at mots MAX_DOCS items at random if binary method was used
        shuffle(valid_doc_scores)
        valid_docs = [valid_doc[0] for valid_doc in valid_doc_scores[:MAX_DOCS]]

    return {"valid_docs": valid_docs + (state["valid_docs"] or [])}



def build_data_retriever_graph(memory):
    """
    Builds the data_retriever graph
    """
    graph_builder_doc_retriever = StateGraph(DocRetrieverState)

    graph_builder_doc_retriever.add_node("generate_cypher", generate_cypher)
    graph_builder_doc_retriever.add_node("get_docs", get_docs)
    graph_builder_doc_retriever.add_node("eval_docs", eval_docs)


    graph_builder_doc_retriever.add_edge("__start__", "generate_cypher")
    graph_builder_doc_retriever.add_edge("generate_cypher", "get_docs")
    graph_builder_doc_retriever.add_edge("get_docs", "eval_docs")
    graph_builder_doc_retriever.add_edge("eval_docs", "__end__")

    graph_doc_retriever = graph_builder_doc_retriever.compile(checkpointer=memory)
    
    return graph_doc_retriever