Spaces:
Paused
Paused
Add initial implementation of SQL agent with few-shot learning and ChainLit integration for dynamic SQL query generation and execution based on user input.
Browse files- .tables +0 -0
- SELECT +0 -0
- app.py +30 -0
- prompt_templates.py +60 -0
- requirements.txt +6 -0
- sql_agent.py +63 -0
.tables
ADDED
File without changes
|
SELECT
ADDED
File without changes
|
app.py
ADDED
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import chainlit as cl
|
2 |
+
from langchain.schema.runnable.config import RunnableConfig
|
3 |
+
from sql_agent import SQLAgent
|
4 |
+
|
5 |
+
# Test the agent
|
6 |
+
# agent.invoke({"input": "How many artists are there?"})
|
7 |
+
|
8 |
+
# ChainLit Integration
|
9 |
+
@cl.on_chat_start
|
10 |
+
async def on_chat_start():
|
11 |
+
cl.user_session.set("agent", SQLAgent)
|
12 |
+
|
13 |
+
@cl.on_message
|
14 |
+
async def on_message(message: cl.Message):
|
15 |
+
agent = cl.user_session.get("agent") # Get the agent from the session
|
16 |
+
cb = cl.AsyncLangchainCallbackHandler(stream_final_answer=True)
|
17 |
+
config = RunnableConfig(callbacks=[cb])
|
18 |
+
|
19 |
+
result = await agent.ainvoke(message.content, config=config)
|
20 |
+
|
21 |
+
msg = cl.Message(content="")
|
22 |
+
|
23 |
+
async for chunk in result:
|
24 |
+
await msg.stream_token(chunk)
|
25 |
+
|
26 |
+
await msg.send()
|
27 |
+
|
28 |
+
# Run the app
|
29 |
+
if __name__ == "__main__":
|
30 |
+
cl.run()
|
prompt_templates.py
ADDED
@@ -0,0 +1,60 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Few-shot examples
|
2 |
+
few_shot_examples = [
|
3 |
+
{"input": "List all artists.", "query": "SELECT * FROM Artist;"},
|
4 |
+
{
|
5 |
+
"input": "Find all albums for the artist 'AC/DC'.",
|
6 |
+
"query": "SELECT * FROM Album WHERE ArtistId = (SELECT ArtistId FROM Artist WHERE Name = 'AC/DC');",
|
7 |
+
},
|
8 |
+
{
|
9 |
+
"input": "List all tracks in the 'Rock' genre.",
|
10 |
+
"query": "SELECT * FROM Track WHERE GenreId = (SELECT GenreId FROM Genre WHERE Name = 'Rock');",
|
11 |
+
},
|
12 |
+
{
|
13 |
+
"input": "Find the total duration of all tracks.",
|
14 |
+
"query": "SELECT SUM(Milliseconds) FROM Track;",
|
15 |
+
},
|
16 |
+
{
|
17 |
+
"input": "List all customers from Canada.",
|
18 |
+
"query": "SELECT * FROM Customer WHERE Country = 'Canada';",
|
19 |
+
},
|
20 |
+
{
|
21 |
+
"input": "How many tracks are there in the album with ID 5?",
|
22 |
+
"query": "SELECT COUNT(*) FROM Track WHERE AlbumId = 5;",
|
23 |
+
},
|
24 |
+
{
|
25 |
+
"input": "Find the total number of invoices.",
|
26 |
+
"query": "SELECT COUNT(*) FROM Invoice;",
|
27 |
+
},
|
28 |
+
{
|
29 |
+
"input": "List all tracks that are longer than 5 minutes.",
|
30 |
+
"query": "SELECT * FROM Track WHERE Milliseconds > 300000;",
|
31 |
+
},
|
32 |
+
{
|
33 |
+
"input": "Who are the top 5 customers by total purchase?",
|
34 |
+
"query": "SELECT CustomerId, SUM(Total) AS TotalPurchase FROM Invoice GROUP BY CustomerId ORDER BY TotalPurchase DESC LIMIT 5;",
|
35 |
+
},
|
36 |
+
{
|
37 |
+
"input": "Which albums are from the year 2000?",
|
38 |
+
"query": "SELECT * FROM Album WHERE strftime('%Y', ReleaseDate) = '2000';",
|
39 |
+
},
|
40 |
+
{
|
41 |
+
"input": "How many employees are there",
|
42 |
+
"query": 'SELECT COUNT(*) FROM "Employee"',
|
43 |
+
},
|
44 |
+
]
|
45 |
+
|
46 |
+
# System Prompt template prefix
|
47 |
+
system_prefix = """You are an agent designed to interact with a SQL database.
|
48 |
+
Given an input question, create a syntactically correct {dialect} query to run, then look at the results of the query and return the answer.
|
49 |
+
Unless the user specifies a specific number of examples they wish to obtain, always limit your query to at most {top_k} results.
|
50 |
+
You can order the results by a relevant column to return the most interesting examples in the database.
|
51 |
+
Never query for all the columns from a specific table, only ask for the relevant columns given the question.
|
52 |
+
You have access to tools for interacting with the database.
|
53 |
+
Only use the given tools. Only use the information returned by the tools to construct your final answer.
|
54 |
+
You MUST double check your query before executing it. If you get an error while executing a query, rewrite the query and try again.
|
55 |
+
|
56 |
+
DO NOT make any DML statements (INSERT, UPDATE, DELETE, DROP etc.) to the database.
|
57 |
+
|
58 |
+
If the question does not seem related to the database, just return "I don't know" as the answer.
|
59 |
+
|
60 |
+
Here are some examples of user inputs and their corresponding SQL queries:"""
|
requirements.txt
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
langchain
|
2 |
+
langchain-community
|
3 |
+
langchain-openai
|
4 |
+
python-dotenv
|
5 |
+
faiss-cpu
|
6 |
+
chainlit
|
sql_agent.py
ADDED
@@ -0,0 +1,63 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from dotenv import load_dotenv
|
2 |
+
from langchain_community.agent_toolkits import create_sql_agent
|
3 |
+
from langchain_community.vectorstores import FAISS
|
4 |
+
from langchain_core.example_selectors import SemanticSimilarityExampleSelector
|
5 |
+
from langchain_core.prompts import ChatPromptTemplate, FewShotPromptTemplate, MessagesPlaceholder, PromptTemplate, SystemMessagePromptTemplate
|
6 |
+
from langchain_openai import OpenAIEmbeddings
|
7 |
+
from langchain_openai import ChatOpenAI
|
8 |
+
from langchain_community.utilities import SQLDatabase
|
9 |
+
from prompt_templates import few_shot_examples, system_prefix
|
10 |
+
|
11 |
+
|
12 |
+
# Load the .env file
|
13 |
+
load_dotenv()
|
14 |
+
|
15 |
+
# Initialize the SQL database
|
16 |
+
db = SQLDatabase.from_uri("sqlite:///Chinook.db")
|
17 |
+
|
18 |
+
# Check the database connection
|
19 |
+
print(db.dialect)
|
20 |
+
print(db.get_usable_table_names())
|
21 |
+
db.run("SELECT * FROM Artist LIMIT 10;")
|
22 |
+
|
23 |
+
# Initialize the LLM
|
24 |
+
llm = ChatOpenAI(model="gpt-4o", temperature=0)
|
25 |
+
|
26 |
+
|
27 |
+
# Example selector will dynamically select examples based on the input question
|
28 |
+
example_selector = SemanticSimilarityExampleSelector.from_examples(
|
29 |
+
few_shot_examples,
|
30 |
+
OpenAIEmbeddings(),
|
31 |
+
FAISS,
|
32 |
+
k=5,
|
33 |
+
input_keys=["input"],
|
34 |
+
)
|
35 |
+
|
36 |
+
# Few-shot prompt template
|
37 |
+
few_shot_prompt = FewShotPromptTemplate(
|
38 |
+
example_selector=example_selector,
|
39 |
+
example_prompt=PromptTemplate.from_template(
|
40 |
+
"User input: {input}\nSQL query: {query}"
|
41 |
+
),
|
42 |
+
input_variables=["input", "dialect", "top_k"],
|
43 |
+
prefix=system_prefix,
|
44 |
+
suffix="",
|
45 |
+
)
|
46 |
+
|
47 |
+
# Full prompt template
|
48 |
+
full_prompt = ChatPromptTemplate.from_messages(
|
49 |
+
[
|
50 |
+
SystemMessagePromptTemplate(prompt=few_shot_prompt),
|
51 |
+
("human", "{input}"),
|
52 |
+
MessagesPlaceholder("agent_scratchpad"),
|
53 |
+
]
|
54 |
+
)
|
55 |
+
|
56 |
+
# Create the SQL agent
|
57 |
+
SQLAgent = create_sql_agent(
|
58 |
+
llm=llm,
|
59 |
+
db=db,
|
60 |
+
prompt=full_prompt,
|
61 |
+
verbose=True,
|
62 |
+
agent_type="openai-tools",
|
63 |
+
)
|