Spaces:
Running
Running
Update predict.py
Browse files- predict.py +93 -140
predict.py
CHANGED
@@ -2,7 +2,6 @@ from fastapi import FastAPI, File, UploadFile, HTTPException, Response
|
|
2 |
import cv2
|
3 |
import numpy as np
|
4 |
from ultralytics import YOLO
|
5 |
-
import tensorflow as tf
|
6 |
import io
|
7 |
from typing import Union
|
8 |
|
@@ -11,56 +10,36 @@ PIXELS_PER_CM = 50.0
|
|
11 |
|
12 |
# --- App Initialization ---
|
13 |
app = FastAPI(
|
14 |
-
title="Wound Analysis API",
|
15 |
-
description="An API
|
16 |
-
version="
|
17 |
)
|
18 |
|
19 |
# --- Model Loading ---
|
20 |
-
def
|
21 |
-
"""Loads the
|
22 |
-
segmentation_model, yolo_model = None, None
|
23 |
try:
|
24 |
-
# Load your trained segmentation model
|
25 |
-
segmentation_model = tf.keras.models.load_model("segmentation_model.h5")
|
26 |
-
print("Segmentation model 'segmentation_model.h5' loaded successfully.")
|
27 |
-
except Exception as e:
|
28 |
-
print(f"Warning: Could not load segmentation model. Using fallback. Error: {e}")
|
29 |
-
|
30 |
-
try:
|
31 |
-
# Load your trained YOLO model for wound detection
|
32 |
yolo_model = YOLO("best.pt")
|
33 |
print("YOLO model 'best.pt' loaded successfully.")
|
|
|
34 |
except Exception as e:
|
35 |
-
print(f"
|
36 |
-
|
37 |
-
return segmentation_model, yolo_model
|
38 |
|
39 |
-
|
40 |
|
41 |
# --- Helper Functions ---
|
42 |
|
43 |
def preprocess_image(image: np.ndarray) -> np.ndarray:
|
44 |
-
"""Applies a
|
45 |
-
|
46 |
-
lab = cv2.cvtColor(img_denoised, cv2.COLOR_BGR2LAB)
|
47 |
-
l_channel, a_channel, b_channel = cv2.split(lab)
|
48 |
-
clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8))
|
49 |
-
l_clahe = clahe.apply(l_channel)
|
50 |
-
lab_clahe = cv2.merge([l_clahe, a_channel, b_channel])
|
51 |
-
img_clahe = cv2.cvtColor(lab_clahe, cv2.COLOR_LAB2BGR)
|
52 |
-
gamma = 1.2
|
53 |
-
img_float = img_clahe.astype(np.float32) / 255.0
|
54 |
-
img_gamma = np.power(img_float, gamma)
|
55 |
-
return (img_gamma * 255).astype(np.uint8)
|
56 |
|
57 |
def detect_wound_region_yolo(image: np.ndarray) -> Union[tuple, None]:
|
58 |
-
"""Detects the wound bounding box using the YOLO model."""
|
59 |
if not yolo_model: return None
|
60 |
try:
|
61 |
results = yolo_model.predict(image, verbose=False)
|
62 |
if results and results[0].boxes and len(results[0].boxes) > 0:
|
63 |
-
# Get the box with the highest confidence
|
64 |
best_box = sorted(results[0].boxes, key=lambda b: b.conf[0], reverse=True)[0]
|
65 |
coords = best_box.xyxy[0].cpu().numpy()
|
66 |
return tuple(map(int, coords))
|
@@ -68,53 +47,33 @@ def detect_wound_region_yolo(image: np.ndarray) -> Union[tuple, None]:
|
|
68 |
print(f"YOLO prediction failed: {e}")
|
69 |
return None
|
70 |
|
71 |
-
def
|
72 |
-
"""
|
73 |
-
|
74 |
-
|
75 |
-
|
76 |
-
|
77 |
-
|
78 |
-
|
79 |
-
|
80 |
-
|
81 |
-
|
82 |
-
|
83 |
-
|
84 |
-
|
85 |
-
|
86 |
-
|
87 |
-
|
88 |
-
|
89 |
-
|
90 |
-
return (pred_mask_resized.squeeze() >= 0.5).astype(np.uint8) * 255
|
91 |
-
except Exception as e:
|
92 |
-
print(f"Segmentation model prediction failed: {e}")
|
93 |
-
return None
|
94 |
-
|
95 |
-
def segment_wound_with_fallback(image: np.ndarray) -> np.ndarray:
|
96 |
-
"""A fallback segmentation method using k-means clustering if the primary model fails."""
|
97 |
-
pixels = image.reshape((-1, 3)).astype(np.float32)
|
98 |
-
criteria = (cv2.TERM_CRITERIA_EPS + cv2.TERM_CRITERIA_MAX_ITER, 10, 1.0)
|
99 |
-
_, labels, centers = cv2.kmeans(pixels, 2, None, criteria, 5, cv2.KMEANS_PP_CENTERS)
|
100 |
-
centers_lab = cv2.cvtColor(centers.reshape(1, -1, 3).astype(np.uint8), cv2.COLOR_BGR2LAB)[0]
|
101 |
-
wound_cluster_idx = np.argmax(centers_lab[:, 1]) # 'a' channel in LAB is good for redness
|
102 |
-
mask = (labels.reshape(image.shape[:2]) == wound_cluster_idx).astype(np.uint8) * 255
|
103 |
-
contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
|
104 |
-
if contours:
|
105 |
-
largest_contour = max(contours, key=cv2.contourArea)
|
106 |
-
refined_mask = np.zeros_like(mask)
|
107 |
-
cv2.drawContours(refined_mask, [largest_contour], -1, 255, cv2.FILLED)
|
108 |
-
return refined_mask
|
109 |
-
return mask
|
110 |
|
111 |
def calculate_metrics(mask: np.ndarray) -> dict:
|
112 |
-
"""Calculates dimensional
|
113 |
-
|
114 |
-
if
|
115 |
return {"area_cm2": 0.0, "length_cm": 0.0, "breadth_cm": 0.0, "depth_cm": 0.0, "moisture": 0.0}
|
116 |
|
117 |
-
area_cm2 =
|
118 |
|
119 |
contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
|
120 |
if not contours:
|
@@ -125,91 +84,85 @@ def calculate_metrics(mask: np.ndarray) -> dict:
|
|
125 |
|
126 |
length_cm = max(width, height) / PIXELS_PER_CM
|
127 |
breadth_cm = min(width, height) / PIXELS_PER_CM
|
128 |
-
|
129 |
-
|
130 |
-
# These would typically require more advanced sensors or algorithms.
|
131 |
-
depth_cm = 0.1 # Placeholder value
|
132 |
-
moisture = 75.0 # Placeholder value
|
133 |
|
134 |
return {"area_cm2": area_cm2, "length_cm": length_cm, "breadth_cm": breadth_cm, "depth_cm": depth_cm, "moisture": moisture}
|
135 |
|
136 |
-
def
|
137 |
-
"""
|
138 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
139 |
|
140 |
-
#
|
141 |
-
|
142 |
-
|
143 |
-
|
144 |
-
|
145 |
-
|
146 |
-
#
|
147 |
-
|
148 |
-
|
149 |
-
|
150 |
-
|
151 |
-
|
152 |
-
|
153 |
-
|
154 |
-
|
155 |
-
|
156 |
-
|
157 |
-
zoom_ymax = min(image.shape[0], ymax + padding_y)
|
158 |
-
|
159 |
-
return annotated_img[zoom_ymin:zoom_ymax, zoom_xmin:zoom_xmax]
|
160 |
|
161 |
-
|
|
|
|
|
162 |
|
|
|
163 |
|
164 |
# --- Main API Endpoint ---
|
165 |
@app.post("/analyze_wound")
|
166 |
async def analyze_wound(file: UploadFile = File(...)):
|
167 |
-
|
168 |
-
|
169 |
-
|
170 |
-
"""
|
171 |
contents = await file.read()
|
172 |
-
|
173 |
-
original_image = cv2.imdecode(image_array, cv2.IMREAD_COLOR)
|
174 |
if original_image is None:
|
175 |
-
raise HTTPException(status_code=400, detail="Invalid
|
176 |
|
177 |
processed_image = preprocess_image(original_image)
|
178 |
|
179 |
-
# Use YOLO to find the general wound region
|
180 |
bbox = detect_wound_region_yolo(processed_image)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
181 |
|
182 |
-
|
183 |
-
|
184 |
-
|
185 |
-
|
186 |
-
|
187 |
-
# If no wound is detected, analyze the whole image as a fallback
|
188 |
-
cropped_for_segmentation = processed_image
|
189 |
-
|
190 |
-
# Segment the wound within the cropped region
|
191 |
-
mask = segment_wound_with_model(cropped_for_segmentation)
|
192 |
-
if mask is None:
|
193 |
-
mask = segment_wound_with_fallback(cropped_for_segmentation)
|
194 |
-
|
195 |
-
# Calculate metrics based on the precise mask
|
196 |
-
metrics = calculate_metrics(mask)
|
197 |
-
|
198 |
-
# Create a full-sized mask to pass to the visualization function
|
199 |
-
full_mask = np.zeros(original_image.shape[:2], dtype=np.uint8)
|
200 |
-
if bbox:
|
201 |
-
full_mask[ymin:ymax, xmin:xmax] = mask
|
202 |
-
else:
|
203 |
-
full_mask = mask
|
204 |
-
|
205 |
-
# Generate the final visual output: draw polygon and zoom
|
206 |
-
final_image = create_visual_overlay_and_zoom(original_image, full_mask, bbox)
|
207 |
|
208 |
-
|
|
|
209 |
if not success:
|
210 |
raise HTTPException(status_code=500, detail="Failed to encode output image")
|
211 |
|
212 |
-
# Set the custom headers
|
213 |
headers = {
|
214 |
'X-Length-Cm': f"{metrics['length_cm']:.2f}",
|
215 |
'X-Breadth-Cm': f"{metrics['breadth_cm']:.2f}",
|
|
|
2 |
import cv2
|
3 |
import numpy as np
|
4 |
from ultralytics import YOLO
|
|
|
5 |
import io
|
6 |
from typing import Union
|
7 |
|
|
|
10 |
|
11 |
# --- App Initialization ---
|
12 |
app = FastAPI(
|
13 |
+
title="Wound Heatmap Analysis API",
|
14 |
+
description="An API that generates a three-color heatmap (Red, Blue, Green) on a wound to show tissue characteristics.",
|
15 |
+
version="5.0.0" # Version updated for three-layer color heatmap
|
16 |
)
|
17 |
|
18 |
# --- Model Loading ---
|
19 |
+
def load_yolo_model():
|
20 |
+
"""Loads the YOLO model for initial wound detection."""
|
|
|
21 |
try:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
22 |
yolo_model = YOLO("best.pt")
|
23 |
print("YOLO model 'best.pt' loaded successfully.")
|
24 |
+
return yolo_model
|
25 |
except Exception as e:
|
26 |
+
print(f"Fatal: Could not load YOLO model. Error: {e}")
|
27 |
+
return None
|
|
|
28 |
|
29 |
+
yolo_model = load_yolo_model()
|
30 |
|
31 |
# --- Helper Functions ---
|
32 |
|
33 |
def preprocess_image(image: np.ndarray) -> np.ndarray:
|
34 |
+
"""Applies a median blur to reduce noise."""
|
35 |
+
return cv2.medianBlur(image, 5)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
36 |
|
37 |
def detect_wound_region_yolo(image: np.ndarray) -> Union[tuple, None]:
|
38 |
+
"""Detects the primary wound bounding box using the YOLO model."""
|
39 |
if not yolo_model: return None
|
40 |
try:
|
41 |
results = yolo_model.predict(image, verbose=False)
|
42 |
if results and results[0].boxes and len(results[0].boxes) > 0:
|
|
|
43 |
best_box = sorted(results[0].boxes, key=lambda b: b.conf[0], reverse=True)[0]
|
44 |
coords = best_box.xyxy[0].cpu().numpy()
|
45 |
return tuple(map(int, coords))
|
|
|
47 |
print(f"YOLO prediction failed: {e}")
|
48 |
return None
|
49 |
|
50 |
+
def create_wound_bed_mask(cropped_image: np.ndarray) -> np.ndarray:
|
51 |
+
"""
|
52 |
+
Creates a general binary mask of the entire wound bed, which will be the area for the heatmap.
|
53 |
+
This uses a broader color range than specific tissue segmentation.
|
54 |
+
"""
|
55 |
+
lab_image = cv2.cvtColor(cropped_image, cv2.COLOR_BGR2LAB)
|
56 |
+
|
57 |
+
# Broad thresholds to capture the entire wound area (slough, granulation, etc.)
|
58 |
+
lower_bound = np.array([0, 135, 135])
|
59 |
+
upper_bound = np.array([255, 200, 200])
|
60 |
+
|
61 |
+
mask = cv2.inRange(lab_image, lower_bound, upper_bound)
|
62 |
+
|
63 |
+
# Clean up the mask
|
64 |
+
kernel = np.ones((5, 5), np.uint8)
|
65 |
+
mask_cleaned = cv2.morphologyEx(mask, cv2.MORPH_OPEN, kernel, iterations=2)
|
66 |
+
mask_cleaned = cv2.morphologyEx(mask_cleaned, cv2.MORPH_CLOSE, kernel, iterations=3)
|
67 |
+
|
68 |
+
return mask_cleaned
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
69 |
|
70 |
def calculate_metrics(mask: np.ndarray) -> dict:
|
71 |
+
"""Calculates dimensional metrics from the overall wound mask."""
|
72 |
+
area_pixels = cv2.countNonZero(mask)
|
73 |
+
if area_pixels == 0:
|
74 |
return {"area_cm2": 0.0, "length_cm": 0.0, "breadth_cm": 0.0, "depth_cm": 0.0, "moisture": 0.0}
|
75 |
|
76 |
+
area_cm2 = area_pixels / (PIXELS_PER_CM ** 2)
|
77 |
|
78 |
contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
|
79 |
if not contours:
|
|
|
84 |
|
85 |
length_cm = max(width, height) / PIXELS_PER_CM
|
86 |
breadth_cm = min(width, height) / PIXELS_PER_CM
|
87 |
+
depth_cm = 0.1 # Placeholder
|
88 |
+
moisture = 75.0 # Placeholder
|
|
|
|
|
|
|
89 |
|
90 |
return {"area_cm2": area_cm2, "length_cm": length_cm, "breadth_cm": breadth_cm, "depth_cm": depth_cm, "moisture": moisture}
|
91 |
|
92 |
+
def create_three_color_heatmap(image: np.ndarray, mask: np.ndarray) -> np.ndarray:
|
93 |
+
"""
|
94 |
+
Generates and overlays a three-color (Red, Blue, Green) heatmap onto the image
|
95 |
+
based on the intensity of the 'a' channel (redness) in the LAB color space.
|
96 |
+
"""
|
97 |
+
if cv2.countNonZero(mask) == 0:
|
98 |
+
return image
|
99 |
+
|
100 |
+
# Convert the region of interest to LAB color space for analysis
|
101 |
+
lab_image = cv2.cvtColor(image, cv2.COLOR_BGR2LAB)
|
102 |
+
a_channel = lab_image[:, :, 1] # The 'a' channel represents the green-red axis
|
103 |
+
|
104 |
+
# Create a color overlay image, initially transparent
|
105 |
+
overlay = np.zeros_like(image)
|
106 |
|
107 |
+
# Define thresholds for redness intensity. These values might need tuning.
|
108 |
+
# Values are based on the 'a' channel, where ~128 is neutral.
|
109 |
+
RED_THRESHOLD = 160 # Most intense red (e.g., fresh granulation)
|
110 |
+
BLUE_THRESHOLD = 145 # Medium red (e.g., developing tissue)
|
111 |
+
# Anything above a lower bound (e.g., 135) will be green.
|
112 |
+
|
113 |
+
# Apply colors based on thresholds, only within the wound mask
|
114 |
+
# Red for "most" affected
|
115 |
+
overlay[(a_channel >= RED_THRESHOLD) & (mask == 255)] = (0, 0, 255) # BGR for Red
|
116 |
+
# Blue for "less" affected
|
117 |
+
overlay[(a_channel >= BLUE_THRESHOLD) & (a_channel < RED_THRESHOLD) & (mask == 255)] = (255, 0, 0) # BGR for Blue
|
118 |
+
# Green for "least" affected (but still part of the wound bed)
|
119 |
+
overlay[(a_channel < BLUE_THRESHOLD) & (mask == 255)] = (0, 255, 0) # BGR for Green
|
120 |
+
|
121 |
+
# Blend the original image with the color overlay
|
122 |
+
# A weight of 0.4 for the overlay makes it visible but not overpowering
|
123 |
+
blended_image = cv2.addWeighted(overlay, 0.4, image, 0.6, 0)
|
|
|
|
|
|
|
124 |
|
125 |
+
# To make the result clean, we only apply the blended result where the mask is active
|
126 |
+
final_image = image.copy()
|
127 |
+
final_image[mask == 255] = blended_image[mask == 255]
|
128 |
|
129 |
+
return final_image
|
130 |
|
131 |
# --- Main API Endpoint ---
|
132 |
@app.post("/analyze_wound")
|
133 |
async def analyze_wound(file: UploadFile = File(...)):
|
134 |
+
if not yolo_model:
|
135 |
+
raise HTTPException(status_code=503, detail="YOLO model is not available.")
|
136 |
+
|
|
|
137 |
contents = await file.read()
|
138 |
+
original_image = cv2.imdecode(np.frombuffer(contents, np.uint8), cv2.IMREAD_COLOR)
|
|
|
139 |
if original_image is None:
|
140 |
+
raise HTTPException(status_code=400, detail="Invalid image file")
|
141 |
|
142 |
processed_image = preprocess_image(original_image)
|
143 |
|
|
|
144 |
bbox = detect_wound_region_yolo(processed_image)
|
145 |
+
if not bbox:
|
146 |
+
raise HTTPException(status_code=404, detail="No wound detected in the image.")
|
147 |
+
|
148 |
+
xmin, ymin, xmax, ymax = bbox
|
149 |
+
cropped_image_roi = original_image[ymin:ymax, xmin:xmax]
|
150 |
+
|
151 |
+
# Step 1: Create a mask for the entire wound bed in the cropped image
|
152 |
+
wound_mask = create_wound_bed_mask(cropped_image_roi)
|
153 |
|
154 |
+
# Step 2: Calculate metrics based on this overall wound mask
|
155 |
+
metrics = calculate_metrics(wound_mask)
|
156 |
+
|
157 |
+
# Step 3: Generate the three-color heatmap on the cropped image
|
158 |
+
heatmap_image = create_three_color_heatmap(cropped_image_roi, wound_mask)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
159 |
|
160 |
+
# Step 4: Encode the final, annotated (and already cropped) image
|
161 |
+
success, png_data = cv2.imencode(".png", heatmap_image)
|
162 |
if not success:
|
163 |
raise HTTPException(status_code=500, detail="Failed to encode output image")
|
164 |
|
165 |
+
# Step 5: Set the custom headers
|
166 |
headers = {
|
167 |
'X-Length-Cm': f"{metrics['length_cm']:.2f}",
|
168 |
'X-Breadth-Cm': f"{metrics['breadth_cm']:.2f}",
|