Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -12,45 +12,12 @@ from sklearn.ensemble import RandomForestClassifier, RandomForestRegressor
|
|
12 |
from sklearn.metrics import accuracy_score, mean_squared_error
|
13 |
from ydata_profiling import ProfileReport
|
14 |
from streamlit_pandas_profiling import st_profile_report
|
15 |
-
import joblib
|
|
|
16 |
import shap
|
17 |
from datetime import datetime
|
|
|
18 |
|
19 |
-
# --------------------------
|
20 |
-
# Page Configuration
|
21 |
-
# --------------------------
|
22 |
-
st.set_page_config(
|
23 |
-
page_title="DataInsight Pro",
|
24 |
-
page_icon="🔮",
|
25 |
-
layout="wide",
|
26 |
-
initial_sidebar_state="expanded"
|
27 |
-
)
|
28 |
-
|
29 |
-
|
30 |
-
# --------------------------
|
31 |
-
# Custom Styling
|
32 |
-
# --------------------------
|
33 |
-
st.markdown("""
|
34 |
-
<style>
|
35 |
-
.main {background-color: #f8f9fa;}
|
36 |
-
.sidebar .sidebar-content {background-color: #2c3e50;}
|
37 |
-
.stButton>button {background-color: #3498db; color: white;}
|
38 |
-
.stTextInput>div>div>input {border: 1px solid #3498db;}
|
39 |
-
.stSelectbox>div>div>select {border: 1px solid #3498db;}
|
40 |
-
.stSlider>div>div>div>div {background-color: #3498db;}
|
41 |
-
.metric {padding: 15px; background-color: white; border-radius: 10px; box-shadow: 0 2px 5px rgba(0,0,0,0.1);}
|
42 |
-
</style>
|
43 |
-
""", unsafe_allow_html=True)
|
44 |
-
|
45 |
-
# --------------------------
|
46 |
-
# Session State Initialization
|
47 |
-
# --------------------------
|
48 |
-
if 'raw_data' not in st.session_state:
|
49 |
-
st.session_state.raw_data = None
|
50 |
-
if 'cleaned_data' not in st.session_state:
|
51 |
-
st.session_state.cleaned_data = None
|
52 |
-
if 'model' not in st.session_state:
|
53 |
-
st.session_state.model = None
|
54 |
|
55 |
# --------------------------
|
56 |
# Helper Functions
|
@@ -66,6 +33,7 @@ def update_cleaned_data(df):
|
|
66 |
st.session_state.data_versions.append(df.copy())
|
67 |
st.success("Action completed successfully!")
|
68 |
|
|
|
69 |
def generate_quality_report(df):
|
70 |
"""Generate comprehensive data quality report"""
|
71 |
report = {
|
@@ -91,8 +59,8 @@ def generate_quality_report(df):
|
|
91 |
})
|
92 |
report['columns'][col] = col_report
|
93 |
return report
|
94 |
-
|
95 |
-
|
96 |
def train_model(df, target, features, problem_type, test_size, model_type, model_params, use_grid_search=False):
|
97 |
"""Trains a model with hyperparameter tuning, cross-validation, and customizable model architecture."""
|
98 |
|
@@ -100,42 +68,36 @@ def train_model(df, target, features, problem_type, test_size, model_type, model
|
|
100 |
X = df[features]
|
101 |
y = df[target]
|
102 |
|
103 |
-
# Input Validation
|
104 |
if target not in df.columns:
|
105 |
raise ValueError(f"Target variable '{target}' not found in DataFrame.")
|
106 |
for feature in features:
|
107 |
if feature not in df.columns:
|
108 |
raise ValueError(f"Feature '{feature}' not found in DataFrame.")
|
109 |
|
110 |
-
# Preprocessing Pipeline
|
111 |
-
# Imputation: Handle missing values BEFORE encoding (numerical only for SimpleImputer)
|
112 |
numerical_features = X.select_dtypes(include=np.number).columns
|
113 |
categorical_features = X.select_dtypes(exclude=np.number).columns
|
114 |
|
115 |
imputer_numerical = SimpleImputer(strategy='mean') # Or 'median', 'most_frequent', 'constant'
|
116 |
X[numerical_features] = imputer_numerical.fit_transform(X[numerical_features])
|
117 |
|
118 |
-
|
119 |
-
X = pd.get_dummies(X, columns=categorical_features, dummy_na=False) # dummy_na = False. We imputed already.
|
120 |
|
121 |
-
|
122 |
-
label_encoder = None #Initialize label_encoder
|
123 |
if problem_type == "Classification" or problem_type == "Multiclass":
|
124 |
label_encoder = LabelEncoder()
|
125 |
y = label_encoder.fit_transform(y)
|
126 |
|
127 |
-
|
128 |
-
# Split the data
|
129 |
X_train, X_test, y_train, y_test = train_test_split(
|
130 |
X, y, test_size=test_size, random_state=42
|
131 |
)
|
132 |
|
133 |
-
|
134 |
-
|
135 |
-
|
136 |
-
X_test_scaled = scaler.transform(X_test) #Transform the test data using the fitted scaler
|
137 |
|
138 |
-
# Model Selection and Hyperparameter Tuning
|
139 |
if problem_type == "Regression":
|
140 |
if model_type == "Random Forest":
|
141 |
model = RandomForestRegressor(random_state=42)
|
@@ -152,9 +114,9 @@ def train_model(df, target, features, problem_type, test_size, model_type, model
|
|
152 |
'max_depth': [3, 5]
|
153 |
}
|
154 |
elif model_type == "Neural Network":
|
155 |
-
model = MLPRegressor(random_state=42, max_iter=500)
|
156 |
param_grid = {
|
157 |
-
'hidden_layer_sizes': [(50,), (100,), (50, 50)],
|
158 |
'activation': ['relu', 'tanh'],
|
159 |
'alpha': [0.0001, 0.001]
|
160 |
}
|
@@ -178,9 +140,9 @@ def train_model(df, target, features, problem_type, test_size, model_type, model
|
|
178 |
'max_depth': [3, 5]
|
179 |
}
|
180 |
elif model_type == "Neural Network":
|
181 |
-
model = MLPClassifier(random_state=42, max_iter=500)
|
182 |
param_grid = {
|
183 |
-
'hidden_layer_sizes': [(50,), (100,), (50, 50)],
|
184 |
'activation': ['relu', 'tanh'],
|
185 |
'alpha': [0.0001, 0.001]
|
186 |
}
|
@@ -190,11 +152,11 @@ def train_model(df, target, features, problem_type, test_size, model_type, model
|
|
190 |
elif problem_type == "Multiclass": #Multiclass
|
191 |
|
192 |
if model_type == "Logistic Regression":
|
193 |
-
model = LogisticRegression(random_state=42, solver='liblinear', multi_class='ovr')
|
194 |
-
param_grid = {'C': [0.1, 1.0, 10.0]}
|
195 |
|
196 |
elif model_type == "Support Vector Machine":
|
197 |
-
model = SVC(random_state=42, probability=True)
|
198 |
param_grid = {'C': [0.1, 1.0, 10.0], 'kernel': ['rbf', 'linear']}
|
199 |
|
200 |
elif model_type == "Random Forest":
|
@@ -203,7 +165,7 @@ def train_model(df, target, features, problem_type, test_size, model_type, model
|
|
203 |
'n_estimators': [100, 200],
|
204 |
'max_depth': [None, 5, 10],
|
205 |
'min_samples_split': [2, 5],
|
206 |
-
'criterion': ['gini', 'entropy']
|
207 |
}
|
208 |
|
209 |
else:
|
@@ -211,51 +173,47 @@ def train_model(df, target, features, problem_type, test_size, model_type, model
|
|
211 |
else:
|
212 |
raise ValueError(f"Invalid problem type: {problem_type}")
|
213 |
|
214 |
-
|
215 |
-
param_grid.update(model_params) #This is key to use the model_params provided by user
|
216 |
|
217 |
if use_grid_search:
|
218 |
grid_search = GridSearchCV(model, param_grid, cv=3, scoring='accuracy' if problem_type in ['Classification', 'Multiclass'] else 'neg_mean_squared_error', verbose=1, n_jobs=-1)
|
219 |
-
grid_search.fit(
|
220 |
-
model = grid_search.best_estimator_
|
221 |
-
st.write("Best hyperparameters found by Grid Search:", grid_search.best_params_)
|
222 |
|
223 |
else:
|
224 |
-
model.fit(
|
225 |
|
226 |
-
|
227 |
-
cv_scores = cross_val_score(model, X_train_scaled, y_train, cv=5, scoring='accuracy' if problem_type in ['Classification', 'Multiclass'] else 'neg_mean_squared_error') # Use scaled training data
|
228 |
st.write("Cross-validation scores:", cv_scores)
|
229 |
st.write("Mean cross-validation score:", cv_scores.mean())
|
230 |
|
231 |
-
# Evaluation
|
232 |
-
y_pred = model.predict(
|
233 |
-
metrics = {}
|
234 |
|
235 |
if problem_type == "Classification":
|
236 |
metrics['accuracy'] = accuracy_score(y_test, y_pred)
|
237 |
metrics['confusion_matrix'] = confusion_matrix(y_test, y_pred)
|
238 |
-
metrics['classification_report'] = classification_report(y_test, y_pred, output_dict=True)
|
239 |
|
240 |
elif problem_type == "Multiclass":
|
241 |
-
|
242 |
metrics['accuracy'] = accuracy_score(y_test, y_pred)
|
243 |
metrics['confusion_matrix'] = confusion_matrix(y_test, y_pred)
|
244 |
-
metrics['classification_report'] = classification_report(y_test, y_pred, output_dict=True)
|
245 |
else:
|
246 |
metrics['mse'] = mean_squared_error(y_test, y_pred)
|
247 |
metrics['r2'] = r2_score(y_test, y_pred)
|
248 |
|
249 |
-
# Feature Importance (
|
250 |
try:
|
251 |
-
result = permutation_importance(model,
|
252 |
importance = result.importances_mean
|
253 |
|
254 |
except Exception as e:
|
255 |
st.warning(f"Could not calculate feature importance: {e}")
|
256 |
importance = None
|
257 |
|
258 |
-
# Store the column order for prediction purposes
|
259 |
column_order = X.columns
|
260 |
|
261 |
return model, scaler, label_encoder, imputer_numerical, metrics, column_order, importance, X_train, y_train # Return X_train and y_train
|
@@ -263,8 +221,7 @@ def train_model(df, target, features, problem_type, test_size, model_type, model
|
|
263 |
except Exception as e:
|
264 |
st.error(f"Training failed: {str(e)}")
|
265 |
return None, None, None, None, None, None, None, None, None
|
266 |
-
|
267 |
-
# Model Validation Function
|
268 |
def validate_model(model_path, df, target, features, test_size):
|
269 |
"""Loads a model, preprocesses data, and evaluates the model on a validation set."""
|
270 |
try:
|
@@ -328,11 +285,9 @@ def validate_model(model_path, df, target, features, test_size):
|
|
328 |
# Prediction helper Function
|
329 |
def prediction_input_form(features, default_values=None):
|
330 |
"""Generates input forms for each feature and returns a dictionary of inputs.
|
331 |
-
|
332 |
Args:
|
333 |
features (list): List of feature names.
|
334 |
default_values (dict, optional): Default values for each feature. Defaults to None.
|
335 |
-
|
336 |
Returns:
|
337 |
dict: Dictionary where keys are feature names and values are user inputs.
|
338 |
"""
|
@@ -365,54 +320,18 @@ with st.sidebar:
|
|
365 |
# --------------------------
|
366 |
if app_mode == "Data Upload":
|
367 |
st.title("📤 Data Upload & Profiling")
|
368 |
-
|
369 |
-
uploaded_file = st.file_uploader("Upload your dataset (CSV/XLSX)", type=["csv", "xlsx"])
|
370 |
-
|
371 |
-
if uploaded_file:
|
372 |
-
try:
|
373 |
-
if uploaded_file.name.endswith('.csv'):
|
374 |
-
df = pd.read_csv(uploaded_file)
|
375 |
-
else:
|
376 |
-
df = pd.read_excel(uploaded_file)
|
377 |
-
|
378 |
-
st.session_state.raw_data = df
|
379 |
-
|
380 |
-
col1, col2, col3 = st.columns(3)
|
381 |
-
with col1:
|
382 |
-
st.metric("Rows", df.shape[0])
|
383 |
-
with col2:
|
384 |
-
st.metric("Columns", df.shape[1])
|
385 |
-
with col3:
|
386 |
-
st.metric("Missing Values", df.isna().sum().sum())
|
387 |
-
|
388 |
-
with st.expander("Data Preview", expanded=True):
|
389 |
-
st.dataframe(df.head(10), use_container_width=True)
|
390 |
-
|
391 |
-
if st.button("Generate Full Profile Report"):
|
392 |
-
with st.spinner("Generating comprehensive analysis..."):
|
393 |
-
pr = ProfileReport(df, explorative=True)
|
394 |
-
st_profile_report(pr)
|
395 |
-
|
396 |
-
except Exception as e:
|
397 |
-
st.error(f"Error loading file: {str(e)}")
|
398 |
|
399 |
-
# --------------------------
|
400 |
-
# Page Content
|
401 |
-
# --------------------------
|
402 |
-
if app_mode == "Data Upload":
|
403 |
-
st.title("📤 Data Upload & Profiling")
|
404 |
-
|
405 |
uploaded_file = st.file_uploader("Upload your dataset (CSV/XLSX)", type=["csv", "xlsx"])
|
406 |
-
|
407 |
if uploaded_file:
|
408 |
try:
|
409 |
if uploaded_file.name.endswith('.csv'):
|
410 |
df = pd.read_csv(uploaded_file)
|
411 |
else:
|
412 |
df = pd.read_excel(uploaded_file)
|
413 |
-
|
414 |
st.session_state.raw_data = df
|
415 |
-
|
416 |
col1, col2, col3 = st.columns(3)
|
417 |
with col1:
|
418 |
st.metric("Rows", df.shape[0])
|
@@ -420,15 +339,15 @@ if app_mode == "Data Upload":
|
|
420 |
st.metric("Columns", df.shape[1])
|
421 |
with col3:
|
422 |
st.metric("Missing Values", df.isna().sum().sum())
|
423 |
-
|
424 |
with st.expander("Data Preview", expanded=True):
|
425 |
st.dataframe(df.head(10), use_container_width=True)
|
426 |
-
|
427 |
if st.button("Generate Full Profile Report"):
|
428 |
with st.spinner("Generating comprehensive analysis..."):
|
429 |
pr = ProfileReport(df, explorative=True)
|
430 |
st_profile_report(pr)
|
431 |
-
|
432 |
except Exception as e:
|
433 |
st.error(f"Error loading file: {str(e)}")
|
434 |
|
@@ -440,16 +359,21 @@ elif app_mode == "Data Cleaning":
|
|
440 |
|
441 |
if st.session_state.raw_data is None:
|
442 |
st.warning("Please upload data first")
|
443 |
-
st.stop()
|
444 |
|
445 |
-
|
446 |
-
|
447 |
-
|
448 |
-
|
449 |
-
st.
|
|
|
|
|
450 |
|
|
|
451 |
# Data Health Dashboard
|
|
|
452 |
enhance_section_title("Data Health Dashboard", "📊")
|
|
|
453 |
with st.expander("📊 Data Health Dashboard", expanded=True):
|
454 |
col1, col2, col3 = st.columns(3)
|
455 |
with col1:
|
@@ -465,15 +389,19 @@ elif app_mode == "Data Cleaning":
|
|
465 |
profile = ProfileReport(df, minimal=True)
|
466 |
st_profile_report(profile)
|
467 |
|
|
|
468 |
# Undo Functionality
|
|
|
469 |
if len(st.session_state.data_versions) > 1:
|
470 |
if st.button("⏮️ Undo Last Action"):
|
471 |
st.session_state.data_versions.pop() # Remove current version
|
472 |
st.session_state.cleaned_data = st.session_state.data_versions[-1].copy() # Set data
|
473 |
st.success("Last action undone!")
|
474 |
-
st.
|
475 |
|
|
|
476 |
# Missing Value Handling
|
|
|
477 |
enhance_section_title("Missing Values Treatment", "🔍")
|
478 |
with st.expander("🔍 Missing Values Treatment", expanded=True):
|
479 |
missing_cols = df.columns[df.isna().any()].tolist()
|
@@ -491,27 +419,35 @@ elif app_mode == "Data Cleaning":
|
|
491 |
custom_val = st.text_input("Enter custom value")
|
492 |
|
493 |
if st.button("Apply Treatment (Missing)"):
|
494 |
-
|
495 |
-
|
496 |
-
|
497 |
-
|
498 |
-
|
499 |
-
|
500 |
-
|
501 |
-
|
502 |
-
|
503 |
-
|
504 |
-
|
505 |
-
|
506 |
-
|
507 |
-
|
508 |
-
|
509 |
-
|
|
|
|
|
|
|
|
|
510 |
|
511 |
-
|
512 |
-
|
|
|
|
|
513 |
|
|
|
514 |
# Data Type Conversion
|
|
|
515 |
enhance_section_title("Data Type Conversion", "🔄")
|
516 |
with st.expander("🔄 Data Type Conversion"):
|
517 |
col_to_convert = st.selectbox("Select column", df.columns)
|
@@ -524,8 +460,8 @@ elif app_mode == "Data Cleaning":
|
|
524 |
date_format = st.text_input("Date format (e.g. %Y-%m-%d)", "%Y-%m-%d")
|
525 |
|
526 |
if st.button("Convert (Data Type)"):
|
527 |
-
new_df = df.copy()
|
528 |
try:
|
|
|
529 |
if new_type == "String":
|
530 |
new_df[col_to_convert] = new_df[col_to_convert].astype(str)
|
531 |
elif new_type == "Integer":
|
@@ -544,47 +480,61 @@ elif app_mode == "Data Cleaning":
|
|
544 |
new_df[col_to_convert] = pd.to_datetime(new_df[col_to_convert], format=date_format, errors='coerce')
|
545 |
|
546 |
update_cleaned_data(new_df)
|
547 |
-
st.
|
548 |
except Exception as e:
|
549 |
st.error(f"Error: {str(e)}")
|
550 |
|
|
|
551 |
# Drop Columns
|
|
|
552 |
enhance_section_title("Drop Columns", "🗑️")
|
553 |
with st.expander("🗑️ Drop Columns"):
|
554 |
columns_to_drop = st.multiselect("Select columns to drop", df.columns)
|
555 |
if columns_to_drop:
|
556 |
st.warning(f"Will drop: {', '.join(columns_to_drop)}")
|
557 |
if st.button("Confirm Drop (Columns)"):
|
558 |
-
new_df = df.
|
|
|
559 |
update_cleaned_data(new_df)
|
560 |
-
st.
|
561 |
|
|
|
562 |
# Label Encoding
|
|
|
563 |
enhance_section_title("Label Encoding", "🔢")
|
564 |
with st.expander("🔢 Label Encoding"):
|
565 |
data_to_encode = st.multiselect("Select categorical columns to encode", df.select_dtypes(include='object').columns)
|
566 |
if data_to_encode:
|
567 |
if st.button("Apply Label Encoding (Encoding)"):
|
568 |
new_df = df.copy()
|
|
|
569 |
for col in data_to_encode:
|
570 |
le = LabelEncoder()
|
571 |
new_df[col] = le.fit_transform(new_df[col].astype(str))
|
|
|
572 |
update_cleaned_data(new_df)
|
573 |
-
st.
|
574 |
|
|
|
575 |
# StandardScaler
|
|
|
576 |
enhance_section_title("StandardScaler", "📏")
|
577 |
with st.expander("📏 StandardScaler"):
|
578 |
scale_cols = st.multiselect("Select numeric columns to scale", df.select_dtypes(include=np.number).columns)
|
579 |
if scale_cols:
|
580 |
if st.button("Apply StandardScaler (Scaling)"):
|
581 |
-
|
582 |
-
|
583 |
-
|
584 |
-
|
585 |
-
|
|
|
|
|
|
|
586 |
|
|
|
587 |
# Pattern-Based Cleaning
|
|
|
588 |
enhance_section_title("Pattern-Based Cleaning", "🕵️")
|
589 |
with st.expander("🕵️ Pattern-Based Cleaning"):
|
590 |
selected_col = st.selectbox("Select text column", df.select_dtypes(include='object').columns)
|
@@ -592,12 +542,17 @@ elif app_mode == "Data Cleaning":
|
|
592 |
replacement = st.text_input("Replacement value")
|
593 |
|
594 |
if st.button("Apply Pattern Replacement (Replace)"):
|
595 |
-
|
596 |
-
|
597 |
-
|
598 |
-
|
|
|
|
|
|
|
599 |
|
|
|
600 |
# Bulk Operations
|
|
|
601 |
enhance_section_title("Bulk Actions", "🚀")
|
602 |
with st.expander("🚀 Bulk Actions"):
|
603 |
if st.button("Auto-Clean Common Issues (Cleaning)"):
|
@@ -607,12 +562,17 @@ elif app_mode == "Data Cleaning":
|
|
607 |
text_cols = new_df.select_dtypes(include='object').columns
|
608 |
new_df[text_cols] = new_df[text_cols].apply(lambda x: x.str.strip())
|
609 |
update_cleaned_data(new_df)
|
610 |
-
st.
|
611 |
|
|
|
612 |
# Cleaned Data Preview
|
613 |
-
|
614 |
-
|
615 |
-
|
|
|
|
|
|
|
|
|
616 |
|
617 |
|
618 |
# --------------------------
|
@@ -725,7 +685,7 @@ elif app_mode == "EDA":
|
|
725 |
try:
|
726 |
fig = None # Initialize fig to None
|
727 |
if st.session_state.cleaned_data is None:
|
728 |
-
st.warning("Please
|
729 |
st.stop()
|
730 |
|
731 |
# Generate appropriate visualization with input validation
|
@@ -1046,12 +1006,55 @@ elif app_mode == "Model Training":
|
|
1046 |
st.stop()
|
1047 |
|
1048 |
# Call the training function
|
1049 |
-
model, scaler, label_encoder, imputer_numerical, metrics, column_order, importance
|
1050 |
|
1051 |
if model: # Only proceed if training was successful
|
1052 |
st.success("Model trained successfully!")
|
1053 |
|
1054 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1055 |
|
1056 |
# Save Model
|
1057 |
st.subheader("Save Model")
|
@@ -1107,7 +1110,8 @@ elif app_mode == "Model Training":
|
|
1107 |
st.metric("MSE", f"{validation_metrics['mse']:.2f}")
|
1108 |
st.metric("R2", f"{validation_metrics['r2']:.2f}")
|
1109 |
|
1110 |
-
|
|
|
1111 |
st.title("🔮 Predictive Analytics - Informed Business Decisions")
|
1112 |
|
1113 |
if st.session_state.get("model") is None:
|
@@ -1132,8 +1136,8 @@ elif app_mode == "Predictions":
|
|
1132 |
|
1133 |
with col2:
|
1134 |
st.subheader("Data Overview")
|
1135 |
-
input_df = pd.DataFrame([input_data])
|
1136 |
-
st.dataframe(input_df,use_container_width=True)
|
1137 |
|
1138 |
# Predicts Function and Displays Result
|
1139 |
if st.button("Generate Prediction & Insights"):
|
@@ -1147,14 +1151,12 @@ elif app_mode == "Predictions":
|
|
1147 |
|
1148 |
# 3. One-hot encode (handle unseen categories)
|
1149 |
categorical_features = input_df.select_dtypes(exclude=np.number).columns
|
1150 |
-
input_df = pd.get_dummies(input_df, columns=categorical_features, dummy_na=False)
|
1151 |
|
1152 |
# 4. Ensure correct column order
|
1153 |
-
# Add missing columns with 0 values
|
1154 |
for col in column_order:
|
1155 |
if col not in input_df.columns:
|
1156 |
input_df[col] = 0
|
1157 |
-
# Reorder Columns
|
1158 |
input_df = input_df[column_order]
|
1159 |
|
1160 |
# 5. Scale the input
|
@@ -1177,29 +1179,21 @@ elif app_mode == "Predictions":
|
|
1177 |
|
1178 |
if problem_type == "Classification":
|
1179 |
explainer = shap.TreeExplainer(model)
|
1180 |
-
shap_values = explainer.shap_values(scaled_input)
|
1181 |
-
|
1182 |
-
|
1183 |
-
fig = shap.force_plot(explainer.expected_value[1], shap_values[1], input_df, matplotlib=False,link="logit") # shap_values[1] for class 1 - force plot
|
1184 |
-
st.components.v1.html(shap.getjs() + fig.html(), height=400, width=900) # Adjust height and width as needed.
|
1185 |
-
|
1186 |
else:
|
1187 |
-
explainer = shap.TreeExplainer(model)
|
1188 |
-
shap_values = explainer.shap_values(scaled_input)
|
1189 |
-
|
1190 |
-
|
1191 |
-
st.components.v1.html(shap.getjs() + fig.html(), height=400, width=900) # Adjust height and width as needed.
|
1192 |
|
1193 |
st.write("The visualization above explains how each feature contributed to the final prediction.")
|
1194 |
|
1195 |
# 9. Add Permutation Feature Importance (for more global understanding)
|
1196 |
try:
|
1197 |
enhance_section_title("Global Feature Importance", "🌍")
|
1198 |
-
X = pd.DataFrame(scaler.transform(
|
1199 |
-
#X = pd.DataFrame(scaler.transform(input_df), columns = input_df.columns)
|
1200 |
-
#X = input_df[input_df.columns]
|
1201 |
-
X_train = model_data['X_train'] #Get X train
|
1202 |
-
y_train = model_data['y_train'] #Get Y train
|
1203 |
result = permutation_importance(model, X, input_df, n_repeats=10, random_state=42)
|
1204 |
importance = result.importances_mean
|
1205 |
|
@@ -1210,55 +1204,4 @@ elif app_mode == "Predictions":
|
|
1210 |
st.warning(f"Could not calculate permutation feature importance: {e}")
|
1211 |
|
1212 |
except Exception as e:
|
1213 |
-
st.error(f"Prediction failed: {str(e)}")
|
1214 |
-
|
1215 |
-
# Force rerun Streamlit app after data cleaning operations
|
1216 |
-
st.experimental_rerun()
|
1217 |
-
|
1218 |
-
if __name__ == "__main__":
|
1219 |
-
# Session State Initialization
|
1220 |
-
if 'raw_data' not in st.session_state:
|
1221 |
-
st.session_state.raw_data = None
|
1222 |
-
if 'cleaned_data' not in st.session_state:
|
1223 |
-
st.session_state.cleaned_data = None
|
1224 |
-
if 'model' not in st.session_state:
|
1225 |
-
st.session_state.model = None
|
1226 |
-
if 'data_versions' not in st.session_state:
|
1227 |
-
st.session_state.data_versions = []
|
1228 |
-
|
1229 |
-
# Custom Styling (Keep it in main if needed)
|
1230 |
-
st.markdown("""
|
1231 |
-
<style>
|
1232 |
-
.main {background-color: #f8f9fa;}
|
1233 |
-
.sidebar .sidebar-content {background-color: #2c3e50;}
|
1234 |
-
.stButton>button {background-color: #3498db; color: white;}
|
1235 |
-
.stTextInput>div>div>input {border: 1px solid #3498db;}
|
1236 |
-
.stSelectbox>div>div>select {border: 1px solid #3498db;}
|
1237 |
-
.stSlider>div>div>div>div {background-color: #3498db;}
|
1238 |
-
.metric {padding: 15px; background-color: white; border-radius: 10px; box-shadow: 0 2px 5px rgba(0,0,0,0.1);}
|
1239 |
-
</style>
|
1240 |
-
""", unsafe_allow_html=True)
|
1241 |
-
|
1242 |
-
# Sidebar Navigation (Keep it in main)
|
1243 |
-
with st.sidebar:
|
1244 |
-
st.title("🔮 DataInsight Pro")
|
1245 |
-
app_mode = st.selectbox(
|
1246 |
-
"Navigation",
|
1247 |
-
["Data Upload", "Data Cleaning", "EDA", "Model Training", "Predictions"],
|
1248 |
-
format_func=lambda x: f"📌 {x}"
|
1249 |
-
)
|
1250 |
-
st.markdown("---")
|
1251 |
-
st.markdown("Created by Calvin Allen-Crawford")
|
1252 |
-
st.markdown("v1.0 | © 2025")
|
1253 |
-
|
1254 |
-
# Call app mode function based on selection
|
1255 |
-
if app_mode == "Data Upload":
|
1256 |
-
app_mode_data_upload()
|
1257 |
-
elif app_mode == "Data Cleaning":
|
1258 |
-
app_mode_data_cleaning()
|
1259 |
-
elif app_mode == "EDA":
|
1260 |
-
app_mode_eda()
|
1261 |
-
elif app_mode == "Model Training":
|
1262 |
-
app_mode_model_training()
|
1263 |
-
elif app_mode == "Predictions":
|
1264 |
-
app_mode_predictions()
|
|
|
12 |
from sklearn.metrics import accuracy_score, mean_squared_error
|
13 |
from ydata_profiling import ProfileReport
|
14 |
from streamlit_pandas_profiling import st_profile_report
|
15 |
+
import joblib # For saving and loading models
|
16 |
+
import os # For file directory
|
17 |
import shap
|
18 |
from datetime import datetime
|
19 |
+
from stqdm import stqdm
|
20 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
21 |
|
22 |
# --------------------------
|
23 |
# Helper Functions
|
|
|
33 |
st.session_state.data_versions.append(df.copy())
|
34 |
st.success("Action completed successfully!")
|
35 |
|
36 |
+
@st.cache_data
|
37 |
def generate_quality_report(df):
|
38 |
"""Generate comprehensive data quality report"""
|
39 |
report = {
|
|
|
59 |
})
|
60 |
report['columns'][col] = col_report
|
61 |
return report
|
62 |
+
|
63 |
+
@st.cache_data
|
64 |
def train_model(df, target, features, problem_type, test_size, model_type, model_params, use_grid_search=False):
|
65 |
"""Trains a model with hyperparameter tuning, cross-validation, and customizable model architecture."""
|
66 |
|
|
|
68 |
X = df[features]
|
69 |
y = df[target]
|
70 |
|
71 |
+
# Input Validation (rest of the input validation code remains the same)
|
72 |
if target not in df.columns:
|
73 |
raise ValueError(f"Target variable '{target}' not found in DataFrame.")
|
74 |
for feature in features:
|
75 |
if feature not in df.columns:
|
76 |
raise ValueError(f"Feature '{feature}' not found in DataFrame.")
|
77 |
|
78 |
+
# Preprocessing Pipeline (rest of preprocessing code remains the same)
|
|
|
79 |
numerical_features = X.select_dtypes(include=np.number).columns
|
80 |
categorical_features = X.select_dtypes(exclude=np.number).columns
|
81 |
|
82 |
imputer_numerical = SimpleImputer(strategy='mean') # Or 'median', 'most_frequent', 'constant'
|
83 |
X[numerical_features] = imputer_numerical.fit_transform(X[numerical_features])
|
84 |
|
85 |
+
X = pd.get_dummies(X, columns=categorical_features, dummy_na=False)
|
|
|
86 |
|
87 |
+
label_encoder = None # Initialize label_encoder
|
|
|
88 |
if problem_type == "Classification" or problem_type == "Multiclass":
|
89 |
label_encoder = LabelEncoder()
|
90 |
y = label_encoder.fit_transform(y)
|
91 |
|
|
|
|
|
92 |
X_train, X_test, y_train, y_test = train_test_split(
|
93 |
X, y, test_size=test_size, random_state=42
|
94 |
)
|
95 |
|
96 |
+
scaler = StandardScaler()
|
97 |
+
X_train = scaler.fit_transform(X_train)
|
98 |
+
X_test = scaler.transform(X_test)
|
|
|
99 |
|
100 |
+
# Model Selection and Hyperparameter Tuning (rest of model selection code remains the same)
|
101 |
if problem_type == "Regression":
|
102 |
if model_type == "Random Forest":
|
103 |
model = RandomForestRegressor(random_state=42)
|
|
|
114 |
'max_depth': [3, 5]
|
115 |
}
|
116 |
elif model_type == "Neural Network":
|
117 |
+
model = MLPRegressor(random_state=42, max_iter=500)
|
118 |
param_grid = {
|
119 |
+
'hidden_layer_sizes': [(50,), (100,), (50, 50)],
|
120 |
'activation': ['relu', 'tanh'],
|
121 |
'alpha': [0.0001, 0.001]
|
122 |
}
|
|
|
140 |
'max_depth': [3, 5]
|
141 |
}
|
142 |
elif model_type == "Neural Network":
|
143 |
+
model = MLPClassifier(random_state=42, max_iter=500)
|
144 |
param_grid = {
|
145 |
+
'hidden_layer_sizes': [(50,), (100,), (50, 50)],
|
146 |
'activation': ['relu', 'tanh'],
|
147 |
'alpha': [0.0001, 0.001]
|
148 |
}
|
|
|
152 |
elif problem_type == "Multiclass": #Multiclass
|
153 |
|
154 |
if model_type == "Logistic Regression":
|
155 |
+
model = LogisticRegression(random_state=42, solver='liblinear', multi_class='ovr')
|
156 |
+
param_grid = {'C': [0.1, 1.0, 10.0]}
|
157 |
|
158 |
elif model_type == "Support Vector Machine":
|
159 |
+
model = SVC(random_state=42, probability=True)
|
160 |
param_grid = {'C': [0.1, 1.0, 10.0], 'kernel': ['rbf', 'linear']}
|
161 |
|
162 |
elif model_type == "Random Forest":
|
|
|
165 |
'n_estimators': [100, 200],
|
166 |
'max_depth': [None, 5, 10],
|
167 |
'min_samples_split': [2, 5],
|
168 |
+
'criterion': ['gini', 'entropy']
|
169 |
}
|
170 |
|
171 |
else:
|
|
|
173 |
else:
|
174 |
raise ValueError(f"Invalid problem type: {problem_type}")
|
175 |
|
176 |
+
param_grid.update(model_params)
|
|
|
177 |
|
178 |
if use_grid_search:
|
179 |
grid_search = GridSearchCV(model, param_grid, cv=3, scoring='accuracy' if problem_type in ['Classification', 'Multiclass'] else 'neg_mean_squared_error', verbose=1, n_jobs=-1)
|
180 |
+
grid_search.fit(X_train, y_train)
|
181 |
+
model = grid_search.best_estimator_
|
182 |
+
st.write("Best hyperparameters found by Grid Search:", grid_search.best_params_)
|
183 |
|
184 |
else:
|
185 |
+
model.fit(X_train, y_train)
|
186 |
|
187 |
+
cv_scores = cross_val_score(model, X_train, y_train, cv=5, scoring='accuracy' if problem_type in ['Classification', 'Multiclass'] else 'neg_mean_squared_error')
|
|
|
188 |
st.write("Cross-validation scores:", cv_scores)
|
189 |
st.write("Mean cross-validation score:", cv_scores.mean())
|
190 |
|
191 |
+
# Evaluation (rest of evaluation code remains the same)
|
192 |
+
y_pred = model.predict(X_test)
|
193 |
+
metrics = {}
|
194 |
|
195 |
if problem_type == "Classification":
|
196 |
metrics['accuracy'] = accuracy_score(y_test, y_pred)
|
197 |
metrics['confusion_matrix'] = confusion_matrix(y_test, y_pred)
|
198 |
+
metrics['classification_report'] = classification_report(y_test, y_pred, output_dict=True)
|
199 |
|
200 |
elif problem_type == "Multiclass":
|
|
|
201 |
metrics['accuracy'] = accuracy_score(y_test, y_pred)
|
202 |
metrics['confusion_matrix'] = confusion_matrix(y_test, y_pred)
|
203 |
+
metrics['classification_report'] = classification_report(y_test, y_pred, output_dict=True)
|
204 |
else:
|
205 |
metrics['mse'] = mean_squared_error(y_test, y_pred)
|
206 |
metrics['r2'] = r2_score(y_test, y_pred)
|
207 |
|
208 |
+
# Feature Importance (rest of feature importance code remains the same)
|
209 |
try:
|
210 |
+
result = permutation_importance(model, X_test, y_test, n_repeats=10, random_state=42)
|
211 |
importance = result.importances_mean
|
212 |
|
213 |
except Exception as e:
|
214 |
st.warning(f"Could not calculate feature importance: {e}")
|
215 |
importance = None
|
216 |
|
|
|
217 |
column_order = X.columns
|
218 |
|
219 |
return model, scaler, label_encoder, imputer_numerical, metrics, column_order, importance, X_train, y_train # Return X_train and y_train
|
|
|
221 |
except Exception as e:
|
222 |
st.error(f"Training failed: {str(e)}")
|
223 |
return None, None, None, None, None, None, None, None, None
|
224 |
+
|
|
|
225 |
def validate_model(model_path, df, target, features, test_size):
|
226 |
"""Loads a model, preprocesses data, and evaluates the model on a validation set."""
|
227 |
try:
|
|
|
285 |
# Prediction helper Function
|
286 |
def prediction_input_form(features, default_values=None):
|
287 |
"""Generates input forms for each feature and returns a dictionary of inputs.
|
|
|
288 |
Args:
|
289 |
features (list): List of feature names.
|
290 |
default_values (dict, optional): Default values for each feature. Defaults to None.
|
|
|
291 |
Returns:
|
292 |
dict: Dictionary where keys are feature names and values are user inputs.
|
293 |
"""
|
|
|
320 |
# --------------------------
|
321 |
if app_mode == "Data Upload":
|
322 |
st.title("📤 Data Upload & Profiling")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
323 |
|
|
|
|
|
|
|
|
|
|
|
|
|
324 |
uploaded_file = st.file_uploader("Upload your dataset (CSV/XLSX)", type=["csv", "xlsx"])
|
325 |
+
|
326 |
if uploaded_file:
|
327 |
try:
|
328 |
if uploaded_file.name.endswith('.csv'):
|
329 |
df = pd.read_csv(uploaded_file)
|
330 |
else:
|
331 |
df = pd.read_excel(uploaded_file)
|
332 |
+
|
333 |
st.session_state.raw_data = df
|
334 |
+
|
335 |
col1, col2, col3 = st.columns(3)
|
336 |
with col1:
|
337 |
st.metric("Rows", df.shape[0])
|
|
|
339 |
st.metric("Columns", df.shape[1])
|
340 |
with col3:
|
341 |
st.metric("Missing Values", df.isna().sum().sum())
|
342 |
+
|
343 |
with st.expander("Data Preview", expanded=True):
|
344 |
st.dataframe(df.head(10), use_container_width=True)
|
345 |
+
|
346 |
if st.button("Generate Full Profile Report"):
|
347 |
with st.spinner("Generating comprehensive analysis..."):
|
348 |
pr = ProfileReport(df, explorative=True)
|
349 |
st_profile_report(pr)
|
350 |
+
|
351 |
except Exception as e:
|
352 |
st.error(f"Error loading file: {str(e)}")
|
353 |
|
|
|
359 |
|
360 |
if st.session_state.raw_data is None:
|
361 |
st.warning("Please upload data first")
|
362 |
+
st.stop()
|
363 |
|
364 |
+
df = st.session_state.raw_data.copy() # Ensure df is defined in this section
|
365 |
+
|
366 |
+
# Initialize session state (only if it's not already there)
|
367 |
+
if 'data_versions' not in st.session_state:
|
368 |
+
st.session_state.data_versions = [st.session_state.raw_data.copy()]
|
369 |
+
if 'cleaned_data' not in st.session_state: # Added a conditional value
|
370 |
+
st.session_state.cleaned_data = st.session_state.raw_data.copy()
|
371 |
|
372 |
+
# --------------------------
|
373 |
# Data Health Dashboard
|
374 |
+
# --------------------------
|
375 |
enhance_section_title("Data Health Dashboard", "📊")
|
376 |
+
|
377 |
with st.expander("📊 Data Health Dashboard", expanded=True):
|
378 |
col1, col2, col3 = st.columns(3)
|
379 |
with col1:
|
|
|
389 |
profile = ProfileReport(df, minimal=True)
|
390 |
st_profile_report(profile)
|
391 |
|
392 |
+
# --------------------------
|
393 |
# Undo Functionality
|
394 |
+
# --------------------------
|
395 |
if len(st.session_state.data_versions) > 1:
|
396 |
if st.button("⏮️ Undo Last Action"):
|
397 |
st.session_state.data_versions.pop() # Remove current version
|
398 |
st.session_state.cleaned_data = st.session_state.data_versions[-1].copy() # Set data
|
399 |
st.success("Last action undone!")
|
400 |
+
st.rerun() #Force re-run after undo
|
401 |
|
402 |
+
# --------------------------
|
403 |
# Missing Value Handling
|
404 |
+
# --------------------------
|
405 |
enhance_section_title("Missing Values Treatment", "🔍")
|
406 |
with st.expander("🔍 Missing Values Treatment", expanded=True):
|
407 |
missing_cols = df.columns[df.isna().any()].tolist()
|
|
|
419 |
custom_val = st.text_input("Enter custom value")
|
420 |
|
421 |
if st.button("Apply Treatment (Missing)"):
|
422 |
+
try:
|
423 |
+
new_df = df.copy() # Create a copy to modify
|
424 |
+
if method == "Drop Missing":
|
425 |
+
new_df = new_df.dropna(subset=cols)
|
426 |
+
elif method == "Mean/Median":
|
427 |
+
for col in cols:
|
428 |
+
if pd.api.types.is_numeric_dtype(new_df[col]):
|
429 |
+
new_df[col] = new_df[col].fillna(new_df[col].median())
|
430 |
+
else:
|
431 |
+
new_df[col] = new_df[col].fillna(new_df[col].mode()[0])
|
432 |
+
elif method == "Custom Value" and custom_val:
|
433 |
+
for col in cols:
|
434 |
+
new_df[col] = new_df[col].fillna(custom_val)
|
435 |
+
elif method == "Forward Fill":
|
436 |
+
new_df[cols] = new_df[cols].ffill()
|
437 |
+
elif method == "Backward Fill":
|
438 |
+
new_df[cols] = new_df[cols].bfill()
|
439 |
+
|
440 |
+
update_cleaned_data(new_df)
|
441 |
+
st.rerun() #Force re-run after apply
|
442 |
|
443 |
+
except Exception as e:
|
444 |
+
st.error(f"Error: {str(e)}")
|
445 |
+
else:
|
446 |
+
st.success("✨ No missing values found!")
|
447 |
|
448 |
+
# --------------------------
|
449 |
# Data Type Conversion
|
450 |
+
# --------------------------
|
451 |
enhance_section_title("Data Type Conversion", "🔄")
|
452 |
with st.expander("🔄 Data Type Conversion"):
|
453 |
col_to_convert = st.selectbox("Select column", df.columns)
|
|
|
460 |
date_format = st.text_input("Date format (e.g. %Y-%m-%d)", "%Y-%m-%d")
|
461 |
|
462 |
if st.button("Convert (Data Type)"):
|
|
|
463 |
try:
|
464 |
+
new_df = df.copy()
|
465 |
if new_type == "String":
|
466 |
new_df[col_to_convert] = new_df[col_to_convert].astype(str)
|
467 |
elif new_type == "Integer":
|
|
|
480 |
new_df[col_to_convert] = pd.to_datetime(new_df[col_to_convert], format=date_format, errors='coerce')
|
481 |
|
482 |
update_cleaned_data(new_df)
|
483 |
+
st.rerun() #Force re-run after apply
|
484 |
except Exception as e:
|
485 |
st.error(f"Error: {str(e)}")
|
486 |
|
487 |
+
# --------------------------
|
488 |
# Drop Columns
|
489 |
+
# --------------------------
|
490 |
enhance_section_title("Drop Columns", "🗑️")
|
491 |
with st.expander("🗑️ Drop Columns"):
|
492 |
columns_to_drop = st.multiselect("Select columns to drop", df.columns)
|
493 |
if columns_to_drop:
|
494 |
st.warning(f"Will drop: {', '.join(columns_to_drop)}")
|
495 |
if st.button("Confirm Drop (Columns)"):
|
496 |
+
new_df = df.copy()
|
497 |
+
new_df = new_df.drop(columns=columns_to_drop)
|
498 |
update_cleaned_data(new_df)
|
499 |
+
st.rerun() #Force re-run after apply
|
500 |
|
501 |
+
# --------------------------
|
502 |
# Label Encoding
|
503 |
+
# --------------------------
|
504 |
enhance_section_title("Label Encoding", "🔢")
|
505 |
with st.expander("🔢 Label Encoding"):
|
506 |
data_to_encode = st.multiselect("Select categorical columns to encode", df.select_dtypes(include='object').columns)
|
507 |
if data_to_encode:
|
508 |
if st.button("Apply Label Encoding (Encoding)"):
|
509 |
new_df = df.copy()
|
510 |
+
label_encoders = {}
|
511 |
for col in data_to_encode:
|
512 |
le = LabelEncoder()
|
513 |
new_df[col] = le.fit_transform(new_df[col].astype(str))
|
514 |
+
label_encoders[col] = le
|
515 |
update_cleaned_data(new_df)
|
516 |
+
st.rerun() #Force re-run after apply
|
517 |
|
518 |
+
# --------------------------
|
519 |
# StandardScaler
|
520 |
+
# --------------------------
|
521 |
enhance_section_title("StandardScaler", "📏")
|
522 |
with st.expander("📏 StandardScaler"):
|
523 |
scale_cols = st.multiselect("Select numeric columns to scale", df.select_dtypes(include=np.number).columns)
|
524 |
if scale_cols:
|
525 |
if st.button("Apply StandardScaler (Scaling)"):
|
526 |
+
try:
|
527 |
+
new_df = df.copy()
|
528 |
+
scaler = StandardScaler()
|
529 |
+
new_df[scale_cols] = scaler.fit_transform(new_df[scale_cols])
|
530 |
+
update_cleaned_data(new_df)
|
531 |
+
st.rerun()#Force re-run after apply
|
532 |
+
except Exception as e:
|
533 |
+
st.error(f"Error: {str(e)}")
|
534 |
|
535 |
+
# --------------------------
|
536 |
# Pattern-Based Cleaning
|
537 |
+
# --------------------------
|
538 |
enhance_section_title("Pattern-Based Cleaning", "🕵️")
|
539 |
with st.expander("🕵️ Pattern-Based Cleaning"):
|
540 |
selected_col = st.selectbox("Select text column", df.select_dtypes(include='object').columns)
|
|
|
542 |
replacement = st.text_input("Replacement value")
|
543 |
|
544 |
if st.button("Apply Pattern Replacement (Replace)"):
|
545 |
+
try:
|
546 |
+
new_df = df.copy()
|
547 |
+
new_df[selected_col] = new_df[selected_col].str.replace(pattern, replacement, regex=True)
|
548 |
+
update_cleaned_data(new_df)
|
549 |
+
st.rerun() #Force re-run after apply
|
550 |
+
except Exception as e:
|
551 |
+
st.error(f"Error: {str(e)}")
|
552 |
|
553 |
+
# --------------------------
|
554 |
# Bulk Operations
|
555 |
+
# --------------------------
|
556 |
enhance_section_title("Bulk Actions", "🚀")
|
557 |
with st.expander("🚀 Bulk Actions"):
|
558 |
if st.button("Auto-Clean Common Issues (Cleaning)"):
|
|
|
562 |
text_cols = new_df.select_dtypes(include='object').columns
|
563 |
new_df[text_cols] = new_df[text_cols].apply(lambda x: x.str.strip())
|
564 |
update_cleaned_data(new_df)
|
565 |
+
st.rerun() #Force re-run after apply
|
566 |
|
567 |
+
# --------------------------
|
568 |
# Cleaned Data Preview
|
569 |
+
# --------------------------
|
570 |
+
if st.session_state.get("cleaned_data") is not None:
|
571 |
+
enhance_section_title("Cleaned Data Preview", "✨")
|
572 |
+
with st.expander("✨ Cleaned Data Preview", expanded=True):
|
573 |
+
st.dataframe(st.session_state.cleaned_data.head(), use_container_width=True)
|
574 |
+
|
575 |
+
|
576 |
|
577 |
|
578 |
# --------------------------
|
|
|
685 |
try:
|
686 |
fig = None # Initialize fig to None
|
687 |
if st.session_state.cleaned_data is None:
|
688 |
+
st.warning("Please upload data first")
|
689 |
st.stop()
|
690 |
|
691 |
# Generate appropriate visualization with input validation
|
|
|
1006 |
st.stop()
|
1007 |
|
1008 |
# Call the training function
|
1009 |
+
model, scaler, label_encoder, imputer_numerical, metrics, column_order, importance = train_model(df.copy(), target, features, problem_type, test_size, model_type, model_params, use_grid_search) # Pass a copy to avoid modifying the original
|
1010 |
|
1011 |
if model: # Only proceed if training was successful
|
1012 |
st.success("Model trained successfully!")
|
1013 |
|
1014 |
+
# Display Metrics
|
1015 |
+
st.subheader("Model Evaluation Metrics")
|
1016 |
+
if problem_type in ["Classification", "Multiclass"]: #Combined here
|
1017 |
+
st.metric("Accuracy", f"{metrics['accuracy']:.2%}")
|
1018 |
+
|
1019 |
+
# Confusion Matrix Visualization
|
1020 |
+
st.subheader("Confusion Matrix")
|
1021 |
+
cm = metrics['confusion_matrix']
|
1022 |
+
class_names = [str(i) for i in np.unique(df[target])] #Get original class names
|
1023 |
+
fig_cm = px.imshow(cm,
|
1024 |
+
labels=dict(x="Predicted", y="Actual"),
|
1025 |
+
x=class_names,
|
1026 |
+
y=class_names,
|
1027 |
+
color_continuous_scale="Viridis")
|
1028 |
+
st.plotly_chart(fig_cm, use_container_width=True)
|
1029 |
+
|
1030 |
+
# Classification Report
|
1031 |
+
st.subheader("Classification Report")
|
1032 |
+
report = metrics['classification_report']
|
1033 |
+
report_df = pd.DataFrame(report).transpose()
|
1034 |
+
st.dataframe(report_df)
|
1035 |
+
|
1036 |
+
else:
|
1037 |
+
st.metric("MSE", f"{metrics['mse']:.2f}")
|
1038 |
+
st.metric("R2", f"{metrics['r2']:.2f}")
|
1039 |
+
|
1040 |
+
# Feature Importance
|
1041 |
+
st.subheader("Feature Importance")
|
1042 |
+
try:
|
1043 |
+
fig_importance = px.bar(
|
1044 |
+
x=importance,
|
1045 |
+
y=column_order, #Use stored column order
|
1046 |
+
orientation='h',
|
1047 |
+
title="Feature Importance"
|
1048 |
+
)
|
1049 |
+
st.plotly_chart(fig_importance, use_container_width=True)
|
1050 |
+
except Exception as e:
|
1051 |
+
st.warning(f"Could not display feature importance: {e}")
|
1052 |
+
|
1053 |
+
# Explainable AI (Placeholder)
|
1054 |
+
st.subheader("Explainable AI (XAI)")
|
1055 |
+
st.write("Future implementation will include model explanations using techniques like SHAP or LIME.") #To be implemented
|
1056 |
+
if st.checkbox("Show a random model explanation (example)"): #Example of a feature, to be implemented
|
1057 |
+
st.write("This feature is important because...")
|
1058 |
|
1059 |
# Save Model
|
1060 |
st.subheader("Save Model")
|
|
|
1110 |
st.metric("MSE", f"{validation_metrics['mse']:.2f}")
|
1111 |
st.metric("R2", f"{validation_metrics['r2']:.2f}")
|
1112 |
|
1113 |
+
# Predictions Section (Fixed)
|
1114 |
+
if app_mode == "Predictions":
|
1115 |
st.title("🔮 Predictive Analytics - Informed Business Decisions")
|
1116 |
|
1117 |
if st.session_state.get("model") is None:
|
|
|
1136 |
|
1137 |
with col2:
|
1138 |
st.subheader("Data Overview")
|
1139 |
+
input_df = pd.DataFrame([input_data]) # Make DataFrame
|
1140 |
+
st.dataframe(input_df, use_container_width=True) # DataFrame of the input to see it
|
1141 |
|
1142 |
# Predicts Function and Displays Result
|
1143 |
if st.button("Generate Prediction & Insights"):
|
|
|
1151 |
|
1152 |
# 3. One-hot encode (handle unseen categories)
|
1153 |
categorical_features = input_df.select_dtypes(exclude=np.number).columns
|
1154 |
+
input_df = pd.get_dummies(input_df, columns=categorical_features, dummy_na=False)
|
1155 |
|
1156 |
# 4. Ensure correct column order
|
|
|
1157 |
for col in column_order:
|
1158 |
if col not in input_df.columns:
|
1159 |
input_df[col] = 0
|
|
|
1160 |
input_df = input_df[column_order]
|
1161 |
|
1162 |
# 5. Scale the input
|
|
|
1179 |
|
1180 |
if problem_type == "Classification":
|
1181 |
explainer = shap.TreeExplainer(model)
|
1182 |
+
shap_values = explainer.shap_values(scaled_input)
|
1183 |
+
fig = shap.force_plot(explainer.expected_value[1], shap_values[1], input_df, matplotlib=False, link="logit")
|
1184 |
+
st.components.v1.html(shap.getjs() + fig.html(), height=400, width=900)
|
|
|
|
|
|
|
1185 |
else:
|
1186 |
+
explainer = shap.TreeExplainer(model)
|
1187 |
+
shap_values = explainer.shap_values(scaled_input)
|
1188 |
+
fig = shap.force_plot(explainer.expected_value, shap_values, input_df, matplotlib=False)
|
1189 |
+
st.components.v1.html(shap.getjs() + fig.html(), height=400, width=900)
|
|
|
1190 |
|
1191 |
st.write("The visualization above explains how each feature contributed to the final prediction.")
|
1192 |
|
1193 |
# 9. Add Permutation Feature Importance (for more global understanding)
|
1194 |
try:
|
1195 |
enhance_section_title("Global Feature Importance", "🌍")
|
1196 |
+
X = pd.DataFrame(scaler.transform(input_df), columns=input_df.columns)
|
|
|
|
|
|
|
|
|
1197 |
result = permutation_importance(model, X, input_df, n_repeats=10, random_state=42)
|
1198 |
importance = result.importances_mean
|
1199 |
|
|
|
1204 |
st.warning(f"Could not calculate permutation feature importance: {e}")
|
1205 |
|
1206 |
except Exception as e:
|
1207 |
+
st.error(f"Prediction failed: {str(e)}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|