Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -573,62 +573,134 @@ elif app_mode == "Advanced EDA":
|
|
573 |
elif app_mode == "Model Training":
|
574 |
st.title("🚂 Model Training")
|
575 |
|
576 |
-
|
577 |
-
|
578 |
-
if model_name == "Random Forest":
|
579 |
-
param_grid = {
|
580 |
-
'n_estimators': st.slider("Number of Estimators", 10, 200, 100, help="Number of trees in the forest."),
|
581 |
-
'max_depth': st.slider("Max Depth", 3, 20, 10, help="Maximum depth of the tree."),
|
582 |
-
'min_samples_split': st.slider("Minimum Samples Split", 2, 10, 2, help="Minimum samples required to split an internal node"), #New hyperparameter
|
583 |
-
'min_samples_leaf': st.slider("Minimum Samples Leaf", 1, 10, 1, help="Minimum samples required to be at a leaf node"), #New hyperparameter
|
584 |
-
}
|
585 |
|
586 |
-
#
|
587 |
-
|
588 |
-
|
589 |
-
|
590 |
-
|
591 |
-
|
592 |
-
|
593 |
-
|
594 |
-
|
595 |
-
|
596 |
-
|
597 |
-
|
598 |
-
|
599 |
-
|
600 |
-
|
601 |
-
|
602 |
-
|
603 |
-
|
604 |
-
|
605 |
-
|
606 |
-
|
607 |
-
|
608 |
-
|
609 |
-
|
610 |
-
|
611 |
-
|
612 |
-
|
613 |
-
|
614 |
-
|
615 |
-
|
616 |
-
|
617 |
-
|
618 |
-
|
619 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
620 |
|
621 |
-
|
622 |
-
|
623 |
-
|
624 |
-
|
|
|
625 |
|
626 |
-
|
627 |
-
|
628 |
-
st.write(f"Cross-validation scores: {cv_scores}")
|
629 |
-
st.write(f"Mean cross-validation score: {cv_scores.mean():.4f}")
|
630 |
|
631 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
632 |
|
633 |
# Model Saving
|
634 |
model_filename = st.text_input("Enter Model Filename (without extension)", "trained_model")
|
@@ -648,7 +720,7 @@ if st.button("Train Model"):
|
|
648 |
st.error(f"Error loading model: {e}")
|
649 |
|
650 |
#Model Evaluation Section
|
651 |
-
y_pred = model.predict(
|
652 |
|
653 |
if problem_type == "Regression":
|
654 |
mse = mean_squared_error(y_test, y_pred)
|
|
|
573 |
elif app_mode == "Model Training":
|
574 |
st.title("🚂 Model Training")
|
575 |
|
576 |
+
if st.session_state.cleaned_data is not None:
|
577 |
+
df = st.session_state.cleaned_data.copy()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
578 |
|
579 |
+
# Target Variable Selection
|
580 |
+
target_column = st.selectbox("Select Target Variable", df.columns, help="Choose the column to predict.")
|
581 |
+
|
582 |
+
# Problem Type Selection
|
583 |
+
problem_type = st.radio("Select Problem Type", ["Regression", "Classification"], help="Choose the type of problem.")
|
584 |
+
|
585 |
+
# Feature Selection
|
586 |
+
feature_columns = st.multiselect("Select Feature Columns", df.drop(columns=[target_column]).columns, help="Choose features for training.")
|
587 |
+
|
588 |
+
# Model Selection
|
589 |
+
model_name = st.selectbox("Select Model", [
|
590 |
+
"Linear Regression", "Logistic Regression", "Decision Tree",
|
591 |
+
"Random Forest", "Gradient Boosting", "SVM"
|
592 |
+
], help="Choose a model.")
|
593 |
+
|
594 |
+
feature_selection_method = st.selectbox("Feature Selection Method", ["None", "SelectKBest"])
|
595 |
+
|
596 |
+
if model_name == "Random Forest":
|
597 |
+
param_grid = {
|
598 |
+
'n_estimators': st.slider("Number of Estimators", 10, 200, 100, help="Number of trees in the forest."),
|
599 |
+
'max_depth': st.slider("Max Depth", 3, 20, 10, help="Maximum depth of the tree."),
|
600 |
+
'min_samples_split': st.slider("Minimum Samples Split", 2, 10, 2, help="Minimum samples required to split an internal node"), #New hyperparameter
|
601 |
+
'min_samples_leaf': st.slider("Minimum Samples Leaf", 1, 10, 1, help="Minimum samples required to be at a leaf node"), #New hyperparameter
|
602 |
+
}
|
603 |
+
|
604 |
+
# Train-Test Split
|
605 |
+
test_size = st.slider("Test Size", 0.1, 0.5, 0.2, help="Proportion of the dataset to include in the test split.")
|
606 |
+
|
607 |
+
if st.button("Train Model"):
|
608 |
+
with st.spinner("Training model..."):
|
609 |
+
try:
|
610 |
+
X = df[feature_columns]
|
611 |
+
y = df[target_column]
|
612 |
+
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=test_size, random_state=42)
|
613 |
+
|
614 |
+
# Preprocessing Pipeline
|
615 |
+
numeric_features = X.select_dtypes(include=np.number).columns
|
616 |
+
categorical_features = X.select_dtypes(exclude=np.number).columns
|
617 |
+
|
618 |
+
numeric_transformer = Pipeline(steps=[
|
619 |
+
('imputer', SimpleImputer(strategy='median')),
|
620 |
+
('scaler', StandardScaler())
|
621 |
+
])
|
622 |
+
|
623 |
+
categorical_transformer = Pipeline(steps=[
|
624 |
+
('imputer', SimpleImputer(strategy='most_frequent')),
|
625 |
+
('onehot', OneHotEncoder(handle_unknown='ignore'))
|
626 |
+
])
|
627 |
|
628 |
+
preprocessor = ColumnTransformer(
|
629 |
+
transformers=[
|
630 |
+
('num', numeric_transformer, numeric_features),
|
631 |
+
('cat', categorical_transformer, categorical_features)
|
632 |
+
])
|
633 |
|
634 |
+
X_train_processed = preprocessor.fit_transform(X_train)
|
635 |
+
X_test_processed = preprocessor.transform(X_test)
|
|
|
|
|
636 |
|
637 |
+
#Feature Selection
|
638 |
+
if feature_selection_method == "SelectKBest":
|
639 |
+
k = st.slider("Number of Features to Select", 1, len(feature_columns), len(feature_columns))
|
640 |
+
selector = SelectKBest(k=k)
|
641 |
+
X_train_selected = selector.fit_transform(X_train_processed, y_train)
|
642 |
+
X_test_selected = selector.transform(X_test_processed)
|
643 |
+
else:
|
644 |
+
X_train_selected = X_train_processed
|
645 |
+
X_test_selected = X_test_processed
|
646 |
+
|
647 |
+
# Model Training and Hyperparameter Tuning
|
648 |
+
if model_name == "Linear Regression":
|
649 |
+
model = LinearRegression()
|
650 |
+
elif model_name == "Logistic Regression":
|
651 |
+
model = LogisticRegression(max_iter=1000)
|
652 |
+
elif model_name == "Decision Tree":
|
653 |
+
if problem_type == "Regression":
|
654 |
+
model = DecisionTreeRegressor()
|
655 |
+
else:
|
656 |
+
model = DecisionTreeClassifier()
|
657 |
+
elif model_name == "Random Forest":
|
658 |
+
if problem_type == "Regression":
|
659 |
+
model = RandomForestRegressor(random_state=42)
|
660 |
+
grid_search = GridSearchCV(model, param_grid, cv=3, scoring='neg_mean_squared_error') # Example scoring
|
661 |
+
grid_search.fit(X_train_selected, y_train)
|
662 |
+
model = grid_search.best_estimator_
|
663 |
+
st.write("Best Parameters:", grid_search.best_params_)
|
664 |
+
else:
|
665 |
+
model = RandomForestClassifier(random_state=42)
|
666 |
+
grid_search = GridSearchCV(model, param_grid, cv=3, scoring='accuracy')
|
667 |
+
grid_search.fit(X_train_selected, y_train)
|
668 |
+
model = grid_search.best_estimator_
|
669 |
+
st.write("Best Parameters:", grid_search.best_params_)
|
670 |
+
|
671 |
+
elif model_name == "Gradient Boosting":
|
672 |
+
model = GradientBoostingRegressor() if problem_type == "Regression" else GradientBoostingClassifier()
|
673 |
+
elif model_name == "SVM":
|
674 |
+
model = SVR() if problem_type == "Regression" else SVC()
|
675 |
+
|
676 |
+
# Cross-validation
|
677 |
+
cv_scores = cross_val_score(model, X_train_selected, y_train, cv=5) #example, adjust cv
|
678 |
+
st.write(f"Cross-validation scores: {cv_scores}")
|
679 |
+
st.write(f"Mean cross-validation score: {cv_scores.mean():.4f}")
|
680 |
+
|
681 |
+
model.fit(X_train_selected, y_train)
|
682 |
+
|
683 |
+
# Store model and preprocessor
|
684 |
+
st.session_state.model = Pipeline(steps=[('preprocessor', preprocessor), ('model', model)])
|
685 |
+
st.session_state.preprocessor = preprocessor
|
686 |
+
|
687 |
+
# Model Evaluation
|
688 |
+
y_pred = model.predict(X_test_selected)
|
689 |
+
if problem_type == "Regression":
|
690 |
+
mse = mean_squared_error(y_test, y_pred)
|
691 |
+
r2 = r2_score(y_test, y_pred)
|
692 |
+
st.write(f"Mean Squared Error: {mse:.4f}")
|
693 |
+
st.write(f"R-squared: {r2:.4f}")
|
694 |
+
else:
|
695 |
+
accuracy = accuracy_score(y_test, y_pred)
|
696 |
+
st.write(f"Accuracy: {accuracy:.4f}")
|
697 |
+
|
698 |
+
st.success("Model trained successfully!")
|
699 |
+
|
700 |
+
except Exception as e:
|
701 |
+
st.error(f"An error occurred: {e}")
|
702 |
+
else:
|
703 |
+
st.write("Please upload and clean data first.")
|
704 |
|
705 |
# Model Saving
|
706 |
model_filename = st.text_input("Enter Model Filename (without extension)", "trained_model")
|
|
|
720 |
st.error(f"Error loading model: {e}")
|
721 |
|
722 |
#Model Evaluation Section
|
723 |
+
y_pred = st.session_state.model.predict(X_test)
|
724 |
|
725 |
if problem_type == "Regression":
|
726 |
mse = mean_squared_error(y_test, y_pred)
|