jadechoghari commited on
Commit
5017de6
·
verified ·
1 Parent(s): 0e6a2e9

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +37 -40
app.py CHANGED
@@ -1,13 +1,13 @@
1
- # no cpu required
2
- #TODO: update to gpu usage
3
  from transformers import pipeline, SamModel, SamProcessor
4
  import torch
5
  import numpy as np
6
  import spaces
7
 
 
8
  checkpoint = "google/owlv2-base-patch16-ensemble"
9
- detector = pipeline(model=checkpoint, task="zero-shot-object-detection")
10
- sam_model = SamModel.from_pretrained("jadechoghari/robustsam-vit-base")
11
  sam_processor = SamProcessor.from_pretrained("jadechoghari/robustsam-vit-base")
12
 
13
 
@@ -23,57 +23,54 @@ def query(image, texts, threshold):
23
  result_labels = []
24
  for pred in predictions:
25
 
26
- box = pred["box"]
27
  score = pred["score"]
28
- label = pred["label"]
29
- box = [round(pred["box"]["xmin"], 2), round(pred["box"]["ymin"], 2),
30
- round(pred["box"]["xmax"], 2), round(pred["box"]["ymax"], 2)]
31
 
32
- inputs = sam_processor(
33
- image,
34
- input_boxes=[[[box]]],
35
- return_tensors="pt"
36
- )
37
 
38
- with torch.no_grad():
39
- outputs = sam_model(**inputs)
 
 
 
40
 
41
- mask = sam_processor.image_processor.post_process_masks(
42
- outputs.pred_masks.cpu(),
43
- inputs["original_sizes"].cpu(),
44
- inputs["reshaped_input_sizes"].cpu()
45
- )[0][0][0].numpy()
46
- mask = mask[np.newaxis, ...]
47
 
48
- from PIL import Image, ImageDraw
49
- # Convert mask to image format and overlay on the original image
50
- mask_image = Image.fromarray((mask[0] * 255).astype(np.uint8))
51
- mask_image = mask_image.convert("L") # Convert to grayscale for transparency
52
- mask_image = mask_image.resize(image.size)
53
-
54
- # Create an alpha mask for transparency
55
- alpha_mask = Image.new("L", mask_image.size, 128) # Adjust transparency level here
56
- image.paste(mask_image, (0, 0), alpha_mask) # Overlay the mask on the image
57
-
58
- # Save the annotated image
59
- image.save("annotated_image.png")
60
- print("saved image")
61
- result_labels.append((mask, label))
62
  return image, result_labels
63
 
64
  import gradio as gr
65
 
66
- description = "This Space combines OWLv2, the state-of-the-art zero-shot object detection model with SAM, the state-of-the-art mask generation model. SAM normally doesn't accept text input. Combining SAM with OWLv2 makes SAM text promptable. Try the example or input an image and comma separated candidate labels to segment."
 
 
 
 
 
67
  demo = gr.Interface(
68
  query,
69
  inputs=[gr.Image(type="pil", label="Image Input"), gr.Textbox(label = "Candidate Labels"), gr.Slider(0, 1, value=0.05, label="Confidence Threshold")],
70
- # outputs="annotatedimage", #comment this out - it looks weird
71
  outputs=gr.AnnotatedImage(label="Segmented Image"),
72
- title="OWL 🤝 SAM",
73
  description=description,
74
  examples=[
75
- ["./cats.png", "cat", 0.1],
 
 
 
76
  ],
77
  cache_examples=True
78
  )
79
- demo.launch(debug=True)
 
1
+ # no gpu required
 
2
  from transformers import pipeline, SamModel, SamProcessor
3
  import torch
4
  import numpy as np
5
  import spaces
6
 
7
+ device = "cuda" if torch.cuda.is_available() else "cpu"
8
  checkpoint = "google/owlv2-base-patch16-ensemble"
9
+ detector = pipeline(model=checkpoint, task="zero-shot-object-detection", device=device)
10
+ sam_model = SamModel.from_pretrained("jadechoghari/robustsam-vit-base").to(device)
11
  sam_processor = SamProcessor.from_pretrained("jadechoghari/robustsam-vit-base")
12
 
13
 
 
23
  result_labels = []
24
  for pred in predictions:
25
 
26
+
27
  score = pred["score"]
 
 
 
28
 
29
+ if score > 0.5:
30
+ box = pred["box"]
31
+ label = pred["label"]
32
+ box = [round(pred["box"]["xmin"], 2), round(pred["box"]["ymin"], 2),
33
+ round(pred["box"]["xmax"], 2), round(pred["box"]["ymax"], 2)]
34
 
35
+ inputs = sam_processor(
36
+ image,
37
+ input_boxes=[[[box]]],
38
+ return_tensors="pt"
39
+ ).to(device)
40
 
41
+ with torch.no_grad():
42
+ outputs = sam_model(**inputs)
 
 
 
 
43
 
44
+ mask = sam_processor.image_processor.post_process_masks(
45
+ outputs.pred_masks.cpu(),
46
+ inputs["original_sizes"].cpu(),
47
+ inputs["reshaped_input_sizes"].cpu()
48
+ )[0][0][0].numpy()
49
+ mask = mask[np.newaxis, ...]
50
+ result_labels.append((mask, label))
51
+
 
 
 
 
 
 
52
  return image, result_labels
53
 
54
  import gradio as gr
55
 
56
+ description = (
57
+ "Welcome to RobustSAM by Snap Research."
58
+ "This Space uses RobustSAM, an enhanced version of the Segment Anything Model (SAM) with improved performance on low-quality images while maintaining zero-shot segmentation capabilities. "
59
+ "Thanks to its integration with OWLv2, RobustSAM becomes text-promptable, allowing for flexible and accurate segmentation, even with degraded image quality. Try the example or input an image with comma-separated candidate labels to see the enhanced segmentation results."
60
+ )
61
+
62
  demo = gr.Interface(
63
  query,
64
  inputs=[gr.Image(type="pil", label="Image Input"), gr.Textbox(label = "Candidate Labels"), gr.Slider(0, 1, value=0.05, label="Confidence Threshold")],
 
65
  outputs=gr.AnnotatedImage(label="Segmented Image"),
66
+ title="RobustSAM",
67
  description=description,
68
  examples=[
69
+ ["./blur.jpg", "insect", 0.1],
70
+ ["./lowlight.jpg", "bus, window", 0.1],
71
+ ["./rain.jpg", "tree, leafs", 0.1],
72
+ ["./haze.jpg", "", 0.1],
73
  ],
74
  cache_examples=True
75
  )
76
+ demo.launch()