Roberta2024 commited on
Commit
1835df3
·
verified ·
1 Parent(s): 2c3f886

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +51 -19
app.py CHANGED
@@ -1,7 +1,7 @@
1
  import streamlit as st
2
  import pandas as pd
3
- import plotly.express as px
4
- import plotly.graph_objects as go
5
  from sklearn.ensemble import RandomForestClassifier
6
  from xgboost import XGBClassifier
7
  from sklearn.tree import DecisionTreeClassifier
@@ -38,14 +38,26 @@ def calculate_importances(file):
38
  feature_names = X.columns
39
 
40
  # Prepare DataFrame
41
- rf_importance = pd.DataFrame({'Feature': feature_names, 'Importance': rf_importances})
42
- xgb_importance = pd.DataFrame({'Feature': feature_names, 'Importance': xgb_importances})
43
- cart_importance = pd.DataFrame({'Feature': feature_names, 'Importance': cart_importances})
 
 
 
 
 
 
 
 
44
 
45
  # Correlation Matrix
46
  corr_matrix = heart_df.corr()
47
 
48
- return rf_importance, xgb_importance, cart_importance, corr_matrix
 
 
 
 
49
 
50
  # Streamlit interface
51
  st.title("Feature Importance Calculation")
@@ -55,24 +67,44 @@ uploaded_file = st.file_uploader("Upload heart.csv file", type=['csv'])
55
 
56
  if uploaded_file is not None:
57
  # Process the file and get results
58
- rf_importance, xgb_importance, cart_importance, corr_matrix = calculate_importances(uploaded_file)
59
 
60
- # Display the correlation matrix as an interactive heatmap with Plotly
 
 
 
 
 
 
 
 
 
 
 
 
 
61
  st.write("Correlation Matrix:")
62
- fig_corr = px.imshow(corr_matrix, text_auto=True, aspect="auto", title="Correlation Matrix", color_continuous_scale="coolwarm")
63
- st.plotly_chart(fig_corr)
 
64
 
65
- # Plot and display Random Forest Feature Importances with Plotly
66
  st.write("Random Forest Feature Importance:")
67
- fig_rf = px.bar(rf_importance, x='Importance', y='Feature', orientation='h', title="Random Forest Feature Importances")
68
- st.plotly_chart(fig_rf)
 
 
69
 
70
- # Plot and display XGBoost Feature Importances with Plotly
71
  st.write("XGBoost Feature Importance:")
72
- fig_xgb = px.bar(xgb_importance, x='Importance', y='Feature', orientation='h', title="XGBoost Feature Importances")
73
- st.plotly_chart(fig_xgb)
 
 
74
 
75
- # Plot and display CART (Decision Tree) Feature Importances with Plotly
76
  st.write("CART (Decision Tree) Feature Importance:")
77
- fig_cart = px.bar(cart_importance, x='Importance', y='Feature', orientation='h', title="CART (Decision Tree) Feature Importances")
78
- st.plotly_chart(fig_cart)
 
 
 
1
  import streamlit as st
2
  import pandas as pd
3
+ import matplotlib.pyplot as plt
4
+ import seaborn as sns
5
  from sklearn.ensemble import RandomForestClassifier
6
  from xgboost import XGBClassifier
7
  from sklearn.tree import DecisionTreeClassifier
 
38
  feature_names = X.columns
39
 
40
  # Prepare DataFrame
41
+ rf_importance = {'Feature': feature_names, 'Random Forest': rf_importances}
42
+ xgb_importance = {'Feature': feature_names, 'XGBoost': xgb_importances}
43
+ cart_importance = {'Feature': feature_names, 'CART': cart_importances}
44
+
45
+ # Create DataFrames
46
+ rf_df = pd.DataFrame(rf_importance)
47
+ xgb_df = pd.DataFrame(xgb_importance)
48
+ cart_df = pd.DataFrame(cart_importance)
49
+
50
+ # Merge DataFrames
51
+ importance_df = rf_df.merge(xgb_df, on='Feature').merge(cart_df, on='Feature')
52
 
53
  # Correlation Matrix
54
  corr_matrix = heart_df.corr()
55
 
56
+ # Save to Excel
57
+ file_name = 'feature_importances.xlsx'
58
+ importance_df.to_excel(file_name, index=False)
59
+
60
+ return file_name, importance_df, corr_matrix, rf_importances, xgb_importances, cart_importances, feature_names
61
 
62
  # Streamlit interface
63
  st.title("Feature Importance Calculation")
 
67
 
68
  if uploaded_file is not None:
69
  # Process the file and get results
70
+ excel_file, importance_df, corr_matrix, rf_importances, xgb_importances, cart_importances, feature_names = calculate_importances(uploaded_file)
71
 
72
+ # Display a preview of the DataFrame
73
+ st.write("Feature Importances (Preview):")
74
+ st.dataframe(importance_df.head())
75
+
76
+ # Provide a link to download the Excel file
77
+ with open(excel_file, "rb") as file:
78
+ btn = st.download_button(
79
+ label="Download Excel File",
80
+ data=file,
81
+ file_name=excel_file,
82
+ mime="application/vnd.ms-excel"
83
+ )
84
+
85
+ # Plot and display the Correlation Matrix
86
  st.write("Correlation Matrix:")
87
+ plt.figure(figsize=(10, 8))
88
+ sns.heatmap(corr_matrix, annot=True, fmt=".2f", cmap="coolwarm", cbar=True)
89
+ st.pyplot(plt)
90
 
91
+ # Plot and display the Feature Importance (Random Forest)
92
  st.write("Random Forest Feature Importance:")
93
+ fig_rf, ax_rf = plt.subplots()
94
+ sns.barplot(x=rf_importances, y=feature_names, ax=ax_rf)
95
+ ax_rf.set_title('Random Forest Feature Importances')
96
+ st.pyplot(fig_rf)
97
 
98
+ # Plot and display the Feature Importance (XGBoost)
99
  st.write("XGBoost Feature Importance:")
100
+ fig_xgb, ax_xgb = plt.subplots()
101
+ sns.barplot(x=xgb_importances, y=feature_names, ax=ax_xgb)
102
+ ax_xgb.set_title('XGBoost Feature Importances')
103
+ st.pyplot(fig_xgb)
104
 
105
+ # Plot and display the Feature Importance (Decision Tree - CART)
106
  st.write("CART (Decision Tree) Feature Importance:")
107
+ fig_cart, ax_cart = plt.subplots()
108
+ sns.barplot(x=cart_importances, y=feature_names, ax=ax_cart)
109
+ ax_cart.set_title('CART (Decision Tree) Feature Importances')
110
+ st.pyplot(fig_cart)