Thomas Lucchetta
Update app.py
0854c30 unverified
raw
history blame
7.16 kB
import streamlit as st
import nibabel as nib
import os.path
import os
from nilearn import plotting
import torch
from monai.transforms import (
EnsureChannelFirst,
Compose,
Resize,
ScaleIntensity,
LoadImage,
)
import torch.nn.functional as F
import numpy as np
from statistics import mean
from constants import CLASSES
from model.download_model import load_model
from huggingface_hub import hf_hub_download
#SET PAGE TITLE
st.set_page_config(page_title = "Alzheimer Classifier", page_icon = ":brain:", layout = "wide")
#LOAD MODEL
model = load_model()
#SET NIFTI FILE LOADING AND PROCESSING CONFIGURATIONS
transforms = Compose([
ScaleIntensity(),
EnsureChannelFirst(),
Resize((96, 96, 96)),
])
load_img = LoadImage(image_only=True)
#SET CLASSES
class_names = CLASSES
#SILENCE STREAMIT WARNING
st.set_option('deprecation.showPyplotGlobalUse', False)
#SET STREAMLIT SESSION STATES
if 'clicked_pp' not in st.session_state:
st.session_state.clicked_pp = False
if 'clicked_pred' not in st.session_state:
st.session_state.clicked_pred = False
def click_pp_true():
st.session_state.clicked_pp = True
def click_pred_true():
st.session_state.clicked_pred = True
def click_false():
st.session_state.clicked_pp = False
st.session_state.clicked_pred = False
###########################################################
###################### STREAMLIT APP ######################
###########################################################
with st.sidebar:
st.title("Alzheimer Classifier Demo")
img_path = st.selectbox(
"Select Image",
tuple(class_names),
on_change= click_false,
)
col1, col2 = st.columns((1,1))
with col1:
run_preprocess = st.button("Preprocess Image", on_click=click_pp_true)
if st.session_state.clicked_pp:
with col2:
run_pred = st.button("Run Prediction", on_click= click_pred_true)
with st.container():
if img_path != "":
if st.session_state.clicked_pp:
if st.session_state.clicked_pred == False:
with st.container():
pred_image = nib.load(hf_hub_download(repo_id= "rootstrap-org/Alzheimer-Classifier-Demo", repo_type="dataset", subfolder="preprocessed", filename = img_path + ".nii.gz"))
bounds_pred = plotting.find_cuts._get_auto_mask_bounds(pred_image)
st.sidebar.write("#")
y_value_pred = st.sidebar.slider('Move the slider to adjust the coronal cut ', bounds_pred[1][0], bounds_pred[1][1], mean([bounds_pred[1][0], bounds_pred[1][1]]))
x_value_pred = st.sidebar.slider('Move the slider to adjust the sagittal cut ', bounds_pred[0][0], bounds_pred[0][1], mean([bounds_pred[0][0], bounds_pred[0][1]]))
z_value_pred = st.sidebar.slider('Move the slider to adjust the axial cut ', bounds_pred[2][0], bounds_pred[2][1], mean([bounds_pred[2][0], bounds_pred[2][1]]))
plotting.plot_img(pred_image, cmap="grey", cut_coords=(x_value_pred,y_value_pred,z_value_pred), black_bg=True)
st.pyplot()
else:
with st.container():
pred_image = nib.load(hf_hub_download(repo_id= "rootstrap-org/Alzheimer-Classifier-Demo", repo_type="dataset", subfolder="preprocessed", filename = img_path + ".nii.gz"))
bounds_pred = plotting.find_cuts._get_auto_mask_bounds(pred_image)
st.sidebar.write("#")
y_value_pred = st.sidebar.slider('Move the slider to adjust the coronal cut ', bounds_pred[1][0], bounds_pred[1][1], mean([bounds_pred[1][0], bounds_pred[1][1]]))
x_value_pred = st.sidebar.slider('Move the slider to adjust the sagittal cut ', bounds_pred[0][0], bounds_pred[0][1], mean([bounds_pred[0][0], bounds_pred[0][1]]))
z_value_pred = st.sidebar.slider('Move the slider to adjust the axial cut ', bounds_pred[2][0], bounds_pred[2][1], mean([bounds_pred[2][0], bounds_pred[2][1]]))
img_array = load_img(hf_hub_download(repo_id= "rootstrap-org/Alzheimer-Classifier-Demo", repo_type="dataset", subfolder="preprocessed", filename = img_path + ".nii.gz"))
new_data = transforms(img_array)
new_data_tensor = torch.from_numpy(np.array([new_data]))
with torch.no_grad():
output = model(new_data_tensor)
probabilities = F.softmax(output, dim=1)
probabilities_np = probabilities.numpy()
probabilities_item = probabilities_np[0]
probabilities_percentage = probabilities_item * 100
predicted_class_index = np.argmax(probabilities_np[0])
predicted_class_name = class_names[predicted_class_index]
predicted_probability = probabilities_percentage[predicted_class_index]
st.sidebar.write("#")
if predicted_class_index == 0:
color_name = "red"
elif predicted_class_index == 1:
color_name = "blue"
elif predicted_class_index == 2:
color_name = "green"
if predicted_probability > 80:
color_prob = "green"
elif predicted_probability > 60:
color_prob = "yellow"
else:
color_prob = "red"
class_col, pred_col = st.columns((1,1))
with class_col:
st.write(f"### Predicted Class: :{color_name}[{predicted_class_name}]")
with pred_col:
st.write(f"### Probability: :{color_prob}[{predicted_probability:.2f}%]")
plotting.plot_img(pred_image, cmap="grey", cut_coords=(x_value_pred,y_value_pred,z_value_pred), black_bg=True)
st.pyplot()
else:
raw_image = nib.load(hf_hub_download(repo_id= "rootstrap-org/Alzheimer-Classifier-Demo", repo_type="dataset", subfolder="raw", filename = img_path + ".nii"))
bounds_raw = plotting.find_cuts._get_auto_mask_bounds(raw_image)
st.sidebar.write("#")
y_value_raw = st.sidebar.slider('Move the slider to adjust the coronal cut', bounds_raw[1][0], bounds_raw[1][1], mean([bounds_raw[1][0], bounds_raw[1][1]]))
x_value_raw = st.sidebar.slider('Move the slider to adjust the sagittal cut', bounds_raw[0][0], bounds_raw[0][1], mean([bounds_raw[0][0], bounds_raw[0][1]]))
z_value_raw = st.sidebar.slider('Move the slider to adjust the axial cut', bounds_raw[2][0], bounds_raw[2][1], mean([bounds_raw[2][0], bounds_raw[2][1]]))
plotting.plot_img(raw_image, cmap = "grey", cut_coords=(x_value_raw,y_value_raw,z_value_raw), black_bg=True)
st.pyplot()