Ani14 commited on
Commit
7ff1d87
·
verified ·
1 Parent(s): 831ed15

Update predict.py

Browse files
Files changed (1) hide show
  1. predict.py +147 -131
predict.py CHANGED
@@ -1,176 +1,192 @@
 
 
1
  from fastapi import FastAPI, File, UploadFile, HTTPException, Response
2
  import cv2
3
  import numpy as np
4
- from ultralytics import YOLO
5
- import tensorflow as tf
6
  import io
7
  from typing import Union
8
 
9
- # --- Configuration ---
10
- PIXELS_PER_CM = 50.0
11
-
12
- # --- App Initialization ---
13
- app = FastAPI(
14
- title="High-Quality Wound Heatmap API",
15
- description="Generates a high-quality wound heatmap using a DL model, preserving original image quality.",
16
- version="7.0.0" # Version updated for quality preservation
17
- )
18
-
19
  # --- Model Loading ---
 
 
20
  def load_models():
21
- """Loads both the YOLO and the TensorFlow segmentation models."""
22
- yolo_model, segmentation_model = None, None
 
23
  try:
24
- yolo_model = YOLO("best.pt")
25
- print("YOLO model 'best.pt' loaded successfully.")
26
- except Exception as e:
27
- print(f"Warning: Could not load YOLO model. Error: {e}")
 
28
 
29
  try:
 
30
  segmentation_model = tf.keras.models.load_model("segmentation_model.h5")
31
- print("Segmentation model 'segmentation_model.h5' loaded successfully.")
32
- except Exception as e:
33
- print(f"Warning: Could not load segmentation model. Error: {e}")
34
 
35
- return yolo_model, segmentation_model
 
 
 
 
 
 
 
 
 
 
 
 
 
36
 
37
- yolo_model, segmentation_model = load_models()
38
 
39
  # --- Helper Functions ---
40
 
41
- def detect_wound_region_yolo(image: np.ndarray) -> Union[tuple, None]:
42
- """Detects the primary wound bounding box using the YOLO model on the original quality image."""
43
- if not yolo_model: return None
44
- try:
45
- results = yolo_model.predict(image, verbose=False)
46
- if results and results[0].boxes and len(results[0].boxes) > 0:
47
- best_box = sorted(results[0].boxes, key=lambda b: b.conf[0], reverse=True)[0]
48
- coords = best_box.xyxy[0].cpu().numpy()
49
- return tuple(map(int, coords))
50
- except Exception as e:
51
- print(f"YOLO prediction failed: {e}")
52
- return None
53
-
54
- def segment_wound_with_model(image: np.ndarray) -> Union[np.ndarray, None]:
55
- """
56
- Segments the wound using the TF/Keras model.
57
- It resizes a copy for the model but returns a mask matching the original image's dimensions.
58
- """
59
- if not segmentation_model:
60
- print("Segmentation model not loaded, cannot create mask.")
61
- return None
62
- try:
63
- input_shape = segmentation_model.input_shape[1:3]
64
- # A temporary, resized copy is made for the model's prediction
65
- img_resized_for_model = cv2.resize(image, (input_shape[1], input_shape[0]))
66
- img_norm = np.expand_dims(img_resized_for_model.astype(np.float32) / 255.0, axis=0)
67
-
68
- prediction = segmentation_model.predict(img_norm, verbose=0)
69
-
70
- while isinstance(prediction, list):
71
- prediction = prediction[0]
72
- if isinstance(prediction, tf.Tensor):
73
- prediction = prediction.numpy()
74
-
75
- pred_mask = prediction[0]
76
- # The resulting mask is resized back to the original image's dimensions to ensure perfect alignment
77
- pred_mask_resized = cv2.resize(pred_mask, (image.shape[1], image.shape[0]))
78
- return (pred_mask_resized.squeeze() >= 0.5).astype(np.uint8) * 255
79
- except Exception as e:
80
- print(f"Segmentation model prediction failed: {e}")
81
- return None
82
-
83
- def calculate_metrics(mask: np.ndarray) -> dict:
84
- """Calculates dimensional metrics from the full-resolution wound mask."""
85
- area_pixels = cv2.countNonZero(mask)
86
- if area_pixels == 0:
 
 
87
  return {"area_cm2": 0.0, "length_cm": 0.0, "breadth_cm": 0.0, "depth_cm": 0.0, "moisture": 0.0}
88
 
89
- area_cm2 = area_pixels / (PIXELS_PER_CM ** 2)
90
  contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
91
- if not contours:
92
- return {"area_cm2": area_cm2, "length_cm": 0.0, "breadth_cm": 0.0, "depth_cm": 0.0, "moisture": 0.0}
93
-
94
  largest_contour = max(contours, key=cv2.contourArea)
95
  (_, (width, height), _) = cv2.minAreaRect(largest_contour)
96
-
97
  length_cm = max(width, height) / PIXELS_PER_CM
98
  breadth_cm = min(width, height) / PIXELS_PER_CM
99
- depth_cm = 0.1 # Placeholder
100
- moisture = 75.0 # Placeholder
101
-
102
- return {"area_cm2": area_cm2, "length_cm": length_cm, "breadth_cm": breadth_cm, "depth_cm": depth_cm, "moisture": moisture}
103
-
104
- def create_three_color_heatmap(image: np.ndarray, mask: np.ndarray) -> np.ndarray:
105
- """Generates and overlays a three-color heatmap, preserving the underlying image quality."""
106
- if cv2.countNonZero(mask) == 0:
107
- return image
 
 
 
108
 
109
- lab_image = cv2.cvtColor(image, cv2.COLOR_BGR2LAB)
110
- a_channel = lab_image[:, :, 1]
 
 
 
 
 
111
  overlay = np.zeros_like(image)
112
 
113
- RED_THRESHOLD = 160
114
- BLUE_THRESHOLD = 145
115
-
116
- overlay[(a_channel >= RED_THRESHOLD) & (mask == 255)] = (0, 0, 255)
117
- overlay[(a_channel >= BLUE_THRESHOLD) & (a_channel < RED_THRESHOLD) & (mask == 255)] = (255, 0, 0)
118
- overlay[(a_channel < BLUE_THRESHOLD) & (mask == 255)] = (0, 255, 0)
119
-
120
- # Blend the overlay with the original image
121
- blended_image = cv2.addWeighted(overlay, 0.4, image, 0.6, 0)
122
 
123
- # Create the final image by taking the original and replacing only the masked area with the blended version
124
  final_image = image.copy()
125
- final_image[mask == 255] = blended_image[mask == 255]
 
 
 
 
 
 
126
 
127
  return final_image
128
 
129
  # --- Main API Endpoint ---
130
  @app.post("/analyze_wound")
131
  async def analyze_wound(file: UploadFile = File(...)):
132
- if not yolo_model or not segmentation_model:
133
- raise HTTPException(status_code=503, detail="A required model is not available.")
134
-
135
  contents = await file.read()
136
- # Read the image from buffer into a cv2 object without any lossy conversions
137
- original_image = cv2.imdecode(np.frombuffer(contents, np.uint8), cv2.IMREAD_COLOR)
138
  if original_image is None:
139
- raise HTTPException(status_code=400, detail="Invalid or corrupt image file")
140
-
141
- # No preprocessing is done to the original image to preserve quality.
142
- # A copy of the original image is used for detection.
143
- bbox = detect_wound_region_yolo(original_image.copy())
144
- if not bbox:
145
- raise HTTPException(status_code=404, detail="No wound detected in the image.")
146
-
147
- xmin, ymin, xmax, ymax = bbox
148
- # Crop the region of interest from the original, high-quality image
149
- cropped_image_roi = original_image[ymin:ymax, xmin:xmax]
150
-
151
- # Step 1: Use the DL model on the high-quality crop to get a precise mask
152
- wound_mask = segment_wound_with_model(cropped_image_roi)
153
- if wound_mask is None or cv2.countNonZero(wound_mask) == 0:
154
- raise HTTPException(status_code=404, detail="Segmentation model failed to identify a wound in the detected region.")
155
-
156
- # Step 2: Calculate metrics based on the full-resolution mask
157
- metrics = calculate_metrics(wound_mask)
158
 
159
- # Step 3: Generate the heatmap on the high-quality cropped image
160
- heatmap_image = create_three_color_heatmap(cropped_image_roi, wound_mask)
161
 
162
- # Step 4: Encode the final image into PNG (lossless format) to preserve quality
163
- success, png_data = cv2.imencode(".png", heatmap_image)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
164
  if not success:
165
- raise HTTPException(status_code=500, detail="Failed to encode output image")
166
 
167
- # Step 5: Set the custom headers
168
  headers = {
169
- 'X-Length-Cm': f"{metrics['length_cm']:.2f}",
170
- 'X-Breadth-Cm': f"{metrics['breadth_cm']:.2f}",
171
- 'X-Depth-Cm': f"{metrics['depth_cm']:.2f}",
172
- 'X-Area-Cm2': f"{metrics['area_cm2']:.2f}",
173
- 'X-Moisture': f"{metrics['moisture']:.1f}"
174
  }
175
 
176
  return Response(content=png_data.tobytes(), media_type="image/png", headers=headers)
 
1
+ # main.py
2
+
3
  from fastapi import FastAPI, File, UploadFile, HTTPException, Response
4
  import cv2
5
  import numpy as np
 
 
6
  import io
7
  from typing import Union
8
 
 
 
 
 
 
 
 
 
 
 
9
  # --- Model Loading ---
10
+ # This section attempts to load the specific deep learning models you are using.
11
+
12
  def load_models():
13
+ """Loads TensorFlow and YOLO models using your specified filenames."""
14
+ segmentation_model, yolo_detector = None, None
15
+
16
  try:
17
+ from ultralytics import YOLO
18
+ yolo_detector = YOLO("best.pt")
19
+ print("YOLOv8 detection model 'best.pt' loaded successfully.")
20
+ except (ImportError, IOError, Exception) as e:
21
+ print(f"Warning: YOLOv8 model not loaded. Using contour-based region detection. Error: {e}")
22
 
23
  try:
24
+ import tensorflow as tf
25
  segmentation_model = tf.keras.models.load_model("segmentation_model.h5")
26
+ print("TensorFlow segmentation model 'segmentation_model.h5' loaded successfully.")
27
+ except (ImportError, IOError, Exception) as e:
28
+ print(f"Warning: TensorFlow segmentation model not loaded. Using OpenCV fallback. Error: {e}")
29
 
30
+ return segmentation_model, yolo_detector
31
+
32
+ segmentation_model, yolo_model = load_models()
33
+
34
+
35
+ # --- Configuration ---
36
+ PIXELS_PER_CM = 50.0
37
+
38
+ # --- App Initialization ---
39
+ app = FastAPI(
40
+ title="Wound Analysis API",
41
+ description="A comprehensive API to analyze wound images using deep learning and computer vision techniques.",
42
+ version="9.0.0" # Version with improved visualization (Yellow Heatmap + Boundary)
43
+ )
44
 
 
45
 
46
  # --- Helper Functions ---
47
 
48
+ def preprocess_image(image: np.ndarray) -> np.ndarray:
49
+ """Applies the full preprocessing pipeline: Denoise -> CLAHE -> Gamma Correction."""
50
+ img_denoised = cv2.medianBlur(image, 3)
51
+ lab = cv2.cvtColor(img_denoised, cv2.COLOR_BGR2LAB)
52
+ l_channel, a_channel, b_channel = cv2.split(lab)
53
+ clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8))
54
+ l_clahe = clahe.apply(l_channel)
55
+ lab_clahe = cv2.merge([l_clahe, a_channel, b_channel])
56
+ img_clahe = cv2.cvtColor(lab_clahe, cv2.COLOR_LAB2BGR)
57
+ gamma = 1.2
58
+ img_float = img_clahe.astype(np.float32) / 255.0
59
+ img_gamma = np.power(img_float, gamma)
60
+ return (img_gamma * 255).astype(np.uint8)
61
+
62
+ def segment_wound(image: np.ndarray) -> np.ndarray:
63
+ """Segments the wound using the TF model if available, otherwise falls back to color clustering."""
64
+ if segmentation_model:
65
+ try:
66
+ orig_h, orig_w = image.shape[:2]
67
+ model_input_size = segmentation_model.input.shape[1:3]
68
+ img_resized = cv2.resize(image, (model_input_size[1], model_input_size[0]))
69
+ img_norm = np.expand_dims(img_resized.astype(np.float32) / 255.0, axis=0)
70
+ pred_mask = segmentation_model.predict(img_norm, verbose=0)[0]
71
+ pred_mask_resized = cv2.resize(pred_mask, (orig_w, orig_h), interpolation=cv2.INTER_LINEAR)
72
+ mask = (pred_mask_resized.squeeze() >= 0.5).astype(np.uint8) * 255
73
+ if cv2.countNonZero(mask) > 0:
74
+ return mask
75
+ except Exception as e:
76
+ print(f"Model prediction failed, switching to fallback segmentation. Error: {e}")
77
+
78
+ pixels = image.reshape((-1, 3)).astype(np.float32)
79
+ criteria = (cv2.TERM_CRITERIA_EPS + cv2.TERM_CRITERIA_MAX_ITER, 10, 1.0)
80
+ _, labels, centers = cv2.kmeans(pixels, 2, None, criteria, 5, cv2.KMEANS_PP_CENTERS)
81
+ centers_lab = cv2.cvtColor(centers.reshape(1, -1, 3).astype(np.uint8), cv2.COLOR_BGR2LAB)[0]
82
+ wound_cluster_idx = np.argmax(centers_lab[:, 1])
83
+ mask = (labels.reshape(image.shape[:2]) == wound_cluster_idx).astype(np.uint8) * 255
84
+ contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
85
+ if contours:
86
+ largest_contour = max(contours, key=cv2.contourArea)
87
+ refined_mask = np.zeros_like(mask)
88
+ cv2.drawContours(refined_mask, [largest_contour], -1, 255, cv2.FILLED)
89
+ return refined_mask
90
+ return mask
91
+
92
+ def calculate_all_metrics(mask: np.ndarray, image: np.ndarray) -> dict:
93
+ """Computes all specified wound metrics from the mask and original image."""
94
+ wound_pixels = cv2.countNonZero(mask)
95
+ if wound_pixels == 0:
96
  return {"area_cm2": 0.0, "length_cm": 0.0, "breadth_cm": 0.0, "depth_cm": 0.0, "moisture": 0.0}
97
 
98
+ area_cm2 = wound_pixels / (PIXELS_PER_CM ** 2)
99
  contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
 
 
 
100
  largest_contour = max(contours, key=cv2.contourArea)
101
  (_, (width, height), _) = cv2.minAreaRect(largest_contour)
 
102
  length_cm = max(width, height) / PIXELS_PER_CM
103
  breadth_cm = min(width, height) / PIXELS_PER_CM
104
+ mask_bool = mask.astype(bool)
105
+ lab_img = cv2.cvtColor(image, cv2.COLOR_BGR2LAB)
106
+ mean_a = np.mean(lab_img[:, :, 1][mask_bool])
107
+ depth_score = mean_a - 128.0
108
+ gray_img = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
109
+ texture_std = np.std(gray_img[mask_bool])
110
+ moisture_score = max(0.0, 100.0 * (1.0 - texture_std / 127.0))
111
+
112
+ return {
113
+ "area_cm2": f"{area_cm2:.2f}", "length_cm": f"{length_cm:.2f}", "breadth_cm": f"{breadth_cm:.2f}",
114
+ "depth_cm": f"{depth_score:.1f}", "moisture": f"{moisture_score:.0f}"
115
+ }
116
 
117
+ def create_visual_overlay(image: np.ndarray, mask: np.ndarray) -> np.ndarray:
118
+ """
119
+ Generates a visual overlay with a Yellow/Blue/Green heatmap and a white boundary.
120
+ """
121
+ # --- 1. Create the Color Heatmap ---
122
+ dist = cv2.distanceTransform(mask, cv2.DIST_L2, 5)
123
+ cv2.normalize(dist, dist, 0, 1.0, cv2.NORM_MINMAX)
124
  overlay = np.zeros_like(image)
125
 
126
+ # **CHANGE**: Use Yellow instead of Red for the most affected area for better visibility.
127
+ overlay[dist >= 0.66] = (0, 255, 255) # Yellow in BGR
128
+ overlay[(dist >= 0.33) & (dist < 0.66)] = (255, 0, 0) # Blue in BGR
129
+ overlay[(dist > 0) & (dist < 0.33)] = (0, 255, 0) # Green in BGR
 
 
 
 
 
130
 
131
+ blended = cv2.addWeighted(image, 0.7, overlay, 0.3, 0)
132
  final_image = image.copy()
133
+ final_image[mask.astype(bool)] = blended[mask.astype(bool)]
134
+
135
+ # --- 2. Draw the Boundary Contour ---
136
+ # Find contours from the mask to draw the boundary line.
137
+ contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
138
+ # **NEW**: Draw a crisp white boundary on the final image.
139
+ cv2.drawContours(final_image, contours, -1, (255, 255, 255), 1) # White color, 1px thickness
140
 
141
  return final_image
142
 
143
  # --- Main API Endpoint ---
144
  @app.post("/analyze_wound")
145
  async def analyze_wound(file: UploadFile = File(...)):
 
 
 
146
  contents = await file.read()
147
+ image_array = np.frombuffer(contents, np.uint8)
148
+ original_image = cv2.imdecode(image_array, cv2.IMREAD_COLOR)
149
  if original_image is None:
150
+ raise HTTPException(status_code=400, detail="Invalid or corrupt image file.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
151
 
152
+ processed_image = preprocess_image(original_image)
 
153
 
154
+ roi_image = processed_image
155
+ original_roi = original_image
156
+ if yolo_model:
157
+ try:
158
+ results = yolo_model.predict(processed_image, verbose=False)
159
+ if results and results[0].boxes and len(results[0].boxes) > 0:
160
+ best_box = sorted(results[0].boxes, key=lambda b: b.conf[0], reverse=True)[0]
161
+ coords = best_box.xyxy[0].cpu().numpy()
162
+ x1, y1, x2, y2 = map(int, coords)
163
+ roi_image = processed_image[y1:y2, x1:x2]
164
+ original_roi = original_image[y1:y2, x1:x2]
165
+ except Exception as e:
166
+ print(f"YOLO prediction failed, analyzing full image. Error: {e}")
167
+
168
+ wound_mask = segment_wound(roi_image)
169
+ if cv2.countNonZero(wound_mask) == 0:
170
+ _, png_data = cv2.imencode(".png", original_image)
171
+ headers = {
172
+ 'X-Length-Cm': '0.0', 'X-Breadth-Cm': '0.0', 'X-Depth-Cm': '0.0',
173
+ 'X-Area-Cm2': '0.0', 'X-Moisture': '0.0'
174
+ }
175
+ return Response(content=png_data.tobytes(), media_type="image/png", headers=headers)
176
+
177
+ metrics = calculate_all_metrics(wound_mask, original_roi)
178
+ annotated_image = create_visual_overlay(original_roi, wound_mask)
179
+
180
+ success, png_data = cv2.imencode(".png", annotated_image)
181
  if not success:
182
+ raise HTTPException(status_code=500, detail="Failed to encode output image.")
183
 
 
184
  headers = {
185
+ 'X-Length-Cm': metrics['length_cm'],
186
+ 'X-Breadth-Cm': metrics['breadth_cm'],
187
+ 'X-Depth-Cm': metrics['depth_cm'],
188
+ 'X-Area-Cm2': metrics['area_cm2'],
189
+ 'X-Moisture': metrics['moisture']
190
  }
191
 
192
  return Response(content=png_data.tobytes(), media_type="image/png", headers=headers)