Spaces:
Sleeping
Sleeping
import streamlit as st | |
import numpy as np | |
from tensorflow.keras.preprocessing import image | |
from tensorflow.keras.applications.efficientnet import preprocess_input | |
import tensorflow as tf | |
from PIL import Image | |
def main(): | |
st.set_page_config( | |
page_title="Alzheimer's Detection", | |
layout="centered" | |
) | |
st.markdown("<h1 style='text-align: center;'>Alzheimer's Detection Tool🧠</h1>", unsafe_allow_html=True) | |
# Load the saved model | |
loaded_model = tf.keras.models.load_model("model.h5", compile=False) | |
def predict_image_class(img_path): | |
img = image.load_img(img_path, target_size=(224, 224)) | |
img_array = image.img_to_array(img) | |
img_array = np.expand_dims(img_array, axis=0) | |
img_array = preprocess_input(img_array) | |
predictions = loaded_model.predict(img_array) | |
predicted_class = np.argmax(predictions, axis=1)[0] | |
return predicted_class | |
uploaded_file = st.file_uploader('Upload an MRI image...', type=['jpg', 'png', 'jpeg']) | |
example_images = { | |
"MILD DEMENTED": "mild_468_0_4983.jpg", | |
"MODERATE DEMENTED": "moderate_2_0_72.jpg", | |
"NON DEMENTED": "non_61.jpg", | |
"VERY MILD DEMENTED": "verymild_37_0_2606.jpg" | |
} | |
if uploaded_file is not None: | |
predicted_class = predict_image_class(uploaded_file) | |
class_names = ["MILD DEMENTED", "MODERATE DEMENTED", "NON DEMENTED", "VERY MILD DEMENTED"] | |
st.markdown(f"<div style='text-align: center; font-size: 30px;'>Predicted Alzheimer's stage is: <b>{class_names[predicted_class]}</b></div>", unsafe_allow_html=True) | |
st.image(uploaded_file, use_column_width=False, width=300) | |
else: | |
st.write("Or select an example image below:") | |
selected_example = st.selectbox("Select an example image:", list(example_images.keys())) | |
if selected_example: | |
example_image_path = example_images[selected_example] | |
example_image = Image.open(example_image_path) | |
# Display the selected example image | |
st.image(example_image, caption=selected_example, use_column_width=False,width=300) | |
# Predict based on the selected example image | |
predicted_class = predict_image_class(example_image_path) | |
class_names = ["MILD DEMENTED", "MODERATE DEMENTED", "NON DEMENTED", "VERY MILD DEMENTED"] | |
st.markdown(f"<div style='text-align: center; font-size: 30px;'>Predicted Alzheimer's stage is: <b>{class_names[predicted_class]}</b></div>", unsafe_allow_html=True) | |
if __name__ == '__main__': | |
main() | |