Spaces:
Running
on
Zero
Running
on
Zero
Update app.py
Browse files
app.py
CHANGED
@@ -1,88 +1,74 @@
|
|
1 |
-
|
2 |
-
|
|
|
3 |
import torch
|
4 |
-
|
5 |
-
import
|
6 |
-
from transformers import SamModel, SamProcessor
|
7 |
-
import cv2
|
8 |
-
from typing import List
|
9 |
-
|
10 |
-
device = "cuda" if torch.cuda.is_available() else "cpu"
|
11 |
-
|
12 |
-
# we load model and processor
|
13 |
-
model = SamModel.from_pretrained("jadechoghari/robustsam-vit-base").to(device)
|
14 |
-
processor = SamProcessor.from_pretrained("jadechoghari/robustsam-vit-base")
|
15 |
|
16 |
-
|
|
|
|
|
|
|
17 |
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
-
kernel = np.ones((5,5),np.uint8)
|
22 |
-
closed = cv2.morphologyEx(thresh, cv2.MORPH_CLOSE, kernel)
|
23 |
-
contours, _ = cv2.findContours(closed, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE)
|
24 |
-
points = []
|
25 |
-
for contour in contours:
|
26 |
-
moments = cv2.moments(contour)
|
27 |
-
cx = int(moments['m10']/moments['m00'])
|
28 |
-
cy = int(moments['m01']/moments['m00'])
|
29 |
-
points.append([cx, cy])
|
30 |
-
return [points]
|
31 |
|
32 |
-
|
33 |
-
|
34 |
-
|
35 |
-
|
36 |
-
|
37 |
-
if not cache_data or not torch.equal(inputs['pixel_values'],cache_data[0]):
|
38 |
-
embedding = model.get_image_embeddings(inputs["pixel_values"])
|
39 |
-
pixels = inputs["pixel_values"]
|
40 |
-
cache_data = [pixels, embedding]
|
41 |
-
del inputs["pixel_values"]
|
42 |
|
43 |
-
|
44 |
-
|
45 |
-
outputs.pred_masks.cpu(), inputs["original_sizes"].cpu(), inputs["reshaped_input_sizes"].cpu()
|
46 |
-
)
|
47 |
-
masks = masks[0].squeeze(0).numpy().transpose(1, 2, 0)
|
48 |
|
49 |
-
|
|
|
|
|
|
|
|
|
50 |
|
51 |
-
|
52 |
-
|
53 |
-
|
54 |
-
|
55 |
-
|
56 |
|
57 |
-
|
58 |
-
|
59 |
-
for point in points[0]:
|
60 |
-
draw.ellipse((point[0] - 10, point[1] - 10, point[0] + 10, point[1] + 10), fill="red")
|
61 |
|
62 |
-
|
63 |
-
|
64 |
-
|
|
|
|
|
|
|
65 |
|
66 |
-
|
|
|
|
|
|
|
|
|
67 |
|
68 |
-
|
69 |
-
|
70 |
-
|
71 |
|
72 |
-
|
73 |
-
|
74 |
-
|
75 |
-
|
76 |
-
|
77 |
-
""")
|
78 |
-
with gr.Tab("Flip Image"):
|
79 |
-
with gr.Row():
|
80 |
-
image_input = gr.ImageEditor()
|
81 |
-
image_output = gr.Gallery()
|
82 |
-
|
83 |
-
image_button = gr.Button("Segment Image")
|
84 |
|
85 |
-
|
86 |
-
image_input.upload(reset_data)
|
87 |
|
88 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# no cpu required
|
2 |
+
#TODO: update to gpu usage
|
3 |
+
from transformers import pipeline, SamModel, SamProcessor
|
4 |
import torch
|
5 |
+
import numpy as np
|
6 |
+
import spaces
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
7 |
|
8 |
+
checkpoint = "google/owlv2-base-patch16-ensemble"
|
9 |
+
detector = pipeline(model=checkpoint, task="zero-shot-object-detection")
|
10 |
+
sam_model = SamModel.from_pretrained("jadechoghari/robustsam-vit-base")
|
11 |
+
sam_processor = SamProcessor.from_pretrained("jadechoghari/robustsam-vit-base")
|
12 |
|
13 |
+
@spaces.GPU
|
14 |
+
def query(image, texts, threshold):
|
15 |
+
texts = texts.split(",")
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
16 |
|
17 |
+
predictions = detector(
|
18 |
+
image,
|
19 |
+
candidate_labels=texts,
|
20 |
+
threshold=threshold
|
21 |
+
)
|
|
|
|
|
|
|
|
|
|
|
22 |
|
23 |
+
result_labels = []
|
24 |
+
for pred in predictions:
|
|
|
|
|
|
|
25 |
|
26 |
+
box = pred["box"]
|
27 |
+
score = pred["score"]
|
28 |
+
label = pred["label"]
|
29 |
+
box = [round(pred["box"]["xmin"], 2), round(pred["box"]["ymin"], 2),
|
30 |
+
round(pred["box"]["xmax"], 2), round(pred["box"]["ymax"], 2)]
|
31 |
|
32 |
+
inputs = sam_processor(
|
33 |
+
image,
|
34 |
+
input_boxes=[[[box]]],
|
35 |
+
return_tensors="pt"
|
36 |
+
)
|
37 |
|
38 |
+
with torch.no_grad():
|
39 |
+
outputs = sam_model(**inputs)
|
|
|
|
|
40 |
|
41 |
+
mask = sam_processor.image_processor.post_process_masks(
|
42 |
+
outputs.pred_masks.cpu(),
|
43 |
+
inputs["original_sizes"].cpu(),
|
44 |
+
inputs["reshaped_input_sizes"].cpu()
|
45 |
+
)[0][0][0].numpy()
|
46 |
+
mask = mask[np.newaxis, ...]
|
47 |
|
48 |
+
from PIL import Image, ImageDraw
|
49 |
+
# Convert mask to image format and overlay on the original image
|
50 |
+
mask_image = Image.fromarray((mask[0] * 255).astype(np.uint8))
|
51 |
+
mask_image = mask_image.convert("L") # Convert to grayscale for transparency
|
52 |
+
mask_image = mask_image.resize(image.size)
|
53 |
|
54 |
+
# Create an alpha mask for transparency
|
55 |
+
alpha_mask = Image.new("L", mask_image.size, 128) # Adjust transparency level here
|
56 |
+
image.paste(mask_image, (0, 0), alpha_mask) # Overlay the mask on the image
|
57 |
|
58 |
+
# Save the annotated image
|
59 |
+
image.save("annotated_image.png")
|
60 |
+
print("saved image")
|
61 |
+
result_labels.append((mask, label))
|
62 |
+
return image, result_labels
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
63 |
|
64 |
+
import gradio as gr
|
|
|
65 |
|
66 |
+
description = "This Space combines OWLv2, the state-of-the-art zero-shot object detection model with SAM, the state-of-the-art mask generation model. SAM normally doesn't accept text input. Combining SAM with OWLv2 makes SAM text promptable. Try the example or input an image and comma separated candidate labels to segment."
|
67 |
+
demo = gr.Interface(
|
68 |
+
query,
|
69 |
+
inputs=[gr.Image(type="pil", label="Image Input"), gr.Textbox(label = "Candidate Labels"), gr.Slider(0, 1, value=0.05, label="Confidence Threshold")],
|
70 |
+
outputs="annotatedimage",
|
71 |
+
title="OWL 🤝 SAM",
|
72 |
+
description=description
|
73 |
+
)
|
74 |
+
demo.launch(debug=True)
|