Spaces:
Running
Running
Create app.py
Browse files
app.py
ADDED
@@ -0,0 +1,441 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
import pandas as pd
|
3 |
+
import seaborn as sns
|
4 |
+
import matplotlib.pyplot as plt
|
5 |
+
import plotly.express as px
|
6 |
+
import numpy as np
|
7 |
+
import xgboost as xgb
|
8 |
+
import os
|
9 |
+
from sklearn.model_selection import train_test_split
|
10 |
+
from sklearn.metrics import accuracy_score, classification_report, confusion_matrix
|
11 |
+
|
12 |
+
# Set page configuration at the very top
|
13 |
+
st.set_page_config(page_title="Healthcare Dashboard", layout="wide", page_icon="💡")
|
14 |
+
|
15 |
+
# Define human-readable labels for prediction outcomes
|
16 |
+
OUTCOME_MAP = {
|
17 |
+
0: "Patient recovered and went home",
|
18 |
+
1: "Transferred to another hospital or care center",
|
19 |
+
2: "Transferred to a rehab center",
|
20 |
+
3: "Left the hospital early without approval",
|
21 |
+
4: "Passed away or had a very serious event",
|
22 |
+
}
|
23 |
+
|
24 |
+
# Function to load the model
|
25 |
+
def load_model():
|
26 |
+
try:
|
27 |
+
model = xgb.XGBClassifier()
|
28 |
+
model.load_model("xgboost_patient_model.json")
|
29 |
+
return model
|
30 |
+
except Exception as e:
|
31 |
+
st.error(f"Error loading model: {e}")
|
32 |
+
return None
|
33 |
+
|
34 |
+
# Ensure data matches model's feature requirements
|
35 |
+
def preprocess_data(data, model_features):
|
36 |
+
try:
|
37 |
+
data = data.apply(pd.to_numeric, errors='coerce')
|
38 |
+
missing_features = [f for f in model_features if f not in data.columns]
|
39 |
+
extra_features = [f for f in data.columns if f not in model_features]
|
40 |
+
|
41 |
+
if missing_features:
|
42 |
+
st.error(f"❌ Missing required features: {missing_features}")
|
43 |
+
return None
|
44 |
+
if extra_features:
|
45 |
+
st.warning(f"⚠️ Extra features in uploaded data: {extra_features}")
|
46 |
+
|
47 |
+
return data[model_features]
|
48 |
+
except Exception as e:
|
49 |
+
st.error(f"Data preprocessing error: {e}")
|
50 |
+
return None
|
51 |
+
|
52 |
+
# Predict patient outcomes
|
53 |
+
def predict_outcome(model, data):
|
54 |
+
if model is None:
|
55 |
+
return None, None, None, None
|
56 |
+
|
57 |
+
actual_target = data.pop("target") if "target" in data.columns else None
|
58 |
+
|
59 |
+
try:
|
60 |
+
model_features = model.get_booster().feature_names
|
61 |
+
data = preprocess_data(data, model_features)
|
62 |
+
if data is None:
|
63 |
+
return None, None, None, None
|
64 |
+
|
65 |
+
predictions = model.predict(data)
|
66 |
+
|
67 |
+
# Convert numerical predictions to human-readable labels
|
68 |
+
mapped_predictions = [OUTCOME_MAP[pred] for pred in predictions]
|
69 |
+
actual_labels = [OUTCOME_MAP[actual] for actual in actual_target] if actual_target is not None else ["N/A"] * len(predictions)
|
70 |
+
|
71 |
+
# Debugging information
|
72 |
+
if actual_target is not None:
|
73 |
+
correct_predictions = (predictions == actual_target).sum()
|
74 |
+
total_predictions = len(actual_target)
|
75 |
+
accuracy = (correct_predictions / total_predictions) * 100
|
76 |
+
st.write(f"✅ Correct Predictions: {correct_predictions}/{total_predictions}")
|
77 |
+
st.write(f"📊 Model Accuracy: **{accuracy:.2f}%**")
|
78 |
+
|
79 |
+
return actual_target, predictions, mapped_predictions, actual_labels
|
80 |
+
except Exception as e:
|
81 |
+
st.error(f"Prediction error: {e}")
|
82 |
+
return None, None, None, None
|
83 |
+
|
84 |
+
# Load the data
|
85 |
+
file_path = 'final_cleaned_patient_data.csv'
|
86 |
+
try:
|
87 |
+
df = pd.read_csv(file_path)
|
88 |
+
except Exception as e:
|
89 |
+
st.error(f"Error loading data: {e}")
|
90 |
+
df = pd.DataFrame() # Create empty DataFrame if file doesn't exist
|
91 |
+
|
92 |
+
# Sidebar navigation
|
93 |
+
st.sidebar.title('Healthcare Data Dashboard')
|
94 |
+
|
95 |
+
# Team Members Section
|
96 |
+
st.sidebar.markdown("### 🏆 Team Members:")
|
97 |
+
team_members = [
|
98 |
+
"1. R. Sai Somnath",
|
99 |
+
"2. S. Sreevardhan",
|
100 |
+
"3. S. Mohammad Basha",
|
101 |
+
"4. V. Hussain Basha",
|
102 |
+
"5. P. Charles"
|
103 |
+
]
|
104 |
+
for member in team_members:
|
105 |
+
st.sidebar.text(member)
|
106 |
+
|
107 |
+
# Add a divider
|
108 |
+
st.sidebar.markdown("---")
|
109 |
+
|
110 |
+
# Section navigation - Added two new sections
|
111 |
+
option = st.sidebar.selectbox('Choose a section', [
|
112 |
+
'Data Overview',
|
113 |
+
'Data Visualization',
|
114 |
+
'Interactive Reports',
|
115 |
+
'Correlation Analysis',
|
116 |
+
'Data Insights',
|
117 |
+
'Patient Outcome Prediction',
|
118 |
+
'Batch Prediction',
|
119 |
+
'Model Performance'
|
120 |
+
])
|
121 |
+
|
122 |
+
# Apply a Streamlit theme with a dark background for a modern look
|
123 |
+
st.markdown("""
|
124 |
+
<style>
|
125 |
+
h1 { color: #00FFAA; }
|
126 |
+
.stApp { background-color: #121212; color: #FFFFFF; }
|
127 |
+
.sidebar .sidebar-content { background-color: #333333; color: #FFFFFF; }
|
128 |
+
.css-1d391kg { color: #FFFFFF; }
|
129 |
+
.css-18e3th9 { background-color: #1E1E1E; }
|
130 |
+
</style>
|
131 |
+
""", unsafe_allow_html=True)
|
132 |
+
|
133 |
+
# Data Overview Section
|
134 |
+
if option == 'Data Overview':
|
135 |
+
st.title('📊 Data Overview')
|
136 |
+
st.write(df.head())
|
137 |
+
st.write(f"Dataset Shape: {df.shape}")
|
138 |
+
st.write(f"Column Names: {df.columns.tolist()}")
|
139 |
+
st.write("Basic Statistical Overview:")
|
140 |
+
st.write(df.describe())
|
141 |
+
|
142 |
+
if st.checkbox('Show Missing Values'):
|
143 |
+
st.write(df.isnull().sum())
|
144 |
+
|
145 |
+
# Data Visualization Section
|
146 |
+
elif option == 'Data Visualization':
|
147 |
+
st.title('📈 Data Visualization')
|
148 |
+
column = st.selectbox('Select Column for Visualization', df.columns)
|
149 |
+
plot_type = st.radio('Choose plot type', ['Histogram', 'Boxplot', 'Violin Plot', 'Scatter Plot', 'Line Plot', 'Animated Plot'])
|
150 |
+
|
151 |
+
if plot_type == 'Animated Plot':
|
152 |
+
time_col = st.selectbox('Select Time Column (if applicable)', df.columns)
|
153 |
+
fig = px.scatter(df, x=column, y=column, animation_frame=time_col, size_max=60)
|
154 |
+
elif plot_type == 'Histogram':
|
155 |
+
fig = px.histogram(df, x=column, marginal='box', nbins=30)
|
156 |
+
elif plot_type == 'Boxplot':
|
157 |
+
fig = px.box(df, y=column)
|
158 |
+
elif plot_type == 'Violin Plot':
|
159 |
+
fig = px.violin(df, y=column, box=True, points='all')
|
160 |
+
elif plot_type == 'Scatter Plot':
|
161 |
+
x_col = st.selectbox('Select X axis', df.columns)
|
162 |
+
fig = px.scatter(df, x=x_col, y=column, color=column)
|
163 |
+
elif plot_type == 'Line Plot':
|
164 |
+
x_col = st.selectbox('Select X axis for Line Plot', df.columns)
|
165 |
+
fig = px.line(df, x=x_col, y=column)
|
166 |
+
|
167 |
+
st.plotly_chart(fig)
|
168 |
+
|
169 |
+
# Correlation Analysis Section
|
170 |
+
elif option == 'Correlation Analysis':
|
171 |
+
st.title('🔎 Correlation Analysis')
|
172 |
+
corr_matrix = df.corr()
|
173 |
+
fig, ax = plt.subplots(figsize=(12, 8))
|
174 |
+
sns.heatmap(corr_matrix, annot=True, cmap='coolwarm', fmt='.2f', ax=ax)
|
175 |
+
st.pyplot(fig)
|
176 |
+
|
177 |
+
# Interactive Reports Section
|
178 |
+
elif option == 'Interactive Reports':
|
179 |
+
st.title('📂 Interactive Reports')
|
180 |
+
st.write("Filter and explore the data.")
|
181 |
+
selected_columns = st.multiselect('Select columns to display', df.columns)
|
182 |
+
st.dataframe(df[selected_columns] if selected_columns else df)
|
183 |
+
|
184 |
+
st.write("Filter the Data:")
|
185 |
+
filter_column = st.selectbox('Select column to filter by', df.columns)
|
186 |
+
filter_value = st.text_input('Enter filter value')
|
187 |
+
if filter_value:
|
188 |
+
filtered_data = df[df[filter_column].astype(str).str.contains(filter_value, case=False)]
|
189 |
+
st.write(filtered_data)
|
190 |
+
|
191 |
+
# Download option
|
192 |
+
csv_data = filtered_data.to_csv(index=False).encode('utf-8')
|
193 |
+
st.download_button(label='Download Filtered Data as CSV', data=csv_data, file_name='filtered_data.csv', mime='text/csv')
|
194 |
+
|
195 |
+
# Data Insights Section
|
196 |
+
elif option == 'Data Insights':
|
197 |
+
st.title('🧠 Data Insights')
|
198 |
+
st.write("Gain insights into the data using various metrics.")
|
199 |
+
st.write("Total Unique Values per Column:")
|
200 |
+
st.write(df.nunique())
|
201 |
+
|
202 |
+
st.write("Top 5 Frequent Values for Each Column:")
|
203 |
+
for col in df.columns:
|
204 |
+
st.write(f"{col}: {df[col].value_counts().head(5)}")
|
205 |
+
|
206 |
+
# Patient Outcome Prediction Section
|
207 |
+
elif option == 'Patient Outcome Prediction':
|
208 |
+
st.title('🤖 Patient Outcome Prediction')
|
209 |
+
|
210 |
+
# Load the pre-trained model
|
211 |
+
model = load_model()
|
212 |
+
|
213 |
+
if model is not None:
|
214 |
+
st.success("✅ Pre-trained model loaded successfully!")
|
215 |
+
|
216 |
+
# Define class descriptions
|
217 |
+
class_descriptions = {
|
218 |
+
0: "Patient recovered and went home",
|
219 |
+
1: "Patient transferred to another hospital",
|
220 |
+
2: "Patient moved to rehab facility",
|
221 |
+
3: "Patient left against medical advice",
|
222 |
+
4: "Patient deceased or serious outcome"
|
223 |
+
}
|
224 |
+
|
225 |
+
# Display target class distribution if target column exists
|
226 |
+
target_column = 'target'
|
227 |
+
if target_column in df.columns:
|
228 |
+
st.subheader("Target Class Distribution")
|
229 |
+
target_counts = df[target_column].value_counts().reset_index()
|
230 |
+
target_counts.columns = ['Class', 'Count']
|
231 |
+
target_counts['Description'] = target_counts['Class'].map(class_descriptions)
|
232 |
+
st.write(target_counts)
|
233 |
+
|
234 |
+
fig = px.pie(target_counts, values='Count', names='Description', title='Target Class Distribution')
|
235 |
+
st.plotly_chart(fig)
|
236 |
+
|
237 |
+
# Prediction interface
|
238 |
+
st.subheader("Make Predictions")
|
239 |
+
st.write("Enter values for the features to predict the patient outcome:")
|
240 |
+
|
241 |
+
# Create a more interactive UI for prediction with all input values
|
242 |
+
col1, col2, col3 = st.columns(3)
|
243 |
+
|
244 |
+
# Create input fields for all required features
|
245 |
+
input_values = {}
|
246 |
+
|
247 |
+
with col1:
|
248 |
+
input_values['age'] = st.number_input("Age", min_value=0, max_value=120, value=51)
|
249 |
+
input_values['gender'] = st.selectbox("Gender", [0, 1], index=1, format_func=lambda x: "Male" if x == 0 else "Female")
|
250 |
+
input_values['previous_hospitalizations'] = st.number_input("Previous Hospitalizations", min_value=0, value=4)
|
251 |
+
input_values['heart_rate'] = st.number_input("Heart Rate", min_value=30, max_value=200, value=63)
|
252 |
+
input_values['respiratory_rate'] = st.number_input("Respiratory Rate", min_value=5, max_value=60, value=16)
|
253 |
+
input_values['blood_pressure_sys'] = st.number_input("Blood Pressure (Systolic)", min_value=50, max_value=250, value=86)
|
254 |
+
input_values['blood_pressure_dia'] = st.number_input("Blood Pressure (Diastolic)", min_value=30, max_value=150, value=58)
|
255 |
+
input_values['temperature'] = st.number_input("Temperature (°C)", min_value=35.0, max_value=42.0, value=35.86, step=0.1)
|
256 |
+
input_values['wbc_count'] = st.number_input("WBC Count", min_value=0.0, max_value=50.0, value=7.15, step=0.1)
|
257 |
+
input_values['creatinine'] = st.number_input("Creatinine", min_value=0.1, max_value=10.0, value=2.93, step=0.1)
|
258 |
+
|
259 |
+
with col2:
|
260 |
+
input_values['bilirubin'] = st.number_input("Bilirubin", min_value=0.1, max_value=30.0, value=1.72, step=0.1)
|
261 |
+
input_values['glucose'] = st.number_input("Glucose", min_value=40, max_value=500, value=137)
|
262 |
+
input_values['bun'] = st.number_input("BUN", min_value=5, max_value=150, value=36)
|
263 |
+
input_values['pH'] = st.number_input("pH", min_value=6.8, max_value=7.8, value=7.34, step=0.01)
|
264 |
+
input_values['pao2'] = st.number_input("PaO2", min_value=40, max_value=300, value=72)
|
265 |
+
input_values['pco2'] = st.number_input("PCO2", min_value=20, max_value=100, value=58)
|
266 |
+
input_values['fio2'] = st.number_input("FiO2", min_value=0.21, max_value=1.0, value=0.88, step=0.01)
|
267 |
+
input_values['gcs'] = st.slider("GCS Score", 3, 15, 5)
|
268 |
+
input_values['comorbidity_index'] = st.slider("Comorbidity Index", 0, 10, 1)
|
269 |
+
input_values['admission_source'] = st.selectbox("Admission Source", [0, 1, 2, 3], index=1, format_func=lambda x: ["Emergency", "OPD", "Transfer", "Other"][x])
|
270 |
+
|
271 |
+
with col3:
|
272 |
+
input_values['elective_surgery'] = st.selectbox("Elective Surgery", [0, 1], index=1, format_func=lambda x: "No" if x == 0 else "Yes")
|
273 |
+
input_values['num_medications'] = st.number_input("Number of Medications", min_value=0, value=18)
|
274 |
+
input_values['charlson_comorbidity_index'] = st.slider("Charlson Comorbidity Index", 0, 15, 1)
|
275 |
+
input_values['ews_score'] = st.slider("EWS Score", 0, 20, 7)
|
276 |
+
input_values['severity_score'] = st.slider("Severity Score", 0, 10, 4)
|
277 |
+
input_values['bed_occupancy_rate'] = st.slider("Bed Occupancy Rate (%)", 50, 100, int(68.67))
|
278 |
+
input_values['staff_to_patient_ratio'] = st.slider("Staff to Patient Ratio", 0.1, 2.0, 0.99, step=0.1)
|
279 |
+
input_values['past_icu_admissions'] = st.number_input("Past ICU Admissions", min_value=0, value=2)
|
280 |
+
input_values['previous_surgery'] = st.selectbox("Previous Surgery", [0, 1], index=1, format_func=lambda x: "No" if x == 0 else "Yes")
|
281 |
+
input_values['high_risk_treatment'] = st.selectbox("High Risk Treatment", [0, 1], index=1, format_func=lambda x: "No" if x == 0 else "Yes")
|
282 |
+
input_values['discharge_support'] = st.selectbox("Discharge Support", [0, 1], index=0, format_func=lambda x: "No" if x == 0 else "Yes")
|
283 |
+
|
284 |
+
|
285 |
+
if st.button("Predict Outcome"):
|
286 |
+
# Define input columns (must match your model's expected input features)
|
287 |
+
input_columns = [
|
288 |
+
'age', 'gender', 'previous_hospitalizations', 'heart_rate',
|
289 |
+
'respiratory_rate', 'blood_pressure_sys', 'blood_pressure_dia',
|
290 |
+
'temperature', 'wbc_count', 'creatinine', 'bilirubin', 'glucose', 'bun',
|
291 |
+
'pH', 'pao2', 'pco2', 'fio2', 'gcs', 'comorbidity_index',
|
292 |
+
'admission_source', 'elective_surgery', 'num_medications',
|
293 |
+
'charlson_comorbidity_index', 'ews_score', 'severity_score',
|
294 |
+
'bed_occupancy_rate', 'staff_to_patient_ratio', 'past_icu_admissions',
|
295 |
+
'previous_surgery', 'high_risk_treatment', 'discharge_support'
|
296 |
+
]
|
297 |
+
|
298 |
+
# Create a sample input for prediction (using a template from your dataset)
|
299 |
+
if len(df) > 0:
|
300 |
+
sample_input = pd.DataFrame([{col: 0 for col in input_columns}])
|
301 |
+
|
302 |
+
# Update with user inputs
|
303 |
+
for feature, value in input_values.items():
|
304 |
+
if feature in sample_input.columns:
|
305 |
+
sample_input[feature] = value
|
306 |
+
|
307 |
+
# Make prediction
|
308 |
+
try:
|
309 |
+
prediction = model.predict(sample_input)[0]
|
310 |
+
prediction_proba = model.predict_proba(sample_input)[0]
|
311 |
+
|
312 |
+
# Display prediction
|
313 |
+
st.subheader("Prediction Result")
|
314 |
+
st.write(f"Predicted Class: {prediction} - {class_descriptions.get(prediction, 'Unknown')}")
|
315 |
+
|
316 |
+
# Display probability for each class
|
317 |
+
st.write("Prediction Probabilities:")
|
318 |
+
proba_df = pd.DataFrame({
|
319 |
+
'Class': [class_descriptions.get(i, f"Class {i}") for i in range(len(prediction_proba))],
|
320 |
+
'Probability': prediction_proba
|
321 |
+
})
|
322 |
+
fig = px.bar(proba_df, x='Class', y='Probability', title='Prediction Probabilities')
|
323 |
+
st.plotly_chart(fig)
|
324 |
+
except Exception as e:
|
325 |
+
st.error(f"Error making prediction: {e}")
|
326 |
+
else:
|
327 |
+
st.error("Dataset is empty, cannot create input template.")
|
328 |
+
else:
|
329 |
+
st.error("Failed to load model. Please check if 'xgboost_patient_model.json' exists in the current directory.")
|
330 |
+
|
331 |
+
# NEW SECTION 1: Batch Prediction
|
332 |
+
elif option == 'Batch Prediction':
|
333 |
+
st.title("🏥 Batch Patient Outcome Prediction")
|
334 |
+
st.write("Upload a CSV file with patient data to predict outcomes for multiple patients at once.")
|
335 |
+
|
336 |
+
uploaded_file = st.file_uploader("📂 Upload CSV file", type=["csv"])
|
337 |
+
|
338 |
+
if uploaded_file is not None:
|
339 |
+
batch_df = pd.read_csv(uploaded_file)
|
340 |
+
batch_df = batch_df.dropna().reset_index(drop=True)
|
341 |
+
|
342 |
+
st.write("## Preview of Uploaded Data")
|
343 |
+
st.dataframe(batch_df.head(), use_container_width=True)
|
344 |
+
|
345 |
+
model = load_model()
|
346 |
+
actual_target, predicted_classes, predicted_outcomes, actual_outcomes = predict_outcome(model, batch_df.copy())
|
347 |
+
|
348 |
+
if predicted_classes is not None:
|
349 |
+
st.write("## 🏥 Prediction Results")
|
350 |
+
result_df = pd.DataFrame({
|
351 |
+
"Patient ID": range(1, len(predicted_classes) + 1),
|
352 |
+
"Actual Class": actual_target if actual_target is not None else ["N/A"] * len(predicted_classes),
|
353 |
+
"Predicted Class": predicted_classes,
|
354 |
+
"Predicted Outcome": predicted_outcomes
|
355 |
+
})
|
356 |
+
st.dataframe(result_df, use_container_width=True)
|
357 |
+
|
358 |
+
# Add visualization of batch prediction results
|
359 |
+
st.write("## Prediction Distribution")
|
360 |
+
results_count = pd.Series(predicted_outcomes).value_counts().reset_index()
|
361 |
+
results_count.columns = ['Predicted Outcome', 'Count']
|
362 |
+
fig = px.pie(results_count, values='Count', names='Predicted Outcome',
|
363 |
+
title='Distribution of Predicted Outcomes')
|
364 |
+
st.plotly_chart(fig)
|
365 |
+
|
366 |
+
# Offer download of results
|
367 |
+
csv_results = result_df.to_csv(index=False).encode('utf-8')
|
368 |
+
st.download_button(
|
369 |
+
label="Download Prediction Results",
|
370 |
+
data=csv_results,
|
371 |
+
file_name="patient_predictions.csv",
|
372 |
+
mime="text/csv"
|
373 |
+
)
|
374 |
+
|
375 |
+
# NEW SECTION 2: Model Performance
|
376 |
+
elif option == 'Model Performance':
|
377 |
+
st.title("📊 Model Performance Analysis")
|
378 |
+
|
379 |
+
# Check if data exists and contains target variable
|
380 |
+
if len(df) > 0 and 'target' in df.columns:
|
381 |
+
st.write("Analyze the model's performance on the dataset.")
|
382 |
+
|
383 |
+
# Split data into features and target
|
384 |
+
X = df.drop(columns=["target"]) # Features
|
385 |
+
y = df["target"] # Target
|
386 |
+
|
387 |
+
# Split data for testing
|
388 |
+
X_train, X_test, y_train, y_test = train_test_split(
|
389 |
+
X, y, test_size=0.2, random_state=42, stratify=y
|
390 |
+
)
|
391 |
+
|
392 |
+
|
393 |
+
|
394 |
+
# Load the model
|
395 |
+
model = load_model()
|
396 |
+
|
397 |
+
if model is not None:
|
398 |
+
# Make predictions
|
399 |
+
try:
|
400 |
+
y_pred = model.predict(X_test)
|
401 |
+
y_prob = model.predict_proba(X_test)
|
402 |
+
|
403 |
+
# Calculate metrics
|
404 |
+
accuracy = accuracy_score(y_test, y_pred)
|
405 |
+
conf_matrix = confusion_matrix(y_test, y_pred)
|
406 |
+
class_report = classification_report(y_test, y_pred, output_dict=True)
|
407 |
+
|
408 |
+
# Display metrics
|
409 |
+
col1, col2 = st.columns(2)
|
410 |
+
|
411 |
+
with col1:
|
412 |
+
st.metric("Model Accuracy", f"{accuracy:.2%}")
|
413 |
+
|
414 |
+
# Plot confusion matrix
|
415 |
+
st.write("### Confusion Matrix")
|
416 |
+
fig, ax = plt.subplots(figsize=(10, 8))
|
417 |
+
sns.heatmap(conf_matrix, annot=True, fmt='d', cmap='Blues', ax=ax)
|
418 |
+
ax.set_xlabel('Predicted Labels')
|
419 |
+
ax.set_ylabel('True Labels')
|
420 |
+
st.pyplot(fig)
|
421 |
+
|
422 |
+
with col2:
|
423 |
+
# Plot classification report
|
424 |
+
st.write("### Classification Report")
|
425 |
+
report_df = pd.DataFrame(class_report).transpose()
|
426 |
+
st.dataframe(report_df.style.format({
|
427 |
+
'precision': '{:.2f}',
|
428 |
+
'recall': '{:.2f}',
|
429 |
+
'f1-score': '{:.2f}',
|
430 |
+
'support': '{:.0f}'
|
431 |
+
}))
|
432 |
+
|
433 |
+
|
434 |
+
except Exception as e:
|
435 |
+
st.error(f"Error performing analysis: {e}")
|
436 |
+
else:
|
437 |
+
st.error("Model could not be loaded. Please check if the model file exists.")
|
438 |
+
else:
|
439 |
+
st.error("Cannot perform model analysis. Dataset is empty or missing target variable.")
|
440 |
+
|
441 |
+
st.sidebar.write("Forecasting discharge outcomes for critically ILL patients using machine learning")
|