ngthanhtinqn's picture
fix text prompt
6b44c63
raw
history blame
5.48 kB
import argparse
import os
import copy
import numpy as np
import torch
from PIL import Image, ImageDraw, ImageFont
import PIL
# OwlViT Detection
from transformers import OwlViTProcessor, OwlViTForObjectDetection
# segment anything
from segment_anything import build_sam, SamPredictor
import cv2
import numpy as np
import matplotlib.pyplot as plt
import gc
def show_mask(mask, ax, random_color=False):
if random_color:
color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
else:
color = np.array([30/255, 144/255, 255/255, 0.6])
h, w = mask.shape[-2:]
mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
ax.imshow(mask_image)
def show_box(box, ax):
x0, y0 = box[0], box[1]
w, h = box[2] - box[0], box[3] - box[1]
ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor='green', facecolor=(0,0,0,0), lw=2))
def plot_boxes_to_image(image_pil, tgt):
H, W = tgt["size"]
boxes = tgt["boxes"]
labels = tgt["labels"]
assert len(boxes) == len(labels), "boxes and labels must have same length"
draw = ImageDraw.Draw(image_pil)
mask = Image.new("L", image_pil.size, 0)
mask_draw = ImageDraw.Draw(mask)
# draw boxes and masks
for box, label in zip(boxes, labels):
# random color
color = tuple(np.random.randint(0, 255, size=3).tolist())
# draw
x0, y0, x1, y1 = box
x0, y0, x1, y1 = int(x0), int(y0), int(x1), int(y1)
draw.rectangle([x0, y0, x1, y1], outline=color, width=6)
draw.text((x0, y0), str(label), fill=color)
font = ImageFont.load_default()
if hasattr(font, "getbbox"):
bbox = draw.textbbox((x0, y0), str(label), font)
else:
w, h = draw.textsize(str(label), font)
bbox = (x0, y0, w + x0, y0 + h)
# bbox = draw.textbbox((x0, y0), str(label))
draw.rectangle(bbox, fill=color)
draw.text((x0, y0), str(label), fill="white")
mask_draw.rectangle([x0, y0, x1, y1], fill=255, width=6)
return image_pil, mask
# Use GPU if available
if torch.cuda.is_available():
device = torch.device("cuda:4")
else:
device = torch.device("cpu")
# load OWL-ViT model
owlvit_model = OwlViTForObjectDetection.from_pretrained("google/owlvit-base-patch32").to(device)
owlvit_model.eval()
owlvit_processor = OwlViTProcessor.from_pretrained("google/owlvit-base-patch32")
# run segment anything (SAM)
sam_predictor = SamPredictor(build_sam(checkpoint="./sam_vit_h_4b8939.pth"))
def query_image(img, text_prompt):
# load image
if not isinstance(img, PIL.Image.Image):
pil_img = Image.fromarray(np.uint8(img)).convert('RGB')
text_prompt = text_prompt
texts = [text_prompt.split(",")]
box_threshold = 0.0
# run object detection model
with torch.no_grad():
inputs = owlvit_processor(text=texts, images=pil_img, return_tensors="pt").to(device)
outputs = owlvit_model(**inputs)
# Target image sizes (height, width) to rescale box predictions [batch_size, 2]
target_sizes = torch.Tensor([pil_img.size[::-1]])
# Convert outputs (bounding boxes and class logits) to COCO API
results = owlvit_processor.post_process_object_detection(outputs=outputs, threshold=box_threshold, target_sizes=target_sizes.to(device))
scores = torch.sigmoid(outputs.logits)
topk_scores, topk_idxs = torch.topk(scores, k=1, dim=1)
i = 0 # Retrieve predictions for the first image for the corresponding text queries
text = texts[i]
topk_idxs = topk_idxs.squeeze(1).tolist()
topk_boxes = results[i]['boxes'][topk_idxs]
topk_scores = topk_scores.view(len(text), -1)
topk_labels = results[i]["labels"][topk_idxs]
boxes, scores, labels = topk_boxes, topk_scores, topk_labels
# boxes, scores, labels = results[i]["boxes"], results[i]["scores"], results[i]["labels"]
boxes = boxes.cpu().detach().numpy()
normalized_boxes = copy.deepcopy(boxes)
# # visualize pred
size = pil_img.size
pred_dict = {
"boxes": normalized_boxes,
"size": [size[1], size[0]], # H, W
"labels": [text[idx] for idx in labels]
}
# release the OWL-ViT
# owlvit_model.cpu()
# del owlvit_model
gc.collect()
torch.cuda.empty_cache()
# run segment anything (SAM)
open_cv_image = np.array(pil_img)
image = cv2.cvtColor(open_cv_image, cv2.COLOR_BGR2RGB)
sam_predictor.set_image(image)
H, W = size[1], size[0]
for i in range(boxes.shape[0]):
boxes[i] = torch.Tensor(boxes[i])
boxes = torch.tensor(boxes, device=sam_predictor.device)
transformed_boxes = sam_predictor.transform.apply_boxes_torch(boxes, image.shape[:2])
masks, _, _ = sam_predictor.predict_torch(
point_coords = None,
point_labels = None,
boxes = transformed_boxes,
multimask_output = False,
)
plt.figure(figsize=(10, 10))
plt.imshow(image)
for mask in masks:
show_mask(mask.cpu().numpy(), plt.gca(), random_color=True)
for box in boxes:
show_box(box.numpy(), plt.gca())
plt.axis('off')
import io
buf = io.BytesIO()
plt.savefig(buf)
buf.seek(0)
owlvit_segment_image = Image.open(buf).convert('RGB')
# grounded results
image_with_box = plot_boxes_to_image(pil_img, pred_dict)[0]
return owlvit_segment_image, image_with_box
# return owlvit_segment_image