Ani14 commited on
Commit
ecbbf4c
·
verified ·
1 Parent(s): c727c1c

Update predict.py

Browse files
Files changed (1) hide show
  1. predict.py +10 -15
predict.py CHANGED
@@ -63,34 +63,29 @@ def detect_wound_region_yolo(image: np.ndarray) -> Union[tuple, None]:
63
  return None
64
 
65
  def segment_wound_with_model(image: np.ndarray) -> Union[np.ndarray, None]:
66
- """Segments the wound using the U-Net model."""
67
  if not segmentation_model:
68
  return None
69
  try:
70
- input_shape = segmentation_model.input.shape[1:3]
71
  img_resized = cv2.resize(image, (input_shape[1], input_shape[0]))
72
  img_norm = np.expand_dims(img_resized.astype(np.float32) / 255.0, axis=0)
73
-
74
  prediction = segmentation_model.predict(img_norm, verbose=0)
75
-
76
- # --- FIX APPLIED HERE ---
77
- # Handle cases where the model returns a list of outputs
78
- if isinstance(prediction, list):
79
- pred_mask = prediction[0]
80
- else:
81
- pred_mask = prediction
82
-
83
- # The output of predict() is batched, so get the first item.
84
- pred_mask = pred_mask[0]
85
- # --- END OF FIX ---
86
 
 
 
 
 
 
 
 
87
  pred_mask_resized = cv2.resize(pred_mask, (image.shape[1], image.shape[0]))
88
-
89
  return (pred_mask_resized.squeeze() >= 0.5).astype(np.uint8) * 255
90
  except Exception as e:
91
  print(f"Segmentation model prediction failed: {e}")
92
  return None
93
 
 
94
  def segment_wound_with_fallback(image: np.ndarray) -> np.ndarray:
95
  pixels = image.reshape((-1, 3)).astype(np.float32)
96
  criteria = (cv2.TERM_CRITERIA_EPS + cv2.TERM_CRITERIA_MAX_ITER, 10, 1.0)
 
63
  return None
64
 
65
  def segment_wound_with_model(image: np.ndarray) -> Union[np.ndarray, None]:
 
66
  if not segmentation_model:
67
  return None
68
  try:
69
+ input_shape = segmentation_model.input_shape[1:3]
70
  img_resized = cv2.resize(image, (input_shape[1], input_shape[0]))
71
  img_norm = np.expand_dims(img_resized.astype(np.float32) / 255.0, axis=0)
72
+
73
  prediction = segmentation_model.predict(img_norm, verbose=0)
 
 
 
 
 
 
 
 
 
 
 
74
 
75
+ # FIX: Handle nested list output or Tensor
76
+ while isinstance(prediction, list):
77
+ prediction = prediction[0]
78
+ if isinstance(prediction, tf.Tensor):
79
+ prediction = prediction.numpy()
80
+
81
+ pred_mask = prediction[0]
82
  pred_mask_resized = cv2.resize(pred_mask, (image.shape[1], image.shape[0]))
 
83
  return (pred_mask_resized.squeeze() >= 0.5).astype(np.uint8) * 255
84
  except Exception as e:
85
  print(f"Segmentation model prediction failed: {e}")
86
  return None
87
 
88
+
89
  def segment_wound_with_fallback(image: np.ndarray) -> np.ndarray:
90
  pixels = image.reshape((-1, 3)).astype(np.float32)
91
  criteria = (cv2.TERM_CRITERIA_EPS + cv2.TERM_CRITERIA_MAX_ITER, 10, 1.0)