Ari commited on
Commit
b9a3a14
·
verified ·
1 Parent(s): 2b5a0ae

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +48 -13
app.py CHANGED
@@ -6,11 +6,15 @@ from langchain import OpenAI, LLMChain, PromptTemplate
6
  from langchain_community.utilities import SQLDatabase
7
  import sqlparse
8
  import logging
9
- from sql_metadata import Parser # Added import
10
 
11
  # OpenAI API key (ensure it is securely stored)
12
  openai_api_key = os.getenv("OPENAI_API_KEY")
13
 
 
 
 
 
14
  # Step 1: Upload CSV data file (or use default)
15
  csv_file = st.file_uploader("Upload your CSV file", type=["csv"])
16
  if csv_file is None:
@@ -31,6 +35,29 @@ data.to_sql(table_name, conn, index=False, if_exists='replace')
31
  valid_columns = list(data.columns)
32
  st.write(f"Valid columns: {valid_columns}")
33
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34
  # Step 3: Define SQL validation helpers
35
  def validate_sql(query, valid_columns):
36
  """Validates the SQL query by ensuring it references only valid columns."""
@@ -38,7 +65,7 @@ def validate_sql(query, valid_columns):
38
  columns_in_query = parser.columns
39
  for column in columns_in_query:
40
  if column not in valid_columns:
41
- st.write(f"Invalid column detected: {column}")
42
  return False
43
  return True
44
 
@@ -62,32 +89,40 @@ SQL Query:
62
  prompt = PromptTemplate(template=template, input_variables=['question', 'table_name', 'columns'])
63
  sql_generation_chain = LLMChain(llm=OpenAI(temperature=0), prompt=prompt)
64
 
 
 
 
 
 
65
  # Step 5: Generate SQL query based on user input
66
  user_prompt = st.text_input("Enter your natural language prompt:")
67
  if user_prompt:
68
  try:
69
- # Step 6: Adjust the logic to handle "what are the columns" query
70
- if "columns" in user_prompt.lower():
71
- # Custom logic to return columns
72
- st.write(f"The columns are: {', '.join(valid_columns)}")
73
  else:
74
  columns = ', '.join(valid_columns)
75
  generated_sql = sql_generation_chain.run({'question': user_prompt, 'table_name': table_name, 'columns': columns})
76
 
77
- # Debug: Display generated SQL query for inspection
78
- st.write(f"Generated SQL Query:\n{generated_sql}")
79
 
80
  # Step 7: Validate SQL query
81
  if not validate_sql_with_sqlparse(generated_sql):
82
- st.write("Generated SQL is not valid.")
83
  elif not validate_sql(generated_sql, valid_columns):
84
- st.write("Generated SQL references invalid columns.")
85
  else:
86
  # Step 8: Execute SQL query
87
  result = pd.read_sql_query(generated_sql, conn)
88
- st.write("Query Results:")
89
- st.dataframe(result)
90
 
91
  except Exception as e:
92
  logging.error(f"An error occurred: {e}")
93
- st.write(f"Error: {e}")
 
 
 
 
 
6
  from langchain_community.utilities import SQLDatabase
7
  import sqlparse
8
  import logging
9
+ from sql_metadata import Parser
10
 
11
  # OpenAI API key (ensure it is securely stored)
12
  openai_api_key = os.getenv("OPENAI_API_KEY")
13
 
14
+ # Initialize conversation history
15
+ if 'conversation' not in st.session_state:
16
+ st.session_state.conversation = []
17
+
18
  # Step 1: Upload CSV data file (or use default)
19
  csv_file = st.file_uploader("Upload your CSV file", type=["csv"])
20
  if csv_file is None:
 
35
  valid_columns = list(data.columns)
36
  st.write(f"Valid columns: {valid_columns}")
37
 
38
+ # Function to extract column names from the question
39
+ def extract_column_name(question, valid_columns):
40
+ for column in valid_columns:
41
+ if column.lower() in question.lower():
42
+ return column
43
+ return None
44
+
45
+ # Function to generate statistical insights
46
+ def generate_statistical_insights(question, data):
47
+ if "mean" in question.lower():
48
+ column = extract_column_name(question, valid_columns)
49
+ if column:
50
+ mean_value = data[column].mean()
51
+ st.session_state.conversation.append(f"Mean of {column}: {mean_value}")
52
+ else:
53
+ st.session_state.conversation.append(f"Could not find a valid column in the question.")
54
+ elif "median" in question.lower():
55
+ column = extract_column_name(question, valid_columns)
56
+ if column:
57
+ median_value = data[column].median()
58
+ st.session_state.conversation.append(f"Median of {column}: {median_value}")
59
+ # Add more statistical insights (mode, std, etc.)
60
+
61
  # Step 3: Define SQL validation helpers
62
  def validate_sql(query, valid_columns):
63
  """Validates the SQL query by ensuring it references only valid columns."""
 
65
  columns_in_query = parser.columns
66
  for column in columns_in_query:
67
  if column not in valid_columns:
68
+ st.session_state.conversation.append(f"Invalid column detected: {column}")
69
  return False
70
  return True
71
 
 
89
  prompt = PromptTemplate(template=template, input_variables=['question', 'table_name', 'columns'])
90
  sql_generation_chain = LLMChain(llm=OpenAI(temperature=0), prompt=prompt)
91
 
92
+ # Display conversation history like a text thread
93
+ st.write("### Conversation Thread")
94
+ for chat in st.session_state.conversation:
95
+ st.write(f"User: {chat}")
96
+
97
  # Step 5: Generate SQL query based on user input
98
  user_prompt = st.text_input("Enter your natural language prompt:")
99
  if user_prompt:
100
  try:
101
+ # Step 6: Handle statistical insights or generate SQL
102
+ if any(stat_term in user_prompt.lower() for stat_term in ["mean", "median", "mode", "std"]):
103
+ generate_statistical_insights(user_prompt, data)
 
104
  else:
105
  columns = ', '.join(valid_columns)
106
  generated_sql = sql_generation_chain.run({'question': user_prompt, 'table_name': table_name, 'columns': columns})
107
 
108
+ # Display generated SQL query in the conversation thread
109
+ st.session_state.conversation.append(f"Generated SQL Query: {generated_sql}")
110
 
111
  # Step 7: Validate SQL query
112
  if not validate_sql_with_sqlparse(generated_sql):
113
+ st.session_state.conversation.append("Generated SQL is not valid.")
114
  elif not validate_sql(generated_sql, valid_columns):
115
+ st.session_state.conversation.append("Generated SQL references invalid columns.")
116
  else:
117
  # Step 8: Execute SQL query
118
  result = pd.read_sql_query(generated_sql, conn)
119
+ st.session_state.conversation.append("Query Results:")
120
+ st.session_state.conversation.append(result.to_string())
121
 
122
  except Exception as e:
123
  logging.error(f"An error occurred: {e}")
124
+ st.session_state.conversation.append(f"Error: {e}")
125
+
126
+ # Display the text input box below the conversation thread
127
+ user_input = st.text_input("Enter a question to ask the data:", key="user_input")
128
+