Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -1,17 +1,14 @@
|
|
1 |
import streamlit as st
|
2 |
import pandas as pd
|
3 |
import plotly.express as px
|
4 |
-
import plotly.graph_objects as go
|
5 |
import numpy as np
|
6 |
from sklearn.model_selection import train_test_split
|
7 |
from sklearn.neural_network import MLPClassifier, MLPRegressor
|
8 |
-
from sklearn.cluster import KMeans
|
9 |
-
from sklearn.ensemble import RandomForestClassifier, RandomForestRegressor, GradientBoostingClassifier, GradientBoostingRegressor
|
10 |
from sklearn.metrics import accuracy_score, r2_score, silhouette_score, confusion_matrix, classification_report, mean_squared_error
|
11 |
from sklearn.preprocessing import StandardScaler
|
12 |
from ydata_profiling import ProfileReport
|
13 |
from streamlit_pandas_profiling import st_profile_report
|
14 |
-
from streamlit_lottie import st_lottie
|
15 |
from groq import Groq
|
16 |
from langchain_community.vectorstores import FAISS
|
17 |
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
@@ -21,11 +18,6 @@ from langchain_community.tools.tavily_search import TavilySearchResults
|
|
21 |
import os
|
22 |
from dotenv import load_dotenv
|
23 |
import tempfile
|
24 |
-
import datetime
|
25 |
-
import time
|
26 |
-
import matplotlib.pyplot as plt
|
27 |
-
import shap
|
28 |
-
import xgboost as xgb
|
29 |
|
30 |
# Load environment variables
|
31 |
load_dotenv()
|
@@ -149,102 +141,6 @@ st.markdown("""
|
|
149 |
transform: translateY(-2px);
|
150 |
box-shadow: 0 6px 16px rgba(59, 130, 246, 0.4);
|
151 |
}
|
152 |
-
/* Card styles */
|
153 |
-
.card {
|
154 |
-
background: var(--white);
|
155 |
-
border-radius: 16px;
|
156 |
-
box-shadow: 0 4px 16px rgba(0,0,0,0.1);
|
157 |
-
padding: 20px;
|
158 |
-
margin-bottom: 25px;
|
159 |
-
transition: all 0.3s ease;
|
160 |
-
}
|
161 |
-
.card:hover {
|
162 |
-
box-shadow: 0 8px 24px rgba(0,0,0,0.15);
|
163 |
-
transform: translateY(-2px);
|
164 |
-
}
|
165 |
-
|
166 |
-
/* Step header styles */
|
167 |
-
.step-header {
|
168 |
-
display: flex;
|
169 |
-
align-items: center;
|
170 |
-
margin-bottom: 15px;
|
171 |
-
}
|
172 |
-
.step-counter {
|
173 |
-
background: var(--primary-blue);
|
174 |
-
color: var(--white);
|
175 |
-
width: 36px;
|
176 |
-
height: 36px;
|
177 |
-
border-radius: 50%;
|
178 |
-
display: flex;
|
179 |
-
align-items: center;
|
180 |
-
justify-content: center;
|
181 |
-
font-weight: bold;
|
182 |
-
margin-right: 15px;
|
183 |
-
box-shadow: 0 4px 10px rgba(59, 130, 246, 0.3);
|
184 |
-
}
|
185 |
-
.step-title {
|
186 |
-
font-size: 1.5rem;
|
187 |
-
font-weight: 700;
|
188 |
-
color: var(--dark-blue);
|
189 |
-
}
|
190 |
-
|
191 |
-
/* Notification styles */
|
192 |
-
.notification {
|
193 |
-
display: flex;
|
194 |
-
align-items: center;
|
195 |
-
background: #ECFDF5;
|
196 |
-
border-left: 4px solid #059669;
|
197 |
-
color: #065F46;
|
198 |
-
padding: 15px;
|
199 |
-
border-radius: 8px;
|
200 |
-
margin: 15px 0;
|
201 |
-
box-shadow: 0 2px 8px rgba(5, 150, 105, 0.1);
|
202 |
-
transition: transform 0.2s ease;
|
203 |
-
}
|
204 |
-
.notification:hover {
|
205 |
-
transform: translateY(-2px);
|
206 |
-
}
|
207 |
-
.notification-icon {
|
208 |
-
background: #059669;
|
209 |
-
color: white;
|
210 |
-
width: 24px;
|
211 |
-
height: 24px;
|
212 |
-
border-radius: 50%;
|
213 |
-
display: flex;
|
214 |
-
align-items: center;
|
215 |
-
justify-content: center;
|
216 |
-
margin-right: 15px;
|
217 |
-
font-weight: bold;
|
218 |
-
}
|
219 |
-
|
220 |
-
/* Metrics card styles */
|
221 |
-
.metrics-card {
|
222 |
-
background: var(--white);
|
223 |
-
border-radius: 12px;
|
224 |
-
padding: 15px;
|
225 |
-
box-shadow: 0 2px 10px rgba(0,0,0,0.08);
|
226 |
-
text-align: center;
|
227 |
-
transition: all 0.3s ease;
|
228 |
-
height: 100%;
|
229 |
-
display: flex;
|
230 |
-
flex-direction: column;
|
231 |
-
justify-content: center;
|
232 |
-
}
|
233 |
-
.metrics-card:hover {
|
234 |
-
transform: translateY(-3px);
|
235 |
-
box-shadow: 0 6px 16px rgba(0,0,0,0.12);
|
236 |
-
}
|
237 |
-
.metrics-value {
|
238 |
-
font-size: 2rem;
|
239 |
-
font-weight: 700;
|
240 |
-
color: var(--primary-blue);
|
241 |
-
margin-bottom: 5px;
|
242 |
-
}
|
243 |
-
.metrics-label {
|
244 |
-
color: var(--medium-grey);
|
245 |
-
font-size: 0.9rem;
|
246 |
-
font-weight: 500;
|
247 |
-
}
|
248 |
</style>
|
249 |
""", unsafe_allow_html=True)
|
250 |
|
@@ -338,932 +234,165 @@ def plot_clusters(X, labels):
|
|
338 |
fig = px.scatter(X, x=X.columns[0], y=X.columns[1], color=labels, title="Cluster Visualization")
|
339 |
return fig
|
340 |
|
341 |
-
def plot_learning_curve(model, X, y, cv=5):
|
342 |
-
"""Plot learning curve to show model performance with increasing data"""
|
343 |
-
from sklearn.model_selection import learning_curve
|
344 |
-
|
345 |
-
train_sizes, train_scores, test_scores = learning_curve(
|
346 |
-
model, X, y, cv=cv, scoring='accuracy' if hasattr(y, 'nunique') else 'r2',
|
347 |
-
n_jobs=-1, train_sizes=np.linspace(0.1, 1.0, 10))
|
348 |
-
|
349 |
-
train_mean = np.mean(train_scores, axis=1)
|
350 |
-
train_std = np.std(train_scores, axis=1)
|
351 |
-
test_mean = np.mean(test_scores, axis=1)
|
352 |
-
test_std = np.std(test_scores, axis=1)
|
353 |
-
|
354 |
-
# Create DataFrame for plotting
|
355 |
-
df_curve = pd.DataFrame({
|
356 |
-
'Training Size (%)': train_sizes / len(X) * 100,
|
357 |
-
'Training Score': train_mean,
|
358 |
-
'Training Upper': train_mean + train_std,
|
359 |
-
'Training Lower': train_mean - train_std,
|
360 |
-
'Testing Score': test_mean,
|
361 |
-
'Testing Upper': test_mean + test_std,
|
362 |
-
'Testing Lower': test_mean - test_std
|
363 |
-
})
|
364 |
-
|
365 |
-
# Create the plot
|
366 |
-
fig = go.Figure()
|
367 |
-
|
368 |
-
# Add training data with confidence interval
|
369 |
-
fig.add_trace(go.Scatter(
|
370 |
-
x=df_curve['Training Size (%)'],
|
371 |
-
y=df_curve['Training Score'],
|
372 |
-
mode='lines+markers',
|
373 |
-
name='Training Score',
|
374 |
-
line=dict(color='blue', width=2),
|
375 |
-
marker=dict(size=8)
|
376 |
-
))
|
377 |
-
fig.add_trace(go.Scatter(
|
378 |
-
x=df_curve['Training Size (%)'],
|
379 |
-
y=df_curve['Training Upper'],
|
380 |
-
mode='lines',
|
381 |
-
line=dict(width=0),
|
382 |
-
showlegend=False
|
383 |
-
))
|
384 |
-
fig.add_trace(go.Scatter(
|
385 |
-
x=df_curve['Training Size (%)'],
|
386 |
-
y=df_curve['Training Lower'],
|
387 |
-
mode='lines',
|
388 |
-
line=dict(width=0),
|
389 |
-
fill='tonexty',
|
390 |
-
fillcolor='rgba(0, 0, 255, 0.1)',
|
391 |
-
showlegend=False
|
392 |
-
))
|
393 |
-
|
394 |
-
# Add testing data with confidence interval
|
395 |
-
fig.add_trace(go.Scatter(
|
396 |
-
x=df_curve['Training Size (%)'],
|
397 |
-
y=df_curve['Testing Score'],
|
398 |
-
mode='lines+markers',
|
399 |
-
name='Testing Score',
|
400 |
-
line=dict(color='red', width=2),
|
401 |
-
marker=dict(size=8)
|
402 |
-
))
|
403 |
-
fig.add_trace(go.Scatter(
|
404 |
-
x=df_curve['Training Size (%)'],
|
405 |
-
y=df_curve['Testing Upper'],
|
406 |
-
mode='lines',
|
407 |
-
line=dict(width=0),
|
408 |
-
showlegend=False
|
409 |
-
))
|
410 |
-
fig.add_trace(go.Scatter(
|
411 |
-
x=df_curve['Training Size (%)'],
|
412 |
-
y=df_curve['Testing Lower'],
|
413 |
-
mode='lines',
|
414 |
-
line=dict(width=0),
|
415 |
-
fill='tonexty',
|
416 |
-
fillcolor='rgba(255, 0, 0, 0.1)',
|
417 |
-
showlegend=False
|
418 |
-
))
|
419 |
-
|
420 |
-
# Update layout
|
421 |
-
fig.update_layout(
|
422 |
-
title='Learning Curve',
|
423 |
-
xaxis_title='Training Set Size (%)',
|
424 |
-
yaxis_title='Score',
|
425 |
-
hovermode='x unified',
|
426 |
-
width=700,
|
427 |
-
height=400,
|
428 |
-
legend=dict(
|
429 |
-
orientation="h",
|
430 |
-
yanchor="bottom",
|
431 |
-
y=1.02,
|
432 |
-
xanchor="right",
|
433 |
-
x=1
|
434 |
-
)
|
435 |
-
)
|
436 |
-
|
437 |
-
return fig
|
438 |
-
|
439 |
-
def plot_shap_summary(model, X):
|
440 |
-
"""Create SHAP summary plot for model explainability"""
|
441 |
-
try:
|
442 |
-
# Create explainer based on model type
|
443 |
-
if hasattr(model, 'predict_proba'):
|
444 |
-
explainer = shap.Explainer(model)
|
445 |
-
else:
|
446 |
-
explainer = shap.Explainer(model)
|
447 |
-
|
448 |
-
# Calculate SHAP values
|
449 |
-
shap_values = explainer(X)
|
450 |
-
|
451 |
-
# Create the SHAP summary plot
|
452 |
-
plt.figure(figsize=(10, 8))
|
453 |
-
shap.summary_plot(shap_values, X, show=False)
|
454 |
-
fig = plt.gcf()
|
455 |
-
plt.tight_layout()
|
456 |
-
|
457 |
-
return fig
|
458 |
-
except Exception as e:
|
459 |
-
st.warning(f"Could not generate SHAP plot: {e}")
|
460 |
-
return None
|
461 |
-
|
462 |
-
# Data Exploration and Insights Generation
|
463 |
-
def generate_data_insights(df):
|
464 |
-
"""Generate comprehensive insights about the dataset"""
|
465 |
-
insights = {}
|
466 |
-
|
467 |
-
# Basic statistics
|
468 |
-
insights['shape'] = df.shape
|
469 |
-
insights['missing_values'] = df.isna().sum().sum()
|
470 |
-
insights['duplicate_rows'] = df.duplicated().sum()
|
471 |
-
|
472 |
-
# Column types
|
473 |
-
insights['numeric_columns'] = list(df.select_dtypes(include=['number']).columns)
|
474 |
-
insights['categorical_columns'] = list(df.select_dtypes(include=['object', 'category', 'bool']).columns)
|
475 |
-
insights['datetime_columns'] = []
|
476 |
-
for col in df.columns:
|
477 |
-
try:
|
478 |
-
if pd.to_datetime(df[col], errors='coerce').notna().any():
|
479 |
-
insights['datetime_columns'].append(col)
|
480 |
-
except:
|
481 |
-
pass
|
482 |
-
|
483 |
-
# Distribution statistics
|
484 |
-
insights['skewed_columns'] = []
|
485 |
-
for col in insights['numeric_columns']:
|
486 |
-
if abs(df[col].skew()) > 1.0:
|
487 |
-
insights['skewed_columns'].append((col, df[col].skew()))
|
488 |
-
|
489 |
-
# Correlation analysis
|
490 |
-
if len(insights['numeric_columns']) > 1:
|
491 |
-
corr_matrix = df[insights['numeric_columns']].corr().abs()
|
492 |
-
corr_pairs = []
|
493 |
-
for i in range(len(corr_matrix.columns)):
|
494 |
-
for j in range(i):
|
495 |
-
if corr_matrix.iloc[i, j] > 0.7: # Strong correlation threshold
|
496 |
-
corr_pairs.append((corr_matrix.columns[i], corr_matrix.columns[j], corr_matrix.iloc[i, j]))
|
497 |
-
insights['correlated_features'] = sorted(corr_pairs, key=lambda x: x[2], reverse=True)
|
498 |
-
|
499 |
-
# Categorical feature analysis
|
500 |
-
insights['high_cardinality_features'] = []
|
501 |
-
for col in insights['categorical_columns']:
|
502 |
-
if df[col].nunique() > 10:
|
503 |
-
insights['high_cardinality_features'].append((col, df[col].nunique()))
|
504 |
-
|
505 |
-
# Missing value patterns
|
506 |
-
insights['missing_patterns'] = []
|
507 |
-
for col in df.columns:
|
508 |
-
missing_pct = df[col].isna().mean() * 100
|
509 |
-
if missing_pct > 0:
|
510 |
-
insights['missing_patterns'].append((col, missing_pct))
|
511 |
-
|
512 |
-
# Outlier detection
|
513 |
-
insights['outlier_columns'] = []
|
514 |
-
for col in insights['numeric_columns']:
|
515 |
-
Q1 = df[col].quantile(0.25)
|
516 |
-
Q3 = df[col].quantile(0.75)
|
517 |
-
IQR = Q3 - Q1
|
518 |
-
outliers_count = ((df[col] < (Q1 - 1.5 * IQR)) | (df[col] > (Q3 + 1.5 * IQR))).sum()
|
519 |
-
if outliers_count > 0:
|
520 |
-
insights['outlier_columns'].append((col, outliers_count, outliers_count/len(df)*100))
|
521 |
-
|
522 |
-
return insights
|
523 |
-
|
524 |
-
# Enhanced Model Selection and Training Functions
|
525 |
-
def get_model_options(problem_type):
|
526 |
-
"""Get appropriate models for the selected problem type"""
|
527 |
-
if problem_type == "Classification":
|
528 |
-
return {
|
529 |
-
"Neural Network": MLPClassifier(max_iter=1000, random_state=42),
|
530 |
-
"Random Forest": RandomForestClassifier(random_state=42),
|
531 |
-
"Gradient Boosting": GradientBoostingClassifier(random_state=42),
|
532 |
-
"XGBoost": xgb.XGBClassifier(random_state=42)
|
533 |
-
}
|
534 |
-
elif problem_type == "Regression":
|
535 |
-
return {
|
536 |
-
"Neural Network": MLPRegressor(max_iter=1000, random_state=42),
|
537 |
-
"Random Forest": RandomForestRegressor(random_state=42),
|
538 |
-
"Gradient Boosting": GradientBoostingRegressor(random_state=42),
|
539 |
-
"XGBoost": xgb.XGBRegressor(random_state=42)
|
540 |
-
}
|
541 |
-
else: # Clustering
|
542 |
-
return {
|
543 |
-
"K-Means": KMeans(random_state=42),
|
544 |
-
"DBSCAN": DBSCAN(),
|
545 |
-
"Agglomerative": AgglomerativeClustering()
|
546 |
-
}
|
547 |
-
|
548 |
-
def train_model_with_optimization(model, X_train, X_test, y_train, y_test, problem_type, optimization_level="basic"):
|
549 |
-
"""Train model with optional hyperparameter optimization"""
|
550 |
-
start_time = time.time()
|
551 |
-
|
552 |
-
if optimization_level == "none":
|
553 |
-
# Simple fit without optimization
|
554 |
-
model.fit(X_train, y_train)
|
555 |
-
best_model = model
|
556 |
-
elif optimization_level == "basic":
|
557 |
-
# Basic parameter grid
|
558 |
-
param_grid = {}
|
559 |
-
|
560 |
-
if problem_type in ["Classification", "Regression"]:
|
561 |
-
if isinstance(model, (RandomForestClassifier, RandomForestRegressor)):
|
562 |
-
param_grid = {
|
563 |
-
'n_estimators': [100, 200],
|
564 |
-
'max_depth': [None, 10, 20]
|
565 |
-
}
|
566 |
-
elif isinstance(model, (GradientBoostingClassifier, GradientBoostingRegressor)):
|
567 |
-
param_grid = {
|
568 |
-
'n_estimators': [100, 200],
|
569 |
-
'learning_rate': [0.01, 0.1]
|
570 |
-
}
|
571 |
-
elif isinstance(model, (MLPClassifier, MLPRegressor)):
|
572 |
-
param_grid = {
|
573 |
-
'hidden_layer_sizes': [(100,), (100, 50)],
|
574 |
-
'alpha': [0.0001, 0.001]
|
575 |
-
}
|
576 |
-
elif "XGB" in str(model.__class__):
|
577 |
-
param_grid = {
|
578 |
-
'n_estimators': [100, 200],
|
579 |
-
'learning_rate': [0.01, 0.1],
|
580 |
-
'max_depth': [3, 6]
|
581 |
-
}
|
582 |
-
elif problem_type == "Clustering":
|
583 |
-
if isinstance(model, KMeans):
|
584 |
-
param_grid = {
|
585 |
-
'n_clusters': [3, 4, 5, 6]
|
586 |
-
}
|
587 |
-
elif isinstance(model, DBSCAN):
|
588 |
-
param_grid = {
|
589 |
-
'eps': [0.3, 0.5, 0.7],
|
590 |
-
'min_samples': [5, 10, 15]
|
591 |
-
}
|
592 |
-
elif isinstance(model, AgglomerativeClustering):
|
593 |
-
param_grid = {
|
594 |
-
'n_clusters': [3, 4, 5, 6],
|
595 |
-
'linkage': ['ward', 'complete', 'average']
|
596 |
-
}
|
597 |
-
|
598 |
-
# Only run GridSearchCV if we have parameters to optimize
|
599 |
-
if param_grid:
|
600 |
-
if problem_type == "Clustering":
|
601 |
-
# For clustering, use silhouette score as the metric
|
602 |
-
from sklearn.metrics import make_scorer, silhouette_score
|
603 |
-
from sklearn.model_selection import GridSearchCV
|
604 |
-
|
605 |
-
# Custom scorer for clustering
|
606 |
-
def silhouette_scorer(estimator, X):
|
607 |
-
labels = estimator.fit_predict(X)
|
608 |
-
if len(set(labels)) <= 1: # Check if all points are in one cluster
|
609 |
-
return -1
|
610 |
-
return silhouette_score(X, labels)
|
611 |
-
|
612 |
-
grid_search = GridSearchCV(
|
613 |
-
estimator=model,
|
614 |
-
param_grid=param_grid,
|
615 |
-
scoring=make_scorer(silhouette_scorer),
|
616 |
-
cv=3,
|
617 |
-
n_jobs=-1
|
618 |
-
)
|
619 |
-
grid_search.fit(X_train)
|
620 |
-
else:
|
621 |
-
# For classification and regression
|
622 |
-
scoring = 'accuracy' if problem_type == "Classification" else 'r2'
|
623 |
-
grid_search = GridSearchCV(
|
624 |
-
estimator=model,
|
625 |
-
param_grid=param_grid,
|
626 |
-
scoring=scoring,
|
627 |
-
cv=5,
|
628 |
-
n_jobs=-1
|
629 |
-
)
|
630 |
-
grid_search.fit(X_train, y_train)
|
631 |
-
|
632 |
-
best_model = grid_search.best_estimator_
|
633 |
-
else:
|
634 |
-
# If no param grid, just fit the model
|
635 |
-
model.fit(X_train, y_train)
|
636 |
-
best_model = model
|
637 |
-
else: # Advanced optimization
|
638 |
-
# TODO: Implement advanced optimization with more parameters,
|
639 |
-
# RandomizedSearchCV or BayesianOptimization
|
640 |
-
pass
|
641 |
-
|
642 |
-
# Calculate training time
|
643 |
-
training_time = time.time() - start_time
|
644 |
-
|
645 |
-
# Get predictions for evaluation
|
646 |
-
if problem_type == "Clustering":
|
647 |
-
if hasattr(best_model, 'predict'):
|
648 |
-
y_pred = best_model.predict(X_test)
|
649 |
-
else:
|
650 |
-
y_pred = best_model.fit_predict(X_test)
|
651 |
-
else:
|
652 |
-
y_pred = best_model.predict(X_test)
|
653 |
-
|
654 |
-
return best_model, y_pred, training_time
|
655 |
-
|
656 |
# Pages
|
657 |
def data_upload_page():
|
658 |
-
"
|
659 |
-
st.
|
660 |
-
|
661 |
-
# Create a header with animation
|
662 |
-
col1, col2 = st.columns([1, 3])
|
663 |
-
with col1:
|
664 |
-
st_lottie(lottie_upload, height=150, key="upload_animation")
|
665 |
-
with col2:
|
666 |
-
st.markdown('<div class="step-header">', unsafe_allow_html=True)
|
667 |
-
st.markdown('<div class="step-counter">1</div>', unsafe_allow_html=True)
|
668 |
-
st.markdown('<div class="step-title">Data Upload & Exploratory Analysis</div>', unsafe_allow_html=True)
|
669 |
-
st.markdown('</div>', unsafe_allow_html=True)
|
670 |
-
st.markdown("Upload your dataset and get comprehensive insights before model training.")
|
671 |
-
|
672 |
-
# File uploader with enhanced UI
|
673 |
-
uploaded_file = st.file_uploader("Upload Dataset (CSV, Excel, or JSON)",
|
674 |
-
type=["csv", "xlsx", "json"],
|
675 |
-
help="Upload your data file to start analysis")
|
676 |
|
677 |
if uploaded_file:
|
678 |
-
|
679 |
-
|
680 |
-
|
681 |
-
|
682 |
-
df = pd.read_csv(uploaded_file)
|
683 |
-
elif uploaded_file.name.endswith('xlsx'):
|
684 |
-
df = pd.read_excel(uploaded_file)
|
685 |
-
elif uploaded_file.name.endswith('json'):
|
686 |
-
df = pd.read_json(uploaded_file)
|
687 |
-
|
688 |
-
# Store in session state
|
689 |
-
st.session_state.df = df
|
690 |
-
st.session_state.vector_store = create_vector_store(convert_df_to_text(df))
|
691 |
-
st.session_state.metrics = {}
|
692 |
-
|
693 |
-
# Generate insights
|
694 |
-
st.session_state.dataset_insights = generate_data_insights(df)
|
695 |
-
|
696 |
-
# Success notification
|
697 |
-
if 'notification' not in st.session_state or st.session_state.notification is None:
|
698 |
-
st.session_state.notification = "Data successfully loaded! 🎉"
|
699 |
-
|
700 |
-
# Display a notification
|
701 |
-
st.markdown(f"""
|
702 |
-
<div class="notification">
|
703 |
-
<div class="notification-icon">✓</div>
|
704 |
-
<div>{st.session_state.notification}</div>
|
705 |
-
</div>
|
706 |
-
""", unsafe_allow_html=True)
|
707 |
-
st.session_state.notification = None
|
708 |
-
|
709 |
-
# Create tabs for different data views
|
710 |
-
data_tabs = st.tabs(["📊 Overview", "🔍 Data Explorer", "📈 Visualizations", "📋 Profile Report"])
|
711 |
-
|
712 |
-
with data_tabs[0]:
|
713 |
-
st.subheader("Dataset Overview")
|
714 |
-
col1, col2, col3 = st.columns(3)
|
715 |
-
|
716 |
-
with col1:
|
717 |
-
st.markdown('<div class="metrics-card">', unsafe_allow_html=True)
|
718 |
-
st.markdown(f'<div class="metrics-value">{df.shape[0]:,}</div>', unsafe_allow_html=True)
|
719 |
-
st.markdown('<div class="metrics-label">Total Samples</div>', unsafe_allow_html=True)
|
720 |
-
st.markdown('</div>', unsafe_allow_html=True)
|
721 |
-
|
722 |
-
with col2:
|
723 |
-
st.markdown('<div class="metrics-card">', unsafe_allow_html=True)
|
724 |
-
st.markdown(f'<div class="metrics-value">{df.shape[1]}</div>', unsafe_allow_html=True)
|
725 |
-
st.markdown('<div class="metrics-label">Features</div>', unsafe_allow_html=True)
|
726 |
-
st.markdown('</div>', unsafe_allow_html=True)
|
727 |
-
|
728 |
-
with col3:
|
729 |
-
missing_pct = df.isna().sum().sum() / (df.shape[0] * df.shape[1]) * 100
|
730 |
-
st.markdown('<div class="metrics-card">', unsafe_allow_html=True)
|
731 |
-
st.markdown(f'<div class="metrics-value">{missing_pct:.1f}%</div>', unsafe_allow_html=True)
|
732 |
-
st.markdown('<div class="metrics-label">Missing Values</div>', unsafe_allow_html=True)
|
733 |
-
st.markdown('</div>', unsafe_allow_html=True)
|
734 |
-
|
735 |
-
# Data Types Breakdown
|
736 |
-
st.subheader("Data Types")
|
737 |
-
dtype_counts = df.dtypes.value_counts().reset_index()
|
738 |
-
dtype_counts.columns = ['Data Type', 'Count']
|
739 |
-
fig = px.pie(dtype_counts, values='Count', names='Data Type', hole=0.4,
|
740 |
-
color_discrete_sequence=px.colors.qualitative.Bold)
|
741 |
-
fig.update_layout(margin=dict(t=0, b=0, l=0, r=0), height=300)
|
742 |
-
st.plotly_chart(fig, use_container_width=True)
|
743 |
-
|
744 |
-
# Key Insights
|
745 |
-
st.subheader("Key Insights")
|
746 |
-
insights = st.session_state.dataset_insights
|
747 |
-
|
748 |
-
insight_cols = st.columns(2)
|
749 |
-
|
750 |
-
with insight_cols[0]:
|
751 |
-
st.markdown("**Data Quality Issues**")
|
752 |
-
if insights['missing_patterns']:
|
753 |
-
st.markdown("🔹 **Missing Values:**")
|
754 |
-
for col, pct in sorted(insights['missing_patterns'], key=lambda x: x[1], reverse=True)[:5]:
|
755 |
-
st.markdown(f" • *{col}*: {pct:.1f}% missing")
|
756 |
-
|
757 |
-
if insights['outlier_columns']:
|
758 |
-
st.markdown("🔹 **Outliers Detected:**")
|
759 |
-
for col, count, pct in sorted(insights['outlier_columns'], key=lambda x: x[2], reverse=True)[:5]:
|
760 |
-
st.markdown(f" • *{col}*: {count} outliers ({pct:.1f}%)")
|
761 |
-
|
762 |
-
with insight_cols[1]:
|
763 |
-
st.markdown("**Feature Relationships**")
|
764 |
-
if 'correlated_features' in insights and insights['correlated_features']:
|
765 |
-
st.markdown("🔹 **Highly Correlated Features:**")
|
766 |
-
for col1, col2, corr in insights['correlated_features'][:5]:
|
767 |
-
st.markdown(f" • *{col1}* & *{col2}*: {corr:.2f} correlation")
|
768 |
-
|
769 |
-
if insights['skewed_columns']:
|
770 |
-
st.markdown("🔹 **Skewed Distributions:**")
|
771 |
-
for col, skew in sorted(insights['skewed_columns'], key=lambda x: abs(x[1]), reverse=True)[:5]:
|
772 |
-
direction = "right" if skew > 0 else "left"
|
773 |
-
st.markdown(f" • *{col}*: {direction}-skewed ({skew:.2f})")
|
774 |
-
|
775 |
-
with data_tabs[1]:
|
776 |
-
st.subheader("Interactive Data Explorer")
|
777 |
-
|
778 |
-
# Filter and search options
|
779 |
-
col1, col2 = st.columns([2, 3])
|
780 |
-
with col1:
|
781 |
-
search_term = st.text_input("Search columns", "")
|
782 |
-
with col2:
|
783 |
-
selected_dtypes = st.multiselect(
|
784 |
-
"Filter by data type",
|
785 |
-
options=['numeric', 'object', 'datetime', 'category', 'bool'],
|
786 |
-
default=['numeric', 'object']
|
787 |
-
)
|
788 |
-
|
789 |
-
# Apply filters
|
790 |
-
filtered_cols = df.columns
|
791 |
-
if search_term:
|
792 |
-
filtered_cols = [col for col in filtered_cols if search_term.lower() in col.lower()]
|
793 |
-
|
794 |
-
if selected_dtypes:
|
795 |
-
dtype_map = {
|
796 |
-
'numeric': 'number',
|
797 |
-
'object': 'object',
|
798 |
-
'datetime': 'datetime',
|
799 |
-
'category': 'category',
|
800 |
-
'bool': 'bool'
|
801 |
-
}
|
802 |
-
dtype_filtered = []
|
803 |
-
for dtype in selected_dtypes:
|
804 |
-
dtype_filtered.extend(df.select_dtypes(include=[dtype_map[dtype]]).columns)
|
805 |
-
filtered_cols = [col for col in filtered_cols if col in dtype_filtered]
|
806 |
-
|
807 |
-
# Display filtered dataframe
|
808 |
-
if filtered_cols:
|
809 |
-
st.dataframe(df[filtered_cols], height=400)
|
810 |
-
|
811 |
-
# Column statistics
|
812 |
-
selected_column = st.selectbox("Select column for detailed statistics", options=filtered_cols)
|
813 |
-
|
814 |
-
col1, col2 = st.columns(2)
|
815 |
-
with col1:
|
816 |
-
st.subheader(f"Statistics for: {selected_column}")
|
817 |
-
if pd.api.types.is_numeric_dtype(df[selected_column]):
|
818 |
-
stats = df[selected_column].describe()
|
819 |
-
for stat, value in stats.items():
|
820 |
-
st.markdown(f"**{stat}:** {value:.4f}")
|
821 |
-
|
822 |
-
# Additional stats
|
823 |
-
st.markdown(f"**Skewness:** {df[selected_column].skew():.4f}")
|
824 |
-
st.markdown(f"**Kurtosis:** {df[selected_column].kurtosis():.4f}")
|
825 |
-
else:
|
826 |
-
st.markdown(f"**Unique Values:** {df[selected_column].nunique()}")
|
827 |
-
st.markdown(f"**Most Common:** {df[selected_column].value_counts().index[0]}")
|
828 |
-
st.markdown(f"**Least Common:** {df[selected_column].value_counts().index[-1]}")
|
829 |
-
st.markdown(f"**Missing Values:** {df[selected_column].isna().sum()} ({df[selected_column].isna().mean()*100:.2f}%)")
|
830 |
-
|
831 |
-
with col2:
|
832 |
-
st.subheader("Distribution")
|
833 |
-
if pd.api.types.is_numeric_dtype(df[selected_column]):
|
834 |
-
# Histogram for numeric
|
835 |
-
fig = px.histogram(df, x=selected_column, histnorm='probability density',
|
836 |
-
marginal='box', color_discrete_sequence=['#3B82F6'])
|
837 |
-
fig.update_layout(height=300, margin=dict(l=0, r=0, t=20, b=0))
|
838 |
-
else:
|
839 |
-
# Bar chart for categorical
|
840 |
-
value_counts = df[selected_column].value_counts().reset_index()
|
841 |
-
value_counts.columns = [selected_column, 'Count']
|
842 |
-
value_counts = value_counts.head(15) # Limit to top 15
|
843 |
-
fig = px.bar(value_counts, x=selected_column, y='Count',
|
844 |
-
color_discrete_sequence=['#3B82F6'])
|
845 |
-
fig.update_layout(height=300, margin=dict(l=0, r=0, t=20, b=0))
|
846 |
-
|
847 |
-
st.plotly_chart(fig, use_container_width=True)
|
848 |
-
else:
|
849 |
-
st.warning("No columns match your filters.")
|
850 |
|
851 |
-
|
852 |
-
|
853 |
-
|
854 |
-
|
855 |
-
|
856 |
-
options=["Scatter Plot", "Correlation Matrix", "Pair Plot", "Box Plot", "Violin Plot", "Line Chart"]
|
857 |
-
)
|
858 |
-
|
859 |
-
if viz_type == "Scatter Plot":
|
860 |
-
col1, col2, col3 = st.columns(3)
|
861 |
-
with col1:
|
862 |
-
x_col = st.selectbox("X-axis", options=df.select_dtypes(include=['number']).columns)
|
863 |
-
with col2:
|
864 |
-
y_col = st.selectbox("Y-axis", options=df.select_dtypes(include=['number']).columns,
|
865 |
-
index=min(1, len(df.select_dtypes(include=['number']).columns)-1))
|
866 |
-
with col3:
|
867 |
-
color_col = st.selectbox("Color by", options=["None"] + list(df.columns), index=0)
|
868 |
-
|
869 |
-
# Create plot
|
870 |
-
if color_col == "None":
|
871 |
-
fig = px.scatter(df, x=x_col, y=y_col, title=f"{x_col} vs {y_col}",
|
872 |
-
opacity=0.7, color_discrete_sequence=['#3B82F6'])
|
873 |
-
else:
|
874 |
-
fig = px.scatter(df, x=x_col, y=y_col, color=color_col, title=f"{x_col} vs {y_col} by {color_col}",
|
875 |
-
opacity=0.7)
|
876 |
-
|
877 |
-
fig.update_layout(height=500)
|
878 |
-
st.plotly_chart(fig, use_container_width=True)
|
879 |
-
|
880 |
-
# Add regression line option
|
881 |
-
if st.checkbox("Add regression line"):
|
882 |
-
fig = px.scatter(df, x=x_col, y=y_col, trendline="ols",
|
883 |
-
title=f"{x_col} vs {y_col} with Regression Line",
|
884 |
-
opacity=0.7, color_discrete_sequence=['#3B82F6'])
|
885 |
-
fig.update_layout(height=500)
|
886 |
-
st.plotly_chart(fig, use_container_width=True)
|
887 |
-
|
888 |
-
elif viz_type == "Correlation Matrix":
|
889 |
-
# Select columns for correlation
|
890 |
-
numeric_cols = df.select_dtypes(include=['number']).columns
|
891 |
-
selected_corr_cols = st.multiselect(
|
892 |
-
"Select columns for correlation matrix",
|
893 |
-
options=numeric_cols,
|
894 |
-
default=list(numeric_cols)[:min(8, len(numeric_cols))]
|
895 |
-
)
|
896 |
-
|
897 |
-
if selected_corr_cols:
|
898 |
-
# Correlation matrix
|
899 |
-
corr = df[selected_corr_cols].corr()
|
900 |
-
mask = np.triu(np.ones_like(corr, dtype=bool))
|
901 |
-
|
902 |
-
# Create plot
|
903 |
-
fig = px.imshow(corr, text_auto=True, color_continuous_scale='RdBu_r',
|
904 |
-
zmin=-1, zmax=1, aspect="auto")
|
905 |
-
fig.update_layout(height=600)
|
906 |
-
|
907 |
-
st.plotly_chart(fig, use_container_width=True)
|
908 |
-
else:
|
909 |
-
st.warning("Please select at least one numeric column.")
|
910 |
-
|
911 |
-
elif viz_type == "Pair Plot":
|
912 |
-
# Select columns for pair plot
|
913 |
-
numeric_cols = df.select_dtypes(include=['number']).columns
|
914 |
-
selected_pair_cols = st.multiselect(
|
915 |
-
"Select columns for pair plot (limit 4-5 for readability)",
|
916 |
-
options=numeric_cols,
|
917 |
-
default=list(numeric_cols)[:min(4, len(numeric_cols))]
|
918 |
-
)
|
919 |
-
|
920 |
-
if selected_pair_cols:
|
921 |
-
if len(selected_pair_cols) > 6:
|
922 |
-
st.warning("Too many columns may make the plot hard to read. Consider selecting 4-5 columns.")
|
923 |
-
|
924 |
-
# Color option
|
925 |
-
color_col = st.selectbox("Color by (categorical)",
|
926 |
-
options=["None"] + list(df.select_dtypes(exclude=['number']).columns),
|
927 |
-
index=0)
|
928 |
-
|
929 |
-
# Create pair plot
|
930 |
-
if color_col == "None":
|
931 |
-
fig = px.scatter_matrix(df, dimensions=selected_pair_cols, opacity=0.7)
|
932 |
-
else:
|
933 |
-
fig = px.scatter_matrix(df, dimensions=selected_pair_cols, color=color_col, opacity=0.7)
|
934 |
-
|
935 |
-
fig.update_layout(height=700)
|
936 |
-
st.plotly_chart(fig, use_container_width=True)
|
937 |
-
else:
|
938 |
-
st.warning("Please select at least one numeric column.")
|
939 |
-
|
940 |
-
elif viz_type == "Box Plot":
|
941 |
-
col1, col2 = st.columns(2)
|
942 |
-
with col1:
|
943 |
-
y_col = st.selectbox("Value column", options=df.select_dtypes(include=['number']).columns)
|
944 |
-
with col2:
|
945 |
-
x_col = st.selectbox("Category column",
|
946 |
-
options=["None"] + list(df.select_dtypes(exclude=['number']).columns),
|
947 |
-
index=0)
|
948 |
-
|
949 |
-
# Create plot
|
950 |
-
if x_col == "None":
|
951 |
-
fig = px.box(df, y=y_col, title=f"Distribution of {y_col}",
|
952 |
-
color_discrete_sequence=['#3B82F6'])
|
953 |
-
else:
|
954 |
-
fig = px.box(df, x=x_col, y=y_col, title=f"Distribution of {y_col} by {x_col}")
|
955 |
-
|
956 |
-
fig.update_layout(height=500)
|
957 |
-
st.plotly_chart(fig, use_container_width=True)
|
958 |
-
|
959 |
-
elif viz_type == "Violin Plot":
|
960 |
-
col1, col2 = st.columns(2)
|
961 |
-
with col1:
|
962 |
-
y_col = st.selectbox("Value column", options=df.select_dtypes(include=['number']).columns, key="violin_y")
|
963 |
-
with col2:
|
964 |
-
x_col = st.selectbox("Category column",
|
965 |
-
options=["None"] + list(df.select_dtypes(exclude=['number']).columns),
|
966 |
-
index=0, key="violin_x")
|
967 |
-
|
968 |
-
# Create plot
|
969 |
-
if x_col == "None":
|
970 |
-
fig = px.violin(df, y=y_col, box=True, title=f"Distribution of {y_col}",
|
971 |
-
color_discrete_sequence=['#3B82F6'])
|
972 |
-
else:
|
973 |
-
fig = px.violin(df, x=x_col, y=y_col, box=True, title=f"Distribution of {y_col} by {x_col}")
|
974 |
-
|
975 |
-
fig.update_layout(height=500)
|
976 |
-
st.plotly_chart(fig, use_container_width=True)
|
977 |
-
|
978 |
-
elif viz_type == "Line Chart":
|
979 |
-
# Identify potential date columns
|
980 |
-
date_cols = []
|
981 |
-
for col in df.columns:
|
982 |
-
try:
|
983 |
-
if pd.to_datetime(df[col], errors='coerce').notna().all():
|
984 |
-
date_cols.append(col)
|
985 |
-
except:
|
986 |
-
pass
|
987 |
-
|
988 |
-
if date_cols:
|
989 |
-
col1, col2 = st.columns(2)
|
990 |
-
with col1:
|
991 |
-
x_col = st.selectbox("Time axis", options=date_cols)
|
992 |
-
# Convert to datetime if not already
|
993 |
-
df[x_col] = pd.to_datetime(df[x_col])
|
994 |
-
with col2:
|
995 |
-
y_cols = st.multiselect("Value columns", options=df.select_dtypes(include=['number']).columns,
|
996 |
-
default=[df.select_dtypes(include=['number']).columns[0]])
|
997 |
-
|
998 |
-
if y_cols:
|
999 |
-
# Create line chart
|
1000 |
-
fig = go.Figure()
|
1001 |
-
for y_col in y_cols:
|
1002 |
-
fig.add_trace(go.Scatter(x=df[x_col], y=df[y_col], mode='lines', name=y_col))
|
1003 |
-
|
1004 |
-
fig.update_layout(
|
1005 |
-
title=f"Time Series Plot",
|
1006 |
-
xaxis_title=x_col,
|
1007 |
-
yaxis_title="Values",
|
1008 |
-
height=500,
|
1009 |
-
legend=dict(orientation="h", yanchor="bottom", y=1.02, xanchor="right", x=1)
|
1010 |
-
)
|
1011 |
-
|
1012 |
-
st.plotly_chart(fig, use_container_width=True)
|
1013 |
-
else:
|
1014 |
-
st.warning("Please select at least one value column.")
|
1015 |
-
else:
|
1016 |
-
st.warning("No datetime columns detected. Please ensure you have columns with date/time values.")
|
1017 |
|
1018 |
-
|
1019 |
-
st.
|
1020 |
-
|
1021 |
-
|
1022 |
-
with profile_options[0]:
|
1023 |
-
minimal = st.checkbox("Minimal Report (Faster)", value=True)
|
1024 |
-
with profile_options[1]:
|
1025 |
-
sample_data = st.checkbox("Use Sample (Faster for large datasets)", value=True)
|
1026 |
-
with profile_options[2]:
|
1027 |
-
report_percent = st.slider("Sample Size %", min_value=10, max_value=100, value=50, step=10)
|
1028 |
-
|
1029 |
-
if st.button("Generate Profile Report"):
|
1030 |
-
with st.spinner("Generating comprehensive profile report..."):
|
1031 |
-
if sample_data and len(df) > 1000:
|
1032 |
-
profile_df = df.sample(int(len(df) * report_percent/100))
|
1033 |
-
else:
|
1034 |
-
profile_df = df
|
1035 |
-
|
1036 |
-
if minimal:
|
1037 |
-
profile = ProfileReport(profile_df, minimal=True, title="Dataset Profile Report")
|
1038 |
-
else:
|
1039 |
-
profile = ProfileReport(profile_df, explorative=True, title="Dataset Profile Report")
|
1040 |
-
|
1041 |
-
st_profile_report(profile)
|
1042 |
-
|
1043 |
-
st.markdown('</div>', unsafe_allow_html=True)
|
1044 |
|
1045 |
def model_training_page():
|
1046 |
-
"
|
1047 |
-
st.markdown('<div class="card">', unsafe_allow_html=True)
|
1048 |
-
|
1049 |
-
# Header with animation
|
1050 |
-
col1, col2 = st.columns([1, 3])
|
1051 |
-
with col1:
|
1052 |
-
st_lottie(lottie_neural, height=150, key="neural_animation")
|
1053 |
-
with col2:
|
1054 |
-
st.markdown('<div class="step-header">', unsafe_allow_html=True)
|
1055 |
-
st.markdown('<div class="step-counter">2</div>', unsafe_allow_html=True)
|
1056 |
-
st.markdown('<div class="step-title">Neural Network Training Studio</div>', unsafe_allow_html=True)
|
1057 |
-
st.markdown('</div>', unsafe_allow_html=True)
|
1058 |
-
st.markdown("Train advanced models with automated optimization and hyperparameter tuning.")
|
1059 |
|
1060 |
if 'df' not in st.session_state:
|
1061 |
-
st.warning("
|
1062 |
-
st.markdown('</div>', unsafe_allow_html=True)
|
1063 |
return
|
1064 |
|
1065 |
df = st.session_state.df
|
|
|
|
|
1066 |
|
1067 |
-
|
1068 |
-
|
|
|
|
|
|
|
|
|
|
|
1069 |
|
1070 |
-
|
1071 |
-
st.
|
1072 |
-
|
1073 |
-
|
1074 |
-
problem_type = st.selectbox(
|
1075 |
-
"Select Problem Type",
|
1076 |
-
["Classification", "Regression", "Clustering"],
|
1077 |
-
help="Classification: predict categories, Regression: predict continuous values, Clustering: group similar data points"
|
1078 |
-
)
|
1079 |
-
|
1080 |
-
# Domain specialization
|
1081 |
-
domain_col1, domain_col2 = st.columns(2)
|
1082 |
-
with domain_col1:
|
1083 |
-
mode = st.selectbox(
|
1084 |
-
"Domain Specialization",
|
1085 |
-
["General", "Legal", "Financial", "Medical", "Technical", "Academic"],
|
1086 |
-
help="Optimize the model for your specific domain"
|
1087 |
-
)
|
1088 |
-
|
1089 |
-
with domain_col2:
|
1090 |
-
experiment_name = st.text_input(
|
1091 |
-
"Experiment Name",
|
1092 |
-
value=f"{problem_type}_{datetime.datetime.now().strftime('%Y%m%d_%H%M')}",
|
1093 |
-
help="Name your experiment for reference"
|
1094 |
-
)
|
1095 |
-
|
1096 |
-
# Target variable selection
|
1097 |
-
if problem_type != "Clustering":
|
1098 |
-
target_col1, target_col2 = st.columns(2)
|
1099 |
-
|
1100 |
-
with target_col1:
|
1101 |
-
target = st.selectbox("Select Target Variable", df.columns)
|
1102 |
-
|
1103 |
-
with target_col2:
|
1104 |
-
if problem_type == "Classification":
|
1105 |
-
st.info(f"Class Distribution: {df[target].value_counts().to_dict()}")
|
1106 |
-
else:
|
1107 |
-
st.info(f"Target Range: {df[target].min()} to {df[target].max()}")
|
1108 |
|
1109 |
-
|
1110 |
-
|
1111 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1112 |
|
1113 |
-
|
1114 |
-
|
1115 |
-
|
1116 |
-
|
1117 |
-
|
1118 |
-
|
1119 |
-
)
|
1120 |
-
st.session_state.selected_columns = selected_features + [target]
|
1121 |
-
else:
|
1122 |
-
st.session_state.selected_columns = df.columns.tolist()
|
1123 |
-
else:
|
1124 |
-
# For clustering, all columns are features
|
1125 |
-
available_features = df.columns.tolist()
|
1126 |
-
selected_features = st.multiselect(
|
1127 |
-
"Select features for clustering",
|
1128 |
-
options=available_features,
|
1129 |
-
default=available_features[:min(5, len(available_features))]
|
1130 |
-
)
|
1131 |
-
st.session_state.selected_columns = selected_features
|
1132 |
-
|
1133 |
-
# Rest of the model_training_page function would go here...
|
1134 |
-
st.markdown('</div>', unsafe_allow_html=True)
|
1135 |
|
1136 |
def visualization_page():
|
1137 |
-
"
|
1138 |
-
st.markdown('<div class="card">', unsafe_allow_html=True)
|
1139 |
-
|
1140 |
-
# Header with animation
|
1141 |
-
col1, col2 = st.columns([1, 3])
|
1142 |
-
with col1:
|
1143 |
-
st_lottie(lottie_visualization, height=150, key="viz_animation")
|
1144 |
-
with col2:
|
1145 |
-
st.markdown('<div class="step-header">', unsafe_allow_html=True)
|
1146 |
-
st.markdown('<div class="step-counter">3</div>', unsafe_allow_html=True)
|
1147 |
-
st.markdown('<div class="step-title">Neural Network Evaluation Center</div>', unsafe_allow_html=True)
|
1148 |
-
st.markdown('</div>', unsafe_allow_html=True)
|
1149 |
-
st.markdown("Visualize, interpret, and validate your trained neural networks.")
|
1150 |
|
1151 |
if 'best_model' not in st.session_state:
|
1152 |
-
st.warning("
|
1153 |
-
st.markdown('</div>', unsafe_allow_html=True)
|
1154 |
return
|
1155 |
|
1156 |
-
|
1157 |
-
|
1158 |
-
|
1159 |
-
|
|
|
|
|
|
|
|
|
|
|
1160 |
|
1161 |
-
st.
|
|
|
1162 |
|
|
|
1163 |
def ai_assistant():
|
1164 |
-
"""AI Assistant for neural network development guidance"""
|
1165 |
st.markdown('<div class="chat-container">', unsafe_allow_html=True)
|
1166 |
-
st.subheader("
|
1167 |
-
|
1168 |
-
user_input = st.text_area("Ask a question about your data or neural network development:", "")
|
1169 |
-
use_web_search = st.checkbox("Enable web search for up-to-date information", value=False)
|
1170 |
|
1171 |
-
|
1172 |
-
|
1173 |
-
with st.spinner("Analyzing your question..."):
|
1174 |
-
# Add user message to chat history
|
1175 |
-
st.session_state.chat_history.append({"role": "user", "content": user_input})
|
1176 |
-
for msg in st.session_state.chat_history:
|
1177 |
-
if msg["role"] == "user":
|
1178 |
-
st.markdown(f'<div class="user-message">{msg["content"]}</div>', unsafe_allow_html=True)
|
1179 |
-
else:
|
1180 |
-
st.markdown(f'<div class="bot-message">{msg["content"]}</div>', unsafe_allow_html=True)
|
1181 |
-
|
1182 |
-
# Generate response
|
1183 |
-
try:
|
1184 |
-
ai_response = get_groq_response(user_input, st.session_state.get('mode', 'General'), use_web_search)
|
1185 |
-
st.session_state.chat_history.append({"role": "assistant", "content": ai_response})
|
1186 |
-
|
1187 |
-
st.markdown(f'<div class="bot-message">{ai_response}</div>', unsafe_allow_html=True)
|
1188 |
-
except Exception as e:
|
1189 |
-
st.error(f"Error getting AI response: {str(e)}")
|
1190 |
-
st.info("Falling back to alternative model...")
|
1191 |
-
try:
|
1192 |
-
# Fallback to OpenAI
|
1193 |
-
ai_response = "I'm sorry, I couldn't generate a proper response. Please try rephrasing your question."
|
1194 |
-
st.session_state.chat_history.append({"role": "assistant", "content": ai_response})
|
1195 |
-
st.markdown(f'<div class="bot-message">{ai_response}</div>', unsafe_allow_html=True)
|
1196 |
-
except:
|
1197 |
-
st.error("Both primary and fallback AI services failed. Please try again later.")
|
1198 |
|
1199 |
-
st.
|
1200 |
-
|
1201 |
-
|
1202 |
-
if 'notification' not in st.session_state:
|
1203 |
-
st.session_state.notification = None
|
1204 |
-
|
1205 |
-
# Import and initialize Lottie animations
|
1206 |
-
def load_lottie_url(url):
|
1207 |
-
"""Load Lottie animation from URL"""
|
1208 |
-
try:
|
1209 |
-
import json
|
1210 |
-
import requests
|
1211 |
-
r = requests.get(url)
|
1212 |
-
if r.status_code != 200:
|
1213 |
-
return None
|
1214 |
-
return r.json()
|
1215 |
-
except:
|
1216 |
-
return None
|
1217 |
-
|
1218 |
-
# Lottie animations
|
1219 |
-
lottie_upload = load_lottie_url("https://assets9.lottiefiles.com/packages/lf20_grdj1jti.json")
|
1220 |
-
lottie_neural = load_lottie_url("https://assets8.lottiefiles.com/private_files/lf30_8uvz2gcg.json")
|
1221 |
-
lottie_visualization = load_lottie_url("https://assets5.lottiefiles.com/packages/lf20_usmfx6bp.json")
|
1222 |
-
|
1223 |
-
# Main function to run the app
|
1224 |
-
def main():
|
1225 |
-
"""Main function to run the app"""
|
1226 |
-
# Main App Layout
|
1227 |
-
st.markdown("""
|
1228 |
-
<div class="header">
|
1229 |
-
<h1 class="header-title">Neural-Vision Enhanced</h1>
|
1230 |
-
<div class="header-subtitle">Neural Network Development for Domain-Specialized Analysis</div>
|
1231 |
-
</div>
|
1232 |
-
""", unsafe_allow_html=True)
|
1233 |
|
1234 |
-
|
1235 |
-
st.
|
1236 |
-
|
1237 |
-
|
1238 |
-
"Neural Network Training Studio",
|
1239 |
-
"Neural Network Evaluation Center"
|
1240 |
-
])
|
1241 |
-
st.session_state.active_page = page
|
1242 |
-
st.markdown("---")
|
1243 |
-
st.markdown("**Environment Setup**")
|
1244 |
|
1245 |
-
|
1246 |
-
|
1247 |
-
|
1248 |
-
if tavily_api_input:
|
1249 |
-
st.session_state.tavily_api_key = tavily_api_input
|
1250 |
-
st.success("Tavily API Key submitted successfully!")
|
1251 |
-
else:
|
1252 |
-
st.warning("Please enter a valid API key.")
|
1253 |
|
1254 |
-
st.
|
1255 |
-
|
1256 |
|
1257 |
-
|
1258 |
-
|
1259 |
-
|
1260 |
-
|
1261 |
-
|
1262 |
-
|
1263 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1264 |
|
1265 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1266 |
|
1267 |
-
|
1268 |
-
if __name__ == "__main__":
|
1269 |
-
main()
|
|
|
1 |
import streamlit as st
|
2 |
import pandas as pd
|
3 |
import plotly.express as px
|
|
|
4 |
import numpy as np
|
5 |
from sklearn.model_selection import train_test_split
|
6 |
from sklearn.neural_network import MLPClassifier, MLPRegressor
|
7 |
+
from sklearn.cluster import KMeans
|
|
|
8 |
from sklearn.metrics import accuracy_score, r2_score, silhouette_score, confusion_matrix, classification_report, mean_squared_error
|
9 |
from sklearn.preprocessing import StandardScaler
|
10 |
from ydata_profiling import ProfileReport
|
11 |
from streamlit_pandas_profiling import st_profile_report
|
|
|
12 |
from groq import Groq
|
13 |
from langchain_community.vectorstores import FAISS
|
14 |
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
|
|
18 |
import os
|
19 |
from dotenv import load_dotenv
|
20 |
import tempfile
|
|
|
|
|
|
|
|
|
|
|
21 |
|
22 |
# Load environment variables
|
23 |
load_dotenv()
|
|
|
141 |
transform: translateY(-2px);
|
142 |
box-shadow: 0 6px 16px rgba(59, 130, 246, 0.4);
|
143 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
144 |
</style>
|
145 |
""", unsafe_allow_html=True)
|
146 |
|
|
|
234 |
fig = px.scatter(X, x=X.columns[0], y=X.columns[1], color=labels, title="Cluster Visualization")
|
235 |
return fig
|
236 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
237 |
# Pages
|
238 |
def data_upload_page():
|
239 |
+
st.header("📤 Data Upload & Analysis")
|
240 |
+
uploaded_file = st.file_uploader("Upload Dataset", type=["csv"])
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
241 |
|
242 |
if uploaded_file:
|
243 |
+
df = pd.read_csv(uploaded_file)
|
244 |
+
st.session_state.df = df
|
245 |
+
st.session_state.vector_store = create_vector_store(convert_df_to_text(df))
|
246 |
+
st.session_state.metrics = {}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
247 |
|
248 |
+
st.subheader("Dataset Health Check")
|
249 |
+
col1, col2, col3 = st.columns(3)
|
250 |
+
col1.metric("Total Samples", df.shape[0])
|
251 |
+
col2.metric("Features", df.shape[1])
|
252 |
+
col3.metric("Missing Values", df.isna().sum().sum())
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
253 |
|
254 |
+
if st.button("Generate Full EDA Report"):
|
255 |
+
with st.spinner("Generating comprehensive analysis..."):
|
256 |
+
profile = ProfileReport(df, explorative=True)
|
257 |
+
st_profile_report(profile)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
258 |
|
259 |
def model_training_page():
|
260 |
+
st.header("🧠 Neural Network Training Studio")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
261 |
|
262 |
if 'df' not in st.session_state:
|
263 |
+
st.warning("Upload data first!")
|
|
|
264 |
return
|
265 |
|
266 |
df = st.session_state.df
|
267 |
+
problem_type = st.selectbox("Select Problem Type", ["Classification", "Regression", "Clustering"])
|
268 |
+
mode = st.selectbox("Domain Specialization", ["Legal", "Financial", "Academic", "Technical"])
|
269 |
|
270 |
+
if problem_type != "Clustering":
|
271 |
+
target = st.selectbox("Select Target Variable", df.columns)
|
272 |
+
X = df.drop(columns=[target])
|
273 |
+
y = df[target]
|
274 |
+
else:
|
275 |
+
X = df
|
276 |
+
y = None
|
277 |
|
278 |
+
if st.button("Train Neural Network"):
|
279 |
+
with st.spinner("Training in progress..."):
|
280 |
+
X_scaled = StandardScaler().fit_transform(X)
|
281 |
+
X_train, X_test, y_train, y_test = train_test_split(X_scaled, y, test_size=0.2, random_state=42) if y is not None else (X_scaled, None, None, None)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
282 |
|
283 |
+
if problem_type == "Classification":
|
284 |
+
model = MLPClassifier(hidden_layer_sizes=(100, 50), max_iter=500, random_state=42)
|
285 |
+
model.fit(X_train, y_train)
|
286 |
+
y_pred = model.predict(X_test)
|
287 |
+
st.session_state.metrics = {
|
288 |
+
"Accuracy": accuracy_score(y_test, y_pred),
|
289 |
+
"Classification Report": classification_report(y_test, y_pred, output_dict=True)
|
290 |
+
}
|
291 |
+
elif problem_type == "Regression":
|
292 |
+
model = MLPRegressor(hidden_layer_sizes=(100, 50), max_iter=500, random_state=42)
|
293 |
+
model.fit(X_train, y_train)
|
294 |
+
y_pred = model.predict(X_test)
|
295 |
+
st.session_state.metrics = {
|
296 |
+
"R2 Score": r2_score(y_test, y_pred),
|
297 |
+
"Mean Squared Error": mean_squared_error(y_test, y_pred)
|
298 |
+
}
|
299 |
+
else: # Clustering
|
300 |
+
model = KMeans(n_clusters=3, random_state=42)
|
301 |
+
labels = model.fit_predict(X_scaled)
|
302 |
+
st.session_state.metrics = {
|
303 |
+
"Silhouette Score": silhouette_score(X_scaled, labels)
|
304 |
+
}
|
305 |
|
306 |
+
st.session_state.best_model = model
|
307 |
+
st.session_state.X_test = X_test
|
308 |
+
st.session_state.y_test = y_test
|
309 |
+
st.session_state.y_pred = y_pred if y is not None else labels
|
310 |
+
st.session_state.problem_type = problem_type
|
311 |
+
st.success(f"Model trained successfully in {mode} mode!")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
312 |
|
313 |
def visualization_page():
|
314 |
+
st.header("🔍 Neural Network Evaluation Center")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
315 |
|
316 |
if 'best_model' not in st.session_state:
|
317 |
+
st.warning("Train a model first!")
|
|
|
318 |
return
|
319 |
|
320 |
+
st.subheader("Performance Analysis")
|
321 |
+
if st.session_state.problem_type == "Classification":
|
322 |
+
st.plotly_chart(plot_confusion_matrix(st.session_state.y_test, st.session_state.y_pred))
|
323 |
+
st.plotly_chart(plot_feature_importance(st.session_state.best_model, pd.DataFrame(st.session_state.X_test, columns=st.session_state.df.columns[:-1])))
|
324 |
+
elif st.session_state.problem_type == "Regression":
|
325 |
+
st.plotly_chart(plot_residuals(st.session_state.y_test, st.session_state.y_pred))
|
326 |
+
st.plotly_chart(plot_feature_importance(st.session_state.best_model, pd.DataFrame(st.session_state.X_test, columns=st.session_state.df.columns[:-1])))
|
327 |
+
else: # Clustering
|
328 |
+
st.plotly_chart(plot_clusters(pd.DataFrame(st.session_state.X_test, columns=st.session_state.df.columns), st.session_state.y_pred))
|
329 |
|
330 |
+
st.subheader("Metrics")
|
331 |
+
st.write(st.session_state.metrics)
|
332 |
|
333 |
+
# Chatbot Interface
|
334 |
def ai_assistant():
|
|
|
335 |
st.markdown('<div class="chat-container">', unsafe_allow_html=True)
|
336 |
+
st.subheader("🧠 Neural Insight Assistant (RAG + Web Search)")
|
|
|
|
|
|
|
337 |
|
338 |
+
use_web_search = st.checkbox("Enable Tavily Web Search", value=False)
|
339 |
+
mode = st.selectbox("Domain Mode", ["Legal", "Financial", "Academic", "Technical"], key="chat_mode")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
340 |
|
341 |
+
for msg in st.session_state.chat_history:
|
342 |
+
with st.chat_message(msg["role"]):
|
343 |
+
st.markdown(f'<div class="{msg["role"]}-message">{msg["content"]}</div>', unsafe_allow_html=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
344 |
|
345 |
+
if prompt := st.chat_input("Ask about data, models, or web insights..."):
|
346 |
+
st.session_state.chat_history.append({"role": "user", "content": prompt})
|
347 |
+
with st.chat_message("user"):
|
348 |
+
st.markdown(f'<div class="user-message">{prompt}</div>', unsafe_allow_html=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
349 |
|
350 |
+
with st.spinner("Processing..."):
|
351 |
+
response = get_groq_response(prompt, mode, use_web_search)
|
352 |
+
st.session_state.chat_history.append({"role": "assistant", "content": response})
|
|
|
|
|
|
|
|
|
|
|
353 |
|
354 |
+
with st.chat_message("assistant"):
|
355 |
+
st.markdown(f'<div class="bot-message">{response}</div>', unsafe_allow_html=True)
|
356 |
|
357 |
+
st.markdown('</div>', unsafe_allow_html=True)
|
358 |
+
|
359 |
+
# Main App Layout
|
360 |
+
st.markdown("""
|
361 |
+
<div class="header">
|
362 |
+
<h1 class="header-title">Neural-Vision Enhanced</h1>
|
363 |
+
<div class="header-subtitle">Neural Network Development for Domain-Specialized Analysis</div>
|
364 |
+
</div>
|
365 |
+
""", unsafe_allow_html=True)
|
366 |
+
|
367 |
+
with st.sidebar:
|
368 |
+
st.title("🔮 Neural-Vision Enhanced")
|
369 |
+
page = st.selectbox("Navigation", [
|
370 |
+
"Data Upload & Analysis",
|
371 |
+
"Neural Network Training Studio",
|
372 |
+
"Neural Network Evaluation Center"
|
373 |
+
])
|
374 |
+
st.session_state.active_page = page
|
375 |
+
st.markdown("---")
|
376 |
+
st.markdown("**Environment Setup**")
|
377 |
+
|
378 |
+
# Tavily API Key Input and Submit Button
|
379 |
+
tavily_api_input = st.text_input("Tavily API Key", type="password", help="Enter your Tavily API key for web search functionality")
|
380 |
+
if st.button("Submit API Key"):
|
381 |
+
if tavily_api_input:
|
382 |
+
st.session_state.tavily_api_key = tavily_api_input
|
383 |
+
st.success("Tavily API Key submitted successfully!")
|
384 |
+
else:
|
385 |
+
st.warning("Please enter a valid API key.")
|
386 |
|
387 |
+
st.markdown("---")
|
388 |
+
st.markdown("v5.0 | © 2025 Neural-Vision")
|
389 |
+
|
390 |
+
# Page Routing
|
391 |
+
if "Data Upload & Analysis" in page:
|
392 |
+
data_upload_page()
|
393 |
+
elif "Neural Network Training Studio" in page:
|
394 |
+
model_training_page()
|
395 |
+
else:
|
396 |
+
visualization_page()
|
397 |
|
398 |
+
ai_assistant()
|
|
|
|