Spaces:
Running
Running
Update predict.py
Browse files- predict.py +18 -16
predict.py
CHANGED
@@ -7,8 +7,6 @@ import io
|
|
7 |
from typing import Union
|
8 |
|
9 |
# --- Model Loading ---
|
10 |
-
# This section attempts to load the specific deep learning models you are using.
|
11 |
-
|
12 |
def load_models():
|
13 |
"""Loads TensorFlow and YOLO models using your specified filenames."""
|
14 |
segmentation_model, yolo_detector = None, None
|
@@ -39,7 +37,7 @@ PIXELS_PER_CM = 50.0
|
|
39 |
app = FastAPI(
|
40 |
title="Wound Analysis API",
|
41 |
description="A comprehensive API to analyze wound images using deep learning and computer vision techniques.",
|
42 |
-
version="9.
|
43 |
)
|
44 |
|
45 |
|
@@ -67,7 +65,17 @@ def segment_wound(image: np.ndarray) -> np.ndarray:
|
|
67 |
model_input_size = segmentation_model.input.shape[1:3]
|
68 |
img_resized = cv2.resize(image, (model_input_size[1], model_input_size[0]))
|
69 |
img_norm = np.expand_dims(img_resized.astype(np.float32) / 255.0, axis=0)
|
70 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
71 |
pred_mask_resized = cv2.resize(pred_mask, (orig_w, orig_h), interpolation=cv2.INTER_LINEAR)
|
72 |
mask = (pred_mask_resized.squeeze() >= 0.5).astype(np.uint8) * 255
|
73 |
if cv2.countNonZero(mask) > 0:
|
@@ -75,6 +83,7 @@ def segment_wound(image: np.ndarray) -> np.ndarray:
|
|
75 |
except Exception as e:
|
76 |
print(f"Model prediction failed, switching to fallback segmentation. Error: {e}")
|
77 |
|
|
|
78 |
pixels = image.reshape((-1, 3)).astype(np.float32)
|
79 |
criteria = (cv2.TERM_CRITERIA_EPS + cv2.TERM_CRITERIA_MAX_ITER, 10, 1.0)
|
80 |
_, labels, centers = cv2.kmeans(pixels, 2, None, criteria, 5, cv2.KMEANS_PP_CENTERS)
|
@@ -115,28 +124,21 @@ def calculate_all_metrics(mask: np.ndarray, image: np.ndarray) -> dict:
|
|
115 |
}
|
116 |
|
117 |
def create_visual_overlay(image: np.ndarray, mask: np.ndarray) -> np.ndarray:
|
118 |
-
"""
|
119 |
-
Generates a visual overlay with a Yellow/Blue/Green heatmap and a white boundary.
|
120 |
-
"""
|
121 |
-
# --- 1. Create the Color Heatmap ---
|
122 |
dist = cv2.distanceTransform(mask, cv2.DIST_L2, 5)
|
123 |
cv2.normalize(dist, dist, 0, 1.0, cv2.NORM_MINMAX)
|
124 |
overlay = np.zeros_like(image)
|
125 |
|
126 |
-
|
127 |
-
overlay[dist >= 0.66] = (
|
128 |
-
overlay[(dist
|
129 |
-
overlay[(dist > 0) & (dist < 0.33)] = (0, 255, 0) # Green in BGR
|
130 |
|
131 |
blended = cv2.addWeighted(image, 0.7, overlay, 0.3, 0)
|
132 |
final_image = image.copy()
|
133 |
final_image[mask.astype(bool)] = blended[mask.astype(bool)]
|
134 |
|
135 |
-
# --- 2. Draw the Boundary Contour ---
|
136 |
-
# Find contours from the mask to draw the boundary line.
|
137 |
contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
|
138 |
-
|
139 |
-
cv2.drawContours(final_image, contours, -1, (255, 255, 255), 1) # White color, 1px thickness
|
140 |
|
141 |
return final_image
|
142 |
|
|
|
7 |
from typing import Union
|
8 |
|
9 |
# --- Model Loading ---
|
|
|
|
|
10 |
def load_models():
|
11 |
"""Loads TensorFlow and YOLO models using your specified filenames."""
|
12 |
segmentation_model, yolo_detector = None, None
|
|
|
37 |
app = FastAPI(
|
38 |
title="Wound Analysis API",
|
39 |
description="A comprehensive API to analyze wound images using deep learning and computer vision techniques.",
|
40 |
+
version="9.1.0" # Version with fix for model prediction output format
|
41 |
)
|
42 |
|
43 |
|
|
|
65 |
model_input_size = segmentation_model.input.shape[1:3]
|
66 |
img_resized = cv2.resize(image, (model_input_size[1], model_input_size[0]))
|
67 |
img_norm = np.expand_dims(img_resized.astype(np.float32) / 255.0, axis=0)
|
68 |
+
|
69 |
+
prediction = segmentation_model.predict(img_norm, verbose=0)
|
70 |
+
|
71 |
+
# --- FIX: Handle model output that is a list ---
|
72 |
+
# If the prediction is a list, extract the first element which is the actual numpy array.
|
73 |
+
if isinstance(prediction, list):
|
74 |
+
pred_mask = prediction[0]
|
75 |
+
else:
|
76 |
+
pred_mask = prediction
|
77 |
+
# --- END FIX ---
|
78 |
+
|
79 |
pred_mask_resized = cv2.resize(pred_mask, (orig_w, orig_h), interpolation=cv2.INTER_LINEAR)
|
80 |
mask = (pred_mask_resized.squeeze() >= 0.5).astype(np.uint8) * 255
|
81 |
if cv2.countNonZero(mask) > 0:
|
|
|
83 |
except Exception as e:
|
84 |
print(f"Model prediction failed, switching to fallback segmentation. Error: {e}")
|
85 |
|
86 |
+
# Fallback Method
|
87 |
pixels = image.reshape((-1, 3)).astype(np.float32)
|
88 |
criteria = (cv2.TERM_CRITERIA_EPS + cv2.TERM_CRITERIA_MAX_ITER, 10, 1.0)
|
89 |
_, labels, centers = cv2.kmeans(pixels, 2, None, criteria, 5, cv2.KMEANS_PP_CENTERS)
|
|
|
124 |
}
|
125 |
|
126 |
def create_visual_overlay(image: np.ndarray, mask: np.ndarray) -> np.ndarray:
|
127 |
+
"""Generates a visual overlay with a Yellow/Blue/Green heatmap and a white boundary."""
|
|
|
|
|
|
|
128 |
dist = cv2.distanceTransform(mask, cv2.DIST_L2, 5)
|
129 |
cv2.normalize(dist, dist, 0, 1.0, cv2.NORM_MINMAX)
|
130 |
overlay = np.zeros_like(image)
|
131 |
|
132 |
+
overlay[dist >= 0.66] = (0, 255, 255)
|
133 |
+
overlay[(dist >= 0.33) & (dist < 0.66)] = (255, 0, 0)
|
134 |
+
overlay[(dist > 0) & (dist < 0.33)] = (0, 255, 0)
|
|
|
135 |
|
136 |
blended = cv2.addWeighted(image, 0.7, overlay, 0.3, 0)
|
137 |
final_image = image.copy()
|
138 |
final_image[mask.astype(bool)] = blended[mask.astype(bool)]
|
139 |
|
|
|
|
|
140 |
contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
|
141 |
+
cv2.drawContours(final_image, contours, -1, (255, 255, 255), 1)
|
|
|
142 |
|
143 |
return final_image
|
144 |
|