sagar007 commited on
Commit
564688d
·
verified ·
1 Parent(s): 99fdace

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +21 -10
app.py CHANGED
@@ -12,6 +12,24 @@ device = "cuda" if torch.cuda.is_available() else "cpu"
12
  model = SamModel.from_pretrained("facebook/sam-vit-base").to(device)
13
  processor = SamProcessor.from_pretrained("facebook/sam-vit-base")
14
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
15
  def segment_image(input_image, segment_anything):
16
  try:
17
  if input_image is None:
@@ -44,20 +62,13 @@ def segment_image(input_image, segment_anything):
44
  inputs["reshaped_input_sizes"].cpu()
45
  )
46
 
47
- # Convert mask to numpy array
48
  if segment_anything:
49
  combined_mask = np.any(masks[0].numpy() > 0.5, axis=0)
50
  else:
51
- combined_mask = masks[0][0].numpy() > 0.5
52
-
53
- # Ensure mask is 2D
54
- if combined_mask.ndim > 2:
55
- combined_mask = combined_mask.squeeze()
56
 
57
- # Resize mask to match original image size using PIL
58
- mask_image = Image.fromarray((combined_mask * 255).astype(np.uint8))
59
- mask_image = mask_image.resize(original_size, Image.NEAREST)
60
- combined_mask = np.array(mask_image) > 0
61
 
62
  # Overlay the mask on the original image
63
  result_image = np.array(input_image)
 
12
  model = SamModel.from_pretrained("facebook/sam-vit-base").to(device)
13
  processor = SamProcessor.from_pretrained("facebook/sam-vit-base")
14
 
15
+ def process_mask(mask, target_size):
16
+ # Ensure mask is 2D
17
+ if mask.ndim > 2:
18
+ mask = mask.squeeze()
19
+
20
+ # If mask is still not 2D, take the first 2D slice
21
+ if mask.ndim > 2:
22
+ mask = mask[0]
23
+
24
+ # Convert to binary
25
+ mask = (mask > 0.5).astype(np.uint8) * 255
26
+
27
+ # Resize mask to match original image size using PIL
28
+ mask_image = Image.fromarray(mask)
29
+ mask_image = mask_image.resize(target_size, Image.NEAREST)
30
+
31
+ return np.array(mask_image) > 0
32
+
33
  def segment_image(input_image, segment_anything):
34
  try:
35
  if input_image is None:
 
62
  inputs["reshaped_input_sizes"].cpu()
63
  )
64
 
65
+ # Process the mask
66
  if segment_anything:
67
  combined_mask = np.any(masks[0].numpy() > 0.5, axis=0)
68
  else:
69
+ combined_mask = masks[0][0].numpy()
 
 
 
 
70
 
71
+ combined_mask = process_mask(combined_mask, original_size)
 
 
 
72
 
73
  # Overlay the mask on the original image
74
  result_image = np.array(input_image)