File size: 7,779 Bytes
32b542e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 |
import torch
from torch import nn
from uniperceiver.config import configurable
from ..layers.create_act import get_act_layer
from .build import EMBEDDING_REGISTRY
from .position_embedding import build_position_encoding
# from uniperceiver.modeling.layers import LayerNorm
from uniperceiver.utils import comm
import copy
from uniperceiver.modeling.layers import FP16LayerNorm
__all__ = ["TokenBaseEmbedding"]
@EMBEDDING_REGISTRY.register()
class TokenBaseEmbedding(nn.Module):
@configurable
def __init__(
self,
*,
dim: int,
vocab_size: int, # include <BOS>/<EOS>
**kwargs
):
super(TokenBaseEmbedding, self).__init__()
self.embeddings = nn.Embedding(vocab_size, dim)
self.embeddings_act = kwargs.pop("embeddings_act", None)
self.embeddings_norm = kwargs.pop("embeddings_norm", None)
self.embeddings_dropout = kwargs.pop("embeddings_dropout", None)
self.embeddings_pos = kwargs.pop("embeddings_pos", None)
self.embeddings_token_type = kwargs.pop('embeddings_token_type', None)
self.embeddings_token_seg = kwargs.pop('embeddings_token_seg', None)
self.bw_own_embed = kwargs.pop('bw_own_embed', False)
self.pos_before = kwargs.pop('pos_before', True)
self.cfg = kwargs.pop('cfg', None)
if self.bw_own_embed:
# only for debugging
self.bw_embeddings = copy.deepcopy(self.embeddings)
self.bw_embeddings_norm = copy.deepcopy(self.embeddings_norm)
self.bw_embeddings_pos = copy.deepcopy(self.embeddings_pos)
self.bw_embeddings_token_type = copy.deepcopy(self.embeddings_token_type)
self.s_token_bias = None
@classmethod
def from_config(cls, cfg):
kwargs = {
"dim": cfg.MODEL.TOKEN_EMBED.DIM,
"vocab_size": cfg.MODEL.VOCAB_SIZE
}
activation_name = (cfg.MODEL.TOKEN_EMBED.ACTIVATION).lower()
if activation_name != "none":
activation = get_act_layer(activation_name)
assert activation is not None
act_kwargs = {}
if activation_name in { "elu", "celu" }:
act_kwargs["alpha"] = cfg.MODEL.TOKEN_EMBED.ELU_ALPHA
embeddings_act = activation(**act_kwargs)
kwargs['embeddings_act'] = embeddings_act
if cfg.MODEL.TOKEN_EMBED.DROPOUT > 0:
embeddings_dropout = nn.Dropout(cfg.MODEL.TOKEN_EMBED.DROPOUT)
kwargs['embeddings_dropout'] = embeddings_dropout
if cfg.MODEL.TOKEN_EMBED.USE_NORM:
if cfg.SOLVER.FORCE_LN_FP16:
embeddings_norm = FP16LayerNorm(cfg.MODEL.TOKEN_EMBED.DIM)
else:
embeddings_norm = nn.LayerNorm(cfg.MODEL.TOKEN_EMBED.DIM)
kwargs['embeddings_norm'] = embeddings_norm
if (cfg.MODEL.TOKEN_EMBED.POSITION).lower() != 'none':
embeddings_pos = build_position_encoding(cfg,
cfg.MODEL.TOKEN_EMBED.DIM, cfg.MODEL.TOKEN_EMBED.POSITION_MAX_LEN)
kwargs['embeddings_pos'] = embeddings_pos
if cfg.MODEL.TOKEN_EMBED.TYPE_VOCAB_SIZE > 0:
embeddings_token_type = nn.Embedding(
cfg.MODEL.TOKEN_EMBED.TYPE_VOCAB_SIZE, cfg.MODEL.TOKEN_EMBED.DIM)
kwargs['embeddings_token_type'] = embeddings_token_type
if cfg.MODEL.TOKEN_EMBED.TYPE_SEG_SIZE > 0:
embeddings_token_seg = nn.Embedding(
cfg.MODEL.TOKEN_EMBED.TYPE_SEG_SIZE, cfg.MODEL.TOKEN_EMBED.DIM)
kwargs['embeddings_token_seg'] = embeddings_token_seg
# for debug
kwargs['bw_own_embed'] = cfg.MODEL.BW_OWD_EMBED
kwargs['pos_before'] = cfg.MODEL.POS_BEFORE
kwargs['cfg'] = cfg
return kwargs
def get_time_step(self, data, sample_info, task_info=None):
"""
data: Bs, L
task_info: {
task_type: str
}
"""
# TODO: the position embedding for caption text should be handled in a different way. 0,1, n/2,0,1, n/2,
if task_info is None:
task_type = ''
else:
task_type = task_info.get('task_type', None)
time_step = None
if isinstance(sample_info, list):
sample_info = sample_info[0]
if task_type in ['image_caption', 'video_caption'] and sample_info.get('text_spe_cat', False):
text_length = data.shape[1]
time_step = torch.cat([
torch.arange(text_length // 2,
dtype=torch.long,
device=data.device) for _ in range(2)
])
elif task_type == 'VQA' and sample_info.get('text_spe_cat', False):
text_length = data.shape[1]
time_step = torch.cat([
torch.arange(text_length - 1,
dtype=torch.long,
device=data.device),
torch.arange(1, dtype=torch.long, device=data.device)
])
return time_step
def forward(self, data, sample_info={}, task_info={}, **kwargs):
time_step = kwargs.pop('time_step', None)
if time_step is None:
time_step = self.get_time_step(data, sample_info, task_info)
if kwargs.pop("prompt_with_pos", False):
raise NotImplementedError
else:
start_time = 0
type_embed = kwargs.get('type_embed', True)
pos_emb = kwargs.get('pos_embed', True)
data = self._forward(data,
type_embed=type_embed,
pos_emb=pos_emb,
token_seg_ids=None,
time_step=time_step,
start_time=start_time)
return data
def set_s_token_bias(self, s_token_bias):
self.s_token_bias = s_token_bias
def _forward(self, input_ids, type_embed=True, token_seg_ids=None, time_step=None, pos_emb=True, start_time=0, ):
embeddings = self.embeddings(input_ids)
if self.cfg.SOLVER.FORCE_EMBED_FP16:
embeddings = embeddings.half()
if self.s_token_bias is not None:
# learnable
embeddings[input_ids == 49410] = embeddings[input_ids == 49410] + self.s_token_bias
if self.embeddings_pos is not None and pos_emb and self.pos_before:
pos_inputs = input_ids if time_step is None else time_step
position_embeddings = self.embeddings_pos(pos_inputs, start_time=start_time)
embeddings = embeddings + position_embeddings.to(embeddings.dtype)
if self.embeddings_token_type is not None and type_embed:
embeddings_token_type = self.embeddings_token_type.weight[0].unsqueeze(0).unsqueeze(1)
embeddings = embeddings + embeddings_token_type.to(embeddings.dtype)
if (self.embeddings_token_seg is not None) and (token_seg_ids is not None):
embeddings_token_seg = self.embeddings_token_seg(token_seg_ids)
embeddings = embeddings + embeddings_token_seg
if self.embeddings_act is not None:
embeddings = self.embeddings_act(embeddings)
if self.embeddings_norm is not None:
embeddings = self.embeddings_norm(embeddings)
if self.embeddings_pos is not None and pos_emb and not self.pos_before:
pos_inputs = input_ids if time_step is None else time_step
position_embeddings = self.embeddings_pos(pos_inputs, start_time=start_time)
embeddings = embeddings + position_embeddings.to(embeddings.dtype)
if self.embeddings_dropout is not None:
embeddings = self.embeddings_dropout(embeddings)
return embeddings
|