Spaces:
Runtime error
Runtime error
# Copyright (c) OpenMMLab. All rights reserved. | |
# Copyright (c) ByteDance, Inc. and its affiliates. All rights reserved. | |
# Modified from https://github.com/keyu-tian/SparK/blob/main/encoder.py | |
import torch | |
import torch.nn as nn | |
from mmpretrain.registry import MODELS | |
class SparseHelper: | |
"""The helper to compute sparse operation with pytorch, such as sparse | |
convlolution, sparse batch norm, etc.""" | |
_cur_active: torch.Tensor = None | |
def _get_active_map_or_index(H: int, | |
returning_active_map: bool = True | |
) -> torch.Tensor: | |
"""Get current active map with (B, 1, f, f) shape or index format.""" | |
# _cur_active with shape (B, 1, f, f) | |
downsample_raito = H // SparseHelper._cur_active.shape[-1] | |
active_ex = SparseHelper._cur_active.repeat_interleave( | |
downsample_raito, 2).repeat_interleave(downsample_raito, 3) | |
return active_ex if returning_active_map else active_ex.squeeze( | |
1).nonzero(as_tuple=True) | |
def sp_conv_forward(self, x: torch.Tensor) -> torch.Tensor: | |
"""Sparse convolution forward function.""" | |
x = super(type(self), self).forward(x) | |
# (b, c, h, w) *= (b, 1, h, w), mask the output of conv | |
x *= SparseHelper._get_active_map_or_index( | |
H=x.shape[2], returning_active_map=True) | |
return x | |
def sp_bn_forward(self, x: torch.Tensor) -> torch.Tensor: | |
"""Sparse batch norm forward function.""" | |
active_index = SparseHelper._get_active_map_or_index( | |
H=x.shape[2], returning_active_map=False) | |
# (b, c, h, w) -> (b, h, w, c) | |
x_permuted = x.permute(0, 2, 3, 1) | |
# select the features on non-masked positions to form flatten features | |
# with shape (n, c) | |
x_flattened = x_permuted[active_index] | |
# use BN1d to normalize this flatten feature (n, c) | |
x_flattened = super(type(self), self).forward(x_flattened) | |
# generate output | |
output = torch.zeros_like(x_permuted, dtype=x_flattened.dtype) | |
output[active_index] = x_flattened | |
# (b, h, w, c) -> (b, c, h, w) | |
output = output.permute(0, 3, 1, 2) | |
return output | |
class SparseConv2d(nn.Conv2d): | |
"""hack: override the forward function. | |
See `sp_conv_forward` above for more details | |
""" | |
forward = SparseHelper.sp_conv_forward | |
class SparseMaxPooling(nn.MaxPool2d): | |
"""hack: override the forward function. | |
See `sp_conv_forward` above for more details | |
""" | |
forward = SparseHelper.sp_conv_forward | |
class SparseAvgPooling(nn.AvgPool2d): | |
"""hack: override the forward function. | |
See `sp_conv_forward` above for more details | |
""" | |
forward = SparseHelper.sp_conv_forward | |
class SparseBatchNorm2d(nn.BatchNorm1d): | |
"""hack: override the forward function. | |
See `sp_bn_forward` above for more details | |
""" | |
forward = SparseHelper.sp_bn_forward | |
class SparseSyncBatchNorm2d(nn.SyncBatchNorm): | |
"""hack: override the forward function. | |
See `sp_bn_forward` above for more details | |
""" | |
forward = SparseHelper.sp_bn_forward | |
class SparseLayerNorm2D(nn.LayerNorm): | |
"""Implementation of sparse LayerNorm on channels for 2d images.""" | |
def forward(self, | |
x: torch.Tensor, | |
data_format='channel_first') -> torch.Tensor: | |
"""Sparse layer norm forward function with 2D data. | |
Args: | |
x (torch.Tensor): The input tensor. | |
data_format (str): The format of the input tensor. If | |
``"channel_first"``, the shape of the input tensor should be | |
(B, C, H, W). If ``"channel_last"``, the shape of the input | |
tensor should be (B, H, W, C). Defaults to "channel_first". | |
""" | |
assert x.dim() == 4, ( | |
f'LayerNorm2d only supports inputs with shape ' | |
f'(N, C, H, W), but got tensor with shape {x.shape}') | |
if data_format == 'channel_last': | |
index = SparseHelper._get_active_map_or_index( | |
H=x.shape[1], returning_active_map=False) | |
# select the features on non-masked positions to form flatten | |
# features with shape (n, c) | |
x_flattened = x[index] | |
# use LayerNorm to normalize this flatten feature (n, c) | |
x_flattened = super().forward(x_flattened) | |
# generate output | |
x = torch.zeros_like(x, dtype=x_flattened.dtype) | |
x[index] = x_flattened | |
elif data_format == 'channel_first': | |
index = SparseHelper._get_active_map_or_index( | |
H=x.shape[2], returning_active_map=False) | |
x_permuted = x.permute(0, 2, 3, 1) | |
# select the features on non-masked positions to form flatten | |
# features with shape (n, c) | |
x_flattened = x_permuted[index] | |
# use LayerNorm to normalize this flatten feature (n, c) | |
x_flattened = super().forward(x_flattened) | |
# generate output | |
x = torch.zeros_like(x_permuted, dtype=x_flattened.dtype) | |
x[index] = x_flattened | |
x = x.permute(0, 3, 1, 2).contiguous() | |
else: | |
raise NotImplementedError | |
return x | |