iSushant commited on
Commit
529397f
·
verified ·
1 Parent(s): fa5df4a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +18 -10
app.py CHANGED
@@ -19,8 +19,7 @@ def dice_loss(y_true, y_pred):
19
 
20
  # Load the model from the same directory as the script
21
  model_filename = "model.h5" # Replace with your actual model filename
22
- model_path = os.path.join(os.path.dirname(__file__), model_filename) # Construct the full path
23
-
24
 
25
  def load_model(model_path):
26
  try:
@@ -37,8 +36,8 @@ def perform_inference(image):
37
  if model is None:
38
  print("Model not loaded properly.")
39
  return None, None, None
40
-
41
- # Preprocess the image
42
  original_shape = image.shape[:2]
43
  image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
44
  image_resized = cv2.resize(image, (256, 256))
@@ -50,18 +49,27 @@ def perform_inference(image):
50
  mask_resized = cv2.resize(mask, (original_shape[1], original_shape[0]))
51
  mask_binary = (mask_resized > 0.5).astype(np.uint8) * 255
52
 
 
 
53
 
54
- # Apply the mask to the original image (for better visualization)
55
  heatmap_img = cv2.applyColorMap(mask_binary, cv2.COLORMAP_JET)
56
  segmented_image = cv2.addWeighted(image, 0.7, heatmap_img, 0.3, 0)
 
57
 
58
- #Convert back to BGR
59
- segmented_image = cv2.cvtColor(segmented_image, cv2.COLOR_RGB2BGR)
60
 
61
- # Convert results to PIL Images
 
 
 
 
 
 
 
 
62
  return (Image.fromarray(image),
63
- Image.fromarray(mask_binary.astype(np.uint8)), #Mask is already multiplied by 255
64
- Image.fromarray(segmented_image))
65
 
66
 
67
  # Gradio app
 
19
 
20
  # Load the model from the same directory as the script
21
  model_filename = "model.h5" # Replace with your actual model filename
22
+ model_path = os.path.join(os.path.dirname(__file__), model_filename)
 
23
 
24
  def load_model(model_path):
25
  try:
 
36
  if model is None:
37
  print("Model not loaded properly.")
38
  return None, None, None
39
+
40
+ # Preprocess the image
41
  original_shape = image.shape[:2]
42
  image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
43
  image_resized = cv2.resize(image, (256, 256))
 
49
  mask_resized = cv2.resize(mask, (original_shape[1], original_shape[0]))
50
  mask_binary = (mask_resized > 0.5).astype(np.uint8) * 255
51
 
52
+ # Find contours in the binary mask
53
+ contours, _ = cv2.findContours(mask_binary, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
54
 
55
+ # Apply the mask to the original image (heatmap for visualization)
56
  heatmap_img = cv2.applyColorMap(mask_binary, cv2.COLORMAP_JET)
57
  segmented_image = cv2.addWeighted(image, 0.7, heatmap_img, 0.3, 0)
58
+ segmented_image_with_box = segmented_image.copy() #make a copy for drawing box
59
 
 
 
60
 
61
+ # Get bounding boxes for all contours
62
+ for contour in contours:
63
+ x, y, w, h = cv2.boundingRect(contour)
64
+ cv2.rectangle(segmented_image_with_box, (x, y), (x + w, y + h), (0, 0, 255), 2)
65
+
66
+
67
+ #Convert back to BGR
68
+ segmented_image_with_box = cv2.cvtColor(segmented_image_with_box, cv2.COLOR_RGB2BGR)
69
+ #Return image with bounding box
70
  return (Image.fromarray(image),
71
+ Image.fromarray(mask_binary.astype(np.uint8)),
72
+ Image.fromarray(segmented_image_with_box))
73
 
74
 
75
  # Gradio app