File size: 3,877 Bytes
6a0ec6a
 
6c37e10
6a0ec6a
4f04c00
20e319d
6c37e10
1767e22
d39bf30
20e319d
4a3afbe
 
 
20e319d
 
 
 
 
 
 
 
d39bf30
20e319d
d39bf30
20e319d
042246b
 
4a3afbe
 
 
 
 
 
 
 
 
042246b
 
 
d39bf30
042246b
 
d39bf30
10e2935
d39bf30
 
042246b
 
 
 
7306c07
35cddc5
 
 
 
 
 
d39bf30
61d9b40
35cddc5
d39bf30
61d9b40
d39bf30
2443195
d39bf30
 
f8c651a
 
d39bf30
2e81bab
2443195
10e2935
 
2443195
10e2935
d39bf30
1f7ee11
215368b
1df3c5d
1f7ee11
 
1df3c5d
1f7ee11
 
 
1767e22
18bb121
 
 
d39bf30
18bb121
15f10e9
18bb121
d39bf30
18bb121
15f10e9
 
18bb121
1767e22
 
1df3c5d
b9e14b1
d39bf30
b9e14b1
1767e22
1df3c5d
b9e14b1
1767e22
 
d39bf30
8be3748
6a0ec6a
 
0380e03
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
import os
import gradio as gr
from sqlalchemy import text
from smolagents import tool, CodeAgent, HfApiModel
import spaces
import pandas as pd
from database import engine, receipts

# Fetch all data from the 'receipts' table
def get_receipts_table():
    """
    Fetch all rows from the receipts table and return as a Pandas DataFrame.
    """
    try:
        with engine.connect() as con:
            result = con.execute(text("SELECT * FROM receipts"))
            rows = result.fetchall()
        
        if not rows:
            return pd.DataFrame(columns=["receipt_id", "customer_name", "price", "tip"])

        return pd.DataFrame(rows, columns=["receipt_id", "customer_name", "price", "tip"])
    except Exception as e:
        return pd.DataFrame({"Error": [str(e)]})

@tool
def sql_engine(query: str) -> str:
    """
    Executes an SQL query on the database and returns the result.

    Args:
        query (str): The SQL query to execute.
    
    Returns:
        str: Query result as a formatted string.
    """
    try:
        with engine.connect() as con:
            rows = con.execute(text(query)).fetchall()
        
        if not rows:
            return "No results found."
        
        if len(rows) == 1 and len(rows[0]) == 1:
            return str(rows[0][0])
        
        return "\n".join([", ".join(map(str, row)) for row in rows])
    except Exception as e:
        return f"Error: {str(e)}"

def query_sql(user_query: str) -> str:
    schema_info = (
        "The database has a table named 'receipts' with the following schema:\n"
        "- receipt_id (INTEGER, primary key)\n"
        "- customer_name (VARCHAR(16))\n"
        "- price (FLOAT)\n"
        "- tip (FLOAT)\n"
        "Generate a valid SQL SELECT query using ONLY these column names."
        "DO NOT explain your reasoning, and DO NOT return anything other than the SQL query itself."
    )
    
    generated_sql = agent.run(f"{schema_info} Convert this request into SQL: {user_query}")
    
    if not isinstance(generated_sql, str):
        return f"{generated_sql}"
    
    if not generated_sql.strip().lower().startswith(("select", "show", "pragma")):
        return "Error: Only SELECT queries are allowed."
    
    result = sql_engine(generated_sql)
    
    try:
        float_result = float(result)
        return f"{float_result:.2f}"
    except ValueError:
        return result

def handle_query(user_input: str) -> str:
    return query_sql(user_input)

agent = CodeAgent(
    tools=[sql_engine],
    model=HfApiModel(model_id="Qwen/Qwen2.5-Coder-32B-Instruct"),
)

with gr.Blocks() as demo:
    gr.Markdown("""
## Plain Text Query Interface

This tool allows you to query a receipts database using natural language. Simply type your question into the input box, press **Run**, and the tool will generate and execute an SQL query to retrieve relevant data. The results will be displayed in the output box.

### Usage:
1. Enter a question related to the receipts data in the text box.
2. Click **Run** to execute the query.
3. The result will be displayed in the output box.

> The current receipts table is also displayed for reference.
""")

    with gr.Row():
        with gr.Column(scale=1):
            user_input = gr.Textbox(label="Ask a question about the data")
            run_button = gr.Button("Run", variant="primary")  # Purple button
            query_output = gr.Textbox(label="Result")
        
        with gr.Column(scale=2):
            gr.Markdown("### Receipts Table")
            receipts_table = gr.Dataframe(value=get_receipts_table(), label="Receipts Table")

    run_button.click(fn=handle_query, inputs=user_input, outputs=query_output)  # Trigger only on button press
    demo.load(fn=get_receipts_table, outputs=receipts_table)

if __name__ == "__main__":
    demo.launch(server_name="0.0.0.0", server_port=7860, share=True)