DHEIVER's picture
Update app.py
0596f27
raw
history blame
2.4 kB
import gradio as gr
import tensorflow as tf
import numpy as np
from PIL import Image
import cv2
import datetime
class ImageClassifierApp:
def __init__(self, model_path):
self.model_path = model_path
self.model = self.load_model()
self.class_labels = ["Normal", "Cataract"]
def load_model(self):
# Load the trained TensorFlow model
with tf.keras.utils.custom_object_scope({'FixedDropout': FixedDropout}):
model = tf.keras.models.load_model(self.model_path)
return model
def classify_image(self, input_image):
input_image = tf.image.resize(input_image, (192, 256))
input_image = (input_image / 255.0)
input_image = np.expand_dims(input_image, axis=0)
current_time = datetime.datetime.now()
prediction = self.model.predict(input_image)
class_index = np.argmax(prediction)
predicted_class = self.class_labels[class_index]
output_image = (input_image[0] * 255).astype('uint8')
output_image = cv2.copyMakeBorder(output_image, 0, 50, 0, 0, cv2.BORDER_CONSTANT, value=(255, 255, 255))
label_background = np.ones((50, output_image.shape[1], 3), dtype=np.uint8) * 255
output_image[-50:] = label_background
font = cv2.FONT_HERSHEY_SIMPLEX
font_scale = 0.4
cv2.putText(output_image, f"Analysis Time: {current_time.strftime('%Y-%m-%d %H:%M:%S')}", (10, output_image.shape[0] - 30), font, font_scale, (0, 0, 0), 1)
cv2.putText(output_image, f"Predicted Class: {predicted_class}", (10, output_image.shape[0] - 10), font, font_scale, (0, 0, 0), 1)
image_height, image_width, _ = output_image.shape
box_size = 100
box_x = (image_width - box_size) // 2
box_y = (image_height - box_size) // 2
object_box_color = (255, 0, 0)
cv2.rectangle(output_image, (box_x, box_y), (box_x + box_size, box_y + box_size), object_box_color, 2)
return output_image
def run_interface(self):
input_interface = gr.Interface(
fn=self.classify_image,
inputs="image",
outputs="image",
live=True
)
input_interface.launch()
if __name__ == "__main__":
model_path = 'modelo_treinado.h5' # Substitua pelo caminho para o seu modelo treinado
app = ImageClassifierApp(model_path)
app.run_interface()