File size: 5,428 Bytes
cc0dd3c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
# Copyright (c) OpenMMLab. All rights reserved.
# Copyright (c) ByteDance, Inc. and its affiliates. All rights reserved.
# Modified from https://github.com/keyu-tian/SparK/blob/main/encoder.py
import torch
import torch.nn as nn

from mmpretrain.registry import MODELS


class SparseHelper:
    """The helper to compute sparse operation with pytorch, such as sparse
    convlolution, sparse batch norm, etc."""

    _cur_active: torch.Tensor = None

    @staticmethod
    def _get_active_map_or_index(H: int,
                                 returning_active_map: bool = True
                                 ) -> torch.Tensor:
        """Get current active map with (B, 1, f, f) shape or index format."""
        # _cur_active with shape (B, 1, f, f)
        downsample_raito = H // SparseHelper._cur_active.shape[-1]
        active_ex = SparseHelper._cur_active.repeat_interleave(
            downsample_raito, 2).repeat_interleave(downsample_raito, 3)
        return active_ex if returning_active_map else active_ex.squeeze(
            1).nonzero(as_tuple=True)

    @staticmethod
    def sp_conv_forward(self, x: torch.Tensor) -> torch.Tensor:
        """Sparse convolution forward function."""
        x = super(type(self), self).forward(x)

        # (b, c, h, w) *= (b, 1, h, w), mask the output of conv
        x *= SparseHelper._get_active_map_or_index(
            H=x.shape[2], returning_active_map=True)
        return x

    @staticmethod
    def sp_bn_forward(self, x: torch.Tensor) -> torch.Tensor:
        """Sparse batch norm forward function."""
        active_index = SparseHelper._get_active_map_or_index(
            H=x.shape[2], returning_active_map=False)

        # (b, c, h, w) -> (b, h, w, c)
        x_permuted = x.permute(0, 2, 3, 1)

        # select the features on non-masked positions to form flatten features
        # with shape (n, c)
        x_flattened = x_permuted[active_index]

        # use BN1d to normalize this flatten feature (n, c)
        x_flattened = super(type(self), self).forward(x_flattened)

        # generate output
        output = torch.zeros_like(x_permuted, dtype=x_flattened.dtype)
        output[active_index] = x_flattened

        # (b, h, w, c) -> (b, c, h, w)
        output = output.permute(0, 3, 1, 2)
        return output


class SparseConv2d(nn.Conv2d):
    """hack: override the forward function.
    See `sp_conv_forward` above for more details
    """
    forward = SparseHelper.sp_conv_forward


class SparseMaxPooling(nn.MaxPool2d):
    """hack: override the forward function.
    See `sp_conv_forward` above for more details
    """
    forward = SparseHelper.sp_conv_forward


class SparseAvgPooling(nn.AvgPool2d):
    """hack: override the forward function.
    See `sp_conv_forward` above for more details
    """
    forward = SparseHelper.sp_conv_forward


@MODELS.register_module()
class SparseBatchNorm2d(nn.BatchNorm1d):
    """hack: override the forward function.
    See `sp_bn_forward` above for more details
    """
    forward = SparseHelper.sp_bn_forward


@MODELS.register_module()
class SparseSyncBatchNorm2d(nn.SyncBatchNorm):
    """hack: override the forward function.
    See `sp_bn_forward` above for more details
    """
    forward = SparseHelper.sp_bn_forward


@MODELS.register_module('SparseLN2d')
class SparseLayerNorm2D(nn.LayerNorm):
    """Implementation of sparse LayerNorm on channels for 2d images."""

    def forward(self,
                x: torch.Tensor,
                data_format='channel_first') -> torch.Tensor:
        """Sparse layer norm forward function with 2D data.

        Args:
            x (torch.Tensor): The input tensor.
            data_format (str): The format of the input tensor. If
                ``"channel_first"``, the shape of the input tensor should be
                (B, C, H, W). If ``"channel_last"``, the shape of the input
                tensor should be (B, H, W, C). Defaults to "channel_first".
        """
        assert x.dim() == 4, (
            f'LayerNorm2d only supports inputs with shape '
            f'(N, C, H, W), but got tensor with shape {x.shape}')
        if data_format == 'channel_last':
            index = SparseHelper._get_active_map_or_index(
                H=x.shape[1], returning_active_map=False)

            # select the features on non-masked positions to form flatten
            # features with shape (n, c)
            x_flattened = x[index]
            # use LayerNorm to normalize this flatten feature (n, c)
            x_flattened = super().forward(x_flattened)

            # generate output
            x = torch.zeros_like(x, dtype=x_flattened.dtype)
            x[index] = x_flattened
        elif data_format == 'channel_first':
            index = SparseHelper._get_active_map_or_index(
                H=x.shape[2], returning_active_map=False)
            x_permuted = x.permute(0, 2, 3, 1)

            # select the features on non-masked positions to form flatten
            # features with shape (n, c)
            x_flattened = x_permuted[index]
            # use LayerNorm to normalize this flatten feature (n, c)
            x_flattened = super().forward(x_flattened)

            # generate output
            x = torch.zeros_like(x_permuted, dtype=x_flattened.dtype)
            x[index] = x_flattened
            x = x.permute(0, 3, 1, 2).contiguous()
        else:
            raise NotImplementedError
        return x