Spaces:
Running
Running
Update predict.py
Browse files- predict.py +71 -69
predict.py
CHANGED
@@ -2,6 +2,7 @@ from fastapi import FastAPI, File, UploadFile, HTTPException, Response
|
|
2 |
import cv2
|
3 |
import numpy as np
|
4 |
from ultralytics import YOLO
|
|
|
5 |
import io
|
6 |
from typing import Union
|
7 |
|
@@ -10,32 +11,35 @@ PIXELS_PER_CM = 50.0
|
|
10 |
|
11 |
# --- App Initialization ---
|
12 |
app = FastAPI(
|
13 |
-
title="Wound Heatmap
|
14 |
-
description="
|
15 |
-
version="
|
16 |
)
|
17 |
|
18 |
# --- Model Loading ---
|
19 |
-
def
|
20 |
-
"""Loads the YOLO
|
|
|
21 |
try:
|
22 |
yolo_model = YOLO("best.pt")
|
23 |
print("YOLO model 'best.pt' loaded successfully.")
|
24 |
-
return yolo_model
|
25 |
except Exception as e:
|
26 |
-
print(f"
|
27 |
-
return None
|
28 |
|
29 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
30 |
|
31 |
-
|
32 |
|
33 |
-
|
34 |
-
"""Applies a median blur to reduce noise."""
|
35 |
-
return cv2.medianBlur(image, 5)
|
36 |
|
37 |
def detect_wound_region_yolo(image: np.ndarray) -> Union[tuple, None]:
|
38 |
-
"""Detects the primary wound bounding box using the YOLO model."""
|
39 |
if not yolo_model: return None
|
40 |
try:
|
41 |
results = yolo_model.predict(image, verbose=False)
|
@@ -47,34 +51,42 @@ def detect_wound_region_yolo(image: np.ndarray) -> Union[tuple, None]:
|
|
47 |
print(f"YOLO prediction failed: {e}")
|
48 |
return None
|
49 |
|
50 |
-
def
|
51 |
"""
|
52 |
-
|
53 |
-
|
54 |
"""
|
55 |
-
|
56 |
-
|
57 |
-
|
58 |
-
|
59 |
-
|
60 |
-
|
61 |
-
|
62 |
-
|
63 |
-
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
-
|
68 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
69 |
|
70 |
def calculate_metrics(mask: np.ndarray) -> dict:
|
71 |
-
"""Calculates dimensional metrics from the
|
72 |
area_pixels = cv2.countNonZero(mask)
|
73 |
if area_pixels == 0:
|
74 |
return {"area_cm2": 0.0, "length_cm": 0.0, "breadth_cm": 0.0, "depth_cm": 0.0, "moisture": 0.0}
|
75 |
|
76 |
area_cm2 = area_pixels / (PIXELS_PER_CM ** 2)
|
77 |
-
|
78 |
contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
|
79 |
if not contours:
|
80 |
return {"area_cm2": area_cm2, "length_cm": 0.0, "breadth_cm": 0.0, "depth_cm": 0.0, "moisture": 0.0}
|
@@ -90,39 +102,25 @@ def calculate_metrics(mask: np.ndarray) -> dict:
|
|
90 |
return {"area_cm2": area_cm2, "length_cm": length_cm, "breadth_cm": breadth_cm, "depth_cm": depth_cm, "moisture": moisture}
|
91 |
|
92 |
def create_three_color_heatmap(image: np.ndarray, mask: np.ndarray) -> np.ndarray:
|
93 |
-
"""
|
94 |
-
Generates and overlays a three-color (Red, Blue, Green) heatmap onto the image
|
95 |
-
based on the intensity of the 'a' channel (redness) in the LAB color space.
|
96 |
-
"""
|
97 |
if cv2.countNonZero(mask) == 0:
|
98 |
return image
|
99 |
|
100 |
-
# Convert the region of interest to LAB color space for analysis
|
101 |
lab_image = cv2.cvtColor(image, cv2.COLOR_BGR2LAB)
|
102 |
-
a_channel = lab_image[:, :, 1]
|
103 |
-
|
104 |
-
# Create a color overlay image, initially transparent
|
105 |
overlay = np.zeros_like(image)
|
106 |
|
107 |
-
|
108 |
-
|
109 |
-
|
110 |
-
|
111 |
-
|
112 |
-
|
113 |
-
|
114 |
-
#
|
115 |
-
overlay[(a_channel >= RED_THRESHOLD) & (mask == 255)] = (0, 0, 255) # BGR for Red
|
116 |
-
# Blue for "less" affected
|
117 |
-
overlay[(a_channel >= BLUE_THRESHOLD) & (a_channel < RED_THRESHOLD) & (mask == 255)] = (255, 0, 0) # BGR for Blue
|
118 |
-
# Green for "least" affected (but still part of the wound bed)
|
119 |
-
overlay[(a_channel < BLUE_THRESHOLD) & (mask == 255)] = (0, 255, 0) # BGR for Green
|
120 |
-
|
121 |
-
# Blend the original image with the color overlay
|
122 |
-
# A weight of 0.4 for the overlay makes it visible but not overpowering
|
123 |
blended_image = cv2.addWeighted(overlay, 0.4, image, 0.6, 0)
|
124 |
|
125 |
-
#
|
126 |
final_image = image.copy()
|
127 |
final_image[mask == 255] = blended_image[mask == 255]
|
128 |
|
@@ -131,33 +129,37 @@ def create_three_color_heatmap(image: np.ndarray, mask: np.ndarray) -> np.ndarra
|
|
131 |
# --- Main API Endpoint ---
|
132 |
@app.post("/analyze_wound")
|
133 |
async def analyze_wound(file: UploadFile = File(...)):
|
134 |
-
if not yolo_model:
|
135 |
-
raise HTTPException(status_code=503, detail="
|
136 |
|
137 |
contents = await file.read()
|
|
|
138 |
original_image = cv2.imdecode(np.frombuffer(contents, np.uint8), cv2.IMREAD_COLOR)
|
139 |
if original_image is None:
|
140 |
-
raise HTTPException(status_code=400, detail="Invalid image file")
|
141 |
|
142 |
-
|
143 |
-
|
144 |
-
bbox = detect_wound_region_yolo(
|
145 |
if not bbox:
|
146 |
raise HTTPException(status_code=404, detail="No wound detected in the image.")
|
147 |
|
148 |
xmin, ymin, xmax, ymax = bbox
|
|
|
149 |
cropped_image_roi = original_image[ymin:ymax, xmin:xmax]
|
150 |
|
151 |
-
# Step 1:
|
152 |
-
wound_mask =
|
|
|
|
|
153 |
|
154 |
-
# Step 2: Calculate metrics based on
|
155 |
metrics = calculate_metrics(wound_mask)
|
156 |
|
157 |
-
# Step 3: Generate the
|
158 |
heatmap_image = create_three_color_heatmap(cropped_image_roi, wound_mask)
|
159 |
|
160 |
-
# Step 4: Encode the final
|
161 |
success, png_data = cv2.imencode(".png", heatmap_image)
|
162 |
if not success:
|
163 |
raise HTTPException(status_code=500, detail="Failed to encode output image")
|
|
|
2 |
import cv2
|
3 |
import numpy as np
|
4 |
from ultralytics import YOLO
|
5 |
+
import tensorflow as tf
|
6 |
import io
|
7 |
from typing import Union
|
8 |
|
|
|
11 |
|
12 |
# --- App Initialization ---
|
13 |
app = FastAPI(
|
14 |
+
title="High-Quality Wound Heatmap API",
|
15 |
+
description="Generates a high-quality wound heatmap using a DL model, preserving original image quality.",
|
16 |
+
version="7.0.0" # Version updated for quality preservation
|
17 |
)
|
18 |
|
19 |
# --- Model Loading ---
|
20 |
+
def load_models():
|
21 |
+
"""Loads both the YOLO and the TensorFlow segmentation models."""
|
22 |
+
yolo_model, segmentation_model = None, None
|
23 |
try:
|
24 |
yolo_model = YOLO("best.pt")
|
25 |
print("YOLO model 'best.pt' loaded successfully.")
|
|
|
26 |
except Exception as e:
|
27 |
+
print(f"Warning: Could not load YOLO model. Error: {e}")
|
|
|
28 |
|
29 |
+
try:
|
30 |
+
segmentation_model = tf.keras.models.load_model("segmentation_model.h5")
|
31 |
+
print("Segmentation model 'segmentation_model.h5' loaded successfully.")
|
32 |
+
except Exception as e:
|
33 |
+
print(f"Warning: Could not load segmentation model. Error: {e}")
|
34 |
+
|
35 |
+
return yolo_model, segmentation_model
|
36 |
|
37 |
+
yolo_model, segmentation_model = load_models()
|
38 |
|
39 |
+
# --- Helper Functions ---
|
|
|
|
|
40 |
|
41 |
def detect_wound_region_yolo(image: np.ndarray) -> Union[tuple, None]:
|
42 |
+
"""Detects the primary wound bounding box using the YOLO model on the original quality image."""
|
43 |
if not yolo_model: return None
|
44 |
try:
|
45 |
results = yolo_model.predict(image, verbose=False)
|
|
|
51 |
print(f"YOLO prediction failed: {e}")
|
52 |
return None
|
53 |
|
54 |
+
def segment_wound_with_model(image: np.ndarray) -> Union[np.ndarray, None]:
|
55 |
"""
|
56 |
+
Segments the wound using the TF/Keras model.
|
57 |
+
It resizes a copy for the model but returns a mask matching the original image's dimensions.
|
58 |
"""
|
59 |
+
if not segmentation_model:
|
60 |
+
print("Segmentation model not loaded, cannot create mask.")
|
61 |
+
return None
|
62 |
+
try:
|
63 |
+
input_shape = segmentation_model.input_shape[1:3]
|
64 |
+
# A temporary, resized copy is made for the model's prediction
|
65 |
+
img_resized_for_model = cv2.resize(image, (input_shape[1], input_shape[0]))
|
66 |
+
img_norm = np.expand_dims(img_resized_for_model.astype(np.float32) / 255.0, axis=0)
|
67 |
+
|
68 |
+
prediction = segmentation_model.predict(img_norm, verbose=0)
|
69 |
+
|
70 |
+
while isinstance(prediction, list):
|
71 |
+
prediction = prediction[0]
|
72 |
+
if isinstance(prediction, tf.Tensor):
|
73 |
+
prediction = prediction.numpy()
|
74 |
+
|
75 |
+
pred_mask = prediction[0]
|
76 |
+
# The resulting mask is resized back to the original image's dimensions to ensure perfect alignment
|
77 |
+
pred_mask_resized = cv2.resize(pred_mask, (image.shape[1], image.shape[0]))
|
78 |
+
return (pred_mask_resized.squeeze() >= 0.5).astype(np.uint8) * 255
|
79 |
+
except Exception as e:
|
80 |
+
print(f"Segmentation model prediction failed: {e}")
|
81 |
+
return None
|
82 |
|
83 |
def calculate_metrics(mask: np.ndarray) -> dict:
|
84 |
+
"""Calculates dimensional metrics from the full-resolution wound mask."""
|
85 |
area_pixels = cv2.countNonZero(mask)
|
86 |
if area_pixels == 0:
|
87 |
return {"area_cm2": 0.0, "length_cm": 0.0, "breadth_cm": 0.0, "depth_cm": 0.0, "moisture": 0.0}
|
88 |
|
89 |
area_cm2 = area_pixels / (PIXELS_PER_CM ** 2)
|
|
|
90 |
contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
|
91 |
if not contours:
|
92 |
return {"area_cm2": area_cm2, "length_cm": 0.0, "breadth_cm": 0.0, "depth_cm": 0.0, "moisture": 0.0}
|
|
|
102 |
return {"area_cm2": area_cm2, "length_cm": length_cm, "breadth_cm": breadth_cm, "depth_cm": depth_cm, "moisture": moisture}
|
103 |
|
104 |
def create_three_color_heatmap(image: np.ndarray, mask: np.ndarray) -> np.ndarray:
|
105 |
+
"""Generates and overlays a three-color heatmap, preserving the underlying image quality."""
|
|
|
|
|
|
|
106 |
if cv2.countNonZero(mask) == 0:
|
107 |
return image
|
108 |
|
|
|
109 |
lab_image = cv2.cvtColor(image, cv2.COLOR_BGR2LAB)
|
110 |
+
a_channel = lab_image[:, :, 1]
|
|
|
|
|
111 |
overlay = np.zeros_like(image)
|
112 |
|
113 |
+
RED_THRESHOLD = 160
|
114 |
+
BLUE_THRESHOLD = 145
|
115 |
+
|
116 |
+
overlay[(a_channel >= RED_THRESHOLD) & (mask == 255)] = (0, 0, 255)
|
117 |
+
overlay[(a_channel >= BLUE_THRESHOLD) & (a_channel < RED_THRESHOLD) & (mask == 255)] = (255, 0, 0)
|
118 |
+
overlay[(a_channel < BLUE_THRESHOLD) & (mask == 255)] = (0, 255, 0)
|
119 |
+
|
120 |
+
# Blend the overlay with the original image
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
121 |
blended_image = cv2.addWeighted(overlay, 0.4, image, 0.6, 0)
|
122 |
|
123 |
+
# Create the final image by taking the original and replacing only the masked area with the blended version
|
124 |
final_image = image.copy()
|
125 |
final_image[mask == 255] = blended_image[mask == 255]
|
126 |
|
|
|
129 |
# --- Main API Endpoint ---
|
130 |
@app.post("/analyze_wound")
|
131 |
async def analyze_wound(file: UploadFile = File(...)):
|
132 |
+
if not yolo_model or not segmentation_model:
|
133 |
+
raise HTTPException(status_code=503, detail="A required model is not available.")
|
134 |
|
135 |
contents = await file.read()
|
136 |
+
# Read the image from buffer into a cv2 object without any lossy conversions
|
137 |
original_image = cv2.imdecode(np.frombuffer(contents, np.uint8), cv2.IMREAD_COLOR)
|
138 |
if original_image is None:
|
139 |
+
raise HTTPException(status_code=400, detail="Invalid or corrupt image file")
|
140 |
|
141 |
+
# No preprocessing is done to the original image to preserve quality.
|
142 |
+
# A copy of the original image is used for detection.
|
143 |
+
bbox = detect_wound_region_yolo(original_image.copy())
|
144 |
if not bbox:
|
145 |
raise HTTPException(status_code=404, detail="No wound detected in the image.")
|
146 |
|
147 |
xmin, ymin, xmax, ymax = bbox
|
148 |
+
# Crop the region of interest from the original, high-quality image
|
149 |
cropped_image_roi = original_image[ymin:ymax, xmin:xmax]
|
150 |
|
151 |
+
# Step 1: Use the DL model on the high-quality crop to get a precise mask
|
152 |
+
wound_mask = segment_wound_with_model(cropped_image_roi)
|
153 |
+
if wound_mask is None or cv2.countNonZero(wound_mask) == 0:
|
154 |
+
raise HTTPException(status_code=404, detail="Segmentation model failed to identify a wound in the detected region.")
|
155 |
|
156 |
+
# Step 2: Calculate metrics based on the full-resolution mask
|
157 |
metrics = calculate_metrics(wound_mask)
|
158 |
|
159 |
+
# Step 3: Generate the heatmap on the high-quality cropped image
|
160 |
heatmap_image = create_three_color_heatmap(cropped_image_roi, wound_mask)
|
161 |
|
162 |
+
# Step 4: Encode the final image into PNG (lossless format) to preserve quality
|
163 |
success, png_data = cv2.imencode(".png", heatmap_image)
|
164 |
if not success:
|
165 |
raise HTTPException(status_code=500, detail="Failed to encode output image")
|