CosmickVisions commited on
Commit
cff9e1f
·
verified ·
1 Parent(s): efd2599

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +57 -26
app.py CHANGED
@@ -722,10 +722,6 @@ elif app_mode == "Model Training":
722
 
723
  # Store model and preprocessor
724
  st.session_state.model = Pipeline(steps=[('preprocessor', preprocessor), ('model', model)])
725
- st.session_state.preprocessor = preprocessor
726
-
727
- # Store model and preprocessor
728
- st.session_state.model = Pipeline(steps=[('preprocessor', preprocessor), ('model', model)])
729
  st.session_state.preprocessor = preprocessor
730
 
731
  # Model Evaluation
@@ -764,7 +760,7 @@ elif app_mode == "Model Training":
764
 
765
  #Heatmap
766
  fig_conf, ax_conf = plt.subplots()
767
- sns.heatmap(conf_matrix, annot=True, fmt='d', cmap='Blues', ax=ax_conf)
768
  ax_conf.set_xlabel('Predicted Labels')
769
  ax_conf.set_ylabel('True Labels')
770
  ax_conf.set_title('Confusion Matrix')
@@ -774,11 +770,10 @@ elif app_mode == "Model Training":
774
  #Added section for model visualization
775
  st.subheader("Model Visualization")
776
 
777
- if problem_type == "Classification":
778
-
779
- try: #All the plotting code here.
780
  #Added code that calculates the learning curves
781
- train_sizes, train_scores, valid_scores = learning_curve(model, X_train_selected, y_train, cv=5, scoring='accuracy')
782
 
783
  #Then add a plot for the learning curve and use st.pyplot
784
  train_mean = np.mean(train_scores, axis=1)
@@ -799,25 +794,61 @@ elif app_mode == "Model Training":
799
  ax_lc.set_ylabel('Accuracy')
800
  ax_lc.legend(loc='best')
801
  st.pyplot(fig_lc) # Display the figure in Streamlit
 
 
 
802
 
 
 
 
 
 
 
803
 
804
- #Feature Importance (Tree-based Models)
805
- if model_name in ["Random Forest", "Gradient Boosting"] : #Make sure its the correct type for extraction
806
- importances = model.feature_importances_ # Assumed tree-based model
807
- feat_importances = pd.Series(importances, index=X_train.columns)
808
- feat_importances = feat_importances.nlargest(20)
809
 
810
- fig_feat, ax_feat = plt.subplots()
811
- feat_importances.plot(kind='barh', ax=ax_feat)
812
- ax_feat.set_xlabel('Relative Importance')
813
- ax_feat.set_ylabel('Features')
814
- ax_feat.set_title('Feature Importances')
815
- st.pyplot(fig_feat)
816
- except Exception as e: #Local error
817
- st.write(f"Plotting functions requires tree based-models and for classification: {e}")
818
 
819
- else:
820
- st.write("Please upload and clean data first.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
821
 
822
  # Model Saving
823
  model_filename = st.text_input("Enter Model Filename (without extension)", "trained_model")
@@ -839,6 +870,7 @@ elif app_mode == "Model Training":
839
  #Model Evaluation Section
840
  if 'X_test' in locals() and st.session_state.model is not None:
841
  try: #Error catching with new test data
 
842
  y_pred = st.session_state.model.predict(X_test)
843
 
844
  if problem_type == "Regression":
@@ -850,9 +882,8 @@ elif app_mode == "Model Training":
850
  from sklearn.metrics import confusion_matrix, roc_curve, auc, precision_recall_curve, classification_report #Import here to avoid library bloat
851
  accuracy = accuracy_score(y_test, y_pred)
852
  st.write(f"Accuracy: {accuracy:.4f}")
853
-
854
  except Exception as e: #local error
855
- st.error(f"An error occurred during model evaluation: {e}")
856
 
857
  elif app_mode == "Predictions":
858
  st.title("🔮 Make Predictions")
 
722
 
723
  # Store model and preprocessor
724
  st.session_state.model = Pipeline(steps=[('preprocessor', preprocessor), ('model', model)])
 
 
 
 
725
  st.session_state.preprocessor = preprocessor
726
 
727
  # Model Evaluation
 
760
 
761
  #Heatmap
762
  fig_conf, ax_conf = plt.subplots()
763
+ sns.heatmap(conf_matrix, annot=True, fmt='d', cmap='Blues', ax_conf)
764
  ax_conf.set_xlabel('Predicted Labels')
765
  ax_conf.set_ylabel('True Labels')
766
  ax_conf.set_title('Confusion Matrix')
 
770
  #Added section for model visualization
771
  st.subheader("Model Visualization")
772
 
773
+ try: #All the plotting code here.
774
+ if problem_type == "Classification" and model_name not in ["Linear Regression","Logistic Regression","SVM","Naive Bayes", "KNN"]:
 
775
  #Added code that calculates the learning curves
776
+ train_sizes, train_scores, valid_scores = learning_curve(model, X_train_selected, y_train, cv=5, scoring='accuracy',n_jobs = -1)
777
 
778
  #Then add a plot for the learning curve and use st.pyplot
779
  train_mean = np.mean(train_scores, axis=1)
 
794
  ax_lc.set_ylabel('Accuracy')
795
  ax_lc.legend(loc='best')
796
  st.pyplot(fig_lc) # Display the figure in Streamlit
797
+ importances = model.feature_importances_ # Assumed tree-based model
798
+ feat_importances = pd.Series(importances, index=X_train.columns)
799
+ feat_importances = feat_importances.nlargest(20)
800
 
801
+ fig_feat, ax_feat = plt.subplots()
802
+ feat_importances.plot(kind='barh', ax=ax_feat)
803
+ ax_feat.set_xlabel('Relative Importance')
804
+ ax_feat.set_ylabel('Features')
805
+ ax_feat.set_title('Feature Importances')
806
+ st.pyplot(fig_feat)
807
 
808
+ elif problem_type == "Regression" and model_name not in ["Linear Regression","Logistic Regression","SVM","Naive Bayes", "KNN"]: #graph regressions with regressor based models
 
 
 
 
809
 
810
+ train_sizes, train_scores, valid_scores = learning_curve(model, X_train_selected, y_train, cv=5, scoring='neg_mean_squared_error', n_jobs=-1)
 
 
 
 
 
 
 
811
 
812
+ #Then add a plot for the learning curve and use st.pyplot
813
+ train_mean = np.mean(train_scores, axis=1)
814
+ train_std = np.std(train_scores, axis=1)
815
+ valid_mean = np.mean(valid_scores, axis=1)
816
+ valid_std = np.std(valid_scores, axis=1)
817
+
818
+ fig_lc, ax_lc = plt.subplots() #plot the curve in matplotlib
819
+
820
+
821
+ ax_lc.plot(train_sizes, train_mean, color='blue', marker='o', markersize=5, label='Training neg_mean_squared_error')
822
+ ax_lc.fill_between(train_sizes, train_mean + train_std, train_mean - train_std, alpha=0.15, color='blue')
823
+ ax_lc.plot(train_sizes, valid_mean, color='green', linestyle='--', marker='s', markersize=5, label='Validation neg_mean_squared_error')
824
+ ax_lc.fill_between(train_sizes, valid_mean + valid_std, valid_mean - valid_std, alpha=0.15, color='green')
825
+
826
+ ax_lc.set_title('Learning Curves')
827
+ ax_lc.set_xlabel('Training Set Size')
828
+ ax_lc.set_ylabel('Neg Mean Squared Error')
829
+ ax_lc.legend(loc='best')
830
+ st.pyplot(fig_lc) # Display the figure in Streamlit
831
+ importances = model.feature_importances_ # Assumed tree-based model
832
+ feat_importances = pd.Series(importances, index=X_train.columns)
833
+ feat_importances = feat_importances.nlargest(20)
834
+
835
+ fig_feat, ax_feat = plt.subplots()
836
+ feat_importances.plot(kind='barh', ax=ax_feat)
837
+ ax_feat.set_xlabel('Relative Importance')
838
+ ax_feat.set_ylabel('Features')
839
+ ax_feat.set_title('Feature Importances')
840
+ st.pyplot(fig_feat)
841
+
842
+
843
+ except Exception as e: #Local error
844
+ st.write(f"Plotting functions requires tree based-models and for classification: {e}")
845
+
846
+
847
+ except Exception as e:
848
+ st.error(f"An error occurred: {e}")
849
+
850
+ else:
851
+ st.write("Please upload and clean data first.")
852
 
853
  # Model Saving
854
  model_filename = st.text_input("Enter Model Filename (without extension)", "trained_model")
 
870
  #Model Evaluation Section
871
  if 'X_test' in locals() and st.session_state.model is not None:
872
  try: #Error catching with new test data
873
+
874
  y_pred = st.session_state.model.predict(X_test)
875
 
876
  if problem_type == "Regression":
 
882
  from sklearn.metrics import confusion_matrix, roc_curve, auc, precision_recall_curve, classification_report #Import here to avoid library bloat
883
  accuracy = accuracy_score(y_test, y_pred)
884
  st.write(f"Accuracy: {accuracy:.4f}")
 
885
  except Exception as e: #local error
886
+ st.error(f"An error occurred during model evaluation: {e}")
887
 
888
  elif app_mode == "Predictions":
889
  st.title("🔮 Make Predictions")