WebashalarForML commited on
Commit
34d17cd
·
verified ·
1 Parent(s): 60bf8f3

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +224 -0
app.py ADDED
@@ -0,0 +1,224 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from flask import Flask, render_template, request
2
+ from flask_socketio import SocketIO
3
+ import threading
4
+ import os
5
+ from dotenv import load_dotenv
6
+
7
+ # LangChain and agent imports
8
+ from langchain_community.chat_models.huggingface import ChatHuggingFace # if needed later
9
+ from langchain.agents import Tool
10
+ from langchain.agents.format_scratchpad import format_log_to_str
11
+ from langchain.agents.output_parsers import ReActJsonSingleInputOutputParser
12
+ from langchain_core.callbacks import CallbackManager, BaseCallbackHandler
13
+ from langchain_community.agent_toolkits.load_tools import load_tools # ensure correct import
14
+ from langchain_core.tools import tool
15
+ from langchain_community.agent_toolkits import PowerBIToolkit
16
+ from langchain.chains import LLMMathChain
17
+ from langchain import hub
18
+ from langchain_community.tools import DuckDuckGoSearchRun
19
+
20
+ # Agent requirements and type hints
21
+ from typing import Annotated, Literal, Sequence, TypedDict, Any
22
+ from langchain_core.messages import AIMessage, ToolMessage
23
+ from pydantic import BaseModel, Field
24
+ from typing_extensions import TypedDict
25
+ from langgraph.graph import END, StateGraph, START
26
+ from langgraph.graph.message import AnyMessage, add_messages
27
+ from langchain_core.runnables import RunnableLambda, RunnableWithFallbacks
28
+ from langgraph.prebuilt import ToolNode
29
+
30
+ # Load environment variables
31
+ load_dotenv()
32
+
33
+ # Instead of hardcoding the DB URI, get it from an environment variable.
34
+ # This lets you plug in any single DB by changing the DATABASE_URI environment variable.
35
+ DATABASE_URI = os.getenv("DATABASE_URI", "sqlite:///employee.db")
36
+
37
+ GROQ_API_KEY = os.getenv("GROQ_API_KEY")
38
+ os.environ["GROQ_API_KEY"] = GROQ_API_KEY
39
+
40
+ # Use ChatGroq LLM (which does not require a Hugging Face API token)
41
+ from langchain_groq import ChatGroq
42
+ llm = ChatGroq(model="llama3-70b-8192")
43
+
44
+ # Connect to the provided database URI (works with any single DB)
45
+ from langchain_community.utilities import SQLDatabase
46
+ db = SQLDatabase.from_uri(DATABASE_URI)
47
+
48
+ # Create SQL toolkit and get the tools
49
+ from langchain_community.agent_toolkits import SQLDatabaseToolkit
50
+ toolkit = SQLDatabaseToolkit(db=db, llm=llm)
51
+ tools = toolkit.get_tools()
52
+
53
+ # Define a custom query tool for executing SQL queries
54
+ @tool
55
+ def db_query_tool(query: str) -> str:
56
+ """
57
+ Execute a SQL query against the database and return the result.
58
+ If the query is invalid or returns no result, an error message will be returned.
59
+ In case of an error, the user is advised to rewrite the query and try again.
60
+ """
61
+ result = db.run_no_throw(query)
62
+ if not result:
63
+ return "Error: Query failed. Please rewrite your query and try again."
64
+ return result
65
+
66
+ # Define a Pydantic model for submitting the final answer
67
+ class SubmitFinalAnswer(BaseModel):
68
+ """Submit the final answer to the user based on the query results."""
69
+ final_answer: str = Field(..., description="The final answer to the user")
70
+
71
+ # Define the state type
72
+ class State(TypedDict):
73
+ messages: Annotated[list[AnyMessage], add_messages]
74
+
75
+ # Define prompt templates for query checking and query generation
76
+ from langchain_core.prompts import ChatPromptTemplate
77
+
78
+ query_check_system = """You are a SQL expert with a strong attention to detail.
79
+ Double check the SQLite query for common mistakes, including:
80
+ - Using NOT IN with NULL values
81
+ - Using UNION when UNION ALL should have been used
82
+ - Using BETWEEN for exclusive ranges
83
+ - Data type mismatch in predicates
84
+ - Properly quoting identifiers
85
+ - Using the correct number of arguments for functions
86
+ - Casting to the correct data type
87
+ - Using the proper columns for joins
88
+
89
+ If there are any of the above mistakes, rewrite the query. If there are no mistakes, just reproduce the original query.
90
+
91
+ You will call the appropriate tool to execute the query after running this check."""
92
+ query_check_prompt = ChatPromptTemplate.from_messages([("system", query_check_system), ("placeholder", "{messages}")])
93
+ query_check = query_check_prompt | llm.bind_tools([db_query_tool])
94
+
95
+ query_gen_system = """You are a SQL expert with a strong attention to detail.
96
+
97
+ Given an input question, output a syntactically correct SQLite query to run, then look at the results of the query and return the answer.
98
+
99
+ DO NOT call any tool besides SubmitFinalAnswer to submit the final answer.
100
+
101
+ When generating the query:
102
+
103
+ Output the SQL query that answers the input question without a tool call.
104
+
105
+ Unless the user specifies a specific number of examples they wish to obtain, always limit your query to at most 5 results.
106
+ You can order the results by a relevant column to return the most interesting examples in the database.
107
+ Never query for all the columns from a specific table, only ask for the relevant columns given the question.
108
+
109
+ If you get an error while executing a query, rewrite the query and try again.
110
+
111
+ If you get an empty result set, you should try to rewrite the query to get a non-empty result set.
112
+ NEVER make stuff up if you don't have enough information to answer the query... just say you don't have enough information.
113
+
114
+ If you have enough information to answer the input question, simply invoke the appropriate tool to submit the final answer to the user.
115
+
116
+ DO NOT make any DML statements (INSERT, UPDATE, DELETE, DROP etc.) to the database. Do not return any sql query except answer."""
117
+ query_gen_prompt = ChatPromptTemplate.from_messages([("system", query_gen_system), ("placeholder", "{messages}")])
118
+ query_gen = query_gen_prompt | llm.bind_tools([SubmitFinalAnswer])
119
+
120
+ # Define nodes and fallback functions for the workflow
121
+ def first_tool_call(state: State) -> dict[str, list[AIMessage]]:
122
+ return {"messages": [AIMessage(content="", tool_calls=[{"name": "sql_db_list_tables", "args": {}, "id": "tool_abcd123"}])]}
123
+
124
+ def handle_tool_error(state: State) -> dict:
125
+ error = state.get("error")
126
+ tool_calls = state["messages"][-1].tool_calls
127
+ return {
128
+ "messages": [
129
+ ToolMessage(content=f"Error: {repr(error)}\n please fix your mistakes.", tool_call_id=tc["id"])
130
+ for tc in tool_calls
131
+ ]
132
+ }
133
+
134
+ def create_tool_node_with_fallback(tools_list: list) -> RunnableWithFallbacks[Any, dict]:
135
+ return ToolNode(tools_list).with_fallbacks([RunnableLambda(handle_tool_error)], exception_key="error")
136
+
137
+ def query_gen_node(state: State):
138
+ message = query_gen.invoke(state)
139
+ # Check for incorrect tool calls
140
+ tool_messages = []
141
+ if message.tool_calls:
142
+ for tc in message.tool_calls:
143
+ if tc["name"] != "SubmitFinalAnswer":
144
+ tool_messages.append(
145
+ ToolMessage(
146
+ content=f"Error: The wrong tool was called: {tc['name']}. Please fix your mistakes. Remember to only call SubmitFinalAnswer to submit the final answer. Generated queries should be outputted WITHOUT a tool call.",
147
+ tool_call_id=tc["id"],
148
+ )
149
+ )
150
+ return {"messages": [message] + tool_messages}
151
+
152
+ def should_continue(state: State) -> Literal[END, "correct_query", "query_gen"]:
153
+ messages = state["messages"]
154
+ last_message = messages[-1]
155
+ if getattr(last_message, "tool_calls", None):
156
+ return END
157
+ if last_message.content.startswith("Error:"):
158
+ return "query_gen"
159
+ else:
160
+ return "correct_query"
161
+
162
+ def model_check_query(state: State) -> dict[str, list[AIMessage]]:
163
+ """Double-check if the query is correct before executing it."""
164
+ return {"messages": [query_check.invoke({"messages": [state["messages"][-1]]})]}
165
+
166
+ # Get tools for listing tables and fetching schema
167
+ list_tables_tool = next((tool for tool in tools if tool.name == "sql_db_list_tables"), None)
168
+ get_schema_tool = next((tool for tool in tools if tool.name == "sql_db_schema"), None)
169
+
170
+ # Define the workflow (state graph)
171
+ workflow = StateGraph(State)
172
+ workflow.add_node("first_tool_call", first_tool_call)
173
+ workflow.add_node("list_tables_tool", create_tool_node_with_fallback([list_tables_tool]))
174
+ workflow.add_node("get_schema_tool", create_tool_node_with_fallback([get_schema_tool]))
175
+ model_get_schema = llm.bind_tools([get_schema_tool])
176
+ workflow.add_node("model_get_schema", lambda state: {"messages": [model_get_schema.invoke(state["messages"])],})
177
+ workflow.add_node("query_gen", query_gen_node)
178
+ workflow.add_node("correct_query", model_check_query)
179
+ workflow.add_node("execute_query", create_tool_node_with_fallback([db_query_tool]))
180
+
181
+ workflow.add_edge(START, "first_tool_call")
182
+ workflow.add_edge("first_tool_call", "list_tables_tool")
183
+ workflow.add_edge("list_tables_tool", "model_get_schema")
184
+ workflow.add_edge("model_get_schema", "get_schema_tool")
185
+ workflow.add_edge("get_schema_tool", "query_gen")
186
+ workflow.add_conditional_edges("query_gen", should_continue)
187
+ workflow.add_edge("correct_query", "execute_query")
188
+ workflow.add_edge("execute_query", "query_gen")
189
+
190
+ # Compile the workflow into an agent application.
191
+ agent_app = workflow.compile()
192
+
193
+ # Initialize Flask and SocketIO
194
+ flask_app = Flask(__name__)
195
+ socketio = SocketIO(flask_app, cors_allowed_origins="*")
196
+
197
+ # Function to run the agent in a separate thread
198
+ def run_agent(prompt):
199
+ try:
200
+ query = {"messages": [("user", prompt)]}
201
+ result = agent_app.invoke(query)
202
+ result = result["messages"][-1].tool_calls[0]["args"]["final_answer"]
203
+ print("final_answer------>", result)
204
+ socketio.emit("final", {"message": f"{result}"})
205
+ except Exception as e:
206
+ socketio.emit("log", {"message": f"[ERROR]: {str(e)}"})
207
+ socketio.emit("final", {"message": "Generation failed."})
208
+
209
+ @flask_app.route("/")
210
+ def index():
211
+ return render_template("index.html")
212
+
213
+ @flask_app.route("/generate", methods=["POST"])
214
+ def generate():
215
+ data = request.json
216
+ prompt = data.get("prompt", "")
217
+ socketio.emit("log", {"message": f"[INFO]: Received prompt: {prompt}\n"})
218
+ # Run the agent in a separate thread
219
+ thread = threading.Thread(target=run_agent, args=(prompt,))
220
+ thread.start()
221
+ return "OK", 200
222
+
223
+ if __name__ == "__main__":
224
+ socketio.run(flask_app, debug=True)