Spaces:
Running
Running
Update predict.py
Browse files- predict.py +147 -131
predict.py
CHANGED
@@ -1,176 +1,192 @@
|
|
|
|
|
|
1 |
from fastapi import FastAPI, File, UploadFile, HTTPException, Response
|
2 |
import cv2
|
3 |
import numpy as np
|
4 |
-
from ultralytics import YOLO
|
5 |
-
import tensorflow as tf
|
6 |
import io
|
7 |
from typing import Union
|
8 |
|
9 |
-
# --- Configuration ---
|
10 |
-
PIXELS_PER_CM = 50.0
|
11 |
-
|
12 |
-
# --- App Initialization ---
|
13 |
-
app = FastAPI(
|
14 |
-
title="High-Quality Wound Heatmap API",
|
15 |
-
description="Generates a high-quality wound heatmap using a DL model, preserving original image quality.",
|
16 |
-
version="7.0.0" # Version updated for quality preservation
|
17 |
-
)
|
18 |
-
|
19 |
# --- Model Loading ---
|
|
|
|
|
20 |
def load_models():
|
21 |
-
"""Loads
|
22 |
-
|
|
|
23 |
try:
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
|
|
|
28 |
|
29 |
try:
|
|
|
30 |
segmentation_model = tf.keras.models.load_model("segmentation_model.h5")
|
31 |
-
print("
|
32 |
-
except Exception as e:
|
33 |
-
print(f"Warning:
|
34 |
|
35 |
-
return
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
36 |
|
37 |
-
yolo_model, segmentation_model = load_models()
|
38 |
|
39 |
# --- Helper Functions ---
|
40 |
|
41 |
-
def
|
42 |
-
"""
|
43 |
-
|
44 |
-
|
45 |
-
|
46 |
-
|
47 |
-
|
48 |
-
|
49 |
-
|
50 |
-
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
-
Segments the wound using the TF
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
-
|
70 |
-
|
71 |
-
|
72 |
-
|
73 |
-
|
74 |
-
|
75 |
-
|
76 |
-
|
77 |
-
|
78 |
-
|
79 |
-
|
80 |
-
|
81 |
-
|
82 |
-
|
83 |
-
|
84 |
-
|
85 |
-
|
86 |
-
|
|
|
|
|
87 |
return {"area_cm2": 0.0, "length_cm": 0.0, "breadth_cm": 0.0, "depth_cm": 0.0, "moisture": 0.0}
|
88 |
|
89 |
-
area_cm2 =
|
90 |
contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
|
91 |
-
if not contours:
|
92 |
-
return {"area_cm2": area_cm2, "length_cm": 0.0, "breadth_cm": 0.0, "depth_cm": 0.0, "moisture": 0.0}
|
93 |
-
|
94 |
largest_contour = max(contours, key=cv2.contourArea)
|
95 |
(_, (width, height), _) = cv2.minAreaRect(largest_contour)
|
96 |
-
|
97 |
length_cm = max(width, height) / PIXELS_PER_CM
|
98 |
breadth_cm = min(width, height) / PIXELS_PER_CM
|
99 |
-
|
100 |
-
|
101 |
-
|
102 |
-
|
103 |
-
|
104 |
-
|
105 |
-
|
106 |
-
|
107 |
-
|
|
|
|
|
|
|
108 |
|
109 |
-
|
110 |
-
|
|
|
|
|
|
|
|
|
|
|
111 |
overlay = np.zeros_like(image)
|
112 |
|
113 |
-
|
114 |
-
|
115 |
-
|
116 |
-
overlay[(
|
117 |
-
overlay[(a_channel >= BLUE_THRESHOLD) & (a_channel < RED_THRESHOLD) & (mask == 255)] = (255, 0, 0)
|
118 |
-
overlay[(a_channel < BLUE_THRESHOLD) & (mask == 255)] = (0, 255, 0)
|
119 |
-
|
120 |
-
# Blend the overlay with the original image
|
121 |
-
blended_image = cv2.addWeighted(overlay, 0.4, image, 0.6, 0)
|
122 |
|
123 |
-
|
124 |
final_image = image.copy()
|
125 |
-
final_image[mask
|
|
|
|
|
|
|
|
|
|
|
|
|
126 |
|
127 |
return final_image
|
128 |
|
129 |
# --- Main API Endpoint ---
|
130 |
@app.post("/analyze_wound")
|
131 |
async def analyze_wound(file: UploadFile = File(...)):
|
132 |
-
if not yolo_model or not segmentation_model:
|
133 |
-
raise HTTPException(status_code=503, detail="A required model is not available.")
|
134 |
-
|
135 |
contents = await file.read()
|
136 |
-
|
137 |
-
original_image = cv2.imdecode(
|
138 |
if original_image is None:
|
139 |
-
raise HTTPException(status_code=400, detail="Invalid or corrupt image file")
|
140 |
-
|
141 |
-
# No preprocessing is done to the original image to preserve quality.
|
142 |
-
# A copy of the original image is used for detection.
|
143 |
-
bbox = detect_wound_region_yolo(original_image.copy())
|
144 |
-
if not bbox:
|
145 |
-
raise HTTPException(status_code=404, detail="No wound detected in the image.")
|
146 |
-
|
147 |
-
xmin, ymin, xmax, ymax = bbox
|
148 |
-
# Crop the region of interest from the original, high-quality image
|
149 |
-
cropped_image_roi = original_image[ymin:ymax, xmin:xmax]
|
150 |
-
|
151 |
-
# Step 1: Use the DL model on the high-quality crop to get a precise mask
|
152 |
-
wound_mask = segment_wound_with_model(cropped_image_roi)
|
153 |
-
if wound_mask is None or cv2.countNonZero(wound_mask) == 0:
|
154 |
-
raise HTTPException(status_code=404, detail="Segmentation model failed to identify a wound in the detected region.")
|
155 |
-
|
156 |
-
# Step 2: Calculate metrics based on the full-resolution mask
|
157 |
-
metrics = calculate_metrics(wound_mask)
|
158 |
|
159 |
-
|
160 |
-
heatmap_image = create_three_color_heatmap(cropped_image_roi, wound_mask)
|
161 |
|
162 |
-
|
163 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
164 |
if not success:
|
165 |
-
raise HTTPException(status_code=500, detail="Failed to encode output image")
|
166 |
|
167 |
-
# Step 5: Set the custom headers
|
168 |
headers = {
|
169 |
-
'X-Length-Cm':
|
170 |
-
'X-Breadth-Cm':
|
171 |
-
'X-Depth-Cm':
|
172 |
-
'X-Area-Cm2':
|
173 |
-
'X-Moisture':
|
174 |
}
|
175 |
|
176 |
return Response(content=png_data.tobytes(), media_type="image/png", headers=headers)
|
|
|
1 |
+
# main.py
|
2 |
+
|
3 |
from fastapi import FastAPI, File, UploadFile, HTTPException, Response
|
4 |
import cv2
|
5 |
import numpy as np
|
|
|
|
|
6 |
import io
|
7 |
from typing import Union
|
8 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
9 |
# --- Model Loading ---
|
10 |
+
# This section attempts to load the specific deep learning models you are using.
|
11 |
+
|
12 |
def load_models():
|
13 |
+
"""Loads TensorFlow and YOLO models using your specified filenames."""
|
14 |
+
segmentation_model, yolo_detector = None, None
|
15 |
+
|
16 |
try:
|
17 |
+
from ultralytics import YOLO
|
18 |
+
yolo_detector = YOLO("best.pt")
|
19 |
+
print("YOLOv8 detection model 'best.pt' loaded successfully.")
|
20 |
+
except (ImportError, IOError, Exception) as e:
|
21 |
+
print(f"Warning: YOLOv8 model not loaded. Using contour-based region detection. Error: {e}")
|
22 |
|
23 |
try:
|
24 |
+
import tensorflow as tf
|
25 |
segmentation_model = tf.keras.models.load_model("segmentation_model.h5")
|
26 |
+
print("TensorFlow segmentation model 'segmentation_model.h5' loaded successfully.")
|
27 |
+
except (ImportError, IOError, Exception) as e:
|
28 |
+
print(f"Warning: TensorFlow segmentation model not loaded. Using OpenCV fallback. Error: {e}")
|
29 |
|
30 |
+
return segmentation_model, yolo_detector
|
31 |
+
|
32 |
+
segmentation_model, yolo_model = load_models()
|
33 |
+
|
34 |
+
|
35 |
+
# --- Configuration ---
|
36 |
+
PIXELS_PER_CM = 50.0
|
37 |
+
|
38 |
+
# --- App Initialization ---
|
39 |
+
app = FastAPI(
|
40 |
+
title="Wound Analysis API",
|
41 |
+
description="A comprehensive API to analyze wound images using deep learning and computer vision techniques.",
|
42 |
+
version="9.0.0" # Version with improved visualization (Yellow Heatmap + Boundary)
|
43 |
+
)
|
44 |
|
|
|
45 |
|
46 |
# --- Helper Functions ---
|
47 |
|
48 |
+
def preprocess_image(image: np.ndarray) -> np.ndarray:
|
49 |
+
"""Applies the full preprocessing pipeline: Denoise -> CLAHE -> Gamma Correction."""
|
50 |
+
img_denoised = cv2.medianBlur(image, 3)
|
51 |
+
lab = cv2.cvtColor(img_denoised, cv2.COLOR_BGR2LAB)
|
52 |
+
l_channel, a_channel, b_channel = cv2.split(lab)
|
53 |
+
clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8))
|
54 |
+
l_clahe = clahe.apply(l_channel)
|
55 |
+
lab_clahe = cv2.merge([l_clahe, a_channel, b_channel])
|
56 |
+
img_clahe = cv2.cvtColor(lab_clahe, cv2.COLOR_LAB2BGR)
|
57 |
+
gamma = 1.2
|
58 |
+
img_float = img_clahe.astype(np.float32) / 255.0
|
59 |
+
img_gamma = np.power(img_float, gamma)
|
60 |
+
return (img_gamma * 255).astype(np.uint8)
|
61 |
+
|
62 |
+
def segment_wound(image: np.ndarray) -> np.ndarray:
|
63 |
+
"""Segments the wound using the TF model if available, otherwise falls back to color clustering."""
|
64 |
+
if segmentation_model:
|
65 |
+
try:
|
66 |
+
orig_h, orig_w = image.shape[:2]
|
67 |
+
model_input_size = segmentation_model.input.shape[1:3]
|
68 |
+
img_resized = cv2.resize(image, (model_input_size[1], model_input_size[0]))
|
69 |
+
img_norm = np.expand_dims(img_resized.astype(np.float32) / 255.0, axis=0)
|
70 |
+
pred_mask = segmentation_model.predict(img_norm, verbose=0)[0]
|
71 |
+
pred_mask_resized = cv2.resize(pred_mask, (orig_w, orig_h), interpolation=cv2.INTER_LINEAR)
|
72 |
+
mask = (pred_mask_resized.squeeze() >= 0.5).astype(np.uint8) * 255
|
73 |
+
if cv2.countNonZero(mask) > 0:
|
74 |
+
return mask
|
75 |
+
except Exception as e:
|
76 |
+
print(f"Model prediction failed, switching to fallback segmentation. Error: {e}")
|
77 |
+
|
78 |
+
pixels = image.reshape((-1, 3)).astype(np.float32)
|
79 |
+
criteria = (cv2.TERM_CRITERIA_EPS + cv2.TERM_CRITERIA_MAX_ITER, 10, 1.0)
|
80 |
+
_, labels, centers = cv2.kmeans(pixels, 2, None, criteria, 5, cv2.KMEANS_PP_CENTERS)
|
81 |
+
centers_lab = cv2.cvtColor(centers.reshape(1, -1, 3).astype(np.uint8), cv2.COLOR_BGR2LAB)[0]
|
82 |
+
wound_cluster_idx = np.argmax(centers_lab[:, 1])
|
83 |
+
mask = (labels.reshape(image.shape[:2]) == wound_cluster_idx).astype(np.uint8) * 255
|
84 |
+
contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
|
85 |
+
if contours:
|
86 |
+
largest_contour = max(contours, key=cv2.contourArea)
|
87 |
+
refined_mask = np.zeros_like(mask)
|
88 |
+
cv2.drawContours(refined_mask, [largest_contour], -1, 255, cv2.FILLED)
|
89 |
+
return refined_mask
|
90 |
+
return mask
|
91 |
+
|
92 |
+
def calculate_all_metrics(mask: np.ndarray, image: np.ndarray) -> dict:
|
93 |
+
"""Computes all specified wound metrics from the mask and original image."""
|
94 |
+
wound_pixels = cv2.countNonZero(mask)
|
95 |
+
if wound_pixels == 0:
|
96 |
return {"area_cm2": 0.0, "length_cm": 0.0, "breadth_cm": 0.0, "depth_cm": 0.0, "moisture": 0.0}
|
97 |
|
98 |
+
area_cm2 = wound_pixels / (PIXELS_PER_CM ** 2)
|
99 |
contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
|
|
|
|
|
|
|
100 |
largest_contour = max(contours, key=cv2.contourArea)
|
101 |
(_, (width, height), _) = cv2.minAreaRect(largest_contour)
|
|
|
102 |
length_cm = max(width, height) / PIXELS_PER_CM
|
103 |
breadth_cm = min(width, height) / PIXELS_PER_CM
|
104 |
+
mask_bool = mask.astype(bool)
|
105 |
+
lab_img = cv2.cvtColor(image, cv2.COLOR_BGR2LAB)
|
106 |
+
mean_a = np.mean(lab_img[:, :, 1][mask_bool])
|
107 |
+
depth_score = mean_a - 128.0
|
108 |
+
gray_img = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
|
109 |
+
texture_std = np.std(gray_img[mask_bool])
|
110 |
+
moisture_score = max(0.0, 100.0 * (1.0 - texture_std / 127.0))
|
111 |
+
|
112 |
+
return {
|
113 |
+
"area_cm2": f"{area_cm2:.2f}", "length_cm": f"{length_cm:.2f}", "breadth_cm": f"{breadth_cm:.2f}",
|
114 |
+
"depth_cm": f"{depth_score:.1f}", "moisture": f"{moisture_score:.0f}"
|
115 |
+
}
|
116 |
|
117 |
+
def create_visual_overlay(image: np.ndarray, mask: np.ndarray) -> np.ndarray:
|
118 |
+
"""
|
119 |
+
Generates a visual overlay with a Yellow/Blue/Green heatmap and a white boundary.
|
120 |
+
"""
|
121 |
+
# --- 1. Create the Color Heatmap ---
|
122 |
+
dist = cv2.distanceTransform(mask, cv2.DIST_L2, 5)
|
123 |
+
cv2.normalize(dist, dist, 0, 1.0, cv2.NORM_MINMAX)
|
124 |
overlay = np.zeros_like(image)
|
125 |
|
126 |
+
# **CHANGE**: Use Yellow instead of Red for the most affected area for better visibility.
|
127 |
+
overlay[dist >= 0.66] = (0, 255, 255) # Yellow in BGR
|
128 |
+
overlay[(dist >= 0.33) & (dist < 0.66)] = (255, 0, 0) # Blue in BGR
|
129 |
+
overlay[(dist > 0) & (dist < 0.33)] = (0, 255, 0) # Green in BGR
|
|
|
|
|
|
|
|
|
|
|
130 |
|
131 |
+
blended = cv2.addWeighted(image, 0.7, overlay, 0.3, 0)
|
132 |
final_image = image.copy()
|
133 |
+
final_image[mask.astype(bool)] = blended[mask.astype(bool)]
|
134 |
+
|
135 |
+
# --- 2. Draw the Boundary Contour ---
|
136 |
+
# Find contours from the mask to draw the boundary line.
|
137 |
+
contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
|
138 |
+
# **NEW**: Draw a crisp white boundary on the final image.
|
139 |
+
cv2.drawContours(final_image, contours, -1, (255, 255, 255), 1) # White color, 1px thickness
|
140 |
|
141 |
return final_image
|
142 |
|
143 |
# --- Main API Endpoint ---
|
144 |
@app.post("/analyze_wound")
|
145 |
async def analyze_wound(file: UploadFile = File(...)):
|
|
|
|
|
|
|
146 |
contents = await file.read()
|
147 |
+
image_array = np.frombuffer(contents, np.uint8)
|
148 |
+
original_image = cv2.imdecode(image_array, cv2.IMREAD_COLOR)
|
149 |
if original_image is None:
|
150 |
+
raise HTTPException(status_code=400, detail="Invalid or corrupt image file.")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
151 |
|
152 |
+
processed_image = preprocess_image(original_image)
|
|
|
153 |
|
154 |
+
roi_image = processed_image
|
155 |
+
original_roi = original_image
|
156 |
+
if yolo_model:
|
157 |
+
try:
|
158 |
+
results = yolo_model.predict(processed_image, verbose=False)
|
159 |
+
if results and results[0].boxes and len(results[0].boxes) > 0:
|
160 |
+
best_box = sorted(results[0].boxes, key=lambda b: b.conf[0], reverse=True)[0]
|
161 |
+
coords = best_box.xyxy[0].cpu().numpy()
|
162 |
+
x1, y1, x2, y2 = map(int, coords)
|
163 |
+
roi_image = processed_image[y1:y2, x1:x2]
|
164 |
+
original_roi = original_image[y1:y2, x1:x2]
|
165 |
+
except Exception as e:
|
166 |
+
print(f"YOLO prediction failed, analyzing full image. Error: {e}")
|
167 |
+
|
168 |
+
wound_mask = segment_wound(roi_image)
|
169 |
+
if cv2.countNonZero(wound_mask) == 0:
|
170 |
+
_, png_data = cv2.imencode(".png", original_image)
|
171 |
+
headers = {
|
172 |
+
'X-Length-Cm': '0.0', 'X-Breadth-Cm': '0.0', 'X-Depth-Cm': '0.0',
|
173 |
+
'X-Area-Cm2': '0.0', 'X-Moisture': '0.0'
|
174 |
+
}
|
175 |
+
return Response(content=png_data.tobytes(), media_type="image/png", headers=headers)
|
176 |
+
|
177 |
+
metrics = calculate_all_metrics(wound_mask, original_roi)
|
178 |
+
annotated_image = create_visual_overlay(original_roi, wound_mask)
|
179 |
+
|
180 |
+
success, png_data = cv2.imencode(".png", annotated_image)
|
181 |
if not success:
|
182 |
+
raise HTTPException(status_code=500, detail="Failed to encode output image.")
|
183 |
|
|
|
184 |
headers = {
|
185 |
+
'X-Length-Cm': metrics['length_cm'],
|
186 |
+
'X-Breadth-Cm': metrics['breadth_cm'],
|
187 |
+
'X-Depth-Cm': metrics['depth_cm'],
|
188 |
+
'X-Area-Cm2': metrics['area_cm2'],
|
189 |
+
'X-Moisture': metrics['moisture']
|
190 |
}
|
191 |
|
192 |
return Response(content=png_data.tobytes(), media_type="image/png", headers=headers)
|