sagar007 commited on
Commit
54db35c
·
verified ·
1 Parent(s): a3ee867

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +13 -8
app.py CHANGED
@@ -17,29 +17,34 @@ def segment_everything(image):
17
  return Image.fromarray(segmentation)
18
 
19
  def segment_box(image, box):
 
 
20
  x1, y1, x2, y2 = map(int, box)
21
- mask = Image.new('L', image.size, 0)
22
  draw = ImageDraw.Draw(mask)
23
  draw.rectangle([x1, y1, x2, y2], fill=255)
24
 
25
- inputs = processor(text=["object in box"], images=[image], mask_pixels=mask, padding="max_length", return_tensors="pt")
26
  with torch.no_grad():
27
  outputs = model(**inputs)
28
  preds = outputs.logits.squeeze().sigmoid()
29
  segmentation = (preds.numpy() * 255).astype(np.uint8)
30
  return Image.fromarray(segmentation)
31
 
32
- def update_image(image, segmentation, tool):
33
  if segmentation is None:
34
  return image
35
- blended = Image.blend(image.convert('RGBA'), segmentation.convert('RGBA'), 0.5)
36
- return blended
 
 
37
 
38
  with gr.Blocks() as demo:
39
  gr.Markdown("# Segment Anything-like Demo")
40
  with gr.Row():
41
  with gr.Column(scale=1):
42
- input_image = gr.Image(label="Input Image", tool="select")
 
43
  with gr.Row():
44
  everything_btn = gr.Button("Everything")
45
  box_btn = gr.Button("Box")
@@ -54,13 +59,13 @@ with gr.Blocks() as demo:
54
 
55
  box_btn.click(
56
  fn=segment_box,
57
- inputs=[input_image, input_image.sel],
58
  outputs=[output_image]
59
  )
60
 
61
  output_image.change(
62
  fn=update_image,
63
- inputs=[input_image, output_image, gr.State("last_tool")],
64
  outputs=[output_image]
65
  )
66
 
 
17
  return Image.fromarray(segmentation)
18
 
19
  def segment_box(image, box):
20
+ if box is None:
21
+ return image
22
  x1, y1, x2, y2 = map(int, box)
23
+ mask = Image.new('L', (image.shape[1], image.shape[0]), 0)
24
  draw = ImageDraw.Draw(mask)
25
  draw.rectangle([x1, y1, x2, y2], fill=255)
26
 
27
+ inputs = processor(text=["object in box"], images=[image], mask_pixels=np.array(mask), padding="max_length", return_tensors="pt")
28
  with torch.no_grad():
29
  outputs = model(**inputs)
30
  preds = outputs.logits.squeeze().sigmoid()
31
  segmentation = (preds.numpy() * 255).astype(np.uint8)
32
  return Image.fromarray(segmentation)
33
 
34
+ def update_image(image, segmentation):
35
  if segmentation is None:
36
  return image
37
+ image_pil = Image.fromarray((image * 255).astype(np.uint8))
38
+ seg_pil = Image.fromarray(segmentation).convert('RGBA')
39
+ blended = Image.blend(image_pil.convert('RGBA'), seg_pil, 0.5)
40
+ return np.array(blended)
41
 
42
  with gr.Blocks() as demo:
43
  gr.Markdown("# Segment Anything-like Demo")
44
  with gr.Row():
45
  with gr.Column(scale=1):
46
+ input_image = gr.Image(label="Input Image")
47
+ box_input = gr.Box(label="Select Box")
48
  with gr.Row():
49
  everything_btn = gr.Button("Everything")
50
  box_btn = gr.Button("Box")
 
59
 
60
  box_btn.click(
61
  fn=segment_box,
62
+ inputs=[input_image, box_input],
63
  outputs=[output_image]
64
  )
65
 
66
  output_image.change(
67
  fn=update_image,
68
+ inputs=[input_image, output_image],
69
  outputs=[output_image]
70
  )
71