LeonceNsh commited on
Commit
1fa796c
·
verified ·
1 Parent(s): e3824aa

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +29 -16
app.py CHANGED
@@ -2,10 +2,10 @@ import json
2
  import gradio as gr
3
  import duckdb
4
  from functools import lru_cache
5
- from transformers import pipeline
6
  import pandas as pd
7
  import plotly.express as px
8
  import openai
 
9
 
10
  # Load the Parquet dataset path
11
  dataset_path = 'sample_contract_df.parquet' # Update with your Parquet file path
@@ -60,31 +60,37 @@ def load_dataset_schema():
60
  finally:
61
  con.close()
62
 
63
- # Advanced Natural Language to SQL Parser using OpenAI's GPT-3
64
  def parse_query(nl_query):
65
  """
66
- Converts a natural language query into SQL query using OpenAI GPT-3.
67
  """
68
- openai.api_key = 'YOUR_OPENAI_API_KEY' # Replace with your OpenAI API key
69
 
70
- prompt = f"""
71
- Convert the following natural language query into a SQL query for a DuckDB database. Use 'contract_data' as the table name.
 
72
  Schema:
73
  {json.dumps(schema, indent=2)}
74
- Query:
 
75
  "{nl_query}"
76
  """
 
77
  try:
78
- response = openai.Completion.create(
79
- engine="text-davinci-003",
80
- prompt=prompt,
 
 
 
81
  temperature=0,
82
  max_tokens=150,
83
  top_p=1,
84
  n=1,
85
  stop=None
86
  )
87
- sql_query = response.choices[0].text.strip()
88
  return sql_query
89
  except Exception as e:
90
  return f"Error generating SQL query: {e}"
@@ -94,7 +100,10 @@ def detect_plot_intent(nl_query):
94
  """
95
  Detects if the user's query involves plotting.
96
  """
97
- plot_keywords = ['plot', 'graph', 'chart', 'distribution', 'visualize', 'histogram', 'bar chart', 'line chart', 'scatter plot', 'pie chart']
 
 
 
98
  for keyword in plot_keywords:
99
  if keyword in nl_query.lower():
100
  return True
@@ -108,12 +117,13 @@ def generate_sql_and_plot_code(query):
108
  is_plot = detect_plot_intent(query)
109
  sql_query = parse_query(query)
110
  plot_code = ""
111
- if is_plot:
112
  # Generate plot code based on the query
113
  # For simplicity, we'll generate a basic plot code
114
  plot_code = """
115
  import plotly.express as px
116
- fig = px.bar(result_df, x='x_column', y='y_column')
 
117
  """
118
  return sql_query, plot_code
119
 
@@ -122,6 +132,9 @@ def execute_query(sql_query):
122
  """
123
  Executes the SQL query and returns the results as a DataFrame.
124
  """
 
 
 
125
  try:
126
  con = duckdb.connect()
127
  # Ensure the view is created
@@ -151,8 +164,8 @@ def generate_plot(plot_code, result_df):
151
  plot_code = plot_code.replace('y_column', columns[1])
152
 
153
  # Execute the plot code
154
- local_vars = {'result_df': result_df}
155
- exec(plot_code, {'px': px}, local_vars)
156
  fig = local_vars.get('fig', None)
157
  if fig:
158
  return fig, ""
 
2
  import gradio as gr
3
  import duckdb
4
  from functools import lru_cache
 
5
  import pandas as pd
6
  import plotly.express as px
7
  import openai
8
+ import os
9
 
10
  # Load the Parquet dataset path
11
  dataset_path = 'sample_contract_df.parquet' # Update with your Parquet file path
 
60
  finally:
61
  con.close()
62
 
63
+ # Advanced Natural Language to SQL Parser using OpenAI's ChatCompletion
64
  def parse_query(nl_query):
65
  """
66
+ Converts a natural language query into a SQL query using OpenAI's GPT-3.5-turbo.
67
  """
68
+ openai.api_key = os.getenv('OPENAI_API_KEY') # It's recommended to set your API key as an environment variable
69
 
70
+ system_prompt = "You are an assistant that converts natural language queries into SQL queries for a DuckDB database named 'contract_data'. Use the provided schema to form accurate SQL queries."
71
+
72
+ user_prompt = f"""
73
  Schema:
74
  {json.dumps(schema, indent=2)}
75
+
76
+ Convert the following natural language query into a SQL query:
77
  "{nl_query}"
78
  """
79
+
80
  try:
81
+ response = openai.ChatCompletion.create(
82
+ model="gpt-3.5-turbo",
83
+ messages=[
84
+ {"role": "system", "content": system_prompt},
85
+ {"role": "user", "content": user_prompt}
86
+ ],
87
  temperature=0,
88
  max_tokens=150,
89
  top_p=1,
90
  n=1,
91
  stop=None
92
  )
93
+ sql_query = response.choices[0].message['content'].strip()
94
  return sql_query
95
  except Exception as e:
96
  return f"Error generating SQL query: {e}"
 
100
  """
101
  Detects if the user's query involves plotting.
102
  """
103
+ plot_keywords = [
104
+ 'plot', 'graph', 'chart', 'distribution', 'visualize', 'histogram',
105
+ 'bar chart', 'line chart', 'scatter plot', 'pie chart'
106
+ ]
107
  for keyword in plot_keywords:
108
  if keyword in nl_query.lower():
109
  return True
 
117
  is_plot = detect_plot_intent(query)
118
  sql_query = parse_query(query)
119
  plot_code = ""
120
+ if is_plot and not sql_query.startswith("Error"):
121
  # Generate plot code based on the query
122
  # For simplicity, we'll generate a basic plot code
123
  plot_code = """
124
  import plotly.express as px
125
+ fig = px.bar(result_df, x='x_column', y='y_column', title='Generated Plot')
126
+ fig.update_layout(title_x=0.5)
127
  """
128
  return sql_query, plot_code
129
 
 
132
  """
133
  Executes the SQL query and returns the results as a DataFrame.
134
  """
135
+ if sql_query.startswith("Error"):
136
+ return None, sql_query # Pass the error message forward
137
+
138
  try:
139
  con = duckdb.connect()
140
  # Ensure the view is created
 
164
  plot_code = plot_code.replace('y_column', columns[1])
165
 
166
  # Execute the plot code
167
+ local_vars = {'result_df': result_df, 'px': px}
168
+ exec(plot_code, {}, local_vars)
169
  fig = local_vars.get('fig', None)
170
  if fig:
171
  return fig, ""