AdithyaSNair's picture
update
79a1c68
raw
history blame
1.26 kB
import numpy as np
import os
import keras
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
from keras.models import Sequential
from PIL import Image
from keras.layers import Conv2D, Flatten, Dense, Dropout, BatchNormalization, MaxPooling2D
from sklearn.preprocessing import OneHotEncoder
import pickle
import tensorflow as tf
import gradio as gr
# Load the model
model_path = "model.pkl"
model = tf.keras.models.load_model(model_path)
# Define the labels
labels = ['Non Demented', 'Mild Dementia', 'Moderate Dementia', 'Very Mild Dementia']
# Define the prediction function
def predict_dementia(image):
img = Image.fromarray(image.astype('uint8'))
img = img.resize((128, 128))
img = np.array(img)
img = img.reshape(1, 128, 128, 3)
prediction = model.predict(img)
prediction_class = np.argmax(prediction)
return labels[prediction_class]
# Create the Gradio interface
iface = gr.Interface(
fn=predict_dementia,
inputs="image",
outputs="text",
title="Dementia Classification",
description="Classify dementia based on brain images",
examples=[["Non(1).jpg"],["Mild.jpg"],["Moderate.jpg"],["Very(1).jpg"]],
allow_flagging=False
)
# Launch the interface
iface.launch(debug=True)