sagar007 commited on
Commit
3ba1061
·
verified ·
1 Parent(s): 7bee2b4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +17 -5
app.py CHANGED
@@ -13,15 +13,21 @@ model = SamModel.from_pretrained("facebook/sam-vit-base").to(device)
13
  processor = SamProcessor.from_pretrained("facebook/sam-vit-base")
14
 
15
  def segment_image(input_image, segment_anything):
 
 
 
16
  # Convert input_image to PIL Image
17
  input_image = Image.fromarray(input_image)
18
 
 
 
 
19
  if segment_anything:
20
  # Segment everything in the image
21
  inputs = processor(input_image, return_tensors="pt").to(device)
22
  else:
23
  # Use the center of the image as a point prompt
24
- height, width = input_image.size
25
  center_point = [[width // 2, height // 2]]
26
  inputs = processor(input_image, input_points=[center_point], return_tensors="pt").to(device)
27
 
@@ -36,7 +42,7 @@ def segment_image(input_image, segment_anything):
36
  inputs["reshaped_input_sizes"].cpu()
37
  )
38
 
39
- # Convert mask to numpy array
40
  if segment_anything:
41
  # Combine all masks
42
  combined_mask = np.any(masks[0].numpy() > 0.5, axis=0)
@@ -44,22 +50,28 @@ def segment_image(input_image, segment_anything):
44
  # Use the first mask
45
  combined_mask = masks[0][0].numpy() > 0.5
46
 
 
 
 
47
  # Overlay the mask on the original image
48
  result_image = np.array(input_image)
49
  mask_rgb = np.zeros_like(result_image)
50
  mask_rgb[combined_mask] = [255, 0, 0] # Red color for the mask
51
  result_image = cv2.addWeighted(result_image, 1, mask_rgb, 0.5, 0)
52
 
53
- return result_image
54
 
55
  # Create Gradio interface
56
  iface = gr.Interface(
57
  fn=segment_image,
58
  inputs=[
59
- gr.Image(type="numpy"),
60
  gr.Checkbox(label="Segment Everything")
61
  ],
62
- outputs=gr.Image(type="numpy"),
 
 
 
63
  title="Segment Anything Model (SAM) Image Segmentation",
64
  description="Upload an image and choose whether to segment everything or use a center point."
65
  )
 
13
  processor = SamProcessor.from_pretrained("facebook/sam-vit-base")
14
 
15
  def segment_image(input_image, segment_anything):
16
+ if input_image is None:
17
+ return None, "Please upload an image before submitting."
18
+
19
  # Convert input_image to PIL Image
20
  input_image = Image.fromarray(input_image)
21
 
22
+ # Store original size
23
+ original_size = input_image.size
24
+
25
  if segment_anything:
26
  # Segment everything in the image
27
  inputs = processor(input_image, return_tensors="pt").to(device)
28
  else:
29
  # Use the center of the image as a point prompt
30
+ width, height = input_image.size
31
  center_point = [[width // 2, height // 2]]
32
  inputs = processor(input_image, input_points=[center_point], return_tensors="pt").to(device)
33
 
 
42
  inputs["reshaped_input_sizes"].cpu()
43
  )
44
 
45
+ # Convert mask to numpy array and resize to match original image
46
  if segment_anything:
47
  # Combine all masks
48
  combined_mask = np.any(masks[0].numpy() > 0.5, axis=0)
 
50
  # Use the first mask
51
  combined_mask = masks[0][0].numpy() > 0.5
52
 
53
+ # Resize mask to match original image size
54
+ combined_mask = cv2.resize(combined_mask.astype(np.uint8), original_size[::-1]) > 0
55
+
56
  # Overlay the mask on the original image
57
  result_image = np.array(input_image)
58
  mask_rgb = np.zeros_like(result_image)
59
  mask_rgb[combined_mask] = [255, 0, 0] # Red color for the mask
60
  result_image = cv2.addWeighted(result_image, 1, mask_rgb, 0.5, 0)
61
 
62
+ return result_image, "Segmentation completed successfully."
63
 
64
  # Create Gradio interface
65
  iface = gr.Interface(
66
  fn=segment_image,
67
  inputs=[
68
+ gr.Image(type="numpy", label="Upload an image"),
69
  gr.Checkbox(label="Segment Everything")
70
  ],
71
+ outputs=[
72
+ gr.Image(type="numpy", label="Segmented Image"),
73
+ gr.Textbox(label="Status")
74
+ ],
75
  title="Segment Anything Model (SAM) Image Segmentation",
76
  description="Upload an image and choose whether to segment everything or use a center point."
77
  )