|
|
|
|
|
|
|
|
|
|
|
from typing import Dict, List, Optional |
|
|
|
import torch |
|
import torch.nn as nn |
|
from fairseq import utils |
|
from fairseq.modules import LayerNorm |
|
from fairseq.modules.fairseq_dropout import FairseqDropout |
|
from fairseq.modules.quant_noise import quant_noise |
|
from torch import Tensor |
|
|
|
from .unify_multihead_attention import MultiheadAttention |
|
|
|
|
|
def drop_path(x, drop_prob: float = 0.0, training: bool = False): |
|
""" |
|
Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). |
|
Comment by Ross Wightman: This is the same as the DropConnect impl I created for EfficientNet, etc networks, |
|
however, the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper... |
|
See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for changing the |
|
layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use 'survival rate' as the |
|
argument. |
|
""" |
|
if drop_prob == 0.0 or not training: |
|
return x |
|
keep_prob = 1 - drop_prob |
|
shape = (1, x.shape[1], 1) |
|
random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device) |
|
random_tensor.floor_() |
|
output = x.div(keep_prob) * random_tensor |
|
return output |
|
|
|
def init_bert_weights(module): |
|
"""Initialize the weights.""" |
|
if isinstance(module, (nn.Linear, nn.Embedding)): |
|
|
|
module.weight.data.normal_(mean=0.0, std=0.02) |
|
elif isinstance(module, nn.LayerNorm): |
|
module.bias.data.zero_() |
|
module.weight.data.fill_(1.0) |
|
if isinstance(module, nn.Linear) and module.bias is not None: |
|
module.bias.data.zero_() |
|
|
|
|
|
class Adapter_Layer(torch.nn.Module): |
|
def __init__(self, |
|
d_model=None, |
|
down_size=None, |
|
dropout=0.0, |
|
init_option="bert", |
|
adapter_scalar="1.0"): |
|
super().__init__() |
|
self.n_embd = d_model |
|
self.down_size = down_size |
|
|
|
|
|
if adapter_scalar == "learnable_scalar": |
|
self.scale = nn.Parameter(torch.ones(1)) |
|
else: |
|
self.scale = float(adapter_scalar) |
|
|
|
self.down_proj = nn.Linear(self.n_embd, self.down_size) |
|
self.non_linear_func = nn.ReLU() |
|
self.up_proj = nn.Linear(self.down_size, self.n_embd) |
|
|
|
self.dropout = dropout |
|
if init_option == "bert": |
|
self.apply(init_bert_weights) |
|
elif init_option == "lora": |
|
with torch.no_grad(): |
|
nn.init.kaiming_uniform_(self.down_proj.weight, a=math.sqrt(5)) |
|
nn.init.zeros_(self.up_proj.weight) |
|
nn.init.zeros_(self.down_proj.bias) |
|
nn.init.zeros_(self.up_proj.bias) |
|
|
|
def forward(self, x, add_residual=True, residual=None): |
|
residual = x if residual is None else residual |
|
|
|
down = self.down_proj(x) |
|
down = self.non_linear_func(down) |
|
down = nn.functional.dropout(down, p=self.dropout, training=self.training) |
|
up = self.up_proj(down) |
|
up = up * self.scale |
|
if add_residual: |
|
output = up + residual |
|
else: |
|
output = up |
|
|
|
return output |
|
|
|
class VLAdapter_Layer(torch.nn.Module): |
|
def __init__(self, |
|
d_model=None, |
|
down_size=None, |
|
dropout=0.0, |
|
init_option="bert", |
|
adapter_scalar="1.0"): |
|
super().__init__() |
|
print("load VL adapter") |
|
self.v_adapter = Adapter_Layer(d_model=d_model, |
|
down_size=down_size, |
|
dropout=dropout, |
|
init_option=init_option, |
|
adapter_scalar=adapter_scalar) |
|
|
|
self.l_adapter = Adapter_Layer(d_model=d_model, |
|
down_size=down_size, |
|
dropout=dropout, |
|
init_option=init_option, |
|
adapter_scalar=adapter_scalar) |
|
|
|
|
|
def forward(self, x, add_residual=True, residual=None, num_image_tokens=None): |
|
|
|
if num_image_tokens is not None: |
|
v_x = x[:num_image_tokens, :, :] |
|
l_x = x[num_image_tokens:, :, :] |
|
else: |
|
v_x = x |
|
l_x = x |
|
|
|
v_x = self.v_adapter(v_x, add_residual=add_residual, residual=residual) |
|
l_x = self.l_adapter(l_x, add_residual=add_residual, residual=residual) |
|
|
|
if num_image_tokens is not None: |
|
x = torch.cat((v_x, l_x), dim=0) |
|
else: |
|
x = v_x + l_x |
|
|
|
return x |
|
|
|
|
|
class DropPath(nn.Module): |
|
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).""" |
|
|
|
def __init__(self, drop_prob=None): |
|
super().__init__() |
|
self.drop_prob = drop_prob |
|
|
|
def forward(self, x): |
|
return drop_path(x, self.drop_prob, self.training) |
|
|
|
def extra_repr(self) -> str: |
|
return "p={}".format(self.drop_prob) |
|
|
|
|
|
class TransformerEncoderLayer(nn.Module): |
|
"""Encoder layer block. |
|
|
|
In the original paper each operation (multi-head attention or FFN) is |
|
postprocessed with: `dropout -> add residual -> layernorm`. In the |
|
tensor2tensor code they suggest that learning is more robust when |
|
preprocessing each layer with layernorm and postprocessing with: |
|
`dropout -> add residual`. We default to the approach in the paper, but the |
|
tensor2tensor approach can be enabled by setting |
|
*args.encoder_normalize_before* to ``True``. |
|
|
|
Args: |
|
args (argparse.Namespace): parsed command-line arguments |
|
""" |
|
|
|
def __init__(self, args, drop_path_rate=0.0, use_adapter=False, adapter_dim=200, adapter_type='UM'): |
|
super().__init__() |
|
self.args = args |
|
self.use_adapter = use_adapter |
|
self.embed_dim = args.encoder_embed_dim |
|
self.adapter_type = adapter_type |
|
if self.use_adapter: |
|
if adapter_type == 'VL': |
|
self.adapter = VLAdapter_Layer(d_model=self.embed_dim, down_size=adapter_dim) |
|
else: |
|
self.adapter = Adapter_Layer(d_model=self.embed_dim, down_size=adapter_dim) |
|
self.quant_noise = getattr(args, 'quant_noise_pq', 0) |
|
self.quant_noise_block_size = getattr(args, 'quant_noise_pq_block_size', 8) or 8 |
|
self.self_attn = self.build_self_attention(self.embed_dim, args) |
|
self.self_attn_layer_norm = LayerNorm(self.embed_dim) |
|
self.dropout_module = FairseqDropout( |
|
args.dropout, module_name=self.__class__.__name__ |
|
) |
|
self.activation_fn = utils.get_activation_fn( |
|
activation=getattr(args, 'activation_fn', 'relu') or "relu" |
|
) |
|
activation_dropout_p = getattr(args, "activation_dropout", 0) or 0 |
|
if activation_dropout_p == 0: |
|
|
|
activation_dropout_p = getattr(args, "relu_dropout", 0) or 0 |
|
self.activation_dropout_module = FairseqDropout( |
|
float(activation_dropout_p), module_name=self.__class__.__name__ |
|
) |
|
self.normalize_before = args.encoder_normalize_before |
|
self.fc1 = self.build_fc1( |
|
self.embed_dim, |
|
args.encoder_ffn_embed_dim, |
|
self.quant_noise, |
|
self.quant_noise_block_size, |
|
) |
|
self.fc2 = self.build_fc2( |
|
args.encoder_ffn_embed_dim, |
|
self.embed_dim, |
|
self.quant_noise, |
|
self.quant_noise_block_size, |
|
) |
|
|
|
self.attn_ln = LayerNorm(self.embed_dim) if getattr(args, 'scale_attn', False) else None |
|
self.nh = self.self_attn.num_heads |
|
self.head_dim = self.self_attn.head_dim |
|
|
|
self.ffn_layernorm = LayerNorm(args.encoder_ffn_embed_dim) if getattr(args, 'scale_fc', False) else None |
|
self.w_resid = nn.Parameter(torch.ones(self.embed_dim, ), requires_grad=True) if getattr(args, 'scale_resids', False) else None |
|
|
|
self.final_layer_norm = LayerNorm(self.embed_dim) |
|
|
|
self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0.0 else nn.Identity() |
|
|
|
def build_fc1(self, input_dim, output_dim, q_noise, qn_block_size): |
|
return quant_noise( |
|
nn.Linear(input_dim, output_dim), p=q_noise, block_size=qn_block_size |
|
) |
|
|
|
def build_fc2(self, input_dim, output_dim, q_noise, qn_block_size): |
|
return quant_noise( |
|
nn.Linear(input_dim, output_dim), p=q_noise, block_size=qn_block_size |
|
) |
|
|
|
def build_self_attention(self, embed_dim, args): |
|
return MultiheadAttention( |
|
embed_dim, |
|
args.encoder_attention_heads, |
|
dropout=args.attention_dropout, |
|
self_attention=True, |
|
q_noise=self.quant_noise, |
|
qn_block_size=self.quant_noise_block_size, |
|
scale_factor=args.attn_scale_factor, |
|
scale_heads=getattr(args, 'scale_heads', False), |
|
qk_norm=getattr(args, 'qk_norm', False), |
|
) |
|
|
|
def residual_connection(self, x, residual): |
|
return residual + self.drop_path(x) |
|
|
|
def upgrade_state_dict_named(self, state_dict, name): |
|
""" |
|
Rename layer norm states from `...layer_norms.0.weight` to |
|
`...self_attn_layer_norm.weight` and `...layer_norms.1.weight` to |
|
`...final_layer_norm.weight` |
|
""" |
|
layer_norm_map = {"0": "self_attn_layer_norm", "1": "final_layer_norm"} |
|
for old, new in layer_norm_map.items(): |
|
for m in ("weight", "bias"): |
|
k = "{}.layer_norms.{}.{}".format(name, old, m) |
|
if k in state_dict: |
|
state_dict["{}.{}.{}".format(name, new, m)] = state_dict[k] |
|
del state_dict[k] |
|
if "{}.{}.{}".format(name, new, m) not in state_dict and "{}.{}".format(new, m) in self.state_dict(): |
|
state_dict[ |
|
"{}.{}.{}".format(name, new, m) |
|
] = self.state_dict()["{}.{}".format(new, m)] |
|
|
|
prefix = name + "." if name != "" else "" |
|
for param_name, param_tensor in self.state_dict().items(): |
|
if (prefix + param_name) not in state_dict: |
|
state_dict[prefix + param_name] = self.state_dict()[param_name] |
|
|
|
def forward( |
|
self, |
|
x, |
|
encoder_padding_mask: Optional[Tensor], |
|
attn_mask: Optional[Tensor] = None, |
|
self_attn_bias: Optional[Tensor] = None, |
|
prompt_kv: Optional[Tensor] = None, |
|
num_image_tokens = None, |
|
): |
|
""" |
|
Args: |
|
x (Tensor): input to the layer of shape `(seq_len, batch, embed_dim)` |
|
encoder_padding_mask (ByteTensor): binary ByteTensor of shape |
|
`(batch, seq_len)` where padding elements are indicated by ``1``. |
|
attn_mask (ByteTensor): binary tensor of shape `(tgt_len, src_len)`, |
|
where `tgt_len` is the length of output and `src_len` is the |
|
length of input, though here both are equal to `seq_len`. |
|
`attn_mask[tgt_i, src_j] = 1` means that when calculating the |
|
embedding for `tgt_i`, we exclude (mask out) `src_j`. This is |
|
useful for strided self-attention. |
|
|
|
Returns: |
|
encoded output of shape `(seq_len, batch, embed_dim)` |
|
""" |
|
|
|
|
|
|
|
|
|
|
|
if attn_mask is not None: |
|
attn_mask = attn_mask.masked_fill( |
|
attn_mask.to(torch.bool), |
|
-1e8 if x.dtype == torch.float32 else -1e4 |
|
) |
|
|
|
residual = x |
|
if self.normalize_before: |
|
x = self.self_attn_layer_norm(x) |
|
x, _ = self.self_attn( |
|
query=x, |
|
key=x, |
|
value=x, |
|
key_padding_mask=encoder_padding_mask, |
|
need_weights=False, |
|
attn_mask=attn_mask, |
|
attn_bias=self_attn_bias, |
|
prompt_kv=prompt_kv |
|
) |
|
if self.attn_ln is not None: |
|
x = self.attn_ln(x) |
|
x = self.dropout_module(x) |
|
x = self.residual_connection(x, residual) |
|
if not self.normalize_before: |
|
x = self.self_attn_layer_norm(x) |
|
|
|
residual = x |
|
if self.normalize_before: |
|
x = self.final_layer_norm(x) |
|
x = self.activation_fn(self.fc1(x)) |
|
x = self.activation_dropout_module(x) |
|
if self.ffn_layernorm is not None: |
|
x = self.ffn_layernorm(x) |
|
x = self.fc2(x) |
|
x = self.dropout_module(x) |
|
if self.use_adapter: |
|
if self.adapter_type == 'VL': |
|
x = self.adapter(x, num_image_tokens=num_image_tokens) |
|
else: |
|
x = self.adapter(x) |
|
if self.w_resid is not None: |
|
residual = torch.mul(self.w_resid, residual) |
|
x = self.residual_connection(x, residual) |
|
if not self.normalize_before: |
|
x = self.final_layer_norm(x) |
|
return x |
|
|
|
|
|
class TransformerDecoderLayer(nn.Module): |
|
"""Decoder layer block. |
|
|
|
In the original paper each operation (multi-head attention, encoder |
|
attention or FFN) is postprocessed with: `dropout -> add residual -> |
|
layernorm`. In the tensor2tensor code they suggest that learning is more |
|
robust when preprocessing each layer with layernorm and postprocessing with: |
|
`dropout -> add residual`. We default to the approach in the paper, but the |
|
tensor2tensor approach can be enabled by setting |
|
*args.decoder_normalize_before* to ``True``. |
|
|
|
Args: |
|
args (argparse.Namespace): parsed command-line arguments |
|
no_encoder_attn (bool, optional): whether to attend to encoder outputs |
|
(default: False). |
|
""" |
|
|
|
def __init__( |
|
self, args, no_encoder_attn=False, add_bias_kv=False, add_zero_attn=False, \ |
|
drop_path_rate=0.0, use_adapter=False, adapter_dim=200): |
|
super().__init__() |
|
self.embed_dim = args.decoder_embed_dim |
|
self.use_adapter = use_adapter |
|
if use_adapter == True: |
|
self.adapter = Adapter_Layer(d_model=self.embed_dim, down_size=adapter_dim) |
|
|
|
self.dropout_module = FairseqDropout( |
|
args.dropout, module_name=self.__class__.__name__ |
|
) |
|
self.quant_noise = getattr(args, "quant_noise_pq", 0) |
|
self.quant_noise_block_size = getattr(args, "quant_noise_pq_block_size", 8) |
|
|
|
self.cross_self_attention = getattr(args, "cross_self_attention", False) |
|
|
|
self.self_attn = self.build_self_attention( |
|
self.embed_dim, |
|
args, |
|
add_bias_kv=add_bias_kv, |
|
add_zero_attn=add_zero_attn, |
|
) |
|
self.self_attn_ln = LayerNorm(self.embed_dim) if getattr(args, 'scale_attn', False) else None |
|
self.cross_attn_ln = LayerNorm(self.embed_dim) if getattr(args, 'scale_attn', False) else None |
|
self.nh = self.self_attn.num_heads |
|
self.head_dim = self.self_attn.head_dim |
|
|
|
self.activation_fn = utils.get_activation_fn( |
|
activation=str(args.activation_fn) |
|
if getattr(args, "activation_fn", None) is not None |
|
else "relu" |
|
) |
|
activation_dropout_p = getattr(args, "activation_dropout", 0) or 0 |
|
if activation_dropout_p == 0: |
|
|
|
activation_dropout_p = getattr(args, "relu_dropout", 0) or 0 |
|
self.activation_dropout_module = FairseqDropout( |
|
float(activation_dropout_p), module_name=self.__class__.__name__ |
|
) |
|
self.normalize_before = args.decoder_normalize_before |
|
|
|
|
|
|
|
|
|
export = getattr(args, "char_inputs", False) |
|
self.self_attn_layer_norm = LayerNorm(self.embed_dim, export=export) |
|
|
|
if no_encoder_attn: |
|
self.encoder_attn = None |
|
self.encoder_attn_layer_norm = None |
|
else: |
|
self.encoder_attn = self.build_encoder_attention(self.embed_dim, args) |
|
self.encoder_attn_layer_norm = LayerNorm(self.embed_dim, export=export) |
|
|
|
self.ffn_layernorm = LayerNorm(args.decoder_ffn_embed_dim) if getattr(args, 'scale_fc', False) else None |
|
self.w_resid = nn.Parameter(torch.ones(self.embed_dim, ), requires_grad=True) if getattr(args, 'scale_resids', False) else None |
|
|
|
self.fc1 = self.build_fc1( |
|
self.embed_dim, |
|
args.decoder_ffn_embed_dim, |
|
self.quant_noise, |
|
self.quant_noise_block_size, |
|
) |
|
self.fc2 = self.build_fc2( |
|
args.decoder_ffn_embed_dim, |
|
self.embed_dim, |
|
self.quant_noise, |
|
self.quant_noise_block_size, |
|
) |
|
|
|
self.final_layer_norm = LayerNorm(self.embed_dim, export=export) |
|
self.need_attn = True |
|
|
|
self.onnx_trace = False |
|
|
|
self.drop_path = DropPath(drop_path_rate) if drop_path_rate > 0.0 else nn.Identity() |
|
|
|
def build_fc1(self, input_dim, output_dim, q_noise, qn_block_size): |
|
return quant_noise(nn.Linear(input_dim, output_dim), q_noise, qn_block_size) |
|
|
|
def build_fc2(self, input_dim, output_dim, q_noise, qn_block_size): |
|
return quant_noise(nn.Linear(input_dim, output_dim), q_noise, qn_block_size) |
|
|
|
def build_self_attention( |
|
self, embed_dim, args, add_bias_kv=False, add_zero_attn=False |
|
): |
|
return MultiheadAttention( |
|
embed_dim, |
|
args.decoder_attention_heads, |
|
dropout=args.attention_dropout, |
|
add_bias_kv=add_bias_kv, |
|
add_zero_attn=add_zero_attn, |
|
self_attention=not getattr(args, "cross_self_attention", False), |
|
q_noise=self.quant_noise, |
|
qn_block_size=self.quant_noise_block_size, |
|
scale_factor=args.attn_scale_factor, |
|
scale_heads=getattr(args, 'scale_heads', False), |
|
qk_norm=getattr(args, 'qk_norm', False), |
|
) |
|
|
|
def build_encoder_attention(self, embed_dim, args): |
|
return MultiheadAttention( |
|
embed_dim, |
|
args.decoder_attention_heads, |
|
kdim=getattr(args, "encoder_embed_dim", None), |
|
vdim=getattr(args, "encoder_embed_dim", None), |
|
dropout=args.attention_dropout, |
|
encoder_decoder_attention=True, |
|
q_noise=self.quant_noise, |
|
qn_block_size=self.quant_noise_block_size, |
|
scale_factor=args.attn_scale_factor, |
|
scale_heads=getattr(args, 'scale_heads', False), |
|
qk_norm=getattr(args, 'qk_norm', False), |
|
) |
|
|
|
def prepare_for_onnx_export_(self): |
|
self.onnx_trace = True |
|
|
|
def residual_connection(self, x, residual): |
|
return residual + self.drop_path(x) |
|
|
|
def forward( |
|
self, |
|
x, |
|
encoder_out: Optional[torch.Tensor] = None, |
|
encoder_padding_mask: Optional[torch.Tensor] = None, |
|
incremental_state: Optional[Dict[str, Dict[str, Optional[Tensor]]]] = None, |
|
prev_self_attn_state: Optional[List[torch.Tensor]] = None, |
|
prev_attn_state: Optional[List[torch.Tensor]] = None, |
|
self_attn_mask: Optional[torch.Tensor] = None, |
|
self_attn_padding_mask: Optional[torch.Tensor] = None, |
|
need_attn: bool = False, |
|
need_head_weights: bool = False, |
|
self_attn_bias: Optional[Tensor] = None, |
|
cross_attn_bias: Optional[Tensor] = None, |
|
prompt_kv: Optional[Tensor] = None, |
|
): |
|
""" |
|
Args: |
|
x (Tensor): input to the layer of shape `(seq_len, batch, embed_dim)` |
|
encoder_padding_mask (ByteTensor, optional): binary |
|
ByteTensor of shape `(batch, src_len)` where padding |
|
elements are indicated by ``1``. |
|
need_attn (bool, optional): return attention weights |
|
need_head_weights (bool, optional): return attention weights |
|
for each head (default: return average over heads). |
|
|
|
Returns: |
|
encoded output of shape `(seq_len, batch, embed_dim)` |
|
""" |
|
if need_head_weights: |
|
need_attn = True |
|
|
|
residual = x |
|
if self.normalize_before: |
|
x = self.self_attn_layer_norm(x) |
|
if prev_self_attn_state is not None: |
|
prev_key, prev_value = prev_self_attn_state[:2] |
|
saved_state: Dict[str, Optional[Tensor]] = { |
|
"prev_key": prev_key, |
|
"prev_value": prev_value, |
|
} |
|
if len(prev_self_attn_state) >= 3: |
|
saved_state["prev_key_padding_mask"] = prev_self_attn_state[2] |
|
assert incremental_state is not None |
|
self.self_attn._set_input_buffer(incremental_state, saved_state) |
|
_self_attn_input_buffer = self.self_attn._get_input_buffer(incremental_state) |
|
if self.cross_self_attention and not ( |
|
incremental_state is not None |
|
and _self_attn_input_buffer is not None |
|
and "prev_key" in _self_attn_input_buffer |
|
): |
|
if self_attn_mask is not None: |
|
assert encoder_out is not None |
|
self_attn_mask = torch.cat( |
|
(x.new_zeros(x.size(0), encoder_out.size(0)), self_attn_mask), dim=1 |
|
) |
|
if self_attn_padding_mask is not None: |
|
if encoder_padding_mask is None: |
|
assert encoder_out is not None |
|
encoder_padding_mask = self_attn_padding_mask.new_zeros( |
|
encoder_out.size(1), encoder_out.size(0) |
|
) |
|
self_attn_padding_mask = torch.cat( |
|
(encoder_padding_mask, self_attn_padding_mask), dim=1 |
|
) |
|
assert encoder_out is not None |
|
y = torch.cat((encoder_out, x), dim=0) |
|
else: |
|
y = x |
|
|
|
x, attn = self.self_attn( |
|
query=x, |
|
key=y, |
|
value=y, |
|
key_padding_mask=self_attn_padding_mask, |
|
incremental_state=incremental_state, |
|
need_weights=False, |
|
attn_mask=self_attn_mask, |
|
attn_bias=self_attn_bias, |
|
prompt_kv=prompt_kv |
|
) |
|
if self.self_attn_ln is not None: |
|
x = self.self_attn_ln(x) |
|
x = self.dropout_module(x) |
|
x = self.residual_connection(x, residual) |
|
if not self.normalize_before: |
|
x = self.self_attn_layer_norm(x) |
|
|
|
if self.encoder_attn is not None and encoder_out is not None: |
|
residual = x |
|
if self.normalize_before: |
|
x = self.encoder_attn_layer_norm(x) |
|
if prev_attn_state is not None: |
|
prev_key, prev_value = prev_attn_state[:2] |
|
saved_state: Dict[str, Optional[Tensor]] = { |
|
"prev_key": prev_key, |
|
"prev_value": prev_value, |
|
} |
|
if len(prev_attn_state) >= 3: |
|
saved_state["prev_key_padding_mask"] = prev_attn_state[2] |
|
assert incremental_state is not None |
|
self.encoder_attn._set_input_buffer(incremental_state, saved_state) |
|
|
|
x, attn = self.encoder_attn( |
|
query=x, |
|
key=encoder_out, |
|
value=encoder_out, |
|
key_padding_mask=encoder_padding_mask, |
|
incremental_state=incremental_state, |
|
static_kv=True, |
|
need_weights=need_attn or (not self.training and self.need_attn), |
|
need_head_weights=need_head_weights, |
|
attn_bias=cross_attn_bias |
|
) |
|
if self.cross_attn_ln is not None: |
|
x = self.cross_attn_ln(x) |
|
x = self.dropout_module(x) |
|
x = self.residual_connection(x, residual) |
|
if not self.normalize_before: |
|
x = self.encoder_attn_layer_norm(x) |
|
|
|
residual = x |
|
if self.normalize_before: |
|
x = self.final_layer_norm(x) |
|
|
|
x = self.activation_fn(self.fc1(x)) |
|
x = self.activation_dropout_module(x) |
|
if self.ffn_layernorm is not None: |
|
x = self.ffn_layernorm(x) |
|
x = self.fc2(x) |
|
x = self.dropout_module(x) |
|
if self.use_adapter == True: |
|
x = self.adapter(x) |
|
|
|
if self.w_resid is not None: |
|
residual = torch.mul(self.w_resid, residual) |
|
x = self.residual_connection(x, residual) |
|
if not self.normalize_before: |
|
x = self.final_layer_norm(x) |
|
if self.onnx_trace and incremental_state is not None: |
|
saved_state = self.self_attn._get_input_buffer(incremental_state) |
|
assert saved_state is not None |
|
if self_attn_padding_mask is not None: |
|
self_attn_state = [ |
|
saved_state["prev_key"], |
|
saved_state["prev_value"], |
|
saved_state["prev_key_padding_mask"], |
|
] |
|
else: |
|
self_attn_state = [saved_state["prev_key"], saved_state["prev_value"]] |
|
return x, attn, self_attn_state |
|
return x, attn, None |
|
|
|
def make_generation_fast_(self, need_attn: bool = False, **kwargs): |
|
self.need_attn = need_attn |
|
|
|
def upgrade_state_dict_named(self, state_dict, name): |
|
""" |
|
Rename layer norm states from `...layer_norms.0.weight` to |
|
`...self_attn_layer_norm.weight` and `...layer_norms.1.weight` to |
|
`...final_layer_norm.weight` |
|
""" |
|
|
|
layer_norm_map = { |
|
"0": "self_attn_layer_norm", |
|
"1": "encoder_attn_layer_norm", |
|
"2": "final_layer_norm", |
|
} |
|
for old, new in layer_norm_map.items(): |
|
for m in ("weight", "bias"): |
|
k = "{}.layer_norms.{}.{}".format(name, old, m) |
|
if k in state_dict: |
|
state_dict[ |
|
"{}.{}.{}".format(name, new, m) |
|
] = state_dict[k] |
|
del state_dict[k] |
|
if "{}.{}.{}".format(name, new, m) not in state_dict and "{}.{}".format(new, m) in self.state_dict(): |
|
state_dict[ |
|
"{}.{}.{}".format(name, new, m) |
|
] = self.state_dict()["{}.{}".format(new, m)] |
|
|
|
prefix = name + "." if name != "" else "" |
|
for param_name, param_tensor in self.state_dict().items(): |
|
if (prefix + param_name) not in state_dict: |
|
state_dict[prefix + param_name] = self.state_dict()[param_name] |
|
|