Ari commited on
Commit
a01caad
·
verified ·
1 Parent(s): 9e9d1c1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +5 -5
app.py CHANGED
@@ -3,7 +3,6 @@ import streamlit as st
3
  import pandas as pd
4
  import sqlite3
5
  from langchain import OpenAI, LLMChain, PromptTemplate
6
- # Removed unused import: from langchain_community.utilities import SQLDatabase
7
  import sqlparse
8
  import logging
9
  from sql_metadata import Parser
@@ -40,8 +39,12 @@ def validate_sql(query, valid_columns):
40
  """Validates the SQL query by ensuring it references only valid columns."""
41
  parser = Parser(query)
42
  columns_in_query = parser.columns
 
 
 
 
43
  for column in columns_in_query:
44
- if column not in valid_columns:
45
  st.write(f"Invalid column detected: {column}")
46
  return False
47
  return True
@@ -86,9 +89,6 @@ def process_input():
86
  'columns': columns
87
  })
88
 
89
- # Debug: Display generated SQL query for inspection
90
- # st.write(f"Generated SQL Query:\n{generated_sql}")
91
-
92
  # Validate SQL query
93
  if not validate_sql_with_sqlparse(generated_sql):
94
  assistant_response = "Generated SQL is not valid."
 
3
  import pandas as pd
4
  import sqlite3
5
  from langchain import OpenAI, LLMChain, PromptTemplate
 
6
  import sqlparse
7
  import logging
8
  from sql_metadata import Parser
 
39
  """Validates the SQL query by ensuring it references only valid columns."""
40
  parser = Parser(query)
41
  columns_in_query = parser.columns
42
+
43
+ # Convert valid columns to lowercase for case-insensitive comparison
44
+ valid_columns_lower = [col.lower() for col in valid_columns]
45
+
46
  for column in columns_in_query:
47
+ if column.lower() not in valid_columns_lower:
48
  st.write(f"Invalid column detected: {column}")
49
  return False
50
  return True
 
89
  'columns': columns
90
  })
91
 
 
 
 
92
  # Validate SQL query
93
  if not validate_sql_with_sqlparse(generated_sql):
94
  assistant_response = "Generated SQL is not valid."