Spaces:
Runtime error
Runtime error
# coding=utf-8 | |
# Copyright 2024 The Google Research Authors. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
"""Find objects.""" | |
# pylint: disable=g-importing-member | |
import numpy as np | |
import scipy | |
from scipy import ndimage | |
from scipy.linalg import eigh | |
from scipy.ndimage import label | |
import torch | |
import torch.nn.functional as F | |
def ncut( | |
feats, | |
dims, | |
scales, | |
init_image_size, | |
tau=0, | |
eps=1e-5, | |
no_binary_graph=False, | |
): | |
"""Implementation of NCut Method. | |
Args: | |
feats: the pixel/patche features of an image | |
dims: dimension of the map from which the features are used | |
scales: from image to map scale | |
init_image_size: size of the image | |
tau: thresold for graph construction | |
eps: graph edge weight | |
no_binary_graph: ablation study for using similarity score as graph | |
edge weight | |
Returns: | |
TODO | |
""" | |
feats = feats[0, 1:, :] | |
feats = F.normalize(feats, p=2) | |
a = feats @ feats.transpose(1, 0) | |
a = a.cpu().numpy() | |
if no_binary_graph: | |
a[a < tau] = eps | |
else: | |
a = a > tau | |
a = np.where(a.astype(float) == 0, eps, a) | |
d_i = np.sum(a, axis=1) | |
d = np.diag(d_i) | |
# Print second and third smallest eigenvector | |
_, eigenvectors = eigh(d - a, d, subset_by_index=[1, 2]) | |
eigenvec = np.copy(eigenvectors[:, 0]) | |
# Using average point to compute bipartition | |
second_smallest_vec = eigenvectors[:, 0] | |
avg = np.sum(second_smallest_vec) / len(second_smallest_vec) | |
bipartition = second_smallest_vec > avg | |
seed = np.argmax(np.abs(second_smallest_vec)) | |
if bipartition[seed] != 1: | |
eigenvec = eigenvec * -1 | |
bipartition = np.logical_not(bipartition) | |
bipartition = bipartition.reshape(dims).astype(float) | |
# predict BBox | |
# We only extract the principal object BBox | |
pred, _, objects, cc = detect_box( | |
bipartition, | |
seed, | |
dims, | |
scales=scales, | |
initial_im_size=init_image_size[1:], | |
) | |
mask = np.zeros(dims) | |
mask[cc[0], cc[1]] = 1 | |
return np.asarray(pred), objects, mask, seed, None, eigenvec.reshape(dims) | |
def grad_obj_discover_on_attn(attn, gradcam, dims, topk=1, threshold=0.6): | |
"""Get the gradcam and attn map, then find the seed, then use LOST algorithm to find the potential points. | |
Args: | |
attn: attention map from ViT averaged across all heads, shape: [1, | |
(1+num_patches), (1+num_patches)]. | |
gradcam: gradcam map from ViT, shape: [1, 1, H, W]. | |
dims: | |
topk: | |
threshold: | |
Returns: | |
th_attn: | |
""" | |
w_featmap, h_featmap = dims | |
# nh = attn.shape[1] | |
attn = attn.squeeze() | |
seeds = torch.argsort(gradcam.flatten(), descending=True)[:topk] | |
# We keep only the output patch attention | |
# Get the attentions corresponding to [CLS] token | |
patch_attn = attn[1:, 1:] | |
topk_attn = patch_attn[seeds] | |
nh = topk_attn.shape[0] | |
# attentions = attn[0, :, 0, 1:].reshape(nh, -1) | |
# we keep only a certain percentage of the mass | |
val, idx = torch.sort(topk_attn) | |
val /= torch.sum(val, dim=1, keepdim=True) | |
cumval = torch.cumsum(val, dim=1) | |
th_attn = cumval > (1 - threshold) | |
idx2 = torch.argsort(idx) | |
for h in range(nh): | |
th_attn[h] = th_attn[h][idx2[h]] | |
th_attn = th_attn.reshape(nh, w_featmap, h_featmap).float() | |
th_attn = th_attn.sum(0) | |
th_attn[th_attn > 1] = 1 | |
return th_attn[None, None] | |
def grad_obj_discover(feats, gradcam, dims): | |
"""Using gradient heatmap to find the seed, then use LOST algorithm to find the potential points. | |
Args: | |
feats: the pixel/patche features of an image. Shape: [1, HW, C] | |
gradcam: the grad cam map | |
dims: dimension of the map from which the features are used | |
Returns: | |
pred: box predictions | |
A: binary affinity matrix | |
scores: lowest degree scores for all patches | |
seed: selected patch corresponding to an object | |
""" | |
# Compute the similarity | |
a = (feats @ feats.transpose(1, 2)).squeeze() | |
# Compute the inverse degree centrality measure per patch | |
# sorted_patches, scores = patch_scoring(a) | |
# Select the initial seed | |
# seed = sorted_patches[0] | |
seed = gradcam.argmax() | |
mask = a[seed] | |
mask = mask.view(1, 1, *dims) | |
return mask | |
def lost(feats, dims, scales, init_image_size, k_patches=100): | |
"""Implementation of LOST method. | |
Args: | |
feats: the pixel/patche features of an image. Shape: [1, C, H, W] | |
dims: dimension of the map from which the features are used | |
scales: from image to map scale | |
init_image_size: size of the image | |
k_patches: number of k patches retrieved that are compared to the seed | |
at seed expansion. | |
Returns: | |
pred: box predictions | |
A: binary affinity matrix | |
scores: lowest degree scores for all patches | |
seed: selected patch corresponding to an object | |
""" | |
# Compute the similarity | |
feats = feats.flatten(2).transpose(1, 2) | |
a = (feats @ feats.transpose(1, 2)).squeeze() | |
# Compute the inverse degree centrality measure per patch | |
sorted_patches, _ = patch_scoring(a) | |
# Select the initial seed | |
seed = sorted_patches[0] | |
# Seed expansion | |
potentials = sorted_patches[:k_patches] | |
similars = potentials[a[seed, potentials] > 0.0] | |
m = torch.sum(a[similars, :], dim=0) | |
# Box extraction | |
_, _, _, mask = detect_box( | |
m, seed, dims, scales=scales, initial_im_size=init_image_size[1:] | |
) | |
return mask | |
# return np.asarray(bbox), A, scores, seed | |
def patch_scoring(m, threshold=0.0): | |
"""Patch scoring based on the inverse degree.""" | |
# Cloning important | |
a = m.clone() | |
# Zero diagonal | |
a.fill_diagonal_(0) | |
# Make sure symmetric and non nul | |
a[a < 0] = 0 | |
# C = A + A.t() | |
# Sort pixels by inverse degree | |
cent = -torch.sum(a > threshold, dim=1).type(torch.float32) | |
sel = torch.argsort(cent, descending=True) | |
return sel, cent | |
def detect_box( | |
bipartition, | |
seed, | |
dims, | |
initial_im_size=None, | |
scales=None, | |
principle_object=True, | |
): | |
"""Extract a box corresponding to the seed patch.""" | |
# Among connected components extract from the affinity matrix, select the one | |
# corresponding to the seed patch. | |
# w_featmap, h_featmap = dims | |
objects, _ = ndimage.label(bipartition) | |
cc = objects[np.unravel_index(seed, dims)] | |
if principle_object: | |
mask = np.where(objects == cc) | |
# Add +1 because excluded max | |
ymin, ymax = min(mask[0]), max(mask[0]) + 1 | |
xmin, xmax = min(mask[1]), max(mask[1]) + 1 | |
# Rescale to image size | |
r_xmin, r_xmax = scales[1] * xmin, scales[1] * xmax | |
r_ymin, r_ymax = scales[0] * ymin, scales[0] * ymax | |
pred = [r_xmin, r_ymin, r_xmax, r_ymax] | |
# Check not out of image size (used when padding) | |
if initial_im_size: | |
pred[2] = min(pred[2], initial_im_size[1]) | |
pred[3] = min(pred[3], initial_im_size[0]) | |
# Coordinate predictions for the feature space | |
# Axis different then in image space | |
pred_feats = [ymin, xmin, ymax, xmax] | |
return pred, pred_feats, objects, mask | |
else: | |
raise NotImplementedError | |
# This function is modified from | |
# https://github.com/facebookresearch/dino/blob/main/visualize_attention.py | |
# Ref: https://github.com/facebookresearch/dino. | |
def dino_seg(attn, dims, patch_size, head=0): | |
"""Extraction of boxes based on the DINO segmentation method proposed in DINO.""" | |
w_featmap, h_featmap = dims | |
nh = attn.shape[1] | |
official_th = 0.6 | |
# We keep only the output patch attention | |
# Get the attentions corresponding to [CLS] token | |
attentions = attn[0, :, 0, 1:].reshape(nh, -1) | |
# we keep only a certain percentage of the mass | |
val, idx = torch.sort(attentions) | |
val /= torch.sum(val, dim=1, keepdim=True) | |
cumval = torch.cumsum(val, dim=1) | |
th_attn = cumval > (1 - official_th) | |
idx2 = torch.argsort(idx) | |
for h in range(nh): | |
th_attn[h] = th_attn[h][idx2[h]] | |
th_attn = th_attn.reshape(nh, w_featmap, h_featmap).float() | |
# Connected components | |
labeled_array, _ = scipy.ndimage.label(th_attn[head].cpu().numpy()) | |
# Find the biggest component | |
size_components = [ | |
np.sum(labeled_array == c) for c in range(np.max(labeled_array)) | |
] | |
if len(size_components) > 1: | |
# Select the biggest component avoiding component 0 corresponding | |
# to background | |
biggest_component = np.argmax(size_components[1:]) + 1 | |
else: | |
# Cases of a single component | |
biggest_component = 0 | |
# Mask corresponding to connected component | |
mask = np.where(labeled_array == biggest_component) | |
# Add +1 because excluded max | |
ymin, ymax = min(mask[0]), max(mask[0]) + 1 | |
xmin, xmax = min(mask[1]), max(mask[1]) + 1 | |
# Rescale to image | |
r_xmin, r_xmax = xmin * patch_size, xmax * patch_size | |
r_ymin, r_ymax = ymin * patch_size, ymax * patch_size | |
pred = [r_xmin, r_ymin, r_xmax, r_ymax] | |
return pred | |
def get_feats(feat_out, shape): | |
# Batch size, Number of heads, Number of tokens | |
nb_im, nh, nb_tokens = shape[0:3] | |
qkv = ( | |
feat_out["qkv"] | |
.reshape(nb_im, nb_tokens, 3, nh, -1 // nh) | |
.permute(2, 0, 3, 1, 4) | |
) | |
k = qkv[1] | |
k = k.transpose(1, 2).reshape(nb_im, nb_tokens, -1) | |
return k | |
def get_instances(masks, return_largest=False): | |
return [ | |
get_instances_single(m[None], return_largest=return_largest) | |
for m in masks | |
] | |
def get_instances_single(mask, return_largest=False): | |
"""Get the mask of a single instance.""" | |
labeled_array, _ = label(mask.cpu().numpy()) | |
instances = np.concatenate( | |
[labeled_array == c for c in range(np.max(labeled_array) + 1)], axis=0 | |
) | |
if return_largest: | |
size_components = np.sum(instances, axis=(1, 2)) | |
if len(size_components) > 1: | |
# Select the biggest component avoiding component 0 corresponding | |
# to background | |
biggest_component = np.argmax(size_components[1:]) + 1 | |
else: | |
# Cases of a single component | |
biggest_component = 0 | |
# Mask corresponding to connected component | |
return torch.from_numpy(labeled_array == biggest_component).float() | |
return torch.from_numpy(instances[1:]).float() | |