|
import torch |
|
import os |
|
import glob |
|
import time |
|
from torchvision.io import read_image |
|
import matplotlib.pyplot as plt |
|
from scipy import ndimage |
|
from PIL import Image |
|
import bbnet.trainval.validator as validator |
|
import modeling_pretrain_cleaned as vmae_transformers |
|
import modeling_pretrain as vmae_transformers_old |
|
import positional_vmae as pos_transformers |
|
import big_models as big_transformers |
|
import bbnet.models.preprocessor as preprocessor |
|
import bbnet.models.error as error_generator |
|
from functools import partial |
|
import bbnet.models.teachers as teachers |
|
from tqdm import tqdm |
|
from torch.nn import functional as F |
|
import argparse |
|
import sys |
|
import numpy as np |
|
import json |
|
import pycocotools.mask as mask_util |
|
sys.path.append('/ccn2/u/honglinc/CutLER') |
|
sys.path.append('/ccn2/u/honglinc/CutLER/maskcut') |
|
sys.path.append('/ccn2/u/honglinc/CutLER/third_party') |
|
import dino |
|
from maskcut import get_affinity_matrix, second_smallest_eigenvector, get_salient_areas, check_num_fg_corners, get_masked_affinity_matrix |
|
from third_party.TokenCut.unsupervised_saliency_detection import utils, metric |
|
from third_party.TokenCut.unsupervised_saliency_detection.object_discovery import detect_box |
|
from crf import densecrf |
|
|
|
|
|
vit_arch = 'base' |
|
vit_feat = 'k' |
|
patch_size = 8 |
|
|
|
url = "https://dl.fbaipublicfiles.com/dino/dino_vitbase8_pretrain/dino_vitbase8_pretrain.pth" |
|
feat_dim = 768 |
|
dino_backbone = dino.ViTFeat(url, feat_dim, vit_arch, vit_feat, patch_size) |
|
dino_backbone = dino_backbone.eval().requires_grad_(False).cuda() |
|
|
|
|
|
def get_dino_predominance(images, dims=[28, 28], current_mask=None, painting=None, img_size=[224, 224]): |
|
input_dino = images |
|
input_dino = input_dino - torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1).to(input_dino.device) |
|
input_dino = input_dino / torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1).to(input_dino.device) |
|
|
|
input_dino = torch.nn.functional.interpolate(input_dino, size=img_size, mode='bilinear') |
|
features = dino_backbone(input_dino) |
|
|
|
predominence_map = [] |
|
|
|
for i in range(features.shape[0]): |
|
feats = features[i] |
|
if current_mask == None: |
|
painting = torch.from_numpy(np.zeros(dims)) |
|
painting = painting.to(feats) |
|
else: |
|
feats, painting = get_masked_affinity_matrix(painting, feats, current_mask, ps=dims[0]) |
|
|
|
A, D = get_affinity_matrix(feats, tau=0.15) |
|
|
|
_, second_smallest_vec = second_smallest_eigenvector(A, D) |
|
|
|
bipartition = get_salient_areas(second_smallest_vec) |
|
|
|
|
|
|
|
seed = np.argmax(np.abs(second_smallest_vec)) |
|
nc = check_num_fg_corners(bipartition, dims) |
|
if nc >= 2: |
|
reverse = True |
|
else: |
|
reverse = bipartition[seed] != 1 |
|
if reverse: |
|
second_smallest_vec = 1 - second_smallest_vec |
|
second_smallest_vec = torch.tensor(second_smallest_vec).to(images.device).contiguous() |
|
map = torch.nn.functional.interpolate(second_smallest_vec.reshape(1, 1, dims[0], dims[1]), size=img_size, |
|
mode='bilinear') |
|
map -= map.min() |
|
map /= map.max() |
|
predominence_map.append(map) |
|
init_dist = torch.cat(predominence_map, dim=0).detach() |
|
return init_dist, A, feats, painting |
|
|
|
|
|
|
|
|
|
def interpolate_pos_encoding(pos_embed, n_frames, h, w): |
|
N = pos_embed.shape[1] |
|
if N == (h * w * n_frames): |
|
return pos_embed |
|
old_h = old_w = int((N / n_frames) ** 0.5) |
|
patch_pos_embed = pos_embed.view(1, n_frames, old_h, old_w, -1).flatten(0, 1).permute(0, 3, 1, 2) |
|
|
|
patch_pos_embed = F.interpolate( |
|
patch_pos_embed, |
|
size=(h, w), |
|
mode='bicubic', |
|
) |
|
return patch_pos_embed.permute(0, 2, 3, 1).flatten(0, 2).unsqueeze(0) |
|
|
|
|
|
|
|
def vis_results(x, targets_dict, annotation, name): |
|
img = x[0, 0].permute(1, 2, 0).cpu() |
|
fig, axs = plt.subplots(1, 1+len(targets_dict), figsize=(3*len(targets_dict), 3)) |
|
axs[0].imshow(img) |
|
axs[0].set_title('Image') |
|
|
|
for i, v in enumerate(targets_list): |
|
v = v[0, 0] |
|
axs[1+i].imshow((v[..., None] * img) + (~v[..., None] * torch.ones_like(img))) |
|
axs[1+i].set_title(f'Segment {i}', fontsize=10) |
|
|
|
for ax in axs: |
|
ax.set_axis_off() |
|
|
|
plt.show() |
|
plt.close() |
|
|
|
|
|
if __name__ == "__main__": |
|
parser = argparse.ArgumentParser('Generate zero-shot segments from CWM model', add_help=False) |
|
parser.add_argument('--input_pattern', default='/ccn2/u/honglinc/datasets/coco/images/val2017/*', nargs='+', type=str, help='Pattern for input images') |
|
parser.add_argument('--output', default='./output.pt', type=str, help='output path for saving the results') |
|
parser.add_argument('--num_iter', default=1, type=int, help='number of iterations') |
|
parser.add_argument('--visualize', action='store_true', help='Visualize the results') |
|
args = parser.parse_args() |
|
|
|
|
|
image_list = glob.glob(args.input_pattern) if isinstance(args.input_pattern, str) else args.input_pattern |
|
thresh = 0.5 |
|
visualize = args.visualize |
|
save_dict = {} |
|
image_size = [480, 480] |
|
patch_size = 8 |
|
dims = [int(s / patch_size) for s in image_size] |
|
|
|
|
|
default_model_dir = '/ccn2/u/honglinc/cwm_checkpoints/' |
|
model_func = vmae_transformers.vitb_8x8patch_3frames |
|
ckpt_path = 'ablation_3frame_no_clumping_mr0.90_extra_data_ep400' |
|
label = '3 frame 8x8' |
|
teacher_func = teachers.iteration_segment_teacher_with_filter |
|
|
|
teacher = teacher_func( |
|
model_func=model_func, |
|
model_path=teachers.get_load_path(os.path.join(default_model_dir, ckpt_path), model_checkpoint=-1), |
|
visualization_mode=visualize, |
|
initial_sampling_distribution_kwargs={'num_samples': 20, 'num_active_patches': 1, 'num_passive_patches': 1}, |
|
).requires_grad_(False).cuda() |
|
|
|
teacher.predictor.encoder.pos_embed = interpolate_pos_encoding( |
|
teacher.predictor.encoder.pos_embed, 3, dims[0], dims[1]) |
|
teacher.predictor.pos_embed = interpolate_pos_encoding( |
|
teacher.predictor.pos_embed, 3, dims[0], dims[1]) |
|
teacher.predictor.image_size = image_size |
|
|
|
|
|
start = time.time() |
|
|
|
|
|
if os.path.exists(args.output): |
|
print('Load partial results from: ', args.output) |
|
save_dict = torch.load(args.output) |
|
print('Length of existing dict: ', len(save_dict)) |
|
|
|
for image_path in image_list: |
|
|
|
|
|
image_name = image_path.split('/')[-1] |
|
|
|
if image_name in save_dict: |
|
continue |
|
|
|
image = read_image(image_path) |
|
if image.shape[0] == 1: |
|
image = image.expand(3, -1, -1) |
|
|
|
x = torch.stack([image] * 3, dim=0) |
|
x = torch.nn.functional.interpolate(x.float(), size=image_size, mode='bicubic')[None] / 255. |
|
_x = x.to(torch.float16).cuda() |
|
|
|
targets_list = [] |
|
|
|
for n in range(args.num_iter): |
|
|
|
|
|
if n == 0: |
|
predominance, _, feats, painting = get_dino_predominance(x[:, :, 0].cuda(), dims=dims, img_size=image_size) |
|
else: |
|
predominance, _, feats, painting = get_dino_predominance(x[:, :, 0].cuda(), |
|
current_mask=current_mask.cuda(), |
|
painting=painting, dims=dims, |
|
img_size=image_size) |
|
|
|
if visualize: |
|
plt.imshow(predominance[0, 0].cpu()) |
|
plt.title(f'Predominance (max:{predominance[0, 0].max()})') |
|
plt.show() |
|
|
|
|
|
if n > 0: |
|
for mask in targets_list: |
|
predominance[0, 0][mask[0, 0].cuda()] = 0 |
|
|
|
|
|
|
|
with torch.cuda.amp.autocast(enabled=True): |
|
targets = teacher(_x, sampling_distribution=predominance)[0] |
|
if n == 0: |
|
targets_list = [targets.cpu() >= thresh] |
|
else: |
|
ratio = targets.mean() |
|
mask = targets.cpu() >= thresh |
|
iou = 0 |
|
match_idx = None |
|
|
|
for idx, existing_mask in enumerate(targets_list): |
|
_iou = metric.IoU(mask[0, 0], existing_mask[0, 0]) |
|
if _iou > iou: |
|
iou = _iou |
|
match_idx = idx |
|
|
|
|
|
if iou > 0.2 or ratio <= 0.01: |
|
mask = torch.zeros_like(mask) |
|
|
|
|
|
|
|
targets_list.append(mask) |
|
|
|
current_mask = F.interpolate(targets, size=dims, mode='bilinear') >= thresh |
|
|
|
vid_name = image_path |
|
save_dict[image_name] = targets_list |
|
if visualize: |
|
vis_results(x, targets_list, None, vid_name.split('/')[-2] + '.png') |
|
|
|
if (len(save_dict) + 1) % 1 == 0: |
|
total = len(image_list) |
|
num_completed = len(save_dict) |
|
avg_time = (time.time() - start) / num_completed |
|
eta = (total - num_completed) * avg_time / 60. |
|
print(f'{num_completed} / {total} completed, avg. time per image: {avg_time:.2f} sec, eta: {eta:.1f} mins') |
|
torch.save(save_dict, args.output) |
|
|
|
torch.save(save_dict, args.output) |