jadechoghari commited on
Commit
c73b59b
·
verified ·
1 Parent(s): c35872a

Update app.py

Browse files

gr.AnnotatedImage from gradio seems down in HF spaces, we'll replace it with gr.Image

Files changed (1) hide show
  1. app.py +55 -40
app.py CHANGED
@@ -1,68 +1,82 @@
1
- # no gpu required
2
  from transformers import pipeline, SamModel, SamProcessor
3
  import torch
4
  import numpy as np
5
- import spaces
 
6
 
 
7
  device = "cuda" if torch.cuda.is_available() else "cpu"
 
 
8
  checkpoint = "google/owlv2-base-patch16-ensemble"
9
  detector = pipeline(model=checkpoint, task="zero-shot-object-detection", device=device)
10
  sam_model = SamModel.from_pretrained("jadechoghari/robustsam-vit-base").to(device)
11
  sam_processor = SamProcessor.from_pretrained("jadechoghari/robustsam-vit-base")
12
 
 
 
 
 
 
13
 
14
  def query(image, texts, threshold):
15
- texts = texts.split(",")
16
-
17
- predictions = detector(
18
- image,
19
- candidate_labels=texts,
20
- threshold=threshold
21
- )
22
-
23
- result_labels = []
24
- for pred in predictions:
25
-
26
 
27
- score = pred["score"]
28
-
29
- if score > 0.5:
30
- box = pred["box"]
31
- label = pred["label"]
32
- box = [round(pred["box"]["xmin"], 2), round(pred["box"]["ymin"], 2),
33
- round(pred["box"]["xmax"], 2), round(pred["box"]["ymax"], 2)]
 
 
 
 
 
 
 
 
 
34
 
35
- inputs = sam_processor(
36
- image,
37
- input_boxes=[[[box]]],
38
- return_tensors="pt"
39
- ).to(device)
40
 
41
- with torch.no_grad():
42
- outputs = sam_model(**inputs)
43
 
44
- mask = sam_processor.image_processor.post_process_masks(
45
- outputs.pred_masks.cpu(),
46
- inputs["original_sizes"].cpu(),
47
- inputs["reshaped_input_sizes"].cpu()
48
- )[0][0][0].numpy()
49
- mask = mask[np.newaxis, ...]
50
- result_labels.append((mask, label))
51
-
52
- return image, result_labels
53
 
54
- import gradio as gr
 
 
55
 
56
  description = (
57
  "Welcome to RobustSAM by Snap Research."
58
- "This Space uses RobustSAM, an enhanced version of the Segment Anything Model (SAM) with improved performance on low-quality images while maintaining zero-shot segmentation capabilities. "
59
  "Thanks to its integration with OWLv2, RobustSAM becomes text-promptable, allowing for flexible and accurate segmentation, even with degraded image quality. Try the example or input an image with comma-separated candidate labels to see the enhanced segmentation results."
60
  )
61
 
62
  demo = gr.Interface(
63
  query,
64
- inputs=[gr.Image(type="pil", label="Image Input"), gr.Textbox(label = "Candidate Labels"), gr.Slider(0, 1, value=0.05, label="Confidence Threshold")],
65
- outputs=gr.AnnotatedImage(label="Segmented Image"),
66
  title="RobustSAM",
67
  description=description,
68
  examples=[
@@ -73,4 +87,5 @@ demo = gr.Interface(
73
  ],
74
  cache_examples=True
75
  )
 
76
  demo.launch()
 
 
1
  from transformers import pipeline, SamModel, SamProcessor
2
  import torch
3
  import numpy as np
4
+ import gradio as gr
5
+ from PIL import Image
6
 
7
+ # check if cuda is available
8
  device = "cuda" if torch.cuda.is_available() else "cpu"
9
+
10
+ # we initialize model and processor
11
  checkpoint = "google/owlv2-base-patch16-ensemble"
12
  detector = pipeline(model=checkpoint, task="zero-shot-object-detection", device=device)
13
  sam_model = SamModel.from_pretrained("jadechoghari/robustsam-vit-base").to(device)
14
  sam_processor = SamProcessor.from_pretrained("jadechoghari/robustsam-vit-base")
15
 
16
+ def apply_mask(image, mask, color):
17
+ """Apply a mask to an image with a specific color."""
18
+ for c in range(3): # iterate over rgb channels
19
+ image[:, :, c] = np.where(mask, color[c], image[:, :, c])
20
+ return image
21
 
22
  def query(image, texts, threshold):
23
+ texts = texts.split(",")
24
+ predictions = detector(
25
+ image,
26
+ candidate_labels=texts,
27
+ threshold=threshold
28
+ )
 
 
 
 
 
29
 
30
+ image = np.array(image).copy()
31
+
32
+ colors = [
33
+ (255, 0, 0), # Red
34
+ (0, 255, 0), # Green
35
+ (0, 0, 255), # Blue
36
+ (255, 255, 0), # Yellow
37
+ (255, 165, 0), # Orange
38
+ (255, 0, 255) # Magenta
39
+ ]
40
+
41
+ for i, pred in enumerate(predictions):
42
+ score = pred["score"]
43
+ if score > 0.5:
44
+ box = [round(pred["box"]["xmin"], 2), round(pred["box"]["ymin"], 2),
45
+ round(pred["box"]["xmax"], 2), round(pred["box"]["ymax"], 2)]
46
 
47
+ inputs = sam_processor(
48
+ image,
49
+ input_boxes=[[[box]]],
50
+ return_tensors="pt"
51
+ ).to(device)
52
 
53
+ with torch.no_grad():
54
+ outputs = sam_model(**inputs)
55
 
56
+ mask = sam_processor.image_processor.post_process_masks(
57
+ outputs.pred_masks.cpu(),
58
+ inputs["original_sizes"].cpu(),
59
+ inputs["reshaped_input_sizes"].cpu()
60
+ )[0][0][0].numpy()
61
+
62
+ # we apply the mask with the corresponding color
63
+ color = colors[i % len(colors)] # we cycle through colors
64
+ image = apply_mask(image, mask > 0.5, color)
65
 
66
+ result_image = Image.fromarray(image)
67
+
68
+ return result_image
69
 
70
  description = (
71
  "Welcome to RobustSAM by Snap Research."
72
+ "This Space uses RobustSAM, a robust version of the Segment Anything Model (SAM) with improved performance on low-quality images while maintaining zero-shot segmentation capabilities. "
73
  "Thanks to its integration with OWLv2, RobustSAM becomes text-promptable, allowing for flexible and accurate segmentation, even with degraded image quality. Try the example or input an image with comma-separated candidate labels to see the enhanced segmentation results."
74
  )
75
 
76
  demo = gr.Interface(
77
  query,
78
+ inputs=[gr.Image(type="pil", label="Image Input"), gr.Textbox(label="Candidate Labels"), gr.Slider(0, 1, value=0.05, label="Confidence Threshold")],
79
+ outputs=gr.Image(type="pil", label="Segmented Image"),
80
  title="RobustSAM",
81
  description=description,
82
  examples=[
 
87
  ],
88
  cache_examples=True
89
  )
90
+
91
  demo.launch()