chat_bodacc / nodes.py
bagbreizh
Modification prompt et max_iterations
2da892b
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Sun Oct 13 10:30:56 2024
@author: legalchain
"""
from typing import Literal, Optional, List, Union, Any
from langchain_openai import ChatOpenAI
import pandas as pd
from langchain_core.prompts import ChatPromptTemplate
from langgraph.graph import END, StateGraph, START
from langchain_core.output_parsers import StrOutputParser
from pydantic import BaseModel, Field
from models import NatureJugement
from prompts import df_prompt, feed_back_prompt, reflection_prompt
llm = ChatOpenAI(model="gpt-4o-mini")
MAX_GENERATIONS = 3
MAX_ROWS: int = 200
class Query(BaseModel):
query:str = Field(..., title="Requête pour filtrer les résultats du dataframe entourée avec des gullemets de type \" ")
def clean_query(self):
# Correction des échappements dans la chaîne de la requête
corrected_query = self.query.replace("\\'", "\\'")
# Extraire la condition à l'intérieur des crochets
import re
condition = re.search(r"df\[(.*)\]", corrected_query).group(1)
return condition
class GradeResults(BaseModel):
binary_score: Literal["yes", "no"] = Field(
description="Les résultats sont satisfaisants -> 'yes' ou il y une erreur ou pas de résultats ou les résultats sont améliorables -> 'no'"
)
class GraphState(BaseModel):
df : Any
df_head:str
instructions: Optional[str] = None
nature_jugement: List = ', '.join([e.value for e in NatureJugement])
region:str = ''
dep:str = ''
query: Optional[str] = None
results :Union[str, List[str]] = []
query_feedbacks: Optional[str] = None
results_feedbacks: bool = None
generation_num: int = 0
retrieval_num: int = 0
search_mode: Literal["vectorstore", "websearch", "QA_LM"] = "QA_LM"
error_query: Optional[Any] = ""
error_results: Optional[Any] = ""
truncated: bool = False
# Méthode pour récupérer le DataFrame
def get_df(self) -> pd.DataFrame:
return pd.read_json(self.df)
# Surcharger l'initialisation pour créer les champs 'region' et 'dep'
def __init__(self, **data):
super().__init__(**data)
# Générer les chaînes pour les régions et départements
distinct_regions = self.df['region_nom_officiel'].dropna().unique().tolist()
distinct_departements = self.df['departement_nom_officiel'].dropna().unique().tolist()
# Convertir en chaînes séparées par des virgules
self.region = ', '.join(distinct_regions)
self.dep = ', '.join(distinct_departements)
def generate_query_node(state: GraphState):
prompt = ChatPromptTemplate.from_messages(messages = df_prompt)
generate_df_query = prompt | llm.with_structured_output(
Query,
include_raw=True, # permet de checker les erreurs en sortie
)
# TODO : Ajouter le retour erreur de parse_error
try :
query_generate = generate_df_query.invoke({
'df_head' : state.df_head,
'instructions' : state.instructions,
'feedback' : state.query_feedbacks,
'error' : state.error_query,
'nature_jugement' : state.nature_jugement,
'dep' : state.dep,
'region': state.region
})
query_final = query_generate['parsed'].clean_query()
return {
"query": query_final,
"error_query" : "" # si il ya une erreur cela remet le compteur à zéro
}
except Exception as e:
return {'error_query' : e}
def evaluate_query_node(state:GraphState):
if state.error_query != "":
return "Il y a une erreur dans la requête. Je me suis sûrement trompé. Veuillez réessayer."
else:
return "ok"
def generate_results_node(state:GraphState):
try :
query = state.query
print("query ", query)
print('je suis dans generate', type(state.df))
query = eval(query, {"df": state.df})
new_df = state.df[query]
print("new_df", new_df.empty)
if new_df.empty:
return {
"generation_num": state.generation_num + 1}
elif len(new_df)> MAX_ROWS:
return {'results' : new_df.head(MAX_ROWS).to_json(orient='records'),
"generation_num": state.generation_num + 1,
"truncated": True
}
else:
return {'results' : new_df.to_json(orient='records'),
"generation_num": state.generation_num + 1,
}
except Exception as e :
return {'error_results' : e,
"generation_num": state.generation_num + 1}
def evaluate_results_node(state:GraphState):
prompt_eval = ChatPromptTemplate.from_messages(messages=reflection_prompt)
generate_eval = prompt_eval | llm.with_structured_output(
GradeResults,
include_raw=False, # permet de checker les erreurs en sortie
)
evaluation = generate_eval.invoke({'df_head' : state.df_head,
'results' :state.results,
'instructions' : state.instructions})
if state.generation_num > MAX_GENERATIONS:
return "max_generation_reached"
return evaluation.binary_score
def query_feedback_node(state: GraphState):
prompt_feed_back = ChatPromptTemplate.from_messages(messages=feed_back_prompt)
query_feedback_chain = prompt_feed_back| llm |StrOutputParser()
feedback = query_feedback_chain.invoke({
"df_head" : state.df_head,
"instructions": state.instructions,
"results": state.results,
"query": state.query
})
feedback = f"Evaluation de la recherche : {feedback}"
print(feedback)
return {"query_feedbacks": feedback}