CosmickVisions commited on
Commit
1e6151b
·
verified ·
1 Parent(s): 339e41b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +32 -15
app.py CHANGED
@@ -735,6 +735,8 @@ elif app_mode == "Model Training":
735
  st.write(f"R-squared: {r2:.4f}")
736
  else:
737
  from sklearn.metrics import confusion_matrix, roc_curve, auc, precision_recall_curve, classification_report #Import here to avoid library bloat
 
 
738
 
739
  #Weighted averaging for metrics for multiclass
740
  average_method = "weighted" #changed from None
@@ -763,6 +765,22 @@ elif app_mode == "Model Training":
763
  ax_conf.set_title('Confusion Matrix')
764
  st.pyplot(fig_conf)
765
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
766
  else:
767
  st.write("Please upload and clean data first.")
768
 
@@ -784,22 +802,21 @@ elif app_mode == "Model Training":
784
  st.error(f"Error loading model: {e}")
785
 
786
  #Model Evaluation Section
787
- if 'X_test' in locals() and st.session_state.model is not None and problem_type == "Regression":
788
- y_pred = st.session_state.model.predict(X_test)
789
-
790
- if problem_type == "Regression":
791
- mse = mean_squared_error(y_test, y_pred)
792
- r2 = r2_score(y_test, y_pred)
793
- st.write(f"Mean Squared Error: {mse:.4f}")
794
- st.write(f"R-squared: {r2:.4f}")
795
- else:
796
- from sklearn.metrics import confusion_matrix, roc_curve, auc, precision_recall_curve, classification_report #Import here to avoid library bloat
797
-
798
- #Weighted averaging for metrics for multiclass
799
- average_method = "weighted" #changed from None
800
 
801
- accuracy = accuracy_score(y_test, y_pred)
802
- st.write(f"Accuracy: {accuracy:.4f}")
 
 
 
 
 
 
 
 
 
803
 
804
  elif app_mode == "Predictions":
805
  st.title("🔮 Make Predictions")
 
735
  st.write(f"R-squared: {r2:.4f}")
736
  else:
737
  from sklearn.metrics import confusion_matrix, roc_curve, auc, precision_recall_curve, classification_report #Import here to avoid library bloat
738
+ import seaborn as sns
739
+ import matplotlib.pyplot as plt #Added import statement
740
 
741
  #Weighted averaging for metrics for multiclass
742
  average_method = "weighted" #changed from None
 
765
  ax_conf.set_title('Confusion Matrix')
766
  st.pyplot(fig_conf)
767
 
768
+
769
+ # Feature Importance (Tree-based Models)
770
+ if model_name in ["Random Forest", "Gradient Boosting"] and problem_type == "Classification":
771
+ importances = model.feature_importances_ # Assumed tree-based model
772
+ feat_importances = pd.Series(importances, index=X_train.columns)
773
+ feat_importances = feat_importances.nlargest(20)
774
+
775
+ fig_feat, ax_feat = plt.subplots()
776
+ feat_importances.plot(kind='barh', ax=ax_feat)
777
+ ax_feat.set_xlabel('Relative Importance')
778
+ ax_feat.set_ylabel('Features')
779
+ ax_feat.set_title('Feature Importances')
780
+ st.pyplot(fig_feat)
781
+
782
+ except Exception as e:
783
+ st.error(f"An error occurred: {e}")
784
  else:
785
  st.write("Please upload and clean data first.")
786
 
 
802
  st.error(f"Error loading model: {e}")
803
 
804
  #Model Evaluation Section
805
+ if 'X_test' in locals() and st.session_state.model is not None:
806
+ try:
807
+ y_pred = st.session_state.model.predict(X_test)
 
 
 
 
 
 
 
 
 
 
808
 
809
+ if problem_type == "Regression":
810
+ mse = mean_squared_error(y_test, y_pred)
811
+ r2 = r2_score(y_test, y_pred)
812
+ st.write(f"Mean Squared Error: {mse:.4f}")
813
+ st.write(f"R-squared: {r2:.4f}")
814
+ else:
815
+ from sklearn.metrics import confusion_matrix, roc_curve, auc, precision_recall_curve, classification_report #Import here to avoid library bloat
816
+ accuracy = accuracy_score(y_test, y_pred)
817
+ st.write(f"Accuracy: {accuracy:.4f}")
818
+ except Exception as e:
819
+ st.error(f"An error occurred during model evaluation: {e}")
820
 
821
  elif app_mode == "Predictions":
822
  st.title("🔮 Make Predictions")