|
import argparse |
|
import os |
|
import copy |
|
|
|
import numpy as np |
|
import torch |
|
from PIL import Image, ImageDraw, ImageFont |
|
import PIL |
|
|
|
from transformers import OwlViTProcessor, OwlViTForObjectDetection |
|
|
|
|
|
from segment_anything import build_sam, SamPredictor |
|
import cv2 |
|
import numpy as np |
|
import matplotlib.pyplot as plt |
|
|
|
import gc |
|
|
|
def show_mask(mask, ax, random_color=False): |
|
if random_color: |
|
color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0) |
|
else: |
|
color = np.array([30/255, 144/255, 255/255, 0.6]) |
|
h, w = mask.shape[-2:] |
|
mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1) |
|
ax.imshow(mask_image) |
|
|
|
|
|
def show_box(box, ax): |
|
x0, y0 = box[0], box[1] |
|
w, h = box[2] - box[0], box[3] - box[1] |
|
ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor='green', facecolor=(0,0,0,0), lw=2)) |
|
|
|
def plot_boxes_to_image(image_pil, tgt): |
|
H, W = tgt["size"] |
|
boxes = tgt["boxes"] |
|
labels = tgt["labels"] |
|
assert len(boxes) == len(labels), "boxes and labels must have same length" |
|
|
|
draw = ImageDraw.Draw(image_pil) |
|
mask = Image.new("L", image_pil.size, 0) |
|
mask_draw = ImageDraw.Draw(mask) |
|
|
|
|
|
for box, label in zip(boxes, labels): |
|
|
|
color = tuple(np.random.randint(0, 255, size=3).tolist()) |
|
|
|
x0, y0, x1, y1 = box |
|
x0, y0, x1, y1 = int(x0), int(y0), int(x1), int(y1) |
|
|
|
draw.rectangle([x0, y0, x1, y1], outline=color, width=6) |
|
draw.text((x0, y0), str(label), fill=color) |
|
|
|
font = ImageFont.load_default() |
|
if hasattr(font, "getbbox"): |
|
bbox = draw.textbbox((x0, y0), str(label), font) |
|
else: |
|
w, h = draw.textsize(str(label), font) |
|
bbox = (x0, y0, w + x0, y0 + h) |
|
|
|
draw.rectangle(bbox, fill=color) |
|
draw.text((x0, y0), str(label), fill="white") |
|
|
|
mask_draw.rectangle([x0, y0, x1, y1], fill=255, width=6) |
|
|
|
return image_pil, mask |
|
|
|
|
|
if torch.cuda.is_available(): |
|
device = torch.device("cuda:4") |
|
else: |
|
device = torch.device("cpu") |
|
|
|
|
|
owlvit_model = OwlViTForObjectDetection.from_pretrained("google/owlvit-base-patch32").to(device) |
|
owlvit_model.eval() |
|
owlvit_processor = OwlViTProcessor.from_pretrained("google/owlvit-base-patch32") |
|
|
|
|
|
sam_predictor = SamPredictor(build_sam(checkpoint="./sam_vit_h_4b8939.pth")) |
|
|
|
def query_image(img, text_prompt): |
|
|
|
if not isinstance(img, PIL.Image.Image): |
|
pil_img = Image.fromarray(np.uint8(img)).convert('RGB') |
|
|
|
text_prompt = text_prompt |
|
texts = [text_prompt.split(",")] |
|
|
|
box_threshold = 0.0 |
|
|
|
|
|
with torch.no_grad(): |
|
inputs = owlvit_processor(text=texts, images=pil_img, return_tensors="pt").to(device) |
|
outputs = owlvit_model(**inputs) |
|
|
|
|
|
target_sizes = torch.Tensor([pil_img.size[::-1]]) |
|
|
|
results = owlvit_processor.post_process_object_detection(outputs=outputs, threshold=box_threshold, target_sizes=target_sizes.to(device)) |
|
scores = torch.sigmoid(outputs.logits) |
|
topk_scores, topk_idxs = torch.topk(scores, k=1, dim=1) |
|
|
|
i = 0 |
|
text = texts[i] |
|
|
|
topk_idxs = topk_idxs.squeeze(1).tolist() |
|
topk_boxes = results[i]['boxes'][topk_idxs] |
|
topk_scores = topk_scores.view(len(text), -1) |
|
topk_labels = results[i]["labels"][topk_idxs] |
|
boxes, scores, labels = topk_boxes, topk_scores, topk_labels |
|
|
|
|
|
|
|
|
|
boxes = boxes.cpu().detach().numpy() |
|
normalized_boxes = copy.deepcopy(boxes) |
|
|
|
|
|
size = pil_img.size |
|
pred_dict = { |
|
"boxes": normalized_boxes, |
|
"size": [size[1], size[0]], |
|
"labels": [text[idx] for idx in labels] |
|
} |
|
|
|
|
|
|
|
|
|
gc.collect() |
|
torch.cuda.empty_cache() |
|
|
|
|
|
open_cv_image = np.array(pil_img) |
|
image = cv2.cvtColor(open_cv_image, cv2.COLOR_BGR2RGB) |
|
sam_predictor.set_image(image) |
|
|
|
H, W = size[1], size[0] |
|
|
|
for i in range(boxes.shape[0]): |
|
boxes[i] = torch.Tensor(boxes[i]) |
|
|
|
boxes = torch.tensor(boxes, device=sam_predictor.device) |
|
|
|
transformed_boxes = sam_predictor.transform.apply_boxes_torch(boxes, image.shape[:2]) |
|
|
|
masks, _, _ = sam_predictor.predict_torch( |
|
point_coords = None, |
|
point_labels = None, |
|
boxes = transformed_boxes, |
|
multimask_output = False, |
|
) |
|
plt.figure(figsize=(10, 10)) |
|
plt.imshow(image) |
|
for mask in masks: |
|
show_mask(mask.cpu().numpy(), plt.gca(), random_color=True) |
|
for box in boxes: |
|
show_box(box.numpy(), plt.gca()) |
|
plt.axis('off') |
|
|
|
import io |
|
buf = io.BytesIO() |
|
plt.savefig(buf) |
|
buf.seek(0) |
|
owlvit_segment_image = Image.open(buf).convert('RGB') |
|
|
|
|
|
image_with_box = plot_boxes_to_image(pil_img, pred_dict)[0] |
|
|
|
return owlvit_segment_image, image_with_box |
|
|
|
|