sagar007 commited on
Commit
1d4c655
·
verified ·
1 Parent(s): 4db07a0

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +16 -6
app.py CHANGED
@@ -19,12 +19,10 @@ def segment_everything(image):
19
  def segment_box(image, x1, y1, x2, y2):
20
  x1, y1, x2, y2 = int(x1), int(y1), int(x2), int(y2)
21
  cropped_image = image[y1:y2, x1:x2]
22
-
23
  inputs = processor(text=["object"], images=[cropped_image], padding="max_length", return_tensors="pt")
24
  with torch.no_grad():
25
  outputs = model(**inputs)
26
  preds = outputs.logits.squeeze().sigmoid()
27
-
28
  segmentation = np.zeros((image.shape[0], image.shape[1]), dtype=np.uint8)
29
  segmentation[y1:y2, x1:x2] = (preds.numpy() * 255).astype(np.uint8)
30
  return Image.fromarray(segmentation)
@@ -32,9 +30,23 @@ def segment_box(image, x1, y1, x2, y2):
32
  def update_image(image, segmentation):
33
  if segmentation is None:
34
  return image
35
- image_pil = Image.fromarray((image * 255).astype(np.uint8))
 
 
 
 
 
 
 
36
  seg_pil = Image.fromarray(segmentation).convert('RGBA')
 
 
 
 
 
 
37
  blended = Image.blend(image_pil.convert('RGBA'), seg_pil, 0.5)
 
38
  return np.array(blended)
39
 
40
  with gr.Blocks() as demo:
@@ -52,19 +64,17 @@ with gr.Blocks() as demo:
52
  box_btn = gr.Button("Box")
53
  with gr.Column(scale=1):
54
  output_image = gr.Image(label="Segmentation Result")
55
-
56
  everything_btn.click(
57
  fn=segment_everything,
58
  inputs=[input_image],
59
  outputs=[output_image]
60
  )
61
-
62
  box_btn.click(
63
  fn=segment_box,
64
  inputs=[input_image, x1_input, y1_input, x2_input, y2_input],
65
  outputs=[output_image]
66
  )
67
-
68
  output_image.change(
69
  fn=update_image,
70
  inputs=[input_image, output_image],
 
19
  def segment_box(image, x1, y1, x2, y2):
20
  x1, y1, x2, y2 = int(x1), int(y1), int(x2), int(y2)
21
  cropped_image = image[y1:y2, x1:x2]
 
22
  inputs = processor(text=["object"], images=[cropped_image], padding="max_length", return_tensors="pt")
23
  with torch.no_grad():
24
  outputs = model(**inputs)
25
  preds = outputs.logits.squeeze().sigmoid()
 
26
  segmentation = np.zeros((image.shape[0], image.shape[1]), dtype=np.uint8)
27
  segmentation[y1:y2, x1:x2] = (preds.numpy() * 255).astype(np.uint8)
28
  return Image.fromarray(segmentation)
 
30
  def update_image(image, segmentation):
31
  if segmentation is None:
32
  return image
33
+
34
+ # Ensure image is in the correct format (PIL Image)
35
+ if isinstance(image, np.ndarray):
36
+ image_pil = Image.fromarray((image * 255).astype(np.uint8))
37
+ else:
38
+ image_pil = image
39
+
40
+ # Convert segmentation to RGBA
41
  seg_pil = Image.fromarray(segmentation).convert('RGBA')
42
+
43
+ # Resize segmentation to match input image if necessary
44
+ if image_pil.size != seg_pil.size:
45
+ seg_pil = seg_pil.resize(image_pil.size, Image.NEAREST)
46
+
47
+ # Blend images
48
  blended = Image.blend(image_pil.convert('RGBA'), seg_pil, 0.5)
49
+
50
  return np.array(blended)
51
 
52
  with gr.Blocks() as demo:
 
64
  box_btn = gr.Button("Box")
65
  with gr.Column(scale=1):
66
  output_image = gr.Image(label="Segmentation Result")
67
+
68
  everything_btn.click(
69
  fn=segment_everything,
70
  inputs=[input_image],
71
  outputs=[output_image]
72
  )
 
73
  box_btn.click(
74
  fn=segment_box,
75
  inputs=[input_image, x1_input, y1_input, x2_input, y2_input],
76
  outputs=[output_image]
77
  )
 
78
  output_image.change(
79
  fn=update_image,
80
  inputs=[input_image, output_image],