TESTTT / app.py
Roberta2024's picture
Update app.py
5c0cabb verified
raw
history blame
3.13 kB
import streamlit as st
import pandas as pd
import plotly.express as px
import seaborn as sns
from sklearn.ensemble import RandomForestClassifier
from xgboost import XGBClassifier
from sklearn.tree import DecisionTreeClassifier
from sklearn.model_selection import train_test_split
import numpy as np
# Function to process data and return feature importances and correlation matrix
def calculate_importances(file):
# Read uploaded file
heart_df = pd.read_csv(file)
# Set X and y
X = heart_df.drop('target', axis=1)
y = heart_df['target']
# Split the data
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.25, random_state=42)
# Initialize models
rf_model = RandomForestClassifier(random_state=42)
xgb_model = XGBClassifier(use_label_encoder=False, eval_metric='logloss', random_state=42)
cart_model = DecisionTreeClassifier(random_state=42)
# Train models
rf_model.fit(X_train, y_train)
xgb_model.fit(X_train, y_train)
cart_model.fit(X_train, y_train)
# Get feature importances
rf_importances = rf_model.feature_importances_
xgb_importances = xgb_model.feature_importances_
cart_importances = cart_model.feature_importances_
feature_names = X.columns
# Prepare DataFrame
rf_importance = pd.DataFrame({'Feature': feature_names, 'Importance': rf_importances})
xgb_importance = pd.DataFrame({'Feature': feature_names, 'Importance': xgb_importances})
cart_importance = pd.DataFrame({'Feature': feature_names, 'Importance': cart_importances})
# Correlation Matrix
corr_matrix = heart_df.corr()
return rf_importance, xgb_importance, cart_importance, corr_matrix
# Streamlit interface
st.title("Feature Importance Calculation")
# File upload
uploaded_file = st.file_uploader("Upload heart.csv file", type=['csv'])
if uploaded_file is not None:
# Process the file and get results
rf_importance, xgb_importance, cart_importance, corr_matrix = calculate_importances(uploaded_file)
# Display the correlation matrix as a heatmap (static for now)
st.write("Correlation Matrix:")
plt.figure(figsize=(10, 8))
sns.heatmap(corr_matrix, annot=True, fmt=".2f", cmap="coolwarm", cbar=True)
st.pyplot(plt)
# Plot and display Random Forest Feature Importances with Plotly
st.write("Random Forest Feature Importance:")
fig_rf = px.bar(rf_importance, x='Importance', y='Feature', orientation='h', title="Random Forest Feature Importances")
st.plotly_chart(fig_rf)
# Plot and display XGBoost Feature Importances with Plotly
st.write("XGBoost Feature Importance:")
fig_xgb = px.bar(xgb_importance, x='Importance', y='Feature', orientation='h', title="XGBoost Feature Importances")
st.plotly_chart(fig_xgb)
# Plot and display CART (Decision Tree) Feature Importances with Plotly
st.write("CART (Decision Tree) Feature Importance:")
fig_cart = px.bar(cart_importance, x='Importance', y='Feature', orientation='h', title="CART (Decision Tree) Feature Importances")
st.plotly_chart(fig_cart)