Spaces:
Running
Running
Update predict.py
Browse files- predict.py +93 -54
predict.py
CHANGED
@@ -6,15 +6,18 @@ from typing import Union
|
|
6 |
|
7 |
# --- Load Models ---
|
8 |
def load_models():
|
|
|
9 |
segmentation_model, yolo_model = None, None
|
10 |
try:
|
11 |
import tensorflow as tf
|
|
|
12 |
segmentation_model = tf.keras.models.load_model("segmentation_model.h5")
|
13 |
print("✅ Segmentation model loaded.")
|
14 |
except Exception as e:
|
15 |
print(f"⚠️ Failed to load segmentation model: {e}")
|
16 |
try:
|
17 |
from ultralytics import YOLO
|
|
|
18 |
yolo_model = YOLO("best.pt")
|
19 |
print("✅ YOLO model loaded.")
|
20 |
except Exception as e:
|
@@ -23,132 +26,168 @@ def load_models():
|
|
23 |
|
24 |
segmentation_model, yolo_model = load_models()
|
25 |
|
|
|
|
|
|
|
26 |
PIXELS_PER_CM = 50.0
|
27 |
-
app = FastAPI(title="Wound Analyzer", version="10.1") # Version with depth calculation fix
|
28 |
|
29 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
30 |
def preprocess_image(image: np.ndarray) -> np.ndarray:
|
31 |
-
"""Enhances image for better segmentation."""
|
32 |
img_denoised = cv2.medianBlur(image, 3)
|
33 |
lab = cv2.cvtColor(img_denoised, cv2.COLOR_BGR2LAB)
|
34 |
l, a, b = cv2.split(lab)
|
35 |
-
clahe = cv2.createCLAHE(2.0, (8, 8))
|
36 |
l_clahe = clahe.apply(l)
|
37 |
lab_clahe = cv2.merge((l_clahe, a, b))
|
38 |
result = cv2.cvtColor(lab_clahe, cv2.COLOR_LAB2BGR)
|
39 |
gamma = 1.2
|
40 |
return np.clip((result / 255.0) ** gamma * 255, 0, 255).astype(np.uint8)
|
41 |
|
42 |
-
# --- Segmentation ---
|
43 |
def segment_wound(image: np.ndarray) -> np.ndarray:
|
44 |
-
"""Segments wound from a preprocessed image, with a fallback."""
|
45 |
-
|
46 |
-
|
47 |
input_size = segmentation_model.input_shape[1:3]
|
48 |
resized = cv2.resize(image, (input_size[1], input_size[0]))
|
49 |
norm = np.expand_dims(resized / 255.0, axis=0)
|
50 |
prediction = segmentation_model.predict(norm, verbose=0)
|
|
|
51 |
if isinstance(prediction, list):
|
52 |
prediction = prediction[0]
|
53 |
prediction = prediction[0]
|
54 |
mask = cv2.resize(prediction.squeeze(), (image.shape[1], image.shape[0]))
|
55 |
return (mask >= 0.5).astype(np.uint8) * 255
|
56 |
-
|
57 |
-
|
58 |
|
|
|
59 |
Z = image.reshape((-1, 3)).astype(np.float32)
|
60 |
-
|
|
|
61 |
centers_lab = cv2.cvtColor(centers.reshape(1, -1, 3).astype(np.uint8), cv2.COLOR_BGR2LAB)[0]
|
62 |
-
wound_idx = np.argmax(centers_lab[:, 1])
|
63 |
mask = (labels.reshape(image.shape[:2]) == wound_idx).astype(np.uint8) * 255
|
64 |
return mask
|
65 |
|
66 |
-
# --- Metrics ---
|
67 |
def calculate_metrics(mask: np.ndarray, original_image: np.ndarray):
|
68 |
-
"""
|
|
|
|
|
69 |
area_px = cv2.countNonZero(mask)
|
70 |
if area_px == 0:
|
71 |
return dict(area_cm2=0.0, length_cm=0.0, breadth_cm=0.0, depth_cm=0.0, moisture=0.0)
|
72 |
-
|
|
|
73 |
area_cm2 = area_px / (PIXELS_PER_CM ** 2)
|
74 |
contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
|
|
|
|
|
|
|
75 |
rect = cv2.minAreaRect(max(contours, key=cv2.contourArea))
|
76 |
(w, h) = rect[1]
|
77 |
length_cm, breadth_cm = max(w, h) / PIXELS_PER_CM, min(w, h) / PIXELS_PER_CM
|
78 |
|
|
|
79 |
mask_bool = mask.astype(bool)
|
80 |
-
|
81 |
-
lab = cv2.cvtColor(original_image, cv2.COLOR_BGR2LAB)
|
82 |
-
gray = cv2.cvtColor(original_image, cv2.COLOR_BGR2GRAY)
|
83 |
|
84 |
-
|
85 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
86 |
|
87 |
-
return dict(area_cm2=round(area_cm2, 2), length_cm=round(length_cm, 2), breadth_cm=round(breadth_cm, 2), depth_cm=round(depth, 1), moisture=round(moisture, 0))
|
88 |
-
|
89 |
-
# --- Overlay ---
|
90 |
def draw_overlay(image: np.ndarray, mask: np.ndarray) -> np.ndarray:
|
91 |
-
"""Draws heatmap and boundary on the image."""
|
92 |
-
|
93 |
-
|
94 |
-
|
95 |
-
heatmap = np.zeros_like(image)
|
96 |
-
|
97 |
-
heatmap[
|
98 |
-
heatmap[(
|
99 |
-
|
100 |
-
|
101 |
-
blended =
|
102 |
-
|
103 |
-
|
104 |
-
|
|
|
|
|
105 |
contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
|
106 |
-
cv2.drawContours(
|
107 |
-
return
|
108 |
|
109 |
# --- API Endpoint ---
|
110 |
@app.post("/analyze_wound")
|
111 |
async def analyze(file: UploadFile = File(...)):
|
|
|
|
|
|
|
|
|
112 |
contents = await file.read()
|
113 |
original_image = cv2.imdecode(np.frombuffer(contents, np.uint8), cv2.IMREAD_COLOR)
|
114 |
if original_image is None:
|
115 |
-
raise HTTPException(status_code=400, detail="Invalid image file.")
|
116 |
|
117 |
-
# 1.
|
118 |
preprocessed_image = preprocess_image(original_image)
|
119 |
|
120 |
-
# 2.
|
121 |
-
|
122 |
-
preprocessed_roi = preprocessed_image.copy()
|
123 |
-
|
124 |
if yolo_model:
|
125 |
try:
|
126 |
results = yolo_model.predict(preprocessed_image, verbose=False)
|
127 |
if results and results[0].boxes:
|
128 |
coords = results[0].boxes[0].xyxy[0].cpu().numpy().astype(int)
|
129 |
-
|
130 |
-
# Crop both the original and preprocessed images to the same ROI
|
131 |
-
original_roi = original_image[y1:y2, x1:x2]
|
132 |
-
preprocessed_roi = preprocessed_image[y1:y2, x1:x2]
|
133 |
except Exception as e:
|
134 |
-
print(f"⚠️ YOLO detection failed: {e}")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
135 |
|
136 |
# 3. Get mask from the preprocessed ROI
|
137 |
mask = segment_wound(preprocessed_roi)
|
138 |
|
139 |
-
# 4. Calculate metrics using the ORIGINAL ROI for color accuracy
|
140 |
metrics = calculate_metrics(mask, original_roi)
|
141 |
|
142 |
# 5. Draw overlay on the ORIGINAL ROI for correct visualization
|
143 |
annotated_image = draw_overlay(original_roi, mask)
|
144 |
|
|
|
145 |
success, out_bytes = cv2.imencode(".png", annotated_image)
|
146 |
if not success:
|
147 |
-
raise HTTPException(status_code=500, detail="Failed to encode image.")
|
148 |
|
149 |
headers = {
|
150 |
-
"X-Length-Cm": str(metrics["length_cm"]),
|
151 |
-
"X-
|
|
|
|
|
152 |
"X-Moisture": str(metrics["moisture"]),
|
153 |
}
|
154 |
return Response(content=out_bytes.tobytes(), media_type="image/png", headers=headers)
|
|
|
6 |
|
7 |
# --- Load Models ---
|
8 |
def load_models():
|
9 |
+
"""Loads machine learning models safely."""
|
10 |
segmentation_model, yolo_model = None, None
|
11 |
try:
|
12 |
import tensorflow as tf
|
13 |
+
# Ensure you have a valid model file at this path
|
14 |
segmentation_model = tf.keras.models.load_model("segmentation_model.h5")
|
15 |
print("✅ Segmentation model loaded.")
|
16 |
except Exception as e:
|
17 |
print(f"⚠️ Failed to load segmentation model: {e}")
|
18 |
try:
|
19 |
from ultralytics import YOLO
|
20 |
+
# Ensure you have a valid model file at this path
|
21 |
yolo_model = YOLO("best.pt")
|
22 |
print("✅ YOLO model loaded.")
|
23 |
except Exception as e:
|
|
|
26 |
|
27 |
segmentation_model, yolo_model = load_models()
|
28 |
|
29 |
+
# --- Configuration ---
|
30 |
+
# This value should be calibrated by taking a picture of a ruler.
|
31 |
+
# Measure a known length (e.g., 5cm) in pixels, then divide pixels by cm.
|
32 |
PIXELS_PER_CM = 50.0
|
|
|
33 |
|
34 |
+
app = FastAPI(
|
35 |
+
title="Wound Analyzer",
|
36 |
+
version="10.2", # Version with improved depth logic
|
37 |
+
description="Analyzes wound images, keeping the original API contract."
|
38 |
+
)
|
39 |
+
|
40 |
+
# --- Image Processing ---
|
41 |
def preprocess_image(image: np.ndarray) -> np.ndarray:
|
42 |
+
"""Enhances image for better segmentation by improving contrast and reducing noise."""
|
43 |
img_denoised = cv2.medianBlur(image, 3)
|
44 |
lab = cv2.cvtColor(img_denoised, cv2.COLOR_BGR2LAB)
|
45 |
l, a, b = cv2.split(lab)
|
46 |
+
clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8))
|
47 |
l_clahe = clahe.apply(l)
|
48 |
lab_clahe = cv2.merge((l_clahe, a, b))
|
49 |
result = cv2.cvtColor(lab_clahe, cv2.COLOR_LAB2BGR)
|
50 |
gamma = 1.2
|
51 |
return np.clip((result / 255.0) ** gamma * 255, 0, 255).astype(np.uint8)
|
52 |
|
|
|
53 |
def segment_wound(image: np.ndarray) -> np.ndarray:
|
54 |
+
"""Segments wound from a preprocessed image, with a fallback to KMeans if the model fails."""
|
55 |
+
if segmentation_model:
|
56 |
+
try:
|
57 |
input_size = segmentation_model.input_shape[1:3]
|
58 |
resized = cv2.resize(image, (input_size[1], input_size[0]))
|
59 |
norm = np.expand_dims(resized / 255.0, axis=0)
|
60 |
prediction = segmentation_model.predict(norm, verbose=0)
|
61 |
+
# Handle models with multiple outputs
|
62 |
if isinstance(prediction, list):
|
63 |
prediction = prediction[0]
|
64 |
prediction = prediction[0]
|
65 |
mask = cv2.resize(prediction.squeeze(), (image.shape[1], image.shape[0]))
|
66 |
return (mask >= 0.5).astype(np.uint8) * 255
|
67 |
+
except Exception as e:
|
68 |
+
print(f"⚠️ Segmentation model prediction failed: {e}. Falling back to KMeans.")
|
69 |
|
70 |
+
# Fallback method using color clustering if the primary model fails
|
71 |
Z = image.reshape((-1, 3)).astype(np.float32)
|
72 |
+
criteria = (cv2.TERM_CRITERIA_EPS + cv2.TERM_CRITERIA_MAX_ITER, 10, 1.0)
|
73 |
+
_, labels, centers = cv2.kmeans(Z, 2, None, criteria, 5, cv2.KMEANS_PP_CENTERS)
|
74 |
centers_lab = cv2.cvtColor(centers.reshape(1, -1, 3).astype(np.uint8), cv2.COLOR_BGR2LAB)[0]
|
75 |
+
wound_idx = np.argmax(centers_lab[:, 1]) # Assume wound is the reddest cluster
|
76 |
mask = (labels.reshape(image.shape[:2]) == wound_idx).astype(np.uint8) * 255
|
77 |
return mask
|
78 |
|
|
|
79 |
def calculate_metrics(mask: np.ndarray, original_image: np.ndarray):
|
80 |
+
"""
|
81 |
+
Calculates all metrics. Depth is now calculated based on shadow/intensity analysis.
|
82 |
+
"""
|
83 |
area_px = cv2.countNonZero(mask)
|
84 |
if area_px == 0:
|
85 |
return dict(area_cm2=0.0, length_cm=0.0, breadth_cm=0.0, depth_cm=0.0, moisture=0.0)
|
86 |
+
|
87 |
+
# --- Area and Dimensions ---
|
88 |
area_cm2 = area_px / (PIXELS_PER_CM ** 2)
|
89 |
contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
|
90 |
+
if not contours: # Handle case where mask is present but no contours are found
|
91 |
+
return dict(area_cm2=area_cm2, length_cm=0.0, breadth_cm=0.0, depth_cm=0.0, moisture=0.0)
|
92 |
+
|
93 |
rect = cv2.minAreaRect(max(contours, key=cv2.contourArea))
|
94 |
(w, h) = rect[1]
|
95 |
length_cm, breadth_cm = max(w, h) / PIXELS_PER_CM, min(w, h) / PIXELS_PER_CM
|
96 |
|
97 |
+
# --- Depth and Moisture Calculation ---
|
98 |
mask_bool = mask.astype(bool)
|
99 |
+
gray_image = cv2.cvtColor(original_image, cv2.COLOR_BGR2GRAY)
|
|
|
|
|
100 |
|
101 |
+
# **ENHANCED DEPTH CALCULATION**: Use standard deviation of pixel intensity.
|
102 |
+
# A higher standard deviation implies more shadows and highlights, suggesting greater depth.
|
103 |
+
# The result is scaled to a 0-10 range for a more intuitive, relative depth score.
|
104 |
+
intensity_std_dev = np.std(gray_image[mask_bool])
|
105 |
+
depth_score = (intensity_std_dev / 127.0) * 10.0 # Scale to a 0-10 range
|
106 |
+
|
107 |
+
# Moisture calculation remains the same, based on intensity variance
|
108 |
+
moisture = max(0.0, 100.0 * (1 - np.std(gray_image[mask_bool]) / 127.0))
|
109 |
+
|
110 |
+
return dict(
|
111 |
+
area_cm2=round(area_cm2, 2),
|
112 |
+
length_cm=round(length_cm, 2),
|
113 |
+
breadth_cm=round(breadth_cm, 2),
|
114 |
+
depth_cm=round(depth_score, 1), # This is now the new depth score
|
115 |
+
moisture=round(moisture, 1)
|
116 |
+
)
|
117 |
|
|
|
|
|
|
|
118 |
def draw_overlay(image: np.ndarray, mask: np.ndarray) -> np.ndarray:
|
119 |
+
"""Draws a heatmap and boundary on the image for visualization."""
|
120 |
+
dist_transform = cv2.distanceTransform(mask, cv2.DIST_L2, 5)
|
121 |
+
cv2.normalize(dist_transform, dist_transform, 0, 1.0, cv2.NORM_MINMAX)
|
122 |
+
|
123 |
+
heatmap = np.zeros_like(image, dtype=np.uint8)
|
124 |
+
heatmap[dist_transform >= 0.66] = (0, 255, 255) # Yellow - Core
|
125 |
+
heatmap[(dist_transform >= 0.33) & (dist_transform < 0.66)] = (255, 0, 0) # Blue - Moderate
|
126 |
+
heatmap[(dist_transform > 0) & (dist_transform < 0.33)] = (0, 255, 0) # Green - Periphery
|
127 |
+
|
128 |
+
# Create a blended image only where the mask is active
|
129 |
+
blended = image.copy()
|
130 |
+
alpha = 0.4
|
131 |
+
masked_pixels = mask.astype(bool)
|
132 |
+
blended[masked_pixels] = cv2.addWeighted(image, 1 - alpha, heatmap, alpha, 0)[masked_pixels]
|
133 |
+
|
134 |
+
# Draw a clean, white contour
|
135 |
contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
|
136 |
+
cv2.drawContours(blended, contours, -1, (255, 255, 255), 2)
|
137 |
+
return blended
|
138 |
|
139 |
# --- API Endpoint ---
|
140 |
@app.post("/analyze_wound")
|
141 |
async def analyze(file: UploadFile = File(...)):
|
142 |
+
"""
|
143 |
+
Accepts an image, analyzes the wound, and returns an annotated image
|
144 |
+
with metrics in the response headers, maintaining the original API contract.
|
145 |
+
"""
|
146 |
contents = await file.read()
|
147 |
original_image = cv2.imdecode(np.frombuffer(contents, np.uint8), cv2.IMREAD_COLOR)
|
148 |
if original_image is None:
|
149 |
+
raise HTTPException(status_code=400, detail="Invalid or corrupt image file.")
|
150 |
|
151 |
+
# 1. Preprocess a copy of the image for object detection and segmentation
|
152 |
preprocessed_image = preprocess_image(original_image)
|
153 |
|
154 |
+
# 2. Detect ROI using YOLO if available, otherwise use the whole image
|
155 |
+
roi_coords = (0, 0, original_image.shape[1], original_image.shape[0])
|
|
|
|
|
156 |
if yolo_model:
|
157 |
try:
|
158 |
results = yolo_model.predict(preprocessed_image, verbose=False)
|
159 |
if results and results[0].boxes:
|
160 |
coords = results[0].boxes[0].xyxy[0].cpu().numpy().astype(int)
|
161 |
+
roi_coords = (coords[0], coords[1], coords[2], coords[3])
|
|
|
|
|
|
|
162 |
except Exception as e:
|
163 |
+
print(f"⚠️ YOLO detection failed: {e}. Using full image as ROI.")
|
164 |
+
|
165 |
+
x1, y1, x2, y2 = roi_coords
|
166 |
+
original_roi = original_image[y1:y2, x1:x2]
|
167 |
+
preprocessed_roi = preprocessed_image[y1:y2, x1:x2]
|
168 |
+
|
169 |
+
if original_roi.size == 0:
|
170 |
+
raise HTTPException(status_code=404, detail="Wound region of interest could not be determined.")
|
171 |
|
172 |
# 3. Get mask from the preprocessed ROI
|
173 |
mask = segment_wound(preprocessed_roi)
|
174 |
|
175 |
+
# 4. Calculate metrics using the ORIGINAL ROI for color/intensity accuracy
|
176 |
metrics = calculate_metrics(mask, original_roi)
|
177 |
|
178 |
# 5. Draw overlay on the ORIGINAL ROI for correct visualization
|
179 |
annotated_image = draw_overlay(original_roi, mask)
|
180 |
|
181 |
+
# 6. Encode image and prepare response
|
182 |
success, out_bytes = cv2.imencode(".png", annotated_image)
|
183 |
if not success:
|
184 |
+
raise HTTPException(status_code=500, detail="Failed to encode output image.")
|
185 |
|
186 |
headers = {
|
187 |
+
"X-Length-Cm": str(metrics["length_cm"]),
|
188 |
+
"X-Breadth-Cm": str(metrics["breadth_cm"]),
|
189 |
+
"X-Depth-Cm": str(metrics["depth_cm"]),
|
190 |
+
"X-Area-Cm2": str(metrics["area_cm2"]),
|
191 |
"X-Moisture": str(metrics["moisture"]),
|
192 |
}
|
193 |
return Response(content=out_bytes.tobytes(), media_type="image/png", headers=headers)
|