Spaces:
Running
Running
Update predict.py
Browse files- predict.py +45 -39
predict.py
CHANGED
@@ -24,50 +24,48 @@ def load_models():
|
|
24 |
segmentation_model, yolo_model = load_models()
|
25 |
|
26 |
PIXELS_PER_CM = 50.0
|
27 |
-
app = FastAPI(title="Wound Analyzer", version="10.
|
28 |
|
29 |
# --- Preprocessing ---
|
30 |
def preprocess_image(image: np.ndarray) -> np.ndarray:
|
|
|
31 |
img_denoised = cv2.medianBlur(image, 3)
|
32 |
lab = cv2.cvtColor(img_denoised, cv2.COLOR_BGR2LAB)
|
33 |
l, a, b = cv2.split(lab)
|
34 |
clahe = cv2.createCLAHE(2.0, (8, 8))
|
35 |
-
|
36 |
-
|
37 |
-
result = cv2.cvtColor(
|
38 |
gamma = 1.2
|
39 |
return np.clip((result / 255.0) ** gamma * 255, 0, 255).astype(np.uint8)
|
40 |
|
41 |
# --- Segmentation ---
|
42 |
def segment_wound(image: np.ndarray) -> np.ndarray:
|
|
|
43 |
try:
|
44 |
if segmentation_model:
|
45 |
input_size = segmentation_model.input_shape[1:3]
|
46 |
resized = cv2.resize(image, (input_size[1], input_size[0]))
|
47 |
norm = np.expand_dims(resized / 255.0, axis=0)
|
48 |
-
|
49 |
prediction = segmentation_model.predict(norm, verbose=0)
|
50 |
-
if isinstance(prediction, list):
|
51 |
prediction = prediction[0]
|
52 |
-
prediction = prediction[0]
|
53 |
-
|
54 |
mask = cv2.resize(prediction.squeeze(), (image.shape[1], image.shape[0]))
|
55 |
return (mask >= 0.5).astype(np.uint8) * 255
|
56 |
except Exception as e:
|
57 |
print(f"⚠️ Model prediction failed: {e}")
|
58 |
|
59 |
-
# Fallback segmentation
|
60 |
Z = image.reshape((-1, 3)).astype(np.float32)
|
61 |
-
_, labels, centers = cv2.kmeans(Z, 2, None,
|
62 |
-
(cv2.TERM_CRITERIA_EPS + cv2.TERM_CRITERIA_MAX_ITER, 10, 1.0),
|
63 |
-
5, cv2.KMEANS_PP_CENTERS)
|
64 |
centers_lab = cv2.cvtColor(centers.reshape(1, -1, 3).astype(np.uint8), cv2.COLOR_BGR2LAB)[0]
|
65 |
wound_idx = np.argmax(centers_lab[:, 1])
|
66 |
mask = (labels.reshape(image.shape[:2]) == wound_idx).astype(np.uint8) * 255
|
67 |
return mask
|
68 |
|
69 |
# --- Metrics ---
|
70 |
-
def calculate_metrics(mask: np.ndarray,
|
|
|
71 |
area_px = cv2.countNonZero(mask)
|
72 |
if area_px == 0:
|
73 |
return dict(area_cm2=0.0, length_cm=0.0, breadth_cm=0.0, depth_cm=0.0, moisture=0.0)
|
@@ -79,26 +77,24 @@ def calculate_metrics(mask: np.ndarray, image: np.ndarray):
|
|
79 |
length_cm, breadth_cm = max(w, h) / PIXELS_PER_CM, min(w, h) / PIXELS_PER_CM
|
80 |
|
81 |
mask_bool = mask.astype(bool)
|
82 |
-
|
83 |
-
|
|
|
|
|
84 |
depth = np.mean(lab[:, :, 1][mask_bool]) - 128.0
|
85 |
moisture = max(0.0, 100.0 * (1 - np.std(gray[mask_bool]) / 127.0))
|
86 |
|
87 |
-
return dict(
|
88 |
-
area_cm2=round(area_cm2, 2),
|
89 |
-
length_cm=round(length_cm, 2),
|
90 |
-
breadth_cm=round(breadth_cm, 2),
|
91 |
-
depth_cm=round(depth, 1),
|
92 |
-
moisture=round(moisture, 0)
|
93 |
-
)
|
94 |
|
95 |
# --- Overlay ---
|
96 |
def draw_overlay(image: np.ndarray, mask: np.ndarray) -> np.ndarray:
|
|
|
|
|
97 |
dist = cv2.distanceTransform(mask, cv2.DIST_L2, 5)
|
98 |
cv2.normalize(dist, dist, 0, 1.0, cv2.NORM_MINMAX)
|
99 |
heatmap = np.zeros_like(image)
|
100 |
|
101 |
-
heatmap[dist >= 0.66] = (0,
|
102 |
heatmap[(dist >= 0.33) & (dist < 0.66)] = (255, 0, 0) # Blue - Moderate
|
103 |
heatmap[(dist > 0) & (dist < 0.33)] = (0, 255, 0) # Green - Least
|
104 |
|
@@ -107,42 +103,52 @@ def draw_overlay(image: np.ndarray, mask: np.ndarray) -> np.ndarray:
|
|
107 |
annotated[mask.astype(bool)] = blended[mask.astype(bool)]
|
108 |
|
109 |
contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
|
110 |
-
cv2.drawContours(annotated, contours, -1, (255, 255, 255), 2)
|
111 |
return annotated
|
112 |
|
113 |
# --- API Endpoint ---
|
114 |
@app.post("/analyze_wound")
|
115 |
async def analyze(file: UploadFile = File(...)):
|
116 |
-
|
117 |
-
|
|
|
118 |
raise HTTPException(status_code=400, detail="Invalid image file.")
|
119 |
|
120 |
-
|
121 |
-
|
|
|
|
|
|
|
|
|
122 |
|
123 |
if yolo_model:
|
124 |
try:
|
125 |
-
results = yolo_model.predict(
|
126 |
if results and results[0].boxes:
|
127 |
coords = results[0].boxes[0].xyxy[0].cpu().numpy().astype(int)
|
128 |
x1, y1, x2, y2 = coords
|
129 |
-
|
|
|
|
|
130 |
except Exception as e:
|
131 |
print(f"⚠️ YOLO detection failed: {e}")
|
132 |
|
133 |
-
mask
|
134 |
-
|
135 |
-
|
|
|
|
|
|
|
|
|
|
|
136 |
|
137 |
-
success,
|
138 |
if not success:
|
139 |
raise HTTPException(status_code=500, detail="Failed to encode image.")
|
140 |
|
141 |
headers = {
|
142 |
-
"X-Length-Cm": str(metrics["length_cm"]),
|
143 |
-
"X-
|
144 |
-
"X-Depth-Cm": str(metrics["depth_cm"]),
|
145 |
-
"X-Area-Cm2": str(metrics["area_cm2"]),
|
146 |
"X-Moisture": str(metrics["moisture"]),
|
147 |
}
|
148 |
-
return Response(content=
|
|
|
24 |
segmentation_model, yolo_model = load_models()
|
25 |
|
26 |
PIXELS_PER_CM = 50.0
|
27 |
+
app = FastAPI(title="Wound Analyzer", version="10.1") # Version with depth calculation fix
|
28 |
|
29 |
# --- Preprocessing ---
|
30 |
def preprocess_image(image: np.ndarray) -> np.ndarray:
|
31 |
+
"""Enhances image for better segmentation."""
|
32 |
img_denoised = cv2.medianBlur(image, 3)
|
33 |
lab = cv2.cvtColor(img_denoised, cv2.COLOR_BGR2LAB)
|
34 |
l, a, b = cv2.split(lab)
|
35 |
clahe = cv2.createCLAHE(2.0, (8, 8))
|
36 |
+
l_clahe = clahe.apply(l)
|
37 |
+
lab_clahe = cv2.merge((l_clahe, a, b))
|
38 |
+
result = cv2.cvtColor(lab_clahe, cv2.COLOR_LAB2BGR)
|
39 |
gamma = 1.2
|
40 |
return np.clip((result / 255.0) ** gamma * 255, 0, 255).astype(np.uint8)
|
41 |
|
42 |
# --- Segmentation ---
|
43 |
def segment_wound(image: np.ndarray) -> np.ndarray:
|
44 |
+
"""Segments wound from a preprocessed image, with a fallback."""
|
45 |
try:
|
46 |
if segmentation_model:
|
47 |
input_size = segmentation_model.input_shape[1:3]
|
48 |
resized = cv2.resize(image, (input_size[1], input_size[0]))
|
49 |
norm = np.expand_dims(resized / 255.0, axis=0)
|
|
|
50 |
prediction = segmentation_model.predict(norm, verbose=0)
|
51 |
+
if isinstance(prediction, list):
|
52 |
prediction = prediction[0]
|
53 |
+
prediction = prediction[0]
|
|
|
54 |
mask = cv2.resize(prediction.squeeze(), (image.shape[1], image.shape[0]))
|
55 |
return (mask >= 0.5).astype(np.uint8) * 255
|
56 |
except Exception as e:
|
57 |
print(f"⚠️ Model prediction failed: {e}")
|
58 |
|
|
|
59 |
Z = image.reshape((-1, 3)).astype(np.float32)
|
60 |
+
_, labels, centers = cv2.kmeans(Z, 2, None, (cv2.TERM_CRITERIA_EPS + cv2.TERM_CRITERIA_MAX_ITER, 10, 1.0), 5, cv2.KMEANS_PP_CENTERS)
|
|
|
|
|
61 |
centers_lab = cv2.cvtColor(centers.reshape(1, -1, 3).astype(np.uint8), cv2.COLOR_BGR2LAB)[0]
|
62 |
wound_idx = np.argmax(centers_lab[:, 1])
|
63 |
mask = (labels.reshape(image.shape[:2]) == wound_idx).astype(np.uint8) * 255
|
64 |
return mask
|
65 |
|
66 |
# --- Metrics ---
|
67 |
+
def calculate_metrics(mask: np.ndarray, original_image: np.ndarray):
|
68 |
+
"""Calculates metrics using the mask and the ORIGINAL image for color accuracy."""
|
69 |
area_px = cv2.countNonZero(mask)
|
70 |
if area_px == 0:
|
71 |
return dict(area_cm2=0.0, length_cm=0.0, breadth_cm=0.0, depth_cm=0.0, moisture=0.0)
|
|
|
77 |
length_cm, breadth_cm = max(w, h) / PIXELS_PER_CM, min(w, h) / PIXELS_PER_CM
|
78 |
|
79 |
mask_bool = mask.astype(bool)
|
80 |
+
# **FIX**: Use the original, unaltered image for color-based metrics.
|
81 |
+
lab = cv2.cvtColor(original_image, cv2.COLOR_BGR2LAB)
|
82 |
+
gray = cv2.cvtColor(original_image, cv2.COLOR_BGR2GRAY)
|
83 |
+
|
84 |
depth = np.mean(lab[:, :, 1][mask_bool]) - 128.0
|
85 |
moisture = max(0.0, 100.0 * (1 - np.std(gray[mask_bool]) / 127.0))
|
86 |
|
87 |
+
return dict(area_cm2=round(area_cm2, 2), length_cm=round(length_cm, 2), breadth_cm=round(breadth_cm, 2), depth_cm=round(depth, 1), moisture=round(moisture, 0))
|
|
|
|
|
|
|
|
|
|
|
|
|
88 |
|
89 |
# --- Overlay ---
|
90 |
def draw_overlay(image: np.ndarray, mask: np.ndarray) -> np.ndarray:
|
91 |
+
"""Draws heatmap and boundary on the image."""
|
92 |
+
# **CHANGE**: Use Yellow for high-intensity areas for better visibility
|
93 |
dist = cv2.distanceTransform(mask, cv2.DIST_L2, 5)
|
94 |
cv2.normalize(dist, dist, 0, 1.0, cv2.NORM_MINMAX)
|
95 |
heatmap = np.zeros_like(image)
|
96 |
|
97 |
+
heatmap[dist >= 0.66] = (0, 255, 255) # Yellow - Most Affected
|
98 |
heatmap[(dist >= 0.33) & (dist < 0.66)] = (255, 0, 0) # Blue - Moderate
|
99 |
heatmap[(dist > 0) & (dist < 0.33)] = (0, 255, 0) # Green - Least
|
100 |
|
|
|
103 |
annotated[mask.astype(bool)] = blended[mask.astype(bool)]
|
104 |
|
105 |
contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
|
106 |
+
cv2.drawContours(annotated, contours, -1, (255, 255, 255), 2)
|
107 |
return annotated
|
108 |
|
109 |
# --- API Endpoint ---
|
110 |
@app.post("/analyze_wound")
|
111 |
async def analyze(file: UploadFile = File(...)):
|
112 |
+
contents = await file.read()
|
113 |
+
original_image = cv2.imdecode(np.frombuffer(contents, np.uint8), cv2.IMREAD_COLOR)
|
114 |
+
if original_image is None:
|
115 |
raise HTTPException(status_code=400, detail="Invalid image file.")
|
116 |
|
117 |
+
# 1. Create a preprocessed version for segmentation
|
118 |
+
preprocessed_image = preprocess_image(original_image)
|
119 |
+
|
120 |
+
# 2. Define Regions of Interest (ROI) for both original and preprocessed images
|
121 |
+
original_roi = original_image.copy()
|
122 |
+
preprocessed_roi = preprocessed_image.copy()
|
123 |
|
124 |
if yolo_model:
|
125 |
try:
|
126 |
+
results = yolo_model.predict(preprocessed_image, verbose=False)
|
127 |
if results and results[0].boxes:
|
128 |
coords = results[0].boxes[0].xyxy[0].cpu().numpy().astype(int)
|
129 |
x1, y1, x2, y2 = coords
|
130 |
+
# Crop both the original and preprocessed images to the same ROI
|
131 |
+
original_roi = original_image[y1:y2, x1:x2]
|
132 |
+
preprocessed_roi = preprocessed_image[y1:y2, x1:x2]
|
133 |
except Exception as e:
|
134 |
print(f"⚠️ YOLO detection failed: {e}")
|
135 |
|
136 |
+
# 3. Get mask from the preprocessed ROI
|
137 |
+
mask = segment_wound(preprocessed_roi)
|
138 |
+
|
139 |
+
# 4. Calculate metrics using the ORIGINAL ROI for color accuracy
|
140 |
+
metrics = calculate_metrics(mask, original_roi)
|
141 |
+
|
142 |
+
# 5. Draw overlay on the ORIGINAL ROI for correct visualization
|
143 |
+
annotated_image = draw_overlay(original_roi, mask)
|
144 |
|
145 |
+
success, out_bytes = cv2.imencode(".png", annotated_image)
|
146 |
if not success:
|
147 |
raise HTTPException(status_code=500, detail="Failed to encode image.")
|
148 |
|
149 |
headers = {
|
150 |
+
"X-Length-Cm": str(metrics["length_cm"]), "X-Breadth-Cm": str(metrics["breadth_cm"]),
|
151 |
+
"X-Depth-Cm": str(metrics["depth_cm"]), "X-Area-Cm2": str(metrics["area_cm2"]),
|
|
|
|
|
152 |
"X-Moisture": str(metrics["moisture"]),
|
153 |
}
|
154 |
+
return Response(content=out_bytes.tobytes(), media_type="image/png", headers=headers)
|