CosmickVisions commited on
Commit
7ec0dc1
·
verified ·
1 Parent(s): 41390aa

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +394 -152
app.py CHANGED
@@ -2,15 +2,20 @@ import streamlit as st
2
  import pandas as pd
3
  import numpy as np
4
  import plotly.express as px
5
- from sklearn.model_selection import train_test_split
6
- from sklearn.linear_model import LinearRegression
7
- from sklearn.tree import DecisionTreeRegressor
8
- from sklearn.metrics import mean_squared_error, r2_score
 
 
9
  from sklearn.impute import KNNImputer
10
- from sklearn.preprocessing import RobustScaler
 
 
11
  from ydata_profiling import ProfileReport
12
  from streamlit_pandas_profiling import st_profile_report
13
  from io import StringIO
 
14
 
15
  # Configuration
16
  st.set_page_config(page_title="Data Wizard Pro", layout="wide", page_icon="🧙")
@@ -25,12 +30,12 @@ st.markdown(
25
  color: #e0e0ff; /* Light text */
26
  font-family: 'Courier New', monospace; /* Monospace font */
27
  }
28
-
29
  /* Main content area */
30
  .stApp {
31
  background-color: #0a0a1a; /* Match body background */
32
  }
33
-
34
  /* Containers and blocks */
35
  .st-emotion-cache-16idsys,
36
  .st-emotion-cache-1v0mbdj,
@@ -46,44 +51,44 @@ st.markdown(
46
  box-shadow: 0 4px 8px rgba(0, 0, 0, 0.5); /* Enhanced shadow */
47
  color: #e0e0ff; /* Light text color */
48
  }
49
-
50
  /* Sidebar */
51
  .st-bb {
52
  background-color: #141422; /* Dark sidebar background */
53
  padding: 20px;
54
  border-radius: 10px;
55
  }
56
-
57
  /* Headers */
58
  h1, h2, h3, h4, h5, h6, .st-bb {
59
  color: #00f7ff; /* Cyan color for headers */
60
  }
61
-
62
  /* Selectboxes and Buttons */
63
  .st-cb, .st-ci, .st-cj, .st-ch {
64
  background-color: #141422; /* Dark selectbox background */
65
  color: #00f7ff !important; /* Cyan text color */
66
  border: 1px solid #00f7ff; /* Cyan border */
67
  }
68
-
69
  /* Selectbox text */
70
  .st-cv {
71
  color: #00f7ff !important; /* Cyan color for selectbox text */
72
  }
73
-
74
  /* Number input and text input */
75
  .st-cr {
76
  background-color: #141422 !important; /* Dark input background */
77
  color: #00f7ff !important; /* Cyan text color */
78
  border: 1px solid #00f7ff !important; /* Cyan border */
79
  }
80
-
81
  /* Slider */
82
  .st-cw {
83
  background-color: #141422 !important; /* Dark slider background */
84
  border: 1px solid #00f7ff !important; /* Cyan border */
85
  }
86
-
87
  /* Buttons */
88
  .st-bz, .st-b0 {
89
  background-color: #141422; /* Darker Button background */
@@ -95,7 +100,7 @@ st.markdown(
95
  background-color: #00f7ff; /* Hover color */
96
  color: #0a0a1a; /* Hover text color */
97
  }
98
-
99
  /* File uploader */
100
  .st-ae {
101
  background-color: #141422 !important; /* Dark file uploader background */
@@ -103,24 +108,21 @@ st.markdown(
103
  border: 1px solid #00f7ff !important; /* Cyan border */
104
  border-radius: 10px; /* Rounded corners */
105
  }
106
-
107
  /* Metric */
108
  .st-emotion-cache-10trblm {
109
  border-radius: 10px !important; /* Rounded corners */
110
  box-shadow: 0 4px 8px rgba(0, 0, 0, 0.5) !important; /* Enhanced shadow */
111
  }
112
-
113
  /* Dataframes and tables */
114
  .dataframe {
115
  background-color: #1e1e30 !important; /* Dark table background */
116
  color: #e0e0ff !important; /* Light text in tables */
117
  border: 1px solid #00f7ff !important; /* Cyan border for tables */
118
  }
119
-
120
  .dataframe tr:nth-child(odd) {
121
  background-color: #141422 !important; /* Alternating row color */
122
  }
123
-
124
  /* Expanders*/
125
  .st-emotion-cache-10oheav {
126
  color: #00f7ff !important; /* Cyan text color */
@@ -142,10 +144,20 @@ st.markdown(
142
  # Cache decorators
143
  @st.cache_data(ttl=3600)
144
  def load_data(uploaded_file):
145
- """Load and cache dataset"""
146
- if uploaded_file.name.endswith('.csv'):
147
- return pd.read_csv(uploaded_file)
148
- return pd.read_excel(uploaded_file)
 
 
 
 
 
 
 
 
 
 
149
 
150
  @st.cache_data(ttl=3600)
151
  def generate_profile(df):
@@ -161,12 +173,14 @@ if 'train_test' not in st.session_state:
161
  st.session_state.train_test = {}
162
  if 'model' not in st.session_state:
163
  st.session_state.model = None
 
 
164
 
165
  # Sidebar Navigation
166
  st.sidebar.title("🔮 Data Wizard Pro")
167
  app_mode = st.sidebar.radio("Navigate", [
168
- "Data Upload",
169
- "Smart Cleaning",
170
  "Advanced EDA",
171
  "Model Training",
172
  "Predictions",
@@ -176,35 +190,36 @@ app_mode = st.sidebar.radio("Navigate", [
176
  # Data Upload Section
177
  if app_mode == "Data Upload":
178
  st.title("📤 Data Upload & Analysis")
179
-
180
  uploaded_file = st.file_uploader("Upload Dataset", type=["csv", "xlsx"])
181
  if uploaded_file:
182
  df = load_data(uploaded_file)
183
- st.session_state.raw_data = df
184
- st.session_state.cleaned_data = df.copy()
185
-
186
- # Data Overview Cards
187
- col1, col2, col3 = st.columns(3)
188
- with col1:
189
- st.metric("Rows", df.shape[0])
190
- with col2:
191
- st.metric("Columns", df.shape[1])
192
- with col3:
193
- st.metric("Missing Values", df.isna().sum().sum())
194
-
195
- # Automated EDA Report
196
- with st.expander("🚀 Automated Data Report"):
197
- if st.button("Generate Smart Report"):
198
- pr = generate_profile(df)
199
- st_profile_report(pr)
 
200
 
201
  # Smart Cleaning Section
202
  elif app_mode == "Smart Cleaning":
203
  st.title("🧼 Intelligent Data Cleaning")
204
-
205
  if st.session_state.raw_data is not None:
206
  df = st.session_state.cleaned_data
207
-
208
  # Cleaning Toolkit
209
  col1, col2 = st.columns([1, 3])
210
  with col1:
@@ -213,9 +228,10 @@ elif app_mode == "Smart Cleaning":
213
  "Handle Missing Values",
214
  "Remove Duplicates",
215
  "Normalize Data",
216
- "Encode Categories"
 
217
  ])
218
-
219
  if clean_action == "Handle Missing Values":
220
  method = st.selectbox("Imputation Method", [
221
  "KNN Imputation",
@@ -223,160 +239,383 @@ elif app_mode == "Smart Cleaning":
223
  "Mean Fill",
224
  "Drop Missing"
225
  ])
226
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
227
  with col2:
228
  if st.button("Apply Transformation"):
229
  with st.spinner("Applying changes..."):
230
  if clean_action == "Handle Missing Values":
231
- if method == "KNN Imputation":
232
- imputer = KNNImputer()
233
- df = pd.DataFrame(imputer.fit_transform(df), columns=df.columns)
234
- elif method == "Median Fill":
235
- df = df.fillna(df.median())
236
- elif method == "Mean Fill":
237
- df = df.fillna(df.mean())
238
  else:
239
- df = df.dropna()
 
 
 
 
 
 
 
 
240
  elif clean_action == "Remove Duplicates":
241
  df = df.drop_duplicates()
242
  elif clean_action == "Normalize Data":
243
- scaler = RobustScaler()
244
- numerical_cols = df.select_dtypes(include=np.number).columns
245
- df[numerical_cols] = scaler.fit_transform(df[numerical_cols])
246
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
247
  st.session_state.cleaned_data = df
248
  st.success("Transformation applied!")
249
-
250
  # Data Comparison
251
  st.subheader("Data Version Comparison")
252
  col1, col2 = st.columns(2)
253
  with col1:
254
- st.write("Original Data", st.session_state.raw_data.head(3))
255
  with col2:
256
  st.write("Cleaned Data", df.head(3))
257
 
258
  # Advanced EDA Section
259
  elif app_mode == "Advanced EDA":
260
  st.title("🔍 Advanced Exploratory Analysis")
261
-
262
  if st.session_state.cleaned_data is not None:
263
  df = st.session_state.cleaned_data
264
-
265
  # Visualization Selector
266
  plot_type = st.selectbox("Choose Visualization", [
267
- "Histogram",
268
  "Scatter Plot",
269
  "Box Plot",
270
  "Correlation Heatmap",
271
- "3D Scatter"
 
 
272
  ])
273
-
274
  # Dynamic Axis Selection
275
  cols = st.columns(3)
276
  with cols[0]:
277
  x_col = st.selectbox("X Axis", df.columns)
278
  with cols[1]:
279
- y_col = st.selectbox("Y Axis", df.columns) if plot_type in ["Scatter Plot", "Box Plot"] else None
280
  with cols[2]:
281
  z_col = st.selectbox("Z Axis", df.columns) if plot_type == "3D Scatter" else None
282
-
 
 
 
 
 
 
 
 
 
 
 
283
  # Generate Plot
284
  if st.button("Generate Visualization"):
285
- if plot_type == "Histogram":
286
- fig = px.histogram(df, x=x_col, nbins=30, template="plotly_dark")
287
- elif plot_type == "Scatter Plot":
288
- fig = px.scatter(df, x=x_col, y=y_col, color_discrete_sequence=['#00f7ff'])
289
- elif plot_type == "3D Scatter":
290
- fig = px.scatter_3d(df, x=x_col, y=y_col, z=z_col, color=x_col)
291
- elif plot_type == "Correlation Heatmap":
292
- corr = df.corr()
293
- fig = px.imshow(corr, text_auto=True, color_continuous_scale='RdBu')
294
- elif plot_type == "Box Plot":
295
- fig = px.box(df,x=x_col, y=y_col, color_discrete_sequence=['#00f7ff'])
296
-
297
- fig.update_layout(
298
- plot_bgcolor="#1e1e30",
299
- paper_bgcolor="#1e1e30",
300
- font_color="#e0e0ff"
301
- )
302
-
303
- st.plotly_chart(fig, use_container_width=True)
 
 
 
 
 
 
 
304
 
305
  # Model Training Section
306
  elif app_mode == "Model Training":
307
  st.title("🤖 Model Training Studio")
308
-
309
  if st.session_state.cleaned_data is not None:
310
  df = st.session_state.cleaned_data
311
-
 
 
 
 
 
312
  # Model Setup
313
  col1, col2 = st.columns([1, 3])
314
  with col1:
315
- model_type = st.selectbox("Choose Model", [
316
- "Linear Regression",
317
- "Decision Tree"
318
- ])
319
-
 
 
 
 
 
 
 
 
 
 
 
320
  test_size = st.slider("Test Size", 0.1, 0.5, 0.2)
321
  target = st.selectbox("Target Variable", df.columns)
322
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
323
  with col2:
324
  if st.button("Train Model"):
325
- X = df.drop(columns=[target])
326
- y = df[target]
327
-
328
- X_train, X_test, y_train, y_test = train_test_split(
329
- X, y, test_size=test_size, random_state=42
330
- )
331
-
332
- if model_type == "Linear Regression":
333
- model = LinearRegression()
334
- elif model_type == "Decision Tree":
335
- model = DecisionTreeRegressor()
336
-
337
- model.fit(X_train, y_train)
338
- st.session_state.model = model
339
- st.session_state.train_test = {
340
- 'X_test': X_test,
341
- 'y_test': y_test
342
- }
343
-
344
- # Evaluation Metrics
345
- y_pred = model.predict(X_test)
346
- st.metric("R² Score", round(r2_score(y_test, y_pred), 2))
347
- st.metric("MSE", round(mean_squared_error(y_test, y_pred), 2))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
348
 
349
  # Predictions Section
350
  elif app_mode == "Predictions":
351
  st.title("🔮 Make Predictions")
352
-
353
- if st.session_state.model is not None:
354
- model = st.session_state.model
355
-
 
356
  # Prediction Interface
357
  input_data = {}
358
- for col in st.session_state.train_test['X_test'].columns:
359
- input_data[col] = st.number_input(col, value=0.0)
360
-
 
 
 
361
  if st.button("Predict"):
362
- input_df = pd.DataFrame([input_data])
363
- prediction = model.predict(input_df)
364
- st.success(f"Predicted Value: {prediction[0]:.2f}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
365
 
366
  elif app_mode == "Visualization Lab":
367
  st.title("📊 Advanced Visualization Lab")
368
-
369
  if st.session_state.cleaned_data is not None:
370
  df = st.session_state.cleaned_data
371
-
372
  # Visualization Gallery
373
  viz_type = st.selectbox("Choose Visualization Type", [
374
  "3D Scatter Plot",
375
  "Interactive Heatmap",
376
  "Time Series Analysis",
377
- "Cluster Analysis"
378
  ])
379
-
380
  # Dynamic Controls
381
  cols = st.columns(3)
382
  with cols[0]:
@@ -385,23 +624,26 @@ elif app_mode == "Visualization Lab":
385
  y_axis = st.selectbox("Y Axis", df.columns)
386
  with cols[2]:
387
  z_axis = st.selectbox("Z Axis", df.columns) if viz_type == "3D Scatter Plot" else None
388
-
389
  # Generate Visualization
390
- if viz_type == "3D Scatter Plot":
391
- fig = px.scatter_3d(df, x=x_axis, y=y_axis, z=z_axis, color=x_axis)
392
- st.plotly_chart(fig, use_container_width=True)
393
-
394
- elif viz_type == "Interactive Heatmap":
395
- corr = df.corr()
396
- fig = px.imshow(corr, text_auto=True, color_continuous_scale='RdBu')
397
- st.plotly_chart(fig, use_container_width=True)
398
-
399
- elif viz_type == "Time Series Analysis":
400
- # Basic time series plot
401
- time_col = st.selectbox("Time Column", df.columns)
402
- value_col = st.selectbox("Value Column", df.columns)
403
- fig = px.line(df, x=time_col, y=value_col)
404
- st.plotly_chart(fig, use_container_width=True)
405
-
406
- elif viz_type == "Cluster Analysis":
407
- st.write("Cluster Analysis Feature Coming Soon!") # placeholder for future development
 
 
 
 
2
  import pandas as pd
3
  import numpy as np
4
  import plotly.express as px
5
+ from sklearn.model_selection import train_test_split, GridSearchCV, cross_val_score
6
+ from sklearn.linear_model import LinearRegression, LogisticRegression
7
+ from sklearn.tree import DecisionTreeRegressor, DecisionTreeClassifier
8
+ from sklearn.ensemble import RandomForestRegressor, GradientBoostingRegressor, RandomForestClassifier
9
+ from sklearn.svm import SVR, SVC
10
+ from sklearn.metrics import mean_squared_error, r2_score, mean_absolute_error, accuracy_score, precision_score, recall_score, f1_score, roc_auc_score
11
  from sklearn.impute import KNNImputer
12
+ from sklearn.preprocessing import RobustScaler, StandardScaler, OneHotEncoder
13
+ from sklearn.compose import ColumnTransformer
14
+ from sklearn.pipeline import Pipeline
15
  from ydata_profiling import ProfileReport
16
  from streamlit_pandas_profiling import st_profile_report
17
  from io import StringIO
18
+ import joblib # For saving and loading models
19
 
20
  # Configuration
21
  st.set_page_config(page_title="Data Wizard Pro", layout="wide", page_icon="🧙")
 
30
  color: #e0e0ff; /* Light text */
31
  font-family: 'Courier New', monospace; /* Monospace font */
32
  }
33
+
34
  /* Main content area */
35
  .stApp {
36
  background-color: #0a0a1a; /* Match body background */
37
  }
38
+
39
  /* Containers and blocks */
40
  .st-emotion-cache-16idsys,
41
  .st-emotion-cache-1v0mbdj,
 
51
  box-shadow: 0 4px 8px rgba(0, 0, 0, 0.5); /* Enhanced shadow */
52
  color: #e0e0ff; /* Light text color */
53
  }
54
+
55
  /* Sidebar */
56
  .st-bb {
57
  background-color: #141422; /* Dark sidebar background */
58
  padding: 20px;
59
  border-radius: 10px;
60
  }
61
+
62
  /* Headers */
63
  h1, h2, h3, h4, h5, h6, .st-bb {
64
  color: #00f7ff; /* Cyan color for headers */
65
  }
66
+
67
  /* Selectboxes and Buttons */
68
  .st-cb, .st-ci, .st-cj, .st-ch {
69
  background-color: #141422; /* Dark selectbox background */
70
  color: #00f7ff !important; /* Cyan text color */
71
  border: 1px solid #00f7ff; /* Cyan border */
72
  }
73
+
74
  /* Selectbox text */
75
  .st-cv {
76
  color: #00f7ff !important; /* Cyan color for selectbox text */
77
  }
78
+
79
  /* Number input and text input */
80
  .st-cr {
81
  background-color: #141422 !important; /* Dark input background */
82
  color: #00f7ff !important; /* Cyan text color */
83
  border: 1px solid #00f7ff !important; /* Cyan border */
84
  }
85
+
86
  /* Slider */
87
  .st-cw {
88
  background-color: #141422 !important; /* Dark slider background */
89
  border: 1px solid #00f7ff !important; /* Cyan border */
90
  }
91
+
92
  /* Buttons */
93
  .st-bz, .st-b0 {
94
  background-color: #141422; /* Darker Button background */
 
100
  background-color: #00f7ff; /* Hover color */
101
  color: #0a0a1a; /* Hover text color */
102
  }
103
+
104
  /* File uploader */
105
  .st-ae {
106
  background-color: #141422 !important; /* Dark file uploader background */
 
108
  border: 1px solid #00f7ff !important; /* Cyan border */
109
  border-radius: 10px; /* Rounded corners */
110
  }
 
111
  /* Metric */
112
  .st-emotion-cache-10trblm {
113
  border-radius: 10px !important; /* Rounded corners */
114
  box-shadow: 0 4px 8px rgba(0, 0, 0, 0.5) !important; /* Enhanced shadow */
115
  }
116
+
117
  /* Dataframes and tables */
118
  .dataframe {
119
  background-color: #1e1e30 !important; /* Dark table background */
120
  color: #e0e0ff !important; /* Light text in tables */
121
  border: 1px solid #00f7ff !important; /* Cyan border for tables */
122
  }
 
123
  .dataframe tr:nth-child(odd) {
124
  background-color: #141422 !important; /* Alternating row color */
125
  }
 
126
  /* Expanders*/
127
  .st-emotion-cache-10oheav {
128
  color: #00f7ff !important; /* Cyan text color */
 
144
  # Cache decorators
145
  @st.cache_data(ttl=3600)
146
  def load_data(uploaded_file):
147
+ """Load and cache dataset, with file type validation."""
148
+ if uploaded_file is not None:
149
+ file_extension = uploaded_file.name.split(".")[-1].lower()
150
+
151
+ if file_extension == "csv":
152
+ return pd.read_csv(uploaded_file)
153
+ elif file_extension in ["xlsx", "xls"]:
154
+ return pd.read_excel(uploaded_file)
155
+ else:
156
+ st.error("Unsupported file type. Please upload a CSV or Excel file.")
157
+ return None
158
+ else:
159
+ return None
160
+
161
 
162
  @st.cache_data(ttl=3600)
163
  def generate_profile(df):
 
173
  st.session_state.train_test = {}
174
  if 'model' not in st.session_state:
175
  st.session_state.model = None
176
+ if 'preprocessor' not in st.session_state:
177
+ st.session_state.preprocessor = None # to store the column transformer
178
 
179
  # Sidebar Navigation
180
  st.sidebar.title("🔮 Data Wizard Pro")
181
  app_mode = st.sidebar.radio("Navigate", [
182
+ "Data Upload",
183
+ "Smart Cleaning",
184
  "Advanced EDA",
185
  "Model Training",
186
  "Predictions",
 
190
  # Data Upload Section
191
  if app_mode == "Data Upload":
192
  st.title("📤 Data Upload & Analysis")
193
+
194
  uploaded_file = st.file_uploader("Upload Dataset", type=["csv", "xlsx"])
195
  if uploaded_file:
196
  df = load_data(uploaded_file)
197
+ if df is not None: # only proceed if load_data returned a valid dataframe
198
+ st.session_state.raw_data = df
199
+ st.session_state.cleaned_data = df.copy()
200
+
201
+ # Data Overview Cards
202
+ col1, col2, col3 = st.columns(3)
203
+ with col1:
204
+ st.metric("Rows", df.shape[0])
205
+ with col2:
206
+ st.metric("Columns", df.shape[1])
207
+ with col3:
208
+ st.metric("Missing Values", df.isna().sum().sum())
209
+
210
+ # Automated EDA Report
211
+ with st.expander("🚀 Automated Data Report"):
212
+ if st.button("Generate Smart Report"):
213
+ pr = generate_profile(df)
214
+ st_profile_report(pr)
215
 
216
  # Smart Cleaning Section
217
  elif app_mode == "Smart Cleaning":
218
  st.title("🧼 Intelligent Data Cleaning")
219
+
220
  if st.session_state.raw_data is not None:
221
  df = st.session_state.cleaned_data
222
+
223
  # Cleaning Toolkit
224
  col1, col2 = st.columns([1, 3])
225
  with col1:
 
228
  "Handle Missing Values",
229
  "Remove Duplicates",
230
  "Normalize Data",
231
+ "Encode Categories",
232
+ "Outlier Removal"
233
  ])
234
+
235
  if clean_action == "Handle Missing Values":
236
  method = st.selectbox("Imputation Method", [
237
  "KNN Imputation",
 
239
  "Mean Fill",
240
  "Drop Missing"
241
  ])
242
+ impute_cols = st.multiselect("Columns to Impute", df.columns)
243
+
244
+ elif clean_action == "Normalize Data":
245
+ scaler_type = st.selectbox("Scaler Type", ["RobustScaler", "StandardScaler"])
246
+ normalize_cols = st.multiselect("Columns to Normalize", df.select_dtypes(include=np.number).columns.tolist())
247
+
248
+ elif clean_action == "Encode Categories":
249
+ encode_cols = st.multiselect("Columns to Encode", df.select_dtypes(include='object').columns.tolist())
250
+ encoding_method = st.selectbox("Encoding Method", ["OneHotEncoder"]) # Add more if needed
251
+
252
+ elif clean_action == "Outlier Removal":
253
+ outlier_cols = st.multiselect("Columns to Remove Outliers From", df.select_dtypes(include=np.number).columns.tolist())
254
+ outlier_method = st.selectbox("Outlier Removal Method", ["IQR", "Z-score"])
255
+ if outlier_method == "IQR":
256
+ iqr_threshold = st.slider("IQR Threshold", 1.0, 3.0, 1.5)
257
+ else:
258
+ zscore_threshold = st.slider("Z-score Threshold", 2.0, 4.0, 3.0)
259
+
260
+
261
  with col2:
262
  if st.button("Apply Transformation"):
263
  with st.spinner("Applying changes..."):
264
  if clean_action == "Handle Missing Values":
265
+ if not impute_cols:
266
+ st.warning("Please select columns to impute.")
 
 
 
 
 
267
  else:
268
+ if method == "KNN Imputation":
269
+ imputer = KNNImputer()
270
+ df[impute_cols] = imputer.fit_transform(df[impute_cols])
271
+ elif method == "Median Fill":
272
+ df[impute_cols] = df[impute_cols].fillna(df[impute_cols].median())
273
+ elif method == "Mean Fill":
274
+ df[impute_cols] = df[impute_cols].fillna(df[impute_cols].mean())
275
+ else:
276
+ df = df.dropna(subset=impute_cols)
277
  elif clean_action == "Remove Duplicates":
278
  df = df.drop_duplicates()
279
  elif clean_action == "Normalize Data":
280
+ if not normalize_cols:
281
+ st.warning("Please select columns to normalize.")
282
+ else:
283
+ if scaler_type == "RobustScaler":
284
+ scaler = RobustScaler()
285
+ else:
286
+ scaler = StandardScaler()
287
+ df[normalize_cols] = scaler.fit_transform(df[normalize_cols])
288
+
289
+ elif clean_action == "Encode Categories":
290
+ if not encode_cols:
291
+ st.warning("Please select columns to encode.")
292
+ else:
293
+ if encoding_method == "OneHotEncoder":
294
+ encoder = OneHotEncoder(handle_unknown='ignore', sparse_output=False)
295
+ encoded_data = encoder.fit_transform(df[encode_cols])
296
+ encoded_df = pd.DataFrame(encoded_data, columns=encoder.get_feature_names_out(encode_cols))
297
+ df = pd.concat([df.drop(columns=encode_cols), encoded_df], axis=1)
298
+
299
+ elif clean_action == "Outlier Removal":
300
+ if not outlier_cols:
301
+ st.warning("Please select columns to remove outliers from.")
302
+ else:
303
+ for col in outlier_cols:
304
+ if outlier_method == "IQR":
305
+ Q1 = df[col].quantile(0.25)
306
+ Q3 = df[col].quantile(0.75)
307
+ IQR = Q3 - Q1
308
+ lower_bound = Q1 - iqr_threshold * IQR
309
+ upper_bound = Q3 + iqr_threshold * IQR
310
+ df = df[(df[col] >= lower_bound) & (df[col] <= upper_bound)]
311
+ else: # Z-score
312
+ z_scores = np.abs((df[col] - df[col].mean()) / df[col].std())
313
+ df = df[z_scores <= zscore_threshold]
314
+
315
+
316
  st.session_state.cleaned_data = df
317
  st.success("Transformation applied!")
318
+
319
  # Data Comparison
320
  st.subheader("Data Version Comparison")
321
  col1, col2 = st.columns(2)
322
  with col1:
323
+ st.write("Original Data", st.session_state.raw_data.head(3) if st.session_state.raw_data is not None else "No data uploaded")
324
  with col2:
325
  st.write("Cleaned Data", df.head(3))
326
 
327
  # Advanced EDA Section
328
  elif app_mode == "Advanced EDA":
329
  st.title("🔍 Advanced Exploratory Analysis")
330
+
331
  if st.session_state.cleaned_data is not None:
332
  df = st.session_state.cleaned_data
333
+
334
  # Visualization Selector
335
  plot_type = st.selectbox("Choose Visualization", [
336
+ "Histogram",
337
  "Scatter Plot",
338
  "Box Plot",
339
  "Correlation Heatmap",
340
+ "3D Scatter",
341
+ "Violin Plot",
342
+ "Time Series"
343
  ])
344
+
345
  # Dynamic Axis Selection
346
  cols = st.columns(3)
347
  with cols[0]:
348
  x_col = st.selectbox("X Axis", df.columns)
349
  with cols[1]:
350
+ y_col = st.selectbox("Y Axis", df.columns) if plot_type in ["Scatter Plot", "Box Plot", "Violin Plot", "Time Series"] else None
351
  with cols[2]:
352
  z_col = st.selectbox("Z Axis", df.columns) if plot_type == "3D Scatter" else None
353
+ if plot_type == "Time Series":
354
+ time_col = x_col # rename for clarity
355
+ value_col = y_col
356
+
357
+ #Interactive filtering
358
+ filter_col = st.selectbox("Filter Column", [None] + list(df.columns))
359
+ if filter_col:
360
+ unique_values = df[filter_col].unique()
361
+ filter_options = st.multiselect("Filter Values", unique_values, default=unique_values)
362
+ df = df[df[filter_col].isin(filter_options)]
363
+
364
+
365
  # Generate Plot
366
  if st.button("Generate Visualization"):
367
+ try: # add try-except block for potential errors
368
+ if plot_type == "Histogram":
369
+ fig = px.histogram(df, x=x_col, nbins=30, template="plotly_dark")
370
+ elif plot_type == "Scatter Plot":
371
+ fig = px.scatter(df, x=x_col, y=y_col, color_discrete_sequence=['#00f7ff'])
372
+ elif plot_type == "3D Scatter":
373
+ fig = px.scatter_3d(df, x=x_col, y=y_col, z=z_col, color=x_col)
374
+ elif plot_type == "Correlation Heatmap":
375
+ corr = df.corr(numeric_only=True) #handle non-numeric cols
376
+ fig = px.imshow(corr, text_auto=True, color_continuous_scale='RdBu')
377
+ elif plot_type == "Box Plot":
378
+ fig = px.box(df,x=x_col, y=y_col, color_discrete_sequence=['#00f7ff'])
379
+ elif plot_type == "Violin Plot":
380
+ fig = px.violin(df, x=x_col, y=y_col, color_discrete_sequence=['#00f7ff'])
381
+ elif plot_type == "Time Series":
382
+ fig = px.line(df, x=time_col, y=value_col)
383
+
384
+ fig.update_layout(
385
+ plot_bgcolor="#1e1e30",
386
+ paper_bgcolor="#1e1e30",
387
+ font_color="#e0e0ff"
388
+ )
389
+
390
+ st.plotly_chart(fig, use_container_width=True)
391
+ except Exception as e:
392
+ st.error(f"Error generating plot: {e}")
393
 
394
  # Model Training Section
395
  elif app_mode == "Model Training":
396
  st.title("🤖 Model Training Studio")
397
+
398
  if st.session_state.cleaned_data is not None:
399
  df = st.session_state.cleaned_data
400
+
401
+ # Check for missing values before proceeding
402
+ if df.isnull().sum().sum() > 0:
403
+ st.error("Data contains missing values. Please handle them in the 'Smart Cleaning' section before training.")
404
+ st.stop()
405
+
406
  # Model Setup
407
  col1, col2 = st.columns([1, 3])
408
  with col1:
409
+ task_type = st.selectbox("Choose Task", ["Regression", "Classification"])
410
+ if task_type == "Regression":
411
+ model_type = st.selectbox("Choose Model", [
412
+ "Linear Regression",
413
+ "Decision Tree",
414
+ "Random Forest",
415
+ "Gradient Boosting"
416
+ ])
417
+ else: # Classification
418
+ model_type = st.selectbox("Choose Model", [
419
+ "Logistic Regression",
420
+ "Decision Tree",
421
+ "Random Forest",
422
+ "Support Vector Machine" #SVC
423
+ ])
424
+
425
  test_size = st.slider("Test Size", 0.1, 0.5, 0.2)
426
  target = st.selectbox("Target Variable", df.columns)
427
+ features = [col for col in df.columns if col != target] #Exclude target
428
+ numeric_features = df[features].select_dtypes(include=np.number).columns.tolist()
429
+ categorical_features = [col for col in features if col not in numeric_features]
430
+
431
+ # Hyperparameter tuning options (example for RandomForest)
432
+ enable_hyperparameter_tuning = st.checkbox("Enable Hyperparameter Tuning")
433
+ if enable_hyperparameter_tuning and model_type in ["Random Forest", "Gradient Boosting", "Support Vector Machine", "Logistic Regression", "Decision Tree"]: # Add more models later
434
+ st.write("Hyperparameter Tuning Options:")
435
+ if model_type == "Random Forest":
436
+ n_estimators = st.slider("Number of Estimators", 50, 200, 100)
437
+ max_depth = st.slider("Max Depth", 5, 20, None) #None for unlimited
438
+ param_grid = {'n_estimators': [n_estimators], 'max_depth': [max_depth]}
439
+
440
+ elif model_type == "Gradient Boosting":
441
+ n_estimators = st.slider("Number of Estimators", 50, 200, 100, key = "gb_n_estimators")
442
+ learning_rate = st.slider("Learning Rate", 0.01, 0.1, 0.05, key = "gb_learning_rate")
443
+ max_depth = st.slider("Max Depth", 3, 10, 5, key = "gb_max_depth")
444
+ param_grid = {'n_estimators': [n_estimators], 'learning_rate': [learning_rate], 'max_depth': [max_depth]}
445
+
446
+ elif model_type == "Support Vector Machine": #SVC/SVR
447
+ kernel = st.selectbox("Kernel", ['linear', 'rbf', 'poly'])
448
+ C = st.slider("C (Regularization)", 0.1, 1.0, 0.5)
449
+ param_grid = {'kernel': [kernel], 'C': [C]}
450
+ elif model_type == "Logistic Regression":
451
+ C = st.slider("C (Regularization)", 0.1, 1.0, 0.5)
452
+ param_grid = {'C': [C]} # add more as needed
453
+ elif model_type == "Decision Tree":
454
+ max_depth = st.slider("Max Depth", 5, 20, None) # None for unlimited
455
+ param_grid = {'max_depth': [max_depth]}
456
+
457
+
458
+
459
  with col2:
460
  if st.button("Train Model"):
461
+ try:
462
+ X = df.drop(columns=[target])
463
+ y = df[target]
464
+
465
+ X_train, X_test, y_train, y_test = train_test_split(
466
+ X, y, test_size=test_size, random_state=42
467
+ )
468
+
469
+ # Preprocessing
470
+ numeric_transformer = StandardScaler() #StandardScaler or other scalers
471
+ categorical_transformer = OneHotEncoder(handle_unknown='ignore', sparse_output=False) #sparse=False for array output
472
+
473
+ preprocessor = ColumnTransformer(
474
+ transformers=[
475
+ ('num', numeric_transformer, numeric_features),
476
+ ('cat', categorical_transformer, categorical_features)
477
+ ],
478
+ remainder='passthrough' # or 'drop' if you want to drop untransformed cols
479
+ )
480
+
481
+ X_train = preprocessor.fit_transform(X_train)
482
+ X_test = preprocessor.transform(X_test)
483
+ st.session_state.preprocessor = preprocessor #store for prediction later
484
+
485
+
486
+ # Model Training
487
+ if task_type == "Regression":
488
+ if model_type == "Linear Regression":
489
+ model = LinearRegression()
490
+ elif model_type == "Decision Tree":
491
+ model = DecisionTreeRegressor()
492
+ elif model_type == "Random Forest":
493
+ model = RandomForestRegressor()
494
+ elif model_type == "Gradient Boosting":
495
+ model = GradientBoostingRegressor()
496
+ elif model_type == "Support Vector Machine":
497
+ model = SVR()
498
+ else: #Classification
499
+ if model_type == "Logistic Regression":
500
+ model = LogisticRegression(max_iter=1000) #increase max_iter if needed
501
+ elif model_type == "Decision Tree":
502
+ model = DecisionTreeClassifier()
503
+ elif model_type == "Random Forest":
504
+ model = RandomForestClassifier()
505
+ elif model_type == "Support Vector Machine":
506
+ model = SVC(probability=True) #probability=True needed for ROC AUC
507
+
508
+
509
+ #Hyperparameter tuning
510
+ if enable_hyperparameter_tuning and model_type in ["Random Forest", "Gradient Boosting", "Support Vector Machine", "Logistic Regression", "Decision Tree"]:
511
+ grid_search = GridSearchCV(model, param_grid, cv=3, scoring='neg_mean_squared_error' if task_type == "Regression" else 'accuracy')
512
+ grid_search.fit(X_train, y_train)
513
+ model = grid_search.best_estimator_ #use best model
514
+ st.write("Best Parameters:", grid_search.best_params_)
515
+
516
+ else:
517
+ model.fit(X_train, y_train)
518
+
519
+ st.session_state.model = model
520
+ st.session_state.train_test = {
521
+ 'X_test': X_test,
522
+ 'y_test': y_test,
523
+ 'task': task_type #Store task for eval
524
+ }
525
+
526
+ # Evaluation Metrics
527
+ y_pred = model.predict(X_test)
528
+
529
+ if task_type == "Regression":
530
+ r2 = r2_score(y_test, y_pred)
531
+ mse = mean_squared_error(y_test, y_pred)
532
+ mae = mean_absolute_error(y_test, y_pred) #ADDED
533
+ st.metric("R² Score", round(r2, 2))
534
+ st.metric("MSE", round(mse, 2))
535
+ st.metric("MAE", round(mae, 2)) #ADDED
536
+ else: #Classification
537
+ accuracy = accuracy_score(y_test, y_pred)
538
+ precision = precision_score(y_test, y_pred, average='weighted', zero_division=0)
539
+ recall = recall_score(y_test, y_pred, average='weighted', zero_division=0)
540
+ f1 = f1_score(y_test, y_pred, average='weighted', zero_division=0)
541
+ try:
542
+ roc_auc = roc_auc_score(y_test, model.predict_proba(X_test)[:, 1]) #requires probabilities
543
+ st.metric("ROC AUC", round(roc_auc, 2))
544
+ except:
545
+ st.warning("ROC AUC score not available for this classifier.")
546
+ st.metric("Accuracy", round(accuracy, 2))
547
+ st.metric("Precision", round(precision, 2))
548
+ st.metric("Recall", round(recall, 2))
549
+ st.metric("F1 Score", round(f1, 2))
550
+
551
+ #Cross Validation
552
+ scores = cross_val_score(model, X_train, y_train, cv=5, scoring='neg_mean_squared_error' if task_type == "Regression" else 'accuracy') # use appropriate scoring
553
+ st.write("Cross-Validation Scores:", scores)
554
+ st.write("Mean Cross-Validation Score:", scores.mean())
555
+
556
+
557
+
558
+ #Model persistence
559
+ if st.checkbox("Save Model"):
560
+ model_filename = st.text_input("Model Filename", "trained_model.joblib")
561
+ joblib.dump((model, preprocessor), model_filename) # save both model AND preprocessor
562
+ st.success(f"Model saved as {model_filename}")
563
+ except Exception as e:
564
+ st.error(f"Error during training: {e}")
565
 
566
  # Predictions Section
567
  elif app_mode == "Predictions":
568
  st.title("🔮 Make Predictions")
569
+
570
+ if st.session_state.model is not None and st.session_state.preprocessor is not None:
571
+ model, preprocessor = st.session_state.model, st.session_state.preprocessor
572
+ X_test_cols = st.session_state.train_test['X_test'].shape[1] #get the number of input cols
573
+
574
  # Prediction Interface
575
  input_data = {}
576
+ X_test_columns = [f"feature_{i}" for i in range(X_test_cols)] # Generate placeholder column names
577
+ input_data = {}
578
+ for i in range(X_test_cols):
579
+ input_data[f"feature_{i}"] = st.number_input(f"Feature {i+1}", value=0.0)
580
+ #for col in st.session_state.train_test['X_test'].columns: # causes error since its preprocessed
581
+
582
  if st.button("Predict"):
583
+ try:
584
+
585
+ input_df = pd.DataFrame([input_data])
586
+ # Preprocess input
587
+ input_processed = preprocessor.transform(input_df)
588
+ prediction = model.predict(input_processed)
589
+
590
+ if st.session_state.train_test['task'] == "Regression":
591
+ st.success(f"Predicted Value: {prediction[0]:.2f}")
592
+ else:
593
+ st.success(f"Predicted Class: {prediction[0]}")
594
+ # Show probabilities if it's a classifier
595
+ if hasattr(model, "predict_proba"):
596
+ proba = model.predict_proba(input_processed)[0]
597
+ for i, p in enumerate(proba):
598
+ st.write(f"Probability of class {i}: {p:.2f}")
599
+
600
+ except Exception as e:
601
+ st.error(f"Error during prediction: {e}")
602
+ else:
603
+ st.warning("Please train a model first.")
604
 
605
  elif app_mode == "Visualization Lab":
606
  st.title("📊 Advanced Visualization Lab")
607
+
608
  if st.session_state.cleaned_data is not None:
609
  df = st.session_state.cleaned_data
610
+
611
  # Visualization Gallery
612
  viz_type = st.selectbox("Choose Visualization Type", [
613
  "3D Scatter Plot",
614
  "Interactive Heatmap",
615
  "Time Series Analysis",
616
+ "Cluster Analysis (Coming Soon)" #Removed placeholder, keep in mind
617
  ])
618
+
619
  # Dynamic Controls
620
  cols = st.columns(3)
621
  with cols[0]:
 
624
  y_axis = st.selectbox("Y Axis", df.columns)
625
  with cols[2]:
626
  z_axis = st.selectbox("Z Axis", df.columns) if viz_type == "3D Scatter Plot" else None
627
+
628
  # Generate Visualization
629
+ try: #Add try-except
630
+ if viz_type == "3D Scatter Plot":
631
+ fig = px.scatter_3d(df, x=x_axis, y=y_axis, z=z_axis, color=x_axis)
632
+ st.plotly_chart(fig, use_container_width=True)
633
+
634
+ elif viz_type == "Interactive Heatmap":
635
+ corr = df.corr(numeric_only=True) #Add numeric_only=True
636
+ fig = px.imshow(corr, text_auto=True, color_continuous_scale='RdBu')
637
+ st.plotly_chart(fig, use_container_width=True)
638
+
639
+ elif viz_type == "Time Series Analysis":
640
+ # Basic time series plot
641
+ time_col = st.selectbox("Time Column", df.columns)
642
+ value_col = st.selectbox("Value Column", df.columns)
643
+ fig = px.line(df, x=time_col, y=value_col)
644
+ st.plotly_chart(fig, use_container_width=True)
645
+
646
+ elif viz_type == "Cluster Analysis (Coming Soon)": #Removed placeholder
647
+ st.write("Cluster Analysis Feature Coming Soon!") # placeholder for future development
648
+ except Exception as e:
649
+ st.error(f"Error generating visualization: {e}")