CosmickVisions commited on
Commit
9e431a9
·
verified ·
1 Parent(s): bc938fb

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +300 -26
app.py CHANGED
@@ -8,20 +8,14 @@ from streamlit_pandas_profiling import st_profile_report
8
  import os
9
  import requests
10
  import json
11
- from datetime import datetime
12
  import re
13
- import tempfile
14
  from scipy import stats
15
- from sklearn.impute import SimpleImputer
16
  from sklearn.preprocessing import StandardScaler, LabelEncoder, OneHotEncoder
17
  from sklearn.decomposition import PCA
18
- import streamlit.components.v1 as components
19
- from io import StringIO
20
  from dotenv import load_dotenv
21
  from flask import Flask, request, jsonify
22
  from openai import OpenAI
23
  import threading
24
- from sentence_transformers import SentenceTransformer
25
 
26
  # Load environment variables
27
  load_dotenv()
@@ -30,13 +24,6 @@ load_dotenv()
30
  flask_app = Flask(__name__)
31
  FLASK_PORT = 5000 # Internal port for Flask, not exposed externally
32
 
33
- # Initialize OpenAI client
34
- api_key = os.getenv("OPENAI_API_KEY")
35
- if not api_key:
36
- st.error("OPENAI_API_KEY not set. Please configure it in the Hugging Face Space secrets.")
37
- st.stop()
38
- client = OpenAI(api_key=api_key)
39
-
40
  # Flask RAG Endpoint
41
  @flask_app.route('/rag_chat', methods=['POST'])
42
  def rag_chat():
@@ -45,7 +32,6 @@ def rag_chat():
45
  app_mode = data.get('app_mode', 'Data Upload')
46
  dataset_text = data.get('dataset_text', '')
47
 
48
- # RAG Logic: Use dataset_text as retrieval context
49
  system_prompt = (
50
  "You are an AI assistant in Data-Vision Pro, a data analysis app with RAG capabilities. "
51
  "The app has three pages:\n"
@@ -71,7 +57,7 @@ def rag_chat():
71
  {"role": "system", "content": system_prompt},
72
  {"role": "user", "content": user_input}
73
  ],
74
- max_tokens=100, # Increased for RAG context
75
  temperature=0.7
76
  )
77
  return jsonify({"response": response.choices[0].message.content})
@@ -82,7 +68,6 @@ def rag_chat():
82
  def run_flask():
83
  flask_app.run(host='0.0.0.0', port=FLASK_PORT, debug=False, use_reloader=False)
84
 
85
- # Start Flask thread
86
  flask_thread = threading.Thread(target=run_flask, daemon=True)
87
  flask_thread.start()
88
 
@@ -95,11 +80,11 @@ def update_cleaned_data(df):
95
  if 'data_versions' not in st.session_state:
96
  st.session_state.data_versions = [st.session_state.raw_data.copy()]
97
  st.session_state.data_versions.append(df.copy())
 
98
  st.success("✅ Action completed successfully!")
99
  st.rerun()
100
 
101
  def convert_csv_to_json_and_text(df):
102
- """Convert DataFrame to JSON and then to plain text."""
103
  json_data = df.to_json(orient="records")
104
  data_dict = json.loads(json_data)
105
  text_summary = f"Dataset Summary: {df.shape[0]} rows, {df.shape[1]} columns\n"
@@ -115,7 +100,6 @@ def convert_csv_to_json_and_text(df):
115
  return text_summary
116
 
117
  def get_chatbot_response(user_input, app_mode, dataset_text=""):
118
- """Send request to internal Flask RAG endpoint."""
119
  payload = {
120
  "user_input": user_input,
121
  "app_mode": app_mode,
@@ -128,8 +112,88 @@ def get_chatbot_response(user_input, app_mode, dataset_text=""):
128
  except requests.exceptions.RequestException as e:
129
  return f"Error: Could not connect to RAG server. {str(e)}"
130
 
131
- # Streamlit App
132
- # Sidebar Navigation
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
133
  with st.sidebar:
134
  st.title("🔮 Data-Vision Pro")
135
  st.markdown("Your AI-powered data analysis suite with RAG.")
@@ -145,6 +209,13 @@ with st.sidebar:
145
  st.info("🧹 Clean and preprocess your data using various tools.")
146
  elif app_mode == "EDA":
147
  st.info("🔍 Explore your data visually and statistically.")
 
 
 
 
 
 
 
148
 
149
  st.markdown("---")
150
  st.markdown("**Note**: Requires dependencies in `requirements.txt`.")
@@ -159,15 +230,29 @@ with st.sidebar:
159
  st.markdown("Created by Calvin Allen-Crawford")
160
  st.markdown("v1.0 | © 2025")
161
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
162
  # Main App Pages
163
  if app_mode == "Data Upload":
164
  st.title("📤 Data Upload & Profiling")
165
  st.header("Upload Your Dataset")
166
  st.write("Supported formats: CSV, XLSX")
167
-
168
  if 'raw_data' not in st.session_state:
169
  st.info("It looks like no dataset has been uploaded yet. Would you like to upload a CSV or XLSX file?")
170
-
171
  uploaded_file = st.file_uploader("Choose a file", type=["csv", "xlsx"], key="file_uploader")
172
  if uploaded_file:
173
  st.session_state.pop('raw_data', None)
@@ -182,6 +267,7 @@ if app_mode == "Data Upload":
182
  st.error("Uploaded file is empty.")
183
  st.stop()
184
  st.session_state.raw_data = df
 
185
  st.session_state.dataset_text = convert_csv_to_json_and_text(df)
186
  if 'data_versions' not in st.session_state:
187
  st.session_state.data_versions = [df.copy()]
@@ -226,6 +312,92 @@ elif app_mode == "Data Cleaning":
226
  st.session_state.dataset_text = convert_csv_to_json_and_text(st.session_state.cleaned_data)
227
  st.rerun()
228
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
229
  elif app_mode == "EDA":
230
  st.title("🔍 Interactive Data Explorer")
231
  if 'cleaned_data' not in st.session_state:
@@ -242,10 +414,109 @@ elif app_mode == "EDA":
242
  col3.metric("Missing Values", f"{df.isna().sum().sum()} ({missing_percentage:.1f}%)")
243
  col4.metric("Duplicates", df.duplicated().sum())
244
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
245
  # Chatbot Section
246
  st.markdown("---")
247
  st.subheader("💬 AI Chatbot Assistant (RAG Enabled)")
248
- st.info("Ask me about the app or your data! Try: 'What can I do here?' or 'What’s in the dataset?'")
249
  if "chat_history" not in st.session_state:
250
  st.session_state.chat_history = []
251
 
@@ -258,10 +529,13 @@ if user_input:
258
  st.session_state.chat_history.append({"role": "user", "content": user_input})
259
  with st.chat_message("user"):
260
  st.markdown(user_input)
261
-
262
- with st.spinner("Thinking with RAG..."):
263
  dataset_text = st.session_state.get("dataset_text", "")
264
- response = get_chatbot_response(user_input, app_mode, dataset_text)
 
 
 
 
265
  st.session_state.chat_history.append({"role": "assistant", "content": response})
266
  with st.chat_message("assistant"):
267
  st.markdown(response)
 
8
  import os
9
  import requests
10
  import json
 
11
  import re
 
12
  from scipy import stats
 
13
  from sklearn.preprocessing import StandardScaler, LabelEncoder, OneHotEncoder
14
  from sklearn.decomposition import PCA
 
 
15
  from dotenv import load_dotenv
16
  from flask import Flask, request, jsonify
17
  from openai import OpenAI
18
  import threading
 
19
 
20
  # Load environment variables
21
  load_dotenv()
 
24
  flask_app = Flask(__name__)
25
  FLASK_PORT = 5000 # Internal port for Flask, not exposed externally
26
 
 
 
 
 
 
 
 
27
  # Flask RAG Endpoint
28
  @flask_app.route('/rag_chat', methods=['POST'])
29
  def rag_chat():
 
32
  app_mode = data.get('app_mode', 'Data Upload')
33
  dataset_text = data.get('dataset_text', '')
34
 
 
35
  system_prompt = (
36
  "You are an AI assistant in Data-Vision Pro, a data analysis app with RAG capabilities. "
37
  "The app has three pages:\n"
 
57
  {"role": "system", "content": system_prompt},
58
  {"role": "user", "content": user_input}
59
  ],
60
+ max_tokens=100,
61
  temperature=0.7
62
  )
63
  return jsonify({"response": response.choices[0].message.content})
 
68
  def run_flask():
69
  flask_app.run(host='0.0.0.0', port=FLASK_PORT, debug=False, use_reloader=False)
70
 
 
71
  flask_thread = threading.Thread(target=run_flask, daemon=True)
72
  flask_thread.start()
73
 
 
80
  if 'data_versions' not in st.session_state:
81
  st.session_state.data_versions = [st.session_state.raw_data.copy()]
82
  st.session_state.data_versions.append(df.copy())
83
+ st.session_state.dataset_text = convert_csv_to_json_and_text(df)
84
  st.success("✅ Action completed successfully!")
85
  st.rerun()
86
 
87
  def convert_csv_to_json_and_text(df):
 
88
  json_data = df.to_json(orient="records")
89
  data_dict = json.loads(json_data)
90
  text_summary = f"Dataset Summary: {df.shape[0]} rows, {df.shape[1]} columns\n"
 
100
  return text_summary
101
 
102
  def get_chatbot_response(user_input, app_mode, dataset_text=""):
 
103
  payload = {
104
  "user_input": user_input,
105
  "app_mode": app_mode,
 
112
  except requests.exceptions.RequestException as e:
113
  return f"Error: Could not connect to RAG server. {str(e)}"
114
 
115
+ # Command Functions for LLM
116
+ def drop_columns(columns):
117
+ if 'cleaned_data' in st.session_state:
118
+ df = st.session_state.cleaned_data.copy()
119
+ columns_to_drop = [col.strip() for col in columns.split(',')]
120
+ valid_columns = [col for col in columns_to_drop if col in df.columns]
121
+ if valid_columns:
122
+ df.drop(valid_columns, axis=1, inplace=True)
123
+ update_cleaned_data(df)
124
+ return f"Dropped columns: {', '.join(valid_columns)}"
125
+ else:
126
+ return "No valid columns found to drop."
127
+ return "No dataset loaded."
128
+
129
+ # LLM-Driven EDA Commands
130
+ def generate_scatter_plot(params):
131
+ df = st.session_state.cleaned_data
132
+ match = re.search(r"([\w\s]+)\s+vs\s+([\w\s]+)", params)
133
+ if match and len(match.groups()) >= 2:
134
+ x_axis, y_axis = match.group(1).strip(), match.group(2).strip()
135
+ if x_axis in df.columns and y_axis in df.columns:
136
+ fig = px.scatter(df, x=x_axis, y=y_axis, title=f'Scatter Plot of {x_axis} vs {y_axis}')
137
+ st.plotly_chart(fig)
138
+ st.session_state.last_plot = {"type": "Scatter Plot", "x": x_axis, "y": y_axis, "data": df[[x_axis, y_axis]].to_json()}
139
+ return f"Generated scatter plot of {x_axis} vs {y_axis}"
140
+ return "Invalid columns for scatter plot."
141
+
142
+ def generate_histogram(params):
143
+ df = st.session_state.cleaned_data
144
+ x_axis = params.strip()
145
+ if x_axis in df.columns:
146
+ fig = px.histogram(df, x=x_axis, title=f'Histogram of {x_axis}')
147
+ st.plotly_chart(fig)
148
+ st.session_state.last_plot = {"type": "Histogram", "x": x_axis, "data": df[[x_axis]].to_json()}
149
+ return f"Generated histogram of {x_axis}"
150
+ return "Invalid column for histogram."
151
+
152
+ # Inference from Plotted Data
153
+ def analyze_plot():
154
+ if "last_plot" not in st.session_state:
155
+ return "No plot available to analyze."
156
+ plot_info = st.session_state.last_plot
157
+ df = pd.read_json(plot_info["data"])
158
+ plot_type = plot_info["type"]
159
+ x_col = plot_info["x"]
160
+ y_col = plot_info["y"] if "y" in plot_info else None
161
+
162
+ if plot_type == "Scatter Plot" and y_col:
163
+ correlation = df[x_col].corr(df[y_col])
164
+ strength = "strong" if abs(correlation) > 0.7 else "moderate" if abs(correlation) > 0.3 else "weak"
165
+ direction = "positive" if correlation > 0 else "negative"
166
+ return f"The scatter plot of {x_col} vs {y_col} shows a {strength} {direction} correlation (Pearson r = {correlation:.2f})."
167
+ elif plot_type == "Histogram":
168
+ skewness = df[x_col].skew()
169
+ skew_desc = "positively skewed" if skewness > 1 else "negatively skewed" if skewness < -1 else "approximately symmetric"
170
+ return f"The histogram of {x_col} is {skew_desc} (skewness = {skewness:.2f})."
171
+ return "Inference not available for this plot type."
172
+
173
+ # Parse Chatbot Commands
174
+ def parse_command(command):
175
+ command = command.lower().strip()
176
+ if "drop columns" in command or "drop column" in command:
177
+ columns = command.replace("drop columns", "").replace("drop column", "").strip()
178
+ return drop_columns, columns
179
+ elif "show a scatter plot" in command or "scatter plot of" in command:
180
+ params = command.replace("show a scatter plot of", "").replace("scatter plot of", "").strip()
181
+ return generate_scatter_plot, params
182
+ elif "show a histogram" in command or "histogram of" in command:
183
+ params = command.replace("show a histogram of", "").replace("histogram of", "").strip()
184
+ return generate_histogram, params
185
+ elif "analyze plot" in command:
186
+ return lambda x: analyze_plot(), None
187
+ return None, "Command not recognized. Try 'drop columns X, Y', 'scatter plot of X vs Y', or 'analyze plot'."
188
+
189
+ # Dataset Preview Function
190
+ def display_dataset_preview():
191
+ if 'cleaned_data' in st.session_state:
192
+ st.subheader("Current Dataset Preview")
193
+ st.dataframe(st.session_state.cleaned_data.head(10), use_container_width=True)
194
+ st.write("---")
195
+
196
+ # Sidebar Navigation with API Key Input
197
  with st.sidebar:
198
  st.title("🔮 Data-Vision Pro")
199
  st.markdown("Your AI-powered data analysis suite with RAG.")
 
209
  st.info("🧹 Clean and preprocess your data using various tools.")
210
  elif app_mode == "EDA":
211
  st.info("🔍 Explore your data visually and statistically.")
212
+
213
+ # API Key Input Field
214
+ api_key_input = st.text_input(
215
+ "Enter your API key (optional)",
216
+ type="password",
217
+ help="Enter your API key to override the default. Leave blank to use the app's default key."
218
+ )
219
 
220
  st.markdown("---")
221
  st.markdown("**Note**: Requires dependencies in `requirements.txt`.")
 
230
  st.markdown("Created by Calvin Allen-Crawford")
231
  st.markdown("v1.0 | © 2025")
232
 
233
+ # Determine which API key to use
234
+ if api_key_input:
235
+ api_key = api_key_input # Use the user-provided API key from the sidebar
236
+ else:
237
+ api_key = st.secrets.get("OPENAI_API_KEY", os.getenv("OPENAI_API_KEY")) # Fall back to secret or environment variable
238
+
239
+ if not api_key:
240
+ st.error("API key is required. Please provide it in the sidebar or ensure it’s set in the app’s secrets.")
241
+ st.stop()
242
+
243
+ # Initialize OpenAI client with the selected API key
244
+ client = OpenAI(api_key=api_key)
245
+
246
+ # Display dataset preview at the top of each page
247
+ display_dataset_preview()
248
+
249
  # Main App Pages
250
  if app_mode == "Data Upload":
251
  st.title("📤 Data Upload & Profiling")
252
  st.header("Upload Your Dataset")
253
  st.write("Supported formats: CSV, XLSX")
 
254
  if 'raw_data' not in st.session_state:
255
  st.info("It looks like no dataset has been uploaded yet. Would you like to upload a CSV or XLSX file?")
 
256
  uploaded_file = st.file_uploader("Choose a file", type=["csv", "xlsx"], key="file_uploader")
257
  if uploaded_file:
258
  st.session_state.pop('raw_data', None)
 
267
  st.error("Uploaded file is empty.")
268
  st.stop()
269
  st.session_state.raw_data = df
270
+ st.session_state.cleaned_data = df.copy()
271
  st.session_state.dataset_text = convert_csv_to_json_and_text(df)
272
  if 'data_versions' not in st.session_state:
273
  st.session_state.data_versions = [df.copy()]
 
312
  st.session_state.dataset_text = convert_csv_to_json_and_text(st.session_state.cleaned_data)
313
  st.rerun()
314
 
315
+ with st.expander("🛠️ Data Cleaning Operations", expanded=True):
316
+ enhance_section_title("🔍 Missing Values Treatment")
317
+ missing_cols = df.columns[df.isna().any()].tolist()
318
+ if missing_cols:
319
+ cols = st.multiselect("Select columns with missing values", missing_cols)
320
+ method = st.selectbox("Choose imputation method", [
321
+ "Drop Missing Values", "Fill with Mean/Median", "Fill with Custom Value", "Forward Fill", "Backward Fill"
322
+ ])
323
+ if method == "Fill with Custom Value":
324
+ custom_val = st.text_input("Enter custom value:")
325
+ if st.button("Apply Missing Value Treatment"):
326
+ new_df = df.copy()
327
+ if method == "Drop Missing Values":
328
+ new_df = new_df.dropna(subset=cols)
329
+ elif method == "Fill with Mean/Median":
330
+ for col in cols:
331
+ if pd.api.types.is_numeric_dtype(new_df[col]):
332
+ new_df[col] = new_df[col].fillna(new_df[col].median())
333
+ else:
334
+ new_df[col] = new_df[col].fillna(new_df[col].mode()[0])
335
+ elif method == "Fill with Custom Value" and custom_val:
336
+ new_df[cols] = new_df[cols].fillna(custom_val)
337
+ elif method == "Forward Fill":
338
+ new_df[cols] = new_df[cols].ffill()
339
+ elif method == "Backward Fill":
340
+ new_df[cols] = new_df[cols].bfill()
341
+ update_cleaned_data(new_df)
342
+ else:
343
+ st.success("✨ No missing values detected!")
344
+
345
+ enhance_section_title("🔄 Data Type Conversion")
346
+ col_to_convert = st.selectbox("Select column to convert", df.columns)
347
+ new_type = st.selectbox("Select new data type", ["String", "Integer", "Float", "Boolean", "Datetime"])
348
+ if new_type == "Datetime":
349
+ date_format = st.text_input("Enter date format (e.g., %Y-%m-%d):", "%Y-%m-%d")
350
+ if st.button("Convert Data Type"):
351
+ new_df = df.copy()
352
+ if new_type == "String":
353
+ new_df[col_to_convert] = new_df[col_to_convert].astype(str)
354
+ elif new_type == "Integer":
355
+ new_df[col_to_convert] = pd.to_numeric(new_df[col_to_convert], errors='coerce').astype('Int64')
356
+ elif new_type == "Float":
357
+ new_df[col_to_convert] = pd.to_numeric(new_df[col_to_convert], errors='coerce')
358
+ elif new_type == "Boolean":
359
+ new_df[col_to_convert] = new_df[col_to_convert].astype(bool)
360
+ elif new_type == "Datetime":
361
+ new_df[col_to_convert] = pd.to_datetime(new_df[col_to_convert], format=date_format, errors='coerce')
362
+ update_cleaned_data(new_df)
363
+
364
+ enhance_section_title("🗑️ Drop Columns")
365
+ columns_to_drop = st.multiselect("Select columns to remove", df.columns)
366
+ if columns_to_drop and st.button("Confirm Column Removal"):
367
+ new_df = df.copy()
368
+ new_df = new_df.drop(columns=columns_to_drop)
369
+ update_cleaned_data(new_df)
370
+
371
+ enhance_section_title("🔢 Encoding Options")
372
+ encoding_method = st.radio("Choose encoding method", ("Label Encoding", "One-Hot Encoding"))
373
+ data_to_encode = st.multiselect("Select columns to encode", df.select_dtypes(include='object').columns)
374
+ if data_to_encode and st.button("Apply Encoding"):
375
+ new_df = df.copy()
376
+ if encoding_method == "Label Encoding":
377
+ for col in data_to_encode:
378
+ le = LabelEncoder()
379
+ new_df[col] = le.fit_transform(new_df[col].astype(str))
380
+ elif encoding_method == "One-Hot Encoding":
381
+ new_df = pd.get_dummies(new_df, columns=data_to_encode, drop_first=True, dtype=int)
382
+ update_cleaned_data(new_df)
383
+
384
+ enhance_section_title("📏 StandardScaler")
385
+ scale_cols = st.multiselect("Select numerical columns to scale", df.select_dtypes(include=np.number).columns)
386
+ if scale_cols and st.button("Apply StandardScaler"):
387
+ new_df = df.copy()
388
+ scaler = StandardScaler()
389
+ new_df[scale_cols] = scaler.fit_transform(new_df[scale_cols])
390
+ update_cleaned_data(new_df)
391
+
392
+ enhance_section_title("🕵️ Pattern-Based Cleaning")
393
+ selected_col = st.selectbox("Select text column for pattern cleaning", df.select_dtypes(include='object').columns)
394
+ pattern = st.text_input("Enter regex pattern:")
395
+ replacement = st.text_input("Enter replacement value:")
396
+ if st.button("Apply Pattern Replacement"):
397
+ new_df = df.copy()
398
+ new_df[selected_col] = new_df[selected_col].str.replace(pattern, replacement, regex=True)
399
+ update_cleaned_data(new_df)
400
+
401
  elif app_mode == "EDA":
402
  st.title("🔍 Interactive Data Explorer")
403
  if 'cleaned_data' not in st.session_state:
 
414
  col3.metric("Missing Values", f"{df.isna().sum().sum()} ({missing_percentage:.1f}%)")
415
  col4.metric("Duplicates", df.duplicated().sum())
416
 
417
+ tab1, tab2, tab3 = st.tabs(["Quick Preview", "Column Types", "Missing Matrix"])
418
+ with tab1:
419
+ st.write("First few rows of the dataset:")
420
+ st.dataframe(df.head(), use_container_width=True)
421
+ with tab2:
422
+ st.write("Column Data Types:")
423
+ type_counts = df.dtypes.value_counts().reset_index()
424
+ type_counts.columns = ['Type', 'Count']
425
+ st.dataframe(type_counts, use_container_width=True)
426
+ with tab3:
427
+ st.write("Missing Values Matrix:")
428
+ fig_missing = px.imshow(df.isna(), color_continuous_scale=['#e0e0e0', '#66c2a5'])
429
+ fig_missing.update_layout(coloraxis_colorscale=[[0, 'lightgrey'], [1, '#FF4B4B']])
430
+ st.plotly_chart(fig_missing, use_container_width=True)
431
+
432
+ enhance_section_title("Interactive Visualization Builder")
433
+ with st.container():
434
+ col1, col2 = st.columns([1, 3])
435
+ with col1:
436
+ plot_type = st.selectbox("Choose visualization type", [
437
+ "Scatter Plot", "Histogram", "Box Plot", "Violin Plot", "Line Chart", "Bar Chart",
438
+ "Correlation Matrix", "Heatmap", "3D Scatter", "Parallel Categories", "Segmented Bar Chart",
439
+ "Swarm Plot", "Ridge Plot", "Bubble Plot", "Density Plot", "Count Plot", "Lollipop Chart"
440
+ ])
441
+ x_axis = st.selectbox("X-axis", df.columns) if plot_type != "Correlation Matrix" else None
442
+ y_axis = st.selectbox("Y-axis", df.columns) if plot_type in ["Scatter Plot", "Box Plot", "Violin Plot", "Line Chart", "Heatmap", "Swarm Plot", "Ridge Plot", "Bubble Plot", "Density Plot", "Lollipop Chart"] else None
443
+ z_axis = st.selectbox("Z-axis", df.columns) if plot_type == "3D Scatter" else None
444
+ color_by = st.selectbox("Color encoding", ["None"] + df.columns.tolist(), format_func=lambda x: "No color" if x == "None" else x) if plot_type != "Correlation Matrix" else None
445
+ if plot_type == "Parallel Categories":
446
+ dimensions = st.multiselect("Dimensions", df.columns.tolist(), default=df.columns[:3].tolist())
447
+ elif plot_type == "Segmented Bar Chart":
448
+ segment_col = st.selectbox("Segment Column (Categorical)", df.select_dtypes(exclude=np.number).columns)
449
+ elif plot_type == "Bubble Plot":
450
+ size_col = st.selectbox("Size Column", df.columns)
451
+
452
+ with col2:
453
+ try:
454
+ fig = None
455
+ if plot_type == "Scatter Plot" and x_axis and y_axis:
456
+ fig = px.scatter(df, x=x_axis, y=y_axis, color=color_by if color_by != "None" else None, trendline="lowess", title=f'Scatter Plot of {x_axis} vs {y_axis}')
457
+ elif plot_type == "Histogram" and x_axis:
458
+ fig = px.histogram(df, x=x_axis, color=color_by if color_by != "None" else None, nbins=30, marginal="box", title=f'Histogram of {x_axis}')
459
+ elif plot_type == "Box Plot" and x_axis and y_axis:
460
+ fig = px.box(df, x=x_axis, y=y_axis, color=color_by if color_by != "None" else None, title=f'Box Plot of {x_axis} vs {y_axis}')
461
+ elif plot_type == "Violin Plot" and x_axis and y_axis:
462
+ fig = px.violin(df, x=x_axis, y=y_axis, color=color_by if color_by != "None" else None, box=True, title=f'Violin Plot of {x_axis} vs {y_axis}')
463
+ elif plot_type == "Line Chart" and x_axis and y_axis:
464
+ fig = px.line(df, x=x_axis, y=y_axis, color=color_by if color_by != "None" else None, title=f'Line Chart of {x_axis} vs {y_axis}')
465
+ elif plot_type == "Bar Chart" and x_axis:
466
+ fig = px.bar(df, x=x_axis, color=color_by if color_by != "None" else None, title=f'Bar Chart of {x_axis}')
467
+ elif plot_type == "Correlation Matrix":
468
+ numeric_df = df.select_dtypes(include=np.number)
469
+ if len(numeric_df.columns) > 1:
470
+ corr = numeric_df.corr()
471
+ fig = px.imshow(corr, text_auto=True, color_continuous_scale='RdBu_r', zmin=-1, zmax=1, title='Correlation Matrix')
472
+ elif plot_type == "Heatmap" and x_axis and y_axis:
473
+ fig = px.density_heatmap(df, x=x_axis, y=y_axis, facet_col=color_by if color_by != "None" else None, title=f'Heatmap of {x_axis} vs {y_axis}')
474
+ elif plot_type == "3D Scatter" and x_axis and y_axis and z_axis:
475
+ fig = px.scatter_3d(df, x=x_axis, y=y_axis, z=z_axis, color=color_by if color_by != "None" else None, title=f'3D Scatter Plot of {x_axis} vs {y_axis} vs {z_axis}')
476
+ elif plot_type == "Parallel Categories" and dimensions:
477
+ fig = px.parallel_categories(df, dimensions=dimensions, color=color_by if color_by != "None" else None, title='Parallel Categories Plot')
478
+ elif plot_type == "Segmented Bar Chart" and x_axis and segment_col:
479
+ segment_counts = df.groupby([x_axis, segment_col]).size().reset_index(name='counts')
480
+ fig = px.bar(segment_counts, x=x_axis, y='counts', color=segment_col, title=f'Segmented Bar Chart of {x_axis} by {segment_col}')
481
+ fig.update_layout(yaxis_title="Count")
482
+ elif plot_type == "Swarm Plot" and x_axis and y_axis:
483
+ fig = px.strip(df, x=x_axis, y=y_axis, color=color_by if color_by != "None" else None, title=f'Swarm Plot of {x_axis} vs {y_axis}')
484
+ elif plot_type == "Ridge Plot" and x_axis and y_axis:
485
+ fig = px.histogram(df, x=x_axis, color=y_axis, marginal="rug", title=f'Ridge Plot of {x_axis} by {y_axis}')
486
+ elif plot_type == "Bubble Plot" and x_axis and y_axis and size_col:
487
+ fig = px.scatter(df, x=x_axis, y=y_axis, size=size_col, color=color_by if color_by != "None" else None, title=f'Bubble Plot of {x_axis} vs {y_axis}')
488
+ elif plot_type == "Density Plot" and x_axis and y_axis:
489
+ fig = px.density_heatmap(df, x=x_axis, y=y_axis, color_continuous_scale="Viridis", title=f'Density Plot of {x_axis} vs {y_axis}')
490
+ elif plot_type == "Count Plot" and x_axis:
491
+ fig = px.bar(df, x=x_axis, color=color_by if color_by != "None" else None, title=f'Count Plot of {x_axis}')
492
+ fig.update_layout(yaxis_title="Count")
493
+ elif plot_type == "Lollipop Chart" and x_axis and y_axis:
494
+ fig = go.Figure()
495
+ fig.add_trace(go.Scatter(x=df[x_axis], y=df[y_axis], mode='markers', marker=dict(size=10)))
496
+ for i in range(len(df)):
497
+ fig.add_trace(go.Scatter(x=[df[x_axis].iloc[i], df[x_axis].iloc[i]], y=[0, df[y_axis].iloc[i]], mode='lines', line=dict(color='gray')))
498
+ fig.update_layout(showlegend=False, title=f'Lollipop Chart of {x_axis} vs {y_axis}')
499
+
500
+ if fig:
501
+ fig.update_layout(template="plotly_white")
502
+ st.plotly_chart(fig, use_container_width=True)
503
+ st.session_state.last_plot = {
504
+ "type": plot_type,
505
+ "x": x_axis,
506
+ "y": y_axis,
507
+ "z": z_axis,
508
+ "color": color_by if color_by != "None" else None,
509
+ "data": df[[x_axis, y_axis] + ([z_axis] if z_axis else [])].to_json() if x_axis and y_axis else df[[x_axis]].to_json()
510
+ }
511
+ else:
512
+ st.error("Please provide required inputs for the selected plot type.")
513
+ except Exception as e:
514
+ st.error(f"Couldn't create visualization: {str(e)}")
515
+
516
  # Chatbot Section
517
  st.markdown("---")
518
  st.subheader("💬 AI Chatbot Assistant (RAG Enabled)")
519
+ st.info("Ask me about the app or your data! Try: 'drop columns X, Y', 'scatter plot of X vs Y', or 'analyze plot'")
520
  if "chat_history" not in st.session_state:
521
  st.session_state.chat_history = []
522
 
 
529
  st.session_state.chat_history.append({"role": "user", "content": user_input})
530
  with st.chat_message("user"):
531
  st.markdown(user_input)
532
+ with st.spinner("Processing..."):
 
533
  dataset_text = st.session_state.get("dataset_text", "")
534
+ func, param = parse_command(user_input)
535
+ if func:
536
+ response = func(param) if param else func(None)
537
+ else:
538
+ response = get_chatbot_response(user_input, app_mode, dataset_text)
539
  st.session_state.chat_history.append({"role": "assistant", "content": response})
540
  with st.chat_message("assistant"):
541
  st.markdown(response)