Spaces:
Build error
Build error
"""Layer functions""" | |
import torch | |
import torch.nn.functional as F | |
import cirtorch.layers.functional as CF | |
def smoothing_avg_pooling(feats, kernel_size): | |
"""Smoothing average pooling | |
:param torch.Tensor feats: Feature map | |
:param int kernel_size: kernel size of pooling | |
:return torch.Tensor: Smoothend feature map | |
""" | |
pad = kernel_size // 2 | |
return F.avg_pool2d(feats, (kernel_size, kernel_size), stride=1, padding=pad, | |
count_include_pad=False) | |
def weighted_spoc(ms_feats, ms_weights): | |
"""Weighted SPoC pooling, summed over scales. | |
:param list ms_feats: A list of feature maps, each at a different scale | |
:param list ms_weights: A list of weights, each at a different scale | |
:return torch.Tensor: L2-normalized global descriptor | |
""" | |
desc = torch.zeros((1, ms_feats[0].shape[1]), dtype=torch.float32, device=ms_feats[0].device) | |
for feats, weights in zip(ms_feats, ms_weights): | |
desc += (feats * weights).sum((-2, -1)).squeeze() | |
return CF.l2n(desc) | |
def how_select_local(ms_feats, ms_masks, *, scales, features_num): | |
"""Convert multi-scale feature maps with attentions to a list of local descriptors | |
:param list ms_feats: A list of feature maps, each at a different scale | |
:param list ms_masks: A list of attentions, each at a different scale | |
:param list scales: A list of scales (floats) | |
:param int features_num: Number of features to be returned (sorted by attenions) | |
:return tuple: A list of descriptors, attentions, locations (x_coor, y_coor) and scales where | |
elements from each list correspond to each other | |
""" | |
device = ms_feats[0].device | |
size = sum(x.shape[0] * x.shape[1] for x in ms_masks) | |
desc = torch.zeros(size, ms_feats[0].shape[1], dtype=torch.float32, device=device) | |
atts = torch.zeros(size, dtype=torch.float32, device=device) | |
locs = torch.zeros(size, 2, dtype=torch.int16, device=device) | |
scls = torch.zeros(size, dtype=torch.float16, device=device) | |
pointer = 0 | |
for sc, vs, ms in zip(scales, ms_feats, ms_masks): | |
if len(ms.shape) == 0: | |
continue | |
height, width = ms.shape | |
numel = torch.numel(ms) | |
slc = slice(pointer, pointer+numel) | |
pointer += numel | |
desc[slc] = vs.squeeze(0).reshape(vs.shape[1], -1).T | |
atts[slc] = ms.reshape(-1) | |
width_arr = torch.arange(width, dtype=torch.int16) | |
locs[slc, 0] = width_arr.repeat(height).to(device) # x axis | |
height_arr = torch.arange(height, dtype=torch.int16) | |
locs[slc, 1] = height_arr.view(-1, 1).repeat(1, width).reshape(-1).to(device) # y axis | |
scls[slc] = sc | |
keep_n = min(features_num, atts.shape[0]) if features_num is not None else atts.shape[0] | |
idx = atts.sort(descending=True)[1][:keep_n] | |
return desc[idx], atts[idx], locs[idx], scls[idx] | |