WebashalarForML commited on
Commit
b5158ae
·
verified ·
1 Parent(s): 92d3c39

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +53 -203
app.py CHANGED
@@ -1,185 +1,86 @@
1
- from flask import Flask, render_template, request, redirect, url_for, send_from_directory, flash
2
  from flask_socketio import SocketIO
3
- import threading
4
  import os
5
  from dotenv import load_dotenv
6
- import sqlite3
7
  from werkzeug.utils import secure_filename
8
 
9
  # LangChain and agent imports
10
- from langchain_community.chat_models.huggingface import ChatHuggingFace # if needed later
11
- from langchain.agents import Tool
12
- from langchain.agents.format_scratchpad import format_log_to_str
13
- from langchain.agents.output_parsers import ReActJsonSingleInputOutputParser
14
- from langchain_core.callbacks import CallbackManager, BaseCallbackHandler
15
- from langchain_community.agent_toolkits.load_tools import load_tools
16
- from langchain_core.tools import tool
17
- from langchain_community.agent_toolkits import PowerBIToolkit
18
- from langchain.chains import LLMMathChain
19
- from langchain import hub
20
- from langchain_community.tools import DuckDuckGoSearchRun
21
-
22
- # Agent requirements and type hints
23
- from typing import Annotated, Literal, TypedDict, Any
24
  from langchain_core.messages import AIMessage, ToolMessage
25
  from pydantic import BaseModel, Field
26
  from typing_extensions import TypedDict
27
- from langgraph.graph import END, StateGraph, START
28
  from langgraph.graph.message import AnyMessage, add_messages
29
  from langchain_core.runnables import RunnableLambda, RunnableWithFallbacks
30
  from langgraph.prebuilt import ToolNode
31
-
 
 
 
32
  import traceback
33
 
34
  # Load environment variables
35
  load_dotenv()
36
 
37
-
38
  # Global configuration variables
39
  UPLOAD_FOLDER = os.path.join(os.getcwd(), "uploads")
 
40
  BASE_DIR = os.path.abspath(os.path.dirname(__file__))
41
- DATABASE_URI = f"sqlite:///{os.path.join(BASE_DIR, 'data', 'mydb.db')}"
42
- print("DATABASE URI:", DATABASE_URI)
43
 
44
  # API Keys from .env file
45
- GROQ_API_KEY = os.getenv("GROQ_API_KEY")
46
- MISTRAL_API_KEY = os.getenv("MISTRAL_API_KEY")
47
- os.environ["GROQ_API_KEY"] = GROQ_API_KEY
48
- os.environ["MISTRAL_API_KEY"] = MISTRAL_API_KEY
49
 
50
- # Global variables for dynamic agent and DB file path; initially None.
 
 
 
 
 
51
  agent_app = None
52
  abs_file_path = None
53
- db_path = None
54
-
55
- print(traceback.format_exc())
56
-
57
- # =============================================================================
58
- # create_agent_app: Given a database path, initialize the agent workflow.
59
- # =============================================================================
60
 
61
  def create_agent_app(db_path: str):
62
- # Use ChatGroq as our LLM here; you can swap to ChatMistralAI if preferred.
63
  from langchain_groq import ChatGroq
64
  llm = ChatGroq(model="llama3-70b-8192")
65
 
66
- # -------------------------------------------------------------------------
67
- # Define a tool for executing SQL queries.
68
- # -------------------------------------------------------------------------
69
  @tool
70
  def db_query_tool(query: str) -> str:
71
- """
72
- Executes a SQL query on the connected SQLite database.
73
-
74
- Parameters:
75
- query (str): A SQL query string to be executed.
76
-
77
- Returns:
78
- str: The result from the database if successful, or an error message if not.
79
- """
80
  result = db_instance.run_no_throw(query)
81
- return result if result else "Error: Query failed. Please rewrite your query and try again."
82
 
83
- # -------------------------------------------------------------------------
84
- # Pydantic model for final answer
85
- # -------------------------------------------------------------------------
86
  class SubmitFinalAnswer(BaseModel):
87
- final_answer: str = Field(..., description="The final answer to the user")
88
 
89
- # -------------------------------------------------------------------------
90
- # Define state type for our workflow.
91
- # -------------------------------------------------------------------------
92
  class State(TypedDict):
93
  messages: Annotated[list[AnyMessage], add_messages]
94
 
95
- # -------------------------------------------------------------------------
96
- # Set up prompt templates (using langchain_core.prompts) for query checking
97
- # and query generation.
98
- # -------------------------------------------------------------------------
99
- from langchain_core.prompts import ChatPromptTemplate
100
-
101
- query_check_system = (
102
- "You are a SQL expert with a strong attention to detail.\n"
103
- "Double check the SQLite query for common mistakes, including:\n"
104
- "- Using NOT IN with NULL values\n"
105
- "- Using UNION when UNION ALL should have been used\n"
106
- "- Using BETWEEN for exclusive ranges\n"
107
- "- Data type mismatch in predicates\n"
108
- "- Properly quoting identifiers\n"
109
- "- Using the correct number of arguments for functions\n"
110
- "- Casting to the correct data type\n"
111
- "- Using the proper columns for joins\n\n"
112
- "If there are any of the above mistakes, rewrite the query. If there are no mistakes, just reproduce the original query.\n"
113
- "You will call the appropriate tool to execute the query after running this check."
114
- )
115
- query_check_prompt = ChatPromptTemplate.from_messages([
116
- ("system", query_check_system),
117
  ("placeholder", "{messages}")
118
- ])
119
- query_check = query_check_prompt | llm.bind_tools([db_query_tool])
120
 
121
- query_gen_system = (
122
- "You are a SQL expert with a strong attention to detail.\n\n"
123
- "Given an input question, output a syntactically correct SQLite query to run, then look at the results of the query and return the answer.\n\n"
124
- "DO NOT call any tool besides SubmitFinalAnswer to submit the final answer.\n\n"
125
- "When generating the query:\n"
126
- "Output the SQL query that answers the input question without a tool call.\n"
127
- "Unless the user specifies a specific number of examples they wish to obtain, always limit your query to at most 5 results.\n"
128
- "You can order the results by a relevant column to return the most interesting examples in the database.\n"
129
- "Never query for all the columns from a specific table, only ask for the relevant columns given the question.\n\n"
130
- "If you get an error while executing a query, rewrite the query and try again.\n"
131
- "If you get an empty result set, you should try to rewrite the query to get a non-empty result set.\n"
132
- "NEVER make stuff up if you don't have enough information to answer the query... just say you don't have enough information.\n\n"
133
- "If you have enough information to answer the input question, simply invoke the appropriate tool to submit the final answer to the user.\n"
134
- "DO NOT make any DML statements (INSERT, UPDATE, DELETE, DROP etc.) to the database. Do not return any SQL query except answer."
135
- )
136
- query_gen_prompt = ChatPromptTemplate.from_messages([
137
- ("system", query_gen_system),
138
  ("placeholder", "{messages}")
139
- ])
140
- query_gen = query_gen_prompt | llm.bind_tools([SubmitFinalAnswer])
141
 
142
- # -------------------------------------------------------------------------
143
- # Update database URI and file path, create SQLDatabase connection.
144
- # -------------------------------------------------------------------------
145
-
146
- abs_db_path_local = os.path.abspath(db_path)
147
- global DATABASE_URI
148
- DATABASE_URI = abs_db_path_local
149
- db_uri = f"sqlite:///{abs_db_path_local}"
150
- print("db_uri", db_uri)
151
- # Uncomment if flash is needed; ensure you have flask.flash imported if so.
152
- # flash(f"db_uri:{db_uri}", "warning")
153
 
154
- from langchain_community.utilities import SQLDatabase
155
- db_instance = SQLDatabase.from_uri(db_uri)
156
- print("db_instance----->", db_instance)
157
- # flash(f"db_instance:{db_instance}", "warning")
158
 
159
- # -------------------------------------------------------------------------
160
- # Create SQL toolkit.
161
- # -------------------------------------------------------------------------
162
-
163
- from langchain_community.agent_toolkits import SQLDatabaseToolkit
164
- toolkit_instance = SQLDatabaseToolkit(db=db_instance, llm=llm)
165
- tools_instance = toolkit_instance.get_tools()
166
-
167
- # -------------------------------------------------------------------------
168
- # Define workflow nodes and fallback functions.
169
- # -------------------------------------------------------------------------
170
-
171
- def first_tool_call(state: State) -> dict[str, list[AIMessage]]:
172
- return {"messages": [AIMessage(content="", tool_calls=[{"name": "sql_db_list_tables", "args": {}, "id": "tool_abcd123"}])]}
173
-
174
- def handle_tool_error(state: State) -> dict:
175
- error = state.get("error")
176
  tool_calls = state["messages"][-1].tool_calls
177
  return {"messages": [
178
- ToolMessage(content=f"Error: {repr(error)}. Please fix your mistakes.", tool_call_id=tc["id"])
179
- for tc in tool_calls
180
  ]}
181
 
182
- def create_tool_node_with_fallback(tools_list: list) -> RunnableWithFallbacks[Any, dict]:
183
  return ToolNode(tools_list).with_fallbacks([RunnableLambda(handle_tool_error)], exception_key="error")
184
 
185
  def query_gen_node(state: State):
@@ -189,36 +90,31 @@ def create_agent_app(db_path: str):
189
  for tc in message.tool_calls:
190
  if tc["name"] != "SubmitFinalAnswer":
191
  tool_messages.append(ToolMessage(
192
- content=f"Error: The wrong tool was called: {tc['name']}. Please fix your mistakes.",
193
  tool_call_id=tc["id"]
194
  ))
195
  return {"messages": [message] + tool_messages}
196
 
197
- def should_continue(state: State) -> Literal[END, "correct_query", "query_gen"]:
198
- messages = state["messages"]
199
- last_message = messages[-1]
200
  if getattr(last_message, "tool_calls", None):
201
  return END
202
  if last_message.content.startswith("Error:"):
203
  return "query_gen"
204
  return "correct_query"
205
 
206
- def model_check_query(state: State) -> dict[str, list[AIMessage]]:
207
  return {"messages": [query_check.invoke({"messages": [state["messages"][-1]]})]}
208
 
209
- # -------------------------------------------------------------------------
210
- # Get tools for listing tables and fetching schema.
211
- # -------------------------------------------------------------------------
212
-
213
- list_tables_tool = next((tool for tool in tools_instance if tool.name == "sql_db_list_tables"), None)
214
- get_schema_tool = next((tool for tool in tools_instance if tool.name == "sql_db_schema"), None)
215
 
216
  workflow = StateGraph(State)
217
  workflow.add_node("first_tool_call", first_tool_call)
218
- workflow.add_node("list_tables_tool", create_tool_node_with_fallback([list_tables_tool]))
219
- workflow.add_node("get_schema_tool", create_tool_node_with_fallback([get_schema_tool]))
220
- model_get_schema = llm.bind_tools([get_schema_tool])
221
- workflow.add_node("model_get_schema", lambda state: {"messages": [model_get_schema.invoke(state["messages"])],})
222
  workflow.add_node("query_gen", query_gen_node)
223
  workflow.add_node("correct_query", model_check_query)
224
  workflow.add_node("execute_query", create_tool_node_with_fallback([db_query_tool]))
@@ -232,37 +128,15 @@ def create_agent_app(db_path: str):
232
  workflow.add_edge("correct_query", "execute_query")
233
  workflow.add_edge("execute_query", "query_gen")
234
 
235
- # Return compiled workflow
236
  return workflow.compile()
237
 
238
-
239
- # =============================================================================
240
- # create_app: The application factory.
241
- # =============================================================================
242
-
243
- # Flask and SocketIO setup
244
- flask_app = Flask(__name__)
245
- flask_app.config['UPLOAD_FOLDER'] = "uploaded_files"
246
- os.makedirs(flask_app.config['UPLOAD_FOLDER'], exist_ok=True)
247
-
248
- socketio = SocketIO(flask_app, cors_allowed_origins="*")
249
-
250
- # Global variables to manage state
251
- agent_app = None
252
- db_path = None
253
- abs_file_path = None
254
-
255
- # ----------------------------
256
- # ROUTES
257
- # ----------------------------
258
-
259
  @flask_app.route("/", methods=["GET"])
260
  def index():
261
  return render_template("index.html")
262
 
263
  @flask_app.route("/upload", methods=["GET", "POST"])
264
  def upload():
265
- global abs_file_path, db_path, agent_app
266
  try:
267
  if request.method == "POST":
268
  file = request.files.get("file")
@@ -271,69 +145,45 @@ def upload():
271
 
272
  filename = secure_filename(file.filename)
273
  if filename.endswith('.db'):
274
- db_path = os.path.join(flask_app.config['UPLOAD_FOLDER'], "uploaded.db")
275
- file.save(db_path)
276
-
277
- abs_file_path = os.path.abspath(db_path)
278
- agent_app = None # reset agent so it gets re-initialized on next query
279
-
280
- print(f"[INFO]: File '{filename}' uploaded. Agent will be initialized on first query.")
281
- socketio.emit("log", {"message": f"[INFO]: Database file '{filename}' uploaded."})
282
-
283
  return redirect(url_for("index"))
284
  return render_template("upload.html")
285
  except Exception as e:
286
- print(f"[ERROR]: {str(e)}")
287
  socketio.emit("log", {"message": f"[ERROR]: {str(e)}"})
288
  return render_template("upload.html")
289
 
290
-
291
- # ----------------------------
292
- # AGENT INVOCATION
293
- # ----------------------------
294
-
295
  @socketio.on("user_input")
296
  def handle_user_input(data):
297
  prompt = data.get("message")
298
  if not prompt:
299
- socketio.emit("log", {"message": "[ERROR]: Empty prompt received."})
300
  return
301
-
302
  run_agent(prompt)
303
 
304
-
305
  def run_agent(prompt):
306
  global agent_app, abs_file_path
307
  if not abs_file_path:
308
- socketio.emit("log", {"message": "[ERROR]: No DB file uploaded."})
309
- socketio.emit("final", {"message": "No database available. Please upload one and try again."})
310
  return
311
-
312
  try:
313
  if agent_app is None:
314
- print("[INFO]: Initializing agent for the first time...")
315
  agent_app = create_agent_app(abs_file_path)
316
  socketio.emit("log", {"message": "[INFO]: Agent initialized."})
317
 
318
  query = {"messages": [("user", prompt)]}
319
  result = agent_app.invoke(query)
320
-
321
  try:
322
  result = result["messages"][-1].tool_calls[0]["args"]["final_answer"]
323
  except Exception:
324
  result = "Query failed or no valid answer found."
325
-
326
- print("final_answer------>", result)
327
  socketio.emit("final", {"message": result})
328
  except Exception as e:
329
- print(f"[ERROR]: {str(e)}")
330
  socketio.emit("log", {"message": f"[ERROR]: {str(e)}"})
331
  socketio.emit("final", {"message": "Generation failed."})
332
 
333
-
334
- # ----------------------------
335
- # MAIN
336
- # ----------------------------
337
-
338
  if __name__ == "__main__":
339
- socketio.run(flask_app, debug=True)
 
1
+ from flask import Flask, render_template, request, redirect, url_for
2
  from flask_socketio import SocketIO
 
3
  import os
4
  from dotenv import load_dotenv
 
5
  from werkzeug.utils import secure_filename
6
 
7
  # LangChain and agent imports
8
+ from typing import Annotated, Literal
 
 
 
 
 
 
 
 
 
 
 
 
 
9
  from langchain_core.messages import AIMessage, ToolMessage
10
  from pydantic import BaseModel, Field
11
  from typing_extensions import TypedDict
12
+ from langgraph.graph import END, START, StateGraph
13
  from langgraph.graph.message import AnyMessage, add_messages
14
  from langchain_core.runnables import RunnableLambda, RunnableWithFallbacks
15
  from langgraph.prebuilt import ToolNode
16
+ from langchain_core.prompts import ChatPromptTemplate
17
+ from langchain_community.utilities import SQLDatabase
18
+ from langchain_community.agent_toolkits import SQLDatabaseToolkit
19
+ from langchain_core.tools import tool
20
  import traceback
21
 
22
  # Load environment variables
23
  load_dotenv()
24
 
 
25
  # Global configuration variables
26
  UPLOAD_FOLDER = os.path.join(os.getcwd(), "uploads")
27
+ os.makedirs(UPLOAD_FOLDER, exist_ok=True)
28
  BASE_DIR = os.path.abspath(os.path.dirname(__file__))
 
 
29
 
30
  # API Keys from .env file
31
+ os.environ["GROQ_API_KEY"] = os.getenv("GROQ_API_KEY")
32
+ os.environ["MISTRAL_API_KEY"] = os.getenv("MISTRAL_API_KEY")
 
 
33
 
34
+ # Flask and SocketIO setup
35
+ flask_app = Flask(__name__)
36
+ flask_app.config['UPLOAD_FOLDER'] = UPLOAD_FOLDER
37
+ socketio = SocketIO(flask_app, cors_allowed_origins="*")
38
+
39
+ # Global state
40
  agent_app = None
41
  abs_file_path = None
 
 
 
 
 
 
 
42
 
43
  def create_agent_app(db_path: str):
 
44
  from langchain_groq import ChatGroq
45
  llm = ChatGroq(model="llama3-70b-8192")
46
 
47
+ abs_db_path = os.path.abspath(db_path)
48
+ db_instance = SQLDatabase.from_uri(f"sqlite:///{abs_db_path}")
49
+
50
  @tool
51
  def db_query_tool(query: str) -> str:
 
 
 
 
 
 
 
 
 
52
  result = db_instance.run_no_throw(query)
53
+ return result or "Error: Query failed. Please rewrite your query and try again."
54
 
 
 
 
55
  class SubmitFinalAnswer(BaseModel):
56
+ final_answer: str = Field(...)
57
 
 
 
 
58
  class State(TypedDict):
59
  messages: Annotated[list[AnyMessage], add_messages]
60
 
61
+ query_check = ChatPromptTemplate.from_messages([
62
+ ("system", "You are a SQL expert. Fix common issues in SQLite queries."),
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
63
  ("placeholder", "{messages}")
64
+ ]) | llm.bind_tools([db_query_tool])
 
65
 
66
+ query_gen = ChatPromptTemplate.from_messages([
67
+ ("system", "You are a SQL expert. Generate SQLite query and return answer using SubmitFinalAnswer tool."),
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
68
  ("placeholder", "{messages}")
69
+ ]) | llm.bind_tools([SubmitFinalAnswer])
 
70
 
71
+ toolkit = SQLDatabaseToolkit(db=db_instance, llm=llm)
72
+ tools_instance = toolkit.get_tools()
 
 
 
 
 
 
 
 
 
73
 
74
+ def first_tool_call(state: State):
75
+ return {"messages": [AIMessage(content="", tool_calls=[{"name": "sql_db_list_tables", "args": {}, "id": "tool_abcd123"}])]}
 
 
76
 
77
+ def handle_tool_error(state: State):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
78
  tool_calls = state["messages"][-1].tool_calls
79
  return {"messages": [
80
+ ToolMessage(content="Error occurred. Please revise.", tool_call_id=tc["id"]) for tc in tool_calls
 
81
  ]}
82
 
83
+ def create_tool_node_with_fallback(tools_list):
84
  return ToolNode(tools_list).with_fallbacks([RunnableLambda(handle_tool_error)], exception_key="error")
85
 
86
  def query_gen_node(state: State):
 
90
  for tc in message.tool_calls:
91
  if tc["name"] != "SubmitFinalAnswer":
92
  tool_messages.append(ToolMessage(
93
+ content=f"Error: Wrong tool called: {tc['name']}",
94
  tool_call_id=tc["id"]
95
  ))
96
  return {"messages": [message] + tool_messages}
97
 
98
+ def should_continue(state: State):
99
+ last_message = state["messages"][-1]
 
100
  if getattr(last_message, "tool_calls", None):
101
  return END
102
  if last_message.content.startswith("Error:"):
103
  return "query_gen"
104
  return "correct_query"
105
 
106
+ def model_check_query(state: State):
107
  return {"messages": [query_check.invoke({"messages": [state["messages"][-1]]})]}
108
 
109
+ list_tool = next((t for t in tools_instance if t.name == "sql_db_list_tables"), None)
110
+ schema_tool = next((t for t in tools_instance if t.name == "sql_db_schema"), None)
111
+ model_get_schema = llm.bind_tools([schema_tool])
 
 
 
112
 
113
  workflow = StateGraph(State)
114
  workflow.add_node("first_tool_call", first_tool_call)
115
+ workflow.add_node("list_tables_tool", create_tool_node_with_fallback([list_tool]))
116
+ workflow.add_node("get_schema_tool", create_tool_node_with_fallback([schema_tool]))
117
+ workflow.add_node("model_get_schema", lambda s: {"messages": [model_get_schema.invoke(s["messages\])]})
 
118
  workflow.add_node("query_gen", query_gen_node)
119
  workflow.add_node("correct_query", model_check_query)
120
  workflow.add_node("execute_query", create_tool_node_with_fallback([db_query_tool]))
 
128
  workflow.add_edge("correct_query", "execute_query")
129
  workflow.add_edge("execute_query", "query_gen")
130
 
 
131
  return workflow.compile()
132
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
133
  @flask_app.route("/", methods=["GET"])
134
  def index():
135
  return render_template("index.html")
136
 
137
  @flask_app.route("/upload", methods=["GET", "POST"])
138
  def upload():
139
+ global abs_file_path, agent_app
140
  try:
141
  if request.method == "POST":
142
  file = request.files.get("file")
 
145
 
146
  filename = secure_filename(file.filename)
147
  if filename.endswith('.db'):
148
+ save_path = os.path.join(flask_app.config['UPLOAD_FOLDER'], "uploaded.db")
149
+ file.save(save_path)
150
+ abs_file_path = os.path.abspath(save_path)
151
+ agent_app = None
152
+ socketio.emit("log", {"message": f"Database '{filename}' uploaded."})
 
 
 
 
153
  return redirect(url_for("index"))
154
  return render_template("upload.html")
155
  except Exception as e:
 
156
  socketio.emit("log", {"message": f"[ERROR]: {str(e)}"})
157
  return render_template("upload.html")
158
 
 
 
 
 
 
159
  @socketio.on("user_input")
160
  def handle_user_input(data):
161
  prompt = data.get("message")
162
  if not prompt:
163
+ socketio.emit("log", {"message": "[ERROR]: Empty prompt."})
164
  return
 
165
  run_agent(prompt)
166
 
 
167
  def run_agent(prompt):
168
  global agent_app, abs_file_path
169
  if not abs_file_path:
170
+ socketio.emit("final", {"message": "No DB uploaded."})
 
171
  return
 
172
  try:
173
  if agent_app is None:
 
174
  agent_app = create_agent_app(abs_file_path)
175
  socketio.emit("log", {"message": "[INFO]: Agent initialized."})
176
 
177
  query = {"messages": [("user", prompt)]}
178
  result = agent_app.invoke(query)
 
179
  try:
180
  result = result["messages"][-1].tool_calls[0]["args"]["final_answer"]
181
  except Exception:
182
  result = "Query failed or no valid answer found."
 
 
183
  socketio.emit("final", {"message": result})
184
  except Exception as e:
 
185
  socketio.emit("log", {"message": f"[ERROR]: {str(e)}"})
186
  socketio.emit("final", {"message": "Generation failed."})
187
 
 
 
 
 
 
188
  if __name__ == "__main__":
189
+ socketio.run(flask_app, debug=True)