Update app.py
Browse files
app.py
CHANGED
@@ -19,8 +19,7 @@ def dice_loss(y_true, y_pred):
|
|
19 |
|
20 |
# Load the model from the same directory as the script
|
21 |
model_filename = "model.h5" # Replace with your actual model filename
|
22 |
-
model_path = os.path.join(os.path.dirname(__file__), model_filename)
|
23 |
-
|
24 |
|
25 |
def load_model(model_path):
|
26 |
try:
|
@@ -37,8 +36,8 @@ def perform_inference(image):
|
|
37 |
if model is None:
|
38 |
print("Model not loaded properly.")
|
39 |
return None, None, None
|
40 |
-
|
41 |
-
|
42 |
original_shape = image.shape[:2]
|
43 |
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
44 |
image_resized = cv2.resize(image, (256, 256))
|
@@ -50,18 +49,27 @@ def perform_inference(image):
|
|
50 |
mask_resized = cv2.resize(mask, (original_shape[1], original_shape[0]))
|
51 |
mask_binary = (mask_resized > 0.5).astype(np.uint8) * 255
|
52 |
|
|
|
|
|
53 |
|
54 |
-
# Apply the mask to the original image (for
|
55 |
heatmap_img = cv2.applyColorMap(mask_binary, cv2.COLORMAP_JET)
|
56 |
segmented_image = cv2.addWeighted(image, 0.7, heatmap_img, 0.3, 0)
|
|
|
57 |
|
58 |
-
#Convert back to BGR
|
59 |
-
segmented_image = cv2.cvtColor(segmented_image, cv2.COLOR_RGB2BGR)
|
60 |
|
61 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
62 |
return (Image.fromarray(image),
|
63 |
-
Image.fromarray(mask_binary.astype(np.uint8)),
|
64 |
-
Image.fromarray(
|
65 |
|
66 |
|
67 |
# Gradio app
|
|
|
19 |
|
20 |
# Load the model from the same directory as the script
|
21 |
model_filename = "model.h5" # Replace with your actual model filename
|
22 |
+
model_path = os.path.join(os.path.dirname(__file__), model_filename)
|
|
|
23 |
|
24 |
def load_model(model_path):
|
25 |
try:
|
|
|
36 |
if model is None:
|
37 |
print("Model not loaded properly.")
|
38 |
return None, None, None
|
39 |
+
|
40 |
+
# Preprocess the image
|
41 |
original_shape = image.shape[:2]
|
42 |
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
43 |
image_resized = cv2.resize(image, (256, 256))
|
|
|
49 |
mask_resized = cv2.resize(mask, (original_shape[1], original_shape[0]))
|
50 |
mask_binary = (mask_resized > 0.5).astype(np.uint8) * 255
|
51 |
|
52 |
+
# Find contours in the binary mask
|
53 |
+
contours, _ = cv2.findContours(mask_binary, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
|
54 |
|
55 |
+
# Apply the mask to the original image (heatmap for visualization)
|
56 |
heatmap_img = cv2.applyColorMap(mask_binary, cv2.COLORMAP_JET)
|
57 |
segmented_image = cv2.addWeighted(image, 0.7, heatmap_img, 0.3, 0)
|
58 |
+
segmented_image_with_box = segmented_image.copy() #make a copy for drawing box
|
59 |
|
|
|
|
|
60 |
|
61 |
+
# Get bounding boxes for all contours
|
62 |
+
for contour in contours:
|
63 |
+
x, y, w, h = cv2.boundingRect(contour)
|
64 |
+
cv2.rectangle(segmented_image_with_box, (x, y), (x + w, y + h), (0, 0, 255), 2)
|
65 |
+
|
66 |
+
|
67 |
+
#Convert back to BGR
|
68 |
+
segmented_image_with_box = cv2.cvtColor(segmented_image_with_box, cv2.COLOR_RGB2BGR)
|
69 |
+
#Return image with bounding box
|
70 |
return (Image.fromarray(image),
|
71 |
+
Image.fromarray(mask_binary.astype(np.uint8)),
|
72 |
+
Image.fromarray(segmented_image_with_box))
|
73 |
|
74 |
|
75 |
# Gradio app
|