Upload sql-react-agent-mcp.py
Browse files- sql-react-agent-mcp.py +134 -0
sql-react-agent-mcp.py
ADDED
@@ -0,0 +1,134 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import dspy
|
3 |
+
import mlflow
|
4 |
+
import asyncio
|
5 |
+
|
6 |
+
from mcp import ClientSession
|
7 |
+
from mcp.client.streamable_http import streamablehttp_client
|
8 |
+
|
9 |
+
lm = dspy.LM(
|
10 |
+
model='openai/gpt-4o-mini',
|
11 |
+
temperature=0,
|
12 |
+
api_key=os.environ['OPENAI_API_KEY'],
|
13 |
+
api_base=os.environ['OPENAI_BASE_URL']
|
14 |
+
)
|
15 |
+
|
16 |
+
mcp_url = "https://pgurazada1-credit-card-database-mcp-server.hf.space/mcp/"
|
17 |
+
|
18 |
+
# IMPORTANT: Set your Hugging Face user access token in the environment variable HF_TOKEN
|
19 |
+
HF_TOKEN = os.environ.get("HUGGINGFACE_API_KEY")
|
20 |
+
if not HF_TOKEN:
|
21 |
+
raise RuntimeError("Please set your Hugging Face user access token in the HF_TOKEN environment variable.")
|
22 |
+
|
23 |
+
|
24 |
+
dspy.configure(lm=lm)
|
25 |
+
|
26 |
+
mlflow.dspy.autolog()
|
27 |
+
mlflow.set_experiment('sql-react-agent-http')
|
28 |
+
|
29 |
+
|
30 |
+
class QueryResponse(dspy.Signature):
|
31 |
+
"""
|
32 |
+
You are an expert AI assistant specialized in generating and executing SQLite queries against a database.
|
33 |
+
Your primary goal is to accurately answer user questions based *only* on the data retrieved. You must be methodical in exploring the database structure.
|
34 |
+
|
35 |
+
<Schema Exploration and Join Path Strategy>
|
36 |
+
1. **List All Tables:** Always start with `sql_db_list_tables`.
|
37 |
+
2. **Identify Potential Tables:** List tables potentially holding the requested entities (e.g., cities, merchants) and metrics (e.g., spend). Also, identify tables that might *link* these entities (often containing ID columns like `cust_id`, `CARD_ID`, `M_ID`).
|
38 |
+
3. **Get Schemas Systematically:** Use `sql_db_schema` to get schemas for *all* tables identified in step 2. This is crucial. Do not skip potential linking tables.
|
39 |
+
4. **Map the Join Path:**
|
40 |
+
* Explicitly identify the column containing the primary metric (e.g., `transaction.TX_AMOUNT`).
|
41 |
+
* Explicitly identify the column containing the target entity (e.g., `customer.city`).
|
42 |
+
* **CRITICAL:** Trace the connections between these tables using ID columns revealed in the schemas. Look for sequences like `tableA.ID -> tableB.tableA_ID`, `tableB.ID -> tableC.tableB_ID`.
|
43 |
+
* **Example Path:** To link transaction spend to customer city, you MUST verify the path: `transaction.CARD_ID` links to `card.card_number`, AND `card.cust_id` links to `customer.cust_id`. You **MUST** request the schema for the `card` table to confirm this.
|
44 |
+
* **State the Path:** Before writing the query, state the full join path you intend to use (e.g., "Found path: transaction JOIN card ON transaction.CARD_ID = card.card_number JOIN customer ON card.cust_id = customer.cust_id").
|
45 |
+
5. **Verify Columns:** Double-check that *every* column used in your intended SELECT, JOIN, WHERE, GROUP BY, or ORDER BY clauses exists in the schemas you retrieved.
|
46 |
+
</Schema Exploration and Join Path Strategy>
|
47 |
+
|
48 |
+
<Query Construction and Execution>
|
49 |
+
6. **Construct Query:** Build the SQLite query using the verified tables, columns, and the full, correct join path.
|
50 |
+
* Use explicit JOIN clauses (INNER JOIN is usually appropriate unless otherwise specified).
|
51 |
+
* Quote identifiers (like `"transaction"`) if they are keywords or contain special characters.
|
52 |
+
* Select only necessary columns. Alias columns for clarity if needed (e.g., `SUM(t.TX_AMOUNT) AS total_spend`).
|
53 |
+
* Include calculations like percentage contribution if requested. The total sum for percentage calculation should be derived correctly (e.g., `(SELECT SUM(TX_AMOUNT) FROM "transaction")`).
|
54 |
+
* Apply `GROUP BY` to the target entity column (e.g., `c.city`).
|
55 |
+
* Apply `ORDER BY` and `LIMIT 5` (unless otherwise specified).
|
56 |
+
7. **Validate Query:** Use `sql_db_query_checker`. Revise if syntax errors occur.
|
57 |
+
8. **Execute Query:** Use `sql_db_query`.
|
58 |
+
9. **Formulate Answer:** Base the final answer *strictly* on the query results. If the query returns no results *after* confirming a valid join path and correct syntax, state that no data matching the criteria was found.
|
59 |
+
10. **Handle Missing Information:** If, after thorough schema exploration (including checking potential linking tables), you cannot find the requested column (e.g., 'country') or a valid join path, *then and only then* inform the user the data is unavailable. Do not substitute unrelated columns.
|
60 |
+
11. **Final Answer Only:** Provide the answer directly without further tool calls once results are obtained.
|
61 |
+
</Query Construction and Execution>
|
62 |
+
|
63 |
+
<General Restrictions>
|
64 |
+
1. DO NOT make any DML statements (INSERT, UPDATE, DELETE, DROP etc.).
|
65 |
+
2. DO NOT MAKE UP ANSWERS.
|
66 |
+
</General Restrictions>
|
67 |
+
"""
|
68 |
+
|
69 |
+
query: str = dspy.InputField()
|
70 |
+
answer: str = dspy.OutputField(desc="The generated response to the customer query.")
|
71 |
+
|
72 |
+
|
73 |
+
async def respond(query):
|
74 |
+
async with streamablehttp_client(
|
75 |
+
url=mcp_url,
|
76 |
+
headers={"Authorization": f"Bearer {HF_TOKEN}"}
|
77 |
+
) as (read, write, _):
|
78 |
+
async with ClientSession(read, write) as session:
|
79 |
+
# Initialize the connection
|
80 |
+
await session.initialize()
|
81 |
+
# List available tools
|
82 |
+
tools_output = await session.list_tools()
|
83 |
+
|
84 |
+
# Convert MCP tools to DSPy tools
|
85 |
+
dspy_tools = []
|
86 |
+
for tool in tools_output.tools:
|
87 |
+
dspy_tools.append(dspy.Tool.from_mcp_tool(session, tool))
|
88 |
+
|
89 |
+
# Create the agent
|
90 |
+
react_agent = dspy.ReAct(QueryResponse, tools=dspy_tools, max_iters=10)
|
91 |
+
|
92 |
+
output = await react_agent.acall(query=query)
|
93 |
+
return output
|
94 |
+
|
95 |
+
|
96 |
+
# Example 1
|
97 |
+
|
98 |
+
user_query = "Who are the top 5 merchants by total number of transactions?"
|
99 |
+
pred = asyncio.run(respond(user_query))
|
100 |
+
|
101 |
+
print(pred.answer)
|
102 |
+
|
103 |
+
# Example 2
|
104 |
+
|
105 |
+
user_query = "Which is the highest spend month and amount for each card type?"
|
106 |
+
pred = asyncio.run(respond(user_query))
|
107 |
+
|
108 |
+
print(pred.answer)
|
109 |
+
|
110 |
+
# Example 3
|
111 |
+
|
112 |
+
user_query = "Which are the top 5 cities with the highest spend and what is their percentage contribution to overall spends?"
|
113 |
+
pred = asyncio.run(respond(user_query))
|
114 |
+
|
115 |
+
print(pred.answer)
|
116 |
+
|
117 |
+
# Parallelism
|
118 |
+
|
119 |
+
async def main():
|
120 |
+
user_queries = [
|
121 |
+
"Who are the top 5 merchants by total transactions?",
|
122 |
+
"Which is the highest spend month and amount for each card type?",
|
123 |
+
"Which are the top 5 cities with the highest spend and what is their percentage contribution to overall spends?"
|
124 |
+
]
|
125 |
+
|
126 |
+
tasks_to_run = [respond(query) for query in user_queries]
|
127 |
+
results = await asyncio.gather(*tasks_to_run)
|
128 |
+
|
129 |
+
return results
|
130 |
+
|
131 |
+
results = asyncio.run(main())
|
132 |
+
|
133 |
+
for result in results:
|
134 |
+
print(result.answer)
|