WebashalarForML commited on
Commit
e585eee
·
verified ·
1 Parent(s): d263957

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +206 -218
app.py CHANGED
@@ -3,7 +3,8 @@ from flask_socketio import SocketIO
3
  import threading
4
  import os
5
  from dotenv import load_dotenv
6
- import sqlite3
 
7
 
8
  # LangChain and agent imports
9
  from langchain_community.chat_models.huggingface import ChatHuggingFace # if needed later
@@ -19,7 +20,7 @@ from langchain import hub
19
  from langchain_community.tools import DuckDuckGoSearchRun
20
 
21
  # Agent requirements and type hints
22
- from typing import Annotated, Literal, Sequence, TypedDict, Any
23
  from langchain_core.messages import AIMessage, ToolMessage
24
  from pydantic import BaseModel, Field
25
  from typing_extensions import TypedDict
@@ -31,284 +32,269 @@ from langgraph.prebuilt import ToolNode
31
  # Load environment variables
32
  load_dotenv()
33
 
34
- UPLOAD_FOLDER = "uploads/"
35
-
36
- # In your .env file, ensure you have:
37
- # DATABASE_URI=sqlite:///employee.db
38
  BASE_DIR = os.path.abspath(os.path.dirname(__file__))
39
  DATABASE_URI = f"sqlite:///{os.path.join(BASE_DIR, 'data', 'mydb.db')}"
40
  print("DATABASE URI:", DATABASE_URI)
41
 
42
-
43
  GROQ_API_KEY = os.getenv("GROQ_API_KEY")
44
  MISTRAL_API_KEY = os.getenv("MISTRAL_API_KEY")
45
  os.environ["GROQ_API_KEY"] = GROQ_API_KEY
46
  os.environ["MISTRAL_API_KEY"] = MISTRAL_API_KEY
47
 
48
- # Use ChatGroq LLM (which does not require a Hugging Face API token)
49
- from langchain_groq import ChatGroq
50
- from langchain_mistralai.chat_models import ChatMistralAI
51
-
52
- ###############################################################################
53
- # Application Factory: create_app()
54
- #
55
- # This function sets up the Flask application, SocketIO, routes, and initializes
56
- # the global agent_app using the default DATABASE_URI. It returns the Flask app.
57
- ###############################################################################
58
- # --- Application Factory ---
59
  abs_file_path = None
60
- agent_app= None
61
 
62
- def create_app():
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
63
 
64
- def create_agent_app(db_path: str):
65
- # Construct the SQLite URI from the given file path.
66
- # Ensure the db_path is absolute so that SQLAlchemy can locate the file.
67
- #llm = ChatMistralAI(model="mistral-large-latest")
68
- llm = ChatGroq(model="llama3-70b-8192")
69
-
70
- @tool
71
- def db_query_tool(query: str) -> str:
72
- """
73
- Execute a SQL query against the database and return the result.
74
- If the query is invalid or returns no result, an error message will be returned.
75
- In case of an error, the user is advised to rewrite the query and try again.
76
- """
77
- result = db_instance.run_no_throw(query)
78
- if not result:
79
- return "Error: Query failed. Please rewrite your query and try again."
80
- return result
81
-
82
- # Define a Pydantic model for submitting the final answer
83
- class SubmitFinalAnswer(BaseModel):
84
- """Submit the final answer to the user based on the query results."""
85
- final_answer: str = Field(..., description="The final answer to the user")
86
-
87
- # Define the state type
88
- class State(TypedDict):
89
- messages: Annotated[list[AnyMessage], add_messages]
90
-
91
- # Define prompt templates for query checking and query generation
92
- from langchain_core.prompts import ChatPromptTemplate
93
-
94
- query_check_system = """You are a SQL expert with a strong attention to detail.
95
- Double check the SQLite query for common mistakes, including:
96
- - Using NOT IN with NULL values
97
- - Using UNION when UNION ALL should have been used
98
- - Using BETWEEN for exclusive ranges
99
- - Data type mismatch in predicates
100
- - Properly quoting identifiers
101
- - Using the correct number of arguments for functions
102
- - Casting to the correct data type
103
- - Using the proper columns for joins
104
-
105
- If there are any of the above mistakes, rewrite the query. If there are no mistakes, just reproduce the original query.
106
-
107
- You will call the appropriate tool to execute the query after running this check."""
108
- query_check_prompt = ChatPromptTemplate.from_messages([("system", query_check_system), ("placeholder", "{messages}")])
109
- query_check = query_check_prompt | llm.bind_tools([db_query_tool])
110
-
111
- query_gen_system = """You are a SQL expert with a strong attention to detail.
112
-
113
- Given an input question, output a syntactically correct SQLite query to run, then look at the results of the query and return the answer.
114
-
115
- DO NOT call any tool besides SubmitFinalAnswer to submit the final answer.
116
-
117
- When generating the query:
118
-
119
- Output the SQL query that answers the input question without a tool call.
120
-
121
- Unless the user specifies a specific number of examples they wish to obtain, always limit your query to at most 5 results.
122
- You can order the results by a relevant column to return the most interesting examples in the database.
123
- Never query for all the columns from a specific table, only ask for the relevant columns given the question.
124
-
125
- If you get an error while executing a query, rewrite the query and try again.
126
-
127
- If you get an empty result set, you should try to rewrite the query to get a non-empty result set.
128
- NEVER make stuff up if you don't have enough information to answer the query... just say you don't have enough information.
129
-
130
- If you have enough information to answer the input question, simply invoke the appropriate tool to submit the final answer to the user.
131
-
132
- DO NOT make any DML statements (INSERT, UPDATE, DELETE, DROP etc.) to the database. Do not return any sql query except answer."""
133
- query_gen_prompt = ChatPromptTemplate.from_messages([("system", query_gen_system), ("placeholder", "{messages}")])
134
- query_gen = query_gen_prompt | llm.bind_tools([SubmitFinalAnswer])
135
-
136
- abs_db_path = os.path.abspath(db_path)
137
- global DATABASE_URI
138
- DATABASE_URI = abs_db_path
139
- db_uri = f"sqlite:///{abs_db_path}"
140
- print("db_uri",db_uri)
141
-
142
- # Create new SQLDatabase connection using the constructed URI.
143
- from langchain_community.utilities import SQLDatabase
144
- db_instance = SQLDatabase.from_uri(db_uri)
145
- print("db_instance----->",db_instance)
146
- print("db_uri----->",db_uri)
147
-
148
-
149
- # Create SQL toolkit and get the tools.
150
- from langchain_community.agent_toolkits import SQLDatabaseToolkit
151
- toolkit_instance = SQLDatabaseToolkit(db=db_instance, llm=llm)
152
- tools_instance = toolkit_instance.get_tools()
153
-
154
- # Define workflow nodes and fallback functions
155
- def first_tool_call(state: State) -> dict[str, list[AIMessage]]:
156
- return {"messages": [AIMessage(content="", tool_calls=[{"name": "sql_db_list_tables", "args": {}, "id": "tool_abcd123"}])]}
157
-
158
- def handle_tool_error(state: State) -> dict:
159
- error = state.get("error")
160
- tool_calls = state["messages"][-1].tool_calls
161
- return {
162
- "messages": [
163
- ToolMessage(content=f"Error: {repr(error)}\n please fix your mistakes.", tool_call_id=tc["id"])
164
- for tc in tool_calls
165
- ]
166
- }
167
-
168
- def create_tool_node_with_fallback(tools_list: list) -> RunnableWithFallbacks[Any, dict]:
169
- return ToolNode(tools_list).with_fallbacks([RunnableLambda(handle_tool_error)], exception_key="error")
170
-
171
- def query_gen_node(state: State):
172
- message = query_gen.invoke(state)
173
- # Check for incorrect tool calls
174
- tool_messages = []
175
- if message.tool_calls:
176
- for tc in message.tool_calls:
177
- if tc["name"] != "SubmitFinalAnswer":
178
- tool_messages.append(
179
- ToolMessage(
180
- content=f"Error: The wrong tool was called: {tc['name']}. Please fix your mistakes. Remember to only call SubmitFinalAnswer to submit the final answer. Generated queries should be outputted WITHOUT a tool call.",
181
- tool_call_id=tc["id"],
182
- )
183
- )
184
- return {"messages": [message] + tool_messages}
185
-
186
- def should_continue(state: State) -> Literal[END, "correct_query", "query_gen"]:
187
- messages = state["messages"]
188
- last_message = messages[-1]
189
- if getattr(last_message, "tool_calls", None):
190
- return END
191
- if last_message.content.startswith("Error:"):
192
- return "query_gen"
193
- else:
194
- return "correct_query"
195
-
196
- def model_check_query(state: State) -> dict[str, list[AIMessage]]:
197
- """Double-check if the query is correct before executing it."""
198
- return {"messages": [query_check.invoke({"messages": [state["messages"][-1]]})]}
199
-
200
- # Get tools for listing tables and fetching schema
201
- list_tables_tool = next((tool for tool in tools_instance if tool.name == "sql_db_list_tables"), None)
202
- get_schema_tool = next((tool for tool in tools_instance if tool.name == "sql_db_schema"), None)
203
-
204
- # Define the workflow (state graph)
205
- workflow = StateGraph(State)
206
- workflow.add_node("first_tool_call", first_tool_call)
207
- workflow.add_node("list_tables_tool", create_tool_node_with_fallback([list_tables_tool]))
208
- workflow.add_node("get_schema_tool", create_tool_node_with_fallback([get_schema_tool]))
209
- model_get_schema = llm.bind_tools([get_schema_tool])
210
- workflow.add_node("model_get_schema", lambda state: {"messages": [model_get_schema.invoke(state["messages"])],})
211
- workflow.add_node("query_gen", query_gen_node)
212
- workflow.add_node("correct_query", model_check_query)
213
- workflow.add_node("execute_query", create_tool_node_with_fallback([db_query_tool]))
214
-
215
- workflow.add_edge(START, "first_tool_call")
216
- workflow.add_edge("first_tool_call", "list_tables_tool")
217
- workflow.add_edge("list_tables_tool", "model_get_schema")
218
- workflow.add_edge("model_get_schema", "get_schema_tool")
219
- workflow.add_edge("get_schema_tool", "query_gen")
220
- workflow.add_conditional_edges("query_gen", should_continue)
221
- workflow.add_edge("correct_query", "execute_query")
222
- workflow.add_edge("execute_query", "query_gen")
223
-
224
- # Compile and return the agent application workflow.
225
- return workflow.compile()
226
-
227
- # Option: configure static files from uploads folder as well.
228
  flask_app = Flask(__name__, static_url_path='/uploads', static_folder='uploads')
229
  socketio = SocketIO(flask_app, cors_allowed_origins="*")
230
-
231
- # Set up uploads folder
232
- UPLOAD_FOLDER_LOCAL = os.path.join(os.getcwd(), "uploads")
233
- if not os.path.exists(UPLOAD_FOLDER_LOCAL):
234
- os.makedirs(UPLOAD_FOLDER_LOCAL)
235
- flask_app.config['UPLOAD_FOLDER'] = UPLOAD_FOLDER_LOCAL
236
-
237
- # Static route: option if you want a custom route to serve files:
238
  @flask_app.route("/files/<path:filename>")
239
  def uploaded_file(filename):
240
  return send_from_directory(flask_app.config['UPLOAD_FOLDER'], filename)
241
 
242
- # Helper function to run the agent; uses the global agent_app.
 
 
243
  def run_agent(prompt, socketio):
244
  global agent_app
245
  if agent_app is None:
246
- socketio.emit("log", {"message": "[ERROR]: No database has been uploaded. Please upload a database file first."})
247
- socketio.emit("final", {"message": "No database available. Upload a database and try again."})
248
  return
249
  try:
250
  query = {"messages": [("user", prompt)]}
251
- agent_app = create_agent_app(abs_file_path)
252
  result = agent_app.invoke(query)
253
  try:
254
  result = result["messages"][-1].tool_calls[0]["args"]["final_answer"]
255
  except Exception:
256
  result = "Query failed or no valid answer found."
257
-
258
  print("final_answer------>", result)
259
- socketio.emit("final", {"message": f"{result}"})
260
-
261
  except Exception as e:
262
  print(f"[ERROR]: {str(e)}")
263
  socketio.emit("log", {"message": f"[ERROR]: {str(e)}"})
264
  socketio.emit("final", {"message": "Generation failed."})
265
 
 
 
 
266
  @flask_app.route("/")
267
  def index():
268
  return render_template("index.html")
269
 
 
 
 
270
  @flask_app.route("/generate", methods=["POST"])
271
  def generate():
272
  try:
273
  socketio.emit("log", {"message": "[STEP]: Entering query_gen..."})
274
  data = request.json
275
  prompt = data.get("prompt", "")
276
- socketio.emit("log", {"message": f"[INFO]: Received prompt: {prompt}\n"})
277
- # Run the agent in a separate thread
278
  thread = threading.Thread(target=run_agent, args=(prompt, socketio))
279
- socketio.emit("log", {"message": f"[INFO]: thread info: {thread}\n"})
280
- socketio.emit("log", {"message": f"[INFO]: DB PATH: {abs_file_path}\n"})
281
  thread.start()
282
  return "OK", 200
283
  except Exception as e:
284
  print(f"[ERROR]: {str(e)}")
285
  socketio.emit("log", {"message": f"[ERROR]: {str(e)}"})
 
286
 
287
- @flask_app.route("/upload", methods=["POST", "GET"])
 
 
 
288
  def upload():
 
289
  try:
290
- if request.method == 'POST':
291
- file = request.files.get('file')
292
  if not file:
293
  print("No file uploaded")
294
  return "No file uploaded", 400
295
- if file and file.filename.endswith('.db'):
296
- # Save file using flask_app.config
297
- db_path = os.path.join(flask_app.config['UPLOAD_FOLDER'], 'uploaded.db')
298
- socketio.emit("log", {"message": f"[INFO]: Saving file to: {db_path}\n"})
299
  print("Saving file to:", db_path)
300
  file.save(db_path)
301
-
302
- # Reinitialize the agent_app with the new database file
303
- global abs_file_path
304
  abs_file_path = os.path.abspath(db_path)
305
- global agent_app
306
  agent_app = create_agent_app(abs_file_path)
307
-
308
- print(f"[INFO_PRINT]: Database file '{file.filename}' uploaded and loaded.")
309
- socketio.emit("log", {"message": f"[INFO]: Database file '{file.filename}' uploaded and loaded."})
310
  return redirect(url_for("index"))
311
- # For GET, render upload form:
312
  return render_template("upload.html")
313
  except Exception as e:
314
  print(f"[ERROR]: {str(e)}")
@@ -317,7 +303,9 @@ def create_app():
317
 
318
  return flask_app, socketio
319
 
 
320
  # Create the app for Gunicorn compatibility.
 
321
  app, socketio_instance = create_app()
322
 
323
  if __name__ == "__main__":
 
3
  import threading
4
  import os
5
  from dotenv import load_dotenv
6
+ import sqlite3
7
+ from werkzeug.utils import secure_filename
8
 
9
  # LangChain and agent imports
10
  from langchain_community.chat_models.huggingface import ChatHuggingFace # if needed later
 
20
  from langchain_community.tools import DuckDuckGoSearchRun
21
 
22
  # Agent requirements and type hints
23
+ from typing import Annotated, Literal, TypedDict, Any
24
  from langchain_core.messages import AIMessage, ToolMessage
25
  from pydantic import BaseModel, Field
26
  from typing_extensions import TypedDict
 
32
  # Load environment variables
33
  load_dotenv()
34
 
35
+ # Global configuration variables
36
+ UPLOAD_FOLDER = os.path.join(os.getcwd(), "uploads")
 
 
37
  BASE_DIR = os.path.abspath(os.path.dirname(__file__))
38
  DATABASE_URI = f"sqlite:///{os.path.join(BASE_DIR, 'data', 'mydb.db')}"
39
  print("DATABASE URI:", DATABASE_URI)
40
 
41
+ # API Keys from .env file
42
  GROQ_API_KEY = os.getenv("GROQ_API_KEY")
43
  MISTRAL_API_KEY = os.getenv("MISTRAL_API_KEY")
44
  os.environ["GROQ_API_KEY"] = GROQ_API_KEY
45
  os.environ["MISTRAL_API_KEY"] = MISTRAL_API_KEY
46
 
47
+ # Global variables for dynamic agent and DB file path; initially None.
48
+ agent_app = None
 
 
 
 
 
 
 
 
 
49
  abs_file_path = None
 
50
 
51
+ # =============================================================================
52
+ # create_agent_app: Given a database path, initialize the agent workflow.
53
+ # =============================================================================
54
+ def create_agent_app(db_path: str):
55
+ # Use ChatGroq as our LLM here; you can swap to ChatMistralAI if preferred.
56
+ from langchain_groq import ChatGroq
57
+ llm = ChatGroq(model="llama3-70b-8192")
58
+
59
+ # -------------------------------------------------------------------------
60
+ # Define a tool for executing SQL queries.
61
+ # -------------------------------------------------------------------------
62
+ @tool
63
+ def db_query_tool(query: str) -> str:
64
+ result = db_instance.run_no_throw(query)
65
+ return result if result else "Error: Query failed. Please rewrite your query and try again."
66
+
67
+ # -------------------------------------------------------------------------
68
+ # Pydantic model for final answer
69
+ # -------------------------------------------------------------------------
70
+ class SubmitFinalAnswer(BaseModel):
71
+ final_answer: str = Field(..., description="The final answer to the user")
72
+
73
+ # -------------------------------------------------------------------------
74
+ # Define state type for our workflow.
75
+ # -------------------------------------------------------------------------
76
+ class State(TypedDict):
77
+ messages: Annotated[list[AnyMessage], add_messages]
78
+
79
+ # -------------------------------------------------------------------------
80
+ # Set up prompt templates (using langchain_core.prompts) for query checking
81
+ # and query generation.
82
+ # -------------------------------------------------------------------------
83
+ from langchain_core.prompts import ChatPromptTemplate
84
+
85
+ query_check_system = (
86
+ "You are a SQL expert with a strong attention to detail.\n"
87
+ "Double check the SQLite query for common mistakes, including:\n"
88
+ "- Using NOT IN with NULL values\n"
89
+ "- Using UNION when UNION ALL should have been used\n"
90
+ "- Using BETWEEN for exclusive ranges\n"
91
+ "- Data type mismatch in predicates\n"
92
+ "- Properly quoting identifiers\n"
93
+ "- Using the correct number of arguments for functions\n"
94
+ "- Casting to the correct data type\n"
95
+ "- Using the proper columns for joins\n\n"
96
+ "If there are any of the above mistakes, rewrite the query. If there are no mistakes, just reproduce the original query.\n"
97
+ "You will call the appropriate tool to execute the query after running this check."
98
+ )
99
+ query_check_prompt = ChatPromptTemplate.from_messages([
100
+ ("system", query_check_system),
101
+ ("placeholder", "{messages}")
102
+ ])
103
+ query_check = query_check_prompt | llm.bind_tools([db_query_tool])
104
+
105
+ query_gen_system = (
106
+ "You are a SQL expert with a strong attention to detail.\n\n"
107
+ "Given an input question, output a syntactically correct SQLite query to run, then look at the results of the query and return the answer.\n\n"
108
+ "DO NOT call any tool besides SubmitFinalAnswer to submit the final answer.\n\n"
109
+ "When generating the query:\n"
110
+ "Output the SQL query that answers the input question without a tool call.\n"
111
+ "Unless the user specifies a specific number of examples they wish to obtain, always limit your query to at most 5 results.\n"
112
+ "You can order the results by a relevant column to return the most interesting examples in the database.\n"
113
+ "Never query for all the columns from a specific table, only ask for the relevant columns given the question.\n\n"
114
+ "If you get an error while executing a query, rewrite the query and try again.\n"
115
+ "If you get an empty result set, you should try to rewrite the query to get a non-empty result set.\n"
116
+ "NEVER make stuff up if you don't have enough information to answer the query... just say you don't have enough information.\n\n"
117
+ "If you have enough information to answer the input question, simply invoke the appropriate tool to submit the final answer to the user.\n"
118
+ "DO NOT make any DML statements (INSERT, UPDATE, DELETE, DROP etc.) to the database. Do not return any sql query except answer."
119
+ )
120
+ query_gen_prompt = ChatPromptTemplate.from_messages([
121
+ ("system", query_gen_system),
122
+ ("placeholder", "{messages}")
123
+ ])
124
+ query_gen = query_gen_prompt | llm.bind_tools([SubmitFinalAnswer])
125
+
126
+ # Update database URI and file path
127
+ abs_db_path_local = os.path.abspath(db_path)
128
+ global DATABASE_URI
129
+ DATABASE_URI = abs_db_path_local
130
+ db_uri = f"sqlite:///{abs_db_path_local}"
131
+ print("db_uri", db_uri)
132
+
133
+ # Create SQLDatabase connection using langchain utility.
134
+ from langchain_community.utilities import SQLDatabase
135
+ db_instance = SQLDatabase.from_uri(db_uri)
136
+ print("db_instance----->", db_instance)
137
+
138
+ # Create SQL toolkit.
139
+ from langchain_community.agent_toolkits import SQLDatabaseToolkit
140
+ toolkit_instance = SQLDatabaseToolkit(db=db_instance, llm=llm)
141
+ tools_instance = toolkit_instance.get_tools()
142
+
143
+ # Define workflow nodes and fallback functions.
144
+ def first_tool_call(state: State) -> dict[str, list[AIMessage]]:
145
+ return {"messages": [AIMessage(content="", tool_calls=[{"name": "sql_db_list_tables", "args": {}, "id": "tool_abcd123"}])]}
146
 
147
+ def handle_tool_error(state: State) -> dict:
148
+ error = state.get("error")
149
+ tool_calls = state["messages"][-1].tool_calls
150
+ return {"messages": [
151
+ ToolMessage(content=f"Error: {repr(error)}. Please fix your mistakes.", tool_call_id=tc["id"])
152
+ for tc in tool_calls
153
+ ]}
154
+
155
+ def create_tool_node_with_fallback(tools_list: list) -> RunnableWithFallbacks[Any, dict]:
156
+ return ToolNode(tools_list).with_fallbacks([RunnableLambda(handle_tool_error)], exception_key="error")
157
+
158
+ def query_gen_node(state: State):
159
+ message = query_gen.invoke(state)
160
+ tool_messages = []
161
+ if message.tool_calls:
162
+ for tc in message.tool_calls:
163
+ if tc["name"] != "SubmitFinalAnswer":
164
+ tool_messages.append(ToolMessage(
165
+ content=f"Error: The wrong tool was called: {tc['name']}. Please fix your mistakes.",
166
+ tool_call_id=tc["id"]
167
+ ))
168
+ return {"messages": [message] + tool_messages}
169
+
170
+ def should_continue(state: State) -> Literal[END, "correct_query", "query_gen"]:
171
+ messages = state["messages"]
172
+ last_message = messages[-1]
173
+ if getattr(last_message, "tool_calls", None):
174
+ return END
175
+ if last_message.content.startswith("Error:"):
176
+ return "query_gen"
177
+ return "correct_query"
178
+
179
+ def model_check_query(state: State) -> dict[str, list[AIMessage]]:
180
+ return {"messages": [query_check.invoke({"messages": [state["messages"][-1]]})]}
181
+
182
+ # Get table listing and schema tools.
183
+ list_tables_tool = next((tool for tool in tools_instance if tool.name == "sql_db_list_tables"), None)
184
+ get_schema_tool = next((tool for tool in tools_instance if tool.name == "sql_db_schema"), None)
185
+
186
+ workflow = StateGraph(State)
187
+ workflow.add_node("first_tool_call", first_tool_call)
188
+ workflow.add_node("list_tables_tool", create_tool_node_with_fallback([list_tables_tool]))
189
+ workflow.add_node("get_schema_tool", create_tool_node_with_fallback([get_schema_tool]))
190
+ model_get_schema = llm.bind_tools([get_schema_tool])
191
+ workflow.add_node("model_get_schema", lambda state: {"messages": [model_get_schema.invoke(state["messages"])],})
192
+ workflow.add_node("query_gen", query_gen_node)
193
+ workflow.add_node("correct_query", model_check_query)
194
+ workflow.add_node("execute_query", create_tool_node_with_fallback([db_query_tool]))
195
+
196
+ workflow.add_edge(START, "first_tool_call")
197
+ workflow.add_edge("first_tool_call", "list_tables_tool")
198
+ workflow.add_edge("list_tables_tool", "model_get_schema")
199
+ workflow.add_edge("model_get_schema", "get_schema_tool")
200
+ workflow.add_edge("get_schema_tool", "query_gen")
201
+ workflow.add_conditional_edges("query_gen", should_continue)
202
+ workflow.add_edge("correct_query", "execute_query")
203
+ workflow.add_edge("execute_query", "query_gen")
204
+
205
+ # Return compiled workflow
206
+ return workflow.compile()
207
+
208
+ # =============================================================================
209
+ # create_app: The application factory.
210
+ # =============================================================================
211
+ def create_app():
212
+ # Configure static folder for uploads.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
213
  flask_app = Flask(__name__, static_url_path='/uploads', static_folder='uploads')
214
  socketio = SocketIO(flask_app, cors_allowed_origins="*")
215
+
216
+ # Ensure uploads folder exists.
217
+ if not os.path.exists(UPLOAD_FOLDER):
218
+ os.makedirs(UPLOAD_FOLDER)
219
+ flask_app.config['UPLOAD_FOLDER'] = UPLOAD_FOLDER
220
+
221
+ # Serve uploaded files via a custom route.
 
222
  @flask_app.route("/files/<path:filename>")
223
  def uploaded_file(filename):
224
  return send_from_directory(flask_app.config['UPLOAD_FOLDER'], filename)
225
 
226
+ # -------------------------------------------------------------------------
227
+ # Helper: run_agent runs the agent with the given prompt.
228
+ # -------------------------------------------------------------------------
229
  def run_agent(prompt, socketio):
230
  global agent_app
231
  if agent_app is None:
232
+ socketio.emit("log", {"message": "[ERROR]: No database has been uploaded. Upload a database file first."})
233
+ socketio.emit("final", {"message": "No database available. Upload one and try again."})
234
  return
235
  try:
236
  query = {"messages": [("user", prompt)]}
 
237
  result = agent_app.invoke(query)
238
  try:
239
  result = result["messages"][-1].tool_calls[0]["args"]["final_answer"]
240
  except Exception:
241
  result = "Query failed or no valid answer found."
 
242
  print("final_answer------>", result)
243
+ socketio.emit("final", {"message": result})
 
244
  except Exception as e:
245
  print(f"[ERROR]: {str(e)}")
246
  socketio.emit("log", {"message": f"[ERROR]: {str(e)}"})
247
  socketio.emit("final", {"message": "Generation failed."})
248
 
249
+ # -------------------------------------------------------------------------
250
+ # Route: index page
251
+ # -------------------------------------------------------------------------
252
  @flask_app.route("/")
253
  def index():
254
  return render_template("index.html")
255
 
256
+ # -------------------------------------------------------------------------
257
+ # Route: generate (POST) – receives a prompt, runs the agent.
258
+ # -------------------------------------------------------------------------
259
  @flask_app.route("/generate", methods=["POST"])
260
  def generate():
261
  try:
262
  socketio.emit("log", {"message": "[STEP]: Entering query_gen..."})
263
  data = request.json
264
  prompt = data.get("prompt", "")
265
+ socketio.emit("log", {"message": f"[INFO]: Received prompt: {prompt}"})
 
266
  thread = threading.Thread(target=run_agent, args=(prompt, socketio))
267
+ socketio.emit("log", {"message": f"[INFO]: Starting thread: {thread}"})
 
268
  thread.start()
269
  return "OK", 200
270
  except Exception as e:
271
  print(f"[ERROR]: {str(e)}")
272
  socketio.emit("log", {"message": f"[ERROR]: {str(e)}"})
273
+ return "ERROR", 500
274
 
275
+ # -------------------------------------------------------------------------
276
+ # Route: upload (GET/POST) – handles uploading the SQLite DB file.
277
+ # -------------------------------------------------------------------------
278
+ @flask_app.route("/upload", methods=["GET", "POST"])
279
  def upload():
280
+ global abs_file_path, agent_app
281
  try:
282
+ if request.method == "POST":
283
+ file = request.files.get("file")
284
  if not file:
285
  print("No file uploaded")
286
  return "No file uploaded", 400
287
+ # Secure the filename to avoid path traversal issues.
288
+ filename = secure_filename(file.filename)
289
+ if filename.endswith('.db'):
290
+ db_path = os.path.join(flask_app.config['UPLOAD_FOLDER'], "uploaded.db")
291
  print("Saving file to:", db_path)
292
  file.save(db_path)
 
 
 
293
  abs_file_path = os.path.abspath(db_path)
 
294
  agent_app = create_agent_app(abs_file_path)
295
+ print(f"[INFO]: Database file '{filename}' uploaded and loaded.")
296
+ socketio.emit("log", {"message": f"[INFO]: Database file '{filename}' uploaded and loaded."})
 
297
  return redirect(url_for("index"))
 
298
  return render_template("upload.html")
299
  except Exception as e:
300
  print(f"[ERROR]: {str(e)}")
 
303
 
304
  return flask_app, socketio
305
 
306
+ # =============================================================================
307
  # Create the app for Gunicorn compatibility.
308
+ # =============================================================================
309
  app, socketio_instance = create_app()
310
 
311
  if __name__ == "__main__":