CosmickVisions commited on
Commit
f0536a5
·
verified ·
1 Parent(s): f8de685

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +990 -128
app.py CHANGED
@@ -1,14 +1,17 @@
1
  import streamlit as st
2
  import pandas as pd
3
  import plotly.express as px
 
4
  import numpy as np
5
  from sklearn.model_selection import train_test_split
6
  from sklearn.neural_network import MLPClassifier, MLPRegressor
7
- from sklearn.cluster import KMeans
 
8
  from sklearn.metrics import accuracy_score, r2_score, silhouette_score, confusion_matrix, classification_report, mean_squared_error
9
  from sklearn.preprocessing import StandardScaler
10
  from ydata_profiling import ProfileReport
11
  from streamlit_pandas_profiling import st_profile_report
 
12
  from groq import Groq
13
  from langchain_community.vectorstores import FAISS
14
  from langchain.text_splitter import RecursiveCharacterTextSplitter
@@ -18,6 +21,11 @@ from langchain_community.tools.tavily_search import TavilySearchResults
18
  import os
19
  from dotenv import load_dotenv
20
  import tempfile
 
 
 
 
 
21
 
22
  # Load environment variables
23
  load_dotenv()
@@ -141,6 +149,103 @@ st.markdown("""
141
  transform: translateY(-2px);
142
  box-shadow: 0 6px 16px rgba(59, 130, 246, 0.4);
143
  }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
144
  </style>
145
  """, unsafe_allow_html=True)
146
 
@@ -234,165 +339,922 @@ def plot_clusters(X, labels):
234
  fig = px.scatter(X, x=X.columns[0], y=X.columns[1], color=labels, title="Cluster Visualization")
235
  return fig
236
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
237
  # Pages
238
  def data_upload_page():
239
- st.header("📤 Data Upload & Analysis")
240
- uploaded_file = st.file_uploader("Upload Dataset", type=["csv"])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
241
 
242
  if uploaded_file:
243
- df = pd.read_csv(uploaded_file)
244
- st.session_state.df = df
245
- st.session_state.vector_store = create_vector_store(convert_df_to_text(df))
246
- st.session_state.metrics = {}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
247
 
248
- st.subheader("Dataset Health Check")
249
- col1, col2, col3 = st.columns(3)
250
- col1.metric("Total Samples", df.shape[0])
251
- col2.metric("Features", df.shape[1])
252
- col3.metric("Missing Values", df.isna().sum().sum())
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
253
 
254
- if st.button("Generate Full EDA Report"):
255
- with st.spinner("Generating comprehensive analysis..."):
256
- profile = ProfileReport(df, explorative=True)
257
- st_profile_report(profile)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
258
 
259
  def model_training_page():
260
- st.header("🧠 Neural Network Training Studio")
 
 
 
 
 
 
 
 
 
 
 
 
261
 
262
  if 'df' not in st.session_state:
263
- st.warning("Upload data first!")
 
264
  return
265
 
266
  df = st.session_state.df
267
- problem_type = st.selectbox("Select Problem Type", ["Classification", "Regression", "Clustering"])
268
- mode = st.selectbox("Domain Specialization", ["Legal", "Financial", "Academic", "Technical"])
269
 
270
- if problem_type != "Clustering":
271
- target = st.selectbox("Select Target Variable", df.columns)
272
- X = df.drop(columns=[target])
273
- y = df[target]
274
- else:
275
- X = df
276
- y = None
277
 
278
- if st.button("Train Neural Network"):
279
- with st.spinner("Training in progress..."):
280
- X_scaled = StandardScaler().fit_transform(X)
281
- X_train, X_test, y_train, y_test = train_test_split(X_scaled, y, test_size=0.2, random_state=42) if y is not None else (X_scaled, None, None, None)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
282
 
283
- if problem_type == "Classification":
284
- model = MLPClassifier(hidden_layer_sizes=(100, 50), max_iter=500, random_state=42)
285
- model.fit(X_train, y_train)
286
- y_pred = model.predict(X_test)
287
- st.session_state.metrics = {
288
- "Accuracy": accuracy_score(y_test, y_pred),
289
- "Classification Report": classification_report(y_test, y_pred, output_dict=True)
290
- }
291
- elif problem_type == "Regression":
292
- model = MLPRegressor(hidden_layer_sizes=(100, 50), max_iter=500, random_state=42)
293
- model.fit(X_train, y_train)
294
- y_pred = model.predict(X_test)
295
- st.session_state.metrics = {
296
- "R2 Score": r2_score(y_test, y_pred),
297
- "Mean Squared Error": mean_squared_error(y_test, y_pred)
298
- }
299
- else: # Clustering
300
- model = KMeans(n_clusters=3, random_state=42)
301
- labels = model.fit_predict(X_scaled)
302
- st.session_state.metrics = {
303
- "Silhouette Score": silhouette_score(X_scaled, labels)
304
- }
305
 
306
- st.session_state.best_model = model
307
- st.session_state.X_test = X_test
308
- st.session_state.y_test = y_test
309
- st.session_state.y_pred = y_pred if y is not None else labels
310
- st.session_state.problem_type = problem_type
311
- st.success(f"Model trained successfully in {mode} mode!")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
312
 
313
  def visualization_page():
314
- st.header("🔍 Neural Network Evaluation Center")
 
 
 
 
 
 
 
 
 
 
 
 
315
 
316
  if 'best_model' not in st.session_state:
317
- st.warning("Train a model first!")
 
318
  return
319
 
320
- st.subheader("Performance Analysis")
321
- if st.session_state.problem_type == "Classification":
322
- st.plotly_chart(plot_confusion_matrix(st.session_state.y_test, st.session_state.y_pred))
323
- st.plotly_chart(plot_feature_importance(st.session_state.best_model, pd.DataFrame(st.session_state.X_test, columns=st.session_state.df.columns[:-1])))
324
- elif st.session_state.problem_type == "Regression":
325
- st.plotly_chart(plot_residuals(st.session_state.y_test, st.session_state.y_pred))
326
- st.plotly_chart(plot_feature_importance(st.session_state.best_model, pd.DataFrame(st.session_state.X_test, columns=st.session_state.df.columns[:-1])))
327
- else: # Clustering
328
- st.plotly_chart(plot_clusters(pd.DataFrame(st.session_state.X_test, columns=st.session_state.df.columns), st.session_state.y_pred))
329
 
330
- st.subheader("Metrics")
331
- st.write(st.session_state.metrics)
332
 
333
- # Chatbot Interface
334
  def ai_assistant():
 
335
  st.markdown('<div class="chat-container">', unsafe_allow_html=True)
336
- st.subheader("🧠 Neural Insight Assistant (RAG + Web Search)")
337
 
338
- use_web_search = st.checkbox("Enable Tavily Web Search", value=False)
339
- mode = st.selectbox("Domain Mode", ["Legal", "Financial", "Academic", "Technical"], key="chat_mode")
340
 
341
- for msg in st.session_state.chat_history:
342
- with st.chat_message(msg["role"]):
343
- st.markdown(f'<div class="{msg["role"]}-message">{msg["content"]}</div>', unsafe_allow_html=True)
344
-
345
- if prompt := st.chat_input("Ask about data, models, or web insights..."):
346
- st.session_state.chat_history.append({"role": "user", "content": prompt})
347
- with st.chat_message("user"):
348
- st.markdown(f'<div class="user-message">{prompt}</div>', unsafe_allow_html=True)
349
-
350
- with st.spinner("Processing..."):
351
- response = get_groq_response(prompt, mode, use_web_search)
352
- st.session_state.chat_history.append({"role": "assistant", "content": response})
353
-
354
- with st.chat_message("assistant"):
355
- st.markdown(f'<div class="bot-message">{response}</div>', unsafe_allow_html=True)
 
 
 
 
 
 
 
 
 
 
 
 
356
 
357
  st.markdown('</div>', unsafe_allow_html=True)
358
 
359
- # Main App Layout
360
- st.markdown("""
361
- <div class="header">
362
- <h1 class="header-title">Neural-Vision Enhanced</h1>
363
- <div class="header-subtitle">Neural Network Development for Domain-Specialized Analysis</div>
364
- </div>
365
- """, unsafe_allow_html=True)
366
 
367
- with st.sidebar:
368
- st.title("🔮 Neural-Vision Enhanced")
369
- page = st.selectbox("Navigation", [
370
- "Data Upload & Analysis",
371
- "Neural Network Training Studio",
372
- "Neural Network Evaluation Center"
373
- ])
374
- st.session_state.active_page = page
375
- st.markdown("---")
376
- st.markdown("**Environment Setup**")
377
-
378
- # Tavily API Key Input and Submit Button
379
- tavily_api_input = st.text_input("Tavily API Key", type="password", help="Enter your Tavily API key for web search functionality")
380
- if st.button("Submit API Key"):
381
- if tavily_api_input:
382
- st.session_state.tavily_api_key = tavily_api_input
383
- st.success("Tavily API Key submitted successfully!")
384
- else:
385
- st.warning("Please enter a valid API key.")
386
-
387
- st.markdown("---")
388
- st.markdown("v5.0 | © 2025 Neural-Vision")
389
 
390
- # Page Routing
391
- if "Data Upload & Analysis" in page:
392
- data_upload_page()
393
- elif "Neural Network Training Studio" in page:
394
- model_training_page()
395
- else:
396
- visualization_page()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
397
 
398
- ai_assistant()
 
 
 
1
  import streamlit as st
2
  import pandas as pd
3
  import plotly.express as px
4
+ import plotly.graph_objects as go
5
  import numpy as np
6
  from sklearn.model_selection import train_test_split
7
  from sklearn.neural_network import MLPClassifier, MLPRegressor
8
+ from sklearn.cluster import KMeans, DBSCAN, AgglomerativeClustering
9
+ from sklearn.ensemble import RandomForestClassifier, RandomForestRegressor, GradientBoostingClassifier, GradientBoostingRegressor
10
  from sklearn.metrics import accuracy_score, r2_score, silhouette_score, confusion_matrix, classification_report, mean_squared_error
11
  from sklearn.preprocessing import StandardScaler
12
  from ydata_profiling import ProfileReport
13
  from streamlit_pandas_profiling import st_profile_report
14
+ from streamlit_lottie import st_lottie
15
  from groq import Groq
16
  from langchain_community.vectorstores import FAISS
17
  from langchain.text_splitter import RecursiveCharacterTextSplitter
 
21
  import os
22
  from dotenv import load_dotenv
23
  import tempfile
24
+ import datetime
25
+ import time
26
+ import matplotlib.pyplot as plt
27
+ import shap
28
+ import xgboost as xgb
29
 
30
  # Load environment variables
31
  load_dotenv()
 
149
  transform: translateY(-2px);
150
  box-shadow: 0 6px 16px rgba(59, 130, 246, 0.4);
151
  }
152
+
153
+ /* Card styles */
154
+ .card {
155
+ background: var(--white);
156
+ border-radius: 16px;
157
+ box-shadow: 0 4px 16px rgba(0,0,0,0.1);
158
+ padding: 20px;
159
+ margin-bottom: 25px;
160
+ transition: all 0.3s ease;
161
+ }
162
+ .card:hover {
163
+ box-shadow: 0 8px 24px rgba(0,0,0,0.15);
164
+ transform: translateY(-2px);
165
+ }
166
+
167
+ /* Step header styles */
168
+ .step-header {
169
+ display: flex;
170
+ align-items: center;
171
+ margin-bottom: 15px;
172
+ }
173
+ .step-counter {
174
+ background: var(--primary-blue);
175
+ color: var(--white);
176
+ width: 36px;
177
+ height: 36px;
178
+ border-radius: 50%;
179
+ display: flex;
180
+ align-items: center;
181
+ justify-content: center;
182
+ font-weight: bold;
183
+ margin-right: 15px;
184
+ box-shadow: 0 4px 10px rgba(59, 130, 246, 0.3);
185
+ }
186
+ .step-title {
187
+ font-size: 1.5rem;
188
+ font-weight: 700;
189
+ color: var(--dark-blue);
190
+ }
191
+
192
+ /* Notification styles */
193
+ .notification {
194
+ display: flex;
195
+ align-items: center;
196
+ background: #ECFDF5;
197
+ border-left: 4px solid #059669;
198
+ color: #065F46;
199
+ padding: 15px;
200
+ border-radius: 8px;
201
+ margin: 15px 0;
202
+ box-shadow: 0 2px 8px rgba(5, 150, 105, 0.1);
203
+ transition: transform 0.2s ease;
204
+ }
205
+ .notification:hover {
206
+ transform: translateY(-2px);
207
+ }
208
+ .notification-icon {
209
+ background: #059669;
210
+ color: white;
211
+ width: 24px;
212
+ height: 24px;
213
+ border-radius: 50%;
214
+ display: flex;
215
+ align-items: center;
216
+ justify-content: center;
217
+ margin-right: 15px;
218
+ font-weight: bold;
219
+ }
220
+
221
+ /* Metrics card styles */
222
+ .metrics-card {
223
+ background: var(--white);
224
+ border-radius: 12px;
225
+ padding: 15px;
226
+ box-shadow: 0 2px 10px rgba(0,0,0,0.08);
227
+ text-align: center;
228
+ transition: all 0.3s ease;
229
+ height: 100%;
230
+ display: flex;
231
+ flex-direction: column;
232
+ justify-content: center;
233
+ }
234
+ .metrics-card:hover {
235
+ transform: translateY(-3px);
236
+ box-shadow: 0 6px 16px rgba(0,0,0,0.12);
237
+ }
238
+ .metrics-value {
239
+ font-size: 2rem;
240
+ font-weight: 700;
241
+ color: var(--primary-blue);
242
+ margin-bottom: 5px;
243
+ }
244
+ .metrics-label {
245
+ color: var(--medium-grey);
246
+ font-size: 0.9rem;
247
+ font-weight: 500;
248
+ }
249
  </style>
250
  """, unsafe_allow_html=True)
251
 
 
339
  fig = px.scatter(X, x=X.columns[0], y=X.columns[1], color=labels, title="Cluster Visualization")
340
  return fig
341
 
342
+ def plot_learning_curve(model, X, y, cv=5):
343
+ """Plot learning curve to show model performance with increasing data"""
344
+ from sklearn.model_selection import learning_curve
345
+
346
+ train_sizes, train_scores, test_scores = learning_curve(
347
+ model, X, y, cv=cv, scoring='accuracy' if hasattr(y, 'nunique') else 'r2',
348
+ n_jobs=-1, train_sizes=np.linspace(0.1, 1.0, 10))
349
+
350
+ train_mean = np.mean(train_scores, axis=1)
351
+ train_std = np.std(train_scores, axis=1)
352
+ test_mean = np.mean(test_scores, axis=1)
353
+ test_std = np.std(test_scores, axis=1)
354
+
355
+ # Create DataFrame for plotting
356
+ df_curve = pd.DataFrame({
357
+ 'Training Size (%)': train_sizes / len(X) * 100,
358
+ 'Training Score': train_mean,
359
+ 'Training Upper': train_mean + train_std,
360
+ 'Training Lower': train_mean - train_std,
361
+ 'Testing Score': test_mean,
362
+ 'Testing Upper': test_mean + test_std,
363
+ 'Testing Lower': test_mean - test_std
364
+ })
365
+
366
+ # Create the plot
367
+ fig = go.Figure()
368
+
369
+ # Add training data with confidence interval
370
+ fig.add_trace(go.Scatter(
371
+ x=df_curve['Training Size (%)'],
372
+ y=df_curve['Training Score'],
373
+ mode='lines+markers',
374
+ name='Training Score',
375
+ line=dict(color='blue', width=2),
376
+ marker=dict(size=8)
377
+ ))
378
+ fig.add_trace(go.Scatter(
379
+ x=df_curve['Training Size (%)'],
380
+ y=df_curve['Training Upper'],
381
+ mode='lines',
382
+ line=dict(width=0),
383
+ showlegend=False
384
+ ))
385
+ fig.add_trace(go.Scatter(
386
+ x=df_curve['Training Size (%)'],
387
+ y=df_curve['Training Lower'],
388
+ mode='lines',
389
+ line=dict(width=0),
390
+ fill='tonexty',
391
+ fillcolor='rgba(0, 0, 255, 0.1)',
392
+ showlegend=False
393
+ ))
394
+
395
+ # Add testing data with confidence interval
396
+ fig.add_trace(go.Scatter(
397
+ x=df_curve['Training Size (%)'],
398
+ y=df_curve['Testing Score'],
399
+ mode='lines+markers',
400
+ name='Testing Score',
401
+ line=dict(color='red', width=2),
402
+ marker=dict(size=8)
403
+ ))
404
+ fig.add_trace(go.Scatter(
405
+ x=df_curve['Training Size (%)'],
406
+ y=df_curve['Testing Upper'],
407
+ mode='lines',
408
+ line=dict(width=0),
409
+ showlegend=False
410
+ ))
411
+ fig.add_trace(go.Scatter(
412
+ x=df_curve['Training Size (%)'],
413
+ y=df_curve['Testing Lower'],
414
+ mode='lines',
415
+ line=dict(width=0),
416
+ fill='tonexty',
417
+ fillcolor='rgba(255, 0, 0, 0.1)',
418
+ showlegend=False
419
+ ))
420
+
421
+ # Update layout
422
+ fig.update_layout(
423
+ title='Learning Curve',
424
+ xaxis_title='Training Set Size (%)',
425
+ yaxis_title='Score',
426
+ hovermode='x unified',
427
+ width=700,
428
+ height=400,
429
+ legend=dict(
430
+ orientation="h",
431
+ yanchor="bottom",
432
+ y=1.02,
433
+ xanchor="right",
434
+ x=1
435
+ )
436
+ )
437
+
438
+ return fig
439
+
440
+ def plot_shap_summary(model, X):
441
+ """Create SHAP summary plot for model explainability"""
442
+ try:
443
+ # Create explainer based on model type
444
+ if hasattr(model, 'predict_proba'):
445
+ explainer = shap.Explainer(model)
446
+ else:
447
+ explainer = shap.Explainer(model)
448
+
449
+ # Calculate SHAP values
450
+ shap_values = explainer(X)
451
+
452
+ # Create the SHAP summary plot
453
+ plt.figure(figsize=(10, 8))
454
+ shap.summary_plot(shap_values, X, show=False)
455
+ fig = plt.gcf()
456
+ plt.tight_layout()
457
+
458
+ return fig
459
+ except Exception as e:
460
+ st.warning(f"Could not generate SHAP plot: {e}")
461
+ return None
462
+
463
+ # Data Exploration and Insights Generation
464
+ def generate_data_insights(df):
465
+ """Generate comprehensive insights about the dataset"""
466
+ insights = {}
467
+
468
+ # Basic statistics
469
+ insights['shape'] = df.shape
470
+ insights['missing_values'] = df.isna().sum().sum()
471
+ insights['duplicate_rows'] = df.duplicated().sum()
472
+
473
+ # Column types
474
+ insights['numeric_columns'] = list(df.select_dtypes(include=['number']).columns)
475
+ insights['categorical_columns'] = list(df.select_dtypes(include=['object', 'category', 'bool']).columns)
476
+ insights['datetime_columns'] = []
477
+ for col in df.columns:
478
+ try:
479
+ if pd.to_datetime(df[col], errors='coerce').notna().any():
480
+ insights['datetime_columns'].append(col)
481
+ except:
482
+ pass
483
+
484
+ # Distribution statistics
485
+ insights['skewed_columns'] = []
486
+ for col in insights['numeric_columns']:
487
+ if abs(df[col].skew()) > 1.0:
488
+ insights['skewed_columns'].append((col, df[col].skew()))
489
+
490
+ # Correlation analysis
491
+ if len(insights['numeric_columns']) > 1:
492
+ corr_matrix = df[insights['numeric_columns']].corr().abs()
493
+ corr_pairs = []
494
+ for i in range(len(corr_matrix.columns)):
495
+ for j in range(i):
496
+ if corr_matrix.iloc[i, j] > 0.7: # Strong correlation threshold
497
+ corr_pairs.append((corr_matrix.columns[i], corr_matrix.columns[j], corr_matrix.iloc[i, j]))
498
+ insights['correlated_features'] = sorted(corr_pairs, key=lambda x: x[2], reverse=True)
499
+
500
+ # Categorical feature analysis
501
+ insights['high_cardinality_features'] = []
502
+ for col in insights['categorical_columns']:
503
+ if df[col].nunique() > 10:
504
+ insights['high_cardinality_features'].append((col, df[col].nunique()))
505
+
506
+ # Missing value patterns
507
+ insights['missing_patterns'] = []
508
+ for col in df.columns:
509
+ missing_pct = df[col].isna().mean() * 100
510
+ if missing_pct > 0:
511
+ insights['missing_patterns'].append((col, missing_pct))
512
+
513
+ # Outlier detection
514
+ insights['outlier_columns'] = []
515
+ for col in insights['numeric_columns']:
516
+ Q1 = df[col].quantile(0.25)
517
+ Q3 = df[col].quantile(0.75)
518
+ IQR = Q3 - Q1
519
+ outliers_count = ((df[col] < (Q1 - 1.5 * IQR)) | (df[col] > (Q3 + 1.5 * IQR))).sum()
520
+ if outliers_count > 0:
521
+ insights['outlier_columns'].append((col, outliers_count, outliers_count/len(df)*100))
522
+
523
+ return insights
524
+
525
+ # Enhanced Model Selection and Training Functions
526
+ def get_model_options(problem_type):
527
+ """Get appropriate models for the selected problem type"""
528
+ if problem_type == "Classification":
529
+ return {
530
+ "Neural Network": MLPClassifier(max_iter=1000, random_state=42),
531
+ "Random Forest": RandomForestClassifier(random_state=42),
532
+ "Gradient Boosting": GradientBoostingClassifier(random_state=42),
533
+ "XGBoost": xgb.XGBClassifier(random_state=42)
534
+ }
535
+ elif problem_type == "Regression":
536
+ return {
537
+ "Neural Network": MLPRegressor(max_iter=1000, random_state=42),
538
+ "Random Forest": RandomForestRegressor(random_state=42),
539
+ "Gradient Boosting": GradientBoostingRegressor(random_state=42),
540
+ "XGBoost": xgb.XGBRegressor(random_state=42)
541
+ }
542
+ else: # Clustering
543
+ return {
544
+ "K-Means": KMeans(random_state=42),
545
+ "DBSCAN": DBSCAN(),
546
+ "Agglomerative": AgglomerativeClustering()
547
+ }
548
+
549
+ def train_model_with_optimization(model, X_train, X_test, y_train, y_test, problem_type, optimization_level="basic"):
550
+ """Train model with optional hyperparameter optimization"""
551
+ start_time = time.time()
552
+
553
+ if optimization_level == "none":
554
+ # Simple fit without optimization
555
+ model.fit(X_train, y_train)
556
+ best_model = model
557
+ elif optimization_level == "basic":
558
+ # Basic parameter grid
559
+ param_grid = {}
560
+
561
+ if problem_type in ["Classification", "Regression"]:
562
+ if isinstance(model, (RandomForestClassifier, RandomForestRegressor)):
563
+ param_grid = {
564
+ 'n_estimators': [100, 200],
565
+ 'max_depth': [None, 10, 20]
566
+ }
567
+ elif isinstance(model, (GradientBoostingClassifier, GradientBoostingRegressor)):
568
+ param_grid = {
569
+ 'n_estimators': [100, 200],
570
+ 'learning_rate': [0.01, 0.1]
571
+ }
572
+ elif isinstance(model, (MLPClassifier, MLPRegressor)):
573
+ param_grid = {
574
+ 'hidden_layer_sizes': [(100,), (100, 50)],
575
+ 'alpha': [0.0001, 0.001]
576
+ }
577
+ elif "XGB" in str(model.__class__):
578
+ param_grid = {
579
+ 'n_estimators': [100, 200],
580
+ 'learning_rate': [0.01, 0.1],
581
+ 'max_depth': [3, 6]
582
+ }
583
+ elif problem_type == "Clustering":
584
+ if isinstance(model, KMeans):
585
+ param_grid = {
586
+ 'n_clusters': [3, 4, 5, 6]
587
+ }
588
+ elif isinstance(model, DBSCAN):
589
+ param_grid = {
590
+ 'eps': [0.3, 0.5, 0.7],
591
+ 'min_samples': [5, 10, 15]
592
+ }
593
+ elif isinstance(model, AgglomerativeClustering):
594
+ param_grid = {
595
+ 'n_clusters': [3, 4, 5, 6],
596
+ 'linkage': ['ward', 'complete', 'average']
597
+ }
598
+
599
+ # Only run GridSearchCV if we have parameters to optimize
600
+ if param_grid:
601
+ if problem_type == "Clustering":
602
+ # For clustering, use silhouette score as the metric
603
+ from sklearn.metrics import make_scorer, silhouette_score
604
+ from sklearn.model_selection import GridSearchCV
605
+
606
+ # Custom scorer for clustering
607
+ def silhouette_scorer(estimator, X):
608
+ labels = estimator.fit_predict(X)
609
+ if len(set(labels)) <= 1: # Check if all points are in one cluster
610
+ return -1
611
+ return silhouette_score(X, labels)
612
+
613
+ grid_search = GridSearchCV(
614
+ estimator=model,
615
+ param_grid=param_grid,
616
+ scoring=make_scorer(silhouette_scorer),
617
+ cv=3,
618
+ n_jobs=-1
619
+ )
620
+ grid_search.fit(X_train)
621
+ else:
622
+ # For classification and regression
623
+ scoring = 'accuracy' if problem_type == "Classification" else 'r2'
624
+ grid_search = GridSearchCV(
625
+ estimator=model,
626
+ param_grid=param_grid,
627
+ scoring=scoring,
628
+ cv=5,
629
+ n_jobs=-1
630
+ )
631
+ grid_search.fit(X_train, y_train)
632
+
633
+ best_model = grid_search.best_estimator_
634
+ else:
635
+ # If no param grid, just fit the model
636
+ model.fit(X_train, y_train)
637
+ best_model = model
638
+ else: # Advanced optimization
639
+ # TODO: Implement advanced optimization with more parameters,
640
+ # RandomizedSearchCV or BayesianOptimization
641
+ pass
642
+
643
+ # Calculate training time
644
+ training_time = time.time() - start_time
645
+
646
+ # Get predictions for evaluation
647
+ if problem_type == "Clustering":
648
+ if hasattr(best_model, 'predict'):
649
+ y_pred = best_model.predict(X_test)
650
+ else:
651
+ y_pred = best_model.fit_predict(X_test)
652
+ else:
653
+ y_pred = best_model.predict(X_test)
654
+
655
+ return best_model, y_pred, training_time
656
+
657
  # Pages
658
  def data_upload_page():
659
+ """Enhanced data upload & analysis page"""
660
+ st.markdown('<div class="card">', unsafe_allow_html=True)
661
+
662
+ # Create a header with animation
663
+ col1, col2 = st.columns([1, 3])
664
+ with col1:
665
+ st_lottie(lottie_upload, height=150, key="upload_animation")
666
+ with col2:
667
+ st.markdown('<div class="step-header">', unsafe_allow_html=True)
668
+ st.markdown('<div class="step-counter">1</div>', unsafe_allow_html=True)
669
+ st.markdown('<div class="step-title">Data Upload & Exploratory Analysis</div>', unsafe_allow_html=True)
670
+ st.markdown('</div>', unsafe_allow_html=True)
671
+ st.markdown("Upload your dataset and get comprehensive insights before model training.")
672
+
673
+ # File uploader with enhanced UI
674
+ uploaded_file = st.file_uploader("Upload Dataset (CSV, Excel, or JSON)",
675
+ type=["csv", "xlsx", "json"],
676
+ help="Upload your data file to start analysis")
677
 
678
  if uploaded_file:
679
+ # Provide feedback during loading
680
+ with st.spinner('Reading and analyzing your dataset...'):
681
+ # Determine file type and read
682
+ if uploaded_file.name.endswith('csv'):
683
+ df = pd.read_csv(uploaded_file)
684
+ elif uploaded_file.name.endswith('xlsx'):
685
+ df = pd.read_excel(uploaded_file)
686
+ elif uploaded_file.name.endswith('json'):
687
+ df = pd.read_json(uploaded_file)
688
+
689
+ # Store in session state
690
+ st.session_state.df = df
691
+ st.session_state.vector_store = create_vector_store(convert_df_to_text(df))
692
+ st.session_state.metrics = {}
693
+
694
+ # Generate insights
695
+ st.session_state.dataset_insights = generate_data_insights(df)
696
+
697
+ # Success notification
698
+ if 'notification' not in st.session_state or st.session_state.notification is None:
699
+ st.session_state.notification = "Data successfully loaded! 🎉"
700
+
701
+ # Display a notification
702
+ st.markdown(f"""
703
+ <div class="notification">
704
+ <div class="notification-icon">✓</div>
705
+ <div>{st.session_state.notification}</div>
706
+ </div>
707
+ """, unsafe_allow_html=True)
708
+ st.session_state.notification = None
709
+
710
+ # Create tabs for different data views
711
+ data_tabs = st.tabs(["📊 Overview", "🔍 Data Explorer", "📈 Visualizations", "📋 Profile Report"])
712
+
713
+ with data_tabs[0]:
714
+ st.subheader("Dataset Overview")
715
+ col1, col2, col3 = st.columns(3)
716
+
717
+ with col1:
718
+ st.markdown('<div class="metrics-card">', unsafe_allow_html=True)
719
+ st.markdown(f'<div class="metrics-value">{df.shape[0]:,}</div>', unsafe_allow_html=True)
720
+ st.markdown('<div class="metrics-label">Total Samples</div>', unsafe_allow_html=True)
721
+ st.markdown('</div>', unsafe_allow_html=True)
722
+
723
+ with col2:
724
+ st.markdown('<div class="metrics-card">', unsafe_allow_html=True)
725
+ st.markdown(f'<div class="metrics-value">{df.shape[1]}</div>', unsafe_allow_html=True)
726
+ st.markdown('<div class="metrics-label">Features</div>', unsafe_allow_html=True)
727
+ st.markdown('</div>', unsafe_allow_html=True)
728
+
729
+ with col3:
730
+ missing_pct = df.isna().sum().sum() / (df.shape[0] * df.shape[1]) * 100
731
+ st.markdown('<div class="metrics-card">', unsafe_allow_html=True)
732
+ st.markdown(f'<div class="metrics-value">{missing_pct:.1f}%</div>', unsafe_allow_html=True)
733
+ st.markdown('<div class="metrics-label">Missing Values</div>', unsafe_allow_html=True)
734
+ st.markdown('</div>', unsafe_allow_html=True)
735
+
736
+ # Data Types Breakdown
737
+ st.subheader("Data Types")
738
+ dtype_counts = df.dtypes.value_counts().reset_index()
739
+ dtype_counts.columns = ['Data Type', 'Count']
740
+ fig = px.pie(dtype_counts, values='Count', names='Data Type', hole=0.4,
741
+ color_discrete_sequence=px.colors.qualitative.Bold)
742
+ fig.update_layout(margin=dict(t=0, b=0, l=0, r=0), height=300)
743
+ st.plotly_chart(fig, use_container_width=True)
744
+
745
+ # Key Insights
746
+ st.subheader("Key Insights")
747
+ insights = st.session_state.dataset_insights
748
+
749
+ insight_cols = st.columns(2)
750
+
751
+ with insight_cols[0]:
752
+ st.markdown("**Data Quality Issues**")
753
+ if insights['missing_patterns']:
754
+ st.markdown("🔹 **Missing Values:**")
755
+ for col, pct in sorted(insights['missing_patterns'], key=lambda x: x[1], reverse=True)[:5]:
756
+ st.markdown(f" • *{col}*: {pct:.1f}% missing")
757
+
758
+ if insights['outlier_columns']:
759
+ st.markdown("🔹 **Outliers Detected:**")
760
+ for col, count, pct in sorted(insights['outlier_columns'], key=lambda x: x[2], reverse=True)[:5]:
761
+ st.markdown(f" • *{col}*: {count} outliers ({pct:.1f}%)")
762
+
763
+ with insight_cols[1]:
764
+ st.markdown("**Feature Relationships**")
765
+ if 'correlated_features' in insights and insights['correlated_features']:
766
+ st.markdown("🔹 **Highly Correlated Features:**")
767
+ for col1, col2, corr in insights['correlated_features'][:5]:
768
+ st.markdown(f" • *{col1}* & *{col2}*: {corr:.2f} correlation")
769
+
770
+ if insights['skewed_columns']:
771
+ st.markdown("🔹 **Skewed Distributions:**")
772
+ for col, skew in sorted(insights['skewed_columns'], key=lambda x: abs(x[1]), reverse=True)[:5]:
773
+ direction = "right" if skew > 0 else "left"
774
+ st.markdown(f" • *{col}*: {direction}-skewed ({skew:.2f})")
775
 
776
+ with data_tabs[1]:
777
+ st.subheader("Interactive Data Explorer")
778
+
779
+ # Filter and search options
780
+ col1, col2 = st.columns([2, 3])
781
+ with col1:
782
+ search_term = st.text_input("Search columns", "")
783
+ with col2:
784
+ selected_dtypes = st.multiselect(
785
+ "Filter by data type",
786
+ options=['numeric', 'object', 'datetime', 'category', 'bool'],
787
+ default=['numeric', 'object']
788
+ )
789
+
790
+ # Apply filters
791
+ filtered_cols = df.columns
792
+ if search_term:
793
+ filtered_cols = [col for col in filtered_cols if search_term.lower() in col.lower()]
794
+
795
+ if selected_dtypes:
796
+ dtype_map = {
797
+ 'numeric': 'number',
798
+ 'object': 'object',
799
+ 'datetime': 'datetime',
800
+ 'category': 'category',
801
+ 'bool': 'bool'
802
+ }
803
+ dtype_filtered = []
804
+ for dtype in selected_dtypes:
805
+ dtype_filtered.extend(df.select_dtypes(include=[dtype_map[dtype]]).columns)
806
+ filtered_cols = [col for col in filtered_cols if col in dtype_filtered]
807
+
808
+ # Display filtered dataframe
809
+ if filtered_cols:
810
+ st.dataframe(df[filtered_cols], height=400)
811
+
812
+ # Column statistics
813
+ selected_column = st.selectbox("Select column for detailed statistics", options=filtered_cols)
814
+
815
+ col1, col2 = st.columns(2)
816
+ with col1:
817
+ st.subheader(f"Statistics for: {selected_column}")
818
+ if pd.api.types.is_numeric_dtype(df[selected_column]):
819
+ stats = df[selected_column].describe()
820
+ for stat, value in stats.items():
821
+ st.markdown(f"**{stat}:** {value:.4f}")
822
+
823
+ # Additional stats
824
+ st.markdown(f"**Skewness:** {df[selected_column].skew():.4f}")
825
+ st.markdown(f"**Kurtosis:** {df[selected_column].kurtosis():.4f}")
826
+ else:
827
+ st.markdown(f"**Unique Values:** {df[selected_column].nunique()}")
828
+ st.markdown(f"**Most Common:** {df[selected_column].value_counts().index[0]}")
829
+ st.markdown(f"**Least Common:** {df[selected_column].value_counts().index[-1]}")
830
+ st.markdown(f"**Missing Values:** {df[selected_column].isna().sum()} ({df[selected_column].isna().mean()*100:.2f}%)")
831
+
832
+ with col2:
833
+ st.subheader("Distribution")
834
+ if pd.api.types.is_numeric_dtype(df[selected_column]):
835
+ # Histogram for numeric
836
+ fig = px.histogram(df, x=selected_column, histnorm='probability density',
837
+ marginal='box', color_discrete_sequence=['#3B82F6'])
838
+ fig.update_layout(height=300, margin=dict(l=0, r=0, t=20, b=0))
839
+ else:
840
+ # Bar chart for categorical
841
+ value_counts = df[selected_column].value_counts().reset_index()
842
+ value_counts.columns = [selected_column, 'Count']
843
+ value_counts = value_counts.head(15) # Limit to top 15
844
+ fig = px.bar(value_counts, x=selected_column, y='Count',
845
+ color_discrete_sequence=['#3B82F6'])
846
+ fig.update_layout(height=300, margin=dict(l=0, r=0, t=20, b=0))
847
+
848
+ st.plotly_chart(fig, use_container_width=True)
849
+ else:
850
+ st.warning("No columns match your filters.")
851
 
852
+ with data_tabs[2]:
853
+ st.subheader("Data Visualizations")
854
+
855
+ viz_type = st.selectbox(
856
+ "Select Visualization Type",
857
+ options=["Scatter Plot", "Correlation Matrix", "Pair Plot", "Box Plot", "Violin Plot", "Line Chart"]
858
+ )
859
+
860
+ if viz_type == "Scatter Plot":
861
+ col1, col2, col3 = st.columns(3)
862
+ with col1:
863
+ x_col = st.selectbox("X-axis", options=df.select_dtypes(include=['number']).columns)
864
+ with col2:
865
+ y_col = st.selectbox("Y-axis", options=df.select_dtypes(include=['number']).columns,
866
+ index=min(1, len(df.select_dtypes(include=['number']).columns)-1))
867
+ with col3:
868
+ color_col = st.selectbox("Color by", options=["None"] + list(df.columns), index=0)
869
+
870
+ # Create plot
871
+ if color_col == "None":
872
+ fig = px.scatter(df, x=x_col, y=y_col, title=f"{x_col} vs {y_col}",
873
+ opacity=0.7, color_discrete_sequence=['#3B82F6'])
874
+ else:
875
+ fig = px.scatter(df, x=x_col, y=y_col, color=color_col, title=f"{x_col} vs {y_col} by {color_col}",
876
+ opacity=0.7)
877
+
878
+ fig.update_layout(height=500)
879
+ st.plotly_chart(fig, use_container_width=True)
880
+
881
+ # Add regression line option
882
+ if st.checkbox("Add regression line"):
883
+ fig = px.scatter(df, x=x_col, y=y_col, trendline="ols",
884
+ title=f"{x_col} vs {y_col} with Regression Line",
885
+ opacity=0.7, color_discrete_sequence=['#3B82F6'])
886
+ fig.update_layout(height=500)
887
+ st.plotly_chart(fig, use_container_width=True)
888
+
889
+ elif viz_type == "Correlation Matrix":
890
+ # Select columns for correlation
891
+ numeric_cols = df.select_dtypes(include=['number']).columns
892
+ selected_corr_cols = st.multiselect(
893
+ "Select columns for correlation matrix",
894
+ options=numeric_cols,
895
+ default=list(numeric_cols)[:min(8, len(numeric_cols))]
896
+ )
897
+
898
+ if selected_corr_cols:
899
+ # Correlation matrix
900
+ corr = df[selected_corr_cols].corr()
901
+ mask = np.triu(np.ones_like(corr, dtype=bool))
902
+
903
+ # Create plot
904
+ fig = px.imshow(corr, text_auto=True, color_continuous_scale='RdBu_r',
905
+ zmin=-1, zmax=1, aspect="auto")
906
+ fig.update_layout(height=600)
907
+
908
+ st.plotly_chart(fig, use_container_width=True)
909
+ else:
910
+ st.warning("Please select at least one numeric column.")
911
+
912
+ elif viz_type == "Pair Plot":
913
+ # Select columns for pair plot
914
+ numeric_cols = df.select_dtypes(include=['number']).columns
915
+ selected_pair_cols = st.multiselect(
916
+ "Select columns for pair plot (limit 4-5 for readability)",
917
+ options=numeric_cols,
918
+ default=list(numeric_cols)[:min(4, len(numeric_cols))]
919
+ )
920
+
921
+ if selected_pair_cols:
922
+ if len(selected_pair_cols) > 6:
923
+ st.warning("Too many columns may make the plot hard to read. Consider selecting 4-5 columns.")
924
+
925
+ # Color option
926
+ color_col = st.selectbox("Color by (categorical)",
927
+ options=["None"] + list(df.select_dtypes(exclude=['number']).columns),
928
+ index=0)
929
+
930
+ # Create pair plot
931
+ if color_col == "None":
932
+ fig = px.scatter_matrix(df, dimensions=selected_pair_cols, opacity=0.7)
933
+ else:
934
+ fig = px.scatter_matrix(df, dimensions=selected_pair_cols, color=color_col, opacity=0.7)
935
+
936
+ fig.update_layout(height=700)
937
+ st.plotly_chart(fig, use_container_width=True)
938
+ else:
939
+ st.warning("Please select at least one numeric column.")
940
+
941
+ elif viz_type == "Box Plot":
942
+ col1, col2 = st.columns(2)
943
+ with col1:
944
+ y_col = st.selectbox("Value column", options=df.select_dtypes(include=['number']).columns)
945
+ with col2:
946
+ x_col = st.selectbox("Category column",
947
+ options=["None"] + list(df.select_dtypes(exclude=['number']).columns),
948
+ index=0)
949
+
950
+ # Create plot
951
+ if x_col == "None":
952
+ fig = px.box(df, y=y_col, title=f"Distribution of {y_col}",
953
+ color_discrete_sequence=['#3B82F6'])
954
+ else:
955
+ fig = px.box(df, x=x_col, y=y_col, title=f"Distribution of {y_col} by {x_col}")
956
+
957
+ fig.update_layout(height=500)
958
+ st.plotly_chart(fig, use_container_width=True)
959
+
960
+ elif viz_type == "Violin Plot":
961
+ col1, col2 = st.columns(2)
962
+ with col1:
963
+ y_col = st.selectbox("Value column", options=df.select_dtypes(include=['number']).columns, key="violin_y")
964
+ with col2:
965
+ x_col = st.selectbox("Category column",
966
+ options=["None"] + list(df.select_dtypes(exclude=['number']).columns),
967
+ index=0, key="violin_x")
968
+
969
+ # Create plot
970
+ if x_col == "None":
971
+ fig = px.violin(df, y=y_col, box=True, title=f"Distribution of {y_col}",
972
+ color_discrete_sequence=['#3B82F6'])
973
+ else:
974
+ fig = px.violin(df, x=x_col, y=y_col, box=True, title=f"Distribution of {y_col} by {x_col}")
975
+
976
+ fig.update_layout(height=500)
977
+ st.plotly_chart(fig, use_container_width=True)
978
+
979
+ elif viz_type == "Line Chart":
980
+ # Identify potential date columns
981
+ date_cols = []
982
+ for col in df.columns:
983
+ try:
984
+ if pd.to_datetime(df[col], errors='coerce').notna().all():
985
+ date_cols.append(col)
986
+ except:
987
+ pass
988
+
989
+ if date_cols:
990
+ col1, col2 = st.columns(2)
991
+ with col1:
992
+ x_col = st.selectbox("Time axis", options=date_cols)
993
+ # Convert to datetime if not already
994
+ df[x_col] = pd.to_datetime(df[x_col])
995
+ with col2:
996
+ y_cols = st.multiselect("Value columns", options=df.select_dtypes(include=['number']).columns,
997
+ default=[df.select_dtypes(include=['number']).columns[0]])
998
+
999
+ if y_cols:
1000
+ # Create line chart
1001
+ fig = go.Figure()
1002
+ for y_col in y_cols:
1003
+ fig.add_trace(go.Scatter(x=df[x_col], y=df[y_col], mode='lines', name=y_col))
1004
+
1005
+ fig.update_layout(
1006
+ title=f"Time Series Plot",
1007
+ xaxis_title=x_col,
1008
+ yaxis_title="Values",
1009
+ height=500,
1010
+ legend=dict(orientation="h", yanchor="bottom", y=1.02, xanchor="right", x=1)
1011
+ )
1012
+
1013
+ st.plotly_chart(fig, use_container_width=True)
1014
+ else:
1015
+ st.warning("Please select at least one value column.")
1016
+ else:
1017
+ st.warning("No datetime columns detected. Please ensure you have columns with date/time values.")
1018
+
1019
+ with data_tabs[3]:
1020
+ st.subheader("Comprehensive Profiling Report")
1021
+
1022
+ profile_options = st.columns(3)
1023
+ with profile_options[0]:
1024
+ minimal = st.checkbox("Minimal Report (Faster)", value=True)
1025
+ with profile_options[1]:
1026
+ sample_data = st.checkbox("Use Sample (Faster for large datasets)", value=True)
1027
+ with profile_options[2]:
1028
+ report_percent = st.slider("Sample Size %", min_value=10, max_value=100, value=50, step=10)
1029
+
1030
+ if st.button("Generate Profile Report"):
1031
+ with st.spinner("Generating comprehensive profile report..."):
1032
+ if sample_data and len(df) > 1000:
1033
+ profile_df = df.sample(int(len(df) * report_percent/100))
1034
+ else:
1035
+ profile_df = df
1036
+
1037
+ if minimal:
1038
+ profile = ProfileReport(profile_df, minimal=True, title="Dataset Profile Report")
1039
+ else:
1040
+ profile = ProfileReport(profile_df, explorative=True, title="Dataset Profile Report")
1041
+
1042
+ st_profile_report(profile)
1043
+
1044
+ st.markdown('</div>', unsafe_allow_html=True)
1045
 
1046
  def model_training_page():
1047
+ """Enhanced model training page with more options and better UI"""
1048
+ st.markdown('<div class="card">', unsafe_allow_html=True)
1049
+
1050
+ # Header with animation
1051
+ col1, col2 = st.columns([1, 3])
1052
+ with col1:
1053
+ st_lottie(lottie_neural, height=150, key="neural_animation")
1054
+ with col2:
1055
+ st.markdown('<div class="step-header">', unsafe_allow_html=True)
1056
+ st.markdown('<div class="step-counter">2</div>', unsafe_allow_html=True)
1057
+ st.markdown('<div class="step-title">Neural Network Training Studio</div>', unsafe_allow_html=True)
1058
+ st.markdown('</div>', unsafe_allow_html=True)
1059
+ st.markdown("Train advanced models with automated optimization and hyperparameter tuning.")
1060
 
1061
  if 'df' not in st.session_state:
1062
+ st.warning("Please upload data first!")
1063
+ st.markdown('</div>', unsafe_allow_html=True)
1064
  return
1065
 
1066
  df = st.session_state.df
 
 
1067
 
1068
+ # Create multiple tabs for the workflow
1069
+ train_tabs = st.tabs(["⚙️ Setup", "🔄 Preprocessing", "🧠 Training", "📊 Results"])
 
 
 
 
 
1070
 
1071
+ with train_tabs[0]:
1072
+ st.subheader("Model Configuration")
1073
+
1074
+ # Problem type selection
1075
+ problem_type = st.selectbox(
1076
+ "Select Problem Type",
1077
+ ["Classification", "Regression", "Clustering"],
1078
+ help="Classification: predict categories, Regression: predict continuous values, Clustering: group similar data points"
1079
+ )
1080
+
1081
+ # Domain specialization
1082
+ domain_col1, domain_col2 = st.columns(2)
1083
+ with domain_col1:
1084
+ mode = st.selectbox(
1085
+ "Domain Specialization",
1086
+ ["General", "Legal", "Financial", "Medical", "Technical", "Academic"],
1087
+ help="Optimize the model for your specific domain"
1088
+ )
1089
+
1090
+ with domain_col2:
1091
+ experiment_name = st.text_input(
1092
+ "Experiment Name",
1093
+ value=f"{problem_type}_{datetime.datetime.now().strftime('%Y%m%d_%H%M')}",
1094
+ help="Name your experiment for reference"
1095
+ )
1096
+
1097
+ # Target variable selection
1098
+ if problem_type != "Clustering":
1099
+ target_col1, target_col2 = st.columns(2)
1100
 
1101
+ with target_col1:
1102
+ target = st.selectbox("Select Target Variable", df.columns)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1103
 
1104
+ with target_col2:
1105
+ if problem_type == "Classification":
1106
+ st.info(f"Class Distribution: {df[target].value_counts().to_dict()}")
1107
+ else:
1108
+ st.info(f"Target Range: {df[target].min()} to {df[target].max()}")
1109
+
1110
+ # Feature selection
1111
+ st.subheader("Feature Selection")
1112
+ select_features = st.checkbox("Select specific features", value=False)
1113
+
1114
+ if select_features:
1115
+ available_features = [col for col in df.columns if col != target]
1116
+ selected_features = st.multiselect(
1117
+ "Select features to include",
1118
+ options=available_features,
1119
+ default=available_features
1120
+ )
1121
+ st.session_state.selected_columns = selected_features + [target]
1122
+ else:
1123
+ st.session_state.selected_columns = df.columns.tolist()
1124
+ else:
1125
+ # For clustering, all columns are features
1126
 
1127
  def visualization_page():
1128
+ """Visualization and evaluation page for trained models"""
1129
+ st.markdown('<div class="card">', unsafe_allow_html=True)
1130
+
1131
+ # Header with animation
1132
+ col1, col2 = st.columns([1, 3])
1133
+ with col1:
1134
+ st_lottie(lottie_visualization, height=150, key="viz_animation")
1135
+ with col2:
1136
+ st.markdown('<div class="step-header">', unsafe_allow_html=True)
1137
+ st.markdown('<div class="step-counter">3</div>', unsafe_allow_html=True)
1138
+ st.markdown('<div class="step-title">Neural Network Evaluation Center</div>', unsafe_allow_html=True)
1139
+ st.markdown('</div>', unsafe_allow_html=True)
1140
+ st.markdown("Visualize, interpret, and validate your trained neural networks.")
1141
 
1142
  if 'best_model' not in st.session_state:
1143
+ st.warning("Please train a model first!")
1144
+ st.markdown('</div>', unsafe_allow_html=True)
1145
  return
1146
 
1147
+ # Evaluation tabs for different analyses
1148
+ eval_tabs = st.tabs(["📊 Model Performance", "🔍 Model Interpretation", "🧪 Test Predictions"])
1149
+
1150
+ # Tabs content would go here
 
 
 
 
 
1151
 
1152
+ st.markdown('</div>', unsafe_allow_html=True)
 
1153
 
 
1154
  def ai_assistant():
1155
+ """AI Assistant for neural network development guidance"""
1156
  st.markdown('<div class="chat-container">', unsafe_allow_html=True)
1157
+ st.subheader("📚 Neural Network Development Assistant")
1158
 
1159
+ user_input = st.text_area("Ask a question about your data or neural network development:", "")
1160
+ use_web_search = st.checkbox("Enable web search for up-to-date information", value=False)
1161
 
1162
+ if st.button("Get AI Guidance"):
1163
+ if user_input:
1164
+ with st.spinner("Analyzing your question..."):
1165
+ # Add user message to chat history
1166
+ st.session_state.chat_history.append({"role": "user", "content": user_input})
1167
+ for msg in st.session_state.chat_history:
1168
+ if msg["role"] == "user":
1169
+ st.markdown(f'<div class="user-message">{msg["content"]}</div>', unsafe_allow_html=True)
1170
+ else:
1171
+ st.markdown(f'<div class="bot-message">{msg["content"]}</div>', unsafe_allow_html=True)
1172
+
1173
+ # Generate response
1174
+ try:
1175
+ ai_response = get_groq_response(user_input, st.session_state.get('mode', 'General'), use_web_search)
1176
+ st.session_state.chat_history.append({"role": "assistant", "content": ai_response})
1177
+
1178
+ st.markdown(f'<div class="bot-message">{ai_response}</div>', unsafe_allow_html=True)
1179
+ except Exception as e:
1180
+ st.error(f"Error getting AI response: {str(e)}")
1181
+ st.info("Falling back to alternative model...")
1182
+ try:
1183
+ # Fallback to OpenAI
1184
+ ai_response = "I'm sorry, I couldn't generate a proper response. Please try rephrasing your question."
1185
+ st.session_state.chat_history.append({"role": "assistant", "content": ai_response})
1186
+ st.markdown(f'<div class="bot-message">{ai_response}</div>', unsafe_allow_html=True)
1187
+ except:
1188
+ st.error("Both primary and fallback AI services failed. Please try again later.")
1189
 
1190
  st.markdown('</div>', unsafe_allow_html=True)
1191
 
1192
+ # Initialize additional session state variables
1193
+ if 'notification' not in st.session_state:
1194
+ st.session_state.notification = None
 
 
 
 
1195
 
1196
+ # Import and initialize Lottie animations
1197
+ def load_lottie_url(url):
1198
+ """Load Lottie animation from URL"""
1199
+ try:
1200
+ import json
1201
+ import requests
1202
+ r = requests.get(url)
1203
+ if r.status_code != 200:
1204
+ return None
1205
+ return r.json()
1206
+ except:
1207
+ return None
1208
+
1209
+ # Lottie animations
1210
+ lottie_upload = load_lottie_url("https://assets9.lottiefiles.com/packages/lf20_grdj1jti.json")
1211
+ lottie_neural = load_lottie_url("https://assets8.lottiefiles.com/private_files/lf30_8uvz2gcg.json")
1212
+ lottie_visualization = load_lottie_url("https://assets5.lottiefiles.com/packages/lf20_usmfx6bp.json")
 
 
 
 
 
1213
 
1214
+ # Main function to run the app
1215
+ def main():
1216
+ """Main function to run the app"""
1217
+ # Main App Layout
1218
+ st.markdown("""
1219
+ <div class="header">
1220
+ <h1 class="header-title">Neural-Vision Enhanced</h1>
1221
+ <div class="header-subtitle">Neural Network Development for Domain-Specialized Analysis</div>
1222
+ </div>
1223
+ """, unsafe_allow_html=True)
1224
+
1225
+ with st.sidebar:
1226
+ st.title("🔮 Neural-Vision Enhanced")
1227
+ page = st.selectbox("Navigation", [
1228
+ "Data Upload & Analysis",
1229
+ "Neural Network Training Studio",
1230
+ "Neural Network Evaluation Center"
1231
+ ])
1232
+ st.session_state.active_page = page
1233
+ st.markdown("---")
1234
+ st.markdown("**Environment Setup**")
1235
+
1236
+ # Tavily API Key Input and Submit Button
1237
+ tavily_api_input = st.text_input("Tavily API Key", type="password", help="Enter your Tavily API key for web search functionality")
1238
+ if st.button("Submit API Key"):
1239
+ if tavily_api_input:
1240
+ st.session_state.tavily_api_key = tavily_api_input
1241
+ st.success("Tavily API Key submitted successfully!")
1242
+ else:
1243
+ st.warning("Please enter a valid API key.")
1244
+
1245
+ st.markdown("---")
1246
+ st.markdown("v5.0 | © 2025 Neural-Vision")
1247
+
1248
+ # Page Routing
1249
+ if st.session_state.active_page == "Data Upload & Analysis":
1250
+ data_upload_page()
1251
+ elif st.session_state.active_page == "Neural Network Training Studio":
1252
+ model_training_page()
1253
+ else:
1254
+ visualization_page()
1255
+
1256
+ ai_assistant()
1257
 
1258
+ # Run the app
1259
+ if __name__ == "__main__":
1260
+ main()