File size: 2,664 Bytes
7d19637
 
 
 
 
c617a8a
7d19637
 
 
c617a8a
 
 
7d19637
c617a8a
7d19637
c617a8a
 
7d19637
c617a8a
 
 
 
 
7d19637
c617a8a
 
7d19637
c617a8a
7d19637
c617a8a
7d19637
08a822f
 
 
 
 
 
 
c617a8a
 
7d19637
c617a8a
7d19637
c617a8a
 
f1fd2ca
08a822f
 
 
 
 
 
 
 
bb8b34d
f1fd2ca
08a822f
 
f1fd2ca
08a822f
f1fd2ca
08a822f
7d19637
 
c617a8a
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
import streamlit as st
import numpy as np
from tensorflow.keras.preprocessing import image
from tensorflow.keras.applications.efficientnet import preprocess_input
import tensorflow as tf
from PIL import Image

def main():
    st.set_page_config(
         page_title="Alzheimer's Detection",
         layout="centered"
    )

    st.markdown("<h1 style='text-align: center;'>Alzheimer's Detection Tool🧠</h1>", unsafe_allow_html=True)

    # Load the saved model
    loaded_model = tf.keras.models.load_model("model.h5", compile=False)

    def predict_image_class(img_path):
        img = image.load_img(img_path, target_size=(224, 224))
        img_array = image.img_to_array(img)
        img_array = np.expand_dims(img_array, axis=0)
        img_array = preprocess_input(img_array)

        predictions = loaded_model.predict(img_array)
        predicted_class = np.argmax(predictions, axis=1)[0]

        return predicted_class

    uploaded_file = st.file_uploader('Upload an MRI image...', type=['jpg', 'png', 'jpeg'])

    example_images = {
        "MILD DEMENTED": "mild_468_0_4983.jpg",
        "MODERATE DEMENTED": "moderate_2_0_72.jpg",
        "NON DEMENTED": "non_61.jpg",
        "VERY MILD DEMENTED": "verymild_37_0_2606.jpg"
    }

    if uploaded_file is not None:
        predicted_class = predict_image_class(uploaded_file)
        
        class_names = ["MILD DEMENTED", "MODERATE DEMENTED", "NON DEMENTED", "VERY MILD DEMENTED"]
        
        st.markdown(f"<div style='text-align: center; font-size: 30px;'>Predicted Alzheimer's stage is: <b>{class_names[predicted_class]}</b></div>", unsafe_allow_html=True)
        st.image(uploaded_file, use_column_width=False, width=300)
    else:
        st.write("Or select an example image below:")
        selected_example = st.selectbox("Select an example image:", list(example_images.keys()))

        if selected_example:
            example_image_path = example_images[selected_example]
            example_image = Image.open(example_image_path)
            
            # Display the selected example image
            st.image(example_image, caption=selected_example, use_column_width=False,width=300)
            
            # Predict based on the selected example image
            predicted_class = predict_image_class(example_image_path)
            
            class_names = ["MILD DEMENTED", "MODERATE DEMENTED", "NON DEMENTED", "VERY MILD DEMENTED"]
            
            st.markdown(f"<div style='text-align: center; font-size: 30px;'>Predicted Alzheimer's stage is: <b>{class_names[predicted_class]}</b></div>", unsafe_allow_html=True)

if __name__ == '__main__':
    main()