Spaces:
Running
Running
File size: 5,856 Bytes
b4c2b4c 59f74d3 b4c2b4c 2da892b 743f293 b4c2b4c 59f74d3 b4c2b4c 59f74d3 b4c2b4c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 |
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Sun Oct 13 10:30:56 2024
@author: legalchain
"""
from typing import Literal, Optional, List, Union, Any
from langchain_openai import ChatOpenAI
import pandas as pd
from langchain_core.prompts import ChatPromptTemplate
from langgraph.graph import END, StateGraph, START
from langchain_core.output_parsers import StrOutputParser
from pydantic import BaseModel, Field
from models import NatureJugement
from prompts import df_prompt, feed_back_prompt, reflection_prompt
llm = ChatOpenAI(model="gpt-4o-mini")
MAX_GENERATIONS = 3
MAX_ROWS: int = 200
class Query(BaseModel):
query:str = Field(..., title="Requête pour filtrer les résultats du dataframe entourée avec des gullemets de type \" ")
def clean_query(self):
# Correction des échappements dans la chaîne de la requête
corrected_query = self.query.replace("\\'", "\\'")
# Extraire la condition à l'intérieur des crochets
import re
condition = re.search(r"df\[(.*)\]", corrected_query).group(1)
return condition
class GradeResults(BaseModel):
binary_score: Literal["yes", "no"] = Field(
description="Les résultats sont satisfaisants -> 'yes' ou il y une erreur ou pas de résultats ou les résultats sont améliorables -> 'no'"
)
class GraphState(BaseModel):
df : Any
df_head:str
instructions: Optional[str] = None
nature_jugement: List = ', '.join([e.value for e in NatureJugement])
region:str = ''
dep:str = ''
query: Optional[str] = None
results :Union[str, List[str]] = []
query_feedbacks: Optional[str] = None
results_feedbacks: bool = None
generation_num: int = 0
retrieval_num: int = 0
search_mode: Literal["vectorstore", "websearch", "QA_LM"] = "QA_LM"
error_query: Optional[Any] = ""
error_results: Optional[Any] = ""
truncated: bool = False
# Méthode pour récupérer le DataFrame
def get_df(self) -> pd.DataFrame:
return pd.read_json(self.df)
# Surcharger l'initialisation pour créer les champs 'region' et 'dep'
def __init__(self, **data):
super().__init__(**data)
# Générer les chaînes pour les régions et départements
distinct_regions = self.df['region_nom_officiel'].dropna().unique().tolist()
distinct_departements = self.df['departement_nom_officiel'].dropna().unique().tolist()
# Convertir en chaînes séparées par des virgules
self.region = ', '.join(distinct_regions)
self.dep = ', '.join(distinct_departements)
def generate_query_node(state: GraphState):
prompt = ChatPromptTemplate.from_messages(messages = df_prompt)
generate_df_query = prompt | llm.with_structured_output(
Query,
include_raw=True, # permet de checker les erreurs en sortie
)
# TODO : Ajouter le retour erreur de parse_error
try :
query_generate = generate_df_query.invoke({
'df_head' : state.df_head,
'instructions' : state.instructions,
'feedback' : state.query_feedbacks,
'error' : state.error_query,
'nature_jugement' : state.nature_jugement,
'dep' : state.dep,
'region': state.region
})
query_final = query_generate['parsed'].clean_query()
return {
"query": query_final,
"error_query" : "" # si il ya une erreur cela remet le compteur à zéro
}
except Exception as e:
return {'error_query' : e}
def evaluate_query_node(state:GraphState):
if state.error_query != "":
return "Il y a une erreur dans la requête. Je me suis sûrement trompé. Veuillez réessayer."
else:
return "ok"
def generate_results_node(state:GraphState):
try :
query = state.query
print("query ", query)
print('je suis dans generate', type(state.df))
query = eval(query, {"df": state.df})
new_df = state.df[query]
print("new_df", new_df.empty)
if new_df.empty:
return {
"generation_num": state.generation_num + 1}
elif len(new_df)> MAX_ROWS:
return {'results' : new_df.head(MAX_ROWS).to_json(orient='records'),
"generation_num": state.generation_num + 1,
"truncated": True
}
else:
return {'results' : new_df.to_json(orient='records'),
"generation_num": state.generation_num + 1,
}
except Exception as e :
return {'error_results' : e,
"generation_num": state.generation_num + 1}
def evaluate_results_node(state:GraphState):
prompt_eval = ChatPromptTemplate.from_messages(messages=reflection_prompt)
generate_eval = prompt_eval | llm.with_structured_output(
GradeResults,
include_raw=False, # permet de checker les erreurs en sortie
)
evaluation = generate_eval.invoke({'df_head' : state.df_head,
'results' :state.results,
'instructions' : state.instructions})
if state.generation_num > MAX_GENERATIONS:
return "max_generation_reached"
return evaluation.binary_score
def query_feedback_node(state: GraphState):
prompt_feed_back = ChatPromptTemplate.from_messages(messages=feed_back_prompt)
query_feedback_chain = prompt_feed_back| llm |StrOutputParser()
feedback = query_feedback_chain.invoke({
"df_head" : state.df_head,
"instructions": state.instructions,
"results": state.results,
"query": state.query
})
feedback = f"Evaluation de la recherche : {feedback}"
print(feedback)
return {"query_feedbacks": feedback} |