Spaces:
Sleeping
Sleeping
File size: 3,877 Bytes
6a0ec6a 6c37e10 6a0ec6a 4f04c00 20e319d 6c37e10 1767e22 d39bf30 20e319d 4a3afbe 20e319d d39bf30 20e319d d39bf30 20e319d 042246b 4a3afbe 042246b d39bf30 042246b d39bf30 10e2935 d39bf30 042246b 7306c07 35cddc5 d39bf30 61d9b40 35cddc5 d39bf30 61d9b40 d39bf30 2443195 d39bf30 f8c651a d39bf30 2e81bab 2443195 10e2935 2443195 10e2935 d39bf30 1f7ee11 215368b 1df3c5d 1f7ee11 1df3c5d 1f7ee11 1767e22 18bb121 d39bf30 18bb121 15f10e9 18bb121 d39bf30 18bb121 15f10e9 18bb121 1767e22 1df3c5d b9e14b1 d39bf30 b9e14b1 1767e22 1df3c5d b9e14b1 1767e22 d39bf30 8be3748 6a0ec6a 0380e03 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 |
import os
import gradio as gr
from sqlalchemy import text
from smolagents import tool, CodeAgent, HfApiModel
import spaces
import pandas as pd
from database import engine, receipts
# Fetch all data from the 'receipts' table
def get_receipts_table():
"""
Fetch all rows from the receipts table and return as a Pandas DataFrame.
"""
try:
with engine.connect() as con:
result = con.execute(text("SELECT * FROM receipts"))
rows = result.fetchall()
if not rows:
return pd.DataFrame(columns=["receipt_id", "customer_name", "price", "tip"])
return pd.DataFrame(rows, columns=["receipt_id", "customer_name", "price", "tip"])
except Exception as e:
return pd.DataFrame({"Error": [str(e)]})
@tool
def sql_engine(query: str) -> str:
"""
Executes an SQL query on the database and returns the result.
Args:
query (str): The SQL query to execute.
Returns:
str: Query result as a formatted string.
"""
try:
with engine.connect() as con:
rows = con.execute(text(query)).fetchall()
if not rows:
return "No results found."
if len(rows) == 1 and len(rows[0]) == 1:
return str(rows[0][0])
return "\n".join([", ".join(map(str, row)) for row in rows])
except Exception as e:
return f"Error: {str(e)}"
def query_sql(user_query: str) -> str:
schema_info = (
"The database has a table named 'receipts' with the following schema:\n"
"- receipt_id (INTEGER, primary key)\n"
"- customer_name (VARCHAR(16))\n"
"- price (FLOAT)\n"
"- tip (FLOAT)\n"
"Generate a valid SQL SELECT query using ONLY these column names."
"DO NOT explain your reasoning, and DO NOT return anything other than the SQL query itself."
)
generated_sql = agent.run(f"{schema_info} Convert this request into SQL: {user_query}")
if not isinstance(generated_sql, str):
return f"{generated_sql}"
if not generated_sql.strip().lower().startswith(("select", "show", "pragma")):
return "Error: Only SELECT queries are allowed."
result = sql_engine(generated_sql)
try:
float_result = float(result)
return f"{float_result:.2f}"
except ValueError:
return result
def handle_query(user_input: str) -> str:
return query_sql(user_input)
agent = CodeAgent(
tools=[sql_engine],
model=HfApiModel(model_id="Qwen/Qwen2.5-Coder-32B-Instruct"),
)
with gr.Blocks() as demo:
gr.Markdown("""
## Plain Text Query Interface
This tool allows you to query a receipts database using natural language. Simply type your question into the input box, press **Run**, and the tool will generate and execute an SQL query to retrieve relevant data. The results will be displayed in the output box.
### Usage:
1. Enter a question related to the receipts data in the text box.
2. Click **Run** to execute the query.
3. The result will be displayed in the output box.
> The current receipts table is also displayed for reference.
""")
with gr.Row():
with gr.Column(scale=1):
user_input = gr.Textbox(label="Ask a question about the data")
run_button = gr.Button("Run", variant="primary") # Purple button
query_output = gr.Textbox(label="Result")
with gr.Column(scale=2):
gr.Markdown("### Receipts Table")
receipts_table = gr.Dataframe(value=get_receipts_table(), label="Receipts Table")
run_button.click(fn=handle_query, inputs=user_input, outputs=query_output) # Trigger only on button press
demo.load(fn=get_receipts_table, outputs=receipts_table)
if __name__ == "__main__":
demo.launch(server_name="0.0.0.0", server_port=7860, share=True)
|