import streamlit as st import numpy as np from tensorflow.keras.preprocessing import image from tensorflow.keras.applications.efficientnet import preprocess_input import tensorflow as tf from PIL import Image def main(): st.set_page_config( page_title="Alzheimer's Detection", layout="centered" ) st.markdown("

Alzheimer's Detection Tool🧠

", unsafe_allow_html=True) # Load the saved model loaded_model = tf.keras.models.load_model("model.h5", compile=False) def predict_image_class(img_path): img = image.load_img(img_path, target_size=(224, 224)) img_array = image.img_to_array(img) img_array = np.expand_dims(img_array, axis=0) img_array = preprocess_input(img_array) predictions = loaded_model.predict(img_array) predicted_class = np.argmax(predictions, axis=1)[0] return predicted_class uploaded_file = st.file_uploader('Upload an MRI image...', type=['jpg', 'png', 'jpeg']) example_images = { "MILD DEMENTED": "mild_468_0_4983.jpg", "MODERATE DEMENTED": "moderate_2_0_72.jpg", "NON DEMENTED": "non_61.jpg", "VERY MILD DEMENTED": "verymild_37_0_2606.jpg" } if uploaded_file is not None: predicted_class = predict_image_class(uploaded_file) class_names = ["MILD DEMENTED", "MODERATE DEMENTED", "NON DEMENTED", "VERY MILD DEMENTED"] st.markdown(f"
Predicted Alzheimer's stage is: {class_names[predicted_class]}
", unsafe_allow_html=True) st.image(uploaded_file, use_column_width=False, width=300) else: st.write("Or select an example image below:") selected_example = st.selectbox("Select an example image:", list(example_images.keys())) if selected_example: example_image_path = example_images[selected_example] example_image = Image.open(example_image_path) # Display the selected example image st.image(example_image, caption=selected_example, use_column_width=False,width=300) # Predict based on the selected example image predicted_class = predict_image_class(example_image_path) class_names = ["MILD DEMENTED", "MODERATE DEMENTED", "NON DEMENTED", "VERY MILD DEMENTED"] st.markdown(f"
Predicted Alzheimer's stage is: {class_names[predicted_class]}
", unsafe_allow_html=True) if __name__ == '__main__': main()