WebashalarForML commited on
Commit
ae9b429
·
verified ·
1 Parent(s): 1b44a17

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +222 -216
app.py CHANGED
@@ -1,24 +1,36 @@
1
- from flask import Flask, render_template, request, redirect, url_for, flash, send_from_directory
2
  from flask_socketio import SocketIO
3
- import os
4
  import threading
 
5
  from dotenv import load_dotenv
 
6
  from werkzeug.utils import secure_filename
 
7
 
8
  # LangChain and agent imports
9
- from typing import Annotated, Literal
 
 
 
 
 
 
 
 
 
 
 
 
 
10
  from langchain_core.messages import AIMessage, ToolMessage
11
  from pydantic import BaseModel, Field
12
  from typing_extensions import TypedDict
13
- from langgraph.graph import END, START, StateGraph
14
  from langgraph.graph.message import AnyMessage, add_messages
15
  from langchain_core.runnables import RunnableLambda, RunnableWithFallbacks
16
  from langgraph.prebuilt import ToolNode
17
  from langchain_core.prompts import ChatPromptTemplate
18
  from langchain_community.utilities import SQLDatabase
19
- from langchain_community.agent_toolkits import SQLDatabaseToolkit
20
- from langchain_core.tools import tool
21
- import traceback
22
 
23
  # Load environment variables
24
  load_dotenv()
@@ -27,121 +39,123 @@ load_dotenv()
27
  UPLOAD_FOLDER = os.path.join(os.getcwd(), "uploads")
28
  os.makedirs(UPLOAD_FOLDER, exist_ok=True)
29
  BASE_DIR = os.path.abspath(os.path.dirname(__file__))
 
 
30
 
31
  # API Keys from .env file
 
32
  os.environ["GROQ_API_KEY"] = os.getenv("GROQ_API_KEY")
33
  os.environ["MISTRAL_API_KEY"] = os.getenv("MISTRAL_API_KEY")
34
 
35
- # Flask and SocketIO setup
36
- flask_app = Flask(__name__)
37
- flask_app.config['UPLOAD_FOLDER'] = UPLOAD_FOLDER
38
- # Set secret key for flash messages:
39
- flask_app.config['SECRET_KEY'] = os.getenv("FLASK_SECRET_KEY", "mysecretkey")
40
- socketio = SocketIO(flask_app, cors_allowed_origins="*")
41
-
42
- # Global state
43
  agent_app = None
44
  abs_file_path = None
 
45
 
 
 
 
46
  def create_agent_app(db_path: str):
47
- try:
48
- from langchain_groq import ChatGroq
49
- llm = ChatGroq(model="llama3-70b-8192")
50
- except Exception as e:
51
- flash(f"[ERROR]: Failed to initialize ChatGroq: {e}", "error")
52
- raise
53
-
54
- abs_db_path = os.path.abspath(db_path)
55
- try:
56
- db_instance = SQLDatabase.from_uri(f"sqlite:///{abs_db_path}")
57
- except Exception as e:
58
- flash(f"[ERROR]: Failed to connect to DB: {e}", "error")
59
- raise
60
-
61
- @tool
62
  def db_query_tool(query: str) -> str:
63
  """
64
- Execute a SQL query against the database and return the result.
65
- If the query is invalid or returns no result, an error message will be returned.
66
- In case of an error, the user is advised to rewrite the query and try again.
67
  """
68
  try:
69
  result = db_instance.run_no_throw(query)
70
- return result or "Error: Query failed. Please rewrite your query and try again."
71
  except Exception as e:
72
- flash(f"[ERROR]: Exception during query execution: {e}", "error")
73
  return f"Error: {str(e)}"
74
 
 
 
 
75
  class SubmitFinalAnswer(BaseModel):
76
- final_answer: str = Field(...)
77
 
 
 
 
78
  class State(TypedDict):
79
  messages: Annotated[list[AnyMessage], add_messages]
80
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
81
  try:
82
- query_check_system = """You are a SQL expert with a strong attention to detail.
83
- Double check the SQLite query for common mistakes, including:
84
- - Using NOT IN with NULL values
85
- - Using UNION when UNION ALL should have been used
86
- - Using BETWEEN for exclusive ranges
87
- - Data type mismatch in predicates
88
- - Properly quoting identifiers
89
- - Using the correct number of arguments for functions
90
- - Casting to the correct data type
91
- - Using the proper columns for joins
92
-
93
- If there are any of the above mistakes, rewrite the query. If there are no mistakes, just reproduce the original query.
94
-
95
- You will call the appropriate tool to execute the query after running this check.
96
- """
97
-
98
- query_check = ChatPromptTemplate.from_messages([
99
- ("system", query_check_system),
100
- ("placeholder", "{messages}")
101
- ]) | llm.bind_tools([db_query_tool])
102
-
103
- query_gen_system = """You are a SQL expert with a strong attention to detail.
104
-
105
- Given an input question, output a syntactically correct SQLite query to run, then look at the results of the query and return the answer.
106
-
107
- DO NOT call any tool besides SubmitFinalAnswer to submit the final answer.
108
-
109
- When generating the query:
110
-
111
- Output the SQL query that answers the input question without a tool call.
112
-
113
- Unless the user specifies a specific number of examples they wish to obtain, always limit your query to at most 5 results.
114
- You can order the results by a relevant column to return the most interesting examples in the database.
115
- Never query for all the columns from a specific table, only ask for the relevant columns given the question.
116
-
117
- If you get an error while executing a query, rewrite the query and try again.
118
-
119
- If you get an empty result set, you should try to rewrite the query to get a non-empty result set.
120
- NEVER make stuff up if you don't have enough information to answer the query... just say you don't have enough information.
121
-
122
- If you have enough information to answer the input question, simply invoke the appropriate tool to submit the final answer to the user.
123
-
124
- DO NOT make any DML statements (INSERT, UPDATE, DELETE, DROP etc.) to the database. Do not return any sql query except answer.
125
- """
126
-
127
-
128
- query_gen = ChatPromptTemplate.from_messages([
129
- ("system", query_gen_system),
130
- ("placeholder", "{messages}")
131
- ]) | llm.bind_tools([SubmitFinalAnswer])
132
- except Exception as e:
133
- flash(f"[ERROR]: Failed to create prompt templates: {e}", "error")
134
- raise
135
-
136
- try:
137
- toolkit = SQLDatabaseToolkit(db=db_instance, llm=llm)
138
- tools_instance = toolkit.get_tools()
139
  except Exception as e:
140
- flash(f"[ERROR]: Failed to initialize SQL toolkit: {e}", "error")
141
- raise
142
-
 
 
 
 
 
 
 
 
 
143
  def first_tool_call(state: State):
144
- return {"messages": [AIMessage(content="", tool_calls=[{"name": "sql_db_list_tables", "args": {}, "id": "tool_abcd123"}])]}
145
 
146
  def handle_tool_error(state: State):
147
  tool_calls = state["messages"][-1].tool_calls
@@ -156,8 +170,7 @@ def create_agent_app(db_path: str):
156
  try:
157
  message = query_gen.invoke(state)
158
  except Exception as e:
159
- flash(f"[ERROR]: Exception in query_gen_node: {e}", "error")
160
- raise
161
  tool_messages = []
162
  if message.tool_calls:
163
  for tc in message.tool_calls:
@@ -168,8 +181,9 @@ def create_agent_app(db_path: str):
168
  ))
169
  return {"messages": [message] + tool_messages}
170
 
171
- def should_continue(state: State):
172
- last_message = state["messages"][-1]
 
173
  if getattr(last_message, "tool_calls", None):
174
  return END
175
  if last_message.content.startswith("Error:"):
@@ -179,16 +193,18 @@ def create_agent_app(db_path: str):
179
  def model_check_query(state: State):
180
  return {"messages": [query_check.invoke({"messages": [state["messages"][-1]]})]}
181
 
182
- list_tool = next((t for t in tools_instance if t.name == "sql_db_list_tables"), None)
 
 
 
183
  schema_tool = next((t for t in tools_instance if t.name == "sql_db_schema"), None)
184
  model_get_schema = llm.bind_tools([schema_tool])
185
 
186
  workflow = StateGraph(State)
187
  workflow.add_node("first_tool_call", first_tool_call)
188
- workflow.add_node("list_tables_tool", create_tool_node_with_fallback([list_tool]))
189
  workflow.add_node("get_schema_tool", create_tool_node_with_fallback([schema_tool]))
190
- # Fixed unterminated string literal:
191
- workflow.add_node("model_get_schema", lambda s: {"messages": [model_get_schema.invoke(s["messages"])]})
192
  workflow.add_node("query_gen", query_gen_node)
193
  workflow.add_node("correct_query", model_check_query)
194
  workflow.add_node("execute_query", create_tool_node_with_fallback([db_query_tool]))
@@ -204,121 +220,111 @@ def create_agent_app(db_path: str):
204
 
205
  return workflow.compile()
206
 
207
- @flask_app.route("/files/<path:filename>")
208
- def uploaded_file(filename):
209
- try:
210
- return send_from_directory(flask_app.config['UPLOAD_FOLDER'], filename)
211
- except Exception as e:
212
- flash(f"[ERROR]: Could not send file: {str(e)}", "error")
213
- return redirect(url_for("index"))
214
-
215
- # -------------------------------------------------------------------------
216
- # Helper: run_agent runs the agent with the given prompt.
217
- # -------------------------------------------------------------------------
218
- def run_agent(prompt, socketio):
219
- global agent_app, abs_file_path
220
- if not abs_file_path:
221
- socketio.emit("log", {"message": "[ERROR]: No DB file uploaded."})
222
- socketio.emit("final", {"message": "No database available. Please upload one and try again."})
223
- flash("No database available. Please upload one and try again.", "error")
224
- return
225
- try:
226
- # Lazy agent initialization: use the previously uploaded DB.
227
- if agent_app is None:
228
- socketio.emit("log", {"message": "[INFO]: Initializing agent for the first time..."})
229
- agent_app = create_agent_app(abs_file_path)
230
- socketio.emit("log", {"message": "[INFO]: Agent initialized."})
231
- flash("Agent initialized.", "info")
232
- query = {"messages": [("user", prompt)]}
233
- result = agent_app.invoke(query)
234
  try:
235
- result = result["messages"][-1].tool_calls[0]["args"]["final_answer"]
236
  except Exception as e:
237
- result = "Query failed or no valid answer found."
238
- flash("Query failed or no valid answer found.", "warning")
239
- socketio.emit("final", {"message": result})
240
- except Exception as e:
241
- error_message = f"Generation failed: {str(e)}"
242
- socketio.emit("log", {"message": f"[ERROR]: {error_message}"})
243
- socketio.emit("final", {"message": "Generation failed."})
244
- flash(error_message, "error")
245
- traceback.print_exc()
246
-
247
- # -------------------------------------------------------------------------
248
- # Route: index page.
249
- # -------------------------------------------------------------------------
250
- @flask_app.route("/")
251
- def index():
252
- return render_template("index.html")
253
-
254
- # -------------------------------------------------------------------------
255
- # Route: generate (POST) – receives a prompt and runs the agent.
256
- # -------------------------------------------------------------------------
257
- @flask_app.route("/generate", methods=["POST"])
258
- def generate():
259
- try:
260
- socketio.emit("log", {"message": "[STEP]: Entering query generation..."})
261
- data = request.json
262
- prompt = data.get("prompt", "")
263
- socketio.emit("log", {"message": f"[INFO]: Received prompt: {prompt}"})
264
- thread = threading.Thread(target=run_agent, args=(prompt, socketio))
265
- socketio.emit("log", {"message": f"[INFO]: Starting thread: {thread}"})
266
- thread.start()
267
- flash("Query submitted successfully.", "info")
268
- return "OK", 200
269
- except Exception as e:
270
- error_message = f"[ERROR]: {str(e)}"
271
- socketio.emit("log", {"message": error_message})
272
- flash(error_message, "error")
273
- return "ERROR", 500
274
-
275
- # -------------------------------------------------------------------------
276
- # Route: upload (GET/POST) handles uploading the SQLite DB file.
277
- # -------------------------------------------------------------------------
278
- @flask_app.route("/upload", methods=["GET", "POST"])
279
- def upload():
280
- global abs_file_path, agent_app
281
- try:
282
- if request.method == "POST":
283
- file = request.files.get("file")
284
- if not file:
285
- flash("No file uploaded.", "error")
286
- return "No file uploaded", 400
287
- filename = secure_filename(file.filename)
288
- if filename.endswith('.db'):
289
- db_path = os.path.join(flask_app.config['UPLOAD_FOLDER'], "uploaded.db")
290
- try:
 
 
 
 
 
 
 
 
 
 
 
 
 
291
  file.save(db_path)
292
- abs_file_path = os.path.abspath(db_path) # Save it here; agent init will occur on first query.
293
- agent_app = None # Reset agent on upload.
294
- flash(f"Database file '{filename}' uploaded successfully.", "info")
295
  socketio.emit("log", {"message": f"[INFO]: Database file '{filename}' uploaded."})
 
296
  return redirect(url_for("index"))
297
- except Exception as save_err:
298
- flash(f"Error saving file: {save_err}", "error")
299
- socketio.emit("log", {"message": f"[ERROR]: Error saving file: {save_err}"})
300
- return render_template("upload.html")
301
- else:
302
- flash("Only .db files are allowed.", "error")
303
- return render_template("upload.html")
304
- return render_template("upload.html")
305
- except Exception as e:
306
- error_message = f"[ERROR]: {str(e)}"
307
- flash(error_message, "error")
308
- socketio.emit("log", {"message": error_message})
309
- return render_template("upload.html")
310
-
311
- @socketio.on("user_input")
312
- def handle_user_input(data):
313
- prompt = data.get("message")
314
- if not prompt:
315
- socketio.emit("log", {"message": "[ERROR]: Empty prompt."})
316
- flash("Empty prompt.", "error")
317
- return
318
- run_agent(prompt, socketio)
319
-
320
- # Expose the Flask app as "app" for Gunicorn
321
- app = flask_app
322
 
323
  if __name__ == "__main__":
324
- socketio.run(app, debug=True)
 
1
+ from flask import Flask, render_template, request, redirect, url_for, send_from_directory, flash
2
  from flask_socketio import SocketIO
 
3
  import threading
4
+ import os
5
  from dotenv import load_dotenv
6
+ import sqlite3
7
  from werkzeug.utils import secure_filename
8
+ import traceback
9
 
10
  # LangChain and agent imports
11
+ from langchain_community.chat_models.huggingface import ChatHuggingFace # if needed later
12
+ from langchain.agents import Tool
13
+ from langchain.agents.format_scratchpad import format_log_to_str
14
+ from langchain.agents.output_parsers import ReActJsonSingleInputOutputParser
15
+ from langchain_core.callbacks import CallbackManager, BaseCallbackHandler
16
+ from langchain_community.agent_toolkits.load_tools import load_tools
17
+ from langchain_core.tools import tool
18
+ from langchain_community.agent_toolkits import SQLDatabaseToolkit
19
+ from langchain.chains import LLMMathChain
20
+ from langchain import hub
21
+ from langchain_community.tools import DuckDuckGoSearchRun
22
+
23
+ # Agent requirements and type hints
24
+ from typing import Annotated, Literal, TypedDict, Any
25
  from langchain_core.messages import AIMessage, ToolMessage
26
  from pydantic import BaseModel, Field
27
  from typing_extensions import TypedDict
28
+ from langgraph.graph import END, StateGraph, START
29
  from langgraph.graph.message import AnyMessage, add_messages
30
  from langchain_core.runnables import RunnableLambda, RunnableWithFallbacks
31
  from langgraph.prebuilt import ToolNode
32
  from langchain_core.prompts import ChatPromptTemplate
33
  from langchain_community.utilities import SQLDatabase
 
 
 
34
 
35
  # Load environment variables
36
  load_dotenv()
 
39
  UPLOAD_FOLDER = os.path.join(os.getcwd(), "uploads")
40
  os.makedirs(UPLOAD_FOLDER, exist_ok=True)
41
  BASE_DIR = os.path.abspath(os.path.dirname(__file__))
42
+ DATABASE_URI = f"sqlite:///{os.path.join(BASE_DIR, 'data', 'mydb.db')}"
43
+ print("DATABASE URI:", DATABASE_URI)
44
 
45
  # API Keys from .env file
46
+ import os
47
  os.environ["GROQ_API_KEY"] = os.getenv("GROQ_API_KEY")
48
  os.environ["MISTRAL_API_KEY"] = os.getenv("MISTRAL_API_KEY")
49
 
50
+ # Global variables for dynamic agent and DB file path; initially None.
 
 
 
 
 
 
 
51
  agent_app = None
52
  abs_file_path = None
53
+ db_path = None
54
 
55
+ # =============================================================================
56
+ # create_agent_app: Given a database path, initialize the agent workflow.
57
+ # =============================================================================
58
  def create_agent_app(db_path: str):
59
+ # Use ChatGroq as our LLM here; you can swap to ChatMistralAI if preferred.
60
+ from langchain_groq import ChatGroq
61
+ llm = ChatGroq(model="llama3-70b-8192")
62
+
63
+ # -------------------------------------------------------------------------
64
+ # Define a tool for executing SQL queries, with an explicit description.
65
+ # -------------------------------------------------------------------------
66
+ @tool(description="Executes a SQL query on the connected SQLite database and returns the result.")
 
 
 
 
 
 
 
67
  def db_query_tool(query: str) -> str:
68
  """
69
+ Executes a SQL query on the connected SQLite database.
 
 
70
  """
71
  try:
72
  result = db_instance.run_no_throw(query)
73
+ return result if result else "Error: Query failed. Please rewrite your query and try again."
74
  except Exception as e:
 
75
  return f"Error: {str(e)}"
76
 
77
+ # -------------------------------------------------------------------------
78
+ # Pydantic model for final answer.
79
+ # -------------------------------------------------------------------------
80
  class SubmitFinalAnswer(BaseModel):
81
+ final_answer: str = Field(..., description="The final answer to the user")
82
 
83
+ # -------------------------------------------------------------------------
84
+ # Define state type for our workflow.
85
+ # -------------------------------------------------------------------------
86
  class State(TypedDict):
87
  messages: Annotated[list[AnyMessage], add_messages]
88
 
89
+ # -------------------------------------------------------------------------
90
+ # Set up prompt templates for query checking and query generation.
91
+ # -------------------------------------------------------------------------
92
+ query_check_system = (
93
+ "You are a SQL expert with a strong attention to detail.\n"
94
+ "Double check the SQLite query for common mistakes, including:\n"
95
+ "- Using NOT IN with NULL values\n"
96
+ "- Using UNION when UNION ALL should have been used\n"
97
+ "- Using BETWEEN for exclusive ranges\n"
98
+ "- Data type mismatch in predicates\n"
99
+ "- Properly quoting identifiers\n"
100
+ "- Using the correct number of arguments for functions\n"
101
+ "- Casting to the correct data type\n"
102
+ "- Using the proper columns for joins\n\n"
103
+ "If there are any of the above mistakes, rewrite the query. If there are no mistakes, just reproduce the original query.\n"
104
+ "You will call the appropriate tool to execute the query after running this check."
105
+ )
106
+ query_check_prompt = ChatPromptTemplate.from_messages([
107
+ ("system", query_check_system),
108
+ ("placeholder", "{messages}")
109
+ ])
110
+ query_check = query_check_prompt | llm.bind_tools([db_query_tool])
111
+
112
+ query_gen_system = (
113
+ "You are a SQL expert with a strong attention to detail.\n\n"
114
+ "Given an input question, output a syntactically correct SQLite query to run, then look at the results of the query and return the answer.\n\n"
115
+ "DO NOT call any tool besides SubmitFinalAnswer to submit the final answer.\n\n"
116
+ "When generating the query:\n"
117
+ "Output the SQL query that answers the input question without a tool call.\n"
118
+ "Unless the user specifies a specific number of examples they wish to obtain, always limit your query to at most 5 results.\n"
119
+ "You can order the results by a relevant column to return the most interesting examples in the database.\n"
120
+ "Never query for all the columns from a specific table, only ask for the relevant columns given the question.\n\n"
121
+ "If you get an error while executing a query, rewrite the query and try again.\n"
122
+ "If you get an empty result set, you should try to rewrite the query to get a non-empty result set.\n"
123
+ "NEVER make stuff up if you don't have enough information to answer the query... just say you don't have enough information.\n\n"
124
+ "If you have enough information to answer the input question, simply invoke the appropriate tool to submit the final answer to the user.\n"
125
+ "DO NOT make any DML statements (INSERT, UPDATE, DELETE, DROP etc.) to the database. Do not return any SQL query except answer."
126
+ )
127
+ query_gen_prompt = ChatPromptTemplate.from_messages([
128
+ ("system", query_gen_system),
129
+ ("placeholder", "{messages}")
130
+ ])
131
+ query_gen = query_gen_prompt | llm.bind_tools([SubmitFinalAnswer])
132
+
133
+ # -------------------------------------------------------------------------
134
+ # Update database URI, create SQLDatabase connection.
135
+ # -------------------------------------------------------------------------
136
+ abs_db_path_local = os.path.abspath(db_path)
137
+ global DATABASE_URI
138
+ DATABASE_URI = abs_db_path_local
139
+ db_uri = f"sqlite:///{abs_db_path_local}"
140
+ print("db_uri", db_uri)
141
+
142
  try:
143
+ db_instance = SQLDatabase.from_uri(db_uri)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
144
  except Exception as e:
145
+ raise Exception(f"Failed to create SQLDatabase connection: {e}")
146
+ print("db_instance----->", db_instance)
147
+
148
+ # -------------------------------------------------------------------------
149
+ # Create SQL toolkit.
150
+ # -------------------------------------------------------------------------
151
+ toolkit_instance = SQLDatabaseToolkit(db=db_instance, llm=llm)
152
+ tools_instance = toolkit_instance.get_tools()
153
+
154
+ # -------------------------------------------------------------------------
155
+ # Define workflow nodes and fallback functions.
156
+ # -------------------------------------------------------------------------
157
  def first_tool_call(state: State):
158
+ return {"messages": [AIMessage(content="", tool_calls=[{"name": "sql_db_list_tables", "args": {}, "id": "tool_abcd123"}])]}
159
 
160
  def handle_tool_error(state: State):
161
  tool_calls = state["messages"][-1].tool_calls
 
170
  try:
171
  message = query_gen.invoke(state)
172
  except Exception as e:
173
+ raise Exception(f"Exception in query_gen_node: {e}")
 
174
  tool_messages = []
175
  if message.tool_calls:
176
  for tc in message.tool_calls:
 
181
  ))
182
  return {"messages": [message] + tool_messages}
183
 
184
+ def should_continue(state: State) -> Literal[END, "correct_query", "query_gen"]:
185
+ messages = state["messages"]
186
+ last_message = messages[-1]
187
  if getattr(last_message, "tool_calls", None):
188
  return END
189
  if last_message.content.startswith("Error:"):
 
193
  def model_check_query(state: State):
194
  return {"messages": [query_check.invoke({"messages": [state["messages"][-1]]})]}
195
 
196
+ # -------------------------------------------------------------------------
197
+ # Get tools for listing tables and fetching schema.
198
+ # -------------------------------------------------------------------------
199
+ list_tables_tool = next((t for t in tools_instance if t.name == "sql_db_list_tables"), None)
200
  schema_tool = next((t for t in tools_instance if t.name == "sql_db_schema"), None)
201
  model_get_schema = llm.bind_tools([schema_tool])
202
 
203
  workflow = StateGraph(State)
204
  workflow.add_node("first_tool_call", first_tool_call)
205
+ workflow.add_node("list_tables_tool", create_tool_node_with_fallback([list_tables_tool]))
206
  workflow.add_node("get_schema_tool", create_tool_node_with_fallback([schema_tool]))
207
+ workflow.add_node("model_get_schema", lambda state: {"messages": [model_get_schema.invoke(state["messages"])]})
 
208
  workflow.add_node("query_gen", query_gen_node)
209
  workflow.add_node("correct_query", model_check_query)
210
  workflow.add_node("execute_query", create_tool_node_with_fallback([db_query_tool]))
 
220
 
221
  return workflow.compile()
222
 
223
+ # =============================================================================
224
+ # create_app: The application factory.
225
+ # =============================================================================
226
+ def create_app():
227
+ flask_app = Flask(__name__, static_url_path='/uploads', static_folder='uploads')
228
+ socketio = SocketIO(flask_app, cors_allowed_origins="*")
229
+
230
+ if not os.path.exists(UPLOAD_FOLDER):
231
+ os.makedirs(UPLOAD_FOLDER)
232
+ flask_app.config['UPLOAD_FOLDER'] = UPLOAD_FOLDER
233
+ flask_app.config['SECRET_KEY'] = os.getenv("FLASK_SECRET_KEY", "mysecretkey")
234
+
235
+ @flask_app.route("/files/<path:filename>")
236
+ def uploaded_file(filename):
 
 
 
 
 
 
 
 
 
 
 
 
 
237
  try:
238
+ return send_from_directory(flask_app.config['UPLOAD_FOLDER'], filename)
239
  except Exception as e:
240
+ flash(f"Could not send file: {str(e)}", "error")
241
+ return redirect(url_for("index"))
242
+
243
+ def run_agent(prompt, socketio):
244
+ global agent_app, abs_file_path, db_path
245
+ if not abs_file_path:
246
+ socketio.emit("log", {"message": "[ERROR]: No DB file uploaded."})
247
+ socketio.emit("final", {"message": "No database available. Please upload one and try again."})
248
+ flash("No database available. Please upload one and try again.", "error")
249
+ return
250
+ try:
251
+ if agent_app is None:
252
+ socketio.emit("log", {"message": "[INFO]: Initializing agent for the first time..."})
253
+ agent_app = create_agent_app(abs_file_path)
254
+ socketio.emit("log", {"message": "[INFO]: Agent initialized."})
255
+ flash("Agent initialized.", "info")
256
+ query = {"messages": [("user", prompt)]}
257
+ result = agent_app.invoke(query)
258
+ try:
259
+ result = result["messages"][-1].tool_calls[0]["args"]["final_answer"]
260
+ except Exception as e:
261
+ result = "Query failed or no valid answer found."
262
+ flash("Query failed or no valid answer found.", "warning")
263
+ print("final_answer------>", result)
264
+ socketio.emit("final", {"message": result})
265
+ except Exception as e:
266
+ error_message = f"Generation failed: {str(e)}"
267
+ socketio.emit("log", {"message": f"[ERROR]: {error_message}"})
268
+ socketio.emit("final", {"message": "Generation failed."})
269
+ flash(error_message, "error")
270
+ traceback.print_exc()
271
+
272
+ @flask_app.route("/")
273
+ def index():
274
+ return render_template("index.html")
275
+
276
+ @flask_app.route("/generate", methods=["POST"])
277
+ def generate():
278
+ try:
279
+ socketio.emit("log", {"message": "[STEP]: Entering query generation..."})
280
+ data = request.json
281
+ prompt = data.get("prompt", "")
282
+ socketio.emit("log", {"message": f"[INFO]: Received prompt: {prompt}"})
283
+ thread = threading.Thread(target=run_agent, args=(prompt, socketio))
284
+ socketio.emit("log", {"message": f"[INFO]: Starting thread: {thread}"})
285
+ thread.start()
286
+ flash("Query submitted successfully.", "info")
287
+ return "OK", 200
288
+ except Exception as e:
289
+ error_message = f"[ERROR]: {str(e)}"
290
+ socketio.emit("log", {"message": error_message})
291
+ flash(error_message, "error")
292
+ return "ERROR", 500
293
+
294
+ @flask_app.route("/upload", methods=["GET", "POST"])
295
+ def upload():
296
+ global abs_file_path, agent_app, db_path
297
+ try:
298
+ if request.method == "POST":
299
+ file = request.files.get("file")
300
+ if not file:
301
+ flash("No file uploaded.", "error")
302
+ return "No file uploaded", 400
303
+ filename = secure_filename(file.filename)
304
+ if filename.endswith('.db'):
305
+ db_path = os.path.join(flask_app.config['UPLOAD_FOLDER'], "uploaded.db")
306
+ print("Saving file to:", db_path)
307
  file.save(db_path)
308
+ abs_file_path = os.path.abspath(db_path)
309
+ agent_app = None # Reset the agent so it is lazily reinitialized on next query.
310
+ print(f"[INFO]: File '{filename}' uploaded. Agent will be initialized on first query.")
311
  socketio.emit("log", {"message": f"[INFO]: Database file '{filename}' uploaded."})
312
+ flash(f"Database file '{filename}' uploaded successfully.", "info")
313
  return redirect(url_for("index"))
314
+ return render_template("upload.html")
315
+ except Exception as e:
316
+ error_message = f"[ERROR]: {str(e)}"
317
+ print(error_message)
318
+ flash(error_message, "error")
319
+ socketio.emit("log", {"message": error_message})
320
+ return render_template("upload.html")
321
+
322
+ return flask_app, socketio
323
+
324
+ # =============================================================================
325
+ # Create the app for Gunicorn compatibility.
326
+ # =============================================================================
327
+ app, socketio_instance = create_app()
 
 
 
 
 
 
 
 
 
 
 
328
 
329
  if __name__ == "__main__":
330
+ socketio_instance.run(app, debug=True)