|
import torch |
|
from torch import nn |
|
|
|
from uniperceiver.config import configurable |
|
from ..layers.create_act import get_act_layer |
|
from .build import EMBEDDING_REGISTRY |
|
from .position_embedding import build_position_encoding |
|
|
|
from uniperceiver.utils import comm |
|
import copy |
|
from uniperceiver.modeling.layers import FP16LayerNorm |
|
|
|
|
|
__all__ = ["TokenBaseEmbedding"] |
|
|
|
@EMBEDDING_REGISTRY.register() |
|
class TokenBaseEmbedding(nn.Module): |
|
@configurable |
|
def __init__( |
|
self, |
|
*, |
|
dim: int, |
|
vocab_size: int, |
|
**kwargs |
|
): |
|
super(TokenBaseEmbedding, self).__init__() |
|
self.embeddings = nn.Embedding(vocab_size, dim) |
|
self.embeddings_act = kwargs.pop("embeddings_act", None) |
|
self.embeddings_norm = kwargs.pop("embeddings_norm", None) |
|
self.embeddings_dropout = kwargs.pop("embeddings_dropout", None) |
|
self.embeddings_pos = kwargs.pop("embeddings_pos", None) |
|
self.embeddings_token_type = kwargs.pop('embeddings_token_type', None) |
|
self.embeddings_token_seg = kwargs.pop('embeddings_token_seg', None) |
|
self.bw_own_embed = kwargs.pop('bw_own_embed', False) |
|
self.pos_before = kwargs.pop('pos_before', True) |
|
self.cfg = kwargs.pop('cfg', None) |
|
|
|
if self.bw_own_embed: |
|
|
|
self.bw_embeddings = copy.deepcopy(self.embeddings) |
|
self.bw_embeddings_norm = copy.deepcopy(self.embeddings_norm) |
|
self.bw_embeddings_pos = copy.deepcopy(self.embeddings_pos) |
|
self.bw_embeddings_token_type = copy.deepcopy(self.embeddings_token_type) |
|
self.s_token_bias = None |
|
|
|
@classmethod |
|
def from_config(cls, cfg): |
|
kwargs = { |
|
"dim": cfg.MODEL.TOKEN_EMBED.DIM, |
|
"vocab_size": cfg.MODEL.VOCAB_SIZE |
|
} |
|
|
|
activation_name = (cfg.MODEL.TOKEN_EMBED.ACTIVATION).lower() |
|
if activation_name != "none": |
|
activation = get_act_layer(activation_name) |
|
assert activation is not None |
|
|
|
act_kwargs = {} |
|
if activation_name in { "elu", "celu" }: |
|
act_kwargs["alpha"] = cfg.MODEL.TOKEN_EMBED.ELU_ALPHA |
|
embeddings_act = activation(**act_kwargs) |
|
kwargs['embeddings_act'] = embeddings_act |
|
|
|
if cfg.MODEL.TOKEN_EMBED.DROPOUT > 0: |
|
embeddings_dropout = nn.Dropout(cfg.MODEL.TOKEN_EMBED.DROPOUT) |
|
kwargs['embeddings_dropout'] = embeddings_dropout |
|
|
|
if cfg.MODEL.TOKEN_EMBED.USE_NORM: |
|
if cfg.SOLVER.FORCE_LN_FP16: |
|
embeddings_norm = FP16LayerNorm(cfg.MODEL.TOKEN_EMBED.DIM) |
|
else: |
|
embeddings_norm = nn.LayerNorm(cfg.MODEL.TOKEN_EMBED.DIM) |
|
kwargs['embeddings_norm'] = embeddings_norm |
|
|
|
if (cfg.MODEL.TOKEN_EMBED.POSITION).lower() != 'none': |
|
embeddings_pos = build_position_encoding(cfg, |
|
cfg.MODEL.TOKEN_EMBED.DIM, cfg.MODEL.TOKEN_EMBED.POSITION_MAX_LEN) |
|
kwargs['embeddings_pos'] = embeddings_pos |
|
|
|
if cfg.MODEL.TOKEN_EMBED.TYPE_VOCAB_SIZE > 0: |
|
embeddings_token_type = nn.Embedding( |
|
cfg.MODEL.TOKEN_EMBED.TYPE_VOCAB_SIZE, cfg.MODEL.TOKEN_EMBED.DIM) |
|
kwargs['embeddings_token_type'] = embeddings_token_type |
|
|
|
if cfg.MODEL.TOKEN_EMBED.TYPE_SEG_SIZE > 0: |
|
embeddings_token_seg = nn.Embedding( |
|
cfg.MODEL.TOKEN_EMBED.TYPE_SEG_SIZE, cfg.MODEL.TOKEN_EMBED.DIM) |
|
kwargs['embeddings_token_seg'] = embeddings_token_seg |
|
|
|
|
|
kwargs['bw_own_embed'] = cfg.MODEL.BW_OWD_EMBED |
|
kwargs['pos_before'] = cfg.MODEL.POS_BEFORE |
|
kwargs['cfg'] = cfg |
|
return kwargs |
|
|
|
def get_time_step(self, data, sample_info, task_info=None): |
|
""" |
|
data: Bs, L |
|
task_info: { |
|
task_type: str |
|
} |
|
""" |
|
|
|
if task_info is None: |
|
task_type = '' |
|
else: |
|
task_type = task_info.get('task_type', None) |
|
time_step = None |
|
if isinstance(sample_info, list): |
|
sample_info = sample_info[0] |
|
if task_type in ['image_caption', 'video_caption'] and sample_info.get('text_spe_cat', False): |
|
text_length = data.shape[1] |
|
time_step = torch.cat([ |
|
torch.arange(text_length // 2, |
|
dtype=torch.long, |
|
device=data.device) for _ in range(2) |
|
]) |
|
elif task_type == 'VQA' and sample_info.get('text_spe_cat', False): |
|
text_length = data.shape[1] |
|
time_step = torch.cat([ |
|
torch.arange(text_length - 1, |
|
dtype=torch.long, |
|
device=data.device), |
|
torch.arange(1, dtype=torch.long, device=data.device) |
|
]) |
|
|
|
|
|
return time_step |
|
|
|
def forward(self, data, sample_info={}, task_info={}, **kwargs): |
|
|
|
|
|
time_step = kwargs.pop('time_step', None) |
|
if time_step is None: |
|
time_step = self.get_time_step(data, sample_info, task_info) |
|
|
|
if kwargs.pop("prompt_with_pos", False): |
|
raise NotImplementedError |
|
else: |
|
start_time = 0 |
|
|
|
type_embed = kwargs.get('type_embed', True) |
|
pos_emb = kwargs.get('pos_embed', True) |
|
|
|
data = self._forward(data, |
|
type_embed=type_embed, |
|
pos_emb=pos_emb, |
|
token_seg_ids=None, |
|
time_step=time_step, |
|
start_time=start_time) |
|
|
|
return data |
|
|
|
|
|
|
|
def set_s_token_bias(self, s_token_bias): |
|
self.s_token_bias = s_token_bias |
|
|
|
def _forward(self, input_ids, type_embed=True, token_seg_ids=None, time_step=None, pos_emb=True, start_time=0, ): |
|
|
|
embeddings = self.embeddings(input_ids) |
|
if self.cfg.SOLVER.FORCE_EMBED_FP16: |
|
embeddings = embeddings.half() |
|
|
|
if self.s_token_bias is not None: |
|
|
|
embeddings[input_ids == 49410] = embeddings[input_ids == 49410] + self.s_token_bias |
|
|
|
if self.embeddings_pos is not None and pos_emb and self.pos_before: |
|
pos_inputs = input_ids if time_step is None else time_step |
|
position_embeddings = self.embeddings_pos(pos_inputs, start_time=start_time) |
|
embeddings = embeddings + position_embeddings.to(embeddings.dtype) |
|
|
|
if self.embeddings_token_type is not None and type_embed: |
|
|
|
embeddings_token_type = self.embeddings_token_type.weight[0].unsqueeze(0).unsqueeze(1) |
|
embeddings = embeddings + embeddings_token_type.to(embeddings.dtype) |
|
|
|
if (self.embeddings_token_seg is not None) and (token_seg_ids is not None): |
|
embeddings_token_seg = self.embeddings_token_seg(token_seg_ids) |
|
embeddings = embeddings + embeddings_token_seg |
|
|
|
if self.embeddings_act is not None: |
|
embeddings = self.embeddings_act(embeddings) |
|
|
|
if self.embeddings_norm is not None: |
|
embeddings = self.embeddings_norm(embeddings) |
|
|
|
if self.embeddings_pos is not None and pos_emb and not self.pos_before: |
|
pos_inputs = input_ids if time_step is None else time_step |
|
position_embeddings = self.embeddings_pos(pos_inputs, start_time=start_time) |
|
embeddings = embeddings + position_embeddings.to(embeddings.dtype) |
|
|
|
if self.embeddings_dropout is not None: |
|
embeddings = self.embeddings_dropout(embeddings) |
|
|
|
return embeddings |
|
|