Spaces:
Paused
Paused
import ast | |
import re | |
from dotenv import load_dotenv | |
from langchain_community.agent_toolkits import create_sql_agent | |
from langchain.tools.retriever import create_retriever_tool | |
from langchain_community.vectorstores import FAISS | |
from langchain_core.example_selectors import SemanticSimilarityExampleSelector | |
from langchain_core.prompts import ChatPromptTemplate, FewShotPromptTemplate, MessagesPlaceholder, PromptTemplate, SystemMessagePromptTemplate | |
from langchain_openai import OpenAIEmbeddings | |
from langchain_openai import ChatOpenAI | |
from langchain_community.utilities import SQLDatabase | |
from smartquery.prompt_templates import few_shot_examples, system_prefix | |
# Load the .env file | |
load_dotenv() | |
# Initialize the SQL database | |
db = SQLDatabase.from_uri("sqlite:///database/Chinook.db") | |
# Check the database connection | |
print(db.dialect) | |
print(db.get_usable_table_names()) | |
db.run("SELECT * FROM Artist LIMIT 10;") | |
# Initialize the LLM | |
llm = ChatOpenAI(model="gpt-4o", temperature=0) | |
# Function to query database and get list of elements | |
def query_as_list(db, query): | |
res = db.run(query) | |
res = [el for sub in ast.literal_eval(res) for el in sub if el] | |
res = [re.sub(r"\b\d+\b", "", string).strip() for string in res] | |
return list(set(res)) | |
# Create lists of artists and albums | |
artists = query_as_list(db, "SELECT Name FROM Artist") | |
albums = query_as_list(db, "SELECT Title FROM Album") | |
# Create a vector store and use it as a retriever | |
vector_db = FAISS.from_texts(artists + albums, OpenAIEmbeddings()) | |
retriever = vector_db.as_retriever(search_kwargs={"k": 5}) | |
# Create a search proper nouns tool | |
description = """Use to look up values to filter on. Input is an approximate spelling of the proper noun, output is \ | |
valid proper nouns. Use the noun most similar to the search.""" | |
retriever_tool = create_retriever_tool( | |
retriever, | |
name="search_proper_nouns", | |
description=description, | |
) | |
# Example selector will dynamically select examples based on the input question | |
example_selector = SemanticSimilarityExampleSelector.from_examples( | |
few_shot_examples, | |
OpenAIEmbeddings(), | |
FAISS, | |
k=5, | |
input_keys=["input"], | |
) | |
# Few-shot prompt template | |
few_shot_prompt = FewShotPromptTemplate( | |
example_selector=example_selector, | |
example_prompt=PromptTemplate.from_template( | |
"User input: {input}\nSQL query: {query}" | |
), | |
input_variables=["input", "dialect", "top_k"], | |
prefix=system_prefix, | |
suffix="", | |
) | |
# Full prompt template | |
full_prompt = ChatPromptTemplate.from_messages( | |
[ | |
SystemMessagePromptTemplate(prompt=few_shot_prompt), | |
("human", "{input}"), | |
MessagesPlaceholder("agent_scratchpad"), | |
] | |
) | |
# Create the SQL agent | |
SQLAgent = create_sql_agent( | |
llm=llm, | |
db=db, | |
extra_tools=[retriever_tool], | |
prompt=full_prompt, | |
agent_type="openai-tools", | |
verbose=True, | |
) | |