Ani14 commited on
Commit
7dda7cf
·
verified ·
1 Parent(s): ba0f49e

Update predict.py

Browse files
Files changed (1) hide show
  1. predict.py +107 -153
predict.py CHANGED
@@ -1,194 +1,148 @@
1
- # main.py
2
-
3
  from fastapi import FastAPI, File, UploadFile, HTTPException, Response
4
  import cv2
5
  import numpy as np
6
  import io
7
  from typing import Union
8
 
9
- # --- Model Loading ---
10
  def load_models():
11
- """Loads TensorFlow and YOLO models using your specified filenames."""
12
- segmentation_model, yolo_detector = None, None
13
-
14
- try:
15
- from ultralytics import YOLO
16
- yolo_detector = YOLO("best.pt")
17
- print("YOLOv8 detection model 'best.pt' loaded successfully.")
18
- except (ImportError, IOError, Exception) as e:
19
- print(f"Warning: YOLOv8 model not loaded. Using contour-based region detection. Error: {e}")
20
-
21
  try:
22
  import tensorflow as tf
23
  segmentation_model = tf.keras.models.load_model("segmentation_model.h5")
24
- print("TensorFlow segmentation model 'segmentation_model.h5' loaded successfully.")
25
- except (ImportError, IOError, Exception) as e:
26
- print(f"Warning: TensorFlow segmentation model not loaded. Using OpenCV fallback. Error: {e}")
27
-
28
- return segmentation_model, yolo_detector
 
 
 
 
 
29
 
30
  segmentation_model, yolo_model = load_models()
31
 
32
-
33
- # --- Configuration ---
34
  PIXELS_PER_CM = 50.0
 
35
 
36
- # --- App Initialization ---
37
- app = FastAPI(
38
- title="Wound Analysis API",
39
- description="A comprehensive API to analyze wound images using deep learning and computer vision techniques.",
40
- version="9.1.0" # Version with fix for model prediction output format
41
- )
42
-
43
-
44
- # --- Helper Functions ---
45
-
46
  def preprocess_image(image: np.ndarray) -> np.ndarray:
47
- """Applies the full preprocessing pipeline: Denoise -> CLAHE -> Gamma Correction."""
48
  img_denoised = cv2.medianBlur(image, 3)
49
  lab = cv2.cvtColor(img_denoised, cv2.COLOR_BGR2LAB)
50
- l_channel, a_channel, b_channel = cv2.split(lab)
51
- clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8))
52
- l_clahe = clahe.apply(l_channel)
53
- lab_clahe = cv2.merge([l_clahe, a_channel, b_channel])
54
- img_clahe = cv2.cvtColor(lab_clahe, cv2.COLOR_LAB2BGR)
55
  gamma = 1.2
56
- img_float = img_clahe.astype(np.float32) / 255.0
57
- img_gamma = np.power(img_float, gamma)
58
- return (img_gamma * 255).astype(np.uint8)
59
 
 
60
  def segment_wound(image: np.ndarray) -> np.ndarray:
61
- """Segments the wound using the TF model if available, otherwise falls back to color clustering."""
62
- if segmentation_model:
63
- try:
64
- orig_h, orig_w = image.shape[:2]
65
- model_input_size = segmentation_model.input.shape[1:3]
66
- img_resized = cv2.resize(image, (model_input_size[1], model_input_size[0]))
67
- img_norm = np.expand_dims(img_resized.astype(np.float32) / 255.0, axis=0)
68
-
69
- prediction = segmentation_model.predict(img_norm, verbose=0)
70
-
71
- # --- FIX: Handle model output that is a list ---
72
- # If the prediction is a list, extract the first element which is the actual numpy array.
73
- if isinstance(prediction, list):
74
- pred_mask = prediction[0]
75
- else:
76
- pred_mask = prediction
77
- # --- END FIX ---
78
-
79
- pred_mask_resized = cv2.resize(pred_mask, (orig_w, orig_h), interpolation=cv2.INTER_LINEAR)
80
- mask = (pred_mask_resized.squeeze() >= 0.5).astype(np.uint8) * 255
81
- if cv2.countNonZero(mask) > 0:
82
- return mask
83
- except Exception as e:
84
- print(f"Model prediction failed, switching to fallback segmentation. Error: {e}")
85
-
86
- # Fallback Method
87
- pixels = image.reshape((-1, 3)).astype(np.float32)
88
- criteria = (cv2.TERM_CRITERIA_EPS + cv2.TERM_CRITERIA_MAX_ITER, 10, 1.0)
89
- _, labels, centers = cv2.kmeans(pixels, 2, None, criteria, 5, cv2.KMEANS_PP_CENTERS)
90
  centers_lab = cv2.cvtColor(centers.reshape(1, -1, 3).astype(np.uint8), cv2.COLOR_BGR2LAB)[0]
91
- wound_cluster_idx = np.argmax(centers_lab[:, 1])
92
- mask = (labels.reshape(image.shape[:2]) == wound_cluster_idx).astype(np.uint8) * 255
93
- contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
94
- if contours:
95
- largest_contour = max(contours, key=cv2.contourArea)
96
- refined_mask = np.zeros_like(mask)
97
- cv2.drawContours(refined_mask, [largest_contour], -1, 255, cv2.FILLED)
98
- return refined_mask
99
  return mask
100
 
101
- def calculate_all_metrics(mask: np.ndarray, image: np.ndarray) -> dict:
102
- """Computes all specified wound metrics from the mask and original image."""
103
- wound_pixels = cv2.countNonZero(mask)
104
- if wound_pixels == 0:
105
- return {"area_cm2": 0.0, "length_cm": 0.0, "breadth_cm": 0.0, "depth_cm": 0.0, "moisture": 0.0}
106
-
107
- area_cm2 = wound_pixels / (PIXELS_PER_CM ** 2)
108
- contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
109
- largest_contour = max(contours, key=cv2.contourArea)
110
- (_, (width, height), _) = cv2.minAreaRect(largest_contour)
111
- length_cm = max(width, height) / PIXELS_PER_CM
112
- breadth_cm = min(width, height) / PIXELS_PER_CM
113
- mask_bool = mask.astype(bool)
114
- lab_img = cv2.cvtColor(image, cv2.COLOR_BGR2LAB)
115
- mean_a = np.mean(lab_img[:, :, 1][mask_bool])
116
- depth_score = mean_a - 128.0
117
- gray_img = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
118
- texture_std = np.std(gray_img[mask_bool])
119
- moisture_score = max(0.0, 100.0 * (1.0 - texture_std / 127.0))
120
 
121
- return {
122
- "area_cm2": f"{area_cm2:.2f}", "length_cm": f"{length_cm:.2f}", "breadth_cm": f"{breadth_cm:.2f}",
123
- "depth_cm": f"{depth_score:.1f}", "moisture": f"{moisture_score:.0f}"
124
- }
 
125
 
126
- def create_visual_overlay(image: np.ndarray, mask: np.ndarray) -> np.ndarray:
127
- """Generates a visual overlay with a Yellow/Blue/Green heatmap and a white boundary."""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
128
  dist = cv2.distanceTransform(mask, cv2.DIST_L2, 5)
129
  cv2.normalize(dist, dist, 0, 1.0, cv2.NORM_MINMAX)
130
- overlay = np.zeros_like(image)
131
-
132
- overlay[dist >= 0.66] = (0, 255, 255)
133
- overlay[(dist >= 0.33) & (dist < 0.66)] = (255, 0, 0)
134
- overlay[(dist > 0) & (dist < 0.33)] = (0, 255, 0)
135
-
136
- blended = cv2.addWeighted(image, 0.7, overlay, 0.3, 0)
137
- final_image = image.copy()
138
- final_image[mask.astype(bool)] = blended[mask.astype(bool)]
139
 
140
- contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
141
- cv2.drawContours(final_image, contours, -1, (255, 255, 255), 1)
 
 
 
 
 
142
 
143
- return final_image
 
 
144
 
145
- # --- Main API Endpoint ---
146
  @app.post("/analyze_wound")
147
- async def analyze_wound(file: UploadFile = File(...)):
148
- contents = await file.read()
149
- image_array = np.frombuffer(contents, np.uint8)
150
- original_image = cv2.imdecode(image_array, cv2.IMREAD_COLOR)
151
- if original_image is None:
152
- raise HTTPException(status_code=400, detail="Invalid or corrupt image file.")
153
-
154
- processed_image = preprocess_image(original_image)
155
-
156
- roi_image = processed_image
157
- original_roi = original_image
158
  if yolo_model:
159
  try:
160
- results = yolo_model.predict(processed_image, verbose=False)
161
- if results and results[0].boxes and len(results[0].boxes) > 0:
162
- best_box = sorted(results[0].boxes, key=lambda b: b.conf[0], reverse=True)[0]
163
- coords = best_box.xyxy[0].cpu().numpy()
164
- x1, y1, x2, y2 = map(int, coords)
165
- roi_image = processed_image[y1:y2, x1:x2]
166
- original_roi = original_image[y1:y2, x1:x2]
167
  except Exception as e:
168
- print(f"YOLO prediction failed, analyzing full image. Error: {e}")
169
 
170
- wound_mask = segment_wound(roi_image)
171
- if cv2.countNonZero(wound_mask) == 0:
172
- _, png_data = cv2.imencode(".png", original_image)
173
- headers = {
174
- 'X-Length-Cm': '0.0', 'X-Breadth-Cm': '0.0', 'X-Depth-Cm': '0.0',
175
- 'X-Area-Cm2': '0.0', 'X-Moisture': '0.0'
176
- }
177
- return Response(content=png_data.tobytes(), media_type="image/png", headers=headers)
178
 
179
- metrics = calculate_all_metrics(wound_mask, original_roi)
180
- annotated_image = create_visual_overlay(original_roi, wound_mask)
181
-
182
- success, png_data = cv2.imencode(".png", annotated_image)
183
  if not success:
184
- raise HTTPException(status_code=500, detail="Failed to encode output image.")
185
 
186
  headers = {
187
- 'X-Length-Cm': metrics['length_cm'],
188
- 'X-Breadth-Cm': metrics['breadth_cm'],
189
- 'X-Depth-Cm': metrics['depth_cm'],
190
- 'X-Area-Cm2': metrics['area_cm2'],
191
- 'X-Moisture': metrics['moisture']
192
  }
193
-
194
- return Response(content=png_data.tobytes(), media_type="image/png", headers=headers)
 
 
 
1
  from fastapi import FastAPI, File, UploadFile, HTTPException, Response
2
  import cv2
3
  import numpy as np
4
  import io
5
  from typing import Union
6
 
7
+ # --- Load Models ---
8
  def load_models():
9
+ segmentation_model, yolo_model = None, None
 
 
 
 
 
 
 
 
 
10
  try:
11
  import tensorflow as tf
12
  segmentation_model = tf.keras.models.load_model("segmentation_model.h5")
13
+ print(" Segmentation model loaded.")
14
+ except Exception as e:
15
+ print(f"⚠️ Failed to load segmentation model: {e}")
16
+ try:
17
+ from ultralytics import YOLO
18
+ yolo_model = YOLO("best.pt")
19
+ print("✅ YOLO model loaded.")
20
+ except Exception as e:
21
+ print(f"⚠️ Failed to load YOLO model: {e}")
22
+ return segmentation_model, yolo_model
23
 
24
  segmentation_model, yolo_model = load_models()
25
 
 
 
26
  PIXELS_PER_CM = 50.0
27
+ app = FastAPI(title="Wound Analyzer", version="10.0")
28
 
29
+ # --- Preprocessing ---
 
 
 
 
 
 
 
 
 
30
  def preprocess_image(image: np.ndarray) -> np.ndarray:
 
31
  img_denoised = cv2.medianBlur(image, 3)
32
  lab = cv2.cvtColor(img_denoised, cv2.COLOR_BGR2LAB)
33
+ l, a, b = cv2.split(lab)
34
+ clahe = cv2.createCLAHE(2.0, (8, 8))
35
+ l = clahe.apply(l)
36
+ lab = cv2.merge((l, a, b))
37
+ result = cv2.cvtColor(lab, cv2.COLOR_LAB2BGR)
38
  gamma = 1.2
39
+ return np.clip((result / 255.0) ** gamma * 255, 0, 255).astype(np.uint8)
 
 
40
 
41
+ # --- Segmentation ---
42
  def segment_wound(image: np.ndarray) -> np.ndarray:
43
+ try:
44
+ if segmentation_model:
45
+ input_size = segmentation_model.input_shape[1:3]
46
+ resized = cv2.resize(image, (input_size[1], input_size[0]))
47
+ norm = np.expand_dims(resized / 255.0, axis=0)
48
+
49
+ prediction = segmentation_model.predict(norm, verbose=0)
50
+ if isinstance(prediction, list): # <-- Fix
51
+ prediction = prediction[0]
52
+ prediction = prediction[0] # remove batch dim
53
+
54
+ mask = cv2.resize(prediction.squeeze(), (image.shape[1], image.shape[0]))
55
+ return (mask >= 0.5).astype(np.uint8) * 255
56
+ except Exception as e:
57
+ print(f"⚠️ Model prediction failed: {e}")
58
+
59
+ # Fallback segmentation
60
+ Z = image.reshape((-1, 3)).astype(np.float32)
61
+ _, labels, centers = cv2.kmeans(Z, 2, None,
62
+ (cv2.TERM_CRITERIA_EPS + cv2.TERM_CRITERIA_MAX_ITER, 10, 1.0),
63
+ 5, cv2.KMEANS_PP_CENTERS)
 
 
 
 
 
 
 
 
64
  centers_lab = cv2.cvtColor(centers.reshape(1, -1, 3).astype(np.uint8), cv2.COLOR_BGR2LAB)[0]
65
+ wound_idx = np.argmax(centers_lab[:, 1])
66
+ mask = (labels.reshape(image.shape[:2]) == wound_idx).astype(np.uint8) * 255
 
 
 
 
 
 
67
  return mask
68
 
69
+ # --- Metrics ---
70
+ def calculate_metrics(mask: np.ndarray, image: np.ndarray):
71
+ area_px = cv2.countNonZero(mask)
72
+ if area_px == 0:
73
+ return dict(area_cm2=0.0, length_cm=0.0, breadth_cm=0.0, depth_cm=0.0, moisture=0.0)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
74
 
75
+ area_cm2 = area_px / (PIXELS_PER_CM ** 2)
76
+ contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
77
+ rect = cv2.minAreaRect(max(contours, key=cv2.contourArea))
78
+ (w, h) = rect[1]
79
+ length_cm, breadth_cm = max(w, h) / PIXELS_PER_CM, min(w, h) / PIXELS_PER_CM
80
 
81
+ mask_bool = mask.astype(bool)
82
+ lab = cv2.cvtColor(image, cv2.COLOR_BGR2LAB)
83
+ gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
84
+ depth = np.mean(lab[:, :, 1][mask_bool]) - 128.0
85
+ moisture = max(0.0, 100.0 * (1 - np.std(gray[mask_bool]) / 127.0))
86
+
87
+ return dict(
88
+ area_cm2=round(area_cm2, 2),
89
+ length_cm=round(length_cm, 2),
90
+ breadth_cm=round(breadth_cm, 2),
91
+ depth_cm=round(depth, 1),
92
+ moisture=round(moisture, 0)
93
+ )
94
+
95
+ # --- Overlay ---
96
+ def draw_overlay(image: np.ndarray, mask: np.ndarray) -> np.ndarray:
97
  dist = cv2.distanceTransform(mask, cv2.DIST_L2, 5)
98
  cv2.normalize(dist, dist, 0, 1.0, cv2.NORM_MINMAX)
99
+ heatmap = np.zeros_like(image)
 
 
 
 
 
 
 
 
100
 
101
+ heatmap[dist >= 0.66] = (0, 0, 255) # Red - Most Affected
102
+ heatmap[(dist >= 0.33) & (dist < 0.66)] = (255, 0, 0) # Blue - Moderate
103
+ heatmap[(dist > 0) & (dist < 0.33)] = (0, 255, 0) # Green - Least
104
+
105
+ blended = cv2.addWeighted(image, 0.7, heatmap, 0.3, 0)
106
+ annotated = image.copy()
107
+ annotated[mask.astype(bool)] = blended[mask.astype(bool)]
108
 
109
+ contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
110
+ cv2.drawContours(annotated, contours, -1, (255, 255, 255), 2) # White outline
111
+ return annotated
112
 
113
+ # --- API Endpoint ---
114
  @app.post("/analyze_wound")
115
+ async def analyze(file: UploadFile = File(...)):
116
+ image = cv2.imdecode(np.frombuffer(await file.read(), np.uint8), cv2.IMREAD_COLOR)
117
+ if image is None:
118
+ raise HTTPException(status_code=400, detail="Invalid image file.")
119
+
120
+ image = preprocess_image(image)
121
+ crop = image.copy()
122
+
 
 
 
123
  if yolo_model:
124
  try:
125
+ results = yolo_model.predict(image, verbose=False)
126
+ if results and results[0].boxes:
127
+ coords = results[0].boxes[0].xyxy[0].cpu().numpy().astype(int)
128
+ x1, y1, x2, y2 = coords
129
+ crop = image[y1:y2, x1:x2]
 
 
130
  except Exception as e:
131
+ print(f"⚠️ YOLO detection failed: {e}")
132
 
133
+ mask = segment_wound(crop)
134
+ metrics = calculate_metrics(mask, crop)
135
+ annotated = draw_overlay(crop, mask)
 
 
 
 
 
136
 
137
+ success, out = cv2.imencode(".png", annotated)
 
 
 
138
  if not success:
139
+ raise HTTPException(status_code=500, detail="Failed to encode image.")
140
 
141
  headers = {
142
+ "X-Length-Cm": str(metrics["length_cm"]),
143
+ "X-Breadth-Cm": str(metrics["breadth_cm"]),
144
+ "X-Depth-Cm": str(metrics["depth_cm"]),
145
+ "X-Area-Cm2": str(metrics["area_cm2"]),
146
+ "X-Moisture": str(metrics["moisture"]),
147
  }
148
+ return Response(content=out.tobytes(), media_type="image/png", headers=headers)