Spaces:
Runtime error
Runtime error
File size: 5,428 Bytes
cc0dd3c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 |
# Copyright (c) OpenMMLab. All rights reserved.
# Copyright (c) ByteDance, Inc. and its affiliates. All rights reserved.
# Modified from https://github.com/keyu-tian/SparK/blob/main/encoder.py
import torch
import torch.nn as nn
from mmpretrain.registry import MODELS
class SparseHelper:
"""The helper to compute sparse operation with pytorch, such as sparse
convlolution, sparse batch norm, etc."""
_cur_active: torch.Tensor = None
@staticmethod
def _get_active_map_or_index(H: int,
returning_active_map: bool = True
) -> torch.Tensor:
"""Get current active map with (B, 1, f, f) shape or index format."""
# _cur_active with shape (B, 1, f, f)
downsample_raito = H // SparseHelper._cur_active.shape[-1]
active_ex = SparseHelper._cur_active.repeat_interleave(
downsample_raito, 2).repeat_interleave(downsample_raito, 3)
return active_ex if returning_active_map else active_ex.squeeze(
1).nonzero(as_tuple=True)
@staticmethod
def sp_conv_forward(self, x: torch.Tensor) -> torch.Tensor:
"""Sparse convolution forward function."""
x = super(type(self), self).forward(x)
# (b, c, h, w) *= (b, 1, h, w), mask the output of conv
x *= SparseHelper._get_active_map_or_index(
H=x.shape[2], returning_active_map=True)
return x
@staticmethod
def sp_bn_forward(self, x: torch.Tensor) -> torch.Tensor:
"""Sparse batch norm forward function."""
active_index = SparseHelper._get_active_map_or_index(
H=x.shape[2], returning_active_map=False)
# (b, c, h, w) -> (b, h, w, c)
x_permuted = x.permute(0, 2, 3, 1)
# select the features on non-masked positions to form flatten features
# with shape (n, c)
x_flattened = x_permuted[active_index]
# use BN1d to normalize this flatten feature (n, c)
x_flattened = super(type(self), self).forward(x_flattened)
# generate output
output = torch.zeros_like(x_permuted, dtype=x_flattened.dtype)
output[active_index] = x_flattened
# (b, h, w, c) -> (b, c, h, w)
output = output.permute(0, 3, 1, 2)
return output
class SparseConv2d(nn.Conv2d):
"""hack: override the forward function.
See `sp_conv_forward` above for more details
"""
forward = SparseHelper.sp_conv_forward
class SparseMaxPooling(nn.MaxPool2d):
"""hack: override the forward function.
See `sp_conv_forward` above for more details
"""
forward = SparseHelper.sp_conv_forward
class SparseAvgPooling(nn.AvgPool2d):
"""hack: override the forward function.
See `sp_conv_forward` above for more details
"""
forward = SparseHelper.sp_conv_forward
@MODELS.register_module()
class SparseBatchNorm2d(nn.BatchNorm1d):
"""hack: override the forward function.
See `sp_bn_forward` above for more details
"""
forward = SparseHelper.sp_bn_forward
@MODELS.register_module()
class SparseSyncBatchNorm2d(nn.SyncBatchNorm):
"""hack: override the forward function.
See `sp_bn_forward` above for more details
"""
forward = SparseHelper.sp_bn_forward
@MODELS.register_module('SparseLN2d')
class SparseLayerNorm2D(nn.LayerNorm):
"""Implementation of sparse LayerNorm on channels for 2d images."""
def forward(self,
x: torch.Tensor,
data_format='channel_first') -> torch.Tensor:
"""Sparse layer norm forward function with 2D data.
Args:
x (torch.Tensor): The input tensor.
data_format (str): The format of the input tensor. If
``"channel_first"``, the shape of the input tensor should be
(B, C, H, W). If ``"channel_last"``, the shape of the input
tensor should be (B, H, W, C). Defaults to "channel_first".
"""
assert x.dim() == 4, (
f'LayerNorm2d only supports inputs with shape '
f'(N, C, H, W), but got tensor with shape {x.shape}')
if data_format == 'channel_last':
index = SparseHelper._get_active_map_or_index(
H=x.shape[1], returning_active_map=False)
# select the features on non-masked positions to form flatten
# features with shape (n, c)
x_flattened = x[index]
# use LayerNorm to normalize this flatten feature (n, c)
x_flattened = super().forward(x_flattened)
# generate output
x = torch.zeros_like(x, dtype=x_flattened.dtype)
x[index] = x_flattened
elif data_format == 'channel_first':
index = SparseHelper._get_active_map_or_index(
H=x.shape[2], returning_active_map=False)
x_permuted = x.permute(0, 2, 3, 1)
# select the features on non-masked positions to form flatten
# features with shape (n, c)
x_flattened = x_permuted[index]
# use LayerNorm to normalize this flatten feature (n, c)
x_flattened = super().forward(x_flattened)
# generate output
x = torch.zeros_like(x_permuted, dtype=x_flattened.dtype)
x[index] = x_flattened
x = x.permute(0, 3, 1, 2).contiguous()
else:
raise NotImplementedError
return x
|