ZennyKenny commited on
Commit
d39bf30
·
verified ·
1 Parent(s): 15f10e9

Support on-demand run rather than real time

Browse files
Files changed (1) hide show
  1. app.py +18 -61
app.py CHANGED
@@ -5,15 +5,9 @@ from smolagents import tool, CodeAgent, HfApiModel
5
  import spaces
6
  import pandas as pd
7
  from database import engine, receipts
8
- import pandas as pd
9
 
 
10
  def get_receipts_table():
11
- """
12
- Fetches all data from the 'receipts' table and returns it as a Pandas DataFrame.
13
-
14
- Returns:
15
- A Pandas DataFrame containing all receipt data.
16
- """
17
  try:
18
  with engine.connect() as con:
19
  result = con.execute(text("SELECT * FROM receipts"))
@@ -22,91 +16,54 @@ def get_receipts_table():
22
  if not rows:
23
  return pd.DataFrame(columns=["receipt_id", "customer_name", "price", "tip"])
24
 
25
- # Convert rows into a DataFrame
26
- df = pd.DataFrame(rows, columns=["receipt_id", "customer_name", "price", "tip"])
27
- return df
28
-
29
  except Exception as e:
30
- return pd.DataFrame({"Error": [str(e)]}) # Return error message in DataFrame format
31
 
32
  @tool
33
  def sql_engine(query: str) -> str:
34
- """
35
- Executes an SQL query on the 'receipts' table and returns formatted results.
36
-
37
- Args:
38
- query: The SQL query to execute.
39
-
40
- Returns:
41
- Query result as a formatted string.
42
- """
43
  try:
44
  with engine.connect() as con:
45
  rows = con.execute(text(query)).fetchall()
46
-
47
  if not rows:
48
  return "No results found."
49
-
50
  if len(rows) == 1 and len(rows[0]) == 1:
51
- return str(rows[0][0]) # Convert numerical result to string
52
-
53
  return "\n".join([", ".join(map(str, row)) for row in rows])
54
-
55
  except Exception as e:
56
  return f"Error: {str(e)}"
57
 
58
  def query_sql(user_query: str) -> str:
59
- """
60
- Converts natural language input to an SQL query using CodeAgent
61
- and returns the execution results.
62
-
63
- Args:
64
- user_query: The user's request in natural language.
65
-
66
- Returns:
67
- The query result from the database as a formatted string.
68
- """
69
-
70
  schema_info = (
71
  "The database has a table named 'receipts' with the following schema:\n"
72
  "- receipt_id (INTEGER, primary key)\n"
73
  "- customer_name (VARCHAR(16))\n"
74
  "- price (FLOAT)\n"
75
  "- tip (FLOAT)\n"
76
- "Generate a valid SQL SELECT query using ONLY these column names.\n"
77
  "DO NOT explain your reasoning, and DO NOT return anything other than the SQL query itself."
78
  )
79
-
80
  generated_sql = agent.run(f"{schema_info} Convert this request into SQL: {user_query}")
81
-
82
  if not isinstance(generated_sql, str):
83
- return f"{generated_sql}" # Handle unexpected numerical result
84
-
85
- print(f"{generated_sql}")
86
-
87
  if not generated_sql.strip().lower().startswith(("select", "show", "pragma")):
88
  return "Error: Only SELECT queries are allowed."
89
-
90
  result = sql_engine(generated_sql)
91
-
92
- print(f"{result}")
93
 
94
  try:
95
  float_result = float(result)
96
  return f"{float_result:.2f}"
97
  except ValueError:
98
- return result
99
 
100
  def handle_query(user_input: str) -> str:
101
- """
102
- Calls query_sql, captures the output, and directly returns it to the UI.
103
-
104
- Args:
105
- user_input: The user's natural language question.
106
-
107
- Returns:
108
- The SQL query result as a plain string to be displayed in the UI.
109
- """
110
  return query_sql(user_input)
111
 
112
  agent = CodeAgent(
@@ -118,11 +75,11 @@ with gr.Blocks() as demo:
118
  gr.Markdown("""
119
  ## Plain Text Query Interface
120
 
121
- This tool allows you to query a receipts database using natural language. Simply type your question into the input box, and the tool will generate and execute an SQL query to retrieve relevant data. The results will be displayed in the output box.
122
 
123
  ### Usage:
124
  1. Enter a question related to the receipts data in the text box.
125
- 2. The tool will convert your question into an SQL query and execute it.
126
  3. The result will be displayed in the output box.
127
 
128
  > The current receipts table is also displayed for reference.
@@ -131,14 +88,14 @@ This tool allows you to query a receipts database using natural language. Simply
131
  with gr.Row():
132
  with gr.Column(scale=1):
133
  user_input = gr.Textbox(label="Ask a question about the data")
 
134
  query_output = gr.Textbox(label="Result")
135
 
136
  with gr.Column(scale=2):
137
  gr.Markdown("### Receipts Table")
138
  receipts_table = gr.Dataframe(value=get_receipts_table(), label="Receipts Table")
139
 
140
- user_input.change(fn=handle_query, inputs=user_input, outputs=query_output)
141
-
142
  demo.load(fn=get_receipts_table, outputs=receipts_table)
143
 
144
  if __name__ == "__main__":
 
5
  import spaces
6
  import pandas as pd
7
  from database import engine, receipts
 
8
 
9
+ # Fetch all data from the 'receipts' table
10
  def get_receipts_table():
 
 
 
 
 
 
11
  try:
12
  with engine.connect() as con:
13
  result = con.execute(text("SELECT * FROM receipts"))
 
16
  if not rows:
17
  return pd.DataFrame(columns=["receipt_id", "customer_name", "price", "tip"])
18
 
19
+ return pd.DataFrame(rows, columns=["receipt_id", "customer_name", "price", "tip"])
 
 
 
20
  except Exception as e:
21
+ return pd.DataFrame({"Error": [str(e)]})
22
 
23
  @tool
24
  def sql_engine(query: str) -> str:
 
 
 
 
 
 
 
 
 
25
  try:
26
  with engine.connect() as con:
27
  rows = con.execute(text(query)).fetchall()
28
+
29
  if not rows:
30
  return "No results found."
31
+
32
  if len(rows) == 1 and len(rows[0]) == 1:
33
+ return str(rows[0][0])
34
+
35
  return "\n".join([", ".join(map(str, row)) for row in rows])
 
36
  except Exception as e:
37
  return f"Error: {str(e)}"
38
 
39
  def query_sql(user_query: str) -> str:
 
 
 
 
 
 
 
 
 
 
 
40
  schema_info = (
41
  "The database has a table named 'receipts' with the following schema:\n"
42
  "- receipt_id (INTEGER, primary key)\n"
43
  "- customer_name (VARCHAR(16))\n"
44
  "- price (FLOAT)\n"
45
  "- tip (FLOAT)\n"
46
+ "Generate a valid SQL SELECT query using ONLY these column names."
47
  "DO NOT explain your reasoning, and DO NOT return anything other than the SQL query itself."
48
  )
49
+
50
  generated_sql = agent.run(f"{schema_info} Convert this request into SQL: {user_query}")
51
+
52
  if not isinstance(generated_sql, str):
53
+ return f"{generated_sql}"
54
+
 
 
55
  if not generated_sql.strip().lower().startswith(("select", "show", "pragma")):
56
  return "Error: Only SELECT queries are allowed."
57
+
58
  result = sql_engine(generated_sql)
 
 
59
 
60
  try:
61
  float_result = float(result)
62
  return f"{float_result:.2f}"
63
  except ValueError:
64
+ return result
65
 
66
  def handle_query(user_input: str) -> str:
 
 
 
 
 
 
 
 
 
67
  return query_sql(user_input)
68
 
69
  agent = CodeAgent(
 
75
  gr.Markdown("""
76
  ## Plain Text Query Interface
77
 
78
+ This tool allows you to query a receipts database using natural language. Simply type your question into the input box, press **Run**, and the tool will generate and execute an SQL query to retrieve relevant data. The results will be displayed in the output box.
79
 
80
  ### Usage:
81
  1. Enter a question related to the receipts data in the text box.
82
+ 2. Click **Run** to execute the query.
83
  3. The result will be displayed in the output box.
84
 
85
  > The current receipts table is also displayed for reference.
 
88
  with gr.Row():
89
  with gr.Column(scale=1):
90
  user_input = gr.Textbox(label="Ask a question about the data")
91
+ run_button = gr.Button("Run", variant="primary") # Purple button
92
  query_output = gr.Textbox(label="Result")
93
 
94
  with gr.Column(scale=2):
95
  gr.Markdown("### Receipts Table")
96
  receipts_table = gr.Dataframe(value=get_receipts_table(), label="Receipts Table")
97
 
98
+ run_button.click(fn=handle_query, inputs=user_input, outputs=query_output) # Trigger only on button press
 
99
  demo.load(fn=get_receipts_table, outputs=receipts_table)
100
 
101
  if __name__ == "__main__":