DrishtiSharma commited on
Commit
9abae49
Β·
verified Β·
1 Parent(s): 9bd334d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +27 -24
app.py CHANGED
@@ -20,10 +20,10 @@ from langchain_community.utilities.sql_database import SQLDatabase
20
  from datasets import load_dataset
21
  import tempfile
22
 
23
- # Setup API key
24
  os.environ["GROQ_API_KEY"] = st.secrets.get("GROQ_API_KEY", "")
25
 
26
- # Callback handler for logging
27
  class LLMCallbackHandler(BaseCallbackHandler):
28
  def __init__(self, log_path: Path):
29
  self.log_path = log_path
@@ -44,33 +44,36 @@ llm = ChatGroq(
44
  callbacks=[LLMCallbackHandler(Path("prompts.jsonl"))],
45
  )
46
 
47
- st.title("SQL-RAG using CrewAI πŸš€")
48
  st.write("Analyze and summarize data using natural language queries with SQL-based retrieval.")
49
 
50
- # File upload or Hugging Face dataset input
51
- option = st.radio("Choose your input method:", ["Upload a CSV file", "Enter Hugging Face dataset name"])
 
 
52
 
53
- if option == "Upload a CSV file":
54
- uploaded_file = st.file_uploader("Upload your dataset (CSV format)", type=["csv"])
55
- if uploaded_file:
56
- df = pd.read_csv(uploaded_file)
57
- st.success("File uploaded successfully!")
58
- else:
59
- dataset_name = st.text_input("Enter Hugging Face dataset name:", placeholder="e.g., imdb, ag_news")
60
- if dataset_name:
61
- try:
62
  dataset = load_dataset(dataset_name, split="train")
63
  df = pd.DataFrame(dataset)
64
  st.success(f"Dataset '{dataset_name}' loaded successfully!")
65
- except Exception as e:
66
- st.error(f"Error loading Hugging Face dataset: {e}")
67
- df = None
68
-
69
- if 'df' in locals() and not df.empty:
70
- st.write("### Dataset Preview:")
71
- st.dataframe(df.head())
 
 
 
 
 
72
 
73
- # Create a temporary SQLite database
 
74
  temp_dir = tempfile.TemporaryDirectory()
75
  db_path = os.path.join(temp_dir.name, "data.db")
76
  connection = sqlite3.connect(db_path)
@@ -146,7 +149,7 @@ if 'df' in locals() and not df.empty:
146
  memory=False,
147
  )
148
 
149
- query = st.text_input("Enter your query:", placeholder="e.g., 'What are the top 5 highest salaries?'")
150
  if query:
151
  with st.spinner("Processing your query..."):
152
  inputs = {"query": query}
@@ -156,4 +159,4 @@ if 'df' in locals() and not df.empty:
156
 
157
  temp_dir.cleanup()
158
  else:
159
- st.warning("Please upload a valid file or provide a correct Hugging Face dataset name.")
 
20
  from datasets import load_dataset
21
  import tempfile
22
 
23
+ # Setup API Key
24
  os.environ["GROQ_API_KEY"] = st.secrets.get("GROQ_API_KEY", "")
25
 
26
+ # LLM Logging
27
  class LLMCallbackHandler(BaseCallbackHandler):
28
  def __init__(self, log_path: Path):
29
  self.log_path = log_path
 
44
  callbacks=[LLMCallbackHandler(Path("prompts.jsonl"))],
45
  )
46
 
47
+ st.title("SQL-RAG Using CrewAI πŸš€")
48
  st.write("Analyze and summarize data using natural language queries with SQL-based retrieval.")
49
 
50
+ # Primary Option: Hugging Face Dataset
51
+ st.subheader("Option 1: Use a Hugging Face Dataset")
52
+ default_dataset = "Einstellung/demo-salaries"
53
+ dataset_name = st.text_input("Enter Hugging Face dataset name:", value=default_dataset)
54
 
55
+ df = None
56
+ if dataset_name:
57
+ try:
58
+ with st.spinner("Loading Hugging Face dataset..."):
 
 
 
 
 
59
  dataset = load_dataset(dataset_name, split="train")
60
  df = pd.DataFrame(dataset)
61
  st.success(f"Dataset '{dataset_name}' loaded successfully!")
62
+ st.dataframe(df.head())
63
+ except Exception as e:
64
+ st.error(f"Error loading Hugging Face dataset: {e}")
65
+
66
+ # Secondary Option: File Upload
67
+ st.subheader("Option 2: Upload Your CSV File")
68
+ uploaded_file = st.file_uploader("Upload your dataset (CSV format):", type=["csv"])
69
+ if uploaded_file and df is None:
70
+ with st.spinner("Loading uploaded file..."):
71
+ df = pd.read_csv(uploaded_file)
72
+ st.success("File uploaded successfully!")
73
+ st.dataframe(df.head())
74
 
75
+ if df is not None:
76
+ # Create SQLite database
77
  temp_dir = tempfile.TemporaryDirectory()
78
  db_path = os.path.join(temp_dir.name, "data.db")
79
  connection = sqlite3.connect(db_path)
 
149
  memory=False,
150
  )
151
 
152
+ query = st.text_input("Enter your query:", placeholder="e.g., 'What is the average salary by experience level?'")
153
  if query:
154
  with st.spinner("Processing your query..."):
155
  inputs = {"query": query}
 
159
 
160
  temp_dir.cleanup()
161
  else:
162
+ st.warning("Please load a Hugging Face dataset or upload a CSV file to proceed.")