Ani14 commited on
Commit
ba0f49e
·
verified ·
1 Parent(s): 7ff1d87

Update predict.py

Browse files
Files changed (1) hide show
  1. predict.py +18 -16
predict.py CHANGED
@@ -7,8 +7,6 @@ import io
7
  from typing import Union
8
 
9
  # --- Model Loading ---
10
- # This section attempts to load the specific deep learning models you are using.
11
-
12
  def load_models():
13
  """Loads TensorFlow and YOLO models using your specified filenames."""
14
  segmentation_model, yolo_detector = None, None
@@ -39,7 +37,7 @@ PIXELS_PER_CM = 50.0
39
  app = FastAPI(
40
  title="Wound Analysis API",
41
  description="A comprehensive API to analyze wound images using deep learning and computer vision techniques.",
42
- version="9.0.0" # Version with improved visualization (Yellow Heatmap + Boundary)
43
  )
44
 
45
 
@@ -67,7 +65,17 @@ def segment_wound(image: np.ndarray) -> np.ndarray:
67
  model_input_size = segmentation_model.input.shape[1:3]
68
  img_resized = cv2.resize(image, (model_input_size[1], model_input_size[0]))
69
  img_norm = np.expand_dims(img_resized.astype(np.float32) / 255.0, axis=0)
70
- pred_mask = segmentation_model.predict(img_norm, verbose=0)[0]
 
 
 
 
 
 
 
 
 
 
71
  pred_mask_resized = cv2.resize(pred_mask, (orig_w, orig_h), interpolation=cv2.INTER_LINEAR)
72
  mask = (pred_mask_resized.squeeze() >= 0.5).astype(np.uint8) * 255
73
  if cv2.countNonZero(mask) > 0:
@@ -75,6 +83,7 @@ def segment_wound(image: np.ndarray) -> np.ndarray:
75
  except Exception as e:
76
  print(f"Model prediction failed, switching to fallback segmentation. Error: {e}")
77
 
 
78
  pixels = image.reshape((-1, 3)).astype(np.float32)
79
  criteria = (cv2.TERM_CRITERIA_EPS + cv2.TERM_CRITERIA_MAX_ITER, 10, 1.0)
80
  _, labels, centers = cv2.kmeans(pixels, 2, None, criteria, 5, cv2.KMEANS_PP_CENTERS)
@@ -115,28 +124,21 @@ def calculate_all_metrics(mask: np.ndarray, image: np.ndarray) -> dict:
115
  }
116
 
117
  def create_visual_overlay(image: np.ndarray, mask: np.ndarray) -> np.ndarray:
118
- """
119
- Generates a visual overlay with a Yellow/Blue/Green heatmap and a white boundary.
120
- """
121
- # --- 1. Create the Color Heatmap ---
122
  dist = cv2.distanceTransform(mask, cv2.DIST_L2, 5)
123
  cv2.normalize(dist, dist, 0, 1.0, cv2.NORM_MINMAX)
124
  overlay = np.zeros_like(image)
125
 
126
- # **CHANGE**: Use Yellow instead of Red for the most affected area for better visibility.
127
- overlay[dist >= 0.66] = (0, 255, 255) # Yellow in BGR
128
- overlay[(dist >= 0.33) & (dist < 0.66)] = (255, 0, 0) # Blue in BGR
129
- overlay[(dist > 0) & (dist < 0.33)] = (0, 255, 0) # Green in BGR
130
 
131
  blended = cv2.addWeighted(image, 0.7, overlay, 0.3, 0)
132
  final_image = image.copy()
133
  final_image[mask.astype(bool)] = blended[mask.astype(bool)]
134
 
135
- # --- 2. Draw the Boundary Contour ---
136
- # Find contours from the mask to draw the boundary line.
137
  contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
138
- # **NEW**: Draw a crisp white boundary on the final image.
139
- cv2.drawContours(final_image, contours, -1, (255, 255, 255), 1) # White color, 1px thickness
140
 
141
  return final_image
142
 
 
7
  from typing import Union
8
 
9
  # --- Model Loading ---
 
 
10
  def load_models():
11
  """Loads TensorFlow and YOLO models using your specified filenames."""
12
  segmentation_model, yolo_detector = None, None
 
37
  app = FastAPI(
38
  title="Wound Analysis API",
39
  description="A comprehensive API to analyze wound images using deep learning and computer vision techniques.",
40
+ version="9.1.0" # Version with fix for model prediction output format
41
  )
42
 
43
 
 
65
  model_input_size = segmentation_model.input.shape[1:3]
66
  img_resized = cv2.resize(image, (model_input_size[1], model_input_size[0]))
67
  img_norm = np.expand_dims(img_resized.astype(np.float32) / 255.0, axis=0)
68
+
69
+ prediction = segmentation_model.predict(img_norm, verbose=0)
70
+
71
+ # --- FIX: Handle model output that is a list ---
72
+ # If the prediction is a list, extract the first element which is the actual numpy array.
73
+ if isinstance(prediction, list):
74
+ pred_mask = prediction[0]
75
+ else:
76
+ pred_mask = prediction
77
+ # --- END FIX ---
78
+
79
  pred_mask_resized = cv2.resize(pred_mask, (orig_w, orig_h), interpolation=cv2.INTER_LINEAR)
80
  mask = (pred_mask_resized.squeeze() >= 0.5).astype(np.uint8) * 255
81
  if cv2.countNonZero(mask) > 0:
 
83
  except Exception as e:
84
  print(f"Model prediction failed, switching to fallback segmentation. Error: {e}")
85
 
86
+ # Fallback Method
87
  pixels = image.reshape((-1, 3)).astype(np.float32)
88
  criteria = (cv2.TERM_CRITERIA_EPS + cv2.TERM_CRITERIA_MAX_ITER, 10, 1.0)
89
  _, labels, centers = cv2.kmeans(pixels, 2, None, criteria, 5, cv2.KMEANS_PP_CENTERS)
 
124
  }
125
 
126
  def create_visual_overlay(image: np.ndarray, mask: np.ndarray) -> np.ndarray:
127
+ """Generates a visual overlay with a Yellow/Blue/Green heatmap and a white boundary."""
 
 
 
128
  dist = cv2.distanceTransform(mask, cv2.DIST_L2, 5)
129
  cv2.normalize(dist, dist, 0, 1.0, cv2.NORM_MINMAX)
130
  overlay = np.zeros_like(image)
131
 
132
+ overlay[dist >= 0.66] = (0, 255, 255)
133
+ overlay[(dist >= 0.33) & (dist < 0.66)] = (255, 0, 0)
134
+ overlay[(dist > 0) & (dist < 0.33)] = (0, 255, 0)
 
135
 
136
  blended = cv2.addWeighted(image, 0.7, overlay, 0.3, 0)
137
  final_image = image.copy()
138
  final_image[mask.astype(bool)] = blended[mask.astype(bool)]
139
 
 
 
140
  contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
141
+ cv2.drawContours(final_image, contours, -1, (255, 255, 255), 1)
 
142
 
143
  return final_image
144