ODR's picture
Duplicate from haotiz/glip-zeroshot-demo
12e55f6
import cv2
import torch
import re
import numpy as np
from typing import List, Union
import nltk
import inflect
from transformers import AutoTokenizer
from torchvision import transforms as T
import pdb
from maskrcnn_benchmark.modeling.detector import build_detection_model
from maskrcnn_benchmark.utils.checkpoint import DetectronCheckpointer
from maskrcnn_benchmark.structures.image_list import to_image_list
from maskrcnn_benchmark.structures.boxlist_ops import boxlist_iou
from maskrcnn_benchmark.structures.bounding_box import BoxList
from maskrcnn_benchmark import layers as L
from maskrcnn_benchmark.modeling.roi_heads.mask_head.inference import Masker
from maskrcnn_benchmark.utils import cv2_util
engine = inflect.engine()
nltk.download('punkt')
nltk.download('averaged_perceptron_tagger')
import timeit
class GLIPDemo(object):
def __init__(self,
cfg,
confidence_threshold=0.7,
min_image_size=None,
show_mask_heatmaps=False,
masks_per_dim=5,
load_model=True
):
self.cfg = cfg.clone()
if load_model:
self.model = build_detection_model(cfg)
self.model.eval()
self.device = torch.device(cfg.MODEL.DEVICE)
self.model.to(self.device)
self.min_image_size = min_image_size
self.show_mask_heatmaps = show_mask_heatmaps
self.masks_per_dim = masks_per_dim
save_dir = cfg.OUTPUT_DIR
if load_model:
checkpointer = DetectronCheckpointer(cfg, self.model, save_dir=save_dir)
_ = checkpointer.load(cfg.MODEL.WEIGHT)
self.transforms = self.build_transform()
# used to make colors for each tokens
mask_threshold = -1 if show_mask_heatmaps else 0.5
self.masker = Masker(threshold=mask_threshold, padding=1)
self.palette = torch.tensor([2 ** 25 - 1, 2 ** 15 - 1, 2 ** 21 - 1])
self.cpu_device = torch.device("cpu")
self.confidence_threshold = confidence_threshold
self.tokenizer = self.build_tokenizer()
def build_transform(self):
"""
Creates a basic transformation that was used to train the models
"""
cfg = self.cfg
# we are loading images with OpenCV, so we don't need to convert them
# to BGR, they are already! So all we need to do is to normalize
# by 255 if we want to convert to BGR255 format, or flip the channels
# if we want it to be in RGB in [0-1] range.
if cfg.INPUT.TO_BGR255:
to_bgr_transform = T.Lambda(lambda x: x * 255)
else:
to_bgr_transform = T.Lambda(lambda x: x[[2, 1, 0]])
normalize_transform = T.Normalize(
mean=cfg.INPUT.PIXEL_MEAN, std=cfg.INPUT.PIXEL_STD
)
transform = T.Compose(
[
T.ToPILImage(),
T.Resize(self.min_image_size) if self.min_image_size is not None else lambda x: x,
T.ToTensor(),
to_bgr_transform,
normalize_transform,
]
)
return transform
def build_tokenizer(self):
cfg = self.cfg
tokenizer = None
if cfg.MODEL.LANGUAGE_BACKBONE.TOKENIZER_TYPE == "bert-base-uncased":
tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
elif cfg.MODEL.LANGUAGE_BACKBONE.TOKENIZER_TYPE == "clip":
from transformers import CLIPTokenizerFast
if cfg.MODEL.DYHEAD.FUSE_CONFIG.MLM_LOSS:
tokenizer = CLIPTokenizerFast.from_pretrained("openai/clip-vit-base-patch32",
from_slow=True, mask_token='ðŁĴij</w>')
else:
tokenizer = CLIPTokenizerFast.from_pretrained("openai/clip-vit-base-patch32",
from_slow=True)
return tokenizer
def run_ner(self, caption):
noun_phrases = find_noun_phrases(caption)
noun_phrases = [remove_punctuation(phrase) for phrase in noun_phrases]
noun_phrases = [phrase for phrase in noun_phrases if phrase != '']
relevant_phrases = noun_phrases
labels = noun_phrases
self.entities = labels
tokens_positive = []
for entity, label in zip(relevant_phrases, labels):
try:
# search all occurrences and mark them as different entities
for m in re.finditer(entity, caption.lower()):
tokens_positive.append([[m.start(), m.end()]])
except:
print("noun entities:", noun_phrases)
print("entity:", entity)
print("caption:", caption.lower())
return tokens_positive
def inference(self, original_image, original_caption):
predictions = self.compute_prediction(original_image, original_caption)
top_predictions = self._post_process_fixed_thresh(predictions)
return top_predictions
def run_on_web_image(self,
original_image,
original_caption,
thresh=0.5,
custom_entity=None,
alpha=0.0):
predictions = self.compute_prediction(original_image, original_caption, custom_entity)
top_predictions = self._post_process(predictions, thresh)
result = original_image.copy()
if self.show_mask_heatmaps:
return self.create_mask_montage(result, top_predictions)
result = self.overlay_boxes(result, top_predictions)
result = self.overlay_entity_names(result, top_predictions)
if self.cfg.MODEL.MASK_ON:
result = self.overlay_mask(result, top_predictions)
return result, top_predictions
def visualize_with_predictions(self,
original_image,
predictions,
thresh=0.5,
alpha=0.0,
box_pixel=3,
text_size=1,
text_pixel=2,
text_offset=10,
text_offset_original=4,
color=255):
self.color = color
height, width = original_image.shape[:-1]
predictions = predictions.resize((width, height))
top_predictions = self._post_process(predictions, thresh)
result = original_image.copy()
if self.show_mask_heatmaps:
return self.create_mask_montage(result, top_predictions)
result = self.overlay_boxes(result, top_predictions, alpha=alpha, box_pixel=box_pixel)
result = self.overlay_entity_names(result, top_predictions, text_size=text_size, text_pixel=text_pixel,
text_offset=text_offset, text_offset_original=text_offset_original)
if self.cfg.MODEL.MASK_ON:
result = self.overlay_mask(result, top_predictions)
return result, top_predictions
def compute_prediction(self, original_image, original_caption, custom_entity=None):
# image
image = self.transforms(original_image)
image_list = to_image_list(image, self.cfg.DATALOADER.SIZE_DIVISIBILITY)
image_list = image_list.to(self.device)
# caption
if isinstance(original_caption, list):
# we directly provided a list of category names
caption_string = ""
tokens_positive = []
seperation_tokens = " . "
for word in original_caption:
tokens_positive.append([len(caption_string), len(caption_string) + len(word)])
caption_string += word
caption_string += seperation_tokens
tokenized = self.tokenizer([caption_string], return_tensors="pt")
tokens_positive = [tokens_positive]
original_caption = caption_string
print(tokens_positive)
else:
tokenized = self.tokenizer([original_caption], return_tensors="pt")
if custom_entity is None:
tokens_positive = self.run_ner(original_caption)
print(tokens_positive)
# process positive map
positive_map = create_positive_map(tokenized, tokens_positive)
if self.cfg.MODEL.RPN_ARCHITECTURE == "VLDYHEAD":
plus = 1
else:
plus = 0
positive_map_label_to_token = create_positive_map_label_to_token_from_positive_map(positive_map, plus=plus)
self.plus = plus
self.positive_map_label_to_token = positive_map_label_to_token
tic = timeit.time.perf_counter()
# compute predictions
with torch.no_grad():
predictions = self.model(image_list, captions=[original_caption], positive_map=positive_map_label_to_token)
predictions = [o.to(self.cpu_device) for o in predictions]
print("inference time per image: {}".format(timeit.time.perf_counter() - tic))
# always single image is passed at a time
prediction = predictions[0]
# reshape prediction (a BoxList) into the original image size
height, width = original_image.shape[:-1]
prediction = prediction.resize((width, height))
if prediction.has_field("mask"):
# if we have masks, paste the masks in the right position
# in the image, as defined by the bounding boxes
masks = prediction.get_field("mask")
# always single image is passed at a time
masks = self.masker([masks], [prediction])[0]
prediction.add_field("mask", masks)
return prediction
def _post_process_fixed_thresh(self, predictions):
scores = predictions.get_field("scores")
labels = predictions.get_field("labels").tolist()
thresh = scores.clone()
for i, lb in enumerate(labels):
if isinstance(self.confidence_threshold, float):
thresh[i] = self.confidence_threshold
elif len(self.confidence_threshold) == 1:
thresh[i] = self.confidence_threshold[0]
else:
thresh[i] = self.confidence_threshold[lb - 1]
keep = torch.nonzero(scores > thresh).squeeze(1)
predictions = predictions[keep]
scores = predictions.get_field("scores")
_, idx = scores.sort(0, descending=True)
return predictions[idx]
def _post_process(self, predictions, threshold=0.5):
scores = predictions.get_field("scores")
labels = predictions.get_field("labels").tolist()
thresh = scores.clone()
for i, lb in enumerate(labels):
if isinstance(self.confidence_threshold, float):
thresh[i] = threshold
elif len(self.confidence_threshold) == 1:
thresh[i] = threshold
else:
thresh[i] = self.confidence_threshold[lb - 1]
keep = torch.nonzero(scores > thresh).squeeze(1)
predictions = predictions[keep]
scores = predictions.get_field("scores")
_, idx = scores.sort(0, descending=True)
return predictions[idx]
def compute_colors_for_labels(self, labels):
"""
Simple function that adds fixed colors depending on the class
"""
colors = (300 * (labels[:, None] - 1) + 1) * self.palette
colors = (colors % 255).numpy().astype("uint8")
try:
colors = (colors * 0 + self.color).astype("uint8")
except:
pass
return colors
def overlay_boxes(self, image, predictions, alpha=0.5, box_pixel=3):
labels = predictions.get_field("labels")
boxes = predictions.bbox
colors = self.compute_colors_for_labels(labels).tolist()
new_image = image.copy()
for box, color in zip(boxes, colors):
box = box.to(torch.int64)
top_left, bottom_right = box[:2].tolist(), box[2:].tolist()
new_image = cv2.rectangle(
new_image, tuple(top_left), tuple(bottom_right), tuple(color), box_pixel)
# Following line overlays transparent rectangle over the image
image = cv2.addWeighted(new_image, alpha, image, 1 - alpha, 0)
return image
def overlay_scores(self, image, predictions):
scores = predictions.get_field("scores")
boxes = predictions.bbox
for box, score in zip(boxes, scores):
box = box.to(torch.int64)
image = cv2.putText(image, '%.3f' % score,
(int(box[0]), int((box[1] + box[3]) / 2)),
cv2.FONT_HERSHEY_SIMPLEX, 1.0, (255, 255, 255), 2, cv2.LINE_AA)
return image
def overlay_entity_names(self, image, predictions, names=None, text_size=0.7, text_pixel=2, text_offset=10,
text_offset_original=4):
scores = predictions.get_field("scores").tolist()
labels = predictions.get_field("labels").tolist()
new_labels = []
if self.cfg.MODEL.RPN_ARCHITECTURE == "VLDYHEAD":
plus = 1
else:
plus = 0
self.plus = plus
if self.entities and self.plus:
for i in labels:
if i <= len(self.entities):
new_labels.append(self.entities[i - self.plus])
else:
new_labels.append('object')
# labels = [self.entities[i - self.plus] for i in labels ]
else:
new_labels = ['object' for i in labels]
boxes = predictions.bbox
template = "{}:{:.2f}"
previous_locations = []
for box, score, label in zip(boxes, scores, new_labels):
x, y = box[:2]
s = template.format(label, score).replace("_", " ").replace("(", "").replace(")", "")
for x_prev, y_prev in previous_locations:
if abs(x - x_prev) < abs(text_offset) and abs(y - y_prev) < abs(text_offset):
y -= text_offset
cv2.putText(
image, s, (int(x), int(y) - text_offset_original), cv2.FONT_HERSHEY_SIMPLEX, text_size,
(255, 255, 255), text_pixel, cv2.LINE_AA
)
previous_locations.append((int(x), int(y)))
return image
def overlay_mask(self, image, predictions):
masks = predictions.get_field("mask").numpy()
labels = predictions.get_field("labels")
colors = self.compute_colors_for_labels(labels).tolist()
# import pdb
# pdb.set_trace()
# masks = masks > 0.1
for mask, color in zip(masks, colors):
thresh = mask[0, :, :, None].astype(np.uint8)
contours, hierarchy = cv2_util.findContours(
thresh, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE
)
image = cv2.drawContours(image, contours, -1, color, 2)
composite = image
return composite
def create_mask_montage(self, image, predictions):
masks = predictions.get_field("mask")
masks_per_dim = self.masks_per_dim
masks = L.interpolate(
masks.float(), scale_factor=1 / masks_per_dim
).byte()
height, width = masks.shape[-2:]
max_masks = masks_per_dim ** 2
masks = masks[:max_masks]
# handle case where we have less detections than max_masks
if len(masks) < max_masks:
masks_padded = torch.zeros(max_masks, 1, height, width, dtype=torch.uint8)
masks_padded[: len(masks)] = masks
masks = masks_padded
masks = masks.reshape(masks_per_dim, masks_per_dim, height, width)
result = torch.zeros(
(masks_per_dim * height, masks_per_dim * width), dtype=torch.uint8
)
for y in range(masks_per_dim):
start_y = y * height
end_y = (y + 1) * height
for x in range(masks_per_dim):
start_x = x * width
end_x = (x + 1) * width
result[start_y:end_y, start_x:end_x] = masks[y, x]
return cv2.applyColorMap(result.numpy(), cv2.COLORMAP_JET), None
def create_positive_map_label_to_token_from_positive_map(positive_map, plus=0):
positive_map_label_to_token = {}
for i in range(len(positive_map)):
positive_map_label_to_token[i + plus] = torch.nonzero(positive_map[i], as_tuple=True)[0].tolist()
return positive_map_label_to_token
def create_positive_map(tokenized, tokens_positive):
"""construct a map such that positive_map[i,j] = True iff box i is associated to token j"""
positive_map = torch.zeros((len(tokens_positive), 256), dtype=torch.float)
for j, tok_list in enumerate(tokens_positive):
for (beg, end) in tok_list:
try:
beg_pos = tokenized.char_to_token(beg)
end_pos = tokenized.char_to_token(end - 1)
except Exception as e:
print("beg:", beg, "end:", end)
print("token_positive:", tokens_positive)
# print("beg_pos:", beg_pos, "end_pos:", end_pos)
raise e
if beg_pos is None:
try:
beg_pos = tokenized.char_to_token(beg + 1)
if beg_pos is None:
beg_pos = tokenized.char_to_token(beg + 2)
except:
beg_pos = None
if end_pos is None:
try:
end_pos = tokenized.char_to_token(end - 2)
if end_pos is None:
end_pos = tokenized.char_to_token(end - 3)
except:
end_pos = None
if beg_pos is None or end_pos is None:
continue
assert beg_pos is not None and end_pos is not None
positive_map[j, beg_pos: end_pos + 1].fill_(1)
return positive_map / (positive_map.sum(-1)[:, None] + 1e-6)
def find_noun_phrases(caption: str) -> List[str]:
caption = caption.lower()
tokens = nltk.word_tokenize(caption)
pos_tags = nltk.pos_tag(tokens)
grammar = "NP: {<DT>?<JJ.*>*<NN.*>+}"
cp = nltk.RegexpParser(grammar)
result = cp.parse(pos_tags)
noun_phrases = list()
for subtree in result.subtrees():
if subtree.label() == 'NP':
noun_phrases.append(' '.join(t[0] for t in subtree.leaves()))
return noun_phrases
def remove_punctuation(text: str) -> str:
punct = ['|', ':', ';', '@', '(', ')', '[', ']', '{', '}', '^',
'\'', '\"', '’', '`', '?', '$', '%', '#', '!', '&', '*', '+', ',', '.'
]
for p in punct:
text = text.replace(p, '')
return text.strip()