Ani14 commited on
Commit
bfe1160
·
verified ·
1 Parent(s): 7dda7cf

Update predict.py

Browse files
Files changed (1) hide show
  1. predict.py +45 -39
predict.py CHANGED
@@ -24,50 +24,48 @@ def load_models():
24
  segmentation_model, yolo_model = load_models()
25
 
26
  PIXELS_PER_CM = 50.0
27
- app = FastAPI(title="Wound Analyzer", version="10.0")
28
 
29
  # --- Preprocessing ---
30
  def preprocess_image(image: np.ndarray) -> np.ndarray:
 
31
  img_denoised = cv2.medianBlur(image, 3)
32
  lab = cv2.cvtColor(img_denoised, cv2.COLOR_BGR2LAB)
33
  l, a, b = cv2.split(lab)
34
  clahe = cv2.createCLAHE(2.0, (8, 8))
35
- l = clahe.apply(l)
36
- lab = cv2.merge((l, a, b))
37
- result = cv2.cvtColor(lab, cv2.COLOR_LAB2BGR)
38
  gamma = 1.2
39
  return np.clip((result / 255.0) ** gamma * 255, 0, 255).astype(np.uint8)
40
 
41
  # --- Segmentation ---
42
  def segment_wound(image: np.ndarray) -> np.ndarray:
 
43
  try:
44
  if segmentation_model:
45
  input_size = segmentation_model.input_shape[1:3]
46
  resized = cv2.resize(image, (input_size[1], input_size[0]))
47
  norm = np.expand_dims(resized / 255.0, axis=0)
48
-
49
  prediction = segmentation_model.predict(norm, verbose=0)
50
- if isinstance(prediction, list): # <-- Fix
51
  prediction = prediction[0]
52
- prediction = prediction[0] # remove batch dim
53
-
54
  mask = cv2.resize(prediction.squeeze(), (image.shape[1], image.shape[0]))
55
  return (mask >= 0.5).astype(np.uint8) * 255
56
  except Exception as e:
57
  print(f"⚠️ Model prediction failed: {e}")
58
 
59
- # Fallback segmentation
60
  Z = image.reshape((-1, 3)).astype(np.float32)
61
- _, labels, centers = cv2.kmeans(Z, 2, None,
62
- (cv2.TERM_CRITERIA_EPS + cv2.TERM_CRITERIA_MAX_ITER, 10, 1.0),
63
- 5, cv2.KMEANS_PP_CENTERS)
64
  centers_lab = cv2.cvtColor(centers.reshape(1, -1, 3).astype(np.uint8), cv2.COLOR_BGR2LAB)[0]
65
  wound_idx = np.argmax(centers_lab[:, 1])
66
  mask = (labels.reshape(image.shape[:2]) == wound_idx).astype(np.uint8) * 255
67
  return mask
68
 
69
  # --- Metrics ---
70
- def calculate_metrics(mask: np.ndarray, image: np.ndarray):
 
71
  area_px = cv2.countNonZero(mask)
72
  if area_px == 0:
73
  return dict(area_cm2=0.0, length_cm=0.0, breadth_cm=0.0, depth_cm=0.0, moisture=0.0)
@@ -79,26 +77,24 @@ def calculate_metrics(mask: np.ndarray, image: np.ndarray):
79
  length_cm, breadth_cm = max(w, h) / PIXELS_PER_CM, min(w, h) / PIXELS_PER_CM
80
 
81
  mask_bool = mask.astype(bool)
82
- lab = cv2.cvtColor(image, cv2.COLOR_BGR2LAB)
83
- gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
 
 
84
  depth = np.mean(lab[:, :, 1][mask_bool]) - 128.0
85
  moisture = max(0.0, 100.0 * (1 - np.std(gray[mask_bool]) / 127.0))
86
 
87
- return dict(
88
- area_cm2=round(area_cm2, 2),
89
- length_cm=round(length_cm, 2),
90
- breadth_cm=round(breadth_cm, 2),
91
- depth_cm=round(depth, 1),
92
- moisture=round(moisture, 0)
93
- )
94
 
95
  # --- Overlay ---
96
  def draw_overlay(image: np.ndarray, mask: np.ndarray) -> np.ndarray:
 
 
97
  dist = cv2.distanceTransform(mask, cv2.DIST_L2, 5)
98
  cv2.normalize(dist, dist, 0, 1.0, cv2.NORM_MINMAX)
99
  heatmap = np.zeros_like(image)
100
 
101
- heatmap[dist >= 0.66] = (0, 0, 255) # Red - Most Affected
102
  heatmap[(dist >= 0.33) & (dist < 0.66)] = (255, 0, 0) # Blue - Moderate
103
  heatmap[(dist > 0) & (dist < 0.33)] = (0, 255, 0) # Green - Least
104
 
@@ -107,42 +103,52 @@ def draw_overlay(image: np.ndarray, mask: np.ndarray) -> np.ndarray:
107
  annotated[mask.astype(bool)] = blended[mask.astype(bool)]
108
 
109
  contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
110
- cv2.drawContours(annotated, contours, -1, (255, 255, 255), 2) # White outline
111
  return annotated
112
 
113
  # --- API Endpoint ---
114
  @app.post("/analyze_wound")
115
  async def analyze(file: UploadFile = File(...)):
116
- image = cv2.imdecode(np.frombuffer(await file.read(), np.uint8), cv2.IMREAD_COLOR)
117
- if image is None:
 
118
  raise HTTPException(status_code=400, detail="Invalid image file.")
119
 
120
- image = preprocess_image(image)
121
- crop = image.copy()
 
 
 
 
122
 
123
  if yolo_model:
124
  try:
125
- results = yolo_model.predict(image, verbose=False)
126
  if results and results[0].boxes:
127
  coords = results[0].boxes[0].xyxy[0].cpu().numpy().astype(int)
128
  x1, y1, x2, y2 = coords
129
- crop = image[y1:y2, x1:x2]
 
 
130
  except Exception as e:
131
  print(f"⚠️ YOLO detection failed: {e}")
132
 
133
- mask = segment_wound(crop)
134
- metrics = calculate_metrics(mask, crop)
135
- annotated = draw_overlay(crop, mask)
 
 
 
 
 
136
 
137
- success, out = cv2.imencode(".png", annotated)
138
  if not success:
139
  raise HTTPException(status_code=500, detail="Failed to encode image.")
140
 
141
  headers = {
142
- "X-Length-Cm": str(metrics["length_cm"]),
143
- "X-Breadth-Cm": str(metrics["breadth_cm"]),
144
- "X-Depth-Cm": str(metrics["depth_cm"]),
145
- "X-Area-Cm2": str(metrics["area_cm2"]),
146
  "X-Moisture": str(metrics["moisture"]),
147
  }
148
- return Response(content=out.tobytes(), media_type="image/png", headers=headers)
 
24
  segmentation_model, yolo_model = load_models()
25
 
26
  PIXELS_PER_CM = 50.0
27
+ app = FastAPI(title="Wound Analyzer", version="10.1") # Version with depth calculation fix
28
 
29
  # --- Preprocessing ---
30
  def preprocess_image(image: np.ndarray) -> np.ndarray:
31
+ """Enhances image for better segmentation."""
32
  img_denoised = cv2.medianBlur(image, 3)
33
  lab = cv2.cvtColor(img_denoised, cv2.COLOR_BGR2LAB)
34
  l, a, b = cv2.split(lab)
35
  clahe = cv2.createCLAHE(2.0, (8, 8))
36
+ l_clahe = clahe.apply(l)
37
+ lab_clahe = cv2.merge((l_clahe, a, b))
38
+ result = cv2.cvtColor(lab_clahe, cv2.COLOR_LAB2BGR)
39
  gamma = 1.2
40
  return np.clip((result / 255.0) ** gamma * 255, 0, 255).astype(np.uint8)
41
 
42
  # --- Segmentation ---
43
  def segment_wound(image: np.ndarray) -> np.ndarray:
44
+ """Segments wound from a preprocessed image, with a fallback."""
45
  try:
46
  if segmentation_model:
47
  input_size = segmentation_model.input_shape[1:3]
48
  resized = cv2.resize(image, (input_size[1], input_size[0]))
49
  norm = np.expand_dims(resized / 255.0, axis=0)
 
50
  prediction = segmentation_model.predict(norm, verbose=0)
51
+ if isinstance(prediction, list):
52
  prediction = prediction[0]
53
+ prediction = prediction[0]
 
54
  mask = cv2.resize(prediction.squeeze(), (image.shape[1], image.shape[0]))
55
  return (mask >= 0.5).astype(np.uint8) * 255
56
  except Exception as e:
57
  print(f"⚠️ Model prediction failed: {e}")
58
 
 
59
  Z = image.reshape((-1, 3)).astype(np.float32)
60
+ _, labels, centers = cv2.kmeans(Z, 2, None, (cv2.TERM_CRITERIA_EPS + cv2.TERM_CRITERIA_MAX_ITER, 10, 1.0), 5, cv2.KMEANS_PP_CENTERS)
 
 
61
  centers_lab = cv2.cvtColor(centers.reshape(1, -1, 3).astype(np.uint8), cv2.COLOR_BGR2LAB)[0]
62
  wound_idx = np.argmax(centers_lab[:, 1])
63
  mask = (labels.reshape(image.shape[:2]) == wound_idx).astype(np.uint8) * 255
64
  return mask
65
 
66
  # --- Metrics ---
67
+ def calculate_metrics(mask: np.ndarray, original_image: np.ndarray):
68
+ """Calculates metrics using the mask and the ORIGINAL image for color accuracy."""
69
  area_px = cv2.countNonZero(mask)
70
  if area_px == 0:
71
  return dict(area_cm2=0.0, length_cm=0.0, breadth_cm=0.0, depth_cm=0.0, moisture=0.0)
 
77
  length_cm, breadth_cm = max(w, h) / PIXELS_PER_CM, min(w, h) / PIXELS_PER_CM
78
 
79
  mask_bool = mask.astype(bool)
80
+ # **FIX**: Use the original, unaltered image for color-based metrics.
81
+ lab = cv2.cvtColor(original_image, cv2.COLOR_BGR2LAB)
82
+ gray = cv2.cvtColor(original_image, cv2.COLOR_BGR2GRAY)
83
+
84
  depth = np.mean(lab[:, :, 1][mask_bool]) - 128.0
85
  moisture = max(0.0, 100.0 * (1 - np.std(gray[mask_bool]) / 127.0))
86
 
87
+ return dict(area_cm2=round(area_cm2, 2), length_cm=round(length_cm, 2), breadth_cm=round(breadth_cm, 2), depth_cm=round(depth, 1), moisture=round(moisture, 0))
 
 
 
 
 
 
88
 
89
  # --- Overlay ---
90
  def draw_overlay(image: np.ndarray, mask: np.ndarray) -> np.ndarray:
91
+ """Draws heatmap and boundary on the image."""
92
+ # **CHANGE**: Use Yellow for high-intensity areas for better visibility
93
  dist = cv2.distanceTransform(mask, cv2.DIST_L2, 5)
94
  cv2.normalize(dist, dist, 0, 1.0, cv2.NORM_MINMAX)
95
  heatmap = np.zeros_like(image)
96
 
97
+ heatmap[dist >= 0.66] = (0, 255, 255) # Yellow - Most Affected
98
  heatmap[(dist >= 0.33) & (dist < 0.66)] = (255, 0, 0) # Blue - Moderate
99
  heatmap[(dist > 0) & (dist < 0.33)] = (0, 255, 0) # Green - Least
100
 
 
103
  annotated[mask.astype(bool)] = blended[mask.astype(bool)]
104
 
105
  contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
106
+ cv2.drawContours(annotated, contours, -1, (255, 255, 255), 2)
107
  return annotated
108
 
109
  # --- API Endpoint ---
110
  @app.post("/analyze_wound")
111
  async def analyze(file: UploadFile = File(...)):
112
+ contents = await file.read()
113
+ original_image = cv2.imdecode(np.frombuffer(contents, np.uint8), cv2.IMREAD_COLOR)
114
+ if original_image is None:
115
  raise HTTPException(status_code=400, detail="Invalid image file.")
116
 
117
+ # 1. Create a preprocessed version for segmentation
118
+ preprocessed_image = preprocess_image(original_image)
119
+
120
+ # 2. Define Regions of Interest (ROI) for both original and preprocessed images
121
+ original_roi = original_image.copy()
122
+ preprocessed_roi = preprocessed_image.copy()
123
 
124
  if yolo_model:
125
  try:
126
+ results = yolo_model.predict(preprocessed_image, verbose=False)
127
  if results and results[0].boxes:
128
  coords = results[0].boxes[0].xyxy[0].cpu().numpy().astype(int)
129
  x1, y1, x2, y2 = coords
130
+ # Crop both the original and preprocessed images to the same ROI
131
+ original_roi = original_image[y1:y2, x1:x2]
132
+ preprocessed_roi = preprocessed_image[y1:y2, x1:x2]
133
  except Exception as e:
134
  print(f"⚠️ YOLO detection failed: {e}")
135
 
136
+ # 3. Get mask from the preprocessed ROI
137
+ mask = segment_wound(preprocessed_roi)
138
+
139
+ # 4. Calculate metrics using the ORIGINAL ROI for color accuracy
140
+ metrics = calculate_metrics(mask, original_roi)
141
+
142
+ # 5. Draw overlay on the ORIGINAL ROI for correct visualization
143
+ annotated_image = draw_overlay(original_roi, mask)
144
 
145
+ success, out_bytes = cv2.imencode(".png", annotated_image)
146
  if not success:
147
  raise HTTPException(status_code=500, detail="Failed to encode image.")
148
 
149
  headers = {
150
+ "X-Length-Cm": str(metrics["length_cm"]), "X-Breadth-Cm": str(metrics["breadth_cm"]),
151
+ "X-Depth-Cm": str(metrics["depth_cm"]), "X-Area-Cm2": str(metrics["area_cm2"]),
 
 
152
  "X-Moisture": str(metrics["moisture"]),
153
  }
154
+ return Response(content=out_bytes.tobytes(), media_type="image/png", headers=headers)