Ani14 commited on
Commit
af8bbd0
·
verified ·
1 Parent(s): 0e2bcdc

Update predict.py

Browse files
Files changed (1) hide show
  1. predict.py +93 -140
predict.py CHANGED
@@ -2,7 +2,6 @@ from fastapi import FastAPI, File, UploadFile, HTTPException, Response
2
  import cv2
3
  import numpy as np
4
  from ultralytics import YOLO
5
- import tensorflow as tf
6
  import io
7
  from typing import Union
8
 
@@ -11,56 +10,36 @@ PIXELS_PER_CM = 50.0
11
 
12
  # --- App Initialization ---
13
  app = FastAPI(
14
- title="Wound Analysis API",
15
- description="An API to analyze wound images, zoom to the wound, and return an annotated image with data in headers.",
16
- version="3.5.0" # Version updated for polygon and zoom features
17
  )
18
 
19
  # --- Model Loading ---
20
- def load_models():
21
- """Loads the segmentation and YOLO models, handling potential errors."""
22
- segmentation_model, yolo_model = None, None
23
  try:
24
- # Load your trained segmentation model
25
- segmentation_model = tf.keras.models.load_model("segmentation_model.h5")
26
- print("Segmentation model 'segmentation_model.h5' loaded successfully.")
27
- except Exception as e:
28
- print(f"Warning: Could not load segmentation model. Using fallback. Error: {e}")
29
-
30
- try:
31
- # Load your trained YOLO model for wound detection
32
  yolo_model = YOLO("best.pt")
33
  print("YOLO model 'best.pt' loaded successfully.")
 
34
  except Exception as e:
35
- print(f"Warning: Could not load YOLO model. Using fallback. Error: {e}")
36
-
37
- return segmentation_model, yolo_model
38
 
39
- segmentation_model, yolo_model = load_models()
40
 
41
  # --- Helper Functions ---
42
 
43
  def preprocess_image(image: np.ndarray) -> np.ndarray:
44
- """Applies a series of preprocessing steps to enhance the image for analysis."""
45
- img_denoised = cv2.medianBlur(image, 3)
46
- lab = cv2.cvtColor(img_denoised, cv2.COLOR_BGR2LAB)
47
- l_channel, a_channel, b_channel = cv2.split(lab)
48
- clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8))
49
- l_clahe = clahe.apply(l_channel)
50
- lab_clahe = cv2.merge([l_clahe, a_channel, b_channel])
51
- img_clahe = cv2.cvtColor(lab_clahe, cv2.COLOR_LAB2BGR)
52
- gamma = 1.2
53
- img_float = img_clahe.astype(np.float32) / 255.0
54
- img_gamma = np.power(img_float, gamma)
55
- return (img_gamma * 255).astype(np.uint8)
56
 
57
  def detect_wound_region_yolo(image: np.ndarray) -> Union[tuple, None]:
58
- """Detects the wound bounding box using the YOLO model."""
59
  if not yolo_model: return None
60
  try:
61
  results = yolo_model.predict(image, verbose=False)
62
  if results and results[0].boxes and len(results[0].boxes) > 0:
63
- # Get the box with the highest confidence
64
  best_box = sorted(results[0].boxes, key=lambda b: b.conf[0], reverse=True)[0]
65
  coords = best_box.xyxy[0].cpu().numpy()
66
  return tuple(map(int, coords))
@@ -68,53 +47,33 @@ def detect_wound_region_yolo(image: np.ndarray) -> Union[tuple, None]:
68
  print(f"YOLO prediction failed: {e}")
69
  return None
70
 
71
- def segment_wound_with_model(image: np.ndarray) -> Union[np.ndarray, None]:
72
- """Segments the wound from the image using the primary segmentation model."""
73
- if not segmentation_model:
74
- return None
75
- try:
76
- input_shape = segmentation_model.input_shape[1:3]
77
- img_resized = cv2.resize(image, (input_shape[1], input_shape[0]))
78
- img_norm = np.expand_dims(img_resized.astype(np.float32) / 255.0, axis=0)
79
-
80
- prediction = segmentation_model.predict(img_norm, verbose=0)
81
-
82
- # Handle nested list or Tensor output from some model versions
83
- while isinstance(prediction, list):
84
- prediction = prediction[0]
85
- if isinstance(prediction, tf.Tensor):
86
- prediction = prediction.numpy()
87
-
88
- pred_mask = prediction[0]
89
- pred_mask_resized = cv2.resize(pred_mask, (image.shape[1], image.shape[0]))
90
- return (pred_mask_resized.squeeze() >= 0.5).astype(np.uint8) * 255
91
- except Exception as e:
92
- print(f"Segmentation model prediction failed: {e}")
93
- return None
94
-
95
- def segment_wound_with_fallback(image: np.ndarray) -> np.ndarray:
96
- """A fallback segmentation method using k-means clustering if the primary model fails."""
97
- pixels = image.reshape((-1, 3)).astype(np.float32)
98
- criteria = (cv2.TERM_CRITERIA_EPS + cv2.TERM_CRITERIA_MAX_ITER, 10, 1.0)
99
- _, labels, centers = cv2.kmeans(pixels, 2, None, criteria, 5, cv2.KMEANS_PP_CENTERS)
100
- centers_lab = cv2.cvtColor(centers.reshape(1, -1, 3).astype(np.uint8), cv2.COLOR_BGR2LAB)[0]
101
- wound_cluster_idx = np.argmax(centers_lab[:, 1]) # 'a' channel in LAB is good for redness
102
- mask = (labels.reshape(image.shape[:2]) == wound_cluster_idx).astype(np.uint8) * 255
103
- contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
104
- if contours:
105
- largest_contour = max(contours, key=cv2.contourArea)
106
- refined_mask = np.zeros_like(mask)
107
- cv2.drawContours(refined_mask, [largest_contour], -1, 255, cv2.FILLED)
108
- return refined_mask
109
- return mask
110
 
111
  def calculate_metrics(mask: np.ndarray) -> dict:
112
- """Calculates dimensional and analytical metrics from the wound mask."""
113
- wound_pixels = cv2.countNonZero(mask)
114
- if wound_pixels == 0:
115
  return {"area_cm2": 0.0, "length_cm": 0.0, "breadth_cm": 0.0, "depth_cm": 0.0, "moisture": 0.0}
116
 
117
- area_cm2 = wound_pixels / (PIXELS_PER_CM ** 2)
118
 
119
  contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
120
  if not contours:
@@ -125,91 +84,85 @@ def calculate_metrics(mask: np.ndarray) -> dict:
125
 
126
  length_cm = max(width, height) / PIXELS_PER_CM
127
  breadth_cm = min(width, height) / PIXELS_PER_CM
128
-
129
- # Placeholder for depth and moisture calculation.
130
- # These would typically require more advanced sensors or algorithms.
131
- depth_cm = 0.1 # Placeholder value
132
- moisture = 75.0 # Placeholder value
133
 
134
  return {"area_cm2": area_cm2, "length_cm": length_cm, "breadth_cm": breadth_cm, "depth_cm": depth_cm, "moisture": moisture}
135
 
136
- def create_visual_overlay_and_zoom(image: np.ndarray, mask: np.ndarray, bbox: tuple = None) -> np.ndarray:
137
- """Draws a polygon around the wound, applies a color overlay, and zooms to the region."""
138
- annotated_img = image.copy()
 
 
 
 
 
 
 
 
 
 
 
139
 
140
- # Find contours to draw the polygon
141
- contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
142
- if contours:
143
- # Draw a distinct polygon outline around the wound
144
- cv2.drawContours(annotated_img, contours, -1, (0, 255, 255), 2) # Yellow, 2px thick
145
-
146
- # Zoom to the wound area if a bounding box is available
147
- if bbox:
148
- xmin, ymin, xmax, ymax = bbox
149
- # Add a 10% padding around the bounding box for better context
150
- padding_x = int((xmax - xmin) * 0.10)
151
- padding_y = int((ymax - ymin) * 0.10)
152
-
153
- # Ensure coordinates are within image bounds
154
- zoom_xmin = max(0, xmin - padding_x)
155
- zoom_ymin = max(0, ymin - padding_y)
156
- zoom_xmax = min(image.shape[1], xmax + padding_x)
157
- zoom_ymax = min(image.shape[0], ymax + padding_y)
158
-
159
- return annotated_img[zoom_ymin:zoom_ymax, zoom_xmin:zoom_xmax]
160
 
161
- return annotated_img
 
 
162
 
 
163
 
164
  # --- Main API Endpoint ---
165
  @app.post("/analyze_wound")
166
  async def analyze_wound(file: UploadFile = File(...)):
167
- """
168
- Receives an image, analyzes the wound, and returns a zoomed,
169
- annotated image with metrics in the response headers.
170
- """
171
  contents = await file.read()
172
- image_array = np.frombuffer(contents, np.uint8)
173
- original_image = cv2.imdecode(image_array, cv2.IMREAD_COLOR)
174
  if original_image is None:
175
- raise HTTPException(status_code=400, detail="Invalid or corrupt image file")
176
 
177
  processed_image = preprocess_image(original_image)
178
 
179
- # Use YOLO to find the general wound region
180
  bbox = detect_wound_region_yolo(processed_image)
 
 
 
 
 
 
 
 
181
 
182
- if bbox:
183
- xmin, ymin, xmax, ymax = bbox
184
- # Crop to the detected region for more accurate segmentation
185
- cropped_for_segmentation = processed_image[ymin:ymax, xmin:xmax]
186
- else:
187
- # If no wound is detected, analyze the whole image as a fallback
188
- cropped_for_segmentation = processed_image
189
-
190
- # Segment the wound within the cropped region
191
- mask = segment_wound_with_model(cropped_for_segmentation)
192
- if mask is None:
193
- mask = segment_wound_with_fallback(cropped_for_segmentation)
194
-
195
- # Calculate metrics based on the precise mask
196
- metrics = calculate_metrics(mask)
197
-
198
- # Create a full-sized mask to pass to the visualization function
199
- full_mask = np.zeros(original_image.shape[:2], dtype=np.uint8)
200
- if bbox:
201
- full_mask[ymin:ymax, xmin:xmax] = mask
202
- else:
203
- full_mask = mask
204
-
205
- # Generate the final visual output: draw polygon and zoom
206
- final_image = create_visual_overlay_and_zoom(original_image, full_mask, bbox)
207
 
208
- success, png_data = cv2.imencode(".png", final_image)
 
209
  if not success:
210
  raise HTTPException(status_code=500, detail="Failed to encode output image")
211
 
212
- # Set the custom headers as requested
213
  headers = {
214
  'X-Length-Cm': f"{metrics['length_cm']:.2f}",
215
  'X-Breadth-Cm': f"{metrics['breadth_cm']:.2f}",
 
2
  import cv2
3
  import numpy as np
4
  from ultralytics import YOLO
 
5
  import io
6
  from typing import Union
7
 
 
10
 
11
  # --- App Initialization ---
12
  app = FastAPI(
13
+ title="Wound Heatmap Analysis API",
14
+ description="An API that generates a three-color heatmap (Red, Blue, Green) on a wound to show tissue characteristics.",
15
+ version="5.0.0" # Version updated for three-layer color heatmap
16
  )
17
 
18
  # --- Model Loading ---
19
+ def load_yolo_model():
20
+ """Loads the YOLO model for initial wound detection."""
 
21
  try:
 
 
 
 
 
 
 
 
22
  yolo_model = YOLO("best.pt")
23
  print("YOLO model 'best.pt' loaded successfully.")
24
+ return yolo_model
25
  except Exception as e:
26
+ print(f"Fatal: Could not load YOLO model. Error: {e}")
27
+ return None
 
28
 
29
+ yolo_model = load_yolo_model()
30
 
31
  # --- Helper Functions ---
32
 
33
  def preprocess_image(image: np.ndarray) -> np.ndarray:
34
+ """Applies a median blur to reduce noise."""
35
+ return cv2.medianBlur(image, 5)
 
 
 
 
 
 
 
 
 
 
36
 
37
  def detect_wound_region_yolo(image: np.ndarray) -> Union[tuple, None]:
38
+ """Detects the primary wound bounding box using the YOLO model."""
39
  if not yolo_model: return None
40
  try:
41
  results = yolo_model.predict(image, verbose=False)
42
  if results and results[0].boxes and len(results[0].boxes) > 0:
 
43
  best_box = sorted(results[0].boxes, key=lambda b: b.conf[0], reverse=True)[0]
44
  coords = best_box.xyxy[0].cpu().numpy()
45
  return tuple(map(int, coords))
 
47
  print(f"YOLO prediction failed: {e}")
48
  return None
49
 
50
+ def create_wound_bed_mask(cropped_image: np.ndarray) -> np.ndarray:
51
+ """
52
+ Creates a general binary mask of the entire wound bed, which will be the area for the heatmap.
53
+ This uses a broader color range than specific tissue segmentation.
54
+ """
55
+ lab_image = cv2.cvtColor(cropped_image, cv2.COLOR_BGR2LAB)
56
+
57
+ # Broad thresholds to capture the entire wound area (slough, granulation, etc.)
58
+ lower_bound = np.array([0, 135, 135])
59
+ upper_bound = np.array([255, 200, 200])
60
+
61
+ mask = cv2.inRange(lab_image, lower_bound, upper_bound)
62
+
63
+ # Clean up the mask
64
+ kernel = np.ones((5, 5), np.uint8)
65
+ mask_cleaned = cv2.morphologyEx(mask, cv2.MORPH_OPEN, kernel, iterations=2)
66
+ mask_cleaned = cv2.morphologyEx(mask_cleaned, cv2.MORPH_CLOSE, kernel, iterations=3)
67
+
68
+ return mask_cleaned
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
69
 
70
  def calculate_metrics(mask: np.ndarray) -> dict:
71
+ """Calculates dimensional metrics from the overall wound mask."""
72
+ area_pixels = cv2.countNonZero(mask)
73
+ if area_pixels == 0:
74
  return {"area_cm2": 0.0, "length_cm": 0.0, "breadth_cm": 0.0, "depth_cm": 0.0, "moisture": 0.0}
75
 
76
+ area_cm2 = area_pixels / (PIXELS_PER_CM ** 2)
77
 
78
  contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
79
  if not contours:
 
84
 
85
  length_cm = max(width, height) / PIXELS_PER_CM
86
  breadth_cm = min(width, height) / PIXELS_PER_CM
87
+ depth_cm = 0.1 # Placeholder
88
+ moisture = 75.0 # Placeholder
 
 
 
89
 
90
  return {"area_cm2": area_cm2, "length_cm": length_cm, "breadth_cm": breadth_cm, "depth_cm": depth_cm, "moisture": moisture}
91
 
92
+ def create_three_color_heatmap(image: np.ndarray, mask: np.ndarray) -> np.ndarray:
93
+ """
94
+ Generates and overlays a three-color (Red, Blue, Green) heatmap onto the image
95
+ based on the intensity of the 'a' channel (redness) in the LAB color space.
96
+ """
97
+ if cv2.countNonZero(mask) == 0:
98
+ return image
99
+
100
+ # Convert the region of interest to LAB color space for analysis
101
+ lab_image = cv2.cvtColor(image, cv2.COLOR_BGR2LAB)
102
+ a_channel = lab_image[:, :, 1] # The 'a' channel represents the green-red axis
103
+
104
+ # Create a color overlay image, initially transparent
105
+ overlay = np.zeros_like(image)
106
 
107
+ # Define thresholds for redness intensity. These values might need tuning.
108
+ # Values are based on the 'a' channel, where ~128 is neutral.
109
+ RED_THRESHOLD = 160 # Most intense red (e.g., fresh granulation)
110
+ BLUE_THRESHOLD = 145 # Medium red (e.g., developing tissue)
111
+ # Anything above a lower bound (e.g., 135) will be green.
112
+
113
+ # Apply colors based on thresholds, only within the wound mask
114
+ # Red for "most" affected
115
+ overlay[(a_channel >= RED_THRESHOLD) & (mask == 255)] = (0, 0, 255) # BGR for Red
116
+ # Blue for "less" affected
117
+ overlay[(a_channel >= BLUE_THRESHOLD) & (a_channel < RED_THRESHOLD) & (mask == 255)] = (255, 0, 0) # BGR for Blue
118
+ # Green for "least" affected (but still part of the wound bed)
119
+ overlay[(a_channel < BLUE_THRESHOLD) & (mask == 255)] = (0, 255, 0) # BGR for Green
120
+
121
+ # Blend the original image with the color overlay
122
+ # A weight of 0.4 for the overlay makes it visible but not overpowering
123
+ blended_image = cv2.addWeighted(overlay, 0.4, image, 0.6, 0)
 
 
 
124
 
125
+ # To make the result clean, we only apply the blended result where the mask is active
126
+ final_image = image.copy()
127
+ final_image[mask == 255] = blended_image[mask == 255]
128
 
129
+ return final_image
130
 
131
  # --- Main API Endpoint ---
132
  @app.post("/analyze_wound")
133
  async def analyze_wound(file: UploadFile = File(...)):
134
+ if not yolo_model:
135
+ raise HTTPException(status_code=503, detail="YOLO model is not available.")
136
+
 
137
  contents = await file.read()
138
+ original_image = cv2.imdecode(np.frombuffer(contents, np.uint8), cv2.IMREAD_COLOR)
 
139
  if original_image is None:
140
+ raise HTTPException(status_code=400, detail="Invalid image file")
141
 
142
  processed_image = preprocess_image(original_image)
143
 
 
144
  bbox = detect_wound_region_yolo(processed_image)
145
+ if not bbox:
146
+ raise HTTPException(status_code=404, detail="No wound detected in the image.")
147
+
148
+ xmin, ymin, xmax, ymax = bbox
149
+ cropped_image_roi = original_image[ymin:ymax, xmin:xmax]
150
+
151
+ # Step 1: Create a mask for the entire wound bed in the cropped image
152
+ wound_mask = create_wound_bed_mask(cropped_image_roi)
153
 
154
+ # Step 2: Calculate metrics based on this overall wound mask
155
+ metrics = calculate_metrics(wound_mask)
156
+
157
+ # Step 3: Generate the three-color heatmap on the cropped image
158
+ heatmap_image = create_three_color_heatmap(cropped_image_roi, wound_mask)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
159
 
160
+ # Step 4: Encode the final, annotated (and already cropped) image
161
+ success, png_data = cv2.imencode(".png", heatmap_image)
162
  if not success:
163
  raise HTTPException(status_code=500, detail="Failed to encode output image")
164
 
165
+ # Step 5: Set the custom headers
166
  headers = {
167
  'X-Length-Cm': f"{metrics['length_cm']:.2f}",
168
  'X-Breadth-Cm': f"{metrics['breadth_cm']:.2f}",