datacipen commited on
Commit
7754b99
·
verified ·
1 Parent(s): ac418c5

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +35 -8
main.py CHANGED
@@ -5,18 +5,24 @@ import pandas as pd
5
  import numpy as np
6
  from typing import List
7
  from pathlib import Path
 
8
  from langchain_openai import ChatOpenAI, OpenAI
9
  from langchain.schema.runnable.config import RunnableConfig
10
- from langchain.schema import StrOutputParser
11
- from langchain_core.prompts import ChatPromptTemplate
12
 
13
  from langchain.agents import AgentExecutor
14
  from langchain.agents.agent_types import AgentType
15
  #from langchain_experimental.agents.agent_toolkits import create_pandas_dataframe_agent, create_csv_agent
16
- from langchain_community.agent_toolkits import create_sql_agent
17
-
18
  from langchain_community.utilities import SQLDatabase
19
  from sqlalchemy import create_engine
 
 
 
 
 
20
 
21
  import chainlit as cl
22
  from chainlit.input_widget import TextInput, Select, Switch, Slider
@@ -74,8 +80,28 @@ def create_agent(filename: str):
74
  db = cl.user_session.get("db")
75
  # Create a SAL agent.
76
  #e.g agent_executor.invoke({"input": "Quel est le nombre de chargé d'affaires en agencement par entreprise?"})
77
- #return create_pandas_dataframe_agent(llm, df, verbose=False, allow_dangerous_code=True, handle_parsing_errors=True, agent_type=AgentType.OPENAI_FUNCTIONS)
78
- return create_sql_agent(llm, db=db, agent_type="openai-tools", verbose=False)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
79
 
80
  def query_agent(agent, query):
81
  """
@@ -192,10 +218,11 @@ async def on_chat_start():
192
  async def on_message(message: cl.Message):
193
  await cl.Message(f"> SURVEYIA").send()
194
  agent = create_agent("./public/surveyia.csv")
195
- cb = cl.AsyncLangchainCallbackHandler()
 
196
  try:
197
  #res = await agent.acall("Réponds en langue française à la question suivante : " + message.content, callbacks=[cb])
198
- res = await agent.ainvoke({"input": "Réponds de la manière la plus complète et la plus intelligible, en langue française, à la question suivante : " + message.content + ". Réponds au format markdown ou au format tableau si le résultat nécessite l'affichage d'un tableau."})
199
  #res = await agent.ainvoke("Réponds de la manière la plus complète et la plus intelligible, en langue française, à la question suivante : " + message.content + ". Réponds au format markdown ou au format tableau si le résultat nécessite l'affichage d'un tableau.")
200
  await cl.Message(author="COPILOT",content=GoogleTranslator(source='auto', target='fr').translate(res['output'])).send()
201
  except ValueError as e:
 
5
  import numpy as np
6
  from typing import List
7
  from pathlib import Path
8
+ from operator import itemgetter
9
  from langchain_openai import ChatOpenAI, OpenAI
10
  from langchain.schema.runnable.config import RunnableConfig
11
+ #from langchain.schema import StrOutputParser
12
+ from langchain_core.prompts import ChatPromptTemplate, PromptTemplate
13
 
14
  from langchain.agents import AgentExecutor
15
  from langchain.agents.agent_types import AgentType
16
  #from langchain_experimental.agents.agent_toolkits import create_pandas_dataframe_agent, create_csv_agent
17
+ #from langchain_community.agent_toolkits import create_sql_agent
18
+ from langchain.chains import create_sql_query_chain
19
  from langchain_community.utilities import SQLDatabase
20
  from sqlalchemy import create_engine
21
+ from langchain_community.tools.sql_database.tool import QuerySQLDataBaseTool
22
+
23
+
24
+ from langchain_core.output_parsers import StrOutputParser
25
+ from langchain_core.runnables import RunnablePassthrough
26
 
27
  import chainlit as cl
28
  from chainlit.input_widget import TextInput, Select, Switch, Slider
 
80
  db = cl.user_session.get("db")
81
  # Create a SAL agent.
82
  #e.g agent_executor.invoke({"input": "Quel est le nombre de chargé d'affaires en agencement par entreprise?"})
83
+ #return create_sql_agent(llm, db=db, agent_type="openai-tools", verbose=False)
84
+ execute_query = QuerySQLDataBaseTool(db=db)
85
+ write_query = create_sql_query_chain(llm, db)
86
+
87
+ answer_prompt = PromptTemplate.from_template(
88
+ """Given the following user question, corresponding SQL query, and SQL result, answer the user question.
89
+
90
+ Question: {question}
91
+ SQL Query: {query}
92
+ SQL Result: {result}
93
+ Answer: """
94
+ )
95
+
96
+ chain = (
97
+ RunnablePassthrough.assign(query=write_query).assign(
98
+ result=itemgetter("query") | execute_query
99
+ )
100
+ | answer_prompt
101
+ | llm
102
+ | StrOutputParser()
103
+ )
104
+ cl.user_session.set("chain", chain)
105
 
106
  def query_agent(agent, query):
107
  """
 
218
  async def on_message(message: cl.Message):
219
  await cl.Message(f"> SURVEYIA").send()
220
  agent = create_agent("./public/surveyia.csv")
221
+ chain_executor = cl.user_session.get("chain")
222
+ #cb = cl.AsyncLangchainCallbackHandler()
223
  try:
224
  #res = await agent.acall("Réponds en langue française à la question suivante : " + message.content, callbacks=[cb])
225
+ res = await chain_executor.ainvoke({"question": "Réponds de la manière la plus complète et la plus intelligible, en langue française, à la question suivante : " + message.content + ". Réponds au format markdown ou au format tableau si le résultat nécessite l'affichage d'un tableau."})
226
  #res = await agent.ainvoke("Réponds de la manière la plus complète et la plus intelligible, en langue française, à la question suivante : " + message.content + ". Réponds au format markdown ou au format tableau si le résultat nécessite l'affichage d'un tableau.")
227
  await cl.Message(author="COPILOT",content=GoogleTranslator(source='auto', target='fr').translate(res['output'])).send()
228
  except ValueError as e: