Somnath3570 commited on
Commit
badd206
·
verified ·
1 Parent(s): 148dc4c

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +441 -0
app.py ADDED
@@ -0,0 +1,441 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import streamlit as st
2
+ import pandas as pd
3
+ import seaborn as sns
4
+ import matplotlib.pyplot as plt
5
+ import plotly.express as px
6
+ import numpy as np
7
+ import xgboost as xgb
8
+ import os
9
+ from sklearn.model_selection import train_test_split
10
+ from sklearn.metrics import accuracy_score, classification_report, confusion_matrix
11
+
12
+ # Set page configuration at the very top
13
+ st.set_page_config(page_title="Healthcare Dashboard", layout="wide", page_icon="💡")
14
+
15
+ # Define human-readable labels for prediction outcomes
16
+ OUTCOME_MAP = {
17
+ 0: "Patient recovered and went home",
18
+ 1: "Transferred to another hospital or care center",
19
+ 2: "Transferred to a rehab center",
20
+ 3: "Left the hospital early without approval",
21
+ 4: "Passed away or had a very serious event",
22
+ }
23
+
24
+ # Function to load the model
25
+ def load_model():
26
+ try:
27
+ model = xgb.XGBClassifier()
28
+ model.load_model("xgboost_patient_model.json")
29
+ return model
30
+ except Exception as e:
31
+ st.error(f"Error loading model: {e}")
32
+ return None
33
+
34
+ # Ensure data matches model's feature requirements
35
+ def preprocess_data(data, model_features):
36
+ try:
37
+ data = data.apply(pd.to_numeric, errors='coerce')
38
+ missing_features = [f for f in model_features if f not in data.columns]
39
+ extra_features = [f for f in data.columns if f not in model_features]
40
+
41
+ if missing_features:
42
+ st.error(f"❌ Missing required features: {missing_features}")
43
+ return None
44
+ if extra_features:
45
+ st.warning(f"⚠️ Extra features in uploaded data: {extra_features}")
46
+
47
+ return data[model_features]
48
+ except Exception as e:
49
+ st.error(f"Data preprocessing error: {e}")
50
+ return None
51
+
52
+ # Predict patient outcomes
53
+ def predict_outcome(model, data):
54
+ if model is None:
55
+ return None, None, None, None
56
+
57
+ actual_target = data.pop("target") if "target" in data.columns else None
58
+
59
+ try:
60
+ model_features = model.get_booster().feature_names
61
+ data = preprocess_data(data, model_features)
62
+ if data is None:
63
+ return None, None, None, None
64
+
65
+ predictions = model.predict(data)
66
+
67
+ # Convert numerical predictions to human-readable labels
68
+ mapped_predictions = [OUTCOME_MAP[pred] for pred in predictions]
69
+ actual_labels = [OUTCOME_MAP[actual] for actual in actual_target] if actual_target is not None else ["N/A"] * len(predictions)
70
+
71
+ # Debugging information
72
+ if actual_target is not None:
73
+ correct_predictions = (predictions == actual_target).sum()
74
+ total_predictions = len(actual_target)
75
+ accuracy = (correct_predictions / total_predictions) * 100
76
+ st.write(f"✅ Correct Predictions: {correct_predictions}/{total_predictions}")
77
+ st.write(f"📊 Model Accuracy: **{accuracy:.2f}%**")
78
+
79
+ return actual_target, predictions, mapped_predictions, actual_labels
80
+ except Exception as e:
81
+ st.error(f"Prediction error: {e}")
82
+ return None, None, None, None
83
+
84
+ # Load the data
85
+ file_path = 'final_cleaned_patient_data.csv'
86
+ try:
87
+ df = pd.read_csv(file_path)
88
+ except Exception as e:
89
+ st.error(f"Error loading data: {e}")
90
+ df = pd.DataFrame() # Create empty DataFrame if file doesn't exist
91
+
92
+ # Sidebar navigation
93
+ st.sidebar.title('Healthcare Data Dashboard')
94
+
95
+ # Team Members Section
96
+ st.sidebar.markdown("### 🏆 Team Members:")
97
+ team_members = [
98
+ "1. R. Sai Somnath",
99
+ "2. S. Sreevardhan",
100
+ "3. S. Mohammad Basha",
101
+ "4. V. Hussain Basha",
102
+ "5. P. Charles"
103
+ ]
104
+ for member in team_members:
105
+ st.sidebar.text(member)
106
+
107
+ # Add a divider
108
+ st.sidebar.markdown("---")
109
+
110
+ # Section navigation - Added two new sections
111
+ option = st.sidebar.selectbox('Choose a section', [
112
+ 'Data Overview',
113
+ 'Data Visualization',
114
+ 'Interactive Reports',
115
+ 'Correlation Analysis',
116
+ 'Data Insights',
117
+ 'Patient Outcome Prediction',
118
+ 'Batch Prediction',
119
+ 'Model Performance'
120
+ ])
121
+
122
+ # Apply a Streamlit theme with a dark background for a modern look
123
+ st.markdown("""
124
+ <style>
125
+ h1 { color: #00FFAA; }
126
+ .stApp { background-color: #121212; color: #FFFFFF; }
127
+ .sidebar .sidebar-content { background-color: #333333; color: #FFFFFF; }
128
+ .css-1d391kg { color: #FFFFFF; }
129
+ .css-18e3th9 { background-color: #1E1E1E; }
130
+ </style>
131
+ """, unsafe_allow_html=True)
132
+
133
+ # Data Overview Section
134
+ if option == 'Data Overview':
135
+ st.title('📊 Data Overview')
136
+ st.write(df.head())
137
+ st.write(f"Dataset Shape: {df.shape}")
138
+ st.write(f"Column Names: {df.columns.tolist()}")
139
+ st.write("Basic Statistical Overview:")
140
+ st.write(df.describe())
141
+
142
+ if st.checkbox('Show Missing Values'):
143
+ st.write(df.isnull().sum())
144
+
145
+ # Data Visualization Section
146
+ elif option == 'Data Visualization':
147
+ st.title('📈 Data Visualization')
148
+ column = st.selectbox('Select Column for Visualization', df.columns)
149
+ plot_type = st.radio('Choose plot type', ['Histogram', 'Boxplot', 'Violin Plot', 'Scatter Plot', 'Line Plot', 'Animated Plot'])
150
+
151
+ if plot_type == 'Animated Plot':
152
+ time_col = st.selectbox('Select Time Column (if applicable)', df.columns)
153
+ fig = px.scatter(df, x=column, y=column, animation_frame=time_col, size_max=60)
154
+ elif plot_type == 'Histogram':
155
+ fig = px.histogram(df, x=column, marginal='box', nbins=30)
156
+ elif plot_type == 'Boxplot':
157
+ fig = px.box(df, y=column)
158
+ elif plot_type == 'Violin Plot':
159
+ fig = px.violin(df, y=column, box=True, points='all')
160
+ elif plot_type == 'Scatter Plot':
161
+ x_col = st.selectbox('Select X axis', df.columns)
162
+ fig = px.scatter(df, x=x_col, y=column, color=column)
163
+ elif plot_type == 'Line Plot':
164
+ x_col = st.selectbox('Select X axis for Line Plot', df.columns)
165
+ fig = px.line(df, x=x_col, y=column)
166
+
167
+ st.plotly_chart(fig)
168
+
169
+ # Correlation Analysis Section
170
+ elif option == 'Correlation Analysis':
171
+ st.title('🔎 Correlation Analysis')
172
+ corr_matrix = df.corr()
173
+ fig, ax = plt.subplots(figsize=(12, 8))
174
+ sns.heatmap(corr_matrix, annot=True, cmap='coolwarm', fmt='.2f', ax=ax)
175
+ st.pyplot(fig)
176
+
177
+ # Interactive Reports Section
178
+ elif option == 'Interactive Reports':
179
+ st.title('📂 Interactive Reports')
180
+ st.write("Filter and explore the data.")
181
+ selected_columns = st.multiselect('Select columns to display', df.columns)
182
+ st.dataframe(df[selected_columns] if selected_columns else df)
183
+
184
+ st.write("Filter the Data:")
185
+ filter_column = st.selectbox('Select column to filter by', df.columns)
186
+ filter_value = st.text_input('Enter filter value')
187
+ if filter_value:
188
+ filtered_data = df[df[filter_column].astype(str).str.contains(filter_value, case=False)]
189
+ st.write(filtered_data)
190
+
191
+ # Download option
192
+ csv_data = filtered_data.to_csv(index=False).encode('utf-8')
193
+ st.download_button(label='Download Filtered Data as CSV', data=csv_data, file_name='filtered_data.csv', mime='text/csv')
194
+
195
+ # Data Insights Section
196
+ elif option == 'Data Insights':
197
+ st.title('🧠 Data Insights')
198
+ st.write("Gain insights into the data using various metrics.")
199
+ st.write("Total Unique Values per Column:")
200
+ st.write(df.nunique())
201
+
202
+ st.write("Top 5 Frequent Values for Each Column:")
203
+ for col in df.columns:
204
+ st.write(f"{col}: {df[col].value_counts().head(5)}")
205
+
206
+ # Patient Outcome Prediction Section
207
+ elif option == 'Patient Outcome Prediction':
208
+ st.title('🤖 Patient Outcome Prediction')
209
+
210
+ # Load the pre-trained model
211
+ model = load_model()
212
+
213
+ if model is not None:
214
+ st.success("✅ Pre-trained model loaded successfully!")
215
+
216
+ # Define class descriptions
217
+ class_descriptions = {
218
+ 0: "Patient recovered and went home",
219
+ 1: "Patient transferred to another hospital",
220
+ 2: "Patient moved to rehab facility",
221
+ 3: "Patient left against medical advice",
222
+ 4: "Patient deceased or serious outcome"
223
+ }
224
+
225
+ # Display target class distribution if target column exists
226
+ target_column = 'target'
227
+ if target_column in df.columns:
228
+ st.subheader("Target Class Distribution")
229
+ target_counts = df[target_column].value_counts().reset_index()
230
+ target_counts.columns = ['Class', 'Count']
231
+ target_counts['Description'] = target_counts['Class'].map(class_descriptions)
232
+ st.write(target_counts)
233
+
234
+ fig = px.pie(target_counts, values='Count', names='Description', title='Target Class Distribution')
235
+ st.plotly_chart(fig)
236
+
237
+ # Prediction interface
238
+ st.subheader("Make Predictions")
239
+ st.write("Enter values for the features to predict the patient outcome:")
240
+
241
+ # Create a more interactive UI for prediction with all input values
242
+ col1, col2, col3 = st.columns(3)
243
+
244
+ # Create input fields for all required features
245
+ input_values = {}
246
+
247
+ with col1:
248
+ input_values['age'] = st.number_input("Age", min_value=0, max_value=120, value=51)
249
+ input_values['gender'] = st.selectbox("Gender", [0, 1], index=1, format_func=lambda x: "Male" if x == 0 else "Female")
250
+ input_values['previous_hospitalizations'] = st.number_input("Previous Hospitalizations", min_value=0, value=4)
251
+ input_values['heart_rate'] = st.number_input("Heart Rate", min_value=30, max_value=200, value=63)
252
+ input_values['respiratory_rate'] = st.number_input("Respiratory Rate", min_value=5, max_value=60, value=16)
253
+ input_values['blood_pressure_sys'] = st.number_input("Blood Pressure (Systolic)", min_value=50, max_value=250, value=86)
254
+ input_values['blood_pressure_dia'] = st.number_input("Blood Pressure (Diastolic)", min_value=30, max_value=150, value=58)
255
+ input_values['temperature'] = st.number_input("Temperature (°C)", min_value=35.0, max_value=42.0, value=35.86, step=0.1)
256
+ input_values['wbc_count'] = st.number_input("WBC Count", min_value=0.0, max_value=50.0, value=7.15, step=0.1)
257
+ input_values['creatinine'] = st.number_input("Creatinine", min_value=0.1, max_value=10.0, value=2.93, step=0.1)
258
+
259
+ with col2:
260
+ input_values['bilirubin'] = st.number_input("Bilirubin", min_value=0.1, max_value=30.0, value=1.72, step=0.1)
261
+ input_values['glucose'] = st.number_input("Glucose", min_value=40, max_value=500, value=137)
262
+ input_values['bun'] = st.number_input("BUN", min_value=5, max_value=150, value=36)
263
+ input_values['pH'] = st.number_input("pH", min_value=6.8, max_value=7.8, value=7.34, step=0.01)
264
+ input_values['pao2'] = st.number_input("PaO2", min_value=40, max_value=300, value=72)
265
+ input_values['pco2'] = st.number_input("PCO2", min_value=20, max_value=100, value=58)
266
+ input_values['fio2'] = st.number_input("FiO2", min_value=0.21, max_value=1.0, value=0.88, step=0.01)
267
+ input_values['gcs'] = st.slider("GCS Score", 3, 15, 5)
268
+ input_values['comorbidity_index'] = st.slider("Comorbidity Index", 0, 10, 1)
269
+ input_values['admission_source'] = st.selectbox("Admission Source", [0, 1, 2, 3], index=1, format_func=lambda x: ["Emergency", "OPD", "Transfer", "Other"][x])
270
+
271
+ with col3:
272
+ input_values['elective_surgery'] = st.selectbox("Elective Surgery", [0, 1], index=1, format_func=lambda x: "No" if x == 0 else "Yes")
273
+ input_values['num_medications'] = st.number_input("Number of Medications", min_value=0, value=18)
274
+ input_values['charlson_comorbidity_index'] = st.slider("Charlson Comorbidity Index", 0, 15, 1)
275
+ input_values['ews_score'] = st.slider("EWS Score", 0, 20, 7)
276
+ input_values['severity_score'] = st.slider("Severity Score", 0, 10, 4)
277
+ input_values['bed_occupancy_rate'] = st.slider("Bed Occupancy Rate (%)", 50, 100, int(68.67))
278
+ input_values['staff_to_patient_ratio'] = st.slider("Staff to Patient Ratio", 0.1, 2.0, 0.99, step=0.1)
279
+ input_values['past_icu_admissions'] = st.number_input("Past ICU Admissions", min_value=0, value=2)
280
+ input_values['previous_surgery'] = st.selectbox("Previous Surgery", [0, 1], index=1, format_func=lambda x: "No" if x == 0 else "Yes")
281
+ input_values['high_risk_treatment'] = st.selectbox("High Risk Treatment", [0, 1], index=1, format_func=lambda x: "No" if x == 0 else "Yes")
282
+ input_values['discharge_support'] = st.selectbox("Discharge Support", [0, 1], index=0, format_func=lambda x: "No" if x == 0 else "Yes")
283
+
284
+
285
+ if st.button("Predict Outcome"):
286
+ # Define input columns (must match your model's expected input features)
287
+ input_columns = [
288
+ 'age', 'gender', 'previous_hospitalizations', 'heart_rate',
289
+ 'respiratory_rate', 'blood_pressure_sys', 'blood_pressure_dia',
290
+ 'temperature', 'wbc_count', 'creatinine', 'bilirubin', 'glucose', 'bun',
291
+ 'pH', 'pao2', 'pco2', 'fio2', 'gcs', 'comorbidity_index',
292
+ 'admission_source', 'elective_surgery', 'num_medications',
293
+ 'charlson_comorbidity_index', 'ews_score', 'severity_score',
294
+ 'bed_occupancy_rate', 'staff_to_patient_ratio', 'past_icu_admissions',
295
+ 'previous_surgery', 'high_risk_treatment', 'discharge_support'
296
+ ]
297
+
298
+ # Create a sample input for prediction (using a template from your dataset)
299
+ if len(df) > 0:
300
+ sample_input = pd.DataFrame([{col: 0 for col in input_columns}])
301
+
302
+ # Update with user inputs
303
+ for feature, value in input_values.items():
304
+ if feature in sample_input.columns:
305
+ sample_input[feature] = value
306
+
307
+ # Make prediction
308
+ try:
309
+ prediction = model.predict(sample_input)[0]
310
+ prediction_proba = model.predict_proba(sample_input)[0]
311
+
312
+ # Display prediction
313
+ st.subheader("Prediction Result")
314
+ st.write(f"Predicted Class: {prediction} - {class_descriptions.get(prediction, 'Unknown')}")
315
+
316
+ # Display probability for each class
317
+ st.write("Prediction Probabilities:")
318
+ proba_df = pd.DataFrame({
319
+ 'Class': [class_descriptions.get(i, f"Class {i}") for i in range(len(prediction_proba))],
320
+ 'Probability': prediction_proba
321
+ })
322
+ fig = px.bar(proba_df, x='Class', y='Probability', title='Prediction Probabilities')
323
+ st.plotly_chart(fig)
324
+ except Exception as e:
325
+ st.error(f"Error making prediction: {e}")
326
+ else:
327
+ st.error("Dataset is empty, cannot create input template.")
328
+ else:
329
+ st.error("Failed to load model. Please check if 'xgboost_patient_model.json' exists in the current directory.")
330
+
331
+ # NEW SECTION 1: Batch Prediction
332
+ elif option == 'Batch Prediction':
333
+ st.title("🏥 Batch Patient Outcome Prediction")
334
+ st.write("Upload a CSV file with patient data to predict outcomes for multiple patients at once.")
335
+
336
+ uploaded_file = st.file_uploader("📂 Upload CSV file", type=["csv"])
337
+
338
+ if uploaded_file is not None:
339
+ batch_df = pd.read_csv(uploaded_file)
340
+ batch_df = batch_df.dropna().reset_index(drop=True)
341
+
342
+ st.write("## Preview of Uploaded Data")
343
+ st.dataframe(batch_df.head(), use_container_width=True)
344
+
345
+ model = load_model()
346
+ actual_target, predicted_classes, predicted_outcomes, actual_outcomes = predict_outcome(model, batch_df.copy())
347
+
348
+ if predicted_classes is not None:
349
+ st.write("## 🏥 Prediction Results")
350
+ result_df = pd.DataFrame({
351
+ "Patient ID": range(1, len(predicted_classes) + 1),
352
+ "Actual Class": actual_target if actual_target is not None else ["N/A"] * len(predicted_classes),
353
+ "Predicted Class": predicted_classes,
354
+ "Predicted Outcome": predicted_outcomes
355
+ })
356
+ st.dataframe(result_df, use_container_width=True)
357
+
358
+ # Add visualization of batch prediction results
359
+ st.write("## Prediction Distribution")
360
+ results_count = pd.Series(predicted_outcomes).value_counts().reset_index()
361
+ results_count.columns = ['Predicted Outcome', 'Count']
362
+ fig = px.pie(results_count, values='Count', names='Predicted Outcome',
363
+ title='Distribution of Predicted Outcomes')
364
+ st.plotly_chart(fig)
365
+
366
+ # Offer download of results
367
+ csv_results = result_df.to_csv(index=False).encode('utf-8')
368
+ st.download_button(
369
+ label="Download Prediction Results",
370
+ data=csv_results,
371
+ file_name="patient_predictions.csv",
372
+ mime="text/csv"
373
+ )
374
+
375
+ # NEW SECTION 2: Model Performance
376
+ elif option == 'Model Performance':
377
+ st.title("📊 Model Performance Analysis")
378
+
379
+ # Check if data exists and contains target variable
380
+ if len(df) > 0 and 'target' in df.columns:
381
+ st.write("Analyze the model's performance on the dataset.")
382
+
383
+ # Split data into features and target
384
+ X = df.drop(columns=["target"]) # Features
385
+ y = df["target"] # Target
386
+
387
+ # Split data for testing
388
+ X_train, X_test, y_train, y_test = train_test_split(
389
+ X, y, test_size=0.2, random_state=42, stratify=y
390
+ )
391
+
392
+
393
+
394
+ # Load the model
395
+ model = load_model()
396
+
397
+ if model is not None:
398
+ # Make predictions
399
+ try:
400
+ y_pred = model.predict(X_test)
401
+ y_prob = model.predict_proba(X_test)
402
+
403
+ # Calculate metrics
404
+ accuracy = accuracy_score(y_test, y_pred)
405
+ conf_matrix = confusion_matrix(y_test, y_pred)
406
+ class_report = classification_report(y_test, y_pred, output_dict=True)
407
+
408
+ # Display metrics
409
+ col1, col2 = st.columns(2)
410
+
411
+ with col1:
412
+ st.metric("Model Accuracy", f"{accuracy:.2%}")
413
+
414
+ # Plot confusion matrix
415
+ st.write("### Confusion Matrix")
416
+ fig, ax = plt.subplots(figsize=(10, 8))
417
+ sns.heatmap(conf_matrix, annot=True, fmt='d', cmap='Blues', ax=ax)
418
+ ax.set_xlabel('Predicted Labels')
419
+ ax.set_ylabel('True Labels')
420
+ st.pyplot(fig)
421
+
422
+ with col2:
423
+ # Plot classification report
424
+ st.write("### Classification Report")
425
+ report_df = pd.DataFrame(class_report).transpose()
426
+ st.dataframe(report_df.style.format({
427
+ 'precision': '{:.2f}',
428
+ 'recall': '{:.2f}',
429
+ 'f1-score': '{:.2f}',
430
+ 'support': '{:.0f}'
431
+ }))
432
+
433
+
434
+ except Exception as e:
435
+ st.error(f"Error performing analysis: {e}")
436
+ else:
437
+ st.error("Model could not be loaded. Please check if the model file exists.")
438
+ else:
439
+ st.error("Cannot perform model analysis. Dataset is empty or missing target variable.")
440
+
441
+ st.sidebar.write("Forecasting discharge outcomes for critically ILL patients using machine learning")