chat_bodacc / graph.py
rdassignies's picture
Upload 7 files
b4c2b4c verified
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
Created on Sun Sep 22 15:43:16 2024
@author: Raphaël d'Assignies ([email protected])
"""
import json
from typing import Literal, Optional, List, Union, Any
from langchain_openai import ChatOpenAI
import pandas as pd
from langchain_core.prompts import ChatPromptTemplate
from langgraph.graph import END, StateGraph, START
from langchain_core.output_parsers import StrOutputParser
from pydantic import BaseModel, Field
from models import NatureJugement
from nodes import (GradeResults, GraphState, generate_query_node,
generate_results_node, query_feedback_node,
evaluate_query_node, evaluate_results_node)
import streamlit as st
# Instanciate pipeline
pipeline = StateGraph(GraphState)
pipeline.add_node('generate_query', generate_query_node)
pipeline.add_node('generate_results', generate_results_node)
pipeline.add_node('query_feedback', query_feedback_node)
# Only query
#pipeline.add_edge(START,'generate_query')
#pipeline.add_edge('generate_query', generate_query_node)
#pipeline.add_edge('generate_query', END)
# Full scenario
pipeline.add_edge(START,'generate_query')
pipeline.add_conditional_edges(
'generate_query',
evaluate_query_node,
{'error_query' : 'generate_query',
'ok' : 'generate_results'
})
pipeline.add_conditional_edges(
'generate_results',
evaluate_results_node,
{
"yes": END,
"no": 'query_feedback',
"max_generation_reached": END
}
)
# Création du graph
graph = pipeline.compile()
# Load le dataframe
df = pd.read_json('bodacc.json', orient='table')
# Initialise le dictionnaire
inputs = {
'df_head': df.head().to_csv(),
'df': df
}
# Créé un dictionnaire des sorties vide
outputs = {}
# Titre de l'application
st.title("Chat with BODACC !")
# Message d'avertissement
warning_message = (f"Cet outil, purement pédagogique, est basé sur des données réelles allant de {df['dateparution'].min()} "
f"à {df['dateparution'].max()}, et permet d'interroger le BODACC en langage naturel. Compte tenu de la variabilité des modèles, nous ne pouvons pas garantir la fiabilité des réponses.")
st.warning(warning_message)
# Interface utilisateur pour entrer la requête
user_query = st.text_input("Entrez votre requête:", "Trouve moi les restaurants à reprendre en Bretagne dans les 30 derniers jours")
# Afficher les résultats avec Streamlit
inputs["instructions"] = user_query
# Afficher un bouton pour démarrer la recherche
if st.button("Lancer la recherche"):
config = {"configurable": {"thread_id": "2"}}
# Étape 1 : Afficher le message "Je réfléchis..."
st.write("Je réfléchis...")
# Stream des résultats au fur et à mesure
with st.spinner('Recherche en cours...'):
for output in graph.stream(inputs, stream_mode='values', debug=False):
# Ajouter les résultats au dictionnaire outputs
for k, v in output.items():
if k not in outputs:
outputs[k] = []
outputs[k].append(v)
# Ne pas afficher les messages pour les clés non pertinentes (comme error_query)
if 'query' in output and len(output['query'])>0:
st.write(f"query : {output['query']}")
#st.write(outputs.get('query_feedbacks', 'pas de feedback'))
#st.write(outputs.get('results_feedbacks', 'pas de resultfeedback'))
if "results" in output and len(output["results"]) > 0:
records = json.loads(output['results'])
st.write(f"Résultats intermédiaires trouvés : {len(records)} résultats jusqu'à présent.")
# Après la fin du traitement
if "results" in outputs and len(outputs["results"]) > 0:
# Agréger tous les résultats accumulés
all_results = []
for res in outputs["results"]:
json_data = json.loads(res) # Convertir chaque ensemble de résultats en JSON
all_results.extend(json_data) # Accumuler tous les résultats
results_df = pd.DataFrame(all_results) # Créer un DataFrame avec tous les résultats accumulés
# Afficher un aperçu des résultats (jusqu'à 5 premiers)
num_results = len(results_df)
st.write(f"J'ai trouvé {num_results} résultats.")
if num_results > 0:
preview_count = min(5, num_results) # Gérer le cas où il y a moins de 5 résultats
st.write(f"Voici un aperçu des {preview_count} premiers résultats :")
st.write(results_df.head(preview_count))
trunc = outputs.get('truncated', 'pas de traunc')
if trunc[0] == True:
st.warning("Les résultats de votre recherche ont été tronqués car celle-ci était trop large ! ")
# Convertir tous les résultats en CSV
csv = results_df.to_csv(index=False)
# Ajouter un bouton pour télécharger tous les résultats
st.download_button(
label="Télécharger le résultat complet au format CSV",
data=csv,
file_name="results.csv",
mime="text/csv"
)
else:
# Si aucun résultat n'est trouvé
st.write("Aucun résultat trouvé.")