Update main.py
Browse files
main.py
CHANGED
@@ -5,24 +5,18 @@ import pandas as pd
|
|
5 |
import numpy as np
|
6 |
from typing import List
|
7 |
from pathlib import Path
|
8 |
-
|
9 |
from langchain_openai import ChatOpenAI, OpenAI
|
10 |
from langchain.schema.runnable.config import RunnableConfig
|
11 |
-
|
12 |
-
from langchain_core.prompts import ChatPromptTemplate
|
13 |
|
14 |
from langchain.agents import AgentExecutor
|
15 |
from langchain.agents.agent_types import AgentType
|
16 |
#from langchain_experimental.agents.agent_toolkits import create_pandas_dataframe_agent, create_csv_agent
|
17 |
-
|
18 |
-
from langchain.chains import create_sql_query_chain
|
19 |
from langchain_community.utilities import SQLDatabase
|
20 |
from sqlalchemy import create_engine
|
21 |
-
from langchain_community.tools.sql_database.tool import QuerySQLDataBaseTool
|
22 |
-
|
23 |
-
|
24 |
-
from langchain_core.output_parsers import StrOutputParser
|
25 |
-
from langchain_core.runnables import RunnablePassthrough
|
26 |
|
27 |
import chainlit as cl
|
28 |
from chainlit.input_widget import TextInput, Select, Switch, Slider
|
@@ -80,28 +74,7 @@ def create_agent(filename: str):
|
|
80 |
db = cl.user_session.get("db")
|
81 |
# Create a SAL agent.
|
82 |
#e.g agent_executor.invoke({"input": "Quel est le nombre de chargé d'affaires en agencement par entreprise?"})
|
83 |
-
|
84 |
-
execute_query = QuerySQLDataBaseTool(db=db)
|
85 |
-
write_query = create_sql_query_chain(llm, db)
|
86 |
-
|
87 |
-
answer_prompt = PromptTemplate.from_template(
|
88 |
-
"""Given the following user question, corresponding SQL query, and SQL result, answer the user question.
|
89 |
-
|
90 |
-
Question: {question}
|
91 |
-
SQL Query: {query}
|
92 |
-
SQL Result: {result}
|
93 |
-
Answer: """
|
94 |
-
)
|
95 |
-
|
96 |
-
chain = (
|
97 |
-
RunnablePassthrough.assign(query=write_query).assign(
|
98 |
-
result=itemgetter("query") | execute_query
|
99 |
-
)
|
100 |
-
| answer_prompt
|
101 |
-
| llm
|
102 |
-
| StrOutputParser()
|
103 |
-
)
|
104 |
-
cl.user_session.set("chain", chain)
|
105 |
|
106 |
def query_agent(agent, query):
|
107 |
"""
|
@@ -218,13 +191,12 @@ async def on_chat_start():
|
|
218 |
async def on_message(message: cl.Message):
|
219 |
await cl.Message(f"> SURVEYIA").send()
|
220 |
agent = create_agent("./public/surveyia.csv")
|
221 |
-
|
222 |
-
#cb = cl.AsyncLangchainCallbackHandler()
|
223 |
try:
|
224 |
#res = await agent.acall("Réponds en langue française à la question suivante : " + message.content, callbacks=[cb])
|
225 |
-
res = await
|
226 |
#res = await agent.ainvoke("Réponds de la manière la plus complète et la plus intelligible, en langue française, à la question suivante : " + message.content + ". Réponds au format markdown ou au format tableau si le résultat nécessite l'affichage d'un tableau.")
|
227 |
-
await cl.Message(author="COPILOT",content=GoogleTranslator(source='auto', target='fr').translate(res)).send()
|
228 |
except ValueError as e:
|
229 |
res = str(e)
|
230 |
resArray = res.split(":")
|
|
|
5 |
import numpy as np
|
6 |
from typing import List
|
7 |
from pathlib import Path
|
8 |
+
|
9 |
from langchain_openai import ChatOpenAI, OpenAI
|
10 |
from langchain.schema.runnable.config import RunnableConfig
|
11 |
+
from langchain.schema import StrOutputParser
|
12 |
+
from langchain_core.prompts import ChatPromptTemplate
|
13 |
|
14 |
from langchain.agents import AgentExecutor
|
15 |
from langchain.agents.agent_types import AgentType
|
16 |
#from langchain_experimental.agents.agent_toolkits import create_pandas_dataframe_agent, create_csv_agent
|
17 |
+
from langchain_community.agent_toolkits import create_sql_agent
|
|
|
18 |
from langchain_community.utilities import SQLDatabase
|
19 |
from sqlalchemy import create_engine
|
|
|
|
|
|
|
|
|
|
|
20 |
|
21 |
import chainlit as cl
|
22 |
from chainlit.input_widget import TextInput, Select, Switch, Slider
|
|
|
74 |
db = cl.user_session.get("db")
|
75 |
# Create a SAL agent.
|
76 |
#e.g agent_executor.invoke({"input": "Quel est le nombre de chargé d'affaires en agencement par entreprise?"})
|
77 |
+
return create_sql_agent(llm, db=db, agent_type="openai-tools", verbose=False)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
78 |
|
79 |
def query_agent(agent, query):
|
80 |
"""
|
|
|
191 |
async def on_message(message: cl.Message):
|
192 |
await cl.Message(f"> SURVEYIA").send()
|
193 |
agent = create_agent("./public/surveyia.csv")
|
194 |
+
cb = cl.AsyncLangchainCallbackHandler()
|
|
|
195 |
try:
|
196 |
#res = await agent.acall("Réponds en langue française à la question suivante : " + message.content, callbacks=[cb])
|
197 |
+
res = await agent.ainvoke({"input": "Réponds de la manière la plus complète et la plus intelligible, en langue française, à la question suivante : " + message.content + ". Réponds au format markdown ou au format tableau si le résultat nécessite l'affichage d'un tableau."})
|
198 |
#res = await agent.ainvoke("Réponds de la manière la plus complète et la plus intelligible, en langue française, à la question suivante : " + message.content + ". Réponds au format markdown ou au format tableau si le résultat nécessite l'affichage d'un tableau.")
|
199 |
+
await cl.Message(author="COPILOT",content=GoogleTranslator(source='auto', target='fr').translate(res['output'])).send()
|
200 |
except ValueError as e:
|
201 |
res = str(e)
|
202 |
resArray = res.split(":")
|