sagar007 commited on
Commit
99fdace
·
verified ·
1 Parent(s): 73989e5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +7 -8
app.py CHANGED
@@ -17,7 +17,7 @@ def segment_image(input_image, segment_anything):
17
  if input_image is None:
18
  return None, "Please upload an image before submitting."
19
 
20
- # Convert input_image to PIL Image
21
  input_image = Image.fromarray(input_image).convert("RGB")
22
 
23
  # Store original size
@@ -25,11 +25,10 @@ def segment_image(input_image, segment_anything):
25
  if not original_size or 0 in original_size:
26
  return None, "Invalid image size. Please upload a different image."
27
 
 
28
  if segment_anything:
29
- # Segment everything in the image
30
  inputs = processor(input_image, return_tensors="pt").to(device)
31
  else:
32
- # Use the center of the image as a point prompt
33
  width, height = original_size
34
  center_point = [[width // 2, height // 2]]
35
  inputs = processor(input_image, input_points=[center_point], return_tensors="pt").to(device)
@@ -45,20 +44,20 @@ def segment_image(input_image, segment_anything):
45
  inputs["reshaped_input_sizes"].cpu()
46
  )
47
 
48
- # Convert mask to numpy array and resize to match original image
49
  if segment_anything:
50
- # Combine all masks
51
  combined_mask = np.any(masks[0].numpy() > 0.5, axis=0)
52
  else:
53
- # Use the first mask
54
  combined_mask = masks[0][0].numpy() > 0.5
55
 
56
  # Ensure mask is 2D
57
  if combined_mask.ndim > 2:
58
  combined_mask = combined_mask.squeeze()
59
 
60
- # Resize mask to match original image size
61
- combined_mask = cv2.resize(combined_mask.astype(np.uint8), (original_size[0], original_size[1])) > 0
 
 
62
 
63
  # Overlay the mask on the original image
64
  result_image = np.array(input_image)
 
17
  if input_image is None:
18
  return None, "Please upload an image before submitting."
19
 
20
+ # Convert input_image to PIL Image and ensure it's RGB
21
  input_image = Image.fromarray(input_image).convert("RGB")
22
 
23
  # Store original size
 
25
  if not original_size or 0 in original_size:
26
  return None, "Invalid image size. Please upload a different image."
27
 
28
+ # Process the image
29
  if segment_anything:
 
30
  inputs = processor(input_image, return_tensors="pt").to(device)
31
  else:
 
32
  width, height = original_size
33
  center_point = [[width // 2, height // 2]]
34
  inputs = processor(input_image, input_points=[center_point], return_tensors="pt").to(device)
 
44
  inputs["reshaped_input_sizes"].cpu()
45
  )
46
 
47
+ # Convert mask to numpy array
48
  if segment_anything:
 
49
  combined_mask = np.any(masks[0].numpy() > 0.5, axis=0)
50
  else:
 
51
  combined_mask = masks[0][0].numpy() > 0.5
52
 
53
  # Ensure mask is 2D
54
  if combined_mask.ndim > 2:
55
  combined_mask = combined_mask.squeeze()
56
 
57
+ # Resize mask to match original image size using PIL
58
+ mask_image = Image.fromarray((combined_mask * 255).astype(np.uint8))
59
+ mask_image = mask_image.resize(original_size, Image.NEAREST)
60
+ combined_mask = np.array(mask_image) > 0
61
 
62
  # Overlay the mask on the original image
63
  result_image = np.array(input_image)