|
import torch |
|
from torch import nn |
|
|
|
from uniperceiver.config import configurable |
|
|
|
from ..layers.transformer_encoder_layer import TransformerEncoderLayer |
|
from .build import ENCODER_REGISTRY |
|
|
|
import uniperceiver.utils.comm as comm |
|
|
|
__all__ = ["StandardViT", "TextEncoder", "VisualEncoder"] |
|
|
|
|
|
|
|
@ENCODER_REGISTRY.register() |
|
class StandardViT(nn.Module): |
|
@configurable |
|
def __init__(self, *, num_hidden_layers: int, bert_layers, cfg): |
|
super(StandardViT, self).__init__() |
|
self.num_hidden_layers = num_hidden_layers |
|
self.layers = bert_layers |
|
self.cfg = cfg |
|
self.name = cfg.NAME |
|
|
|
@classmethod |
|
def from_config(cls, cfg, global_cfg): |
|
if cfg.DROP_PATH_PROB_FIXED: |
|
dpr = [cfg.DROP_PATH_PROB for _ in range(cfg.NUM_HIDDEN_LAYERS)] |
|
else: |
|
dpr = [x.item() for x in torch.linspace(0, cfg.DROP_PATH_PROB, cfg.NUM_HIDDEN_LAYERS)] |
|
|
|
layers = [] |
|
for i in range(cfg.NUM_HIDDEN_LAYERS): |
|
layers.append( |
|
TransformerEncoderLayer( |
|
d_model=cfg.HIDDEN_SIZE, |
|
nhead=cfg.NUM_ATTENTION_HEADS, |
|
dim_feedforward=cfg.INTERMEDIATE_SIZE, |
|
dropout=cfg.HIDDEN_DROPOUT_PROB, |
|
drop_path_ratio=dpr[i], |
|
activation=cfg.HIDDEN_ACT, |
|
layer_scale=global_cfg.MODEL.LAYER_SCALE, |
|
ls_init_values=global_cfg.MODEL.LAYER_SCALE_INIT, |
|
batch_first=True, |
|
norm_first=True, |
|
cfg=cfg, |
|
)) |
|
|
|
bert_layers = nn.ModuleList( |
|
layers |
|
) |
|
return { |
|
"num_hidden_layers": cfg.NUM_HIDDEN_LAYERS, |
|
"bert_layers": bert_layers, |
|
"cfg": cfg |
|
} |
|
|
|
@classmethod |
|
def add_config(cls, cfg): |
|
pass |
|
|
|
def _forward(self, x, attn_mask=None, key_padding_masks=None, history_states=None, *kwargs): |
|
|
|
for l, layer_module in enumerate(self.layers): |
|
x = layer_module( |
|
src=x, src_mask=attn_mask, src_key_padding_mask=key_padding_masks |
|
) |
|
|
|
return x |
|
|
|
|
|
def forward(self, batched_inputs, return_all=False): |
|
|
|
raise NotImplementedError |
|
|
|
@ENCODER_REGISTRY.register() |
|
class VisualEncoder(StandardViT): |
|
|
|
@staticmethod |
|
def _construct_attention_masks( data, sample_info, task_info): |
|
|
|
return None |
|
|
|
def forward(self, data, invalid_mask, sample_info, task_info, **kwargs): |
|
|
|
|
|
attn_mask = self._construct_attention_masks(data, sample_info, task_info) |
|
history_states = kwargs.pop('history_states', None) |
|
out = self._forward(data, |
|
attn_mask, |
|
invalid_mask, |
|
history_states=history_states, |
|
**kwargs, |
|
) |
|
|
|
return out |
|
|
|
|
|
@ENCODER_REGISTRY.register() |
|
class TextEncoder(StandardViT): |
|
|
|
@staticmethod |
|
def _construct_attention_masks( data, sample_info, task_info): |
|
mask_type = torch.bool |
|
device = data.device |
|
|
|
attn_mask = None |
|
if isinstance(sample_info, list): |
|
sample_info = sample_info[0] |
|
if task_info['task_type'] in ['image_caption', 'video_caption'] and sample_info.get('text_spe_cat', False): |
|
total_length = data.shape[1] |
|
attn_mask = torch.ones((total_length, total_length), dtype=mask_type, device=device) |
|
attn_mask[:total_length // 2, :total_length // 2] = torch.ones( |
|
(total_length // 2, total_length // 2), dtype=mask_type, device=device).triu_(diagonal=1) |
|
attn_mask[total_length // 2:, : total_length // 2] = torch.ones( |
|
(total_length // 2, total_length // 2), |
|
dtype=mask_type, |
|
device=device).triu_(diagonal=0) |
|
attn_mask[total_length // 2:, total_length // 2:] = ~torch.ones( |
|
(total_length // 2), |
|
dtype=mask_type, |
|
device=device).diag() |
|
|
|
return attn_mask |
|
|
|
def forward(self, data, invalid_mask, sample_info, task_info, **kwargs): |
|
|
|
|
|
attn_mask = self._construct_attention_masks(data, sample_info, task_info) |
|
history_states = kwargs.pop('history_states', None) |
|
out = self._forward(data, |
|
attn_mask, |
|
invalid_mask, |
|
history_states=history_states, |
|
**kwargs) |
|
|
|
return out |
|
|