File size: 5,960 Bytes
b4c2b4c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Sun Oct 13 10:30:56 2024

@author: legalchain
"""
from typing import Literal, Optional, List, Union, Any
from langchain_openai import ChatOpenAI
import pandas as pd
from langchain_core.prompts import ChatPromptTemplate
from langgraph.graph import END, StateGraph, START
from langchain_core.output_parsers import StrOutputParser
from pydantic import BaseModel, Field
from models import NatureJugement
from prompts import df_prompt, feed_back_prompt, reflection_prompt
from dotenv import load_dotenv
load_dotenv()

llm = ChatOpenAI(model="gpt-4o-mini")
MAX_GENERATIONS = 2
MAX_ROWS: int = 10

class Query(BaseModel): 
    query:str = Field(..., title="Requête pour filtrer les résultats du dataframe entourée avec des gullemets de type \" ")
    
    def clean_query(self):
        # Correction des échappements dans la chaîne de la requête
        corrected_query = self.query.replace("\\'", "\\'")
        # Extraire la condition à l'intérieur des crochets
        import re
        condition = re.search(r"df\[(.*)\]", corrected_query).group(1)
        return condition

class GradeResults(BaseModel):
    binary_score: Literal["yes", "no"] = Field(
        description="Les résultats sont satisfaisants -> 'yes' ou il y une erreur ou pas de résultats ou les résultats sont améliorables -> 'no'"
    )

class GraphState(BaseModel):

    df : Any
    df_head:str
    instructions: Optional[str] = None
    nature_jugement: List = ', '.join([e.value for e in  NatureJugement])
    region:str = ''
    dep:str = ''
    query: Optional[str] = None
    results :Union[str, List[str]] = []
    query_feedbacks: Optional[str] = None
    results_feedbacks: bool = None
    generation_num: int = 0
    retrieval_num: int = 0
    search_mode: Literal["vectorstore", "websearch", "QA_LM"] = "QA_LM"
    error_query: Optional[Any] = ""
    error_results: Optional[Any] = ""
    truncated: bool = False

    # Méthode pour récupérer le DataFrame
    def get_df(self) -> pd.DataFrame:
       return pd.read_json(self.df)
   
    # Surcharger l'initialisation pour créer les champs 'region' et 'dep'
    def __init__(self, **data):
        super().__init__(**data)
        
        # Extraire le DataFrame
        #df = self.get_df()
        
        # Générer les chaînes pour les régions et départements
        distinct_regions = self.df['region_nom_officiel'].dropna().unique().tolist()
        distinct_departements = self.df['departement_nom_officiel'].dropna().unique().tolist()
        
        # Convertir en chaînes séparées par des virgules
        self.region = ', '.join(distinct_regions)
        self.dep = ', '.join(distinct_departements)

def generate_query_node(state: GraphState):

    prompt  = ChatPromptTemplate.from_messages(messages = df_prompt)
    generate_df_query = prompt | llm.with_structured_output(
        Query,
        include_raw=True, # permet de checker les erreurs en sortie
    )
    # Ajouter le retour erreur de parse_error 
    try : 
        query_generate = generate_df_query.invoke({
            'df_head' : state.df_head, 
            'instructions' : state.instructions,
            'feedback' : state.query_feedbacks,
            'error' : state.error_query,
            'nature_jugement' : state.nature_jugement,
            'dep' : state.dep, 
            'region': state.region
            })
        query_final = query_generate['parsed'].clean_query()
        return {
            "query": query_final,
            "error_query" : "" # si il ya une erreur cela remet le compteur à zéro
        }
    except Exception as e:
        return {'error_query' : e}

def evaluate_query_node(state:GraphState):
    if state.error_query != "":
        return "Il y a une erreur dans la requête.  Je me suis sûrement trompé. Veuillez réessayer."
    else: 
        return "ok"


def generate_results_node(state:GraphState):
    try : 
        query = state.query
        print("query ", query)
        print('je suis dans generate', type(state.df))
        query = eval(query, {"df": state.df})
        new_df = state.df[query]
        print("new_df", new_df.empty)
        if new_df.empty: 
            return {
                "generation_num": state.generation_num + 1}
        elif len(new_df)> MAX_ROWS: 
            return {'results' : new_df.head(MAX_ROWS).to_json(orient='records'),
                "generation_num": state.generation_num + 1, 
                "truncated": True
                }
        else: 
            return {'results' : new_df.to_json(orient='records'),
                "generation_num": state.generation_num + 1, 
               
                }
    except Exception as e :
        return {'error_results' : e,
                "generation_num": state.generation_num + 1}


def evaluate_results_node(state:GraphState): 
    prompt_eval = ChatPromptTemplate.from_messages(messages=reflection_prompt)
    generate_eval = prompt_eval  | llm.with_structured_output(
        GradeResults,
        include_raw=False, # permet de checker les erreurs en sortie
    )

    evaluation = generate_eval.invoke({'df_head' : state.df_head, 
                                       'results' :state.results, 
                                       'instructions' : state.instructions})

    if state.generation_num > MAX_GENERATIONS:
        return "max_generation_reached"

    return evaluation.binary_score

def query_feedback_node(state: GraphState):
    prompt_feed_back = ChatPromptTemplate.from_messages(messages=feed_back_prompt)
    query_feedback_chain = prompt_feed_back| llm |StrOutputParser()
   
    feedback = query_feedback_chain.invoke({
        "df_head" : state.df_head,
        "instructions": state.instructions,
        "results": state.results,
        "query": state.query
    })

    feedback = f"Evaluation de la recherche : {feedback}"
    print(feedback)
    return {"query_feedbacks": feedback}