JulsdL commited on
Commit
cea0ce1
·
1 Parent(s): f0e847a

Add initial implementation of SQL agent with few-shot learning and ChainLit integration for dynamic SQL query generation and execution based on user input.

Browse files
Files changed (6) hide show
  1. .tables +0 -0
  2. SELECT +0 -0
  3. app.py +30 -0
  4. prompt_templates.py +60 -0
  5. requirements.txt +6 -0
  6. sql_agent.py +63 -0
.tables ADDED
File without changes
SELECT ADDED
File without changes
app.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import chainlit as cl
2
+ from langchain.schema.runnable.config import RunnableConfig
3
+ from sql_agent import SQLAgent
4
+
5
+ # Test the agent
6
+ # agent.invoke({"input": "How many artists are there?"})
7
+
8
+ # ChainLit Integration
9
+ @cl.on_chat_start
10
+ async def on_chat_start():
11
+ cl.user_session.set("agent", SQLAgent)
12
+
13
+ @cl.on_message
14
+ async def on_message(message: cl.Message):
15
+ agent = cl.user_session.get("agent") # Get the agent from the session
16
+ cb = cl.AsyncLangchainCallbackHandler(stream_final_answer=True)
17
+ config = RunnableConfig(callbacks=[cb])
18
+
19
+ result = await agent.ainvoke(message.content, config=config)
20
+
21
+ msg = cl.Message(content="")
22
+
23
+ async for chunk in result:
24
+ await msg.stream_token(chunk)
25
+
26
+ await msg.send()
27
+
28
+ # Run the app
29
+ if __name__ == "__main__":
30
+ cl.run()
prompt_templates.py ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Few-shot examples
2
+ few_shot_examples = [
3
+ {"input": "List all artists.", "query": "SELECT * FROM Artist;"},
4
+ {
5
+ "input": "Find all albums for the artist 'AC/DC'.",
6
+ "query": "SELECT * FROM Album WHERE ArtistId = (SELECT ArtistId FROM Artist WHERE Name = 'AC/DC');",
7
+ },
8
+ {
9
+ "input": "List all tracks in the 'Rock' genre.",
10
+ "query": "SELECT * FROM Track WHERE GenreId = (SELECT GenreId FROM Genre WHERE Name = 'Rock');",
11
+ },
12
+ {
13
+ "input": "Find the total duration of all tracks.",
14
+ "query": "SELECT SUM(Milliseconds) FROM Track;",
15
+ },
16
+ {
17
+ "input": "List all customers from Canada.",
18
+ "query": "SELECT * FROM Customer WHERE Country = 'Canada';",
19
+ },
20
+ {
21
+ "input": "How many tracks are there in the album with ID 5?",
22
+ "query": "SELECT COUNT(*) FROM Track WHERE AlbumId = 5;",
23
+ },
24
+ {
25
+ "input": "Find the total number of invoices.",
26
+ "query": "SELECT COUNT(*) FROM Invoice;",
27
+ },
28
+ {
29
+ "input": "List all tracks that are longer than 5 minutes.",
30
+ "query": "SELECT * FROM Track WHERE Milliseconds > 300000;",
31
+ },
32
+ {
33
+ "input": "Who are the top 5 customers by total purchase?",
34
+ "query": "SELECT CustomerId, SUM(Total) AS TotalPurchase FROM Invoice GROUP BY CustomerId ORDER BY TotalPurchase DESC LIMIT 5;",
35
+ },
36
+ {
37
+ "input": "Which albums are from the year 2000?",
38
+ "query": "SELECT * FROM Album WHERE strftime('%Y', ReleaseDate) = '2000';",
39
+ },
40
+ {
41
+ "input": "How many employees are there",
42
+ "query": 'SELECT COUNT(*) FROM "Employee"',
43
+ },
44
+ ]
45
+
46
+ # System Prompt template prefix
47
+ system_prefix = """You are an agent designed to interact with a SQL database.
48
+ Given an input question, create a syntactically correct {dialect} query to run, then look at the results of the query and return the answer.
49
+ Unless the user specifies a specific number of examples they wish to obtain, always limit your query to at most {top_k} results.
50
+ You can order the results by a relevant column to return the most interesting examples in the database.
51
+ Never query for all the columns from a specific table, only ask for the relevant columns given the question.
52
+ You have access to tools for interacting with the database.
53
+ Only use the given tools. Only use the information returned by the tools to construct your final answer.
54
+ You MUST double check your query before executing it. If you get an error while executing a query, rewrite the query and try again.
55
+
56
+ DO NOT make any DML statements (INSERT, UPDATE, DELETE, DROP etc.) to the database.
57
+
58
+ If the question does not seem related to the database, just return "I don't know" as the answer.
59
+
60
+ Here are some examples of user inputs and their corresponding SQL queries:"""
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ langchain
2
+ langchain-community
3
+ langchain-openai
4
+ python-dotenv
5
+ faiss-cpu
6
+ chainlit
sql_agent.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dotenv import load_dotenv
2
+ from langchain_community.agent_toolkits import create_sql_agent
3
+ from langchain_community.vectorstores import FAISS
4
+ from langchain_core.example_selectors import SemanticSimilarityExampleSelector
5
+ from langchain_core.prompts import ChatPromptTemplate, FewShotPromptTemplate, MessagesPlaceholder, PromptTemplate, SystemMessagePromptTemplate
6
+ from langchain_openai import OpenAIEmbeddings
7
+ from langchain_openai import ChatOpenAI
8
+ from langchain_community.utilities import SQLDatabase
9
+ from prompt_templates import few_shot_examples, system_prefix
10
+
11
+
12
+ # Load the .env file
13
+ load_dotenv()
14
+
15
+ # Initialize the SQL database
16
+ db = SQLDatabase.from_uri("sqlite:///Chinook.db")
17
+
18
+ # Check the database connection
19
+ print(db.dialect)
20
+ print(db.get_usable_table_names())
21
+ db.run("SELECT * FROM Artist LIMIT 10;")
22
+
23
+ # Initialize the LLM
24
+ llm = ChatOpenAI(model="gpt-4o", temperature=0)
25
+
26
+
27
+ # Example selector will dynamically select examples based on the input question
28
+ example_selector = SemanticSimilarityExampleSelector.from_examples(
29
+ few_shot_examples,
30
+ OpenAIEmbeddings(),
31
+ FAISS,
32
+ k=5,
33
+ input_keys=["input"],
34
+ )
35
+
36
+ # Few-shot prompt template
37
+ few_shot_prompt = FewShotPromptTemplate(
38
+ example_selector=example_selector,
39
+ example_prompt=PromptTemplate.from_template(
40
+ "User input: {input}\nSQL query: {query}"
41
+ ),
42
+ input_variables=["input", "dialect", "top_k"],
43
+ prefix=system_prefix,
44
+ suffix="",
45
+ )
46
+
47
+ # Full prompt template
48
+ full_prompt = ChatPromptTemplate.from_messages(
49
+ [
50
+ SystemMessagePromptTemplate(prompt=few_shot_prompt),
51
+ ("human", "{input}"),
52
+ MessagesPlaceholder("agent_scratchpad"),
53
+ ]
54
+ )
55
+
56
+ # Create the SQL agent
57
+ SQLAgent = create_sql_agent(
58
+ llm=llm,
59
+ db=db,
60
+ prompt=full_prompt,
61
+ verbose=True,
62
+ agent_type="openai-tools",
63
+ )