File size: 7,779 Bytes
32b542e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
import torch
from torch import nn

from uniperceiver.config import configurable
from ..layers.create_act import get_act_layer
from .build import EMBEDDING_REGISTRY
from .position_embedding import build_position_encoding
# from uniperceiver.modeling.layers import LayerNorm
from uniperceiver.utils import comm
import copy
from uniperceiver.modeling.layers import FP16LayerNorm


__all__ = ["TokenBaseEmbedding"]

@EMBEDDING_REGISTRY.register()
class TokenBaseEmbedding(nn.Module):
    @configurable
    def __init__(
        self,
        *,
        dim: int,
        vocab_size: int, # include <BOS>/<EOS>
        **kwargs
    ):
        super(TokenBaseEmbedding, self).__init__()
        self.embeddings = nn.Embedding(vocab_size, dim)
        self.embeddings_act = kwargs.pop("embeddings_act", None)
        self.embeddings_norm = kwargs.pop("embeddings_norm", None)
        self.embeddings_dropout = kwargs.pop("embeddings_dropout", None)
        self.embeddings_pos = kwargs.pop("embeddings_pos", None)
        self.embeddings_token_type = kwargs.pop('embeddings_token_type', None)
        self.embeddings_token_seg = kwargs.pop('embeddings_token_seg', None)
        self.bw_own_embed = kwargs.pop('bw_own_embed', False)
        self.pos_before = kwargs.pop('pos_before', True)
        self.cfg = kwargs.pop('cfg', None)

        if self.bw_own_embed:
            # only for debugging
            self.bw_embeddings = copy.deepcopy(self.embeddings)
            self.bw_embeddings_norm = copy.deepcopy(self.embeddings_norm)
            self.bw_embeddings_pos = copy.deepcopy(self.embeddings_pos)
            self.bw_embeddings_token_type = copy.deepcopy(self.embeddings_token_type)
        self.s_token_bias = None

    @classmethod
    def from_config(cls, cfg):
        kwargs = {
            "dim": cfg.MODEL.TOKEN_EMBED.DIM,
            "vocab_size": cfg.MODEL.VOCAB_SIZE
        }

        activation_name = (cfg.MODEL.TOKEN_EMBED.ACTIVATION).lower()
        if activation_name != "none":
            activation = get_act_layer(activation_name)
            assert activation is not None

            act_kwargs = {}
            if activation_name in { "elu", "celu" }:
                act_kwargs["alpha"] = cfg.MODEL.TOKEN_EMBED.ELU_ALPHA
            embeddings_act = activation(**act_kwargs)
            kwargs['embeddings_act'] = embeddings_act

        if cfg.MODEL.TOKEN_EMBED.DROPOUT > 0:
            embeddings_dropout = nn.Dropout(cfg.MODEL.TOKEN_EMBED.DROPOUT)
            kwargs['embeddings_dropout'] = embeddings_dropout

        if cfg.MODEL.TOKEN_EMBED.USE_NORM:
            if cfg.SOLVER.FORCE_LN_FP16:
                embeddings_norm = FP16LayerNorm(cfg.MODEL.TOKEN_EMBED.DIM)
            else:
                embeddings_norm = nn.LayerNorm(cfg.MODEL.TOKEN_EMBED.DIM)
            kwargs['embeddings_norm'] = embeddings_norm

        if (cfg.MODEL.TOKEN_EMBED.POSITION).lower() != 'none':
            embeddings_pos = build_position_encoding(cfg,
                cfg.MODEL.TOKEN_EMBED.DIM, cfg.MODEL.TOKEN_EMBED.POSITION_MAX_LEN)
            kwargs['embeddings_pos'] = embeddings_pos

        if cfg.MODEL.TOKEN_EMBED.TYPE_VOCAB_SIZE > 0:
            embeddings_token_type = nn.Embedding(
                cfg.MODEL.TOKEN_EMBED.TYPE_VOCAB_SIZE, cfg.MODEL.TOKEN_EMBED.DIM)
            kwargs['embeddings_token_type'] = embeddings_token_type

        if cfg.MODEL.TOKEN_EMBED.TYPE_SEG_SIZE > 0:
            embeddings_token_seg = nn.Embedding(
                cfg.MODEL.TOKEN_EMBED.TYPE_SEG_SIZE, cfg.MODEL.TOKEN_EMBED.DIM)
            kwargs['embeddings_token_seg'] = embeddings_token_seg

        # for debug
        kwargs['bw_own_embed'] = cfg.MODEL.BW_OWD_EMBED
        kwargs['pos_before'] = cfg.MODEL.POS_BEFORE
        kwargs['cfg'] = cfg
        return kwargs

    def get_time_step(self, data, sample_info, task_info=None):
        """
        data: Bs, L
        task_info: {
            task_type: str
        }
        """
        # TODO: the position embedding for caption text should be handled in a different way.  0,1, n/2,0,1, n/2,
        if task_info is None:
            task_type = ''
        else:
            task_type = task_info.get('task_type', None)
        time_step = None
        if isinstance(sample_info, list):
            sample_info = sample_info[0]
        if task_type in ['image_caption', 'video_caption'] and sample_info.get('text_spe_cat', False):
            text_length = data.shape[1]
            time_step = torch.cat([
                torch.arange(text_length // 2,
                             dtype=torch.long,
                             device=data.device) for _ in range(2)
            ])
        elif task_type == 'VQA' and sample_info.get('text_spe_cat', False):
            text_length = data.shape[1]
            time_step = torch.cat([
                torch.arange(text_length - 1,
                             dtype=torch.long,
                             device=data.device),
                torch.arange(1, dtype=torch.long, device=data.device)
            ])


        return time_step

    def forward(self, data, sample_info={}, task_info={}, **kwargs):


        time_step = kwargs.pop('time_step', None)
        if time_step is None:
            time_step = self.get_time_step(data, sample_info, task_info)

        if kwargs.pop("prompt_with_pos", False):
            raise NotImplementedError
        else:
            start_time = 0

        type_embed = kwargs.get('type_embed', True)
        pos_emb = kwargs.get('pos_embed', True)

        data = self._forward(data,
                            type_embed=type_embed,
                            pos_emb=pos_emb,
                            token_seg_ids=None,
                            time_step=time_step,
                            start_time=start_time)

        return data



    def set_s_token_bias(self, s_token_bias):
        self.s_token_bias = s_token_bias

    def _forward(self, input_ids, type_embed=True, token_seg_ids=None, time_step=None, pos_emb=True, start_time=0, ):

        embeddings = self.embeddings(input_ids)
        if self.cfg.SOLVER.FORCE_EMBED_FP16:
            embeddings = embeddings.half()

        if self.s_token_bias is not None:
            # learnable
            embeddings[input_ids == 49410] = embeddings[input_ids == 49410] + self.s_token_bias

        if self.embeddings_pos is not None and pos_emb and self.pos_before:
            pos_inputs = input_ids if time_step is None else time_step
            position_embeddings = self.embeddings_pos(pos_inputs, start_time=start_time)
            embeddings = embeddings + position_embeddings.to(embeddings.dtype)

        if self.embeddings_token_type is not None and type_embed:

            embeddings_token_type = self.embeddings_token_type.weight[0].unsqueeze(0).unsqueeze(1)
            embeddings = embeddings + embeddings_token_type.to(embeddings.dtype)

        if (self.embeddings_token_seg is not None) and (token_seg_ids is not None):
            embeddings_token_seg = self.embeddings_token_seg(token_seg_ids)
            embeddings = embeddings + embeddings_token_seg

        if self.embeddings_act is not None:
            embeddings = self.embeddings_act(embeddings)

        if self.embeddings_norm is not None:
            embeddings = self.embeddings_norm(embeddings)

        if self.embeddings_pos is not None and pos_emb and not self.pos_before:
            pos_inputs = input_ids if time_step is None else time_step
            position_embeddings = self.embeddings_pos(pos_inputs, start_time=start_time)
            embeddings = embeddings + position_embeddings.to(embeddings.dtype)

        if self.embeddings_dropout is not None:
            embeddings = self.embeddings_dropout(embeddings)

        return embeddings