Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -2,10 +2,10 @@ import json
|
|
2 |
import gradio as gr
|
3 |
import duckdb
|
4 |
from functools import lru_cache
|
5 |
-
from transformers import pipeline
|
6 |
import pandas as pd
|
7 |
import plotly.express as px
|
8 |
import openai
|
|
|
9 |
|
10 |
# Load the Parquet dataset path
|
11 |
dataset_path = 'sample_contract_df.parquet' # Update with your Parquet file path
|
@@ -60,31 +60,37 @@ def load_dataset_schema():
|
|
60 |
finally:
|
61 |
con.close()
|
62 |
|
63 |
-
# Advanced Natural Language to SQL Parser using OpenAI's
|
64 |
def parse_query(nl_query):
|
65 |
"""
|
66 |
-
Converts a natural language query into SQL query using OpenAI GPT-3.
|
67 |
"""
|
68 |
-
openai.api_key = '
|
69 |
|
70 |
-
|
71 |
-
|
|
|
72 |
Schema:
|
73 |
{json.dumps(schema, indent=2)}
|
74 |
-
|
|
|
75 |
"{nl_query}"
|
76 |
"""
|
|
|
77 |
try:
|
78 |
-
response = openai.
|
79 |
-
|
80 |
-
|
|
|
|
|
|
|
81 |
temperature=0,
|
82 |
max_tokens=150,
|
83 |
top_p=1,
|
84 |
n=1,
|
85 |
stop=None
|
86 |
)
|
87 |
-
sql_query = response.choices[0].
|
88 |
return sql_query
|
89 |
except Exception as e:
|
90 |
return f"Error generating SQL query: {e}"
|
@@ -94,7 +100,10 @@ def detect_plot_intent(nl_query):
|
|
94 |
"""
|
95 |
Detects if the user's query involves plotting.
|
96 |
"""
|
97 |
-
plot_keywords = [
|
|
|
|
|
|
|
98 |
for keyword in plot_keywords:
|
99 |
if keyword in nl_query.lower():
|
100 |
return True
|
@@ -108,12 +117,13 @@ def generate_sql_and_plot_code(query):
|
|
108 |
is_plot = detect_plot_intent(query)
|
109 |
sql_query = parse_query(query)
|
110 |
plot_code = ""
|
111 |
-
if is_plot:
|
112 |
# Generate plot code based on the query
|
113 |
# For simplicity, we'll generate a basic plot code
|
114 |
plot_code = """
|
115 |
import plotly.express as px
|
116 |
-
fig = px.bar(result_df, x='x_column', y='y_column')
|
|
|
117 |
"""
|
118 |
return sql_query, plot_code
|
119 |
|
@@ -122,6 +132,9 @@ def execute_query(sql_query):
|
|
122 |
"""
|
123 |
Executes the SQL query and returns the results as a DataFrame.
|
124 |
"""
|
|
|
|
|
|
|
125 |
try:
|
126 |
con = duckdb.connect()
|
127 |
# Ensure the view is created
|
@@ -151,8 +164,8 @@ def generate_plot(plot_code, result_df):
|
|
151 |
plot_code = plot_code.replace('y_column', columns[1])
|
152 |
|
153 |
# Execute the plot code
|
154 |
-
local_vars = {'result_df': result_df}
|
155 |
-
exec(plot_code, {
|
156 |
fig = local_vars.get('fig', None)
|
157 |
if fig:
|
158 |
return fig, ""
|
|
|
2 |
import gradio as gr
|
3 |
import duckdb
|
4 |
from functools import lru_cache
|
|
|
5 |
import pandas as pd
|
6 |
import plotly.express as px
|
7 |
import openai
|
8 |
+
import os
|
9 |
|
10 |
# Load the Parquet dataset path
|
11 |
dataset_path = 'sample_contract_df.parquet' # Update with your Parquet file path
|
|
|
60 |
finally:
|
61 |
con.close()
|
62 |
|
63 |
+
# Advanced Natural Language to SQL Parser using OpenAI's ChatCompletion
|
64 |
def parse_query(nl_query):
|
65 |
"""
|
66 |
+
Converts a natural language query into a SQL query using OpenAI's GPT-3.5-turbo.
|
67 |
"""
|
68 |
+
openai.api_key = os.getenv('OPENAI_API_KEY') # It's recommended to set your API key as an environment variable
|
69 |
|
70 |
+
system_prompt = "You are an assistant that converts natural language queries into SQL queries for a DuckDB database named 'contract_data'. Use the provided schema to form accurate SQL queries."
|
71 |
+
|
72 |
+
user_prompt = f"""
|
73 |
Schema:
|
74 |
{json.dumps(schema, indent=2)}
|
75 |
+
|
76 |
+
Convert the following natural language query into a SQL query:
|
77 |
"{nl_query}"
|
78 |
"""
|
79 |
+
|
80 |
try:
|
81 |
+
response = openai.ChatCompletion.create(
|
82 |
+
model="gpt-3.5-turbo",
|
83 |
+
messages=[
|
84 |
+
{"role": "system", "content": system_prompt},
|
85 |
+
{"role": "user", "content": user_prompt}
|
86 |
+
],
|
87 |
temperature=0,
|
88 |
max_tokens=150,
|
89 |
top_p=1,
|
90 |
n=1,
|
91 |
stop=None
|
92 |
)
|
93 |
+
sql_query = response.choices[0].message['content'].strip()
|
94 |
return sql_query
|
95 |
except Exception as e:
|
96 |
return f"Error generating SQL query: {e}"
|
|
|
100 |
"""
|
101 |
Detects if the user's query involves plotting.
|
102 |
"""
|
103 |
+
plot_keywords = [
|
104 |
+
'plot', 'graph', 'chart', 'distribution', 'visualize', 'histogram',
|
105 |
+
'bar chart', 'line chart', 'scatter plot', 'pie chart'
|
106 |
+
]
|
107 |
for keyword in plot_keywords:
|
108 |
if keyword in nl_query.lower():
|
109 |
return True
|
|
|
117 |
is_plot = detect_plot_intent(query)
|
118 |
sql_query = parse_query(query)
|
119 |
plot_code = ""
|
120 |
+
if is_plot and not sql_query.startswith("Error"):
|
121 |
# Generate plot code based on the query
|
122 |
# For simplicity, we'll generate a basic plot code
|
123 |
plot_code = """
|
124 |
import plotly.express as px
|
125 |
+
fig = px.bar(result_df, x='x_column', y='y_column', title='Generated Plot')
|
126 |
+
fig.update_layout(title_x=0.5)
|
127 |
"""
|
128 |
return sql_query, plot_code
|
129 |
|
|
|
132 |
"""
|
133 |
Executes the SQL query and returns the results as a DataFrame.
|
134 |
"""
|
135 |
+
if sql_query.startswith("Error"):
|
136 |
+
return None, sql_query # Pass the error message forward
|
137 |
+
|
138 |
try:
|
139 |
con = duckdb.connect()
|
140 |
# Ensure the view is created
|
|
|
164 |
plot_code = plot_code.replace('y_column', columns[1])
|
165 |
|
166 |
# Execute the plot code
|
167 |
+
local_vars = {'result_df': result_df, 'px': px}
|
168 |
+
exec(plot_code, {}, local_vars)
|
169 |
fig = local_vars.get('fig', None)
|
170 |
if fig:
|
171 |
return fig, ""
|