unit_test / uniperceiver /modeling /encoder /unified_bert_encoder.py
herrius's picture
Upload 259 files
32b542e
import torch
from torch import nn
from uniperceiver.config import configurable
from ..layers.transformer_encoder_layer import TransformerEncoderLayer
from ..layers.transformer_encoder_moe_layer import MoETransformerEncoderLayer
from .build import ENCODER_REGISTRY
import uniperceiver.utils.comm as comm
__all__ = ["UnifiedBertEncoder"]
def _construct_attention_masks( data, sample_info, task_info):
mask_type = torch.bool
device = data.device
attn_mask = None
if isinstance(sample_info, list):
sample_info = sample_info[0]
if task_info['task_type'] in ['image_caption', 'video_caption'] and sample_info.get('text_spe_cat', False):
# the extra 1 length for spe token
spe_length, img_length, text_total_length = sample_info['data_length']
text_length = text_total_length//2
attn_mask = torch.ones((spe_length + img_length + text_total_length,
spe_length + img_length + text_total_length), dtype=mask_type, device=device)
attn_mask[:spe_length + img_length + text_total_length, :spe_length+img_length] = False
attn_mask[spe_length + img_length:spe_length + img_length + text_length, spe_length + img_length:spe_length + img_length + text_length] = torch.ones(
(text_length, text_length), dtype=mask_type, device=device).triu_(diagonal=1)
attn_mask[spe_length + img_length + text_length:, spe_length + img_length:spe_length + img_length + text_length] = torch.ones(
(text_length, text_length),
dtype=mask_type,
device=device).triu_(diagonal=0)
attn_mask[spe_length + img_length + text_length:,
spe_length + img_length + text_length:] = ~torch.ones(
(text_length), dtype=mask_type,
device=device).diag()
return attn_mask
@ENCODER_REGISTRY.register()
class UnifiedBertEncoder(nn.Module):
@configurable
def __init__(self, *, num_hidden_layers: int, bert_layers,
skip_target_encode, word_balance_losses,
bookswiki_word_alone, cfg):
super(UnifiedBertEncoder, self).__init__()
self.num_hidden_layers = num_hidden_layers
self.layers = bert_layers
self.skip_target_encode = skip_target_encode
self.word_balance_losses = word_balance_losses
self.bookswiki_word_alone = bookswiki_word_alone
self.cfg = cfg
@classmethod
def from_config(cls, cfg):
if cfg.MODEL.BERT.DROP_PATH_PROB_FIXED:
dpr = [cfg.MODEL.BERT.DROP_PATH_PROB for _ in range(cfg.MODEL.BERT.NUM_HIDDEN_LAYERS)]
else:
dpr = [x.item() for x in torch.linspace(0, cfg.MODEL.BERT.DROP_PATH_PROB, cfg.MODEL.BERT.NUM_HIDDEN_LAYERS)]
layers = []
for layer_idx in range(cfg.MODEL.BERT.NUM_HIDDEN_LAYERS):
if not cfg.MOE.MOE:
layers.append(
TransformerEncoderLayer(
d_model=cfg.MODEL.BERT.HIDDEN_SIZE,
nhead=cfg.MODEL.BERT.NUM_ATTENTION_HEADS,
dim_feedforward=cfg.MODEL.BERT.INTERMEDIATE_SIZE,
dropout=cfg.MODEL.BERT.HIDDEN_DROPOUT_PROB,
drop_path_ratio=dpr[layer_idx],
activation=cfg.MODEL.BERT.HIDDEN_ACT,
layer_scale=cfg.MODEL.LAYER_SCALE,
ls_init_values=cfg.MODEL.LAYER_SCALE_INIT,
batch_first=True,
norm_first=True,
cfg = cfg,
))
else:
attention_moe = False
ffn_moe = False
moe_layer_start_idx = cfg.MOE.MOE_LAYER_START_IDX
moe_layer_end_idx = cfg.MOE.MOE_LAYER_END_IDX
if cfg.MOE.MOE and cfg.MOE.MOE_EXPERT_LOCATION == 'odd':
if layer_idx % 2 == 0 and layer_idx >= moe_layer_start_idx and layer_idx < moe_layer_end_idx:
moe_layers = cfg.MOE.MOE_EXPERT_TYPE.split(',')
attention_moe = "SA" in moe_layers
ffn_moe = 'FFN' in moe_layers
elif cfg.MOE.MOE and cfg.MOE.MOE_EXPERT_LOCATION == 'four':
if layer_idx % 4 == 0 and layer_idx >= moe_layer_start_idx and layer_idx < moe_layer_end_idx:
moe_layers = cfg.MOE.MOE_EXPERT_TYPE.split(',')
attention_moe = "SA" in moe_layers
ffn_moe = 'FFN' in moe_layers
elif cfg.MOE.MOE and cfg.MOE.MOE_EXPERT_LOCATION == 'all':
if layer_idx >= moe_layer_start_idx and layer_idx < moe_layer_end_idx:
moe_layers = cfg.MOE.MOE_EXPERT_TYPE.split(',')
attention_moe = "SA" in moe_layers
ffn_moe = 'FFN' in moe_layers
elif cfg.MOE.MOE and cfg.MOE.MOE_EXPERT_LOCATION == 'none':
attention_moe = None
ffn_moe = None
elif cfg.MOE.MOE:
raise NotImplementedError('cfg.MOE.MOE_EXPERT_LOCATION')
layers.append(
MoETransformerEncoderLayer(
d_model=cfg.MODEL.BERT.HIDDEN_SIZE,
nhead=cfg.MODEL.BERT.NUM_ATTENTION_HEADS,
dim_feedforward=cfg.MODEL.BERT.INTERMEDIATE_SIZE,
dropout=cfg.MODEL.BERT.HIDDEN_DROPOUT_PROB,
drop_path_ratio=dpr[layer_idx],
activation=cfg.MODEL.BERT.HIDDEN_ACT,
layer_scale=cfg.MODEL.LAYER_SCALE,
ls_init_values=cfg.MODEL.LAYER_SCALE_INIT,
batch_first=False,
norm_first=True,
cfg = cfg,
ffn_moe=ffn_moe,
attn_moe=attention_moe,
))
bert_layers = nn.ModuleList(
layers
)
return {
"num_hidden_layers": cfg.MODEL.BERT.NUM_HIDDEN_LAYERS,
"skip_target_encode": cfg.MODEL.BERT.SKIP_TARGET_ENCODE,
"bert_layers": bert_layers,
"word_balance_losses": cfg.SOLVER.WORD_BALANCE_LOSSESS,
"bookswiki_word_alone": cfg.MODEL.BW_WORD_ALONE,
"cfg": cfg
}
@classmethod
def add_config(cls, cfg):
pass
def forward(self, data, invalid_mask, sample_info, task_info, history_states=None, return_all=False, **kwargs):
attn_mask = _construct_attention_masks(data, sample_info, task_info)
kwargs.update({'sample_info': sample_info})
data_type = kwargs.get('data_type', 'input')
if data_type == 'target' and self.skip_target_encode:
# used for debugging with single gpu sometimes
return data
if return_all:
data_all = [data]
for l, layer_module in enumerate(self.layers):
if history_states is None:
data = layer_module(src=data, src_mask=attn_mask, src_key_padding_mask=invalid_mask, task_info=task_info, **kwargs)
else:
data = layer_module(src=data, src_mask=attn_mask, src_key_padding_mask=invalid_mask, history_states=history_states[l], task_info=task_info, **kwargs)
if return_all:
data_all.append(data)
return data if not return_all else data_all