Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
|
@@ -3,54 +3,27 @@ import pandas as pd
|
|
| 3 |
import numpy as np
|
| 4 |
import plotly.express as px
|
| 5 |
import plotly.graph_objects as go
|
| 6 |
-
import matplotlib.pyplot as plt
|
| 7 |
from scipy.stats import pearsonr, spearmanr
|
| 8 |
from sklearn.inspection import permutation_importance
|
| 9 |
from sklearn.preprocessing import StandardScaler, LabelEncoder
|
| 10 |
-
from sklearn.model_selection import train_test_split
|
| 11 |
-
from sklearn.ensemble import RandomForestClassifier, RandomForestRegressor
|
| 12 |
-
from sklearn.
|
| 13 |
-
from
|
| 14 |
-
from
|
| 15 |
-
import joblib
|
|
|
|
| 16 |
import shap
|
| 17 |
from datetime import datetime
|
|
|
|
| 18 |
|
| 19 |
-
#
|
| 20 |
-
|
| 21 |
-
#
|
| 22 |
-
|
| 23 |
-
|
| 24 |
-
page_icon="🔮",
|
| 25 |
-
layout="wide",
|
| 26 |
-
initial_sidebar_state="expanded"
|
| 27 |
-
)
|
| 28 |
-
|
| 29 |
|
| 30 |
-
# --------------------------
|
| 31 |
-
# Custom Styling
|
| 32 |
-
# --------------------------
|
| 33 |
-
st.markdown("""
|
| 34 |
-
<style>
|
| 35 |
-
.main {background-color: #f8f9fa;}
|
| 36 |
-
.sidebar .sidebar-content {background-color: #2c3e50;}
|
| 37 |
-
.stButton>button {background-color: #3498db; color: white;}
|
| 38 |
-
.stTextInput>div>div>input {border: 1px solid #3498db;}
|
| 39 |
-
.stSelectbox>div>div>select {border: 1px solid #3498db;}
|
| 40 |
-
.stSlider>div>div>div>div {background-color: #3498db;}
|
| 41 |
-
.metric {padding: 15px; background-color: white; border-radius: 10px; box-shadow: 0 2px 5px rgba(0,0,0,0.1);}
|
| 42 |
-
</style>
|
| 43 |
-
""", unsafe_allow_html=True)
|
| 44 |
-
|
| 45 |
-
# --------------------------
|
| 46 |
-
# Session State Initialization
|
| 47 |
-
# --------------------------
|
| 48 |
-
if 'raw_data' not in st.session_state:
|
| 49 |
-
st.session_state.raw_data = None
|
| 50 |
-
if 'cleaned_data' not in st.session_state:
|
| 51 |
-
st.session_state.cleaned_data = None
|
| 52 |
-
if 'model' not in st.session_state:
|
| 53 |
-
st.session_state.model = None
|
| 54 |
|
| 55 |
# --------------------------
|
| 56 |
# Helper Functions
|
|
@@ -92,7 +65,6 @@ def generate_quality_report(df):
|
|
| 92 |
report['columns'][col] = col_report
|
| 93 |
return report
|
| 94 |
|
| 95 |
-
# Function to train the model (Separated for clarity and reusability)
|
| 96 |
def train_model(df, target, features, problem_type, test_size, model_type, model_params, use_grid_search=False):
|
| 97 |
"""Trains a model with hyperparameter tuning, cross-validation, and customizable model architecture."""
|
| 98 |
|
|
@@ -258,13 +230,12 @@ def train_model(df, target, features, problem_type, test_size, model_type, model
|
|
| 258 |
# Store the column order for prediction purposes
|
| 259 |
column_order = X.columns
|
| 260 |
|
| 261 |
-
return model, scaler, label_encoder, imputer_numerical, metrics, column_order, importance
|
| 262 |
|
| 263 |
except Exception as e:
|
| 264 |
st.error(f"Training failed: {str(e)}")
|
| 265 |
-
return None, None, None, None, None, None, None
|
| 266 |
-
|
| 267 |
-
# Model Validation Function
|
| 268 |
def validate_model(model_path, df, target, features, test_size):
|
| 269 |
"""Loads a model, preprocesses data, and evaluates the model on a validation set."""
|
| 270 |
try:
|
|
@@ -365,18 +336,18 @@ with st.sidebar:
|
|
| 365 |
# --------------------------
|
| 366 |
if app_mode == "Data Upload":
|
| 367 |
st.title("📤 Data Upload & Profiling")
|
| 368 |
-
|
| 369 |
uploaded_file = st.file_uploader("Upload your dataset (CSV/XLSX)", type=["csv", "xlsx"])
|
| 370 |
-
|
| 371 |
if uploaded_file:
|
| 372 |
try:
|
| 373 |
if uploaded_file.name.endswith('.csv'):
|
| 374 |
df = pd.read_csv(uploaded_file)
|
| 375 |
else:
|
| 376 |
df = pd.read_excel(uploaded_file)
|
| 377 |
-
|
| 378 |
st.session_state.raw_data = df
|
| 379 |
-
|
| 380 |
col1, col2, col3 = st.columns(3)
|
| 381 |
with col1:
|
| 382 |
st.metric("Rows", df.shape[0])
|
|
@@ -384,15 +355,15 @@ if app_mode == "Data Upload":
|
|
| 384 |
st.metric("Columns", df.shape[1])
|
| 385 |
with col3:
|
| 386 |
st.metric("Missing Values", df.isna().sum().sum())
|
| 387 |
-
|
| 388 |
with st.expander("Data Preview", expanded=True):
|
| 389 |
st.dataframe(df.head(10), use_container_width=True)
|
| 390 |
-
|
| 391 |
if st.button("Generate Full Profile Report"):
|
| 392 |
with st.spinner("Generating comprehensive analysis..."):
|
| 393 |
pr = ProfileReport(df, explorative=True)
|
| 394 |
st_profile_report(pr)
|
| 395 |
-
|
| 396 |
except Exception as e:
|
| 397 |
st.error(f"Error loading file: {str(e)}")
|
| 398 |
|
|
@@ -406,6 +377,8 @@ elif app_mode == "Data Cleaning":
|
|
| 406 |
st.warning("Please upload data first")
|
| 407 |
st.stop()
|
| 408 |
|
|
|
|
|
|
|
| 409 |
# Initialize session state (only if it's not already there)
|
| 410 |
if 'data_versions' not in st.session_state:
|
| 411 |
st.session_state.data_versions = [st.session_state.raw_data.copy()]
|
|
|
|
| 3 |
import numpy as np
|
| 4 |
import plotly.express as px
|
| 5 |
import plotly.graph_objects as go
|
| 6 |
+
import matplotlib.pyplot as plt # For SHAP charts
|
| 7 |
from scipy.stats import pearsonr, spearmanr
|
| 8 |
from sklearn.inspection import permutation_importance
|
| 9 |
from sklearn.preprocessing import StandardScaler, LabelEncoder
|
| 10 |
+
from sklearn.model_selection import train_test_split, GridSearchCV, cross_val_score
|
| 11 |
+
from sklearn.ensemble import RandomForestClassifier, RandomForestRegressor, GradientBoostingClassifier, GradientBoostingRegressor
|
| 12 |
+
from sklearn.neural_network import MLPClassifier, MLPRegressor
|
| 13 |
+
from sklearn.metrics import accuracy_score, mean_squared_error, r2_score, confusion_matrix, classification_report
|
| 14 |
+
from sklearn.impute import SimpleImputer
|
| 15 |
+
import joblib # For saving and loading models
|
| 16 |
+
import os # For file directory
|
| 17 |
import shap
|
| 18 |
from datetime import datetime
|
| 19 |
+
from stqdm import stqdm
|
| 20 |
|
| 21 |
+
# Constants used (global)
|
| 22 |
+
PATH_FILES = "/".join(('.', "files"))
|
| 23 |
+
# Ensure upload location exists; make dir if it didn't create one.
|
| 24 |
+
if not os.path.isdir("..") / PATH_FILES:
|
| 25 |
+
os.makedirs("created", 0o777, exist_ok=True)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 26 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 27 |
|
| 28 |
# --------------------------
|
| 29 |
# Helper Functions
|
|
|
|
| 65 |
report['columns'][col] = col_report
|
| 66 |
return report
|
| 67 |
|
|
|
|
| 68 |
def train_model(df, target, features, problem_type, test_size, model_type, model_params, use_grid_search=False):
|
| 69 |
"""Trains a model with hyperparameter tuning, cross-validation, and customizable model architecture."""
|
| 70 |
|
|
|
|
| 230 |
# Store the column order for prediction purposes
|
| 231 |
column_order = X.columns
|
| 232 |
|
| 233 |
+
return model, scaler, label_encoder, imputer_numerical, metrics, column_order, importance, X_train, y_train # Return X_train and y_train
|
| 234 |
|
| 235 |
except Exception as e:
|
| 236 |
st.error(f"Training failed: {str(e)}")
|
| 237 |
+
return None, None, None, None, None, None, None, None, None
|
| 238 |
+
|
|
|
|
| 239 |
def validate_model(model_path, df, target, features, test_size):
|
| 240 |
"""Loads a model, preprocesses data, and evaluates the model on a validation set."""
|
| 241 |
try:
|
|
|
|
| 336 |
# --------------------------
|
| 337 |
if app_mode == "Data Upload":
|
| 338 |
st.title("📤 Data Upload & Profiling")
|
| 339 |
+
|
| 340 |
uploaded_file = st.file_uploader("Upload your dataset (CSV/XLSX)", type=["csv", "xlsx"])
|
| 341 |
+
|
| 342 |
if uploaded_file:
|
| 343 |
try:
|
| 344 |
if uploaded_file.name.endswith('.csv'):
|
| 345 |
df = pd.read_csv(uploaded_file)
|
| 346 |
else:
|
| 347 |
df = pd.read_excel(uploaded_file)
|
| 348 |
+
|
| 349 |
st.session_state.raw_data = df
|
| 350 |
+
|
| 351 |
col1, col2, col3 = st.columns(3)
|
| 352 |
with col1:
|
| 353 |
st.metric("Rows", df.shape[0])
|
|
|
|
| 355 |
st.metric("Columns", df.shape[1])
|
| 356 |
with col3:
|
| 357 |
st.metric("Missing Values", df.isna().sum().sum())
|
| 358 |
+
|
| 359 |
with st.expander("Data Preview", expanded=True):
|
| 360 |
st.dataframe(df.head(10), use_container_width=True)
|
| 361 |
+
|
| 362 |
if st.button("Generate Full Profile Report"):
|
| 363 |
with st.spinner("Generating comprehensive analysis..."):
|
| 364 |
pr = ProfileReport(df, explorative=True)
|
| 365 |
st_profile_report(pr)
|
| 366 |
+
|
| 367 |
except Exception as e:
|
| 368 |
st.error(f"Error loading file: {str(e)}")
|
| 369 |
|
|
|
|
| 377 |
st.warning("Please upload data first")
|
| 378 |
st.stop()
|
| 379 |
|
| 380 |
+
df = st.session_state.raw_data.copy() # Ensure df is defined in this section
|
| 381 |
+
|
| 382 |
# Initialize session state (only if it's not already there)
|
| 383 |
if 'data_versions' not in st.session_state:
|
| 384 |
st.session_state.data_versions = [st.session_state.raw_data.copy()]
|