AdithyaSNair commited on
Commit
16617ac
·
1 Parent(s): 7e0f513

updated app.py

Browse files
Files changed (1) hide show
  1. app.py +27 -41
app.py CHANGED
@@ -11,48 +11,34 @@ from sklearn.preprocessing import OneHotEncoder
11
  import pickle
12
  import gradio as gr
13
 
14
- def load_model():
15
- save_path = 'model_architecture.json'
16
- with open(save_path, 'r') as file:
17
- model = keras.models.model_from_json(file.read())
18
- model.load_weights('model_weights.h5')
19
- return model
20
 
21
- def predict_dementia(images, model):
22
- predictions = []
23
- for image in images:
24
- img = Image.fromarray(image.astype('uint8'))
25
- img = img.resize((128, 128))
26
- img = np.array(img)
27
- img = img.reshape(1, 128, 128, 3)
28
-
29
- prediction = model.predict(img)
30
- prediction_class = np.argmax(prediction)
31
- predictions.append(names(prediction_class))
32
- return predictions
33
 
34
- def names(number):
35
- if number == 0:
36
- return 'Non Demented'
37
- elif number == 1:
38
- return 'Mild Dementia'
39
- elif number == 2:
40
- return 'Moderate Dementia'
41
- elif number == 3:
42
- return 'Very Mild Dementia'
43
- else:
44
- return 'Error in Prediction'
45
 
46
- def main(images):
47
- model = load_model()
48
- predictions = predict_dementia(images, model)
49
- return predictions
 
 
 
 
 
 
50
 
51
- iface = gr.Interface(fn=main,
52
- inputs="image",
53
- outputs="text",
54
- title="Dementia Classification",
55
- description="Classify dementia based on brain images",
56
- examples=[["Non(1).jpg"],["Moderate.jpg"],["Mild.jpg"]])
57
-
58
- iface.launch(debug =True)
 
11
  import pickle
12
  import gradio as gr
13
 
14
+ # Load the model
15
+ model_path = "model.pkl"
16
+ model = tf.keras.models.load_model(model_path)
 
 
 
17
 
18
+ # Define the labels
19
+ labels = ['Non Demented', 'Mild Dementia', 'Moderate Dementia', 'Very Mild Dementia']
 
 
 
 
 
 
 
 
 
 
20
 
21
+ # Define the prediction function
22
+ def predict_dementia(image):
23
+ img = Image.fromarray(image.astype('uint8'))
24
+ img = img.resize((128, 128))
25
+ img = np.array(img)
26
+ img = img.reshape(1, 128, 128, 3)
27
+
28
+ prediction = model.predict(img)
29
+ prediction_class = np.argmax(prediction)
30
+ return labels[prediction_class]
 
31
 
32
+ # Create the Gradio interface
33
+ iface = gr.Interface(
34
+ fn=predict_dementia,
35
+ inputs="image",
36
+ outputs="text",
37
+ title="Dementia Classification",
38
+ description="Classify dementia based on brain images",
39
+ examples=[["Non(1).jpg"],["Mild.jpg"],["Moderate.jpg"],["Very(1).jpg"]],
40
+ allow_flagging=False
41
+ )
42
 
43
+ # Launch the interface
44
+ iface.launch(debug=True)