|
import streamlit as st |
|
import pandas as pd |
|
from joblib import load |
|
import numpy as np |
|
import matplotlib.pyplot as plt |
|
|
|
@st.cache_resource |
|
def load_data() -> pd.DataFrame: |
|
""" |
|
Loads the `.csv` data using pandas |
|
""" |
|
df = pd.read_csv('./lung_disease_data.csv') |
|
|
|
numerical_columns = ['Age', 'Lung Capacity', 'Hospital Visits'] |
|
df[numerical_columns] = df[numerical_columns].fillna(df[numerical_columns].mean()) |
|
|
|
|
|
categorical_columns = ['Gender', 'Smoking Status', 'Disease Type', 'Treatment Type', 'Recovered'] |
|
df[categorical_columns] = df[categorical_columns].fillna(df[categorical_columns].mode().iloc[0]) |
|
|
|
return df |
|
|
|
@st.cache_resource |
|
def load_models() -> dict: |
|
""" |
|
Loads the trained models for prediction. |
|
""" |
|
nb = load('./models/GaussianNB.pkl') |
|
lg = load('./models/LogisticRegression.pkl') |
|
rf = load('./models/RandomForests.pkl') |
|
svm = load('./models/SVM.pkl') |
|
xgb = load('./models/XGBoost.pkl') |
|
|
|
models = dict({ |
|
'Gaussian Naive Bayes': nb, |
|
'Logistic Regression': lg, |
|
'Random Forest': rf, |
|
'Support Vector Machines': svm, |
|
'XG Boost': xgb |
|
}) |
|
|
|
return models |
|
|
|
def prediction(model, age: int, gender: str, |
|
smoke_status: str, lung_capacity: float, |
|
disease_type: str, treatment_type: str, |
|
hospital_visits: int |
|
) -> int: |
|
|
|
df_input = pd.DataFrame( |
|
{'Age': [age], |
|
'Hospital Visits': [hospital_visits], |
|
'Lung Capacity': [lung_capacity], |
|
'Gender': [1 if gender == "Male" else 0], |
|
'Smoking Status': [1 if smoke_status == "Yes" else 0], |
|
'Disease Type_Asthma': [1 if disease_type in 'Disease Type_Asthma' else 0], |
|
'Disease Type_Bronchitis': [1 if disease_type in 'Disease Type_Bronchitis' else 0], |
|
'Disease Type_COPD': [1 if disease_type in 'Disease Type_COPD' else 0], |
|
'Disease Type_Lung Cancer': [1 if disease_type in 'Disease Type_Lung Cancer' else 0], |
|
'Disease Type_Pneumonia': [1 if disease_type in 'Disease Type_Pneumonia' else 0], |
|
|
|
'Treatment Type_Medication': [1 if treatment_type in 'Treatment Type_Medication' else 0], |
|
'Treatment Type_Surgery': [1 if treatment_type in 'Treatment Type_Surgery' else 0], |
|
'Treatment Type_Therapy': [1 if treatment_type in 'Treatment Type_Therapy' else 0] |
|
} |
|
) |
|
|
|
input_arr = np.array(df_input) |
|
|
|
prediction = model.predict(input_arr)[0] |
|
|
|
return prediction.item() |
|
|
|
def main(): |
|
st.header("Lung Disease Recovery Predictor") |
|
st.caption('Prepared by `hydraadra112` | John Manuel Carado') |
|
|
|
data_tab, pred_tab, data_viz = st.tabs(['About Data', 'Prediction', 'Data Viz']) |
|
df = load_data() |
|
|
|
with data_tab: |
|
st.header('About the Data') |
|
st.caption('In this tab, we will explore the particular details about our data.') |
|
|
|
st.caption('Take a look at the data table.') |
|
st.dataframe(df) |
|
|
|
col1, col2 = st.columns(2) |
|
|
|
with col1: |
|
st.caption('This dataset captures detailed information about patients suffering from various lung conditions. It includes:') |
|
st.caption('**Age & Gender**: Patient demographics to understand the spread across age groups and gender.') |
|
st.caption('**Smoking Status**: Whether the patient is a smoker or non-smoker.') |
|
st.caption('**Lung Capacity**: Measured lung function to assess disease severity.') |
|
st.caption('**Disease Type**: The specific lung condition, like COPD or Bronchitis.') |
|
|
|
with col2: |
|
st.caption('**Treatment Type**: Different treatments patients received, including therapy, medication, or surgery.') |
|
st.caption('**Hospital Visits**: Number of visits to the hospital for managing the condition.') |
|
st.caption('**Recovery Status**: Indicates whether the patient recovered after treatment.') |
|
|
|
url = 'https://www.kaggle.com/datasets/samikshadalvi/lungs-diseases-dataset' |
|
st.caption('For more details, check out the the original [source](%s) of the dataset.' % url) |
|
|
|
with pred_tab: |
|
st.header('Prediction Tab') |
|
st.caption('In this tab, our ML models will predict if you will recover based on your data.') |
|
|
|
models = load_models() |
|
|
|
model = st.selectbox('Select preferred model for prediction', models.keys()) |
|
model_predictor = models[model] |
|
|
|
col1, col2 = st.columns(2) |
|
|
|
|
|
with col1: |
|
|
|
age = st.number_input('What is your age?', min_value=0, max_value=100) |
|
gender = st.radio('What is your gender?', df['Gender'].unique()) |
|
disease = st.selectbox('What is your lung condition?', df['Disease Type'].unique()) |
|
treatment = st.selectbox('Which treatment did you receive?', df['Treatment Type'].unique()) |
|
|
|
with col2: |
|
visits = st.number_input('How many times do you visit the hospital? (Annually)', min_value=0, max_value=365) |
|
capacity = st.slider('What is your lung capacity?', min_value=1.00, max_value=df['Lung Capacity'].max()+5) |
|
smoke = st.radio('Do you smoke?', ['Yes', 'No']) |
|
|
|
if st.button('Predict!'): |
|
pred = prediction(model_predictor, age, gender, smoke, capacity, disease, treatment, visits) |
|
rec = 'Recovered!' if pred == 1 else 'I am sorry.' |
|
st.header(rec) |
|
|
|
with data_viz: |
|
st.title('Data Viz Tab') |
|
st.caption('In this tab, we can visualize the relationships among our data.') |
|
st.caption('See our pre-existing plots and you can also plot your own!') |
|
|
|
dviz_tab1, dviz_tab2 = st.tabs(['Plots', 'Custom Plot']) |
|
|
|
with dviz_tab1: |
|
st.title('Feature Distribution and Relationships') |
|
st.caption('In this tab we will see the feature distributions of the dataset.') |
|
st.caption('We can see the relationships of the features among each other.') |
|
|
|
|
|
fig, axes = plt.subplots(nrows=4, ncols=2, figsize=(15, 25)) |
|
|
|
|
|
axes[0, 0].hist(df['Age']) |
|
axes[0, 0].set_xlabel('Age') |
|
axes[0, 0].set_ylabel('Frequency') |
|
axes[0, 0].set_title('Age Distribution') |
|
|
|
|
|
axes[0, 1].hist(df['Lung Capacity']) |
|
axes[0, 1].set_xlabel('Lung Capacity') |
|
axes[0, 1].set_ylabel('Frequency') |
|
axes[0, 1].set_title('Lung Capacity Distribution') |
|
|
|
|
|
axes[1, 0].hist(df['Hospital Visits']) |
|
axes[1, 0].set_xlabel('Hospital Visits') |
|
axes[1, 0].set_ylabel('Frequency') |
|
axes[1, 0].set_title('Hospital Visits Distribution') |
|
|
|
|
|
count_data = df.groupby(['Gender', 'Recovered']).size().unstack(fill_value=0) |
|
count_data.plot(kind='bar', stacked=False, ax=axes[1, 1]) |
|
axes[1, 1].set_xlabel('Gender') |
|
axes[1, 1].set_ylabel('Count') |
|
axes[1, 1].set_title('Gender Count by Recovery') |
|
axes[1, 1].legend(title='Recovered') |
|
|
|
|
|
count_data = df.groupby(['Smoking Status', 'Recovered']).size().unstack(fill_value=0) |
|
count_data.plot(kind='bar', stacked=False, ax=axes[2, 0]) |
|
axes[2, 0].set_xlabel('Smoking Status') |
|
axes[2, 0].set_ylabel('Count') |
|
axes[2, 0].set_title('Smoking Status by Recovery') |
|
axes[2, 0].legend(title='Recovered') |
|
|
|
|
|
count_data = df.groupby(['Disease Type', 'Recovered']).size().unstack(fill_value=0) |
|
count_data.plot(kind='bar', stacked=False, ax=axes[2, 1]) |
|
axes[2, 1].set_xlabel('Disease Type') |
|
axes[2, 1].set_ylabel('Count') |
|
axes[2, 1].set_title('Disease Type by Recovery') |
|
axes[2, 1].legend(title='Recovered') |
|
|
|
|
|
count_data = df.groupby(['Treatment Type', 'Recovered']).size().unstack(fill_value=0) |
|
count_data.plot(kind='bar', stacked=False, ax=axes[3, 0]) |
|
axes[3, 0].set_xlabel('Treatment Type') |
|
axes[3, 0].set_ylabel('Count') |
|
axes[3, 0].set_title('Treatment Type by Recovery') |
|
axes[3, 0].legend(title='Recovered') |
|
|
|
|
|
count_data = df.groupby(['Disease Type', 'Treatment Type']).size().unstack(fill_value=0) |
|
count_data.plot(kind='bar', stacked=False, ax=axes[3, 1]) |
|
axes[3, 1].set_xlabel('Disease Type') |
|
axes[3, 1].set_ylabel('Count') |
|
axes[3, 1].set_title('Disease Type by Treatment Type') |
|
axes[3, 1].legend(title='Treatment') |
|
|
|
st.pyplot(fig) |
|
plt.tight_layout() |
|
|
|
with dviz_tab2: |
|
x = st.selectbox("Choose X for plotting.", tuple(df.columns)) |
|
y = st.selectbox("Choose Y for plotting.", tuple(df.drop(x, axis=1).columns)) |
|
|
|
plot = st.selectbox("Select type of plot.", ("Scatter", "Bar", "Line")) |
|
|
|
if st.button("Plot X and Y!"): |
|
if plot == "Scatter": |
|
st.scatter_chart( |
|
data=df, |
|
x=x, |
|
y=y, |
|
size='Recovered' |
|
) |
|
elif plot == "Bar": |
|
st.bar_chart( |
|
data=df, |
|
x=x, |
|
y=y |
|
) |
|
elif plot == "Line": |
|
st.line_chart( |
|
data=df, |
|
x=x, |
|
y=y |
|
) |
|
|
|
if __name__ == "__main__": |
|
main() |