WebashalarForML commited on
Commit
522caaa
·
verified ·
1 Parent(s): a73d493

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +162 -161
app.py CHANGED
@@ -49,180 +49,181 @@ os.environ["MISTRAL_API_KEY"] = MISTRAL_API_KEY
49
  from langchain_groq import ChatGroq
50
  from langchain_mistralai.chat_models import ChatMistralAI
51
 
52
- def create_agent_app(db_path: str):
53
- # Construct the SQLite URI from the given file path.
54
- # Ensure the db_path is absolute so that SQLAlchemy can locate the file.
55
- #llm = ChatMistralAI(model="mistral-large-latest")
56
- llm = ChatGroq(model="llama3-70b-8192")
57
-
58
- @tool
59
- def db_query_tool(query: str) -> str:
60
- """
61
- Execute a SQL query against the database and return the result.
62
- If the query is invalid or returns no result, an error message will be returned.
63
- In case of an error, the user is advised to rewrite the query and try again.
64
- """
65
- result = db_instance.run_no_throw(query)
66
- if not result:
67
- return "Error: Query failed. Please rewrite your query and try again."
68
- return result
69
-
70
- # Define a Pydantic model for submitting the final answer
71
- class SubmitFinalAnswer(BaseModel):
72
- """Submit the final answer to the user based on the query results."""
73
- final_answer: str = Field(..., description="The final answer to the user")
74
-
75
- # Define the state type
76
- class State(TypedDict):
77
- messages: Annotated[list[AnyMessage], add_messages]
78
-
79
- # Define prompt templates for query checking and query generation
80
- from langchain_core.prompts import ChatPromptTemplate
81
-
82
- query_check_system = """You are a SQL expert with a strong attention to detail.
83
- Double check the SQLite query for common mistakes, including:
84
- - Using NOT IN with NULL values
85
- - Using UNION when UNION ALL should have been used
86
- - Using BETWEEN for exclusive ranges
87
- - Data type mismatch in predicates
88
- - Properly quoting identifiers
89
- - Using the correct number of arguments for functions
90
- - Casting to the correct data type
91
- - Using the proper columns for joins
92
-
93
- If there are any of the above mistakes, rewrite the query. If there are no mistakes, just reproduce the original query.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
94
 
95
- You will call the appropriate tool to execute the query after running this check."""
96
- query_check_prompt = ChatPromptTemplate.from_messages([("system", query_check_system), ("placeholder", "{messages}")])
97
- query_check = query_check_prompt | llm.bind_tools([db_query_tool])
98
 
99
- query_gen_system = """You are a SQL expert with a strong attention to detail.
 
 
 
100
 
101
- Given an input question, output a syntactically correct SQLite query to run, then look at the results of the query and return the answer.
 
 
102
 
103
- DO NOT call any tool besides SubmitFinalAnswer to submit the final answer.
 
 
 
 
 
 
 
 
104
 
105
- When generating the query:
 
106
 
107
- Output the SQL query that answers the input question without a tool call.
 
 
 
 
 
 
 
 
 
 
 
 
 
108
 
109
- Unless the user specifies a specific number of examples they wish to obtain, always limit your query to at most 5 results.
110
- You can order the results by a relevant column to return the most interesting examples in the database.
111
- Never query for all the columns from a specific table, only ask for the relevant columns given the question.
 
 
 
 
 
 
112
 
113
- If you get an error while executing a query, rewrite the query and try again.
 
 
114
 
115
- If you get an empty result set, you should try to rewrite the query to get a non-empty result set.
116
- NEVER make stuff up if you don't have enough information to answer the query... just say you don't have enough information.
 
117
 
118
- If you have enough information to answer the input question, simply invoke the appropriate tool to submit the final answer to the user.
 
 
 
 
 
 
 
 
 
119
 
120
- DO NOT make any DML statements (INSERT, UPDATE, DELETE, DROP etc.) to the database. Do not return any sql query except answer."""
121
- query_gen_prompt = ChatPromptTemplate.from_messages([("system", query_gen_system), ("placeholder", "{messages}")])
122
- query_gen = query_gen_prompt | llm.bind_tools([SubmitFinalAnswer])
 
 
 
 
 
123
 
124
- abs_db_path = os.path.abspath(db_path)
125
- global DATABASE_URI
126
- DATABASE_URI = abs_db_path
127
- db_uri = f"sqlite:///{abs_db_path}"
128
- print("db_uri",db_uri)
129
 
130
- # Create new SQLDatabase connection using the constructed URI.
131
- from langchain_community.utilities import SQLDatabase
132
- db_instance = SQLDatabase.from_uri(db_uri)
133
- print("db_instance----->",db_instance)
134
- print("db_uri----->",db_uri)
135
-
136
-
137
- # Create SQL toolkit and get the tools.
138
- from langchain_community.agent_toolkits import SQLDatabaseToolkit
139
- toolkit_instance = SQLDatabaseToolkit(db=db_instance, llm=llm)
140
- tools_instance = toolkit_instance.get_tools()
141
-
142
- # Define workflow nodes and fallback functions
143
- def first_tool_call(state: State) -> dict[str, list[AIMessage]]:
144
- return {"messages": [AIMessage(content="", tool_calls=[{"name": "sql_db_list_tables", "args": {}, "id": "tool_abcd123"}])]}
145
-
146
- def handle_tool_error(state: State) -> dict:
147
- error = state.get("error")
148
- tool_calls = state["messages"][-1].tool_calls
149
- return {
150
- "messages": [
151
- ToolMessage(content=f"Error: {repr(error)}\n please fix your mistakes.", tool_call_id=tc["id"])
152
- for tc in tool_calls
153
- ]
154
- }
155
-
156
- def create_tool_node_with_fallback(tools_list: list) -> RunnableWithFallbacks[Any, dict]:
157
- return ToolNode(tools_list).with_fallbacks([RunnableLambda(handle_tool_error)], exception_key="error")
158
-
159
- def query_gen_node(state: State):
160
- message = query_gen.invoke(state)
161
- # Check for incorrect tool calls
162
- tool_messages = []
163
- if message.tool_calls:
164
- for tc in message.tool_calls:
165
- if tc["name"] != "SubmitFinalAnswer":
166
- tool_messages.append(
167
- ToolMessage(
168
- content=f"Error: The wrong tool was called: {tc['name']}. Please fix your mistakes. Remember to only call SubmitFinalAnswer to submit the final answer. Generated queries should be outputted WITHOUT a tool call.",
169
- tool_call_id=tc["id"],
170
- )
171
- )
172
- return {"messages": [message] + tool_messages}
173
-
174
- def should_continue(state: State) -> Literal[END, "correct_query", "query_gen"]:
175
- messages = state["messages"]
176
- last_message = messages[-1]
177
- if getattr(last_message, "tool_calls", None):
178
- return END
179
- if last_message.content.startswith("Error:"):
180
- return "query_gen"
181
- else:
182
- return "correct_query"
183
-
184
- def model_check_query(state: State) -> dict[str, list[AIMessage]]:
185
- """Double-check if the query is correct before executing it."""
186
- return {"messages": [query_check.invoke({"messages": [state["messages"][-1]]})]}
187
-
188
- # Get tools for listing tables and fetching schema
189
- list_tables_tool = next((tool for tool in tools_instance if tool.name == "sql_db_list_tables"), None)
190
- get_schema_tool = next((tool for tool in tools_instance if tool.name == "sql_db_schema"), None)
191
-
192
- # Define the workflow (state graph)
193
- workflow = StateGraph(State)
194
- workflow.add_node("first_tool_call", first_tool_call)
195
- workflow.add_node("list_tables_tool", create_tool_node_with_fallback([list_tables_tool]))
196
- workflow.add_node("get_schema_tool", create_tool_node_with_fallback([get_schema_tool]))
197
- model_get_schema = llm.bind_tools([get_schema_tool])
198
- workflow.add_node("model_get_schema", lambda state: {"messages": [model_get_schema.invoke(state["messages"])],})
199
- workflow.add_node("query_gen", query_gen_node)
200
- workflow.add_node("correct_query", model_check_query)
201
- workflow.add_node("execute_query", create_tool_node_with_fallback([db_query_tool]))
202
-
203
- workflow.add_edge(START, "first_tool_call")
204
- workflow.add_edge("first_tool_call", "list_tables_tool")
205
- workflow.add_edge("list_tables_tool", "model_get_schema")
206
- workflow.add_edge("model_get_schema", "get_schema_tool")
207
- workflow.add_edge("get_schema_tool", "query_gen")
208
- workflow.add_conditional_edges("query_gen", should_continue)
209
- workflow.add_edge("correct_query", "execute_query")
210
- workflow.add_edge("execute_query", "query_gen")
211
-
212
- # Compile and return the agent application workflow.
213
- return workflow.compile()
214
-
215
- ###############################################################################
216
- # Application Factory: create_app()
217
- #
218
- # This function sets up the Flask application, SocketIO, routes, and initializes
219
- # the global agent_app using the default DATABASE_URI. It returns the Flask app.
220
- ###############################################################################
221
- # --- Application Factory ---
222
- abs_file_path = None
223
- agent_app= None
224
-
225
- def create_app():
226
  # Option: configure static files from uploads folder as well.
227
  flask_app = Flask(__name__, static_url_path='/uploads', static_folder='uploads')
228
  socketio = SocketIO(flask_app, cors_allowed_origins="*")
 
49
  from langchain_groq import ChatGroq
50
  from langchain_mistralai.chat_models import ChatMistralAI
51
 
52
+ ###############################################################################
53
+ # Application Factory: create_app()
54
+ #
55
+ # This function sets up the Flask application, SocketIO, routes, and initializes
56
+ # the global agent_app using the default DATABASE_URI. It returns the Flask app.
57
+ ###############################################################################
58
+ # --- Application Factory ---
59
+ abs_file_path = None
60
+ agent_app= None
61
+
62
+ def create_app():
63
+
64
+ def create_agent_app(db_path: str):
65
+ # Construct the SQLite URI from the given file path.
66
+ # Ensure the db_path is absolute so that SQLAlchemy can locate the file.
67
+ #llm = ChatMistralAI(model="mistral-large-latest")
68
+ llm = ChatGroq(model="llama3-70b-8192")
69
+
70
+ @tool
71
+ def db_query_tool(query: str) -> str:
72
+ """
73
+ Execute a SQL query against the database and return the result.
74
+ If the query is invalid or returns no result, an error message will be returned.
75
+ In case of an error, the user is advised to rewrite the query and try again.
76
+ """
77
+ result = db_instance.run_no_throw(query)
78
+ if not result:
79
+ return "Error: Query failed. Please rewrite your query and try again."
80
+ return result
81
+
82
+ # Define a Pydantic model for submitting the final answer
83
+ class SubmitFinalAnswer(BaseModel):
84
+ """Submit the final answer to the user based on the query results."""
85
+ final_answer: str = Field(..., description="The final answer to the user")
86
+
87
+ # Define the state type
88
+ class State(TypedDict):
89
+ messages: Annotated[list[AnyMessage], add_messages]
90
+
91
+ # Define prompt templates for query checking and query generation
92
+ from langchain_core.prompts import ChatPromptTemplate
93
+
94
+ query_check_system = """You are a SQL expert with a strong attention to detail.
95
+ Double check the SQLite query for common mistakes, including:
96
+ - Using NOT IN with NULL values
97
+ - Using UNION when UNION ALL should have been used
98
+ - Using BETWEEN for exclusive ranges
99
+ - Data type mismatch in predicates
100
+ - Properly quoting identifiers
101
+ - Using the correct number of arguments for functions
102
+ - Casting to the correct data type
103
+ - Using the proper columns for joins
104
+
105
+ If there are any of the above mistakes, rewrite the query. If there are no mistakes, just reproduce the original query.
106
+
107
+ You will call the appropriate tool to execute the query after running this check."""
108
+ query_check_prompt = ChatPromptTemplate.from_messages([("system", query_check_system), ("placeholder", "{messages}")])
109
+ query_check = query_check_prompt | llm.bind_tools([db_query_tool])
110
+
111
+ query_gen_system = """You are a SQL expert with a strong attention to detail.
112
+
113
+ Given an input question, output a syntactically correct SQLite query to run, then look at the results of the query and return the answer.
114
+
115
+ DO NOT call any tool besides SubmitFinalAnswer to submit the final answer.
116
+
117
+ When generating the query:
118
+
119
+ Output the SQL query that answers the input question without a tool call.
120
+
121
+ Unless the user specifies a specific number of examples they wish to obtain, always limit your query to at most 5 results.
122
+ You can order the results by a relevant column to return the most interesting examples in the database.
123
+ Never query for all the columns from a specific table, only ask for the relevant columns given the question.
124
+
125
+ If you get an error while executing a query, rewrite the query and try again.
126
+
127
+ If you get an empty result set, you should try to rewrite the query to get a non-empty result set.
128
+ NEVER make stuff up if you don't have enough information to answer the query... just say you don't have enough information.
129
+
130
+ If you have enough information to answer the input question, simply invoke the appropriate tool to submit the final answer to the user.
131
+
132
+ DO NOT make any DML statements (INSERT, UPDATE, DELETE, DROP etc.) to the database. Do not return any sql query except answer."""
133
+ query_gen_prompt = ChatPromptTemplate.from_messages([("system", query_gen_system), ("placeholder", "{messages}")])
134
+ query_gen = query_gen_prompt | llm.bind_tools([SubmitFinalAnswer])
135
+
136
+ abs_db_path = os.path.abspath(db_path)
137
+ global DATABASE_URI
138
+ DATABASE_URI = abs_db_path
139
+ db_uri = f"sqlite:///{abs_db_path}"
140
+ print("db_uri",db_uri)
141
+
142
+ # Create new SQLDatabase connection using the constructed URI.
143
+ from langchain_community.utilities import SQLDatabase
144
+ db_instance = SQLDatabase.from_uri(db_uri)
145
+ print("db_instance----->",db_instance)
146
+ print("db_uri----->",db_uri)
147
 
 
 
 
148
 
149
+ # Create SQL toolkit and get the tools.
150
+ from langchain_community.agent_toolkits import SQLDatabaseToolkit
151
+ toolkit_instance = SQLDatabaseToolkit(db=db_instance, llm=llm)
152
+ tools_instance = toolkit_instance.get_tools()
153
 
154
+ # Define workflow nodes and fallback functions
155
+ def first_tool_call(state: State) -> dict[str, list[AIMessage]]:
156
+ return {"messages": [AIMessage(content="", tool_calls=[{"name": "sql_db_list_tables", "args": {}, "id": "tool_abcd123"}])]}
157
 
158
+ def handle_tool_error(state: State) -> dict:
159
+ error = state.get("error")
160
+ tool_calls = state["messages"][-1].tool_calls
161
+ return {
162
+ "messages": [
163
+ ToolMessage(content=f"Error: {repr(error)}\n please fix your mistakes.", tool_call_id=tc["id"])
164
+ for tc in tool_calls
165
+ ]
166
+ }
167
 
168
+ def create_tool_node_with_fallback(tools_list: list) -> RunnableWithFallbacks[Any, dict]:
169
+ return ToolNode(tools_list).with_fallbacks([RunnableLambda(handle_tool_error)], exception_key="error")
170
 
171
+ def query_gen_node(state: State):
172
+ message = query_gen.invoke(state)
173
+ # Check for incorrect tool calls
174
+ tool_messages = []
175
+ if message.tool_calls:
176
+ for tc in message.tool_calls:
177
+ if tc["name"] != "SubmitFinalAnswer":
178
+ tool_messages.append(
179
+ ToolMessage(
180
+ content=f"Error: The wrong tool was called: {tc['name']}. Please fix your mistakes. Remember to only call SubmitFinalAnswer to submit the final answer. Generated queries should be outputted WITHOUT a tool call.",
181
+ tool_call_id=tc["id"],
182
+ )
183
+ )
184
+ return {"messages": [message] + tool_messages}
185
 
186
+ def should_continue(state: State) -> Literal[END, "correct_query", "query_gen"]:
187
+ messages = state["messages"]
188
+ last_message = messages[-1]
189
+ if getattr(last_message, "tool_calls", None):
190
+ return END
191
+ if last_message.content.startswith("Error:"):
192
+ return "query_gen"
193
+ else:
194
+ return "correct_query"
195
 
196
+ def model_check_query(state: State) -> dict[str, list[AIMessage]]:
197
+ """Double-check if the query is correct before executing it."""
198
+ return {"messages": [query_check.invoke({"messages": [state["messages"][-1]]})]}
199
 
200
+ # Get tools for listing tables and fetching schema
201
+ list_tables_tool = next((tool for tool in tools_instance if tool.name == "sql_db_list_tables"), None)
202
+ get_schema_tool = next((tool for tool in tools_instance if tool.name == "sql_db_schema"), None)
203
 
204
+ # Define the workflow (state graph)
205
+ workflow = StateGraph(State)
206
+ workflow.add_node("first_tool_call", first_tool_call)
207
+ workflow.add_node("list_tables_tool", create_tool_node_with_fallback([list_tables_tool]))
208
+ workflow.add_node("get_schema_tool", create_tool_node_with_fallback([get_schema_tool]))
209
+ model_get_schema = llm.bind_tools([get_schema_tool])
210
+ workflow.add_node("model_get_schema", lambda state: {"messages": [model_get_schema.invoke(state["messages"])],})
211
+ workflow.add_node("query_gen", query_gen_node)
212
+ workflow.add_node("correct_query", model_check_query)
213
+ workflow.add_node("execute_query", create_tool_node_with_fallback([db_query_tool]))
214
 
215
+ workflow.add_edge(START, "first_tool_call")
216
+ workflow.add_edge("first_tool_call", "list_tables_tool")
217
+ workflow.add_edge("list_tables_tool", "model_get_schema")
218
+ workflow.add_edge("model_get_schema", "get_schema_tool")
219
+ workflow.add_edge("get_schema_tool", "query_gen")
220
+ workflow.add_conditional_edges("query_gen", should_continue)
221
+ workflow.add_edge("correct_query", "execute_query")
222
+ workflow.add_edge("execute_query", "query_gen")
223
 
224
+ # Compile and return the agent application workflow.
225
+ return workflow.compile()
 
 
 
226
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
227
  # Option: configure static files from uploads folder as well.
228
  flask_app = Flask(__name__, static_url_path='/uploads', static_folder='uploads')
229
  socketio = SocketIO(flask_app, cors_allowed_origins="*")