Spaces:
Build error
Build error
import os | |
import copy | |
import math | |
import torch | |
from torch import nn, Tensor | |
import torch.nn.functional as F | |
def rand_sample(x, divisor, max_len): | |
# non_zero_pos_point = [rand_sample((m.nonzero()/divisor).t(), self.max_spatial_len[-1]).t() for m in extra['spatial_query_pos_mask']] | |
if len(x.nonzero()) == 0: | |
return x.nonzero().t() | |
non_zero_point_index = (x.nonzero()/divisor).t() | |
mask_ids = non_zero_point_index[0].unique().long() | |
# compute probability for each samle | |
probs = torch.zeros_like(non_zero_point_index[0]) | |
for idx in mask_ids: | |
prob = 1./(len(mask_ids)*((non_zero_point_index[0:1]==idx).sum())) | |
probs[non_zero_point_index[0]==idx] = prob | |
indices = torch.multinomial(probs, num_samples=min(max_len, len(probs)), replacement=False).sort()[0] | |
non_zero_point_index = non_zero_point_index[:,indices] | |
return non_zero_point_index # [n, 512] | |
def rand_sample_plain(x, max_len): | |
if x.shape[1] <= max_len: | |
return x | |
else: | |
rand_idx = torch.randperm(x.shape[1])[:max_len] | |
return x[:,rand_idx] | |
def prepare_features(x, num_feature_levels, pe_layer, input_proj, level_embed): | |
src = [] | |
pos = [] | |
size_list = [] | |
# disable mask, it does not affect performance | |
for i in range(num_feature_levels): | |
size_list.append(x[i].shape[-2:]) | |
pos.append(pe_layer(x[i], None).flatten(2)) | |
src.append(input_proj[i](x[i]).flatten(2) + level_embed.weight[i][None, :, None]) | |
# flatten NxCxHxW to HWxNxC | |
pos[-1] = pos[-1].permute(2, 0, 1) | |
src[-1] = src[-1].permute(2, 0, 1) | |
return src, pos, size_list |