Spaces:
Running
Running
File size: 3,436 Bytes
6a0ec6a 6c37e10 6a0ec6a 4f04c00 6a0ec6a 79f396e 6c37e10 042246b 215368b 042246b 215368b 042246b 6a0ec6a 7306c07 edb7e14 7306c07 edb7e14 7306c07 edb7e14 7306c07 61d9b40 35cddc5 61d9b40 35cddc5 61d9b40 62d0b77 61d9b40 215368b 1f7ee11 215368b edb7e14 215368b edb7e14 215368b 1f7ee11 215368b 0380e03 215368b 7306c07 6a0ec6a 7306c07 c6d6658 6a0ec6a 0380e03 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 |
import os
import gradio as gr
from sqlalchemy import text
from smolagents import tool, CodeAgent, HfApiModel
import spaces
# Import the persistent database
from database import engine, receipts
@tool
def sql_engine(query: str) -> str:
"""
Executes an SQL query on the 'receipts' table and returns formatted results.
Args:
query: The SQL query to execute.
Returns:
Query result as a formatted string.
"""
try:
with engine.connect() as con:
rows = con.execute(text(query)).fetchall()
if not rows:
return "No results found."
# Convert query results into a clean, readable format
return "\n".join([", ".join(map(str, row)) for row in rows])
except Exception as e:
return f"Error: {str(e)}"
@tool
def query_sql(user_query: str) -> str:
"""
Converts natural language input to an SQL query using CodeAgent
and returns the execution results.
Args:
user_query: The user's request in natural language.
Returns:
The query result from the database as a formatted string.
"""
# Provide the AI with the correct schema and strict instructions
schema_info = (
"The database has a table named 'receipts' with the following schema:\n"
"- receipt_id (INTEGER, primary key)\n"
"- customer_name (VARCHAR(16))\n"
"- price (FLOAT)\n"
"- tip (FLOAT)\n"
"Generate a valid SQL SELECT query using ONLY these column names.\n"
"DO NOT explain your reasoning, and DO NOT return anything other than the SQL query itself."
)
# Generate SQL query using the provided schema
generated_sql = agent.run(f"{schema_info} Convert this request into SQL: {user_query}")
# Log the generated SQL for debugging
print(f"Generated SQL: {generated_sql}")
# Ensure we only execute valid SELECT queries
# if not generated_sql.strip().lower().startswith(("select", "show", "pragma")):
# return "Error: Only SELECT queries are allowed."
# Execute the SQL query and return the result
result = sql_engine(generated_sql)
# Log the SQL query result
print(f"SQL Query Result: {result}")
return result # Return only the final query result, NOT the generated SQL
def handle_query(user_input: str) -> str:
"""
Calls query_sql, captures the output, and directly returns it to the UI.
Args:
user_input: The user's natural language question.
Returns:
The SQL query result as a plain string to be displayed in the UI.
"""
return query_sql(user_input) # Directly return the processed result
# Initialize CodeAgent to generate SQL queries from natural language
agent = CodeAgent(
tools=[sql_engine], # Ensure sql_engine is properly registered
model=HfApiModel(model_id="Qwen/Qwen2.5-Coder-32B-Instruct"),
)
# Define Gradio interface using handle_query instead of query_sql
demo = gr.Interface(
fn=handle_query, # Call handle_query to return the final SQL output
inputs=gr.Textbox(label="Enter your query in plain English"),
outputs=gr.Textbox(label="Query Result"),
title="Natural Language to SQL Executor",
description="Enter a plain English request, and the AI will generate an SQL query and return the results.",
flagging_mode="never",
)
if __name__ == "__main__":
demo.launch(server_name="0.0.0.0", server_port=7860, share=True)
|