File size: 7,561 Bytes
32b542e |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 |
import torch
from torch import nn
from uniperceiver.config import configurable
from ..layers.transformer_encoder_layer import TransformerEncoderLayer
from ..layers.transformer_encoder_moe_layer import MoETransformerEncoderLayer
from .build import ENCODER_REGISTRY
import uniperceiver.utils.comm as comm
__all__ = ["UnifiedBertEncoder"]
def _construct_attention_masks( data, sample_info, task_info):
mask_type = torch.bool
device = data.device
attn_mask = None
if isinstance(sample_info, list):
sample_info = sample_info[0]
if task_info['task_type'] in ['image_caption', 'video_caption'] and sample_info.get('text_spe_cat', False):
# the extra 1 length for spe token
spe_length, img_length, text_total_length = sample_info['data_length']
text_length = text_total_length//2
attn_mask = torch.ones((spe_length + img_length + text_total_length,
spe_length + img_length + text_total_length), dtype=mask_type, device=device)
attn_mask[:spe_length + img_length + text_total_length, :spe_length+img_length] = False
attn_mask[spe_length + img_length:spe_length + img_length + text_length, spe_length + img_length:spe_length + img_length + text_length] = torch.ones(
(text_length, text_length), dtype=mask_type, device=device).triu_(diagonal=1)
attn_mask[spe_length + img_length + text_length:, spe_length + img_length:spe_length + img_length + text_length] = torch.ones(
(text_length, text_length),
dtype=mask_type,
device=device).triu_(diagonal=0)
attn_mask[spe_length + img_length + text_length:,
spe_length + img_length + text_length:] = ~torch.ones(
(text_length), dtype=mask_type,
device=device).diag()
return attn_mask
@ENCODER_REGISTRY.register()
class UnifiedBertEncoder(nn.Module):
@configurable
def __init__(self, *, num_hidden_layers: int, bert_layers,
skip_target_encode, word_balance_losses,
bookswiki_word_alone, cfg):
super(UnifiedBertEncoder, self).__init__()
self.num_hidden_layers = num_hidden_layers
self.layers = bert_layers
self.skip_target_encode = skip_target_encode
self.word_balance_losses = word_balance_losses
self.bookswiki_word_alone = bookswiki_word_alone
self.cfg = cfg
@classmethod
def from_config(cls, cfg):
if cfg.MODEL.BERT.DROP_PATH_PROB_FIXED:
dpr = [cfg.MODEL.BERT.DROP_PATH_PROB for _ in range(cfg.MODEL.BERT.NUM_HIDDEN_LAYERS)]
else:
dpr = [x.item() for x in torch.linspace(0, cfg.MODEL.BERT.DROP_PATH_PROB, cfg.MODEL.BERT.NUM_HIDDEN_LAYERS)]
layers = []
for layer_idx in range(cfg.MODEL.BERT.NUM_HIDDEN_LAYERS):
if not cfg.MOE.MOE:
layers.append(
TransformerEncoderLayer(
d_model=cfg.MODEL.BERT.HIDDEN_SIZE,
nhead=cfg.MODEL.BERT.NUM_ATTENTION_HEADS,
dim_feedforward=cfg.MODEL.BERT.INTERMEDIATE_SIZE,
dropout=cfg.MODEL.BERT.HIDDEN_DROPOUT_PROB,
drop_path_ratio=dpr[layer_idx],
activation=cfg.MODEL.BERT.HIDDEN_ACT,
layer_scale=cfg.MODEL.LAYER_SCALE,
ls_init_values=cfg.MODEL.LAYER_SCALE_INIT,
batch_first=True,
norm_first=True,
cfg = cfg,
))
else:
attention_moe = False
ffn_moe = False
moe_layer_start_idx = cfg.MOE.MOE_LAYER_START_IDX
moe_layer_end_idx = cfg.MOE.MOE_LAYER_END_IDX
if cfg.MOE.MOE and cfg.MOE.MOE_EXPERT_LOCATION == 'odd':
if layer_idx % 2 == 0 and layer_idx >= moe_layer_start_idx and layer_idx < moe_layer_end_idx:
moe_layers = cfg.MOE.MOE_EXPERT_TYPE.split(',')
attention_moe = "SA" in moe_layers
ffn_moe = 'FFN' in moe_layers
elif cfg.MOE.MOE and cfg.MOE.MOE_EXPERT_LOCATION == 'four':
if layer_idx % 4 == 0 and layer_idx >= moe_layer_start_idx and layer_idx < moe_layer_end_idx:
moe_layers = cfg.MOE.MOE_EXPERT_TYPE.split(',')
attention_moe = "SA" in moe_layers
ffn_moe = 'FFN' in moe_layers
elif cfg.MOE.MOE and cfg.MOE.MOE_EXPERT_LOCATION == 'all':
if layer_idx >= moe_layer_start_idx and layer_idx < moe_layer_end_idx:
moe_layers = cfg.MOE.MOE_EXPERT_TYPE.split(',')
attention_moe = "SA" in moe_layers
ffn_moe = 'FFN' in moe_layers
elif cfg.MOE.MOE and cfg.MOE.MOE_EXPERT_LOCATION == 'none':
attention_moe = None
ffn_moe = None
elif cfg.MOE.MOE:
raise NotImplementedError('cfg.MOE.MOE_EXPERT_LOCATION')
layers.append(
MoETransformerEncoderLayer(
d_model=cfg.MODEL.BERT.HIDDEN_SIZE,
nhead=cfg.MODEL.BERT.NUM_ATTENTION_HEADS,
dim_feedforward=cfg.MODEL.BERT.INTERMEDIATE_SIZE,
dropout=cfg.MODEL.BERT.HIDDEN_DROPOUT_PROB,
drop_path_ratio=dpr[layer_idx],
activation=cfg.MODEL.BERT.HIDDEN_ACT,
layer_scale=cfg.MODEL.LAYER_SCALE,
ls_init_values=cfg.MODEL.LAYER_SCALE_INIT,
batch_first=False,
norm_first=True,
cfg = cfg,
ffn_moe=ffn_moe,
attn_moe=attention_moe,
))
bert_layers = nn.ModuleList(
layers
)
return {
"num_hidden_layers": cfg.MODEL.BERT.NUM_HIDDEN_LAYERS,
"skip_target_encode": cfg.MODEL.BERT.SKIP_TARGET_ENCODE,
"bert_layers": bert_layers,
"word_balance_losses": cfg.SOLVER.WORD_BALANCE_LOSSESS,
"bookswiki_word_alone": cfg.MODEL.BW_WORD_ALONE,
"cfg": cfg
}
@classmethod
def add_config(cls, cfg):
pass
def forward(self, data, invalid_mask, sample_info, task_info, history_states=None, return_all=False, **kwargs):
attn_mask = _construct_attention_masks(data, sample_info, task_info)
kwargs.update({'sample_info': sample_info})
data_type = kwargs.get('data_type', 'input')
if data_type == 'target' and self.skip_target_encode:
# used for debugging with single gpu sometimes
return data
if return_all:
data_all = [data]
for l, layer_module in enumerate(self.layers):
if history_states is None:
data = layer_module(src=data, src_mask=attn_mask, src_key_padding_mask=invalid_mask, task_info=task_info, **kwargs)
else:
data = layer_module(src=data, src_mask=attn_mask, src_key_padding_mask=invalid_mask, history_states=history_states[l], task_info=task_info, **kwargs)
if return_all:
data_all.append(data)
return data if not return_all else data_all
|