CosmickVisions commited on
Commit
0d26638
Β·
verified Β·
1 Parent(s): 7ec0dc1

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +133 -43
app.py CHANGED
@@ -8,7 +8,7 @@ from sklearn.tree import DecisionTreeRegressor, DecisionTreeClassifier
8
  from sklearn.ensemble import RandomForestRegressor, GradientBoostingRegressor, RandomForestClassifier
9
  from sklearn.svm import SVR, SVC
10
  from sklearn.metrics import mean_squared_error, r2_score, mean_absolute_error, accuracy_score, precision_score, recall_score, f1_score, roc_auc_score
11
- from sklearn.impute import KNNImputer
12
  from sklearn.preprocessing import RobustScaler, StandardScaler, OneHotEncoder
13
  from sklearn.compose import ColumnTransformer
14
  from sklearn.pipeline import Pipeline
@@ -16,6 +16,9 @@ from ydata_profiling import ProfileReport
16
  from streamlit_pandas_profiling import st_profile_report
17
  from io import StringIO
18
  import joblib # For saving and loading models
 
 
 
19
 
20
  # Configuration
21
  st.set_page_config(page_title="Data Wizard Pro", layout="wide", page_icon="πŸ§™")
@@ -30,12 +33,10 @@ st.markdown(
30
  color: #e0e0ff; /* Light text */
31
  font-family: 'Courier New', monospace; /* Monospace font */
32
  }
33
-
34
  /* Main content area */
35
  .stApp {
36
  background-color: #0a0a1a; /* Match body background */
37
  }
38
-
39
  /* Containers and blocks */
40
  .st-emotion-cache-16idsys,
41
  .st-emotion-cache-1v0mbdj,
@@ -51,44 +52,37 @@ st.markdown(
51
  box-shadow: 0 4px 8px rgba(0, 0, 0, 0.5); /* Enhanced shadow */
52
  color: #e0e0ff; /* Light text color */
53
  }
54
-
55
  /* Sidebar */
56
  .st-bb {
57
  background-color: #141422; /* Dark sidebar background */
58
  padding: 20px;
59
  border-radius: 10px;
60
  }
61
-
62
  /* Headers */
63
  h1, h2, h3, h4, h5, h6, .st-bb {
64
  color: #00f7ff; /* Cyan color for headers */
65
  }
66
-
67
  /* Selectboxes and Buttons */
68
  .st-cb, .st-ci, .st-cj, .st-ch {
69
  background-color: #141422; /* Dark selectbox background */
70
  color: #00f7ff !important; /* Cyan text color */
71
  border: 1px solid #00f7ff; /* Cyan border */
72
  }
73
-
74
  /* Selectbox text */
75
  .st-cv {
76
  color: #00f7ff !important; /* Cyan color for selectbox text */
77
  }
78
-
79
  /* Number input and text input */
80
  .st-cr {
81
  background-color: #141422 !important; /* Dark input background */
82
  color: #00f7ff !important; /* Cyan text color */
83
  border: 1px solid #00f7ff !important; /* Cyan border */
84
  }
85
-
86
  /* Slider */
87
  .st-cw {
88
  background-color: #141422 !important; /* Dark slider background */
89
  border: 1px solid #00f7ff !important; /* Cyan border */
90
  }
91
-
92
  /* Buttons */
93
  .st-bz, .st-b0 {
94
  background-color: #141422; /* Darker Button background */
@@ -100,7 +94,6 @@ st.markdown(
100
  background-color: #00f7ff; /* Hover color */
101
  color: #0a0a1a; /* Hover text color */
102
  }
103
-
104
  /* File uploader */
105
  .st-ae {
106
  background-color: #141422 !important; /* Dark file uploader background */
@@ -113,7 +106,6 @@ st.markdown(
113
  border-radius: 10px !important; /* Rounded corners */
114
  box-shadow: 0 4px 8px rgba(0, 0, 0, 0.5) !important; /* Enhanced shadow */
115
  }
116
-
117
  /* Dataframes and tables */
118
  .dataframe {
119
  background-color: #1e1e30 !important; /* Dark table background */
@@ -136,11 +128,100 @@ st.markdown(
136
  border-radius: 10px;
137
  }
138
  /* Add more styling for other elements as needed */
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
139
  </style>
140
  """,
141
  unsafe_allow_html=True,
142
  )
143
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
144
  # Cache decorators
145
  @st.cache_data(ttl=3600)
146
  def load_data(uploaded_file):
@@ -158,7 +239,6 @@ def load_data(uploaded_file):
158
  else:
159
  return None
160
 
161
-
162
  @st.cache_data(ttl=3600)
163
  def generate_profile(df):
164
  """Generate automated EDA report"""
@@ -187,7 +267,18 @@ app_mode = st.sidebar.radio("Navigate", [
187
  "Visualization Lab"
188
  ])
189
 
190
- # Data Upload Section
 
 
 
 
 
 
 
 
 
 
 
191
  if app_mode == "Data Upload":
192
  st.title("πŸ“€ Data Upload & Analysis")
193
 
@@ -210,10 +301,10 @@ if app_mode == "Data Upload":
210
  # Automated EDA Report
211
  with st.expander("πŸš€ Automated Data Report"):
212
  if st.button("Generate Smart Report"):
 
213
  pr = generate_profile(df)
214
  st_profile_report(pr)
215
 
216
- # Smart Cleaning Section
217
  elif app_mode == "Smart Cleaning":
218
  st.title("🧼 Intelligent Data Cleaning")
219
 
@@ -312,7 +403,6 @@ elif app_mode == "Smart Cleaning":
312
  z_scores = np.abs((df[col] - df[col].mean()) / df[col].std())
313
  df = df[z_scores <= zscore_threshold]
314
 
315
-
316
  st.session_state.cleaned_data = df
317
  st.success("Transformation applied!")
318
 
@@ -324,7 +414,6 @@ elif app_mode == "Smart Cleaning":
324
  with col2:
325
  st.write("Cleaned Data", df.head(3))
326
 
327
- # Advanced EDA Section
328
  elif app_mode == "Advanced EDA":
329
  st.title("πŸ” Advanced Exploratory Analysis")
330
 
@@ -347,12 +436,9 @@ elif app_mode == "Advanced EDA":
347
  with cols[0]:
348
  x_col = st.selectbox("X Axis", df.columns)
349
  with cols[1]:
350
- y_col = st.selectbox("Y Axis", df.columns) if plot_type in ["Scatter Plot", "Box Plot", "Violin Plot", "Time Series"] else None
351
  with cols[2]:
352
  z_col = st.selectbox("Z Axis", df.columns) if plot_type == "3D Scatter" else None
353
- if plot_type == "Time Series":
354
- time_col = x_col # rename for clarity
355
- value_col = y_col
356
 
357
  #Interactive filtering
358
  filter_col = st.selectbox("Filter Column", [None] + list(df.columns))
@@ -361,12 +447,11 @@ elif app_mode == "Advanced EDA":
361
  filter_options = st.multiselect("Filter Values", unique_values, default=unique_values)
362
  df = df[df[filter_col].isin(filter_options)]
363
 
364
-
365
  # Generate Plot
366
  if st.button("Generate Visualization"):
367
  try: # add try-except block for potential errors
368
  if plot_type == "Histogram":
369
- fig = px.histogram(df, x=x_col, nbins=30, template="plotly_dark")
370
  elif plot_type == "Scatter Plot":
371
  fig = px.scatter(df, x=x_col, y=y_col, color_discrete_sequence=['#00f7ff'])
372
  elif plot_type == "3D Scatter":
@@ -391,7 +476,6 @@ elif app_mode == "Advanced EDA":
391
  except Exception as e:
392
  st.error(f"Error generating plot: {e}")
393
 
394
- # Model Training Section
395
  elif app_mode == "Model Training":
396
  st.title("πŸ€– Model Training Studio")
397
 
@@ -454,11 +538,10 @@ elif app_mode == "Model Training":
454
  max_depth = st.slider("Max Depth", 5, 20, None) # None for unlimited
455
  param_grid = {'max_depth': [max_depth]}
456
 
457
-
458
-
459
  with col2:
460
  if st.button("Train Model"):
461
  try:
 
462
  X = df.drop(columns=[target])
463
  y = df[target]
464
 
@@ -482,7 +565,6 @@ elif app_mode == "Model Training":
482
  X_test = preprocessor.transform(X_test)
483
  st.session_state.preprocessor = preprocessor #store for prediction later
484
 
485
-
486
  # Model Training
487
  if task_type == "Regression":
488
  if model_type == "Linear Regression":
@@ -505,15 +587,17 @@ elif app_mode == "Model Training":
505
  elif model_type == "Support Vector Machine":
506
  model = SVC(probability=True) #probability=True needed for ROC AUC
507
 
508
-
509
  #Hyperparameter tuning
510
  if enable_hyperparameter_tuning and model_type in ["Random Forest", "Gradient Boosting", "Support Vector Machine", "Logistic Regression", "Decision Tree"]:
511
  grid_search = GridSearchCV(model, param_grid, cv=3, scoring='neg_mean_squared_error' if task_type == "Regression" else 'accuracy')
 
 
512
  grid_search.fit(X_train, y_train)
513
  model = grid_search.best_estimator_ #use best model
514
  st.write("Best Parameters:", grid_search.best_params_)
515
 
516
  else:
 
517
  model.fit(X_train, y_train)
518
 
519
  st.session_state.model = model
@@ -553,13 +637,13 @@ elif app_mode == "Model Training":
553
  st.write("Cross-Validation Scores:", scores)
554
  st.write("Mean Cross-Validation Score:", scores.mean())
555
 
556
-
557
-
558
  #Model persistence
559
  if st.checkbox("Save Model"):
560
  model_filename = st.text_input("Model Filename", "trained_model.joblib")
561
  joblib.dump((model, preprocessor), model_filename) # save both model AND preprocessor
562
  st.success(f"Model saved as {model_filename}")
 
 
563
  except Exception as e:
564
  st.error(f"Error during training: {e}")
565
 
@@ -581,7 +665,6 @@ elif app_mode == "Predictions":
581
 
582
  if st.button("Predict"):
583
  try:
584
-
585
  input_df = pd.DataFrame([input_data])
586
  # Preprocess input
587
  input_processed = preprocessor.transform(input_df)
@@ -602,7 +685,8 @@ elif app_mode == "Predictions":
602
  else:
603
  st.warning("Please train a model first.")
604
 
605
- elif app_mode == "Visualization Lab":
 
606
  st.title("πŸ“Š Advanced Visualization Lab")
607
 
608
  if st.session_state.cleaned_data is not None:
@@ -610,10 +694,9 @@ elif app_mode == "Visualization Lab":
610
 
611
  # Visualization Gallery
612
  viz_type = st.selectbox("Choose Visualization Type", [
613
- "3D Scatter Plot",
614
- "Interactive Heatmap",
615
  "Time Series Analysis",
616
- "Cluster Analysis (Coming Soon)" #Removed placeholder, keep in mind
617
  ])
618
 
619
  # Dynamic Controls
@@ -621,19 +704,24 @@ elif app_mode == "Visualization Lab":
621
  with cols[0]:
622
  x_axis = st.selectbox("X Axis", df.columns)
623
  with cols[1]:
624
- y_axis = st.selectbox("Y Axis", df.columns)
625
  with cols[2]:
626
  z_axis = st.selectbox("Z Axis", df.columns) if viz_type == "3D Scatter Plot" else None
627
 
628
  # Generate Visualization
629
- try: #Add try-except
630
  if viz_type == "3D Scatter Plot":
631
- fig = px.scatter_3d(df, x=x_axis, y=y_axis, z=z_axis, color=x_axis)
632
- st.plotly_chart(fig, use_container_width=True)
 
 
 
 
633
 
634
  elif viz_type == "Interactive Heatmap":
635
- corr = df.corr(numeric_only=True) #Add numeric_only=True
636
  fig = px.imshow(corr, text_auto=True, color_continuous_scale='RdBu')
 
637
  st.plotly_chart(fig, use_container_width=True)
638
 
639
  elif viz_type == "Time Series Analysis":
@@ -641,9 +729,11 @@ elif app_mode == "Visualization Lab":
641
  time_col = st.selectbox("Time Column", df.columns)
642
  value_col = st.selectbox("Value Column", df.columns)
643
  fig = px.line(df, x=time_col, y=value_col)
 
644
  st.plotly_chart(fig, use_container_width=True)
645
 
646
- elif viz_type == "Cluster Analysis (Coming Soon)": #Removed placeholder
647
- st.write("Cluster Analysis Feature Coming Soon!") # placeholder for future development
 
648
  except Exception as e:
649
- st.error(f"Error generating visualization: {e}")
 
8
  from sklearn.ensemble import RandomForestRegressor, GradientBoostingRegressor, RandomForestClassifier
9
  from sklearn.svm import SVR, SVC
10
  from sklearn.metrics import mean_squared_error, r2_score, mean_absolute_error, accuracy_score, precision_score, recall_score, f1_score, roc_auc_score
11
+ from sklearn.impute import KNNImputer, SimpleImputer
12
  from sklearn.preprocessing import RobustScaler, StandardScaler, OneHotEncoder
13
  from sklearn.compose import ColumnTransformer
14
  from sklearn.pipeline import Pipeline
 
16
  from streamlit_pandas_profiling import st_profile_report
17
  from io import StringIO
18
  import joblib # For saving and loading models
19
+ import requests
20
+ import asyncio
21
+ from io import BytesIO
22
 
23
  # Configuration
24
  st.set_page_config(page_title="Data Wizard Pro", layout="wide", page_icon="πŸ§™")
 
33
  color: #e0e0ff; /* Light text */
34
  font-family: 'Courier New', monospace; /* Monospace font */
35
  }
 
36
  /* Main content area */
37
  .stApp {
38
  background-color: #0a0a1a; /* Match body background */
39
  }
 
40
  /* Containers and blocks */
41
  .st-emotion-cache-16idsys,
42
  .st-emotion-cache-1v0mbdj,
 
52
  box-shadow: 0 4px 8px rgba(0, 0, 0, 0.5); /* Enhanced shadow */
53
  color: #e0e0ff; /* Light text color */
54
  }
 
55
  /* Sidebar */
56
  .st-bb {
57
  background-color: #141422; /* Dark sidebar background */
58
  padding: 20px;
59
  border-radius: 10px;
60
  }
 
61
  /* Headers */
62
  h1, h2, h3, h4, h5, h6, .st-bb {
63
  color: #00f7ff; /* Cyan color for headers */
64
  }
 
65
  /* Selectboxes and Buttons */
66
  .st-cb, .st-ci, .st-cj, .st-ch {
67
  background-color: #141422; /* Dark selectbox background */
68
  color: #00f7ff !important; /* Cyan text color */
69
  border: 1px solid #00f7ff; /* Cyan border */
70
  }
 
71
  /* Selectbox text */
72
  .st-cv {
73
  color: #00f7ff !important; /* Cyan color for selectbox text */
74
  }
 
75
  /* Number input and text input */
76
  .st-cr {
77
  background-color: #141422 !important; /* Dark input background */
78
  color: #00f7ff !important; /* Cyan text color */
79
  border: 1px solid #00f7ff !important; /* Cyan border */
80
  }
 
81
  /* Slider */
82
  .st-cw {
83
  background-color: #141422 !important; /* Dark slider background */
84
  border: 1px solid #00f7ff !important; /* Cyan border */
85
  }
 
86
  /* Buttons */
87
  .st-bz, .st-b0 {
88
  background-color: #141422; /* Darker Button background */
 
94
  background-color: #00f7ff; /* Hover color */
95
  color: #0a0a1a; /* Hover text color */
96
  }
 
97
  /* File uploader */
98
  .st-ae {
99
  background-color: #141422 !important; /* Dark file uploader background */
 
106
  border-radius: 10px !important; /* Rounded corners */
107
  box-shadow: 0 4px 8px rgba(0, 0, 0, 0.5) !important; /* Enhanced shadow */
108
  }
 
109
  /* Dataframes and tables */
110
  .dataframe {
111
  background-color: #1e1e30 !important; /* Dark table background */
 
128
  border-radius: 10px;
129
  }
130
  /* Add more styling for other elements as needed */
131
+ /* Style the space around the navigation menu to match the theme */
132
+ [data-testid="stSidebar"] {
133
+ background-color: #141422 !important;
134
+ }
135
+ [data-testid="stSidebarNav"] {
136
+ background-color: #141422 !important;
137
+ color: #e0e0ff !important;
138
+ }
139
+ [data-testid="stSidebarNavItems"] {
140
+ color: #e0e0ff !important;
141
+ }
142
+ /* Ensure all text is white or cyan (no black) */
143
+ .st-bb,
144
+ .st-cb, .st-ci, .st-cj, .st-ch,
145
+ .st-cv,
146
+ .st-cr,
147
+ .st-cw,
148
+ .st-ae,
149
+ .st-emotion-cache-r421ms,
150
+ .st-emotion-cache-10oheav,
151
+ .st-emotion-cache-16idsys,
152
+ .st-emotion-cache-1v0mbdj,
153
+ .st-emotion-cache-1wrcr25,
154
+ .st-emotion-cache-607q0z,
155
+ .st-emotion-cache-1v3fvcr,
156
+ .st-emotion-cache-10trblm {
157
+ color: #e0e0ff !important; /* Default to white */
158
+ }
159
+ h1, h2, h3, h4, h5, h6 {
160
+ color: #00f7ff !important; /* Headings to cyan */
161
+ }
162
+
163
+ /* Styles for loader */
164
+ .loader {
165
+ border: 5px solid #f3f3f3;
166
+ border-top: 5px solid #00f7ff; /* Cyan loader color */
167
+ border-radius: 50%;
168
+ width: 30px;
169
+ height: 30px;
170
+ animation: spin 2s linear infinite;
171
+ }
172
+
173
+ @keyframes spin {
174
+ 0% { transform: rotate(0deg); }
175
+ 100% { transform: rotate(360deg); }
176
+ }
177
  </style>
178
  """,
179
  unsafe_allow_html=True,
180
  )
181
 
182
+ # --- Image Loading ---
183
+ @st.cache_data(ttl=3600)
184
+ async def load_image(image_url):
185
+ """Loads an image from a URL asynchronously."""
186
+ try:
187
+ response = requests.get(image_url, stream=True)
188
+ response.raise_for_status() # Raise HTTPError for bad responses (4xx or 5xx)
189
+ return BytesIO(response.content) # Return image data as a BytesIO object
190
+ except requests.exceptions.RequestException as e:
191
+ st.error(f"Error loading image: {e}")
192
+ return None # Handle errors gracefully
193
+
194
+ async def set_background():
195
+ """Sets the background image."""
196
+ image_url = "https://images.unsplash.com/photo-1504821618514-8c1b6e408ca8?ixlib=rb-4.0.3&ixid=MnwxMjA3fDB8MHxwaG90by1wYWdlfHx8fGVufDB8fHx8&auto=format&fit=crop&w=1949&q=80" # Replace with actual URL
197
+ image_data = await load_image(image_url)
198
+
199
+ if image_data:
200
+ st.markdown(
201
+ f"""
202
+ <style>
203
+ .stApp {{
204
+ background-image: url(data:image/{"jpeg"};base64,{image_data.getvalue().hex()});
205
+ background-size: cover;
206
+ }}
207
+ </style>
208
+ """,
209
+ unsafe_allow_html=True
210
+ )
211
+ asyncio.run(set_background())
212
+ # --- Loader ----
213
+ def show_loader(message="Loading..."):
214
+ """Displays an animated loader."""
215
+ st.markdown(
216
+ f"""
217
+ <div style="display: flex; align-items: center; justify-content: center; margin-top: 20px;">
218
+ <div class="loader"></div>
219
+ <span style="margin-left: 10px; color: #00f7ff;">{message}</span>
220
+ </div>
221
+ """,
222
+ unsafe_allow_html=True
223
+ )
224
+
225
  # Cache decorators
226
  @st.cache_data(ttl=3600)
227
  def load_data(uploaded_file):
 
239
  else:
240
  return None
241
 
 
242
  @st.cache_data(ttl=3600)
243
  def generate_profile(df):
244
  """Generate automated EDA report"""
 
267
  "Visualization Lab"
268
  ])
269
 
270
+ # --- Progress Bar ----
271
+ def animated_progress_bar(progress_var, message="Processing..."):
272
+ """Displays an animated progress bar with a message."""
273
+ progress_bar = st.progress(0)
274
+ status_text = st.empty() # Empty element to update the status message
275
+
276
+ for i in range(progress_var): #progress will increment
277
+ status_text.text(f"{message} ({i+1}/{progress_var})")
278
+ progress_bar.progress((i+1)/progress_var) #progress incrementally.
279
+ time.sleep(0.01) # reduced sleep timer as its getting too long
280
+
281
+ # --- Main App Logic ---
282
  if app_mode == "Data Upload":
283
  st.title("πŸ“€ Data Upload & Analysis")
284
 
 
301
  # Automated EDA Report
302
  with st.expander("πŸš€ Automated Data Report"):
303
  if st.button("Generate Smart Report"):
304
+ show_loader("Generating EDA Report")
305
  pr = generate_profile(df)
306
  st_profile_report(pr)
307
 
 
308
  elif app_mode == "Smart Cleaning":
309
  st.title("🧼 Intelligent Data Cleaning")
310
 
 
403
  z_scores = np.abs((df[col] - df[col].mean()) / df[col].std())
404
  df = df[z_scores <= zscore_threshold]
405
 
 
406
  st.session_state.cleaned_data = df
407
  st.success("Transformation applied!")
408
 
 
414
  with col2:
415
  st.write("Cleaned Data", df.head(3))
416
 
 
417
  elif app_mode == "Advanced EDA":
418
  st.title("πŸ” Advanced Exploratory Analysis")
419
 
 
436
  with cols[0]:
437
  x_col = st.selectbox("X Axis", df.columns)
438
  with cols[1]:
439
+ y_col = st.selectbox("Y Axis", df.columns) if plot_type not in ["Correlation Heatmap"] else None
440
  with cols[2]:
441
  z_col = st.selectbox("Z Axis", df.columns) if plot_type == "3D Scatter" else None
 
 
 
442
 
443
  #Interactive filtering
444
  filter_col = st.selectbox("Filter Column", [None] + list(df.columns))
 
447
  filter_options = st.multiselect("Filter Values", unique_values, default=unique_values)
448
  df = df[df[filter_col].isin(filter_options)]
449
 
 
450
  # Generate Plot
451
  if st.button("Generate Visualization"):
452
  try: # add try-except block for potential errors
453
  if plot_type == "Histogram":
454
+ fig = px.histogram(df, x=x_col, y=y_col, nbins=30, template="plotly_dark")
455
  elif plot_type == "Scatter Plot":
456
  fig = px.scatter(df, x=x_col, y=y_col, color_discrete_sequence=['#00f7ff'])
457
  elif plot_type == "3D Scatter":
 
476
  except Exception as e:
477
  st.error(f"Error generating plot: {e}")
478
 
 
479
  elif app_mode == "Model Training":
480
  st.title("πŸ€– Model Training Studio")
481
 
 
538
  max_depth = st.slider("Max Depth", 5, 20, None) # None for unlimited
539
  param_grid = {'max_depth': [max_depth]}
540
 
 
 
541
  with col2:
542
  if st.button("Train Model"):
543
  try:
544
+ show_loader("Training the Model")
545
  X = df.drop(columns=[target])
546
  y = df[target]
547
 
 
565
  X_test = preprocessor.transform(X_test)
566
  st.session_state.preprocessor = preprocessor #store for prediction later
567
 
 
568
  # Model Training
569
  if task_type == "Regression":
570
  if model_type == "Linear Regression":
 
587
  elif model_type == "Support Vector Machine":
588
  model = SVC(probability=True) #probability=True needed for ROC AUC
589
 
 
590
  #Hyperparameter tuning
591
  if enable_hyperparameter_tuning and model_type in ["Random Forest", "Gradient Boosting", "Support Vector Machine", "Logistic Regression", "Decision Tree"]:
592
  grid_search = GridSearchCV(model, param_grid, cv=3, scoring='neg_mean_squared_error' if task_type == "Regression" else 'accuracy')
593
+ animated_progress_bar(50, "Performing Grid Search") #add loading for grid search
594
+
595
  grid_search.fit(X_train, y_train)
596
  model = grid_search.best_estimator_ #use best model
597
  st.write("Best Parameters:", grid_search.best_params_)
598
 
599
  else:
600
+ animated_progress_bar(80, "Fitting Model")
601
  model.fit(X_train, y_train)
602
 
603
  st.session_state.model = model
 
637
  st.write("Cross-Validation Scores:", scores)
638
  st.write("Mean Cross-Validation Score:", scores.mean())
639
 
 
 
640
  #Model persistence
641
  if st.checkbox("Save Model"):
642
  model_filename = st.text_input("Model Filename", "trained_model.joblib")
643
  joblib.dump((model, preprocessor), model_filename) # save both model AND preprocessor
644
  st.success(f"Model saved as {model_filename}")
645
+ animated_progress_bar(100, "Model Trained Succesfully")
646
+
647
  except Exception as e:
648
  st.error(f"Error during training: {e}")
649
 
 
665
 
666
  if st.button("Predict"):
667
  try:
 
668
  input_df = pd.DataFrame([input_data])
669
  # Preprocess input
670
  input_processed = preprocessor.transform(input_df)
 
685
  else:
686
  st.warning("Please train a model first.")
687
 
688
+
689
+ elif app_mode == "Visualization Lab":
690
  st.title("πŸ“Š Advanced Visualization Lab")
691
 
692
  if st.session_state.cleaned_data is not None:
 
694
 
695
  # Visualization Gallery
696
  viz_type = st.selectbox("Choose Visualization Type", [
697
+ "3D Scatter Plot"Interactive Heatmap",
 
698
  "Time Series Analysis",
699
+ "Cluster Analysis (Coming Soon)" # Removed placeholder
700
  ])
701
 
702
  # Dynamic Controls
 
704
  with cols[0]:
705
  x_axis = st.selectbox("X Axis", df.columns)
706
  with cols[1]:
707
+ y_axis = st.selectbox("Y Axis", df.columns) if viz_type not in ["Interactive Heatmap"] else None
708
  with cols[2]:
709
  z_axis = st.selectbox("Z Axis", df.columns) if viz_type == "3D Scatter Plot" else None
710
 
711
  # Generate Visualization
712
+ try: # Add try-except
713
  if viz_type == "3D Scatter Plot":
714
+ if y_axis is None or z_axis is None:
715
+ st.error("Please select Y and Z axes for 3D Scatter Plot.")
716
+ else:
717
+ fig = px.scatter_3d(df, x=x_axis, y=y_axis, z=z_axis, color=x_axis)
718
+ fig.update_layout(plot_bgcolor="#1e1e30", paper_bgcolor="#1e1e30", font_color="#e0e0ff")
719
+ st.plotly_chart(fig, use_container_width=True)
720
 
721
  elif viz_type == "Interactive Heatmap":
722
+ corr = df.corr(numeric_only=True) # Add numeric_only=True
723
  fig = px.imshow(corr, text_auto=True, color_continuous_scale='RdBu')
724
+ fig.update_layout(plot_bgcolor="#1e1e30", paper_bgcolor="#1e1e30", font_color="#e0e0ff")
725
  st.plotly_chart(fig, use_container_width=True)
726
 
727
  elif viz_type == "Time Series Analysis":
 
729
  time_col = st.selectbox("Time Column", df.columns)
730
  value_col = st.selectbox("Value Column", df.columns)
731
  fig = px.line(df, x=time_col, y=value_col)
732
+ fig.update_layout(plot_bgcolor="#1e1e30", paper_bgcolor="#1e1e30", font_color="#e0e0ff")
733
  st.plotly_chart(fig, use_container_width=True)
734
 
735
+ elif viz_type == "Cluster Analysis (Coming Soon)": # Removed placeholder
736
+ st.write("Cluster Analysis Feature Coming Soon!") # placeholder for future development
737
+
738
  except Exception as e:
739
+ st.error(f"Error generating visualization: {e}")