sagar007 commited on
Commit
4f39124
·
verified ·
1 Parent(s): 72f4c5c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +22 -76
app.py CHANGED
@@ -21,110 +21,56 @@ def fig2img(fig):
21
  img = Image.open(buf)
22
  return img
23
 
24
- def plot(annotations, prompt_process, mask_random_color=True, better_quality=True, retina=True, with_contours=True):
 
 
 
25
  for ann in annotations:
26
- image = ann.orig_img[..., ::-1] # BGR to RGB
27
- original_h, original_w = ann.orig_shape
28
- fig = plt.figure(figsize=(original_w / 100, original_h / 100))
29
- plt.subplots_adjust(top=1, bottom=0, right=1, left=0, hspace=0, wspace=0)
30
- plt.margins(0, 0)
31
- plt.gca().xaxis.set_major_locator(plt.NullLocator())
32
- plt.gca().yaxis.set_major_locator(plt.NullLocator())
33
- plt.imshow(image)
34
-
35
- if ann.masks is not None:
36
- masks = ann.masks.data
37
- if better_quality:
38
- if isinstance(masks[0], torch.Tensor):
39
- masks = np.array(masks.cpu())
40
- for i, mask in enumerate(masks):
41
- mask = cv2.morphologyEx(mask.astype(np.uint8), cv2.MORPH_CLOSE, np.ones((3, 3), np.uint8))
42
- masks[i] = cv2.morphologyEx(mask.astype(np.uint8), cv2.MORPH_OPEN, np.ones((8, 8), np.uint8))
43
-
44
- prompt_process.fast_show_mask(
45
- masks,
46
- plt.gca(),
47
- random_color=mask_random_color,
48
- bbox=None,
49
- points=None,
50
- pointlabel=None,
51
- retinamask=retina,
52
- target_height=original_h,
53
- target_width=original_w,
54
- )
55
 
56
- if with_contours:
57
- contour_all = []
58
- temp = np.zeros((original_h, original_w, 1))
59
- for i, mask in enumerate(masks):
60
- mask = mask.astype(np.uint8)
61
- if not retina:
62
- mask = cv2.resize(mask, (original_w, original_h), interpolation=cv2.INTER_NEAREST)
63
- contours, _ = cv2.findContours(mask, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)
64
- contour_all.extend(iter(contours))
65
- cv2.drawContours(temp, contour_all, -1, (255, 255, 255), 2)
66
- color = np.array([0 / 255, 0 / 255, 1.0, 0.8])
67
- contour_mask = temp / 255 * color.reshape(1, 1, -1)
68
- plt.imshow(contour_mask)
69
 
70
- plt.axis("off")
71
- plt.close()
72
- return fig2img(fig)
73
-
74
- def segment_image(input_image, object_name):
75
  try:
76
  if input_image is None:
77
  return None, "Please upload an image before submitting."
78
 
79
  input_image = Image.fromarray(input_image).convert("RGB")
80
 
81
- # Run FastSAM model with adjusted parameters
82
- everything_results = model(input_image, retina_masks=True, imgsz=1024, conf=0.25, iou=0.7)
83
 
84
  # Prepare a Prompt Process object
85
  prompt_process = FastSAMPrompt(input_image, everything_results, device=device)
86
 
87
- # Use text prompt to segment the specified object
88
- results = prompt_process.text_prompt(text=object_name)
89
-
90
- if not results:
91
- return input_image, f"Could not find '{object_name}' in the image."
92
-
93
- # Post-process the masks
94
- for ann in results:
95
- if ann.masks is not None:
96
- masks = ann.masks.data
97
- if isinstance(masks[0], torch.Tensor):
98
- masks = np.array(masks.cpu())
99
- for i, mask in enumerate(masks):
100
- # Apply more aggressive morphological operations
101
- kernel = np.ones((5,5), np.uint8)
102
- mask = cv2.morphologyEx(mask.astype(np.uint8), cv2.MORPH_CLOSE, kernel)
103
- mask = cv2.morphologyEx(mask.astype(np.uint8), cv2.MORPH_OPEN, kernel)
104
- masks[i] = cv2.dilate(mask, kernel, iterations=2)
105
- ann.masks.data = masks
106
 
107
  # Plot the results
108
- result_image = plot(annotations=results, prompt_process=prompt_process)
109
 
110
- return result_image, f"Segmented '{object_name}' in the image."
111
 
112
  except Exception as e:
113
  return None, f"An error occurred: {str(e)}"
114
 
115
  # Create Gradio interface
116
  iface = gr.Interface(
117
- fn=segment_image,
118
  inputs=[
119
- gr.Image(type="numpy", label="Upload an image"),
120
- gr.Textbox(label="Specify object to segment (e.g., dog, cat, grass)")
121
  ],
122
  outputs=[
123
  gr.Image(type="pil", label="Segmented Image"),
124
  gr.Textbox(label="Status")
125
  ],
126
- title="FastSAM Segmentation with Object Specification",
127
- description="Upload an image and specify an object to segment using FastSAM."
128
  )
129
 
130
  # Launch the interface
 
21
  img = Image.open(buf)
22
  return img
23
 
24
+ def plot_masks(annotations, output_shape):
25
+ fig, ax = plt.subplots(figsize=(10, 10))
26
+ ax.imshow(annotations[0].orig_img)
27
+
28
  for ann in annotations:
29
+ for mask in ann.masks.data:
30
+ mask = cv2.resize(mask.cpu().numpy().astype('uint8'), output_shape[::-1])
31
+ masked = np.ma.masked_where(mask == 0, mask)
32
+ ax.imshow(masked, alpha=0.5, cmap=plt.cm.get_cmap('jet'))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
 
34
+ ax.axis('off')
35
+ plt.close()
36
+ return fig2img(fig)
 
 
 
 
 
 
 
 
 
 
37
 
38
+ def segment_everything(input_image):
 
 
 
 
39
  try:
40
  if input_image is None:
41
  return None, "Please upload an image before submitting."
42
 
43
  input_image = Image.fromarray(input_image).convert("RGB")
44
 
45
+ # Run FastSAM model in "everything" mode
46
+ everything_results = model(input_image, device=device, retina_masks=True, imgsz=1024, conf=0.25, iou=0.9, agnostic_nms=True)
47
 
48
  # Prepare a Prompt Process object
49
  prompt_process = FastSAMPrompt(input_image, everything_results, device=device)
50
 
51
+ # Get everything segmentation
52
+ ann = prompt_process.everything_prompt()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
53
 
54
  # Plot the results
55
+ result_image = plot_masks(ann, input_image.size)
56
 
57
+ return result_image, f"Segmented everything in the image. Found {len(ann[0].masks)} objects."
58
 
59
  except Exception as e:
60
  return None, f"An error occurred: {str(e)}"
61
 
62
  # Create Gradio interface
63
  iface = gr.Interface(
64
+ fn=segment_everything,
65
  inputs=[
66
+ gr.Image(type="numpy", label="Upload an image")
 
67
  ],
68
  outputs=[
69
  gr.Image(type="pil", label="Segmented Image"),
70
  gr.Textbox(label="Status")
71
  ],
72
+ title="FastSAM Everything Segmentation",
73
+ description="Upload an image to segment all objects using FastSAM."
74
  )
75
 
76
  # Launch the interface