WebashalarForML commited on
Commit
f648e72
·
verified ·
1 Parent(s): 0175c10

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +143 -121
app.py CHANGED
@@ -50,6 +50,19 @@ from langchain_community.agent_toolkits import SQLDatabaseToolkit
50
  toolkit = SQLDatabaseToolkit(db=db, llm=llm)
51
  tools = toolkit.get_tools()
52
 
 
 
 
 
 
 
 
 
 
 
 
 
 
53
  # Define a custom query tool for executing SQL queries
54
  @tool
55
  def db_query_tool(query: str) -> str:
@@ -117,111 +130,134 @@ DO NOT make any DML statements (INSERT, UPDATE, DELETE, DROP etc.) to the databa
117
  query_gen_prompt = ChatPromptTemplate.from_messages([("system", query_gen_system), ("placeholder", "{messages}")])
118
  query_gen = query_gen_prompt | llm.bind_tools([SubmitFinalAnswer])
119
 
120
- # Define nodes and fallback functions for the workflow
121
- def first_tool_call(state: State) -> dict[str, list[AIMessage]]:
122
- return {"messages": [AIMessage(content="", tool_calls=[{"name": "sql_db_list_tables", "args": {}, "id": "tool_abcd123"}])]}
123
-
124
- def handle_tool_error(state: State) -> dict:
125
- error = state.get("error")
126
- tool_calls = state["messages"][-1].tool_calls
127
- return {
128
- "messages": [
129
- ToolMessage(content=f"Error: {repr(error)}\n please fix your mistakes.", tool_call_id=tc["id"])
130
- for tc in tool_calls
131
- ]
132
- }
133
-
134
- def create_tool_node_with_fallback(tools_list: list) -> RunnableWithFallbacks[Any, dict]:
135
- return ToolNode(tools_list).with_fallbacks([RunnableLambda(handle_tool_error)], exception_key="error")
136
-
137
- def query_gen_node(state: State):
138
- message = query_gen.invoke(state)
139
- # Check for incorrect tool calls
140
- tool_messages = []
141
- if message.tool_calls:
142
- for tc in message.tool_calls:
143
- if tc["name"] != "SubmitFinalAnswer":
144
- tool_messages.append(
145
- ToolMessage(
146
- content=f"Error: The wrong tool was called: {tc['name']}. Please fix your mistakes. Remember to only call SubmitFinalAnswer to submit the final answer. Generated queries should be outputted WITHOUT a tool call.",
147
- tool_call_id=tc["id"],
 
148
  )
149
- )
150
- return {"messages": [message] + tool_messages}
151
-
152
- def should_continue(state: State) -> Literal[END, "correct_query", "query_gen"]:
153
- messages = state["messages"]
154
- last_message = messages[-1]
155
- if getattr(last_message, "tool_calls", None):
156
- return END
157
- if last_message.content.startswith("Error:"):
158
- return "query_gen"
159
- else:
160
- return "correct_query"
161
-
162
- def model_check_query(state: State) -> dict[str, list[AIMessage]]:
163
- """Double-check if the query is correct before executing it."""
164
- return {"messages": [query_check.invoke({"messages": [state["messages"][-1]]})]}
165
-
166
- # Get tools for listing tables and fetching schema
167
- list_tables_tool = next((tool for tool in tools if tool.name == "sql_db_list_tables"), None)
168
- get_schema_tool = next((tool for tool in tools if tool.name == "sql_db_schema"), None)
169
-
170
- # Define the workflow (state graph)
171
- workflow = StateGraph(State)
172
- workflow.add_node("first_tool_call", first_tool_call)
173
- workflow.add_node("list_tables_tool", create_tool_node_with_fallback([list_tables_tool]))
174
- workflow.add_node("get_schema_tool", create_tool_node_with_fallback([get_schema_tool]))
175
- model_get_schema = llm.bind_tools([get_schema_tool])
176
- workflow.add_node("model_get_schema", lambda state: {"messages": [model_get_schema.invoke(state["messages"])],})
177
- workflow.add_node("query_gen", query_gen_node)
178
- workflow.add_node("correct_query", model_check_query)
179
- workflow.add_node("execute_query", create_tool_node_with_fallback([db_query_tool]))
180
-
181
- workflow.add_edge(START, "first_tool_call")
182
- workflow.add_edge("first_tool_call", "list_tables_tool")
183
- workflow.add_edge("list_tables_tool", "model_get_schema")
184
- workflow.add_edge("model_get_schema", "get_schema_tool")
185
- workflow.add_edge("get_schema_tool", "query_gen")
186
- workflow.add_conditional_edges("query_gen", should_continue)
187
- workflow.add_edge("correct_query", "execute_query")
188
- workflow.add_edge("execute_query", "query_gen")
189
-
190
- # Compile the workflow into an agent application.
191
- agent_app = workflow.compile()
192
-
193
- # Initialize Flask and SocketIO
194
- flask_app = Flask(__name__)
195
- socketio = SocketIO(flask_app, cors_allowed_origins="*")
196
-
197
- # Set up an uploads directory
198
- UPLOAD_FOLDER = os.path.join(os.getcwd(), "uploads")
199
- if not os.path.exists(UPLOAD_FOLDER):
200
- os.makedirs(UPLOAD_FOLDER)
201
-
202
- # Create a global agent_app using the default DATABASE_URI
203
- agent_app = create_agent_app(DATABASE_URI)
204
-
205
- # Endpoint for uploading a DB file
206
- @flask_app.route("/upload", methods=["GET", "POST"])
207
- def upload():
208
- if request.method == "POST":
209
- file = request.files.get("file")
210
- if not file:
211
- return "No file uploaded", 400
212
- file_path = os.path.join(UPLOAD_FOLDER, file.filename)
213
- file.save(file_path)
214
- # Build a new URI (for SQLite, use absolute path)
215
- new_db_uri = f"sqlite:///{file_path}"
216
- # Reinitialize the agent_app with the new DB
217
- global agent_app
218
- agent_app = create_agent_app(new_db_uri)
219
- socketio.emit("log", {"message": f"[INFO]: Database file '{file.filename}' uploaded and loaded."})
220
- return redirect(url_for("index"))
221
- return render_template("upload.html")
222
-
223
- # Function to run the agent in a separate thread
224
- def run_agent(prompt):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
225
  try:
226
  query = {"messages": [("user", prompt)]}
227
  result = agent_app.invoke(query)
@@ -232,22 +268,8 @@ def run_agent(prompt):
232
  socketio.emit("log", {"message": f"[ERROR]: {str(e)}"})
233
  socketio.emit("final", {"message": "Generation failed."})
234
 
235
- @flask_app.route("/")
236
- def index():
237
- return render_template("index.html")
238
-
239
- @flask_app.route("/generate", methods=["POST"])
240
- def generate():
241
- data = request.json
242
- prompt = data.get("prompt", "")
243
- socketio.emit("log", {"message": f"[INFO]: Received prompt: {prompt}\n"})
244
- # Run the agent in a separate thread
245
- thread = threading.Thread(target=run_agent, args=(prompt,))
246
- thread.start()
247
- return "OK", 200
248
-
249
- # Assign the Flask app to "app" for gunicorn
250
- app = flask_app
251
 
252
  if __name__ == "__main__":
253
- socketio.run(app, debug=True)
 
50
  toolkit = SQLDatabaseToolkit(db=db, llm=llm)
51
  tools = toolkit.get_tools()
52
 
53
+
54
+ def create_agent_app(db_uri: str):
55
+ # Create new SQLDatabase connection
56
+ from langchain_community.utilities import SQLDatabase
57
+ db_instance = SQLDatabase.from_uri(db_uri)
58
+
59
+ # Create SQL toolkit and get the tools
60
+ from langchain_community.agent_toolkits import SQLDatabaseToolkit
61
+ toolkit_instance = SQLDatabaseToolkit(db=db_instance, llm=llm)
62
+ tools_instance = toolkit_instance.get_tools()
63
+
64
+ # Define a custom query tool for executing SQL queries
65
+
66
  # Define a custom query tool for executing SQL queries
67
  @tool
68
  def db_query_tool(query: str) -> str:
 
130
  query_gen_prompt = ChatPromptTemplate.from_messages([("system", query_gen_system), ("placeholder", "{messages}")])
131
  query_gen = query_gen_prompt | llm.bind_tools([SubmitFinalAnswer])
132
 
133
+ # Define workflow nodes and fallback functions
134
+ def first_tool_call(state: State) -> dict[str, list[AIMessage]]:
135
+ return {"messages": [AIMessage(content="", tool_calls=[{"name": "sql_db_list_tables", "args": {}, "id": "tool_abcd123"}])]}
136
+
137
+ def handle_tool_error(state: State) -> dict:
138
+ error = state.get("error")
139
+ tool_calls = state["messages"][-1].tool_calls
140
+ return {
141
+ "messages": [
142
+ ToolMessage(content=f"Error: {repr(error)}\n please fix your mistakes.", tool_call_id=tc["id"])
143
+ for tc in tool_calls
144
+ ]
145
+ }
146
+
147
+ def create_tool_node_with_fallback(tools_list: list) -> RunnableWithFallbacks[Any, dict]:
148
+ return ToolNode(tools_list).with_fallbacks([RunnableLambda(handle_tool_error)], exception_key="error")
149
+
150
+ def query_gen_node(state: State):
151
+ message = query_gen.invoke(state)
152
+ # Check for incorrect tool calls
153
+ tool_messages = []
154
+ if message.tool_calls:
155
+ for tc in message.tool_calls:
156
+ if tc["name"] != "SubmitFinalAnswer":
157
+ tool_messages.append(
158
+ ToolMessage(
159
+ content=f"Error: The wrong tool was called: {tc['name']}. Please fix your mistakes. Remember to only call SubmitFinalAnswer to submit the final answer. Generated queries should be outputted WITHOUT a tool call.",
160
+ tool_call_id=tc["id"],
161
+ )
162
  )
163
+ return {"messages": [message] + tool_messages}
164
+
165
+ def should_continue(state: State) -> Literal[END, "correct_query", "query_gen"]:
166
+ messages = state["messages"]
167
+ last_message = messages[-1]
168
+ if getattr(last_message, "tool_calls", None):
169
+ return END
170
+ if last_message.content.startswith("Error:"):
171
+ return "query_gen"
172
+ else:
173
+ return "correct_query"
174
+
175
+ def model_check_query(state: State) -> dict[str, list[AIMessage]]:
176
+ """Double-check if the query is correct before executing it."""
177
+ return {"messages": [query_check.invoke({"messages": [state["messages"][-1]]})]}
178
+
179
+ # Get tools for listing tables and fetching schema
180
+ list_tables_tool = next((tool for tool in tools_instance if tool.name == "sql_db_list_tables"), None)
181
+ get_schema_tool = next((tool for tool in tools_instance if tool.name == "sql_db_schema"), None)
182
+
183
+ # Define the workflow (state graph)
184
+ workflow = StateGraph(State)
185
+ workflow.add_node("first_tool_call", first_tool_call)
186
+ workflow.add_node("list_tables_tool", create_tool_node_with_fallback([list_tables_tool]))
187
+ workflow.add_node("get_schema_tool", create_tool_node_with_fallback([get_schema_tool]))
188
+ model_get_schema = llm.bind_tools([get_schema_tool])
189
+ workflow.add_node("model_get_schema", lambda state: {"messages": [model_get_schema.invoke(state["messages"])],})
190
+ workflow.add_node("query_gen", query_gen_node)
191
+ workflow.add_node("correct_query", model_check_query)
192
+ workflow.add_node("execute_query", create_tool_node_with_fallback([db_query_tool]))
193
+
194
+ workflow.add_edge(START, "first_tool_call")
195
+ workflow.add_edge("first_tool_call", "list_tables_tool")
196
+ workflow.add_edge("list_tables_tool", "model_get_schema")
197
+ workflow.add_edge("model_get_schema", "get_schema_tool")
198
+ workflow.add_edge("get_schema_tool", "query_gen")
199
+ workflow.add_conditional_edges("query_gen", should_continue)
200
+ workflow.add_edge("correct_query", "execute_query")
201
+ workflow.add_edge("execute_query", "query_gen")
202
+
203
+ # Compile and return the agent application workflow
204
+ return workflow.compile()
205
+
206
+ ###############################################################################
207
+ # Application Factory: create_app()
208
+ #
209
+ # This function sets up the Flask application, SocketIO, routes, and initializes
210
+ # the global agent_app using the default DATABASE_URI. It returns the Flask app.
211
+ ###############################################################################
212
+ def create_app():
213
+ flask_app = Flask(__name__)
214
+ socketio = SocketIO(flask_app, cors_allowed_origins="*")
215
+
216
+ # Set up an uploads directory (for DB file uploads)
217
+ UPLOAD_FOLDER = os.path.join(os.getcwd(), "uploads")
218
+ if not os.path.exists(UPLOAD_FOLDER):
219
+ os.makedirs(UPLOAD_FOLDER)
220
+
221
+ # Create a global agent_app using the default DATABASE_URI
222
+ global agent_app
223
+ agent_app = create_agent_app(DATABASE_URI)
224
+
225
+ @flask_app.route("/")
226
+ def index():
227
+ return render_template("index.html")
228
+
229
+ @flask_app.route("/generate", methods=["POST"])
230
+ def generate():
231
+ data = request.json
232
+ prompt = data.get("prompt", "")
233
+ socketio.emit("log", {"message": f"[INFO]: Received prompt: {prompt}\n"})
234
+ # Run the agent in a separate thread
235
+ thread = threading.Thread(target=run_agent, args=(prompt, socketio))
236
+ thread.start()
237
+ return "OK", 200
238
+
239
+ @flask_app.route("/upload", methods=["GET", "POST"])
240
+ def upload():
241
+ if request.method == "POST":
242
+ file = request.files.get("file")
243
+ if not file:
244
+ return "No file uploaded", 400
245
+ file_path = os.path.join(UPLOAD_FOLDER, file.filename)
246
+ file.save(file_path)
247
+ # For SQLite, use the absolute file path in the URI
248
+ new_db_uri = f"sqlite:///{file_path}"
249
+ global agent_app
250
+ agent_app = create_agent_app(new_db_uri)
251
+ socketio.emit("log", {"message": f"[INFO]: Database file '{file.filename}' uploaded and loaded."})
252
+ return redirect(url_for("index"))
253
+ return render_template("upload.html")
254
+
255
+ return flask_app, socketio
256
+
257
+ ###############################################################################
258
+ # Helper function to run the agent; uses the global agent_app.
259
+ ###############################################################################
260
+ def run_agent(prompt, socketio):
261
  try:
262
  query = {"messages": [("user", prompt)]}
263
  result = agent_app.invoke(query)
 
268
  socketio.emit("log", {"message": f"[ERROR]: {str(e)}"})
269
  socketio.emit("final", {"message": "Generation failed."})
270
 
271
+ # Create the app and assign to "app" for Gunicorn compatibility.
272
+ app, socketio_instance = create_app()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
273
 
274
  if __name__ == "__main__":
275
+ socketio_instance.run(app, debug=True)