Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
@@ -12,12 +12,45 @@ from sklearn.ensemble import RandomForestClassifier, RandomForestRegressor
|
|
12 |
from sklearn.metrics import accuracy_score, mean_squared_error
|
13 |
from ydata_profiling import ProfileReport
|
14 |
from streamlit_pandas_profiling import st_profile_report
|
15 |
-
import joblib
|
16 |
-
import os # For file directory
|
17 |
import shap
|
18 |
from datetime import datetime
|
19 |
-
from stqdm import stqdm
|
20 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
21 |
|
22 |
# --------------------------
|
23 |
# Helper Functions
|
@@ -25,8 +58,7 @@ from stqdm import stqdm
|
|
25 |
def enhance_section_title(title, icon="✨"):
|
26 |
"""Helper function to create a styled section title with an icon."""
|
27 |
st.markdown(f"<h2 style='border-bottom: 2px solid #ccc; padding-bottom: 5px;'>{icon} {title}</h2>", unsafe_allow_html=True)
|
28 |
-
|
29 |
-
@st.cache_data
|
30 |
def update_cleaned_data(df):
|
31 |
"""Updates the cleaned data in session state."""
|
32 |
st.session_state.cleaned_data = df
|
@@ -34,7 +66,6 @@ def update_cleaned_data(df):
|
|
34 |
st.session_state.data_versions.append(df.copy())
|
35 |
st.success("Action completed successfully!")
|
36 |
|
37 |
-
@st.cache_data
|
38 |
def generate_quality_report(df):
|
39 |
"""Generate comprehensive data quality report"""
|
40 |
report = {
|
@@ -60,8 +91,8 @@ def generate_quality_report(df):
|
|
60 |
})
|
61 |
report['columns'][col] = col_report
|
62 |
return report
|
63 |
-
|
64 |
-
|
65 |
def train_model(df, target, features, problem_type, test_size, model_type, model_params, use_grid_search=False):
|
66 |
"""Trains a model with hyperparameter tuning, cross-validation, and customizable model architecture."""
|
67 |
|
@@ -69,36 +100,42 @@ def train_model(df, target, features, problem_type, test_size, model_type, model
|
|
69 |
X = df[features]
|
70 |
y = df[target]
|
71 |
|
72 |
-
# Input Validation
|
73 |
if target not in df.columns:
|
74 |
raise ValueError(f"Target variable '{target}' not found in DataFrame.")
|
75 |
for feature in features:
|
76 |
if feature not in df.columns:
|
77 |
raise ValueError(f"Feature '{feature}' not found in DataFrame.")
|
78 |
|
79 |
-
# Preprocessing Pipeline
|
|
|
80 |
numerical_features = X.select_dtypes(include=np.number).columns
|
81 |
categorical_features = X.select_dtypes(exclude=np.number).columns
|
82 |
|
83 |
imputer_numerical = SimpleImputer(strategy='mean') # Or 'median', 'most_frequent', 'constant'
|
84 |
X[numerical_features] = imputer_numerical.fit_transform(X[numerical_features])
|
85 |
|
86 |
-
|
|
|
87 |
|
88 |
-
|
|
|
89 |
if problem_type == "Classification" or problem_type == "Multiclass":
|
90 |
label_encoder = LabelEncoder()
|
91 |
y = label_encoder.fit_transform(y)
|
92 |
|
|
|
|
|
93 |
X_train, X_test, y_train, y_test = train_test_split(
|
94 |
X, y, test_size=test_size, random_state=42
|
95 |
)
|
96 |
|
97 |
-
|
98 |
-
|
99 |
-
|
|
|
100 |
|
101 |
-
# Model Selection and Hyperparameter Tuning
|
102 |
if problem_type == "Regression":
|
103 |
if model_type == "Random Forest":
|
104 |
model = RandomForestRegressor(random_state=42)
|
@@ -115,9 +152,9 @@ def train_model(df, target, features, problem_type, test_size, model_type, model
|
|
115 |
'max_depth': [3, 5]
|
116 |
}
|
117 |
elif model_type == "Neural Network":
|
118 |
-
model = MLPRegressor(random_state=42, max_iter=500)
|
119 |
param_grid = {
|
120 |
-
'hidden_layer_sizes': [(50,), (100,), (50, 50)],
|
121 |
'activation': ['relu', 'tanh'],
|
122 |
'alpha': [0.0001, 0.001]
|
123 |
}
|
@@ -141,9 +178,9 @@ def train_model(df, target, features, problem_type, test_size, model_type, model
|
|
141 |
'max_depth': [3, 5]
|
142 |
}
|
143 |
elif model_type == "Neural Network":
|
144 |
-
model = MLPClassifier(random_state=42, max_iter=500)
|
145 |
param_grid = {
|
146 |
-
'hidden_layer_sizes': [(50,), (100,), (50, 50)],
|
147 |
'activation': ['relu', 'tanh'],
|
148 |
'alpha': [0.0001, 0.001]
|
149 |
}
|
@@ -153,11 +190,11 @@ def train_model(df, target, features, problem_type, test_size, model_type, model
|
|
153 |
elif problem_type == "Multiclass": #Multiclass
|
154 |
|
155 |
if model_type == "Logistic Regression":
|
156 |
-
model = LogisticRegression(random_state=42, solver='liblinear', multi_class='ovr')
|
157 |
-
param_grid = {'C': [0.1, 1.0, 10.0]}
|
158 |
|
159 |
elif model_type == "Support Vector Machine":
|
160 |
-
model = SVC(random_state=42, probability=True)
|
161 |
param_grid = {'C': [0.1, 1.0, 10.0], 'kernel': ['rbf', 'linear']}
|
162 |
|
163 |
elif model_type == "Random Forest":
|
@@ -166,7 +203,7 @@ def train_model(df, target, features, problem_type, test_size, model_type, model
|
|
166 |
'n_estimators': [100, 200],
|
167 |
'max_depth': [None, 5, 10],
|
168 |
'min_samples_split': [2, 5],
|
169 |
-
'criterion': ['gini', 'entropy']
|
170 |
}
|
171 |
|
172 |
else:
|
@@ -174,47 +211,51 @@ def train_model(df, target, features, problem_type, test_size, model_type, model
|
|
174 |
else:
|
175 |
raise ValueError(f"Invalid problem type: {problem_type}")
|
176 |
|
177 |
-
param_grid
|
|
|
178 |
|
179 |
if use_grid_search:
|
180 |
grid_search = GridSearchCV(model, param_grid, cv=3, scoring='accuracy' if problem_type in ['Classification', 'Multiclass'] else 'neg_mean_squared_error', verbose=1, n_jobs=-1)
|
181 |
-
grid_search.fit(
|
182 |
-
model = grid_search.best_estimator_
|
183 |
-
st.write("Best hyperparameters found by Grid Search:", grid_search.best_params_)
|
184 |
|
185 |
else:
|
186 |
-
model.fit(
|
187 |
|
188 |
-
|
|
|
189 |
st.write("Cross-validation scores:", cv_scores)
|
190 |
st.write("Mean cross-validation score:", cv_scores.mean())
|
191 |
|
192 |
-
# Evaluation
|
193 |
-
y_pred = model.predict(
|
194 |
-
metrics = {}
|
195 |
|
196 |
if problem_type == "Classification":
|
197 |
metrics['accuracy'] = accuracy_score(y_test, y_pred)
|
198 |
metrics['confusion_matrix'] = confusion_matrix(y_test, y_pred)
|
199 |
-
metrics['classification_report'] = classification_report(y_test, y_pred, output_dict=True)
|
200 |
|
201 |
elif problem_type == "Multiclass":
|
|
|
202 |
metrics['accuracy'] = accuracy_score(y_test, y_pred)
|
203 |
metrics['confusion_matrix'] = confusion_matrix(y_test, y_pred)
|
204 |
-
metrics['classification_report'] = classification_report(y_test, y_pred, output_dict=True)
|
205 |
else:
|
206 |
metrics['mse'] = mean_squared_error(y_test, y_pred)
|
207 |
metrics['r2'] = r2_score(y_test, y_pred)
|
208 |
|
209 |
-
# Feature Importance (
|
210 |
try:
|
211 |
-
result = permutation_importance(model,
|
212 |
importance = result.importances_mean
|
213 |
|
214 |
except Exception as e:
|
215 |
st.warning(f"Could not calculate feature importance: {e}")
|
216 |
importance = None
|
217 |
|
|
|
218 |
column_order = X.columns
|
219 |
|
220 |
return model, scaler, label_encoder, imputer_numerical, metrics, column_order, importance, X_train, y_train # Return X_train and y_train
|
@@ -222,7 +263,8 @@ def train_model(df, target, features, problem_type, test_size, model_type, model
|
|
222 |
except Exception as e:
|
223 |
st.error(f"Training failed: {str(e)}")
|
224 |
return None, None, None, None, None, None, None, None, None
|
225 |
-
|
|
|
226 |
def validate_model(model_path, df, target, features, test_size):
|
227 |
"""Loads a model, preprocesses data, and evaluates the model on a validation set."""
|
228 |
try:
|
@@ -304,13 +346,6 @@ def prediction_input_form(features, default_values=None):
|
|
304 |
input_data[feature] = st.number_input(f"{feature}:", value=default_value)
|
305 |
return input_data
|
306 |
|
307 |
-
if 'raw_data' not in st.session_state:
|
308 |
-
st.session_state.raw_data = None
|
309 |
-
if 'cleaned_data' not in st.session_state:
|
310 |
-
st.session_state.cleaned_data = None
|
311 |
-
if 'data_versions' not in st.session_state:
|
312 |
-
st.session_state.data_versions = []
|
313 |
-
|
314 |
# --------------------------
|
315 |
# Sidebar Navigation
|
316 |
# --------------------------
|
@@ -330,20 +365,18 @@ with st.sidebar:
|
|
330 |
# --------------------------
|
331 |
if app_mode == "Data Upload":
|
332 |
st.title("📤 Data Upload & Profiling")
|
333 |
-
|
334 |
-
uploaded_file = st.file_uploader("Upload your dataset (CSV/XLSX)", type=["csv", "xlsx"]
|
335 |
-
|
336 |
if uploaded_file:
|
337 |
try:
|
338 |
if uploaded_file.name.endswith('.csv'):
|
339 |
df = pd.read_csv(uploaded_file)
|
340 |
else:
|
341 |
df = pd.read_excel(uploaded_file)
|
342 |
-
|
343 |
st.session_state.raw_data = df
|
344 |
-
|
345 |
-
st.session_state.data_versions = [df.copy()] # Initialize data versions
|
346 |
-
|
347 |
col1, col2, col3 = st.columns(3)
|
348 |
with col1:
|
349 |
st.metric("Rows", df.shape[0])
|
@@ -351,19 +384,57 @@ if app_mode == "Data Upload":
|
|
351 |
st.metric("Columns", df.shape[1])
|
352 |
with col3:
|
353 |
st.metric("Missing Values", df.isna().sum().sum())
|
354 |
-
|
355 |
with st.expander("Data Preview", expanded=True):
|
356 |
st.dataframe(df.head(10), use_container_width=True)
|
357 |
-
|
358 |
if st.button("Generate Full Profile Report"):
|
359 |
with st.spinner("Generating comprehensive analysis..."):
|
360 |
pr = ProfileReport(df, explorative=True)
|
361 |
st_profile_report(pr)
|
|
|
|
|
|
|
362 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
363 |
except Exception as e:
|
364 |
st.error(f"Error loading file: {str(e)}")
|
365 |
|
366 |
-
#
|
|
|
|
|
367 |
elif app_mode == "Data Cleaning":
|
368 |
st.title("🧹 Smart Data Cleaning")
|
369 |
|
@@ -394,6 +465,14 @@ elif app_mode == "Data Cleaning":
|
|
394 |
profile = ProfileReport(df, minimal=True)
|
395 |
st_profile_report(profile)
|
396 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
397 |
# Missing Value Handling
|
398 |
enhance_section_title("Missing Values Treatment", "🔍")
|
399 |
with st.expander("🔍 Missing Values Treatment", expanded=True):
|
@@ -430,7 +509,7 @@ elif app_mode == "Data Cleaning":
|
|
430 |
new_df[cols] = new_df[cols].bfill()
|
431 |
|
432 |
update_cleaned_data(new_df)
|
433 |
-
st.
|
434 |
|
435 |
# Data Type Conversion
|
436 |
enhance_section_title("Data Type Conversion", "🔄")
|
@@ -465,7 +544,7 @@ elif app_mode == "Data Cleaning":
|
|
465 |
new_df[col_to_convert] = pd.to_datetime(new_df[col_to_convert], format=date_format, errors='coerce')
|
466 |
|
467 |
update_cleaned_data(new_df)
|
468 |
-
st.
|
469 |
except Exception as e:
|
470 |
st.error(f"Error: {str(e)}")
|
471 |
|
@@ -478,7 +557,7 @@ elif app_mode == "Data Cleaning":
|
|
478 |
if st.button("Confirm Drop (Columns)"):
|
479 |
new_df = df.drop(columns=columns_to_drop)
|
480 |
update_cleaned_data(new_df)
|
481 |
-
st.
|
482 |
|
483 |
# Label Encoding
|
484 |
enhance_section_title("Label Encoding", "🔢")
|
@@ -491,7 +570,7 @@ elif app_mode == "Data Cleaning":
|
|
491 |
le = LabelEncoder()
|
492 |
new_df[col] = le.fit_transform(new_df[col].astype(str))
|
493 |
update_cleaned_data(new_df)
|
494 |
-
st.
|
495 |
|
496 |
# StandardScaler
|
497 |
enhance_section_title("StandardScaler", "📏")
|
@@ -503,7 +582,7 @@ elif app_mode == "Data Cleaning":
|
|
503 |
scaler = StandardScaler()
|
504 |
new_df[scale_cols] = scaler.fit_transform(new_df[scale_cols])
|
505 |
update_cleaned_data(new_df)
|
506 |
-
st.
|
507 |
|
508 |
# Pattern-Based Cleaning
|
509 |
enhance_section_title("Pattern-Based Cleaning", "🕵️")
|
@@ -516,13 +595,24 @@ elif app_mode == "Data Cleaning":
|
|
516 |
new_df = df.copy()
|
517 |
new_df[selected_col] = new_df[selected_col].str.replace(pattern, replacement, regex=True)
|
518 |
update_cleaned_data(new_df)
|
519 |
-
st.
|
520 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
521 |
enhance_section_title("✨ Cleaned Data Preview", "✨")
|
522 |
with st.expander("✨ Cleaned Data Preview"):
|
523 |
-
st.dataframe(
|
524 |
-
|
525 |
-
|
526 |
|
527 |
|
528 |
# --------------------------
|
@@ -531,16 +621,11 @@ elif app_mode == "Data Cleaning":
|
|
531 |
elif app_mode == "EDA":
|
532 |
st.title("🔍 Interactive Data Explorer")
|
533 |
|
534 |
-
if st.session_state.
|
535 |
-
st.warning("Please
|
536 |
-
st.stop()
|
537 |
-
|
538 |
-
if 'cleaned_data' in st.session_state and st.session_state.cleaned_data is not None:
|
539 |
-
df = st.session_state.cleaned_data.copy() # Work on the latest cleaned data
|
540 |
-
else:
|
541 |
-
st.warning("No cleaned data available. Please clean your data first.")
|
542 |
-
st.stop()
|
543 |
|
|
|
544 |
|
545 |
# --------------------------
|
546 |
# Enhanced Data Overview
|
@@ -640,7 +725,7 @@ elif app_mode == "EDA":
|
|
640 |
try:
|
641 |
fig = None # Initialize fig to None
|
642 |
if st.session_state.cleaned_data is None:
|
643 |
-
st.warning("Please
|
644 |
st.stop()
|
645 |
|
646 |
# Generate appropriate visualization with input validation
|
@@ -869,18 +954,23 @@ elif app_mode == "EDA":
|
|
869 |
elif app_mode == "Model Training":
|
870 |
st.title("🤖 Intelligent Model Training")
|
871 |
|
872 |
-
if st.session_state.
|
873 |
-
st.warning("Please
|
874 |
-
|
875 |
-
|
876 |
-
|
877 |
-
|
878 |
-
|
879 |
-
|
880 |
-
|
881 |
-
|
882 |
-
|
|
|
|
|
|
|
|
|
883 |
|
|
|
884 |
|
885 |
# Model Setup
|
886 |
col1, col2, col3 = st.columns(3)
|
@@ -950,61 +1040,18 @@ elif app_mode == "Model Training":
|
|
950 |
|
951 |
use_grid_search = st.checkbox("Use Grid Search for Hyperparameter Tuning")
|
952 |
|
953 |
-
|
954 |
if not features:
|
955 |
st.error("Please select at least one feature.")
|
956 |
st.stop()
|
957 |
|
958 |
# Call the training function
|
959 |
-
model, scaler, label_encoder, imputer_numerical, metrics, column_order, importance = train_model(df.copy(), target, features, problem_type, test_size, model_type, model_params, use_grid_search) # Pass a copy to avoid modifying the original
|
960 |
|
961 |
if model: # Only proceed if training was successful
|
962 |
st.success("Model trained successfully!")
|
963 |
|
964 |
-
#
|
965 |
-
st.subheader("Model Evaluation Metrics")
|
966 |
-
if problem_type in ["Classification", "Multiclass"]: #Combined here
|
967 |
-
st.metric("Accuracy", f"{metrics['accuracy']:.2%}")
|
968 |
-
|
969 |
-
# Confusion Matrix Visualization
|
970 |
-
st.subheader("Confusion Matrix")
|
971 |
-
cm = metrics['confusion_matrix']
|
972 |
-
class_names = [str(i) for i in np.unique(df[target])] #Get original class names
|
973 |
-
fig_cm = px.imshow(cm,
|
974 |
-
labels=dict(x="Predicted", y="Actual"),
|
975 |
-
x=class_names,
|
976 |
-
y=class_names,
|
977 |
-
color_continuous_scale="Viridis")
|
978 |
-
st.plotly_chart(fig_cm, use_container_width=True)
|
979 |
-
|
980 |
-
# Classification Report
|
981 |
-
st.subheader("Classification Report")
|
982 |
-
report = metrics['classification_report']
|
983 |
-
report_df = pd.DataFrame(report).transpose()
|
984 |
-
st.dataframe(report_df)
|
985 |
-
|
986 |
-
else:
|
987 |
-
st.metric("MSE", f"{metrics['mse']:.2f}")
|
988 |
-
st.metric("R2", f"{metrics['r2']:.2f}")
|
989 |
-
|
990 |
-
# Feature Importance
|
991 |
-
st.subheader("Feature Importance")
|
992 |
-
try:
|
993 |
-
fig_importance = px.bar(
|
994 |
-
x=importance,
|
995 |
-
y=column_order, #Use stored column order
|
996 |
-
orientation='h',
|
997 |
-
title="Feature Importance"
|
998 |
-
)
|
999 |
-
st.plotly_chart(fig_importance, use_container_width=True)
|
1000 |
-
except Exception as e:
|
1001 |
-
st.warning(f"Could not display feature importance: {e}")
|
1002 |
-
|
1003 |
-
# Explainable AI (Placeholder)
|
1004 |
-
st.subheader("Explainable AI (XAI)")
|
1005 |
-
st.write("Future implementation will include model explanations using techniques like SHAP or LIME.") #To be implemented
|
1006 |
-
if st.checkbox("Show a random model explanation (example)"): #Example of a feature, to be implemented
|
1007 |
-
st.write("This feature is important because...")
|
1008 |
|
1009 |
# Save Model
|
1010 |
st.subheader("Save Model")
|
@@ -1060,22 +1107,12 @@ elif app_mode == "Model Training":
|
|
1060 |
st.metric("MSE", f"{validation_metrics['mse']:.2f}")
|
1061 |
st.metric("R2", f"{validation_metrics['r2']:.2f}")
|
1062 |
|
1063 |
-
# Predictions Section (Fixed)
|
1064 |
elif app_mode == "Predictions":
|
1065 |
-
st.title("🔮 Predictive Analytics")
|
1066 |
-
|
1067 |
-
if st.session_state.raw_data is None:
|
1068 |
-
st.warning("Please upload data first")
|
1069 |
-
st.stop() # Stop execution if no data uploaded
|
1070 |
-
|
1071 |
-
if 'cleaned_data' in st.session_state and st.session_state.cleaned_data is not None:
|
1072 |
-
df = st.session_state.cleaned_data.copy() # Work on the latest cleaned data
|
1073 |
-
else:
|
1074 |
-
st.warning("No cleaned data available. Please clean your data first.")
|
1075 |
-
st.stop() # Stop execution if no cleaned data is available
|
1076 |
-
|
1077 |
-
# Rest of the predictions code...
|
1078 |
|
|
|
|
|
|
|
1079 |
|
1080 |
model_data = st.session_state.model # Get the entire dictionary
|
1081 |
model = model_data['model'] # Access model
|
@@ -1095,8 +1132,8 @@ elif app_mode == "Predictions":
|
|
1095 |
|
1096 |
with col2:
|
1097 |
st.subheader("Data Overview")
|
1098 |
-
input_df = pd.DataFrame([input_data])
|
1099 |
-
st.dataframe(input_df,
|
1100 |
|
1101 |
# Predicts Function and Displays Result
|
1102 |
if st.button("Generate Prediction & Insights"):
|
@@ -1110,12 +1147,14 @@ elif app_mode == "Predictions":
|
|
1110 |
|
1111 |
# 3. One-hot encode (handle unseen categories)
|
1112 |
categorical_features = input_df.select_dtypes(exclude=np.number).columns
|
1113 |
-
input_df = pd.get_dummies(input_df, columns=categorical_features, dummy_na=False)
|
1114 |
|
1115 |
# 4. Ensure correct column order
|
|
|
1116 |
for col in column_order:
|
1117 |
if col not in input_df.columns:
|
1118 |
input_df[col] = 0
|
|
|
1119 |
input_df = input_df[column_order]
|
1120 |
|
1121 |
# 5. Scale the input
|
@@ -1138,21 +1177,29 @@ elif app_mode == "Predictions":
|
|
1138 |
|
1139 |
if problem_type == "Classification":
|
1140 |
explainer = shap.TreeExplainer(model)
|
1141 |
-
shap_values = explainer.shap_values(scaled_input)
|
1142 |
-
|
1143 |
-
|
|
|
|
|
|
|
1144 |
else:
|
1145 |
-
explainer = shap.TreeExplainer(model)
|
1146 |
-
shap_values = explainer.shap_values(scaled_input)
|
1147 |
-
|
1148 |
-
|
|
|
1149 |
|
1150 |
st.write("The visualization above explains how each feature contributed to the final prediction.")
|
1151 |
|
1152 |
# 9. Add Permutation Feature Importance (for more global understanding)
|
1153 |
try:
|
1154 |
enhance_section_title("Global Feature Importance", "🌍")
|
1155 |
-
X = pd.DataFrame(scaler.transform(input_df), columns=input_df.columns)
|
|
|
|
|
|
|
|
|
1156 |
result = permutation_importance(model, X, input_df, n_repeats=10, random_state=42)
|
1157 |
importance = result.importances_mean
|
1158 |
|
@@ -1163,4 +1210,55 @@ elif app_mode == "Predictions":
|
|
1163 |
st.warning(f"Could not calculate permutation feature importance: {e}")
|
1164 |
|
1165 |
except Exception as e:
|
1166 |
-
st.error(f"Prediction failed: {str(e)}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
12 |
from sklearn.metrics import accuracy_score, mean_squared_error
|
13 |
from ydata_profiling import ProfileReport
|
14 |
from streamlit_pandas_profiling import st_profile_report
|
15 |
+
import joblib
|
|
|
16 |
import shap
|
17 |
from datetime import datetime
|
|
|
18 |
|
19 |
+
# --------------------------
|
20 |
+
# Page Configuration
|
21 |
+
# --------------------------
|
22 |
+
st.set_page_config(
|
23 |
+
page_title="DataInsight Pro",
|
24 |
+
page_icon="🔮",
|
25 |
+
layout="wide",
|
26 |
+
initial_sidebar_state="expanded"
|
27 |
+
)
|
28 |
+
|
29 |
+
|
30 |
+
# --------------------------
|
31 |
+
# Custom Styling
|
32 |
+
# --------------------------
|
33 |
+
st.markdown("""
|
34 |
+
<style>
|
35 |
+
.main {background-color: #f8f9fa;}
|
36 |
+
.sidebar .sidebar-content {background-color: #2c3e50;}
|
37 |
+
.stButton>button {background-color: #3498db; color: white;}
|
38 |
+
.stTextInput>div>div>input {border: 1px solid #3498db;}
|
39 |
+
.stSelectbox>div>div>select {border: 1px solid #3498db;}
|
40 |
+
.stSlider>div>div>div>div {background-color: #3498db;}
|
41 |
+
.metric {padding: 15px; background-color: white; border-radius: 10px; box-shadow: 0 2px 5px rgba(0,0,0,0.1);}
|
42 |
+
</style>
|
43 |
+
""", unsafe_allow_html=True)
|
44 |
+
|
45 |
+
# --------------------------
|
46 |
+
# Session State Initialization
|
47 |
+
# --------------------------
|
48 |
+
if 'raw_data' not in st.session_state:
|
49 |
+
st.session_state.raw_data = None
|
50 |
+
if 'cleaned_data' not in st.session_state:
|
51 |
+
st.session_state.cleaned_data = None
|
52 |
+
if 'model' not in st.session_state:
|
53 |
+
st.session_state.model = None
|
54 |
|
55 |
# --------------------------
|
56 |
# Helper Functions
|
|
|
58 |
def enhance_section_title(title, icon="✨"):
|
59 |
"""Helper function to create a styled section title with an icon."""
|
60 |
st.markdown(f"<h2 style='border-bottom: 2px solid #ccc; padding-bottom: 5px;'>{icon} {title}</h2>", unsafe_allow_html=True)
|
61 |
+
|
|
|
62 |
def update_cleaned_data(df):
|
63 |
"""Updates the cleaned data in session state."""
|
64 |
st.session_state.cleaned_data = df
|
|
|
66 |
st.session_state.data_versions.append(df.copy())
|
67 |
st.success("Action completed successfully!")
|
68 |
|
|
|
69 |
def generate_quality_report(df):
|
70 |
"""Generate comprehensive data quality report"""
|
71 |
report = {
|
|
|
91 |
})
|
92 |
report['columns'][col] = col_report
|
93 |
return report
|
94 |
+
|
95 |
+
# Function to train the model (Separated for clarity and reusability)
|
96 |
def train_model(df, target, features, problem_type, test_size, model_type, model_params, use_grid_search=False):
|
97 |
"""Trains a model with hyperparameter tuning, cross-validation, and customizable model architecture."""
|
98 |
|
|
|
100 |
X = df[features]
|
101 |
y = df[target]
|
102 |
|
103 |
+
# Input Validation
|
104 |
if target not in df.columns:
|
105 |
raise ValueError(f"Target variable '{target}' not found in DataFrame.")
|
106 |
for feature in features:
|
107 |
if feature not in df.columns:
|
108 |
raise ValueError(f"Feature '{feature}' not found in DataFrame.")
|
109 |
|
110 |
+
# Preprocessing Pipeline: Handles missing values, encoding, scaling
|
111 |
+
# Imputation: Handle missing values BEFORE encoding (numerical only for SimpleImputer)
|
112 |
numerical_features = X.select_dtypes(include=np.number).columns
|
113 |
categorical_features = X.select_dtypes(exclude=np.number).columns
|
114 |
|
115 |
imputer_numerical = SimpleImputer(strategy='mean') # Or 'median', 'most_frequent', 'constant'
|
116 |
X[numerical_features] = imputer_numerical.fit_transform(X[numerical_features])
|
117 |
|
118 |
+
# Encoding (One-Hot Encode Categorical Features)
|
119 |
+
X = pd.get_dummies(X, columns=categorical_features, dummy_na=False) # dummy_na = False. We imputed already.
|
120 |
|
121 |
+
# Target Encoding (if classification)
|
122 |
+
label_encoder = None #Initialize label_encoder
|
123 |
if problem_type == "Classification" or problem_type == "Multiclass":
|
124 |
label_encoder = LabelEncoder()
|
125 |
y = label_encoder.fit_transform(y)
|
126 |
|
127 |
+
|
128 |
+
# Split the data
|
129 |
X_train, X_test, y_train, y_test = train_test_split(
|
130 |
X, y, test_size=test_size, random_state=42
|
131 |
)
|
132 |
|
133 |
+
# Scaling (AFTER splitting!)
|
134 |
+
scaler = StandardScaler() # Or try MinMaxScaler, RobustScaler, QuantileTransformer
|
135 |
+
X_train_scaled = scaler.fit_transform(X_train) #Fit to the training data ONLY
|
136 |
+
X_test_scaled = scaler.transform(X_test) #Transform the test data using the fitted scaler
|
137 |
|
138 |
+
# Model Selection and Hyperparameter Tuning
|
139 |
if problem_type == "Regression":
|
140 |
if model_type == "Random Forest":
|
141 |
model = RandomForestRegressor(random_state=42)
|
|
|
152 |
'max_depth': [3, 5]
|
153 |
}
|
154 |
elif model_type == "Neural Network":
|
155 |
+
model = MLPRegressor(random_state=42, max_iter=500) #set max_iter to 500
|
156 |
param_grid = {
|
157 |
+
'hidden_layer_sizes': [(50,), (100,), (50, 50)], #example sizes for depth
|
158 |
'activation': ['relu', 'tanh'],
|
159 |
'alpha': [0.0001, 0.001]
|
160 |
}
|
|
|
178 |
'max_depth': [3, 5]
|
179 |
}
|
180 |
elif model_type == "Neural Network":
|
181 |
+
model = MLPClassifier(random_state=42, max_iter=500) #set max_iter to 500
|
182 |
param_grid = {
|
183 |
+
'hidden_layer_sizes': [(50,), (100,), (50, 50)], #example sizes for depth
|
184 |
'activation': ['relu', 'tanh'],
|
185 |
'alpha': [0.0001, 0.001]
|
186 |
}
|
|
|
190 |
elif problem_type == "Multiclass": #Multiclass
|
191 |
|
192 |
if model_type == "Logistic Regression":
|
193 |
+
model = LogisticRegression(random_state=42, solver='liblinear', multi_class='ovr') # 'ovr' for one-vs-rest
|
194 |
+
param_grid = {'C': [0.1, 1.0, 10.0]} # Regularization parameter
|
195 |
|
196 |
elif model_type == "Support Vector Machine":
|
197 |
+
model = SVC(random_state=42, probability=True) # probability=True for probabilities
|
198 |
param_grid = {'C': [0.1, 1.0, 10.0], 'kernel': ['rbf', 'linear']}
|
199 |
|
200 |
elif model_type == "Random Forest":
|
|
|
203 |
'n_estimators': [100, 200],
|
204 |
'max_depth': [None, 5, 10],
|
205 |
'min_samples_split': [2, 5],
|
206 |
+
'criterion': ['gini', 'entropy'] #criterion for decision
|
207 |
}
|
208 |
|
209 |
else:
|
|
|
211 |
else:
|
212 |
raise ValueError(f"Invalid problem type: {problem_type}")
|
213 |
|
214 |
+
# Update param_grid with user-defined parameters
|
215 |
+
param_grid.update(model_params) #This is key to use the model_params provided by user
|
216 |
|
217 |
if use_grid_search:
|
218 |
grid_search = GridSearchCV(model, param_grid, cv=3, scoring='accuracy' if problem_type in ['Classification', 'Multiclass'] else 'neg_mean_squared_error', verbose=1, n_jobs=-1)
|
219 |
+
grid_search.fit(X_train_scaled, y_train) # Use scaled training data
|
220 |
+
model = grid_search.best_estimator_ # Use the best model found
|
221 |
+
st.write("Best hyperparameters found by Grid Search:", grid_search.best_params_) #Print best parameters
|
222 |
|
223 |
else:
|
224 |
+
model.fit(X_train_scaled, y_train) # Use scaled training data
|
225 |
|
226 |
+
# Cross-Validation (after hyperparameter tuning, if applicable)
|
227 |
+
cv_scores = cross_val_score(model, X_train_scaled, y_train, cv=5, scoring='accuracy' if problem_type in ['Classification', 'Multiclass'] else 'neg_mean_squared_error') # Use scaled training data
|
228 |
st.write("Cross-validation scores:", cv_scores)
|
229 |
st.write("Mean cross-validation score:", cv_scores.mean())
|
230 |
|
231 |
+
# Evaluation
|
232 |
+
y_pred = model.predict(X_test_scaled) # Use scaled test data
|
233 |
+
metrics = {} #Store metrics in a dictionary
|
234 |
|
235 |
if problem_type == "Classification":
|
236 |
metrics['accuracy'] = accuracy_score(y_test, y_pred)
|
237 |
metrics['confusion_matrix'] = confusion_matrix(y_test, y_pred)
|
238 |
+
metrics['classification_report'] = classification_report(y_test, y_pred, output_dict=True) #Get report as dictionary
|
239 |
|
240 |
elif problem_type == "Multiclass":
|
241 |
+
|
242 |
metrics['accuracy'] = accuracy_score(y_test, y_pred)
|
243 |
metrics['confusion_matrix'] = confusion_matrix(y_test, y_pred)
|
244 |
+
metrics['classification_report'] = classification_report(y_test, y_pred, output_dict=True) #Get report as dictionary
|
245 |
else:
|
246 |
metrics['mse'] = mean_squared_error(y_test, y_pred)
|
247 |
metrics['r2'] = r2_score(y_test, y_pred)
|
248 |
|
249 |
+
# Feature Importance (Permutation Importance for potentially better handling of correlated features)
|
250 |
try:
|
251 |
+
result = permutation_importance(model, X_test_scaled, y_test, n_repeats=10, random_state=42) #Permutation Feature Importance # Use scaled test data
|
252 |
importance = result.importances_mean
|
253 |
|
254 |
except Exception as e:
|
255 |
st.warning(f"Could not calculate feature importance: {e}")
|
256 |
importance = None
|
257 |
|
258 |
+
# Store the column order for prediction purposes
|
259 |
column_order = X.columns
|
260 |
|
261 |
return model, scaler, label_encoder, imputer_numerical, metrics, column_order, importance, X_train, y_train # Return X_train and y_train
|
|
|
263 |
except Exception as e:
|
264 |
st.error(f"Training failed: {str(e)}")
|
265 |
return None, None, None, None, None, None, None, None, None
|
266 |
+
|
267 |
+
# Model Validation Function
|
268 |
def validate_model(model_path, df, target, features, test_size):
|
269 |
"""Loads a model, preprocesses data, and evaluates the model on a validation set."""
|
270 |
try:
|
|
|
346 |
input_data[feature] = st.number_input(f"{feature}:", value=default_value)
|
347 |
return input_data
|
348 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
349 |
# --------------------------
|
350 |
# Sidebar Navigation
|
351 |
# --------------------------
|
|
|
365 |
# --------------------------
|
366 |
if app_mode == "Data Upload":
|
367 |
st.title("📤 Data Upload & Profiling")
|
368 |
+
|
369 |
+
uploaded_file = st.file_uploader("Upload your dataset (CSV/XLSX)", type=["csv", "xlsx"])
|
370 |
+
|
371 |
if uploaded_file:
|
372 |
try:
|
373 |
if uploaded_file.name.endswith('.csv'):
|
374 |
df = pd.read_csv(uploaded_file)
|
375 |
else:
|
376 |
df = pd.read_excel(uploaded_file)
|
377 |
+
|
378 |
st.session_state.raw_data = df
|
379 |
+
|
|
|
|
|
380 |
col1, col2, col3 = st.columns(3)
|
381 |
with col1:
|
382 |
st.metric("Rows", df.shape[0])
|
|
|
384 |
st.metric("Columns", df.shape[1])
|
385 |
with col3:
|
386 |
st.metric("Missing Values", df.isna().sum().sum())
|
387 |
+
|
388 |
with st.expander("Data Preview", expanded=True):
|
389 |
st.dataframe(df.head(10), use_container_width=True)
|
390 |
+
|
391 |
if st.button("Generate Full Profile Report"):
|
392 |
with st.spinner("Generating comprehensive analysis..."):
|
393 |
pr = ProfileReport(df, explorative=True)
|
394 |
st_profile_report(pr)
|
395 |
+
|
396 |
+
except Exception as e:
|
397 |
+
st.error(f"Error loading file: {str(e)}")
|
398 |
|
399 |
+
# --------------------------
|
400 |
+
# Page Content
|
401 |
+
# --------------------------
|
402 |
+
if app_mode == "Data Upload":
|
403 |
+
st.title("📤 Data Upload & Profiling")
|
404 |
+
|
405 |
+
uploaded_file = st.file_uploader("Upload your dataset (CSV/XLSX)", type=["csv", "xlsx"])
|
406 |
+
|
407 |
+
if uploaded_file:
|
408 |
+
try:
|
409 |
+
if uploaded_file.name.endswith('.csv'):
|
410 |
+
df = pd.read_csv(uploaded_file)
|
411 |
+
else:
|
412 |
+
df = pd.read_excel(uploaded_file)
|
413 |
+
|
414 |
+
st.session_state.raw_data = df
|
415 |
+
|
416 |
+
col1, col2, col3 = st.columns(3)
|
417 |
+
with col1:
|
418 |
+
st.metric("Rows", df.shape[0])
|
419 |
+
with col2:
|
420 |
+
st.metric("Columns", df.shape[1])
|
421 |
+
with col3:
|
422 |
+
st.metric("Missing Values", df.isna().sum().sum())
|
423 |
+
|
424 |
+
with st.expander("Data Preview", expanded=True):
|
425 |
+
st.dataframe(df.head(10), use_container_width=True)
|
426 |
+
|
427 |
+
if st.button("Generate Full Profile Report"):
|
428 |
+
with st.spinner("Generating comprehensive analysis..."):
|
429 |
+
pr = ProfileReport(df, explorative=True)
|
430 |
+
st_profile_report(pr)
|
431 |
+
|
432 |
except Exception as e:
|
433 |
st.error(f"Error loading file: {str(e)}")
|
434 |
|
435 |
+
# --------------------------
|
436 |
+
# Page Content
|
437 |
+
# --------------------------
|
438 |
elif app_mode == "Data Cleaning":
|
439 |
st.title("🧹 Smart Data Cleaning")
|
440 |
|
|
|
465 |
profile = ProfileReport(df, minimal=True)
|
466 |
st_profile_report(profile)
|
467 |
|
468 |
+
# Undo Functionality
|
469 |
+
if len(st.session_state.data_versions) > 1:
|
470 |
+
if st.button("⏮️ Undo Last Action"):
|
471 |
+
st.session_state.data_versions.pop() # Remove current version
|
472 |
+
st.session_state.cleaned_data = st.session_state.data_versions[-1].copy() # Set data
|
473 |
+
st.success("Last action undone!")
|
474 |
+
st.experimental_rerun() #Force re-run after undo
|
475 |
+
|
476 |
# Missing Value Handling
|
477 |
enhance_section_title("Missing Values Treatment", "🔍")
|
478 |
with st.expander("🔍 Missing Values Treatment", expanded=True):
|
|
|
509 |
new_df[cols] = new_df[cols].bfill()
|
510 |
|
511 |
update_cleaned_data(new_df)
|
512 |
+
st.experimental_rerun() # Force re-run after apply
|
513 |
|
514 |
# Data Type Conversion
|
515 |
enhance_section_title("Data Type Conversion", "🔄")
|
|
|
544 |
new_df[col_to_convert] = pd.to_datetime(new_df[col_to_convert], format=date_format, errors='coerce')
|
545 |
|
546 |
update_cleaned_data(new_df)
|
547 |
+
st.experimental_rerun() # Force re-run after apply
|
548 |
except Exception as e:
|
549 |
st.error(f"Error: {str(e)}")
|
550 |
|
|
|
557 |
if st.button("Confirm Drop (Columns)"):
|
558 |
new_df = df.drop(columns=columns_to_drop)
|
559 |
update_cleaned_data(new_df)
|
560 |
+
st.experimental_rerun() # Force re-run after apply
|
561 |
|
562 |
# Label Encoding
|
563 |
enhance_section_title("Label Encoding", "🔢")
|
|
|
570 |
le = LabelEncoder()
|
571 |
new_df[col] = le.fit_transform(new_df[col].astype(str))
|
572 |
update_cleaned_data(new_df)
|
573 |
+
st.experimental_rerun() # Force re-run after apply
|
574 |
|
575 |
# StandardScaler
|
576 |
enhance_section_title("StandardScaler", "📏")
|
|
|
582 |
scaler = StandardScaler()
|
583 |
new_df[scale_cols] = scaler.fit_transform(new_df[scale_cols])
|
584 |
update_cleaned_data(new_df)
|
585 |
+
st.experimental_rerun() # Force re-run after apply
|
586 |
|
587 |
# Pattern-Based Cleaning
|
588 |
enhance_section_title("Pattern-Based Cleaning", "🕵️")
|
|
|
595 |
new_df = df.copy()
|
596 |
new_df[selected_col] = new_df[selected_col].str.replace(pattern, replacement, regex=True)
|
597 |
update_cleaned_data(new_df)
|
598 |
+
st.experimental_rerun() # Force re-run after apply
|
599 |
|
600 |
+
# Bulk Operations
|
601 |
+
enhance_section_title("Bulk Actions", "🚀")
|
602 |
+
with st.expander("🚀 Bulk Actions"):
|
603 |
+
if st.button("Auto-Clean Common Issues (Cleaning)"):
|
604 |
+
new_df = df.copy()
|
605 |
+
new_df = new_df.dropna(axis=1, how='all') # Remove empty cols
|
606 |
+
new_df = new_df.convert_dtypes() # Better type inference
|
607 |
+
text_cols = new_df.select_dtypes(include='object').columns
|
608 |
+
new_df[text_cols] = new_df[text_cols].apply(lambda x: x.str.strip())
|
609 |
+
update_cleaned_data(new_df)
|
610 |
+
st.experimental_rerun() # Force re-run after apply
|
611 |
+
|
612 |
+
# Cleaned Data Preview
|
613 |
enhance_section_title("✨ Cleaned Data Preview", "✨")
|
614 |
with st.expander("✨ Cleaned Data Preview"):
|
615 |
+
st.dataframe(st.session_state.cleaned_data.head(), use_container_width=True)
|
|
|
|
|
616 |
|
617 |
|
618 |
# --------------------------
|
|
|
621 |
elif app_mode == "EDA":
|
622 |
st.title("🔍 Interactive Data Explorer")
|
623 |
|
624 |
+
if st.session_state.cleaned_data is None:
|
625 |
+
st.warning("Please clean your data first")
|
626 |
+
st.stop()
|
|
|
|
|
|
|
|
|
|
|
|
|
627 |
|
628 |
+
df = st.session_state.cleaned_data
|
629 |
|
630 |
# --------------------------
|
631 |
# Enhanced Data Overview
|
|
|
725 |
try:
|
726 |
fig = None # Initialize fig to None
|
727 |
if st.session_state.cleaned_data is None:
|
728 |
+
st.warning("Please clean your data first")
|
729 |
st.stop()
|
730 |
|
731 |
# Generate appropriate visualization with input validation
|
|
|
954 |
elif app_mode == "Model Training":
|
955 |
st.title("🤖 Intelligent Model Training")
|
956 |
|
957 |
+
if st.session_state.get("cleaned_data") is None:
|
958 |
+
st.warning("Please clean your data first")
|
959 |
+
# Show Upload Clean Data button
|
960 |
+
uploaded_clean_file = st.file_uploader("Upload your cleaned dataset (CSV/XLSX)", type=["csv", "xlsx"])
|
961 |
+
if uploaded_clean_file:
|
962 |
+
try:
|
963 |
+
if uploaded_clean_file.name.endswith('.csv'):
|
964 |
+
df = pd.read_csv(uploaded_clean_file)
|
965 |
+
else:
|
966 |
+
df = pd.read_excel(uploaded_clean_file)
|
967 |
+
st.session_state.cleaned_data = df
|
968 |
+
st.success("Cleaned data uploaded successfully!")
|
969 |
+
except Exception as e:
|
970 |
+
st.error(f"Error loading file: {str(e)}")
|
971 |
+
st.stop()
|
972 |
|
973 |
+
df = st.session_state.cleaned_data
|
974 |
|
975 |
# Model Setup
|
976 |
col1, col2, col3 = st.columns(3)
|
|
|
1040 |
|
1041 |
use_grid_search = st.checkbox("Use Grid Search for Hyperparameter Tuning")
|
1042 |
|
1043 |
+
if st.button("Train Model"):
|
1044 |
if not features:
|
1045 |
st.error("Please select at least one feature.")
|
1046 |
st.stop()
|
1047 |
|
1048 |
# Call the training function
|
1049 |
+
model, scaler, label_encoder, imputer_numerical, metrics, column_order, importance, X_train, y_train = train_model(df.copy(), target, features, problem_type, test_size, model_type, model_params, use_grid_search) # Pass a copy to avoid modifying the original # Capture X_train and y_train
|
1050 |
|
1051 |
if model: # Only proceed if training was successful
|
1052 |
st.success("Model trained successfully!")
|
1053 |
|
1054 |
+
# ... (rest of the Model Training code - metrics display, feature importance, saving model) ...
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1055 |
|
1056 |
# Save Model
|
1057 |
st.subheader("Save Model")
|
|
|
1107 |
st.metric("MSE", f"{validation_metrics['mse']:.2f}")
|
1108 |
st.metric("R2", f"{validation_metrics['r2']:.2f}")
|
1109 |
|
|
|
1110 |
elif app_mode == "Predictions":
|
1111 |
+
st.title("🔮 Predictive Analytics - Informed Business Decisions")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1112 |
|
1113 |
+
if st.session_state.get("model") is None:
|
1114 |
+
st.warning("Please train a model first")
|
1115 |
+
st.stop()
|
1116 |
|
1117 |
model_data = st.session_state.model # Get the entire dictionary
|
1118 |
model = model_data['model'] # Access model
|
|
|
1132 |
|
1133 |
with col2:
|
1134 |
st.subheader("Data Overview")
|
1135 |
+
input_df = pd.DataFrame([input_data]) #Make DataFrame
|
1136 |
+
st.dataframe(input_df,use_container_width=True) #DataFrame of the input to see it
|
1137 |
|
1138 |
# Predicts Function and Displays Result
|
1139 |
if st.button("Generate Prediction & Insights"):
|
|
|
1147 |
|
1148 |
# 3. One-hot encode (handle unseen categories)
|
1149 |
categorical_features = input_df.select_dtypes(exclude=np.number).columns
|
1150 |
+
input_df = pd.get_dummies(input_df, columns=categorical_features, dummy_na=False) # dummy_na = False. We imputed already.
|
1151 |
|
1152 |
# 4. Ensure correct column order
|
1153 |
+
# Add missing columns with 0 values
|
1154 |
for col in column_order:
|
1155 |
if col not in input_df.columns:
|
1156 |
input_df[col] = 0
|
1157 |
+
# Reorder Columns
|
1158 |
input_df = input_df[column_order]
|
1159 |
|
1160 |
# 5. Scale the input
|
|
|
1177 |
|
1178 |
if problem_type == "Classification":
|
1179 |
explainer = shap.TreeExplainer(model)
|
1180 |
+
shap_values = explainer.shap_values(scaled_input) # Use the scaled input
|
1181 |
+
# class_names = [str(i) for i in range(len(shap_values))] # Dynamic class names - not needed for force plot
|
1182 |
+
|
1183 |
+
fig = shap.force_plot(explainer.expected_value[1], shap_values[1], input_df, matplotlib=False,link="logit") # shap_values[1] for class 1 - force plot
|
1184 |
+
st.components.v1.html(shap.getjs() + fig.html(), height=400, width=900) # Adjust height and width as needed.
|
1185 |
+
|
1186 |
else:
|
1187 |
+
explainer = shap.TreeExplainer(model) # Regression
|
1188 |
+
shap_values = explainer.shap_values(scaled_input) # Use the scaled input
|
1189 |
+
|
1190 |
+
fig = shap.force_plot(explainer.expected_value, shap_values, input_df, matplotlib=False) # shap_values single array for regression
|
1191 |
+
st.components.v1.html(shap.getjs() + fig.html(), height=400, width=900) # Adjust height and width as needed.
|
1192 |
|
1193 |
st.write("The visualization above explains how each feature contributed to the final prediction.")
|
1194 |
|
1195 |
# 9. Add Permutation Feature Importance (for more global understanding)
|
1196 |
try:
|
1197 |
enhance_section_title("Global Feature Importance", "🌍")
|
1198 |
+
X = pd.DataFrame(scaler.transform(pd.get_dummies(pd.DataFrame(imputer_numerical.transform(input_df), columns=input_df.columns))), columns=input_df.columns) # Apply preprocessing for permutation
|
1199 |
+
#X = pd.DataFrame(scaler.transform(input_df), columns = input_df.columns)
|
1200 |
+
#X = input_df[input_df.columns]
|
1201 |
+
X_train = model_data['X_train'] #Get X train
|
1202 |
+
y_train = model_data['y_train'] #Get Y train
|
1203 |
result = permutation_importance(model, X, input_df, n_repeats=10, random_state=42)
|
1204 |
importance = result.importances_mean
|
1205 |
|
|
|
1210 |
st.warning(f"Could not calculate permutation feature importance: {e}")
|
1211 |
|
1212 |
except Exception as e:
|
1213 |
+
st.error(f"Prediction failed: {str(e)}")
|
1214 |
+
|
1215 |
+
# Force rerun Streamlit app after data cleaning operations
|
1216 |
+
st.experimental_rerun()
|
1217 |
+
|
1218 |
+
if __name__ == "__main__":
|
1219 |
+
# Session State Initialization
|
1220 |
+
if 'raw_data' not in st.session_state:
|
1221 |
+
st.session_state.raw_data = None
|
1222 |
+
if 'cleaned_data' not in st.session_state:
|
1223 |
+
st.session_state.cleaned_data = None
|
1224 |
+
if 'model' not in st.session_state:
|
1225 |
+
st.session_state.model = None
|
1226 |
+
if 'data_versions' not in st.session_state:
|
1227 |
+
st.session_state.data_versions = []
|
1228 |
+
|
1229 |
+
# Custom Styling (Keep it in main if needed)
|
1230 |
+
st.markdown("""
|
1231 |
+
<style>
|
1232 |
+
.main {background-color: #f8f9fa;}
|
1233 |
+
.sidebar .sidebar-content {background-color: #2c3e50;}
|
1234 |
+
.stButton>button {background-color: #3498db; color: white;}
|
1235 |
+
.stTextInput>div>div>input {border: 1px solid #3498db;}
|
1236 |
+
.stSelectbox>div>div>select {border: 1px solid #3498db;}
|
1237 |
+
.stSlider>div>div>div>div {background-color: #3498db;}
|
1238 |
+
.metric {padding: 15px; background-color: white; border-radius: 10px; box-shadow: 0 2px 5px rgba(0,0,0,0.1);}
|
1239 |
+
</style>
|
1240 |
+
""", unsafe_allow_html=True)
|
1241 |
+
|
1242 |
+
# Sidebar Navigation (Keep it in main)
|
1243 |
+
with st.sidebar:
|
1244 |
+
st.title("🔮 DataInsight Pro")
|
1245 |
+
app_mode = st.selectbox(
|
1246 |
+
"Navigation",
|
1247 |
+
["Data Upload", "Data Cleaning", "EDA", "Model Training", "Predictions"],
|
1248 |
+
format_func=lambda x: f"📌 {x}"
|
1249 |
+
)
|
1250 |
+
st.markdown("---")
|
1251 |
+
st.markdown("Created by Calvin Allen-Crawford")
|
1252 |
+
st.markdown("v1.0 | © 2025")
|
1253 |
+
|
1254 |
+
# Call app mode function based on selection
|
1255 |
+
if app_mode == "Data Upload":
|
1256 |
+
app_mode_data_upload()
|
1257 |
+
elif app_mode == "Data Cleaning":
|
1258 |
+
app_mode_data_cleaning()
|
1259 |
+
elif app_mode == "EDA":
|
1260 |
+
app_mode_eda()
|
1261 |
+
elif app_mode == "Model Training":
|
1262 |
+
app_mode_model_training()
|
1263 |
+
elif app_mode == "Predictions":
|
1264 |
+
app_mode_predictions()
|