ZennyKenny commited on
Commit
446ff99
·
verified ·
1 Parent(s): 4ea1ee8
Files changed (1) hide show
  1. app.py +45 -33
app.py CHANGED
@@ -5,11 +5,13 @@ from smolagents import tool, CodeAgent, HfApiModel
5
  import spaces
6
  import pandas as pd
7
  from database import engine, receipts
 
8
 
9
- # Fetch all data from the 'receipts' table
10
  def get_receipts_table():
11
  """
12
- Fetch all rows from the receipts table and return as a Pandas DataFrame.
 
 
13
  """
14
  try:
15
  with engine.connect() as con:
@@ -19,72 +21,85 @@ def get_receipts_table():
19
  if not rows:
20
  return pd.DataFrame(columns=["receipt_id", "customer_name", "price", "tip"])
21
 
22
- return pd.DataFrame(rows, columns=["receipt_id", "customer_name", "price", "tip"])
 
 
 
23
  except Exception as e:
24
- return pd.DataFrame({"Error": [str(e)]})
25
 
26
  @tool
27
  def sql_engine(query: str) -> str:
28
  """
29
- Executes an SQL query on the database and returns the result.
30
-
31
  Args:
32
- query (str): The SQL query to execute.
33
-
34
  Returns:
35
- str: Query result as a formatted string.
36
  """
37
  try:
38
  with engine.connect() as con:
39
  rows = con.execute(text(query)).fetchall()
40
-
41
  if not rows:
42
  return "No results found."
43
-
44
  if len(rows) == 1 and len(rows[0]) == 1:
45
- return str(rows[0][0])
46
-
47
  return "\n".join([", ".join(map(str, row)) for row in rows])
 
48
  except Exception as e:
49
  return f"Error: {str(e)}"
50
 
51
  def query_sql(user_query: str) -> str:
52
  """
53
- Converts a natural language query into an SQL statement and executes it.
54
-
55
  Args:
56
- user_query (str): A question or request in natural language to be converted into SQL.
57
-
58
  Returns:
59
- str: The execution result from the database.
60
  """
 
61
  schema_info = (
62
  "The database has a table named 'receipts' with the following schema:\n"
63
  "- receipt_id (INTEGER, primary key)\n"
64
  "- customer_name (VARCHAR(16))\n"
65
  "- price (FLOAT)\n"
66
  "- tip (FLOAT)\n"
67
- "Generate a valid SQL SELECT query using ONLY these column names."
68
  "DO NOT explain your reasoning, and DO NOT return anything other than the SQL query itself."
69
  )
70
-
71
  generated_sql = agent.run(f"{schema_info} Convert this request into SQL: {user_query}")
72
-
73
  if not isinstance(generated_sql, str):
74
- return f"{generated_sql}"
75
-
 
 
76
  if not generated_sql.strip().lower().startswith(("select", "show", "pragma")):
77
  return "Error: Only SELECT queries are allowed."
78
-
79
  result = sql_engine(generated_sql)
 
 
80
 
81
  try:
82
  float_result = float(result)
83
  return f"{float_result:.2f}"
84
  except ValueError:
85
- return result
86
 
87
  def handle_query(user_input: str) -> str:
 
 
 
 
 
 
 
88
  return query_sql(user_input)
89
 
90
  agent = CodeAgent(
@@ -95,29 +110,26 @@ agent = CodeAgent(
95
  with gr.Blocks() as demo:
96
  gr.Markdown("""
97
  ## Plain Text Query Interface
98
-
99
- This tool allows you to query a receipts database using natural language. Simply type your question into the input box, press **Run**, and the tool will generate and execute an SQL query to retrieve relevant data. The results will be displayed in the output box.
100
-
101
  ### Usage:
102
  1. Enter a question related to the receipts data in the text box.
103
- 2. Click **Run** to execute the query.
104
  3. The result will be displayed in the output box.
105
-
106
  > The current receipts table is also displayed for reference.
107
  """)
108
 
109
  with gr.Row():
110
  with gr.Column(scale=1):
111
  user_input = gr.Textbox(label="Ask a question about the data")
112
- run_button = gr.Button("Run", variant="primary") # Purple button
113
  query_output = gr.Textbox(label="Result")
114
 
115
  with gr.Column(scale=2):
116
  gr.Markdown("### Receipts Table")
117
  receipts_table = gr.Dataframe(value=get_receipts_table(), label="Receipts Table")
118
 
119
- run_button.click(fn=handle_query, inputs=user_input, outputs=query_output) # Trigger only on button press
 
120
  demo.load(fn=get_receipts_table, outputs=receipts_table)
121
 
122
  if __name__ == "__main__":
123
- demo.launch(server_name="0.0.0.0", server_port=7860, share=True)
 
5
  import spaces
6
  import pandas as pd
7
  from database import engine, receipts
8
+ import pandas as pd
9
 
 
10
  def get_receipts_table():
11
  """
12
+ Fetches all data from the 'receipts' table and returns it as a Pandas DataFrame.
13
+ Returns:
14
+ A Pandas DataFrame containing all receipt data.
15
  """
16
  try:
17
  with engine.connect() as con:
 
21
  if not rows:
22
  return pd.DataFrame(columns=["receipt_id", "customer_name", "price", "tip"])
23
 
24
+ # Convert rows into a DataFrame
25
+ df = pd.DataFrame(rows, columns=["receipt_id", "customer_name", "price", "tip"])
26
+ return df
27
+
28
  except Exception as e:
29
+ return pd.DataFrame({"Error": [str(e)]}) # Return error message in DataFrame format
30
 
31
  @tool
32
  def sql_engine(query: str) -> str:
33
  """
34
+ Executes an SQL query on the 'receipts' table and returns formatted results.
 
35
  Args:
36
+ query: The SQL query to execute.
 
37
  Returns:
38
+ Query result as a formatted string.
39
  """
40
  try:
41
  with engine.connect() as con:
42
  rows = con.execute(text(query)).fetchall()
43
+
44
  if not rows:
45
  return "No results found."
46
+
47
  if len(rows) == 1 and len(rows[0]) == 1:
48
+ return str(rows[0][0]) # Convert numerical result to string
49
+
50
  return "\n".join([", ".join(map(str, row)) for row in rows])
51
+
52
  except Exception as e:
53
  return f"Error: {str(e)}"
54
 
55
  def query_sql(user_query: str) -> str:
56
  """
57
+ Converts natural language input to an SQL query using CodeAgent
58
+ and returns the execution results.
59
  Args:
60
+ user_query: The user's request in natural language.
 
61
  Returns:
62
+ The query result from the database as a formatted string.
63
  """
64
+
65
  schema_info = (
66
  "The database has a table named 'receipts' with the following schema:\n"
67
  "- receipt_id (INTEGER, primary key)\n"
68
  "- customer_name (VARCHAR(16))\n"
69
  "- price (FLOAT)\n"
70
  "- tip (FLOAT)\n"
71
+ "Generate a valid SQL SELECT query using ONLY these column names.\n"
72
  "DO NOT explain your reasoning, and DO NOT return anything other than the SQL query itself."
73
  )
74
+
75
  generated_sql = agent.run(f"{schema_info} Convert this request into SQL: {user_query}")
76
+
77
  if not isinstance(generated_sql, str):
78
+ return f"{generated_sql}" # Handle unexpected numerical result
79
+
80
+ print(f"{generated_sql}")
81
+
82
  if not generated_sql.strip().lower().startswith(("select", "show", "pragma")):
83
  return "Error: Only SELECT queries are allowed."
84
+
85
  result = sql_engine(generated_sql)
86
+
87
+ print(f"{result}")
88
 
89
  try:
90
  float_result = float(result)
91
  return f"{float_result:.2f}"
92
  except ValueError:
93
+ return result
94
 
95
  def handle_query(user_input: str) -> str:
96
+ """
97
+ Calls query_sql, captures the output, and directly returns it to the UI.
98
+ Args:
99
+ user_input: The user's natural language question.
100
+ Returns:
101
+ The SQL query result as a plain string to be displayed in the UI.
102
+ """
103
  return query_sql(user_input)
104
 
105
  agent = CodeAgent(
 
110
  with gr.Blocks() as demo:
111
  gr.Markdown("""
112
  ## Plain Text Query Interface
113
+ This tool allows you to query a receipts database using natural language. Simply type your question into the input box, and the tool will generate and execute an SQL query to retrieve relevant data. The results will be displayed in the output box.
 
 
114
  ### Usage:
115
  1. Enter a question related to the receipts data in the text box.
116
+ 2. The tool will convert your question into an SQL query and execute it.
117
  3. The result will be displayed in the output box.
 
118
  > The current receipts table is also displayed for reference.
119
  """)
120
 
121
  with gr.Row():
122
  with gr.Column(scale=1):
123
  user_input = gr.Textbox(label="Ask a question about the data")
 
124
  query_output = gr.Textbox(label="Result")
125
 
126
  with gr.Column(scale=2):
127
  gr.Markdown("### Receipts Table")
128
  receipts_table = gr.Dataframe(value=get_receipts_table(), label="Receipts Table")
129
 
130
+ user_input.change(fn=handle_query, inputs=user_input, outputs=query_output)
131
+
132
  demo.load(fn=get_receipts_table, outputs=receipts_table)
133
 
134
  if __name__ == "__main__":
135
+ demo.launch(server_name="0.0.0.0", server_port=7860, share=True)