sagar007 commited on
Commit
7bee2b4
·
verified ·
1 Parent(s): ac51df9

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +18 -9
app.py CHANGED
@@ -12,12 +12,18 @@ device = "cuda" if torch.cuda.is_available() else "cpu"
12
  model = SamModel.from_pretrained("facebook/sam-vit-base").to(device)
13
  processor = SamProcessor.from_pretrained("facebook/sam-vit-base")
14
 
15
- def segment_image(input_image, x, y):
16
  # Convert input_image to PIL Image
17
  input_image = Image.fromarray(input_image)
18
 
19
- # Prepare inputs
20
- inputs = processor(input_image, input_points=np.array([[x, y]]), return_tensors="pt").to(device)
 
 
 
 
 
 
21
 
22
  # Generate masks
23
  with torch.no_grad():
@@ -29,15 +35,19 @@ def segment_image(input_image, x, y):
29
  inputs["original_sizes"].cpu(),
30
  inputs["reshaped_input_sizes"].cpu()
31
  )
32
- scores = outputs.iou_scores
33
 
34
  # Convert mask to numpy array
35
- mask = masks[0][0].numpy()
 
 
 
 
 
36
 
37
  # Overlay the mask on the original image
38
  result_image = np.array(input_image)
39
  mask_rgb = np.zeros_like(result_image)
40
- mask_rgb[mask > 0.5] = [255, 0, 0] # Red color for the mask
41
  result_image = cv2.addWeighted(result_image, 1, mask_rgb, 0.5, 0)
42
 
43
  return result_image
@@ -47,12 +57,11 @@ iface = gr.Interface(
47
  fn=segment_image,
48
  inputs=[
49
  gr.Image(type="numpy"),
50
- gr.Slider(minimum=0, maximum=1000, step=1, label="X coordinate"),
51
- gr.Slider(minimum=0, maximum=1000, step=1, label="Y coordinate")
52
  ],
53
  outputs=gr.Image(type="numpy"),
54
  title="Segment Anything Model (SAM) Image Segmentation",
55
- description="Enter X and Y coordinates of the object you want to segment."
56
  )
57
 
58
  # Launch the interface
 
12
  model = SamModel.from_pretrained("facebook/sam-vit-base").to(device)
13
  processor = SamProcessor.from_pretrained("facebook/sam-vit-base")
14
 
15
+ def segment_image(input_image, segment_anything):
16
  # Convert input_image to PIL Image
17
  input_image = Image.fromarray(input_image)
18
 
19
+ if segment_anything:
20
+ # Segment everything in the image
21
+ inputs = processor(input_image, return_tensors="pt").to(device)
22
+ else:
23
+ # Use the center of the image as a point prompt
24
+ height, width = input_image.size
25
+ center_point = [[width // 2, height // 2]]
26
+ inputs = processor(input_image, input_points=[center_point], return_tensors="pt").to(device)
27
 
28
  # Generate masks
29
  with torch.no_grad():
 
35
  inputs["original_sizes"].cpu(),
36
  inputs["reshaped_input_sizes"].cpu()
37
  )
 
38
 
39
  # Convert mask to numpy array
40
+ if segment_anything:
41
+ # Combine all masks
42
+ combined_mask = np.any(masks[0].numpy() > 0.5, axis=0)
43
+ else:
44
+ # Use the first mask
45
+ combined_mask = masks[0][0].numpy() > 0.5
46
 
47
  # Overlay the mask on the original image
48
  result_image = np.array(input_image)
49
  mask_rgb = np.zeros_like(result_image)
50
+ mask_rgb[combined_mask] = [255, 0, 0] # Red color for the mask
51
  result_image = cv2.addWeighted(result_image, 1, mask_rgb, 0.5, 0)
52
 
53
  return result_image
 
57
  fn=segment_image,
58
  inputs=[
59
  gr.Image(type="numpy"),
60
+ gr.Checkbox(label="Segment Everything")
 
61
  ],
62
  outputs=gr.Image(type="numpy"),
63
  title="Segment Anything Model (SAM) Image Segmentation",
64
+ description="Upload an image and choose whether to segment everything or use a center point."
65
  )
66
 
67
  # Launch the interface