|
from typing import List, Dict, Union, Tuple |
|
|
|
from PIL import Image, ImageDraw, ImageFilter, ImageOps, ImageEnhance |
|
import spacy |
|
import hashlib |
|
import os |
|
|
|
import torch |
|
import torchvision |
|
import torchvision.transforms as transforms |
|
import clip |
|
from transformers import BertTokenizer, RobertaTokenizerFast |
|
import ruamel.yaml as yaml |
|
import copy |
|
|
|
from interpreter import Box |
|
|
|
import pycocotools.mask as mask_utils |
|
import alpha_clip |
|
from segment_anything import sam_model_registry, SamPredictor |
|
import numpy as np |
|
import cv2 |
|
import matplotlib.pyplot as plt |
|
|
|
import pickle |
|
|
|
class Executor: |
|
def __init__(self, device: str = "cpu", box_representation_method: str = "crop", method_aggregator: str = "max", enlarge_boxes: int = 0, expand_position_embedding: bool = False, square_size: bool = False, blur_std_dev: int = 100, cache_path: str = None, input_file: str = None) -> None: |
|
IMPLEMENTED_METHODS = ["blur", "full", "gray"] |
|
if any(m not in IMPLEMENTED_METHODS for m in box_representation_method.split(",")): |
|
raise NotImplementedError |
|
IMPLEMENTED_AGGREGATORS = ["max", "sum"] |
|
if method_aggregator not in IMPLEMENTED_AGGREGATORS: |
|
raise NotImplementedError |
|
self.box_representation_method = box_representation_method |
|
self.method_aggregator = method_aggregator |
|
self.enlarge_boxes = enlarge_boxes |
|
self.device = device |
|
self.expand_position_embedding = expand_position_embedding |
|
self.square_size = square_size |
|
self.blur_std_dev = blur_std_dev |
|
self.cache_path = cache_path |
|
|
|
def preprocess_image(self, image: Image) -> List[torch.Tensor]: |
|
return [preprocess(image) for preprocess in self.preprocesses] |
|
|
|
def preprocess_mask(self, mask: Image) -> List[torch.Tensor]: |
|
preprocess = self.preprocesses[0] |
|
return preprocess.transforms[1](preprocess.transforms[0](mask)) |
|
|
|
def preprocess_text(self, text: str) -> torch.Tensor: |
|
raise NotImplementedError |
|
|
|
def call_model(self, model: torch.nn.Module, images: torch.Tensor, text: Union[torch.Tensor, Dict[str, torch.Tensor]]) -> torch.Tensor: |
|
raise NotImplementedError |
|
|
|
def tensorize_inputs(self, caption: str, image: Image, boxes: List[Box], image_name: str = None, image_pth: str = None) -> Tuple[List[torch.Tensor], torch.Tensor]: |
|
images = [] |
|
for preprocess in self.preprocesses: |
|
images.append([]) |
|
|
|
if 'aclip' in self.clip_type: |
|
self.all_masks = [] |
|
read_save = False |
|
if self.mask_path is not None: |
|
file_name = image_pth.split('/')[-1].split('.')[0]+'.pkl' |
|
if os.path.exists(os.path.join(self.mask_path, file_name)): |
|
all_rles = pickle.load(open(os.path.join(self.mask_path, file_name),'rb')) |
|
for rle in all_rles: |
|
mask = np.array(mask_utils.decode(rle), dtype=bool) |
|
self.all_masks.append(mask) |
|
read_save = True |
|
if not read_save: |
|
|
|
self.predictor.set_image(np.array(image.convert('RGB'))) |
|
all_rles = [] |
|
for i in range(len(boxes)): |
|
box = [ |
|
max(boxes[i].left-self.enlarge_boxes, 0), |
|
max(boxes[i].top-self.enlarge_boxes, 0), |
|
min(boxes[i].right+self.enlarge_boxes, image.width), |
|
min(boxes[i].bottom+self.enlarge_boxes, image.height) |
|
] |
|
input_box = np.array(box) |
|
masks, _, _ = self.predictor.predict( |
|
point_coords=None, |
|
point_labels=None, |
|
box=input_box[None, :], |
|
multimask_output=False, |
|
) |
|
self.all_masks.append(masks[0]) |
|
rle = mask_utils.encode(np.array(masks[0][:, :, None], order='F', dtype="uint8"))[0] |
|
rle["counts"] = rle["counts"].decode("utf-8") |
|
all_rles.append(rle) |
|
if self.mask_path is not None: |
|
os.makedirs(self.mask_path, exist_ok=True) |
|
pickle.dump(all_rles, open(os.path.join(self.mask_path, file_name),'wb')) |
|
|
|
if self.cache_path is None or any([not os.path.exists(os.path.join(self.cache_path, "refcoco_val", model_name, "image", image_name, method_name+".pt")) for model_name in self.model_names for method_name in self.box_representation_method.split(',')]): |
|
if "full" in self.box_representation_method: |
|
for i in range(len(boxes)): |
|
image_i = image.copy() |
|
preprocessed_images = self.preprocess_image(image_i) |
|
for j, img in enumerate(preprocessed_images): |
|
images[j].append(img.to(self.device)) |
|
if "blur" in self.box_representation_method: |
|
for i in range(len(boxes)): |
|
image_i = image.copy() |
|
|
|
mask = Image.new('L', image_i.size, 0) |
|
draw = ImageDraw.Draw(mask) |
|
box = ( |
|
max(boxes[i].left-self.enlarge_boxes, 0), |
|
max(boxes[i].top-self.enlarge_boxes, 0), |
|
min(boxes[i].right+self.enlarge_boxes, image_i.width), |
|
min(boxes[i].bottom+self.enlarge_boxes, image_i.height) |
|
) |
|
if 'aclip' in self.clip_type: |
|
width, height = image.size |
|
for y in range(height): |
|
for x in range(width): |
|
if self.all_masks[i][y][x] == 1: |
|
draw.point((x, y), fill=255) |
|
else: |
|
draw.rectangle([box[:2], box[2:]], fill=255) |
|
blurred = image_i.filter(ImageFilter.GaussianBlur(self.blur_std_dev)) |
|
blurred.paste(image_i, mask=mask) |
|
preprocessed_images = self.preprocess_image(blurred) |
|
|
|
for j, img in enumerate(preprocessed_images): |
|
images[j].append(img.to(self.device)) |
|
if "gray" in self.box_representation_method: |
|
for i in range(len(boxes)): |
|
image_i = image.copy() |
|
mask_i = self.all_masks[i] |
|
width, height = image.size |
|
|
|
pixels = image_i.load() |
|
for y in range(height): |
|
for x in range(width): |
|
if mask_i[y][x] == 0: |
|
pixel_value = pixels[x, y] |
|
gray_value = int(0.2989 * pixel_value[0] + 0.5870 * pixel_value[1] + 0.1140 * pixel_value[2]) |
|
pixels[x, y] = (gray_value, gray_value, gray_value) |
|
preprocessed_images = self.preprocess_image(image_i) |
|
for j, img in enumerate(preprocessed_images): |
|
images[j].append(img.to(self.device)) |
|
|
|
imgs = [torch.stack(image_list) for image_list in images] |
|
else: |
|
imgs = [[] for _ in self.models] |
|
text_tensor = self.preprocess_text(caption.lower()).to(self.device) |
|
return imgs, text_tensor |
|
|
|
@torch.no_grad() |
|
def __call__(self, caption: str, image: Image, boxes: List[Box], image_name: str = None, image_pth=None) -> torch.Tensor: |
|
images, text_tensor = self.tensorize_inputs(caption, image, boxes, image_name, image_pth) |
|
all_logits_per_image = [] |
|
all_logits_per_text = [] |
|
box_representation_methods = self.box_representation_method.split(',') |
|
caption_hash = hashlib.md5(caption.encode('utf-8')).hexdigest() |
|
for model, images_t, model_name in zip(self.models, images, self.model_names): |
|
self.image_feat_path = "" |
|
if self.cache_path is not None: |
|
text_cache_path = os.path.join(self.cache_path, "refcoco_val", model_name, "text"+("_shade" if self.box_representation_method == "shade" else "")) |
|
image_feat_path = os.path.join(self.cache_path, "refcoco_val", model_name, "image", image_name) |
|
self.image_feat_path = image_feat_path |
|
image_features = None |
|
text_features = None |
|
if self.cache_path is not None and os.path.exists(os.path.join(self.cache_path, "refcoco_val", model_name)): |
|
if os.path.exists(os.path.join(text_cache_path, caption_hash+".pt")): |
|
text_features = torch.load(os.path.join(text_cache_path, caption_hash+".pt"), map_location=self.device) |
|
if os.path.exists(image_feat_path): |
|
if all([os.path.exists(os.path.join(image_feat_path, method_name+".pt")) for method_name in box_representation_methods]): |
|
image_features = [] |
|
for method_name in box_representation_methods: |
|
features = torch.load(os.path.join(image_feat_path, method_name+".pt"), map_location=self.device) |
|
image_features.append(torch.stack([ |
|
features[(box.x, box.y, box.w, box.h)] |
|
for box in boxes |
|
])) |
|
image_features = torch.stack(image_features) |
|
image_features = image_features.view(-1, image_features.shape[-1]) |
|
logits_per_image, logits_per_text, image_features, text_features = self.call_model(model, images_t, text_tensor, image_features=image_features, text_features=text_features, boxes=boxes, image_pth=image_pth) |
|
all_logits_per_image.append(logits_per_image) |
|
all_logits_per_text.append(logits_per_text) |
|
if self.cache_path is not None and image_name is not None and image_features is not None: |
|
image_features = image_features.view(len(box_representation_methods), len(boxes), image_features.shape[-1]) |
|
if not os.path.exists(image_feat_path): |
|
os.makedirs(image_feat_path) |
|
for i in range(image_features.shape[0]): |
|
method_name = box_representation_methods[i] |
|
if not os.path.exists(os.path.join(image_feat_path, method_name+".pt")): |
|
image_features_dict = {(box.x, box.y, box.w, box.h): image_features[i,j,:].cpu() for j, box in enumerate(boxes)} |
|
torch.save(image_features_dict, os.path.join(image_feat_path, method_name+".pt")) |
|
if self.cache_path is not None and not os.path.exists(os.path.join(text_cache_path, caption_hash+".pt")) and text_features is not None: |
|
assert text_features.shape[0] == 1 |
|
if not os.path.exists(text_cache_path): |
|
os.makedirs(text_cache_path) |
|
torch.save(text_features.cpu(), os.path.join(text_cache_path, caption_hash+".pt")) |
|
|
|
all_logits_per_image = torch.stack(all_logits_per_image).sum(0) |
|
all_logits_per_text = torch.stack(all_logits_per_text).sum(0) |
|
if self.method_aggregator == "max": |
|
all_logits_per_text = all_logits_per_text.view(-1, len(boxes)).max(dim=0, keepdim=True)[0] |
|
elif self.method_aggregator == "sum": |
|
all_logits_per_text = all_logits_per_text.view(-1, len(boxes)).sum(dim=0, keepdim=True) |
|
return all_logits_per_text.view(-1) |
|
|
|
class ClipExecutor(Executor): |
|
def __init__(self, clip_model: str = "ViT-B/32", device: str = "cpu", box_representation_method: str = "crop", method_aggregator: str = "max", enlarge_boxes: int = 0, expand_position_embedding: bool = False, square_size: bool = False, blur_std_dev: int = 100, cache_path: str = None, input_file: str = None, clip_type: str=None) -> None: |
|
super().__init__(device, box_representation_method, method_aggregator, enlarge_boxes, expand_position_embedding, square_size, blur_std_dev, cache_path) |
|
self.clip_models = clip_model.split(",") |
|
self.model_names = [model_name.replace("/", "_") for model_name in self.clip_models] |
|
self.models = [] |
|
self.preprocesses = [] |
|
self.data_name = input_file.split('/')[-1].split('.')[0] |
|
self.mask_path = None |
|
self.clip_type = clip_type |
|
if self.cache_path is not None: |
|
self.mask_path = os.path.join(self.cache_path, "refcoco_val", 'det_masks') |
|
sam_checkpoint = "./ckpt/sam_vit_h_4b8939.pth" |
|
model_type = "vit_h" |
|
sam = sam_model_registry[model_type](checkpoint=sam_checkpoint) |
|
sam.to(device=device) |
|
self.predictor = SamPredictor(sam) |
|
for model_name in self.clip_models: |
|
if 'aclip' in self.clip_type: |
|
self.mask_transform = transforms.Compose([ |
|
transforms.ToTensor(), |
|
transforms.Resize((224, 224)), |
|
transforms.Normalize(0.5, 0.26) |
|
]) |
|
if model_name == 'ViT-B/16': |
|
model, preprocess = alpha_clip.load("ViT-B/16", alpha_vision_ckpt_pth="./ckpt/grit1m/clip_b16_grit+mim_fultune_4xe.pth", device=device) |
|
elif model_name == 'ViT-L/14': |
|
model, preprocess = alpha_clip.load("ViT-L/14", alpha_vision_ckpt_pth="./ckpt/grit1m/clip_l14_grit+mim_fultune_6xe.pth", device=device) |
|
|
|
else: model, preprocess = clip.load(model_name, device=device, jit=False) |
|
self.models.append(model) |
|
if self.square_size: |
|
print("Square size!") |
|
preprocess.transforms[0] = transforms.Resize((model.visual.input_resolution, model.visual.input_resolution), interpolation=transforms.InterpolationMode.BICUBIC) |
|
self.preprocesses.append(preprocess) |
|
self.models = torch.nn.ModuleList(self.models) |
|
|
|
def preprocess_text(self, text: str) -> torch.Tensor: |
|
if "aclip" in self.box_representation_method: |
|
return alpha_clip.tokenize([text.lower()]) |
|
if "shade" in self.box_representation_method: |
|
return clip.tokenize([text.lower()+" is in red color."]) |
|
return clip.tokenize(["a photo of "+text.lower()]) |
|
|
|
def call_model(self, model: torch.nn.Module, images: torch.Tensor, text: torch.Tensor, image_features: torch.Tensor = None, text_features: torch.Tensor = None, boxes=None, image_pth=None) -> torch.Tensor: |
|
if image_features is None: |
|
print('computing image features') |
|
if 'aclip' not in self.clip_type: |
|
image_features = model.encode_image(images) |
|
image_features = image_features / image_features.norm(dim=-1, keepdim=True) |
|
else: |
|
image_features = [] |
|
if 'full' in self.box_representation_method: |
|
aclip_images = images[:len(boxes)] |
|
alphas = [] |
|
|
|
if os.path.exists(os.path.join(self.image_feat_path, 'full.pt')): |
|
features = torch.load(os.path.join(self.image_feat_path, 'full.pt'), map_location=self.device) |
|
aclip_image_features = torch.stack([ |
|
features[(box.x, box.y, box.w, box.h)] |
|
for box in boxes |
|
]) |
|
else: |
|
for i in range(len(self.all_masks)): |
|
binary_mask = self.all_masks[i] |
|
alpha = self.mask_transform((binary_mask * 255).astype(np.uint8)) |
|
alpha = alpha.half().cuda().unsqueeze(dim=0) |
|
alphas.append(alpha) |
|
|
|
alphas = torch.cat(alphas, dim=0) |
|
aclip_images = aclip_images.half() |
|
aclip_image_features = model.visual(aclip_images, alphas) |
|
images = images[len(boxes):] |
|
image_features.append(aclip_image_features) |
|
|
|
if 'blur' in self.box_representation_method: |
|
if os.path.exists(os.path.join(self.image_feat_path, 'blur.pt')): |
|
features = torch.load(os.path.join(self.image_feat_path, 'blur.pt'), map_location=self.device) |
|
ablur_images_features = torch.stack([ |
|
features[(box.x, box.y, box.w, box.h)] |
|
for box in boxes |
|
]) |
|
else: |
|
ablur_images = images[:len(boxes)] |
|
alphas = [] |
|
for i in range(len(self.all_masks)): |
|
binary_mask = self.all_masks[i] |
|
alpha = self.mask_transform((binary_mask * 255).astype(np.uint8)) |
|
alpha = alpha.half().cuda().unsqueeze(dim=0) |
|
alphas.append(alpha) |
|
alphas = torch.cat(alphas, dim=0) |
|
ablur_images = ablur_images.half() |
|
ablur_images_features = model.visual(ablur_images, alphas) |
|
images = images[len(boxes):] |
|
image_features.append(ablur_images_features) |
|
|
|
if 'gray' in self.box_representation_method: |
|
if os.path.exists(os.path.join(self.image_feat_path, 'gray.pt')): |
|
features = torch.load(os.path.join(self.image_feat_path, 'gray.pt'), map_location=self.device) |
|
gray_images_features = torch.stack([ |
|
features[(box.x, box.y, box.w, box.h)] |
|
for box in boxes |
|
]) |
|
else: |
|
gray_images = images[:len(boxes)] |
|
alphas = [] |
|
for i in range(len(self.all_masks)): |
|
binary_mask = self.all_masks[i] |
|
alpha = self.mask_transform((binary_mask * 255).astype(np.uint8)) |
|
alpha = alpha.half().cuda().unsqueeze(dim=0) |
|
alphas.append(alpha) |
|
alphas = torch.cat(alphas, dim=0) |
|
gray_images = gray_images.half() |
|
gray_images_features = model.visual(gray_images, alphas) |
|
images = images[len(boxes):] |
|
image_features.append(gray_images_features) |
|
|
|
|
|
image_features = torch.cat(image_features, dim=0) |
|
image_features = image_features / image_features.norm(dim=-1, keepdim=True) |
|
|
|
if text_features is None: |
|
print('computing text features') |
|
text_features = model.encode_text(text) |
|
|
|
text_features = text_features / text_features.norm(dim=-1, keepdim=True) |
|
|
|
|
|
logit_scale = model.logit_scale.exp() |
|
logits_per_image = logit_scale * image_features @ text_features.t() |
|
logits_per_text = logits_per_image.t() |
|
return logits_per_image, logits_per_text, image_features, text_features |
|
|
|
def __call__(self, caption: str, image: Image, boxes: List[Box], image_name: str = None, image_pth=None) -> torch.Tensor: |
|
if self.expand_position_embedding: |
|
original_preprocesses = self.preprocesses |
|
new_preprocesses = [] |
|
original_position_embeddings = [] |
|
for model_name, model, preprocess in zip(self.clip_models, self.models, self.preprocesses): |
|
if "RN" in model_name: |
|
model_spatial_dim = int((model.visual.attnpool.positional_embedding.shape[0]-1)**0.5) |
|
patch_size = model.visual.input_resolution // model_spatial_dim |
|
original_positional_embedding = model.visual.attnpool.positional_embedding.clone() |
|
model.visual.attnpool.positional_embedding = torch.nn.Parameter(torch.nn.functional.interpolate( |
|
model.visual.attnpool.positional_embedding[1:,:].permute(1, 0).view(1, -1, model_spatial_dim, model_spatial_dim), |
|
size=(image.height // patch_size, image.width // patch_size), |
|
mode='bicubic', |
|
align_corners=False |
|
).squeeze(0).permute(1, 2, 0).view(-1, original_positional_embedding.shape[-1])) |
|
model.visual.attnpool.positional_embedding = torch.nn.Parameter(torch.cat(( |
|
original_positional_embedding[:1,:], |
|
model.visual.attnpool.positional_embedding |
|
), dim=0)) |
|
transform = transforms.Compose([ |
|
transforms.Resize(((image.height // patch_size)*patch_size, (image.width // patch_size)*patch_size), interpolation=Image.BICUBIC), |
|
lambda image: image.convert("RGB"), |
|
transforms.ToTensor(), |
|
transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)), |
|
]) |
|
else: |
|
model_spatial_dim = int((model.visual.positional_embedding.shape[0]-1)**0.5) |
|
patch_size = model.visual.input_resolution // model_spatial_dim |
|
original_positional_embedding = model.visual.positional_embedding.clone() |
|
model.visual.positional_embedding = torch.nn.Parameter(torch.nn.functional.interpolate( |
|
model.visual.positional_embedding[1:,:].permute(1, 0).view(1, -1, model_spatial_dim, model_spatial_dim), |
|
size=(image.height // patch_size, image.width // patch_size), |
|
mode='bicubic', |
|
align_corners=False |
|
).squeeze(0).permute(1, 2, 0).view(-1, original_positional_embedding.shape[-1])) |
|
model.visual.positional_embedding = torch.nn.Parameter(torch.cat(( |
|
original_positional_embedding[:1,:], |
|
model.visual.positional_embedding |
|
), dim=0)) |
|
transform = transforms.Compose([ |
|
transforms.Resize(((image.height // patch_size)*patch_size, (image.width // patch_size)*patch_size), interpolation=Image.BICUBIC), |
|
lambda image: image.convert("RGB"), |
|
transforms.ToTensor(), |
|
transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711)), |
|
]) |
|
new_preprocesses.append(transform) |
|
original_position_embeddings.append(original_positional_embedding) |
|
self.preprocesses = new_preprocesses |
|
result = super().__call__(caption, image, boxes, image_name, image_pth) |
|
if self.expand_position_embedding: |
|
self.preprocesses = original_preprocesses |
|
for model, model_name, pos_embedding in zip(self.models, self.clip_models, original_position_embeddings): |
|
if "RN" in model_name: |
|
model.visual.attnpool.positional_embedding = torch.nn.Parameter(pos_embedding) |
|
else: |
|
model.visual.positional_embedding = torch.nn.Parameter(pos_embedding) |
|
return result |
|
|
|
|