JulsdL commited on
Commit
3b5f5c4
·
1 Parent(s): bb2dcc7

Enhance sql_agent.py by adding AST and regex imports, implementing query_as_list function for database querying, and integrating a new retriever tool for proper noun search using FAISS vector store.

Browse files
Files changed (1) hide show
  1. sql_agent.py +28 -1
sql_agent.py CHANGED
@@ -1,5 +1,8 @@
 
 
1
  from dotenv import load_dotenv
2
  from langchain_community.agent_toolkits import create_sql_agent
 
3
  from langchain_community.vectorstores import FAISS
4
  from langchain_core.example_selectors import SemanticSimilarityExampleSelector
5
  from langchain_core.prompts import ChatPromptTemplate, FewShotPromptTemplate, MessagesPlaceholder, PromptTemplate, SystemMessagePromptTemplate
@@ -23,6 +26,29 @@ db.run("SELECT * FROM Artist LIMIT 10;")
23
  # Initialize the LLM
24
  llm = ChatOpenAI(model="gpt-4o", temperature=0)
25
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
 
27
  # Example selector will dynamically select examples based on the input question
28
  example_selector = SemanticSimilarityExampleSelector.from_examples(
@@ -57,7 +83,8 @@ full_prompt = ChatPromptTemplate.from_messages(
57
  SQLAgent = create_sql_agent(
58
  llm=llm,
59
  db=db,
 
60
  prompt=full_prompt,
61
- verbose=True,
62
  agent_type="openai-tools",
 
63
  )
 
1
+ import ast
2
+ import re
3
  from dotenv import load_dotenv
4
  from langchain_community.agent_toolkits import create_sql_agent
5
+ from langchain.tools.retriever import create_retriever_tool
6
  from langchain_community.vectorstores import FAISS
7
  from langchain_core.example_selectors import SemanticSimilarityExampleSelector
8
  from langchain_core.prompts import ChatPromptTemplate, FewShotPromptTemplate, MessagesPlaceholder, PromptTemplate, SystemMessagePromptTemplate
 
26
  # Initialize the LLM
27
  llm = ChatOpenAI(model="gpt-4o", temperature=0)
28
 
29
+ # Function to query database and get list of elements
30
+ def query_as_list(db, query):
31
+ res = db.run(query)
32
+ res = [el for sub in ast.literal_eval(res) for el in sub if el]
33
+ res = [re.sub(r"\b\d+\b", "", string).strip() for string in res]
34
+ return list(set(res))
35
+
36
+ # Create lists of artists and albums
37
+ artists = query_as_list(db, "SELECT Name FROM Artist")
38
+ albums = query_as_list(db, "SELECT Title FROM Album")
39
+
40
+ # Create a vector store and use it as a retriever
41
+ vector_db = FAISS.from_texts(artists + albums, OpenAIEmbeddings())
42
+ retriever = vector_db.as_retriever(search_kwargs={"k": 5})
43
+
44
+ # Create a search proper nouns tool
45
+ description = """Use to look up values to filter on. Input is an approximate spelling of the proper noun, output is \
46
+ valid proper nouns. Use the noun most similar to the search."""
47
+ retriever_tool = create_retriever_tool(
48
+ retriever,
49
+ name="search_proper_nouns",
50
+ description=description,
51
+ )
52
 
53
  # Example selector will dynamically select examples based on the input question
54
  example_selector = SemanticSimilarityExampleSelector.from_examples(
 
83
  SQLAgent = create_sql_agent(
84
  llm=llm,
85
  db=db,
86
+ extra_tools=[retriever_tool],
87
  prompt=full_prompt,
 
88
  agent_type="openai-tools",
89
+ verbose=True,
90
  )