pgurazada1 commited on
Commit
e43ed26
·
verified ·
1 Parent(s): 7db5019

Upload sql-react-agent-mcp.py

Browse files
Files changed (1) hide show
  1. sql-react-agent-mcp.py +134 -0
sql-react-agent-mcp.py ADDED
@@ -0,0 +1,134 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import dspy
3
+ import mlflow
4
+ import asyncio
5
+
6
+ from mcp import ClientSession
7
+ from mcp.client.streamable_http import streamablehttp_client
8
+
9
+ lm = dspy.LM(
10
+ model='openai/gpt-4o-mini',
11
+ temperature=0,
12
+ api_key=os.environ['OPENAI_API_KEY'],
13
+ api_base=os.environ['OPENAI_BASE_URL']
14
+ )
15
+
16
+ mcp_url = "https://pgurazada1-credit-card-database-mcp-server.hf.space/mcp/"
17
+
18
+ # IMPORTANT: Set your Hugging Face user access token in the environment variable HF_TOKEN
19
+ HF_TOKEN = os.environ.get("HUGGINGFACE_API_KEY")
20
+ if not HF_TOKEN:
21
+ raise RuntimeError("Please set your Hugging Face user access token in the HF_TOKEN environment variable.")
22
+
23
+
24
+ dspy.configure(lm=lm)
25
+
26
+ mlflow.dspy.autolog()
27
+ mlflow.set_experiment('sql-react-agent-http')
28
+
29
+
30
+ class QueryResponse(dspy.Signature):
31
+ """
32
+ You are an expert AI assistant specialized in generating and executing SQLite queries against a database.
33
+ Your primary goal is to accurately answer user questions based *only* on the data retrieved. You must be methodical in exploring the database structure.
34
+
35
+ <Schema Exploration and Join Path Strategy>
36
+ 1. **List All Tables:** Always start with `sql_db_list_tables`.
37
+ 2. **Identify Potential Tables:** List tables potentially holding the requested entities (e.g., cities, merchants) and metrics (e.g., spend). Also, identify tables that might *link* these entities (often containing ID columns like `cust_id`, `CARD_ID`, `M_ID`).
38
+ 3. **Get Schemas Systematically:** Use `sql_db_schema` to get schemas for *all* tables identified in step 2. This is crucial. Do not skip potential linking tables.
39
+ 4. **Map the Join Path:**
40
+ * Explicitly identify the column containing the primary metric (e.g., `transaction.TX_AMOUNT`).
41
+ * Explicitly identify the column containing the target entity (e.g., `customer.city`).
42
+ * **CRITICAL:** Trace the connections between these tables using ID columns revealed in the schemas. Look for sequences like `tableA.ID -> tableB.tableA_ID`, `tableB.ID -> tableC.tableB_ID`.
43
+ * **Example Path:** To link transaction spend to customer city, you MUST verify the path: `transaction.CARD_ID` links to `card.card_number`, AND `card.cust_id` links to `customer.cust_id`. You **MUST** request the schema for the `card` table to confirm this.
44
+ * **State the Path:** Before writing the query, state the full join path you intend to use (e.g., "Found path: transaction JOIN card ON transaction.CARD_ID = card.card_number JOIN customer ON card.cust_id = customer.cust_id").
45
+ 5. **Verify Columns:** Double-check that *every* column used in your intended SELECT, JOIN, WHERE, GROUP BY, or ORDER BY clauses exists in the schemas you retrieved.
46
+ </Schema Exploration and Join Path Strategy>
47
+
48
+ <Query Construction and Execution>
49
+ 6. **Construct Query:** Build the SQLite query using the verified tables, columns, and the full, correct join path.
50
+ * Use explicit JOIN clauses (INNER JOIN is usually appropriate unless otherwise specified).
51
+ * Quote identifiers (like `"transaction"`) if they are keywords or contain special characters.
52
+ * Select only necessary columns. Alias columns for clarity if needed (e.g., `SUM(t.TX_AMOUNT) AS total_spend`).
53
+ * Include calculations like percentage contribution if requested. The total sum for percentage calculation should be derived correctly (e.g., `(SELECT SUM(TX_AMOUNT) FROM "transaction")`).
54
+ * Apply `GROUP BY` to the target entity column (e.g., `c.city`).
55
+ * Apply `ORDER BY` and `LIMIT 5` (unless otherwise specified).
56
+ 7. **Validate Query:** Use `sql_db_query_checker`. Revise if syntax errors occur.
57
+ 8. **Execute Query:** Use `sql_db_query`.
58
+ 9. **Formulate Answer:** Base the final answer *strictly* on the query results. If the query returns no results *after* confirming a valid join path and correct syntax, state that no data matching the criteria was found.
59
+ 10. **Handle Missing Information:** If, after thorough schema exploration (including checking potential linking tables), you cannot find the requested column (e.g., 'country') or a valid join path, *then and only then* inform the user the data is unavailable. Do not substitute unrelated columns.
60
+ 11. **Final Answer Only:** Provide the answer directly without further tool calls once results are obtained.
61
+ </Query Construction and Execution>
62
+
63
+ <General Restrictions>
64
+ 1. DO NOT make any DML statements (INSERT, UPDATE, DELETE, DROP etc.).
65
+ 2. DO NOT MAKE UP ANSWERS.
66
+ </General Restrictions>
67
+ """
68
+
69
+ query: str = dspy.InputField()
70
+ answer: str = dspy.OutputField(desc="The generated response to the customer query.")
71
+
72
+
73
+ async def respond(query):
74
+ async with streamablehttp_client(
75
+ url=mcp_url,
76
+ headers={"Authorization": f"Bearer {HF_TOKEN}"}
77
+ ) as (read, write, _):
78
+ async with ClientSession(read, write) as session:
79
+ # Initialize the connection
80
+ await session.initialize()
81
+ # List available tools
82
+ tools_output = await session.list_tools()
83
+
84
+ # Convert MCP tools to DSPy tools
85
+ dspy_tools = []
86
+ for tool in tools_output.tools:
87
+ dspy_tools.append(dspy.Tool.from_mcp_tool(session, tool))
88
+
89
+ # Create the agent
90
+ react_agent = dspy.ReAct(QueryResponse, tools=dspy_tools, max_iters=10)
91
+
92
+ output = await react_agent.acall(query=query)
93
+ return output
94
+
95
+
96
+ # Example 1
97
+
98
+ user_query = "Who are the top 5 merchants by total number of transactions?"
99
+ pred = asyncio.run(respond(user_query))
100
+
101
+ print(pred.answer)
102
+
103
+ # Example 2
104
+
105
+ user_query = "Which is the highest spend month and amount for each card type?"
106
+ pred = asyncio.run(respond(user_query))
107
+
108
+ print(pred.answer)
109
+
110
+ # Example 3
111
+
112
+ user_query = "Which are the top 5 cities with the highest spend and what is their percentage contribution to overall spends?"
113
+ pred = asyncio.run(respond(user_query))
114
+
115
+ print(pred.answer)
116
+
117
+ # Parallelism
118
+
119
+ async def main():
120
+ user_queries = [
121
+ "Who are the top 5 merchants by total transactions?",
122
+ "Which is the highest spend month and amount for each card type?",
123
+ "Which are the top 5 cities with the highest spend and what is their percentage contribution to overall spends?"
124
+ ]
125
+
126
+ tasks_to_run = [respond(query) for query in user_queries]
127
+ results = await asyncio.gather(*tasks_to_run)
128
+
129
+ return results
130
+
131
+ results = asyncio.run(main())
132
+
133
+ for result in results:
134
+ print(result.answer)