Ani14 commited on
Commit
831ed15
·
verified ·
1 Parent(s): af8bbd0

Update predict.py

Browse files
Files changed (1) hide show
  1. predict.py +71 -69
predict.py CHANGED
@@ -2,6 +2,7 @@ from fastapi import FastAPI, File, UploadFile, HTTPException, Response
2
  import cv2
3
  import numpy as np
4
  from ultralytics import YOLO
 
5
  import io
6
  from typing import Union
7
 
@@ -10,32 +11,35 @@ PIXELS_PER_CM = 50.0
10
 
11
  # --- App Initialization ---
12
  app = FastAPI(
13
- title="Wound Heatmap Analysis API",
14
- description="An API that generates a three-color heatmap (Red, Blue, Green) on a wound to show tissue characteristics.",
15
- version="5.0.0" # Version updated for three-layer color heatmap
16
  )
17
 
18
  # --- Model Loading ---
19
- def load_yolo_model():
20
- """Loads the YOLO model for initial wound detection."""
 
21
  try:
22
  yolo_model = YOLO("best.pt")
23
  print("YOLO model 'best.pt' loaded successfully.")
24
- return yolo_model
25
  except Exception as e:
26
- print(f"Fatal: Could not load YOLO model. Error: {e}")
27
- return None
28
 
29
- yolo_model = load_yolo_model()
 
 
 
 
 
 
30
 
31
- # --- Helper Functions ---
32
 
33
- def preprocess_image(image: np.ndarray) -> np.ndarray:
34
- """Applies a median blur to reduce noise."""
35
- return cv2.medianBlur(image, 5)
36
 
37
  def detect_wound_region_yolo(image: np.ndarray) -> Union[tuple, None]:
38
- """Detects the primary wound bounding box using the YOLO model."""
39
  if not yolo_model: return None
40
  try:
41
  results = yolo_model.predict(image, verbose=False)
@@ -47,34 +51,42 @@ def detect_wound_region_yolo(image: np.ndarray) -> Union[tuple, None]:
47
  print(f"YOLO prediction failed: {e}")
48
  return None
49
 
50
- def create_wound_bed_mask(cropped_image: np.ndarray) -> np.ndarray:
51
  """
52
- Creates a general binary mask of the entire wound bed, which will be the area for the heatmap.
53
- This uses a broader color range than specific tissue segmentation.
54
  """
55
- lab_image = cv2.cvtColor(cropped_image, cv2.COLOR_BGR2LAB)
56
-
57
- # Broad thresholds to capture the entire wound area (slough, granulation, etc.)
58
- lower_bound = np.array([0, 135, 135])
59
- upper_bound = np.array([255, 200, 200])
60
-
61
- mask = cv2.inRange(lab_image, lower_bound, upper_bound)
62
-
63
- # Clean up the mask
64
- kernel = np.ones((5, 5), np.uint8)
65
- mask_cleaned = cv2.morphologyEx(mask, cv2.MORPH_OPEN, kernel, iterations=2)
66
- mask_cleaned = cv2.morphologyEx(mask_cleaned, cv2.MORPH_CLOSE, kernel, iterations=3)
67
-
68
- return mask_cleaned
 
 
 
 
 
 
 
 
 
69
 
70
  def calculate_metrics(mask: np.ndarray) -> dict:
71
- """Calculates dimensional metrics from the overall wound mask."""
72
  area_pixels = cv2.countNonZero(mask)
73
  if area_pixels == 0:
74
  return {"area_cm2": 0.0, "length_cm": 0.0, "breadth_cm": 0.0, "depth_cm": 0.0, "moisture": 0.0}
75
 
76
  area_cm2 = area_pixels / (PIXELS_PER_CM ** 2)
77
-
78
  contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
79
  if not contours:
80
  return {"area_cm2": area_cm2, "length_cm": 0.0, "breadth_cm": 0.0, "depth_cm": 0.0, "moisture": 0.0}
@@ -90,39 +102,25 @@ def calculate_metrics(mask: np.ndarray) -> dict:
90
  return {"area_cm2": area_cm2, "length_cm": length_cm, "breadth_cm": breadth_cm, "depth_cm": depth_cm, "moisture": moisture}
91
 
92
  def create_three_color_heatmap(image: np.ndarray, mask: np.ndarray) -> np.ndarray:
93
- """
94
- Generates and overlays a three-color (Red, Blue, Green) heatmap onto the image
95
- based on the intensity of the 'a' channel (redness) in the LAB color space.
96
- """
97
  if cv2.countNonZero(mask) == 0:
98
  return image
99
 
100
- # Convert the region of interest to LAB color space for analysis
101
  lab_image = cv2.cvtColor(image, cv2.COLOR_BGR2LAB)
102
- a_channel = lab_image[:, :, 1] # The 'a' channel represents the green-red axis
103
-
104
- # Create a color overlay image, initially transparent
105
  overlay = np.zeros_like(image)
106
 
107
- # Define thresholds for redness intensity. These values might need tuning.
108
- # Values are based on the 'a' channel, where ~128 is neutral.
109
- RED_THRESHOLD = 160 # Most intense red (e.g., fresh granulation)
110
- BLUE_THRESHOLD = 145 # Medium red (e.g., developing tissue)
111
- # Anything above a lower bound (e.g., 135) will be green.
112
-
113
- # Apply colors based on thresholds, only within the wound mask
114
- # Red for "most" affected
115
- overlay[(a_channel >= RED_THRESHOLD) & (mask == 255)] = (0, 0, 255) # BGR for Red
116
- # Blue for "less" affected
117
- overlay[(a_channel >= BLUE_THRESHOLD) & (a_channel < RED_THRESHOLD) & (mask == 255)] = (255, 0, 0) # BGR for Blue
118
- # Green for "least" affected (but still part of the wound bed)
119
- overlay[(a_channel < BLUE_THRESHOLD) & (mask == 255)] = (0, 255, 0) # BGR for Green
120
-
121
- # Blend the original image with the color overlay
122
- # A weight of 0.4 for the overlay makes it visible but not overpowering
123
  blended_image = cv2.addWeighted(overlay, 0.4, image, 0.6, 0)
124
 
125
- # To make the result clean, we only apply the blended result where the mask is active
126
  final_image = image.copy()
127
  final_image[mask == 255] = blended_image[mask == 255]
128
 
@@ -131,33 +129,37 @@ def create_three_color_heatmap(image: np.ndarray, mask: np.ndarray) -> np.ndarra
131
  # --- Main API Endpoint ---
132
  @app.post("/analyze_wound")
133
  async def analyze_wound(file: UploadFile = File(...)):
134
- if not yolo_model:
135
- raise HTTPException(status_code=503, detail="YOLO model is not available.")
136
 
137
  contents = await file.read()
 
138
  original_image = cv2.imdecode(np.frombuffer(contents, np.uint8), cv2.IMREAD_COLOR)
139
  if original_image is None:
140
- raise HTTPException(status_code=400, detail="Invalid image file")
141
 
142
- processed_image = preprocess_image(original_image)
143
-
144
- bbox = detect_wound_region_yolo(processed_image)
145
  if not bbox:
146
  raise HTTPException(status_code=404, detail="No wound detected in the image.")
147
 
148
  xmin, ymin, xmax, ymax = bbox
 
149
  cropped_image_roi = original_image[ymin:ymax, xmin:xmax]
150
 
151
- # Step 1: Create a mask for the entire wound bed in the cropped image
152
- wound_mask = create_wound_bed_mask(cropped_image_roi)
 
 
153
 
154
- # Step 2: Calculate metrics based on this overall wound mask
155
  metrics = calculate_metrics(wound_mask)
156
 
157
- # Step 3: Generate the three-color heatmap on the cropped image
158
  heatmap_image = create_three_color_heatmap(cropped_image_roi, wound_mask)
159
 
160
- # Step 4: Encode the final, annotated (and already cropped) image
161
  success, png_data = cv2.imencode(".png", heatmap_image)
162
  if not success:
163
  raise HTTPException(status_code=500, detail="Failed to encode output image")
 
2
  import cv2
3
  import numpy as np
4
  from ultralytics import YOLO
5
+ import tensorflow as tf
6
  import io
7
  from typing import Union
8
 
 
11
 
12
  # --- App Initialization ---
13
  app = FastAPI(
14
+ title="High-Quality Wound Heatmap API",
15
+ description="Generates a high-quality wound heatmap using a DL model, preserving original image quality.",
16
+ version="7.0.0" # Version updated for quality preservation
17
  )
18
 
19
  # --- Model Loading ---
20
+ def load_models():
21
+ """Loads both the YOLO and the TensorFlow segmentation models."""
22
+ yolo_model, segmentation_model = None, None
23
  try:
24
  yolo_model = YOLO("best.pt")
25
  print("YOLO model 'best.pt' loaded successfully.")
 
26
  except Exception as e:
27
+ print(f"Warning: Could not load YOLO model. Error: {e}")
 
28
 
29
+ try:
30
+ segmentation_model = tf.keras.models.load_model("segmentation_model.h5")
31
+ print("Segmentation model 'segmentation_model.h5' loaded successfully.")
32
+ except Exception as e:
33
+ print(f"Warning: Could not load segmentation model. Error: {e}")
34
+
35
+ return yolo_model, segmentation_model
36
 
37
+ yolo_model, segmentation_model = load_models()
38
 
39
+ # --- Helper Functions ---
 
 
40
 
41
  def detect_wound_region_yolo(image: np.ndarray) -> Union[tuple, None]:
42
+ """Detects the primary wound bounding box using the YOLO model on the original quality image."""
43
  if not yolo_model: return None
44
  try:
45
  results = yolo_model.predict(image, verbose=False)
 
51
  print(f"YOLO prediction failed: {e}")
52
  return None
53
 
54
+ def segment_wound_with_model(image: np.ndarray) -> Union[np.ndarray, None]:
55
  """
56
+ Segments the wound using the TF/Keras model.
57
+ It resizes a copy for the model but returns a mask matching the original image's dimensions.
58
  """
59
+ if not segmentation_model:
60
+ print("Segmentation model not loaded, cannot create mask.")
61
+ return None
62
+ try:
63
+ input_shape = segmentation_model.input_shape[1:3]
64
+ # A temporary, resized copy is made for the model's prediction
65
+ img_resized_for_model = cv2.resize(image, (input_shape[1], input_shape[0]))
66
+ img_norm = np.expand_dims(img_resized_for_model.astype(np.float32) / 255.0, axis=0)
67
+
68
+ prediction = segmentation_model.predict(img_norm, verbose=0)
69
+
70
+ while isinstance(prediction, list):
71
+ prediction = prediction[0]
72
+ if isinstance(prediction, tf.Tensor):
73
+ prediction = prediction.numpy()
74
+
75
+ pred_mask = prediction[0]
76
+ # The resulting mask is resized back to the original image's dimensions to ensure perfect alignment
77
+ pred_mask_resized = cv2.resize(pred_mask, (image.shape[1], image.shape[0]))
78
+ return (pred_mask_resized.squeeze() >= 0.5).astype(np.uint8) * 255
79
+ except Exception as e:
80
+ print(f"Segmentation model prediction failed: {e}")
81
+ return None
82
 
83
  def calculate_metrics(mask: np.ndarray) -> dict:
84
+ """Calculates dimensional metrics from the full-resolution wound mask."""
85
  area_pixels = cv2.countNonZero(mask)
86
  if area_pixels == 0:
87
  return {"area_cm2": 0.0, "length_cm": 0.0, "breadth_cm": 0.0, "depth_cm": 0.0, "moisture": 0.0}
88
 
89
  area_cm2 = area_pixels / (PIXELS_PER_CM ** 2)
 
90
  contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
91
  if not contours:
92
  return {"area_cm2": area_cm2, "length_cm": 0.0, "breadth_cm": 0.0, "depth_cm": 0.0, "moisture": 0.0}
 
102
  return {"area_cm2": area_cm2, "length_cm": length_cm, "breadth_cm": breadth_cm, "depth_cm": depth_cm, "moisture": moisture}
103
 
104
  def create_three_color_heatmap(image: np.ndarray, mask: np.ndarray) -> np.ndarray:
105
+ """Generates and overlays a three-color heatmap, preserving the underlying image quality."""
 
 
 
106
  if cv2.countNonZero(mask) == 0:
107
  return image
108
 
 
109
  lab_image = cv2.cvtColor(image, cv2.COLOR_BGR2LAB)
110
+ a_channel = lab_image[:, :, 1]
 
 
111
  overlay = np.zeros_like(image)
112
 
113
+ RED_THRESHOLD = 160
114
+ BLUE_THRESHOLD = 145
115
+
116
+ overlay[(a_channel >= RED_THRESHOLD) & (mask == 255)] = (0, 0, 255)
117
+ overlay[(a_channel >= BLUE_THRESHOLD) & (a_channel < RED_THRESHOLD) & (mask == 255)] = (255, 0, 0)
118
+ overlay[(a_channel < BLUE_THRESHOLD) & (mask == 255)] = (0, 255, 0)
119
+
120
+ # Blend the overlay with the original image
 
 
 
 
 
 
 
 
121
  blended_image = cv2.addWeighted(overlay, 0.4, image, 0.6, 0)
122
 
123
+ # Create the final image by taking the original and replacing only the masked area with the blended version
124
  final_image = image.copy()
125
  final_image[mask == 255] = blended_image[mask == 255]
126
 
 
129
  # --- Main API Endpoint ---
130
  @app.post("/analyze_wound")
131
  async def analyze_wound(file: UploadFile = File(...)):
132
+ if not yolo_model or not segmentation_model:
133
+ raise HTTPException(status_code=503, detail="A required model is not available.")
134
 
135
  contents = await file.read()
136
+ # Read the image from buffer into a cv2 object without any lossy conversions
137
  original_image = cv2.imdecode(np.frombuffer(contents, np.uint8), cv2.IMREAD_COLOR)
138
  if original_image is None:
139
+ raise HTTPException(status_code=400, detail="Invalid or corrupt image file")
140
 
141
+ # No preprocessing is done to the original image to preserve quality.
142
+ # A copy of the original image is used for detection.
143
+ bbox = detect_wound_region_yolo(original_image.copy())
144
  if not bbox:
145
  raise HTTPException(status_code=404, detail="No wound detected in the image.")
146
 
147
  xmin, ymin, xmax, ymax = bbox
148
+ # Crop the region of interest from the original, high-quality image
149
  cropped_image_roi = original_image[ymin:ymax, xmin:xmax]
150
 
151
+ # Step 1: Use the DL model on the high-quality crop to get a precise mask
152
+ wound_mask = segment_wound_with_model(cropped_image_roi)
153
+ if wound_mask is None or cv2.countNonZero(wound_mask) == 0:
154
+ raise HTTPException(status_code=404, detail="Segmentation model failed to identify a wound in the detected region.")
155
 
156
+ # Step 2: Calculate metrics based on the full-resolution mask
157
  metrics = calculate_metrics(wound_mask)
158
 
159
+ # Step 3: Generate the heatmap on the high-quality cropped image
160
  heatmap_image = create_three_color_heatmap(cropped_image_roi, wound_mask)
161
 
162
+ # Step 4: Encode the final image into PNG (lossless format) to preserve quality
163
  success, png_data = cv2.imencode(".png", heatmap_image)
164
  if not success:
165
  raise HTTPException(status_code=500, detail="Failed to encode output image")