File size: 4,772 Bytes
e08de36
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b13c6ff
e08de36
 
f464845
21705f0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e08de36
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
from langchain_community.utilities import SQLDatabase
from langchain_core.callbacks import BaseCallbackHandler
from typing import TYPE_CHECKING, Any, Optional, TypeVar, Union
from uuid import UUID
from langchain_community.agent_toolkits import create_sql_agent
from langchain_openai import ChatOpenAI
from langchain_community.vectorstores import Chroma
from langchain_core.example_selectors import SemanticSimilarityExampleSelector
from langchain_openai import OpenAIEmbeddings
from langchain.agents.agent_toolkits import create_retriever_tool
from langchain_core.output_parsers import JsonOutputParser
import os
from langchain_core.prompts import (
    ChatPromptTemplate,
    FewShotPromptTemplate,
    MessagesPlaceholder,
    PromptTemplate,
    SystemMessagePromptTemplate,
)
import ast
from fewshot import examples
import re

parser = JsonOutputParser()
llm = ChatOpenAI(model="gpt-4o-mini", temperature=0, api_key=os.environ['API_KEY'])
example_selector = SemanticSimilarityExampleSelector.from_examples(
    examples,
    OpenAIEmbeddings(model="text-embedding-3-small", api_key=os.environ['API_KEY']),
    Chroma(persist_directory="data"),
    # Chroma,
    k=5,
    input_keys=["input"],
)

db = SQLDatabase.from_uri("sqlite:///attendance_system.db")

employee = query_as_list(db, "SELECT FullName FROM Employee")

vector_db = Chroma.from_texts(employee, OpenAIEmbeddings(model="text-embedding-3-small", api_key=os.environ['API_KEY']))
retriever = vector_db.as_retriever(search_kwargs={"k": 15})
description = """Use to look up values to filter on. Input is an approximate spelling of the proper noun, output is \
valid proper nouns. Use the noun most similar to the search."""
retriever_tool = create_retriever_tool(
    retriever,
    name="search_proper_nouns",
    description=description,
)


def query_as_list(db, query):
    res = db.run(query)
    res = [el for sub in ast.literal_eval(res) for el in sub if el]
    res = [re.sub(r"\b\d+\b", "", string).strip() for string in res]
    return list(set(res))


def get_answer(user_query):

    global retriever_tool, example_selector, db, llm


    system_prefix = """You are an agent designed to interact with a SQL database.
    Given an input question, create a syntactically correct {dialect} query to run, then look at the results of the query and return the answer.
    Unless the user specifies a specific number of examples they wish to obtain, always limit your query to at most {top_k} results.
    You can order the results by a relevant column to return the most interesting examples in the database.
    Never query for all the columns from a specific table, only ask for the relevant columns given the question.
    You have access to tools for interacting with the database.
    Only use the given tools. Only use the information returned by the tools to construct your final answer.
    You MUST double check your query before executing it. If you get an error while executing a query, rewrite the query and try again.

    DO NOT make any DML statements (INSERT, UPDATE, DELETE, DROP etc.) to the database.

    If the question does not seem related to the database, just return "I don't know" as the answer.

    Here are some examples of user inputs and their corresponding SQL queries:"""

    few_shot_prompt = FewShotPromptTemplate(
        example_selector=example_selector,
        example_prompt=PromptTemplate.from_template(
            "User input: {input}\nSQL query: {query}"
        ),
        input_variables=["input", "dialect", "top_k"],
        prefix=system_prefix,
        suffix="",
    )

    employee = query_as_list(db, "SELECT FullName FROM Employee")
    system_unique_name_prompt = """
    If you need to filter on a proper noun, you must ALWAYS first look up the filter value using the "search_proper_nouns" tool!

    You have access to the following tables: {table_names}

    If the question does not seem related to the database, just return "I don't know" as the answer.

    """


    prompt_val = few_shot_prompt.invoke(
        {
            "input": user_query,
            "top_k": 5,
            "dialect": "SQLite",

            "agent_scratchpad": [],
        }
    )

    final_prompt = prompt_val.to_string() + '\n' + system_unique_name_prompt
    full_prompt = ChatPromptTemplate.from_messages(
        [
            ("system",final_prompt),
            ("human", "{input}"),
            MessagesPlaceholder("agent_scratchpad"),
        ]
    )


    agent = create_sql_agent(
        llm=llm,
        db=db,
        max_iterations = 40,
        extra_tools=[retriever_tool],
        prompt=full_prompt,
        agent_type="openai-tools",
        verbose=True,
    )

    result = agent.invoke({'input': user_query})

    return result['output']