Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -1,8 +1,9 @@
|
|
1 |
-
from flask import Flask, render_template, request
|
2 |
from flask_socketio import SocketIO
|
3 |
import threading
|
4 |
import os
|
5 |
from dotenv import load_dotenv
|
|
|
6 |
|
7 |
# LangChain and agent imports
|
8 |
from langchain_community.chat_models.huggingface import ChatHuggingFace # if needed later
|
@@ -30,18 +31,19 @@ from langgraph.prebuilt import ToolNode
|
|
30 |
# Load environment variables
|
31 |
load_dotenv()
|
32 |
|
33 |
-
#
|
34 |
-
#
|
|
|
35 |
DATABASE_URI = os.getenv("DATABASE_URI", "sqlite:///employee.db")
|
36 |
|
37 |
GROQ_API_KEY = os.getenv("GROQ_API_KEY")
|
38 |
-
os.environ["GROQ_API_KEY"] =
|
39 |
|
40 |
# Use ChatGroq LLM (which does not require a Hugging Face API token)
|
41 |
from langchain_groq import ChatGroq
|
42 |
llm = ChatGroq(model="llama3-70b-8192")
|
43 |
|
44 |
-
# Connect to the provided database URI (
|
45 |
from langchain_community.utilities import SQLDatabase
|
46 |
db = SQLDatabase.from_uri(DATABASE_URI)
|
47 |
|
@@ -50,8 +52,6 @@ from langchain_community.agent_toolkits import SQLDatabaseToolkit
|
|
50 |
toolkit = SQLDatabaseToolkit(db=db, llm=llm)
|
51 |
tools = toolkit.get_tools()
|
52 |
|
53 |
-
|
54 |
-
|
55 |
# Define a custom query tool for executing SQL queries
|
56 |
@tool
|
57 |
def db_query_tool(query: str) -> str:
|
@@ -122,7 +122,10 @@ query_gen = query_gen_prompt | llm.bind_tools([SubmitFinalAnswer])
|
|
122 |
|
123 |
def create_agent_app(db_path: str):
|
124 |
# Construct the SQLite URI from the given file path.
|
125 |
-
|
|
|
|
|
|
|
126 |
# Create new SQLDatabase connection using the constructed URI.
|
127 |
from langchain_community.utilities import SQLDatabase
|
128 |
db_instance = SQLDatabase.from_uri(db_uri)
|
@@ -132,9 +135,6 @@ def create_agent_app(db_path: str):
|
|
132 |
toolkit_instance = SQLDatabaseToolkit(db=db_instance, llm=llm)
|
133 |
tools_instance = toolkit_instance.get_tools()
|
134 |
|
135 |
-
# ... (rest of the logic remains unchanged)
|
136 |
-
# Define a custom query tool for executing SQL queries
|
137 |
-
|
138 |
# Define workflow nodes and fallback functions
|
139 |
def first_tool_call(state: State) -> dict[str, list[AIMessage]]:
|
140 |
return {"messages": [AIMessage(content="", tool_calls=[{"name": "sql_db_list_tables", "args": {}, "id": "tool_abcd123"}])]}
|
@@ -205,7 +205,6 @@ def create_agent_app(db_path: str):
|
|
205 |
workflow.add_edge("correct_query", "execute_query")
|
206 |
workflow.add_edge("execute_query", "query_gen")
|
207 |
|
208 |
-
|
209 |
# Compile and return the agent application workflow.
|
210 |
return workflow.compile()
|
211 |
|
@@ -249,15 +248,15 @@ def create_app():
|
|
249 |
return "No file uploaded", 400
|
250 |
file_path = os.path.join(UPLOAD_FOLDER, file.filename)
|
251 |
file.save(file_path)
|
252 |
-
|
253 |
-
#
|
|
|
254 |
global agent_app
|
255 |
-
agent_app = create_agent_app(
|
256 |
|
257 |
socketio.emit("log", {"message": f"[INFO]: Database file '{file.filename}' uploaded and loaded."})
|
258 |
return redirect(url_for("index")) # Go back to index page
|
259 |
|
260 |
-
|
261 |
return flask_app, socketio
|
262 |
|
263 |
###############################################################################
|
@@ -278,4 +277,4 @@ def run_agent(prompt, socketio):
|
|
278 |
app, socketio_instance = create_app()
|
279 |
|
280 |
if __name__ == "__main__":
|
281 |
-
socketio_instance.run(app, debug=True)
|
|
|
1 |
+
from flask import Flask, render_template, request, redirect, url_for
|
2 |
from flask_socketio import SocketIO
|
3 |
import threading
|
4 |
import os
|
5 |
from dotenv import load_dotenv
|
6 |
+
import sqlite3 # You can keep this import if you use sqlite3 elsewhere
|
7 |
|
8 |
# LangChain and agent imports
|
9 |
from langchain_community.chat_models.huggingface import ChatHuggingFace # if needed later
|
|
|
31 |
# Load environment variables
|
32 |
load_dotenv()
|
33 |
|
34 |
+
# Set up the DB URI using an environment variable.
|
35 |
+
# In your .env file, ensure you have:
|
36 |
+
# DATABASE_URI=sqlite:///employee.db
|
37 |
DATABASE_URI = os.getenv("DATABASE_URI", "sqlite:///employee.db")
|
38 |
|
39 |
GROQ_API_KEY = os.getenv("GROQ_API_KEY")
|
40 |
+
os.environ["GROQ_API_KEY"] = GROQ_API_KEY
|
41 |
|
42 |
# Use ChatGroq LLM (which does not require a Hugging Face API token)
|
43 |
from langchain_groq import ChatGroq
|
44 |
llm = ChatGroq(model="llama3-70b-8192")
|
45 |
|
46 |
+
# Connect to the provided database URI using SQLDatabase (which expects a URI)
|
47 |
from langchain_community.utilities import SQLDatabase
|
48 |
db = SQLDatabase.from_uri(DATABASE_URI)
|
49 |
|
|
|
52 |
toolkit = SQLDatabaseToolkit(db=db, llm=llm)
|
53 |
tools = toolkit.get_tools()
|
54 |
|
|
|
|
|
55 |
# Define a custom query tool for executing SQL queries
|
56 |
@tool
|
57 |
def db_query_tool(query: str) -> str:
|
|
|
122 |
|
123 |
def create_agent_app(db_path: str):
|
124 |
# Construct the SQLite URI from the given file path.
|
125 |
+
# Ensure the db_path is absolute so that SQLAlchemy can locate the file.
|
126 |
+
abs_db_path = os.path.abspath(db_path)
|
127 |
+
db_uri = f"sqlite:///{abs_db_path}"
|
128 |
+
|
129 |
# Create new SQLDatabase connection using the constructed URI.
|
130 |
from langchain_community.utilities import SQLDatabase
|
131 |
db_instance = SQLDatabase.from_uri(db_uri)
|
|
|
135 |
toolkit_instance = SQLDatabaseToolkit(db=db_instance, llm=llm)
|
136 |
tools_instance = toolkit_instance.get_tools()
|
137 |
|
|
|
|
|
|
|
138 |
# Define workflow nodes and fallback functions
|
139 |
def first_tool_call(state: State) -> dict[str, list[AIMessage]]:
|
140 |
return {"messages": [AIMessage(content="", tool_calls=[{"name": "sql_db_list_tables", "args": {}, "id": "tool_abcd123"}])]}
|
|
|
205 |
workflow.add_edge("correct_query", "execute_query")
|
206 |
workflow.add_edge("execute_query", "query_gen")
|
207 |
|
|
|
208 |
# Compile and return the agent application workflow.
|
209 |
return workflow.compile()
|
210 |
|
|
|
248 |
return "No file uploaded", 400
|
249 |
file_path = os.path.join(UPLOAD_FOLDER, file.filename)
|
250 |
file.save(file_path)
|
251 |
+
|
252 |
+
# Convert the file path to an absolute path and reinitialize the agent_app
|
253 |
+
abs_file_path = os.path.abspath(file_path)
|
254 |
global agent_app
|
255 |
+
agent_app = create_agent_app(abs_file_path)
|
256 |
|
257 |
socketio.emit("log", {"message": f"[INFO]: Database file '{file.filename}' uploaded and loaded."})
|
258 |
return redirect(url_for("index")) # Go back to index page
|
259 |
|
|
|
260 |
return flask_app, socketio
|
261 |
|
262 |
###############################################################################
|
|
|
277 |
app, socketio_instance = create_app()
|
278 |
|
279 |
if __name__ == "__main__":
|
280 |
+
socketio_instance.run(app, debug=True)
|