iSushant commited on
Commit
2e137d9
·
verified ·
1 Parent(s): ec5c7f6

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +86 -0
app.py ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import tensorflow as tf
3
+ import cv2
4
+ from PIL import Image
5
+ import gradio as gr
6
+ import os
7
+
8
+ # Define custom metrics
9
+ smooth = 1e-15
10
+
11
+ def dice_coef(y_true, y_pred):
12
+ y_true = tf.keras.layers.Flatten()(y_true)
13
+ y_pred = tf.keras.layers.Flatten()(y_pred)
14
+ intersection = tf.reduce_sum(y_true * y_pred)
15
+ return (2. * intersection + smooth) / (tf.reduce_sum(y_true) + tf.reduce_sum(y_pred) + smooth)
16
+
17
+ def dice_loss(y_true, y_pred):
18
+ return 1.0 - dice_coef(y_true, y_pred)
19
+
20
+ # Load the model from the same directory as the script
21
+ model_filename = "model.h5" # Replace with your actual model filename
22
+ model_path = os.path.join(os.path.dirname(__file__), model_filename) # Construct the full path
23
+
24
+
25
+ def load_model(model_path):
26
+ try:
27
+ model = tf.keras.models.load_model(model_path, custom_objects={'dice_loss': dice_loss, 'dice_coef': dice_coef}, compile=False)
28
+ return model
29
+ except Exception as e:
30
+ print(f"Error loading model: {e}")
31
+ return None
32
+
33
+ # Perform inference
34
+ model = load_model(model_path)
35
+
36
+ def perform_inference(image):
37
+ if model is None:
38
+ print("Model not loaded properly.")
39
+ return None, None, None
40
+
41
+ # Preprocess the image
42
+ original_shape = image.shape[:2]
43
+ image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
44
+ image_resized = cv2.resize(image, (256, 256))
45
+ image_normalized = image_resized / 255.0
46
+ image_expanded = np.expand_dims(image_normalized, axis=0)
47
+
48
+ # Get the mask from the model prediction
49
+ mask = model.predict(image_expanded)[0]
50
+ mask_resized = cv2.resize(mask, (original_shape[1], original_shape[0]))
51
+ mask_binary = (mask_resized > 0.5).astype(np.uint8) * 255
52
+
53
+
54
+ # Apply the mask to the original image (for better visualization)
55
+ heatmap_img = cv2.applyColorMap(mask_binary, cv2.COLORMAP_JET)
56
+ segmented_image = cv2.addWeighted(image, 0.7, heatmap_img, 0.3, 0)
57
+
58
+ #Convert back to BGR
59
+ segmented_image = cv2.cvtColor(segmented_image, cv2.COLOR_RGB2BGR)
60
+
61
+ # Convert results to PIL Images
62
+ return (Image.fromarray(image),
63
+ Image.fromarray(mask_binary.astype(np.uint8)), #Mask is already multiplied by 255
64
+ Image.fromarray(segmented_image))
65
+
66
+
67
+ # Gradio app
68
+ def gradio_app():
69
+ with gr.Blocks() as demo:
70
+ with gr.Row():
71
+ with gr.Column():
72
+ input_image = gr.Image(label="Input Image", type="numpy")
73
+ submit_btn = gr.Button("Submit")
74
+
75
+ with gr.Column():
76
+ original_image_output = gr.Image(label="Original Image")
77
+ mask_output = gr.Image(label="Predicted Mask")
78
+ segmented_image_output= gr.Image(label="Segmented Image")
79
+
80
+
81
+ submit_btn.click(perform_inference, inputs=input_image, outputs=[original_image_output, mask_output, segmented_image_output])
82
+
83
+ demo.launch()
84
+
85
+ if __name__ == "__main__":
86
+ gradio_app()