Ani14 commited on
Commit
17d50ed
·
verified ·
1 Parent(s): bfe1160

Update predict.py

Browse files
Files changed (1) hide show
  1. predict.py +93 -54
predict.py CHANGED
@@ -6,15 +6,18 @@ from typing import Union
6
 
7
  # --- Load Models ---
8
  def load_models():
 
9
  segmentation_model, yolo_model = None, None
10
  try:
11
  import tensorflow as tf
 
12
  segmentation_model = tf.keras.models.load_model("segmentation_model.h5")
13
  print("✅ Segmentation model loaded.")
14
  except Exception as e:
15
  print(f"⚠️ Failed to load segmentation model: {e}")
16
  try:
17
  from ultralytics import YOLO
 
18
  yolo_model = YOLO("best.pt")
19
  print("✅ YOLO model loaded.")
20
  except Exception as e:
@@ -23,132 +26,168 @@ def load_models():
23
 
24
  segmentation_model, yolo_model = load_models()
25
 
 
 
 
26
  PIXELS_PER_CM = 50.0
27
- app = FastAPI(title="Wound Analyzer", version="10.1") # Version with depth calculation fix
28
 
29
- # --- Preprocessing ---
 
 
 
 
 
 
30
  def preprocess_image(image: np.ndarray) -> np.ndarray:
31
- """Enhances image for better segmentation."""
32
  img_denoised = cv2.medianBlur(image, 3)
33
  lab = cv2.cvtColor(img_denoised, cv2.COLOR_BGR2LAB)
34
  l, a, b = cv2.split(lab)
35
- clahe = cv2.createCLAHE(2.0, (8, 8))
36
  l_clahe = clahe.apply(l)
37
  lab_clahe = cv2.merge((l_clahe, a, b))
38
  result = cv2.cvtColor(lab_clahe, cv2.COLOR_LAB2BGR)
39
  gamma = 1.2
40
  return np.clip((result / 255.0) ** gamma * 255, 0, 255).astype(np.uint8)
41
 
42
- # --- Segmentation ---
43
  def segment_wound(image: np.ndarray) -> np.ndarray:
44
- """Segments wound from a preprocessed image, with a fallback."""
45
- try:
46
- if segmentation_model:
47
  input_size = segmentation_model.input_shape[1:3]
48
  resized = cv2.resize(image, (input_size[1], input_size[0]))
49
  norm = np.expand_dims(resized / 255.0, axis=0)
50
  prediction = segmentation_model.predict(norm, verbose=0)
 
51
  if isinstance(prediction, list):
52
  prediction = prediction[0]
53
  prediction = prediction[0]
54
  mask = cv2.resize(prediction.squeeze(), (image.shape[1], image.shape[0]))
55
  return (mask >= 0.5).astype(np.uint8) * 255
56
- except Exception as e:
57
- print(f"⚠️ Model prediction failed: {e}")
58
 
 
59
  Z = image.reshape((-1, 3)).astype(np.float32)
60
- _, labels, centers = cv2.kmeans(Z, 2, None, (cv2.TERM_CRITERIA_EPS + cv2.TERM_CRITERIA_MAX_ITER, 10, 1.0), 5, cv2.KMEANS_PP_CENTERS)
 
61
  centers_lab = cv2.cvtColor(centers.reshape(1, -1, 3).astype(np.uint8), cv2.COLOR_BGR2LAB)[0]
62
- wound_idx = np.argmax(centers_lab[:, 1])
63
  mask = (labels.reshape(image.shape[:2]) == wound_idx).astype(np.uint8) * 255
64
  return mask
65
 
66
- # --- Metrics ---
67
  def calculate_metrics(mask: np.ndarray, original_image: np.ndarray):
68
- """Calculates metrics using the mask and the ORIGINAL image for color accuracy."""
 
 
69
  area_px = cv2.countNonZero(mask)
70
  if area_px == 0:
71
  return dict(area_cm2=0.0, length_cm=0.0, breadth_cm=0.0, depth_cm=0.0, moisture=0.0)
72
-
 
73
  area_cm2 = area_px / (PIXELS_PER_CM ** 2)
74
  contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
 
 
 
75
  rect = cv2.minAreaRect(max(contours, key=cv2.contourArea))
76
  (w, h) = rect[1]
77
  length_cm, breadth_cm = max(w, h) / PIXELS_PER_CM, min(w, h) / PIXELS_PER_CM
78
 
 
79
  mask_bool = mask.astype(bool)
80
- # **FIX**: Use the original, unaltered image for color-based metrics.
81
- lab = cv2.cvtColor(original_image, cv2.COLOR_BGR2LAB)
82
- gray = cv2.cvtColor(original_image, cv2.COLOR_BGR2GRAY)
83
 
84
- depth = np.mean(lab[:, :, 1][mask_bool]) - 128.0
85
- moisture = max(0.0, 100.0 * (1 - np.std(gray[mask_bool]) / 127.0))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
86
 
87
- return dict(area_cm2=round(area_cm2, 2), length_cm=round(length_cm, 2), breadth_cm=round(breadth_cm, 2), depth_cm=round(depth, 1), moisture=round(moisture, 0))
88
-
89
- # --- Overlay ---
90
  def draw_overlay(image: np.ndarray, mask: np.ndarray) -> np.ndarray:
91
- """Draws heatmap and boundary on the image."""
92
- # **CHANGE**: Use Yellow for high-intensity areas for better visibility
93
- dist = cv2.distanceTransform(mask, cv2.DIST_L2, 5)
94
- cv2.normalize(dist, dist, 0, 1.0, cv2.NORM_MINMAX)
95
- heatmap = np.zeros_like(image)
96
-
97
- heatmap[dist >= 0.66] = (0, 255, 255) # Yellow - Most Affected
98
- heatmap[(dist >= 0.33) & (dist < 0.66)] = (255, 0, 0) # Blue - Moderate
99
- heatmap[(dist > 0) & (dist < 0.33)] = (0, 255, 0) # Green - Least
100
-
101
- blended = cv2.addWeighted(image, 0.7, heatmap, 0.3, 0)
102
- annotated = image.copy()
103
- annotated[mask.astype(bool)] = blended[mask.astype(bool)]
104
-
 
 
105
  contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
106
- cv2.drawContours(annotated, contours, -1, (255, 255, 255), 2)
107
- return annotated
108
 
109
  # --- API Endpoint ---
110
  @app.post("/analyze_wound")
111
  async def analyze(file: UploadFile = File(...)):
 
 
 
 
112
  contents = await file.read()
113
  original_image = cv2.imdecode(np.frombuffer(contents, np.uint8), cv2.IMREAD_COLOR)
114
  if original_image is None:
115
- raise HTTPException(status_code=400, detail="Invalid image file.")
116
 
117
- # 1. Create a preprocessed version for segmentation
118
  preprocessed_image = preprocess_image(original_image)
119
 
120
- # 2. Define Regions of Interest (ROI) for both original and preprocessed images
121
- original_roi = original_image.copy()
122
- preprocessed_roi = preprocessed_image.copy()
123
-
124
  if yolo_model:
125
  try:
126
  results = yolo_model.predict(preprocessed_image, verbose=False)
127
  if results and results[0].boxes:
128
  coords = results[0].boxes[0].xyxy[0].cpu().numpy().astype(int)
129
- x1, y1, x2, y2 = coords
130
- # Crop both the original and preprocessed images to the same ROI
131
- original_roi = original_image[y1:y2, x1:x2]
132
- preprocessed_roi = preprocessed_image[y1:y2, x1:x2]
133
  except Exception as e:
134
- print(f"⚠️ YOLO detection failed: {e}")
 
 
 
 
 
 
 
135
 
136
  # 3. Get mask from the preprocessed ROI
137
  mask = segment_wound(preprocessed_roi)
138
 
139
- # 4. Calculate metrics using the ORIGINAL ROI for color accuracy
140
  metrics = calculate_metrics(mask, original_roi)
141
 
142
  # 5. Draw overlay on the ORIGINAL ROI for correct visualization
143
  annotated_image = draw_overlay(original_roi, mask)
144
 
 
145
  success, out_bytes = cv2.imencode(".png", annotated_image)
146
  if not success:
147
- raise HTTPException(status_code=500, detail="Failed to encode image.")
148
 
149
  headers = {
150
- "X-Length-Cm": str(metrics["length_cm"]), "X-Breadth-Cm": str(metrics["breadth_cm"]),
151
- "X-Depth-Cm": str(metrics["depth_cm"]), "X-Area-Cm2": str(metrics["area_cm2"]),
 
 
152
  "X-Moisture": str(metrics["moisture"]),
153
  }
154
  return Response(content=out_bytes.tobytes(), media_type="image/png", headers=headers)
 
6
 
7
  # --- Load Models ---
8
  def load_models():
9
+ """Loads machine learning models safely."""
10
  segmentation_model, yolo_model = None, None
11
  try:
12
  import tensorflow as tf
13
+ # Ensure you have a valid model file at this path
14
  segmentation_model = tf.keras.models.load_model("segmentation_model.h5")
15
  print("✅ Segmentation model loaded.")
16
  except Exception as e:
17
  print(f"⚠️ Failed to load segmentation model: {e}")
18
  try:
19
  from ultralytics import YOLO
20
+ # Ensure you have a valid model file at this path
21
  yolo_model = YOLO("best.pt")
22
  print("✅ YOLO model loaded.")
23
  except Exception as e:
 
26
 
27
  segmentation_model, yolo_model = load_models()
28
 
29
+ # --- Configuration ---
30
+ # This value should be calibrated by taking a picture of a ruler.
31
+ # Measure a known length (e.g., 5cm) in pixels, then divide pixels by cm.
32
  PIXELS_PER_CM = 50.0
 
33
 
34
+ app = FastAPI(
35
+ title="Wound Analyzer",
36
+ version="10.2", # Version with improved depth logic
37
+ description="Analyzes wound images, keeping the original API contract."
38
+ )
39
+
40
+ # --- Image Processing ---
41
  def preprocess_image(image: np.ndarray) -> np.ndarray:
42
+ """Enhances image for better segmentation by improving contrast and reducing noise."""
43
  img_denoised = cv2.medianBlur(image, 3)
44
  lab = cv2.cvtColor(img_denoised, cv2.COLOR_BGR2LAB)
45
  l, a, b = cv2.split(lab)
46
+ clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8))
47
  l_clahe = clahe.apply(l)
48
  lab_clahe = cv2.merge((l_clahe, a, b))
49
  result = cv2.cvtColor(lab_clahe, cv2.COLOR_LAB2BGR)
50
  gamma = 1.2
51
  return np.clip((result / 255.0) ** gamma * 255, 0, 255).astype(np.uint8)
52
 
 
53
  def segment_wound(image: np.ndarray) -> np.ndarray:
54
+ """Segments wound from a preprocessed image, with a fallback to KMeans if the model fails."""
55
+ if segmentation_model:
56
+ try:
57
  input_size = segmentation_model.input_shape[1:3]
58
  resized = cv2.resize(image, (input_size[1], input_size[0]))
59
  norm = np.expand_dims(resized / 255.0, axis=0)
60
  prediction = segmentation_model.predict(norm, verbose=0)
61
+ # Handle models with multiple outputs
62
  if isinstance(prediction, list):
63
  prediction = prediction[0]
64
  prediction = prediction[0]
65
  mask = cv2.resize(prediction.squeeze(), (image.shape[1], image.shape[0]))
66
  return (mask >= 0.5).astype(np.uint8) * 255
67
+ except Exception as e:
68
+ print(f"⚠️ Segmentation model prediction failed: {e}. Falling back to KMeans.")
69
 
70
+ # Fallback method using color clustering if the primary model fails
71
  Z = image.reshape((-1, 3)).astype(np.float32)
72
+ criteria = (cv2.TERM_CRITERIA_EPS + cv2.TERM_CRITERIA_MAX_ITER, 10, 1.0)
73
+ _, labels, centers = cv2.kmeans(Z, 2, None, criteria, 5, cv2.KMEANS_PP_CENTERS)
74
  centers_lab = cv2.cvtColor(centers.reshape(1, -1, 3).astype(np.uint8), cv2.COLOR_BGR2LAB)[0]
75
+ wound_idx = np.argmax(centers_lab[:, 1]) # Assume wound is the reddest cluster
76
  mask = (labels.reshape(image.shape[:2]) == wound_idx).astype(np.uint8) * 255
77
  return mask
78
 
 
79
  def calculate_metrics(mask: np.ndarray, original_image: np.ndarray):
80
+ """
81
+ Calculates all metrics. Depth is now calculated based on shadow/intensity analysis.
82
+ """
83
  area_px = cv2.countNonZero(mask)
84
  if area_px == 0:
85
  return dict(area_cm2=0.0, length_cm=0.0, breadth_cm=0.0, depth_cm=0.0, moisture=0.0)
86
+
87
+ # --- Area and Dimensions ---
88
  area_cm2 = area_px / (PIXELS_PER_CM ** 2)
89
  contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
90
+ if not contours: # Handle case where mask is present but no contours are found
91
+ return dict(area_cm2=area_cm2, length_cm=0.0, breadth_cm=0.0, depth_cm=0.0, moisture=0.0)
92
+
93
  rect = cv2.minAreaRect(max(contours, key=cv2.contourArea))
94
  (w, h) = rect[1]
95
  length_cm, breadth_cm = max(w, h) / PIXELS_PER_CM, min(w, h) / PIXELS_PER_CM
96
 
97
+ # --- Depth and Moisture Calculation ---
98
  mask_bool = mask.astype(bool)
99
+ gray_image = cv2.cvtColor(original_image, cv2.COLOR_BGR2GRAY)
 
 
100
 
101
+ # **ENHANCED DEPTH CALCULATION**: Use standard deviation of pixel intensity.
102
+ # A higher standard deviation implies more shadows and highlights, suggesting greater depth.
103
+ # The result is scaled to a 0-10 range for a more intuitive, relative depth score.
104
+ intensity_std_dev = np.std(gray_image[mask_bool])
105
+ depth_score = (intensity_std_dev / 127.0) * 10.0 # Scale to a 0-10 range
106
+
107
+ # Moisture calculation remains the same, based on intensity variance
108
+ moisture = max(0.0, 100.0 * (1 - np.std(gray_image[mask_bool]) / 127.0))
109
+
110
+ return dict(
111
+ area_cm2=round(area_cm2, 2),
112
+ length_cm=round(length_cm, 2),
113
+ breadth_cm=round(breadth_cm, 2),
114
+ depth_cm=round(depth_score, 1), # This is now the new depth score
115
+ moisture=round(moisture, 1)
116
+ )
117
 
 
 
 
118
  def draw_overlay(image: np.ndarray, mask: np.ndarray) -> np.ndarray:
119
+ """Draws a heatmap and boundary on the image for visualization."""
120
+ dist_transform = cv2.distanceTransform(mask, cv2.DIST_L2, 5)
121
+ cv2.normalize(dist_transform, dist_transform, 0, 1.0, cv2.NORM_MINMAX)
122
+
123
+ heatmap = np.zeros_like(image, dtype=np.uint8)
124
+ heatmap[dist_transform >= 0.66] = (0, 255, 255) # Yellow - Core
125
+ heatmap[(dist_transform >= 0.33) & (dist_transform < 0.66)] = (255, 0, 0) # Blue - Moderate
126
+ heatmap[(dist_transform > 0) & (dist_transform < 0.33)] = (0, 255, 0) # Green - Periphery
127
+
128
+ # Create a blended image only where the mask is active
129
+ blended = image.copy()
130
+ alpha = 0.4
131
+ masked_pixels = mask.astype(bool)
132
+ blended[masked_pixels] = cv2.addWeighted(image, 1 - alpha, heatmap, alpha, 0)[masked_pixels]
133
+
134
+ # Draw a clean, white contour
135
  contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
136
+ cv2.drawContours(blended, contours, -1, (255, 255, 255), 2)
137
+ return blended
138
 
139
  # --- API Endpoint ---
140
  @app.post("/analyze_wound")
141
  async def analyze(file: UploadFile = File(...)):
142
+ """
143
+ Accepts an image, analyzes the wound, and returns an annotated image
144
+ with metrics in the response headers, maintaining the original API contract.
145
+ """
146
  contents = await file.read()
147
  original_image = cv2.imdecode(np.frombuffer(contents, np.uint8), cv2.IMREAD_COLOR)
148
  if original_image is None:
149
+ raise HTTPException(status_code=400, detail="Invalid or corrupt image file.")
150
 
151
+ # 1. Preprocess a copy of the image for object detection and segmentation
152
  preprocessed_image = preprocess_image(original_image)
153
 
154
+ # 2. Detect ROI using YOLO if available, otherwise use the whole image
155
+ roi_coords = (0, 0, original_image.shape[1], original_image.shape[0])
 
 
156
  if yolo_model:
157
  try:
158
  results = yolo_model.predict(preprocessed_image, verbose=False)
159
  if results and results[0].boxes:
160
  coords = results[0].boxes[0].xyxy[0].cpu().numpy().astype(int)
161
+ roi_coords = (coords[0], coords[1], coords[2], coords[3])
 
 
 
162
  except Exception as e:
163
+ print(f"⚠️ YOLO detection failed: {e}. Using full image as ROI.")
164
+
165
+ x1, y1, x2, y2 = roi_coords
166
+ original_roi = original_image[y1:y2, x1:x2]
167
+ preprocessed_roi = preprocessed_image[y1:y2, x1:x2]
168
+
169
+ if original_roi.size == 0:
170
+ raise HTTPException(status_code=404, detail="Wound region of interest could not be determined.")
171
 
172
  # 3. Get mask from the preprocessed ROI
173
  mask = segment_wound(preprocessed_roi)
174
 
175
+ # 4. Calculate metrics using the ORIGINAL ROI for color/intensity accuracy
176
  metrics = calculate_metrics(mask, original_roi)
177
 
178
  # 5. Draw overlay on the ORIGINAL ROI for correct visualization
179
  annotated_image = draw_overlay(original_roi, mask)
180
 
181
+ # 6. Encode image and prepare response
182
  success, out_bytes = cv2.imencode(".png", annotated_image)
183
  if not success:
184
+ raise HTTPException(status_code=500, detail="Failed to encode output image.")
185
 
186
  headers = {
187
+ "X-Length-Cm": str(metrics["length_cm"]),
188
+ "X-Breadth-Cm": str(metrics["breadth_cm"]),
189
+ "X-Depth-Cm": str(metrics["depth_cm"]),
190
+ "X-Area-Cm2": str(metrics["area_cm2"]),
191
  "X-Moisture": str(metrics["moisture"]),
192
  }
193
  return Response(content=out_bytes.tobytes(), media_type="image/png", headers=headers)