File size: 8,060 Bytes
32b542e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
import torch
from torch import nn

from uniperceiver.config import configurable
from ..layers.create_act import get_act_layer
from .build import EMBEDDING_REGISTRY
from .position_embedding import NNEmbeddingEncoding
from einops import rearrange, repeat
from uniperceiver.modeling.layers import FP16LayerNorm


__all__ = ["VideoBaseEmbedding"]

@EMBEDDING_REGISTRY.register()
class VideoBaseEmbedding(nn.Module):
    @configurable
    def __init__(
        self,
        *,
        cfg: dict,
        in_dim: int,
        out_dim: int,
        patch_size: int,
        time_span: int,
        max_time_len: int,
        max_spatial_size = 196,
        **kwargs
    ):
        super(VideoBaseEmbedding, self).__init__()
        self.cfg = cfg
        self.embeddings = nn.Linear(in_dim, out_dim)
        self.embeddings_act = kwargs.pop("embeddings_act", None)
        self.embeddings_norm = kwargs.pop("embeddings_norm", None)
        self.embeddings_dropout = kwargs.pop("embeddings_dropout", None)
        self.embeddings_pos = kwargs.pop('embeddings_pos', None)
        self.embeddings_type = kwargs.pop("embeddings_token_type", None)
        self.random_temporal_pos = kwargs.pop("random_temporal_pos", True)
        self.patch_size = patch_size
        self.time_span = time_span
        self.pos_before = kwargs.pop('pos_before', True)
        self.add_type_embedding = cfg.MODEL.VIDEO_EMBED.ADD_TYPE_EMBED
        if self.add_type_embedding:
            assert self.embeddings_type is not None

        self.embeddings_st_pos = None
        self.max_spatial_size = max_spatial_size
        if isinstance(self.embeddings_pos, str):
            if self.embeddings_pos == 'divide_st_pos':
                self.embeddings_st_pos = Divide_ST_POS(
                    max_spatial_size, max_time_len, out_dim,
                    self.random_temporal_pos)
                self.embeddings_pos = None
                del self.embeddings
                self.embeddings = nn.Conv2d(in_dim//(self.patch_size**2), out_dim, kernel_size=self.patch_size, stride=self.patch_size)
        pass

    def replace_weight(self, visual_embed):
        if visual_embed is not None:
            del self.embeddings
            self.embeddings = visual_embed.patch_embed.proj

    def share_spatial_pos(self, visual_embed):
        if self.embeddings_st_pos is not None and visual_embed is not None:

            if self.embeddings_st_pos.spatial_pos_embed.weight.shape[0] == visual_embed.patch_embed.pos_embed.weight.shape[0]:
                self.embeddings_st_pos.spatial_pos_embed_index = 0
            else:
                # cls token in image patch tokenizer
                self.embeddings_st_pos.spatial_pos_embed_index = 1
            del self.embeddings_st_pos.spatial_pos_embed
            self.embeddings_st_pos.spatial_pos_embed = visual_embed.patch_embed.pos_embed
            pass

    @classmethod
    def from_config(cls, cfg):
        kwargs = {
            "in_dim": cfg.MODEL.VIDEO_EMBED.IN_DIM,
            "out_dim": cfg.MODEL.VIDEO_EMBED.OUT_DIM,
            "patch_size": cfg.MODEL.PATCH_SIZE,
            "time_span": cfg.MODEL.VIDEO_EMBED.PATCH_SIZE_T,
            "max_time_len": cfg.MODEL.VIDEO_EMBED.MAX_FRAMES,
        }
        max_spatial_size = int(cfg.MODEL.IMG_INPUT_SIZE/cfg.MODEL.PATCH_SIZE)**2
        kwargs['max_spatial_size'] = max_spatial_size
        activation_name = (cfg.MODEL.VIDEO_EMBED.ACTIVATION).lower()
        if activation_name != "none":
            activation = get_act_layer(activation_name)
            assert activation is not None

            act_kwargs = {}
            if activation_name in { "elu", "celu" }:
                act_kwargs["alpha"] = cfg.MODEL.VIDEO_EMBED.ELU_ALPHA
            embeddings_act = activation(**act_kwargs)
            kwargs['embeddings_act'] = embeddings_act

        if cfg.MODEL.VIDEO_EMBED.DROPOUT > 0:
            embeddings_dropout = nn.Dropout(cfg.MODEL.VIDEO_EMBED.DROPOUT)
            kwargs['embeddings_dropout'] = embeddings_dropout

        if cfg.MODEL.VIDEO_EMBED.USE_NORM:
            if cfg.SOLVER.FORCE_LN_FP16:
                embeddings_norm = FP16LayerNorm(cfg.MODEL.VIDEO_EMBED.OUT_DIM)
            else:
                embeddings_norm = nn.LayerNorm(cfg.MODEL.VIDEO_EMBED.OUT_DIM)
            kwargs['embeddings_norm'] = embeddings_norm

        if cfg.MODEL.VIDEO_EMBED.DIVIDE_ST_POS:
            kwargs['embeddings_pos'] = "divide_st_pos"

        elif cfg.MODEL.VIDEO_EMBED.POSITION.lower() != 'none':
            embeddings_pos = NNEmbeddingEncoding(cfg.MODEL.VIDEO_EMBED.OUT_DIM, cfg.MODEL.VIDEO_EMBED.MAX_LENGTH)
            kwargs['embeddings_pos'] = embeddings_pos

        if cfg.MODEL.VIDEO_EMBED.TYPE_SIZE > 0:
            embeddings_token_type = nn.Embedding(
            cfg.MODEL.VIDEO_EMBED.TYPE_SIZE, cfg.MODEL.VIDEO_EMBED.OUT_DIM)
            kwargs['embeddings_token_type'] = embeddings_token_type
        kwargs['random_temporal_pos'] = cfg.MODEL.VIDEO_EMBED.POS_RANDOM
        kwargs['pos_before'] = cfg.MODEL.POS_BEFORE
        kwargs['cfg'] = cfg
        return kwargs

    def forward(self, data, **kwargs):

        if data.dim() == 4:
            #images
            data = data.unsqueeze(1)  # BS, 3, 224, 224


        if self.embeddings_st_pos is not None:
            bs = data.size(0)
            x = self.embeddings(data.flatten(0, 1)) # b*t, dim, 14, 14
            x = x.flatten(2) # .flatten(2)
            embeddings = rearrange(x, '(b t s) c hw -> b t hw (s c)', b=bs,  s = self.time_span)
            embeddings_pos = self.embeddings_st_pos(embeddings).unsqueeze(
                0).flatten(1, 2)
            embeddings = embeddings.flatten(1, 2)
            if self.pos_before:
                embeddings = embeddings + embeddings_pos.to(embeddings.dtype)


        if self.embeddings_pos is not None:
            x = rearrange(data, 'b (t s) c (h p1) (w p2) -> b (t h w) (s c p1 p2)', s = self.time_span, p1 = self.patch_size, p2 = self.patch_size)
            embeddings = self.embeddings(x)
            embeddings_pos = self.embeddings_pos(x).unsqueeze(0)
            if self.pos_before:
                embeddings = embeddings + embeddings_pos.to(embeddings.dtype)

        if self.add_type_embedding:
            embeddings = embeddings + self.embeddings_type.weight[0].unsqueeze(0).unsqueeze(1).to(embeddings.dtype)

        if self.embeddings_act is not None:
            embeddings = self.embeddings_act(embeddings)

        if self.embeddings_norm is not None:
            embeddings = self.embeddings_norm(embeddings)

        if not self.pos_before:
            embeddings = embeddings + embeddings_pos

        if self.embeddings_dropout is not None:
            embeddings = self.embeddings_dropout(embeddings)

        return embeddings



class Divide_ST_POS(nn.Module):
    def __init__(self, num_patches, max_time_len, out_dim,
                 random_temporal_pos):
        super(Divide_ST_POS, self).__init__()
        self.spatial_pos_embed = nn.Embedding(num_patches, out_dim)
        self.temporal_pos_embed = nn.Embedding(max_time_len, out_dim)
        self.spatial_pos_embed_index = 0 # sometimes image has cls_token
        self.max_frames = max_time_len
        self.random_temporal_pos = random_temporal_pos

    def forward(self, x):
        dtype = x.dtype
        temp_len, spatial_size = x.size(1), x.size(2)

        if self.training and self.random_temporal_pos:
            temporal_pos_ids = torch.arange(temp_len, dtype=torch.long, device=x.device) + \
                torch.randint(0, self.max_frames - temp_len + 1, size=(1,), dtype=torch.long, device=x.device)
        else:
            temporal_pos_ids = torch.arange(temp_len, dtype=torch.long, device=x.device)

        pos_embed = self.temporal_pos_embed(temporal_pos_ids).unsqueeze(1).to(dtype=dtype) + \
            self.spatial_pos_embed( torch.arange(start= self.spatial_pos_embed_index, end=spatial_size +  self.spatial_pos_embed_index , dtype=torch.long, device=x.device)).unsqueeze(0).to(dtype=dtype)
        return pos_embed