NEW_BTD / app.py
iSushant's picture
Update app.py
87751d8 verified
raw
history blame
3.35 kB
import numpy as np
import tensorflow as tf
import cv2
from PIL import Image
import gradio as gr
import os
# Define custom metrics
smooth = 1e-15
def dice_coef(y_true, y_pred):
y_true = tf.keras.layers.Flatten()(y_true)
y_pred = tf.keras.layers.Flatten()(y_pred)
intersection = tf.reduce_sum(y_true * y_pred)
return (2. * intersection + smooth) / (tf.reduce_sum(y_true) + tf.reduce_sum(y_pred) + smooth)
def dice_loss(y_true, y_pred):
return 1.0 - dice_coef(y_true, y_pred)
# Load the model from the same directory as the script
model_filename = "model.h5" # Replace with your actual model filename
model_path = os.path.join(os.path.dirname(__file__), model_filename)
def load_model(model_path):
try:
model = tf.keras.models.load_model(model_path, custom_objects={'dice_loss': dice_loss, 'dice_coef': dice_coef}, compile=False)
return model
except Exception as e:
print(f"Error loading model: {e}")
return None
# Perform inference
model = load_model(model_path)
def perform_inference(image):
if model is None:
print("Model not loaded properly.")
return None, None, None
# Preprocess the image
original_shape = image.shape[:2]
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
image_resized = cv2.resize(image, (256, 256))
image_normalized = image_resized / 255.0
image_expanded = np.expand_dims(image_normalized, axis=0)
# Get the mask from the model prediction
mask = model.predict(image_expanded)[0]
mask_resized = cv2.resize(mask, (original_shape[1], original_shape[0]))
mask_binary = (mask_resized > 0.5).astype(np.uint8) * 255
# Find contours in the binary mask
contours, _ = cv2.findContours(mask_binary, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
# Apply the mask to the original image (heatmap for visualization)
heatmap_img = cv2.applyColorMap(mask_binary, cv2.COLORMAP_JET)
segmented_image = cv2.addWeighted(image, 0.7, heatmap_img, 0.3, 0)
segmented_image_with_box = segmented_image.copy()
# Get bounding boxes for all contours and annotate
for contour in contours:
x, y, w, h = cv2.boundingRect(contour)
cv2.rectangle(segmented_image_with_box, (x, y), (x + w, y + h), (0, 0, 255), 2)
cv2.putText(segmented_image_with_box, "Tumour Detected", (x, y - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.9, (0, 0, 255), 2)
#Convert back to BGR
segmented_image_with_box = cv2.cvtColor(segmented_image_with_box, cv2.COLOR_RGB2BGR)
return (Image.fromarray(image),
Image.fromarray(mask_binary.astype(np.uint8)),
Image.fromarray(segmented_image_with_box))
# Gradio app
def gradio_app():
with gr.Blocks() as demo:
with gr.Row():
with gr.Column():
input_image = gr.Image(label="Input Image", type="numpy")
submit_btn = gr.Button("Submit")
with gr.Column():
original_image_output = gr.Image(label="Original Image")
mask_output = gr.Image(label="Predicted Mask")
segmented_image_output= gr.Image(label="Segmented Image")
submit_btn.click(perform_inference, inputs=input_image, outputs=[original_image_output, mask_output, segmented_image_output])
demo.launch()
if __name__ == "__main__":
gradio_app()