sagar007 commited on
Commit
ac51df9
·
verified ·
1 Parent(s): c95f3e0

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +5 -4
app.py CHANGED
@@ -12,12 +12,12 @@ device = "cuda" if torch.cuda.is_available() else "cpu"
12
  model = SamModel.from_pretrained("facebook/sam-vit-base").to(device)
13
  processor = SamProcessor.from_pretrained("facebook/sam-vit-base")
14
 
15
- def segment_image(input_image, points):
16
  # Convert input_image to PIL Image
17
  input_image = Image.fromarray(input_image)
18
 
19
  # Prepare inputs
20
- inputs = processor(input_image, input_points=[points], return_tensors="pt").to(device)
21
 
22
  # Generate masks
23
  with torch.no_grad():
@@ -47,11 +47,12 @@ iface = gr.Interface(
47
  fn=segment_image,
48
  inputs=[
49
  gr.Image(type="numpy"),
50
- gr.Image(type="numpy", tool="sketch", brush_radius=5, label="Click on objects to segment")
 
51
  ],
52
  outputs=gr.Image(type="numpy"),
53
  title="Segment Anything Model (SAM) Image Segmentation",
54
- description="Click on objects in the image to segment them using SAM."
55
  )
56
 
57
  # Launch the interface
 
12
  model = SamModel.from_pretrained("facebook/sam-vit-base").to(device)
13
  processor = SamProcessor.from_pretrained("facebook/sam-vit-base")
14
 
15
+ def segment_image(input_image, x, y):
16
  # Convert input_image to PIL Image
17
  input_image = Image.fromarray(input_image)
18
 
19
  # Prepare inputs
20
+ inputs = processor(input_image, input_points=np.array([[x, y]]), return_tensors="pt").to(device)
21
 
22
  # Generate masks
23
  with torch.no_grad():
 
47
  fn=segment_image,
48
  inputs=[
49
  gr.Image(type="numpy"),
50
+ gr.Slider(minimum=0, maximum=1000, step=1, label="X coordinate"),
51
+ gr.Slider(minimum=0, maximum=1000, step=1, label="Y coordinate")
52
  ],
53
  outputs=gr.Image(type="numpy"),
54
  title="Segment Anything Model (SAM) Image Segmentation",
55
+ description="Enter X and Y coordinates of the object you want to segment."
56
  )
57
 
58
  # Launch the interface