jadechoghari commited on
Commit
420fa3e
·
verified ·
1 Parent(s): d9a0be3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +60 -74
app.py CHANGED
@@ -1,88 +1,74 @@
1
- import gradio as gr
2
- import numpy as np
 
3
  import torch
4
- from PIL import Image, ImageDraw
5
- import requests
6
- from transformers import SamModel, SamProcessor
7
- import cv2
8
- from typing import List
9
-
10
- device = "cuda" if torch.cuda.is_available() else "cpu"
11
-
12
- # we load model and processor
13
- model = SamModel.from_pretrained("jadechoghari/robustsam-vit-base").to(device)
14
- processor = SamProcessor.from_pretrained("jadechoghari/robustsam-vit-base")
15
 
16
- cache_data = None
 
 
 
17
 
18
- def mask_2_dots(mask: np.ndarray) -> List[List[int]]:
19
- gray = cv2.cvtColor(mask, cv2.COLOR_RGB2GRAY)
20
- _, thresh = cv2.threshold(gray, 127, 255, 0)
21
- kernel = np.ones((5,5),np.uint8)
22
- closed = cv2.morphologyEx(thresh, cv2.MORPH_CLOSE, kernel)
23
- contours, _ = cv2.findContours(closed, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)
24
- points = []
25
- for contour in contours:
26
- moments = cv2.moments(contour)
27
- cx = int(moments['m10']/moments['m00'])
28
- cy = int(moments['m01']/moments['m00'])
29
- points.append([cx, cy])
30
- return [points]
31
 
32
- @torch.no_grad()
33
- def foward_pass(image_input: np.ndarray, points: List[List[int]]) -> np.ndarray:
34
- global cache_data
35
- image_input = Image.fromarray(image_input)
36
- inputs = processor(image_input, input_points=points, return_tensors="pt").to(device)
37
- if not cache_data or not torch.equal(inputs['pixel_values'],cache_data[0]):
38
- embedding = model.get_image_embeddings(inputs["pixel_values"])
39
- pixels = inputs["pixel_values"]
40
- cache_data = [pixels, embedding]
41
- del inputs["pixel_values"]
42
 
43
- outputs = model.forward(image_embeddings=cache_data[1], **inputs)
44
- masks = processor.image_processor.post_process_masks(
45
- outputs.pred_masks.cpu(), inputs["original_sizes"].cpu(), inputs["reshaped_input_sizes"].cpu()
46
- )
47
- masks = masks[0].squeeze(0).numpy().transpose(1, 2, 0)
48
 
49
- return masks
 
 
 
 
50
 
51
- def main_func(inputs) -> List[Image.Image]:
52
- dots = inputs['mask']
53
- points = mask_2_dots(dots)
54
- image_input = inputs['image']
55
- masks = foward_pass(image_input, points)
56
 
57
- image_input = Image.fromarray(image_input)
58
- draw = ImageDraw.Draw(image_input)
59
- for point in points[0]:
60
- draw.ellipse((point[0] - 10, point[1] - 10, point[0] + 10, point[1] + 10), fill="red")
61
 
62
- pred_masks = [image_input]
63
- for i in range(masks.shape[2]):
64
- pred_masks.append(Image.fromarray((masks[:,:,i] * 255).astype(np.uint8)))
 
 
 
65
 
66
- return pred_masks
 
 
 
 
67
 
68
- def reset_data():
69
- global cache_data
70
- cache_data = None
71
 
72
- with gr.Blocks() as demo:
73
- gr.Markdown("# How to use")
74
- gr.Markdown("To start, input an image, then use the brush to create dots on the object which you want to segment, don't worry if your dots aren't perfect as the code will find the middle of each drawn item. Then press the segment button to create masks for the object that the dots are on.")
75
- gr.Markdown("# Demo to run Robust Segment Anything base model")
76
- gr.Markdown("""This app uses the [Robust Segment Anything](https://huggingface.co/jadechoghari/robustsam-vit-base) model from Snap Research to get a mask from a points in an image.
77
- """)
78
- with gr.Tab("Flip Image"):
79
- with gr.Row():
80
- image_input = gr.ImageEditor()
81
- image_output = gr.Gallery()
82
-
83
- image_button = gr.Button("Segment Image")
84
 
85
- image_button.click(main_func, inputs=image_input, outputs=image_output)
86
- image_input.upload(reset_data)
87
 
88
- demo.launch()
 
 
 
 
 
 
 
 
 
1
+ # no cpu required
2
+ #TODO: update to gpu usage
3
+ from transformers import pipeline, SamModel, SamProcessor
4
  import torch
5
+ import numpy as np
6
+ import spaces
 
 
 
 
 
 
 
 
 
7
 
8
+ checkpoint = "google/owlv2-base-patch16-ensemble"
9
+ detector = pipeline(model=checkpoint, task="zero-shot-object-detection")
10
+ sam_model = SamModel.from_pretrained("jadechoghari/robustsam-vit-base")
11
+ sam_processor = SamProcessor.from_pretrained("jadechoghari/robustsam-vit-base")
12
 
13
+ @spaces.GPU
14
+ def query(image, texts, threshold):
15
+ texts = texts.split(",")
 
 
 
 
 
 
 
 
 
 
16
 
17
+ predictions = detector(
18
+ image,
19
+ candidate_labels=texts,
20
+ threshold=threshold
21
+ )
 
 
 
 
 
22
 
23
+ result_labels = []
24
+ for pred in predictions:
 
 
 
25
 
26
+ box = pred["box"]
27
+ score = pred["score"]
28
+ label = pred["label"]
29
+ box = [round(pred["box"]["xmin"], 2), round(pred["box"]["ymin"], 2),
30
+ round(pred["box"]["xmax"], 2), round(pred["box"]["ymax"], 2)]
31
 
32
+ inputs = sam_processor(
33
+ image,
34
+ input_boxes=[[[box]]],
35
+ return_tensors="pt"
36
+ )
37
 
38
+ with torch.no_grad():
39
+ outputs = sam_model(**inputs)
 
 
40
 
41
+ mask = sam_processor.image_processor.post_process_masks(
42
+ outputs.pred_masks.cpu(),
43
+ inputs["original_sizes"].cpu(),
44
+ inputs["reshaped_input_sizes"].cpu()
45
+ )[0][0][0].numpy()
46
+ mask = mask[np.newaxis, ...]
47
 
48
+ from PIL import Image, ImageDraw
49
+ # Convert mask to image format and overlay on the original image
50
+ mask_image = Image.fromarray((mask[0] * 255).astype(np.uint8))
51
+ mask_image = mask_image.convert("L") # Convert to grayscale for transparency
52
+ mask_image = mask_image.resize(image.size)
53
 
54
+ # Create an alpha mask for transparency
55
+ alpha_mask = Image.new("L", mask_image.size, 128) # Adjust transparency level here
56
+ image.paste(mask_image, (0, 0), alpha_mask) # Overlay the mask on the image
57
 
58
+ # Save the annotated image
59
+ image.save("annotated_image.png")
60
+ print("saved image")
61
+ result_labels.append((mask, label))
62
+ return image, result_labels
 
 
 
 
 
 
 
63
 
64
+ import gradio as gr
 
65
 
66
+ description = "This Space combines OWLv2, the state-of-the-art zero-shot object detection model with SAM, the state-of-the-art mask generation model. SAM normally doesn't accept text input. Combining SAM with OWLv2 makes SAM text promptable. Try the example or input an image and comma separated candidate labels to segment."
67
+ demo = gr.Interface(
68
+ query,
69
+ inputs=[gr.Image(type="pil", label="Image Input"), gr.Textbox(label = "Candidate Labels"), gr.Slider(0, 1, value=0.05, label="Confidence Threshold")],
70
+ outputs="annotatedimage",
71
+ title="OWL 🤝 SAM",
72
+ description=description
73
+ )
74
+ demo.launch(debug=True)