DHEIVER commited on
Commit
2d907f6
·
1 Parent(s): 6a96309

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +15 -17
app.py CHANGED
@@ -1,6 +1,5 @@
1
  import gradio as gr
2
  import tensorflow as tf
3
- import requests
4
  import cv2
5
  import numpy as np
6
 
@@ -10,23 +9,22 @@ tf_model = tf.keras.models.load_model(tf_model_path)
10
 
11
  class_labels = ["Normal", "Cataract"]
12
 
13
- def predict(inp):
14
- # Use the TensorFlow model to predict Normal or Cataract
15
- img_array = cv2.cvtColor(np.array(inp), cv2.COLOR_RGB2BGR)
16
- img_array = cv2.resize(img_array, (224, 224))
17
- img_array = img_array / 255.0
18
- img_array = np.expand_dims(img_array, axis=0)
19
 
20
- prediction_tf = tf_model.predict(img_array)
21
- label_index = np.argmax(prediction_tf)
22
- confidence_tf = float(prediction_tf[0, label_index])
 
23
 
24
- return class_labels[label_index], confidence_tf
25
 
26
- demo = gr.Interface(
27
- fn=predict,
28
- inputs=gr.inputs.Image(type="pil"),
29
- outputs=["label", "number"],
30
- )
31
 
32
- demo.launch()
 
1
  import gradio as gr
2
  import tensorflow as tf
 
3
  import cv2
4
  import numpy as np
5
 
 
9
 
10
  class_labels = ["Normal", "Cataract"]
11
 
12
+ # Define a Gradio interface
13
+ def classify_image(input_image):
14
+ # Preprocess the input image
15
+ input_image = cv2.resize(input_image, (224, 224)) # Resize the image to match the model's input size
16
+ input_image = np.expand_dims(input_image, axis=0) # Add batch dimension
17
+ input_image = input_image / 255.0 # Normalize pixel values (assuming input range [0, 255])
18
 
19
+ # Make predictions using the loaded model
20
+ predictions = tf_model.predict(input_image)
21
+ class_index = np.argmax(predictions, axis=1)[0]
22
+ predicted_class = class_labels[class_index]
23
 
24
+ return predicted_class
25
 
26
+ # Create a Gradio interface
27
+ input_image = gr.inputs.Image(shape=(224, 224, 3)) # Define the input image shape
28
+ output_label = gr.outputs.Label() # Define the output label
 
 
29
 
30
+ gr.Interface(fn=classify_image, inputs=input_image, outputs=output_label).launch()