JERNGOC's picture
Update app.py
c72b33d verified
raw
history blame
3.41 kB
import streamlit as st
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from sklearn.model_selection import train_test_split
from sklearn.linear_model import LinearRegression
from sklearn.tree import DecisionTreeClassifier
from sklearn.ensemble import RandomForestClassifier
from sklearn.preprocessing import StandardScaler
# 讓使用者上傳 CSV 檔案
uploaded_file = st.file_uploader("上傳一個 CSV 檔案", type="csv")
if uploaded_file is not None:
# 讀取上傳的 CSV 檔案
df = pd.read_csv(uploaded_file)
# 確保數據裡有 "target" 欄位
if 'target' in df.columns:
# 準備特徵和目標變量
X = df.drop('target', axis=1)
y = df['target']
# 分割數據
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)
# 標準化特徵
scaler = StandardScaler()
X_train_scaled = scaler.fit_transform(X_train)
X_test_scaled = scaler.transform(X_test)
# 計算特徵重要性
def calculate_importance():
# Linear Regression
lr = LinearRegression()
lr.fit(X_train_scaled, y_train)
lr_importance = np.abs(lr.coef_)
# CART
cart = DecisionTreeClassifier(random_state=42)
cart.fit(X_train, y_train)
cart_importance = cart.feature_importances_
# Random Forest
rf = RandomForestClassifier(n_estimators=100, random_state=42)
rf.fit(X_train, y_train)
rf_importance = rf.feature_importances_
return lr_importance, cart_importance, rf_importance
# 創建特徵重要性 DataFrame
lr_importance, cart_importance, rf_importance = calculate_importance()
feature_importance = pd.DataFrame({
'Feature': X.columns,
'Linear Regression': lr_importance,
'CART': cart_importance,
'Random Forest': rf_importance
})
# 排序
feature_importance = feature_importance.sort_values('Random Forest', ascending=False)
# 繪製特徵重要性圖表
def plot_importance():
plt.figure(figsize=(12, 8))
width = 0.25 # 條形圖寬度
indices = np.arange(len(feature_importance['Feature']))
plt.bar(indices - width, feature_importance['Linear Regression'], width=width, label='Linear Regression')
plt.bar(indices, feature_importance['CART'], width=width, label='CART')
plt.bar(indices + width, feature_importance['Random Forest'], width=width, label='Random Forest')
plt.title('Feature Importance Comparison Across Models')
plt.xlabel('Features')
plt.ylabel('Importance')
plt.xticks(indices, feature_importance['Feature'], rotation=45, ha='right')
plt.legend()
st.pyplot(plt)
# Streamlit UI
st.title("自定義CSV檔案分析 - 特徵重要性分析")
st.write("以下是 Linear Regression、CART 和 Random Forest 的特徵重要性對比圖表:")
# 顯示圖表
plot_importance()
# 顯示數據框
st.write("特徵重要性數據:")
st.dataframe(feature_importance)
else:
st.error("上傳的檔案中找不到 'target' 欄位,請確認檔案格式。")