|
from logging import warn |
|
from transformers.models.albert.modeling_albert import * |
|
import torch |
|
import torch.nn as nn |
|
from transformers.models.albert.configuration_albert import AlbertConfig |
|
import sys |
|
|
|
AUTO_MAP = { |
|
"AutoModel": "modeling_lsg_albert.LSGAlbertModel", |
|
"AutoModelForMaskedLM": "modeling_lsg_albert.LSGAlbertForMaskedLM", |
|
"AutoModelForPreTraining": "modeling_lsg_albert.LSGAlbertForPreTraining", |
|
"AutoModelForMultipleChoice": "modeling_lsg_albert.LSGAlbertForMultipleChoice", |
|
"AutoModelForQuestionAnswering": "modeling_lsg_albert.LSGAlbertForQuestionAnswering", |
|
"AutoModelForSequenceClassification": "modeling_lsg_albert.LSGAlbertForSequenceClassification", |
|
"AutoModelForTokenClassification": "modeling_lsg_albert.LSGAlbertForTokenClassification" |
|
} |
|
|
|
class LSGAlbertConfig(AlbertConfig): |
|
""" |
|
This class overrides :class:`~transformers.AlbertConfig`. Please check the superclass for the appropriate |
|
documentation alongside usage examples. |
|
""" |
|
|
|
base_model_prefix = "lsg" |
|
model_type = "albert" |
|
|
|
def __init__( |
|
self, |
|
adaptive=True, |
|
base_model_prefix="lsg", |
|
block_size=128, |
|
lsh_num_pre_rounds=1, |
|
mask_first_token=False, |
|
num_global_tokens=1, |
|
pool_with_global=True, |
|
sparse_block_size=128, |
|
sparsity_factor=2, |
|
sparsity_type="norm", |
|
**kwargs |
|
): |
|
"""Constructs LSGAlbertConfig.""" |
|
super().__init__(**kwargs) |
|
|
|
self.adaptive = adaptive |
|
self.auto_map = AUTO_MAP |
|
self.base_model_prefix = base_model_prefix |
|
self.block_size = block_size |
|
self.lsh_num_pre_rounds = lsh_num_pre_rounds |
|
self.mask_first_token = mask_first_token |
|
self.num_global_tokens = num_global_tokens |
|
self.pool_with_global = pool_with_global |
|
self.sparse_block_size = sparse_block_size |
|
self.sparsity_factor = sparsity_factor |
|
self.sparsity_type = sparsity_type |
|
|
|
if sparsity_type not in [None, "none", "norm", "lsh", "pooling", "stride", "block_stride", "bos_pooling"]: |
|
logger.warning( |
|
"[WARNING CONFIG]: sparsity_mode not in [None, 'none', 'norm', 'lsh', 'pooling', 'stride', 'block_stride', 'bos_pooling'], \ |
|
setting sparsity_type=None, computation will skip sparse attention") |
|
self.sparsity_type = None |
|
|
|
if self.sparsity_type in ["stride", "block_stride"]: |
|
if self.sparsity_factor > self.num_attention_heads: |
|
logger.warning( |
|
"[WARNING CONFIG]: sparsity_factor > num_attention_heads is not recommended for stride/block_stride sparsity" |
|
) |
|
|
|
if self.num_global_tokens < 1: |
|
logger.warning( |
|
"[WARNING CONFIG]: num_global_tokens < 1 is not compatible, setting num_global_tokens=1" |
|
) |
|
self.num_global_tokens = 1 |
|
elif self.num_global_tokens > 512: |
|
logger.warning( |
|
"[WARNING CONFIG]: num_global_tokens > 512 is not allowed, setting num_global_tokens=512" |
|
) |
|
self.num_global_tokens = 512 |
|
|
|
if self.sparsity_factor > 0: |
|
assert self.block_size % self.sparsity_factor == 0, "[ERROR CONFIG]: block_size must be divisible by sparsity_factor" |
|
assert self.block_size//self.sparsity_factor >= 1, "[ERROR CONFIG]: make sure block_size >= sparsity_factor" |
|
|
|
if self.mask_first_token and not pool_with_global: |
|
logger.warning( |
|
"[WARNING CONFIG]: pool_with_global==False is not compatible with mask_first_token==True. Setting pool_with_global to True.") |
|
self.pool_with_global = True |
|
|
|
if hasattr(self, "position_embedding_type"): |
|
if self.position_embedding_type != "absolute": |
|
logger.warning( |
|
"[WARNING CONFIG]: LSG Attention is not compatible with relative positional embedding and will skip its computation. Set position_embedding_type='absolute' to remove this warning.") |
|
|
|
|
|
class BaseSelfAttention(nn.Module): |
|
|
|
def init_modules(self, config): |
|
if config.hidden_size % config.num_attention_heads != 0 and not hasattr( |
|
config, "embedding_size" |
|
): |
|
raise ValueError( |
|
"The hidden size (%d) is not a multiple of the number of attention " |
|
"heads (%d)" % (config.hidden_size, config.num_attention_heads) |
|
) |
|
|
|
self.num_attention_heads = config.num_attention_heads |
|
self.attention_head_size = int(config.hidden_size / config.num_attention_heads) |
|
self.all_head_size = self.num_attention_heads * self.attention_head_size |
|
|
|
self.query = nn.Linear(config.hidden_size, self.all_head_size) |
|
self.key = nn.Linear(config.hidden_size, self.all_head_size) |
|
self.value = nn.Linear(config.hidden_size, self.all_head_size) |
|
|
|
self.dropout = nn.Dropout(config.attention_probs_dropout_prob) |
|
|
|
def transpose_for_scores(self, x): |
|
new_x_shape = x.size()[:-1] + ( |
|
self.num_attention_heads, |
|
self.attention_head_size, |
|
) |
|
x = x.view(*new_x_shape) |
|
return x.permute(0, 2, 1, 3) |
|
|
|
def reshape_output(self, context_layer): |
|
context_layer = context_layer.permute(0, 2, 1, 3).contiguous() |
|
new_context_layer_shape = context_layer.size()[:-2] + (self.all_head_size,) |
|
return context_layer.view(*new_context_layer_shape) |
|
|
|
def project_QKV(self, hidden_states): |
|
|
|
query_layer = self.transpose_for_scores(self.query(hidden_states)) |
|
key_layer = self.transpose_for_scores(self.key(hidden_states)) |
|
value_layer = self.transpose_for_scores(self.value(hidden_states)) |
|
return query_layer, key_layer, value_layer |
|
|
|
|
|
class BaseAttentionProduct(nn.Module): |
|
|
|
def __init__(self, config): |
|
""" |
|
Compute attention: softmax(Q @ K.T) @ V |
|
""" |
|
super().__init__() |
|
self.dropout = nn.Dropout(config.attention_probs_dropout_prob) |
|
|
|
def forward(self, query_layer, key_layer, value_layer, attention_mask=None): |
|
|
|
d = query_layer.shape[-1] |
|
|
|
|
|
attention_scores = query_layer @ key_layer.transpose(-1, -2) / math.sqrt(d) |
|
|
|
del query_layer |
|
del key_layer |
|
|
|
if attention_mask is not None: |
|
|
|
attention_scores = attention_scores + attention_mask |
|
del attention_mask |
|
|
|
|
|
attention_probs = nn.Softmax(dim=-1)(attention_scores) |
|
|
|
|
|
|
|
context_layer = self.dropout(attention_probs) @ value_layer |
|
|
|
return context_layer |
|
|
|
|
|
class CausalAttentionProduct(nn.Module): |
|
|
|
def __init__(self, config): |
|
""" |
|
Compute attention: softmax(Q @ K.T) @ V |
|
""" |
|
super().__init__() |
|
self.dropout = nn.Dropout(config.attention_probs_dropout_prob) |
|
self.block_size = config.block_size |
|
|
|
def forward(self, query_layer, key_layer, value_layer, attention_mask=None, causal_shape=None): |
|
|
|
d = query_layer.shape[-1] |
|
|
|
|
|
attention_scores = query_layer @ key_layer.transpose(-1, -2) / math.sqrt(d) |
|
|
|
del query_layer |
|
del key_layer |
|
|
|
if attention_mask is not None: |
|
|
|
causal_shape = (self.block_size, self.block_size) if causal_shape is None else causal_shape |
|
causal_mask = torch.tril( |
|
torch.ones(*causal_shape, device=attention_mask.device, dtype=attention_scores.dtype), |
|
diagonal=-1 |
|
) |
|
|
|
|
|
dtype_min = torch.tensor( |
|
torch.finfo(attention_scores.dtype).min, device=attention_scores.device, dtype=attention_scores.dtype |
|
) |
|
|
|
|
|
causal_mask = torch.nn.functional.pad(causal_mask.T * dtype_min, (attention_mask.size()[-1] - self.block_size, 0), value=0) |
|
attention_mask = torch.max(attention_mask + causal_mask.unsqueeze(0).unsqueeze(0).unsqueeze(0), dtype_min) |
|
|
|
attention_scores = attention_scores + attention_mask |
|
del attention_mask |
|
del causal_mask |
|
|
|
|
|
attention_probs = nn.Softmax(dim=-1)(attention_scores) |
|
|
|
|
|
|
|
context_layer = self.dropout(attention_probs) @ value_layer |
|
|
|
return context_layer |
|
|
|
|
|
class LSGAttentionProduct(nn.Module): |
|
|
|
def __init__(self, config, block_size=None, sparse_block_size=None, sparsity_factor=4, is_causal=False): |
|
""" |
|
Compute block or overlapping blocks attention products |
|
""" |
|
super().__init__() |
|
|
|
self.block_size = block_size |
|
self.sparse_block_size = sparse_block_size |
|
self.sparsity_factor = sparsity_factor |
|
self.is_causal = is_causal |
|
|
|
if self.block_size is None: |
|
self.block_size = config.block_size |
|
|
|
if self.sparse_block_size is None: |
|
self.sparse_block_size = config.sparse_block_size |
|
|
|
|
|
self.local_shapes = (self.block_size*3, self.block_size) |
|
if self.sparse_block_size and self.sparsity_factor > 0: |
|
self.sparse_shapes = (self.sparse_block_size*3, self.block_size//self.sparsity_factor) |
|
|
|
if is_causal: |
|
self.attention = CausalAttentionProduct(config) |
|
else: |
|
self.attention = BaseAttentionProduct(config) |
|
|
|
def build_lsg_inputs(self, hidden_states, sparse_hidden_states, global_hidden_states, is_attn_mask=False): |
|
|
|
|
|
local_hidden_states = self.reshape_to_local_block(hidden_states, is_attn_mask) |
|
del hidden_states |
|
|
|
|
|
if sparse_hidden_states is not None: |
|
sparse_hidden_states = self.reshape_to_sparse_block(sparse_hidden_states, is_attn_mask) |
|
|
|
return self.cat_global_sparse_local_tokens(global_hidden_states, sparse_hidden_states, local_hidden_states) |
|
|
|
def forward( |
|
self, |
|
query_layer, |
|
key_layer, |
|
value_layer, |
|
attention_mask=None, |
|
sparse_key=None, |
|
sparse_value=None, |
|
sparse_mask=None, |
|
global_key=None, |
|
global_value=None, |
|
global_mask=None |
|
): |
|
|
|
|
|
n, h, t, d = query_layer.size() |
|
n_blocks = t // self.block_size |
|
assert t % self.block_size == 0 |
|
|
|
key_layer = self.build_lsg_inputs( |
|
key_layer, |
|
sparse_key, |
|
global_key |
|
) |
|
del sparse_key |
|
del global_key |
|
|
|
value_layer = self.build_lsg_inputs( |
|
value_layer, |
|
sparse_value, |
|
global_value |
|
) |
|
del sparse_value |
|
del global_value |
|
|
|
attention_mask = self.build_lsg_inputs( |
|
attention_mask, |
|
sparse_mask, |
|
global_mask.transpose(-1, -2), |
|
is_attn_mask=True |
|
).transpose(-1, -2) |
|
del sparse_mask |
|
del global_mask |
|
|
|
|
|
|
|
context_layer = self.attention( |
|
query_layer=self.chunk(query_layer, n_blocks), |
|
key_layer=key_layer, |
|
value_layer=value_layer, |
|
attention_mask=attention_mask |
|
) |
|
|
|
return context_layer.reshape(n, h, -1, d) |
|
|
|
def reshape_to_local_block(self, hidden_states, is_attn_mask=False): |
|
|
|
size, step = self.local_shapes |
|
s = (size - step) // 2 |
|
|
|
|
|
if is_attn_mask: |
|
pad_value = torch.finfo(hidden_states.dtype).min |
|
hidden_states = hidden_states.transpose(-1, -2) |
|
else: |
|
pad_value = 0 |
|
|
|
hidden_states = torch.nn.functional.pad( |
|
hidden_states.transpose(-1, -2), |
|
pad=(s, s), |
|
value=pad_value |
|
).transpose(-1, -2) |
|
|
|
|
|
hidden_states = hidden_states.unfold(-2, size=size, step=step).transpose(-1, -2) |
|
|
|
|
|
if self.is_causal: |
|
return hidden_states[..., :size*2//3, :] |
|
|
|
return hidden_states |
|
|
|
def reshape_to_sparse_block(self, hidden_states, is_attn_mask=False): |
|
|
|
size, step = self.sparse_shapes |
|
|
|
|
|
odd_offset = (step % 2) |
|
|
|
|
|
size = size*2 |
|
s = (size - step) // 2 + odd_offset |
|
|
|
|
|
if is_attn_mask: |
|
pad_value = torch.finfo(hidden_states.dtype).min |
|
hidden_states = hidden_states.transpose(-1, -2) |
|
else: |
|
pad_value = 0 |
|
|
|
hidden_states = torch.nn.functional.pad( |
|
hidden_states.transpose(-1, -2), |
|
pad=(s, s), |
|
value=pad_value |
|
).transpose(-1, -2) |
|
|
|
|
|
hidden_states = hidden_states.unfold(-2, size=size, step=step).transpose(-1, -2) |
|
|
|
|
|
if odd_offset: |
|
hidden_states = hidden_states[..., :-1, :, :] |
|
|
|
|
|
u = (size - self.block_size * 3 // self.sparsity_factor) // 2 + odd_offset |
|
s = self.sparse_block_size |
|
|
|
|
|
if self.is_causal: |
|
return hidden_states[..., u-s:u, :] |
|
|
|
u_ = u + odd_offset |
|
return torch.cat([hidden_states[..., u-s:u, :], hidden_states[..., -u_:-u_+s, :]], dim=-2) |
|
|
|
def cat_global_sparse_local_tokens(self, x_global, x_sparse=None, x_local=None, dim=-2): |
|
|
|
n, h, b, t, d = x_local.size() |
|
x_global = x_global.unsqueeze(-3).expand(-1, -1, b, -1, -1) |
|
if x_sparse is not None: |
|
return torch.cat([x_global, x_sparse, x_local], dim=dim) |
|
return torch.cat([x_global, x_local], dim=dim) |
|
|
|
def chunk(self, x, n_blocks): |
|
|
|
t, d = x.size()[-2:] |
|
return x.reshape(*x.size()[:-2], n_blocks, -1, d) |
|
|
|
|
|
class LSGAlbertEmbeddings(AlbertEmbeddings): |
|
""" |
|
Construct the embeddings from word, position and token_type embeddings. |
|
""" |
|
|
|
def __init__(self, config): |
|
super().__init__(config) |
|
|
|
self.num_global_tokens = config.num_global_tokens |
|
|
|
|
|
self.global_embeddings = nn.Embedding(512, embedding_dim=config.embedding_size, ) |
|
|
|
self.block_size = config.block_size |
|
|
|
def forward( |
|
self, |
|
input_ids: Optional[torch.LongTensor] = None, |
|
token_type_ids: Optional[torch.LongTensor] = None, |
|
position_ids: Optional[torch.LongTensor] = None, |
|
inputs_embeds: Optional[torch.FloatTensor] = None, |
|
past_key_values_length: int = 0, |
|
) -> torch.Tensor: |
|
if input_ids is not None: |
|
input_shape = input_ids.size() |
|
else: |
|
input_shape = inputs_embeds.size()[:-1] |
|
|
|
seq_length = input_shape[1] |
|
|
|
if position_ids is None: |
|
position_ids = self.position_ids[:, past_key_values_length : seq_length + past_key_values_length] |
|
|
|
|
|
|
|
|
|
if token_type_ids is None: |
|
if hasattr(self, "token_type_ids"): |
|
buffered_token_type_ids = self.token_type_ids[:, :seq_length] |
|
buffered_token_type_ids_expanded = buffered_token_type_ids.expand(input_shape[0], seq_length) |
|
token_type_ids = buffered_token_type_ids_expanded |
|
else: |
|
token_type_ids = torch.zeros(input_shape, dtype=torch.long, device=self.position_ids.device) |
|
|
|
if inputs_embeds is None: |
|
inputs_embeds = self.word_embeddings(input_ids) |
|
token_type_embeddings = self.token_type_embeddings(token_type_ids) |
|
|
|
embeddings = inputs_embeds + token_type_embeddings |
|
if self.position_embedding_type == "absolute": |
|
position_embeddings = self.position_embeddings(position_ids) |
|
embeddings += position_embeddings |
|
|
|
n, t, d = embeddings.size() |
|
|
|
|
|
indexes = torch.arange(self.num_global_tokens, device=embeddings.device).reshape(1, -1) |
|
global_embeddings = self.global_embeddings(indexes) |
|
embeddings = torch.cat([global_embeddings.expand(n, -1, d), embeddings], dim=-2) |
|
|
|
|
|
embeddings = self.LayerNorm(embeddings) |
|
embeddings = self.dropout(embeddings) |
|
return embeddings |
|
|
|
|
|
class LSGSelfAttention(BaseSelfAttention): |
|
''' |
|
Compute local attention with overlapping blocs |
|
Use global attention for tokens with highest norm |
|
''' |
|
def __init__(self, config): |
|
super().__init__() |
|
|
|
self.init_modules(config) |
|
|
|
self.output_dropout = nn.Dropout(config.hidden_dropout_prob) |
|
self.dense = nn.Linear(config.hidden_size, config.hidden_size) |
|
self.LayerNorm = nn.LayerNorm(config.hidden_size, eps=config.layer_norm_eps) |
|
|
|
self.block_size = config.block_size |
|
self.sparse_block_size = config.sparse_block_size |
|
self.num_global_tokens = config.num_global_tokens |
|
self.sparsity_factor = config.sparsity_factor |
|
self.is_causal = config.is_decoder |
|
self.is_decoder = config.is_decoder |
|
|
|
self.attention = LSGAttentionProduct( |
|
config, |
|
block_size=config.block_size, |
|
sparse_block_size=config.sparse_block_size, |
|
sparsity_factor=self.sparsity_factor, |
|
is_causal=self.is_causal |
|
) |
|
|
|
if self.is_causal: |
|
self.causal_attention = CausalAttentionProduct(config) |
|
self.full_attention = BaseAttentionProduct(config) |
|
|
|
sparse_functions = { |
|
"norm": self.get_sparse_tokens_with_norm, |
|
"pooling": self.get_sparse_tokens_with_pooling, |
|
"lsh": self.get_sparse_tokens_with_lsh, |
|
"stride": self.get_sparse_tokens_with_stride, |
|
"block_stride": self.get_sparse_tokens_with_block_stride, |
|
"bos_pooling": self.get_sparse_tokens_with_bos_pooling |
|
} |
|
|
|
self.sparsity_type = config.sparsity_type |
|
self.get_sparse_elements = sparse_functions.get(self.sparsity_type, lambda w, x, y, z: (None, None, None)) |
|
|
|
if config.sparsity_type == "lsh": |
|
self.lsh_num_pre_rounds = config.lsh_num_pre_rounds |
|
|
|
def get_sparse_tokens_with_norm(self, queries, keys, values, mask): |
|
|
|
if self.sparsity_factor == 1: |
|
return keys, values, mask.expand(-1, keys.size()[1], -1, -1) |
|
|
|
with torch.no_grad(): |
|
|
|
block_size = min(self.block_size, self.sparse_block_size) |
|
key_norm = keys.detach().norm(dim=-1, keepdim=True) |
|
key_norm = key_norm * ~mask.transpose(-1, -2).bool() |
|
key_norm = self.chunk(key_norm, block_size) |
|
|
|
n, h, b, t, d = key_norm.size() |
|
|
|
idx = key_norm.argsort(dim=-2) |
|
del key_norm |
|
idx += (torch.arange(b, device=keys.device)*t).reshape(1, 1, b, 1, 1) |
|
|
|
split = (t - block_size // self.sparsity_factor, block_size // self.sparsity_factor) |
|
sparse_idx = idx.split(split, -2)[-1].reshape(n, h, -1, 1) |
|
|
|
d = keys.size()[-1] |
|
keys = keys.gather(dim=-2, index=sparse_idx.expand(-1, -1, -1, d)) |
|
values = values.gather(dim=-2, index=sparse_idx.expand(-1, -1, -1, d)) |
|
mask = mask.expand(-1, h, -1, -1).transpose(-1, -2).gather(dim=-2, index=sparse_idx).transpose(-1, -2) |
|
|
|
return keys, values, mask |
|
|
|
def get_sparse_tokens_with_pooling(self, queries, keys, values, mask): |
|
|
|
if self.sparsity_factor == 1: |
|
return keys, values, mask.expand(-1, keys.size()[1], -1, -1) |
|
|
|
keys = self.chunk(keys, self.sparsity_factor) |
|
values = self.chunk(values, self.sparsity_factor) |
|
|
|
n, h, b, t, d = keys.size() |
|
mask = mask.reshape(n, 1, b, 1, t) |
|
mask = ~mask.transpose(-1, -2).bool() |
|
|
|
keys = keys * mask |
|
values = values * mask |
|
|
|
mask = mask.sum(dim=-2) |
|
keys = keys.sum(dim=-2) / (mask + 1e-6) |
|
values = values.sum(dim=-2) / (mask + 1e-6) |
|
|
|
mask = (1. - mask.clamp(0, 1)) |
|
mask *= torch.finfo(mask.dtype).min |
|
return keys.reshape(n, h, -1, d), values.reshape(n, h, -1, d), mask.expand(-1, h, -1, -1).transpose(-1, -2) |
|
|
|
def get_sparse_tokens_with_stride(self, queries, keys, values, mask): |
|
|
|
if self.sparsity_factor == 1: |
|
return keys, values, mask.expand(-1, keys.size()[1], -1, -1) |
|
|
|
n, h, t, d = keys.size() |
|
sparse_idx = torch.arange(t // self.sparsity_factor, device=keys.device) * self.sparsity_factor |
|
sparse_idx = sparse_idx.reshape(1, 1, -1, 1) + (torch.arange(h, device=keys.device) % self.sparsity_factor).reshape(1, h, 1, 1) |
|
sparse_idx = sparse_idx.expand(n, h, -1, 1) |
|
|
|
keys = keys.gather(dim=-2, index=sparse_idx.expand(-1, -1, -1, d)) |
|
values = values.gather(dim=-2, index=sparse_idx.expand(-1, -1, -1, d)) |
|
mask = mask.expand(-1, h, -1, -1).transpose(-1, -2).gather(dim=-2, index=sparse_idx).transpose(-1, -2) |
|
|
|
return keys, values, mask |
|
|
|
def get_sparse_tokens_with_block_stride(self, queries, keys, values, mask): |
|
|
|
if self.sparsity_factor == 1: |
|
return keys, values, mask.expand(-1, keys.size()[1], -1, -1) |
|
|
|
n, h, t, d = keys.size() |
|
|
|
t, b = self.block_size, t // self.block_size |
|
sparse_idx = torch.arange(t // self.sparsity_factor, device=keys.device) |
|
sparse_idx = sparse_idx.reshape(1, 1, 1, -1, 1) + torch.arange(h, device=keys.device).reshape(1, h, 1, 1, 1) * (t // self.sparsity_factor) |
|
sparse_idx = (sparse_idx % t) |
|
sparse_idx = sparse_idx + torch.arange(b, device=keys.device).reshape(1, 1, -1, 1, 1) * t |
|
sparse_idx = sparse_idx.reshape(1, h, -1, 1).expand(n, h, -1, 1) |
|
|
|
keys = keys.gather(dim=-2, index=sparse_idx.expand(-1, -1, -1, d)) |
|
values = values.gather(dim=-2, index=sparse_idx.expand(-1, -1, -1, d)) |
|
mask = mask.expand(-1, h, -1, -1).transpose(-1, -2).gather(dim=-2, index=sparse_idx).transpose(-1, -2) |
|
|
|
return keys, values, mask |
|
|
|
def get_sparse_tokens_with_lsh(self, queries, keys, values, mask): |
|
|
|
if self.sparsity_factor == 1: |
|
return keys, values, mask.expand(-1, keys.size()[1], -1, -1) |
|
|
|
if self.sparsity_factor == self.sparse_block_size: |
|
return self.get_sparse_tokens_with_bos_pooling(queries, keys, values, mask) |
|
|
|
block_size = min(self.block_size, self.sparse_block_size) |
|
keys = self.chunk(keys, block_size) |
|
values = self.chunk(values, block_size) |
|
|
|
n, h, b, t, d = keys.size() |
|
mask = mask.reshape(n, 1, b, 1, t) |
|
mask = ~mask.transpose(-1, -2).bool() |
|
|
|
keys = keys * mask |
|
values = values * mask |
|
mask = mask.expand(-1, h, -1, -1, -1).float() |
|
|
|
extra_factor = 1 |
|
|
|
for _ in range(self.lsh_num_pre_rounds): |
|
keys, values, mask = self.lsh_round(keys, values, mask, t*extra_factor) |
|
|
|
keys, values, mask = self.lsh_round(keys, values, mask, t//self.sparsity_factor) |
|
keys /= mask + 1e-8 |
|
values /= mask + 1e-8 |
|
|
|
mask = (1. - mask.clamp(0, 1)) |
|
mask *= torch.finfo(mask.dtype).min |
|
|
|
return keys.reshape(n, h, -1, d), values.reshape(n, h, -1, d), mask.transpose(-1, -2).reshape(n, h, 1, -1) |
|
|
|
def lsh_round(self, keys, values, mask, output_size): |
|
|
|
with torch.no_grad(): |
|
|
|
n_hashes = output_size // 2 |
|
n, h, b, t, d = keys.size() |
|
binary_mask = mask.clamp(0, 1) |
|
|
|
indexes = (torch.nn.functional.normalize(keys, dim=-1) * binary_mask) @ torch.randn(1, h, 1, d, n_hashes, device=keys.device) |
|
indexes = torch.cat([indexes, -indexes], dim=-1).argmax(dim=-1, keepdim=True) |
|
|
|
n, h, b, t, d = keys.size() |
|
|
|
x_ = torch.zeros(n, h, b, output_size, d, device=keys.device) |
|
mask_ = torch.zeros(n, h, b, output_size, 1, device=keys.device) |
|
keys = torch.scatter_add(x_, dim=-2, index=indexes.expand(-1, -1, -1, -1, d), src=keys) |
|
values = torch.scatter_add(x_, dim=-2, index=indexes.expand(-1, -1, -1, -1, d), src=values) |
|
mask = torch.scatter_add(mask_, dim=-2, index=indexes, src=mask) |
|
|
|
return keys[..., :output_size, :], values[..., :output_size, :], mask[..., :output_size, :] |
|
|
|
def get_sparse_tokens_with_bos_pooling(self, queries, keys, values, mask): |
|
|
|
if self.sparsity_factor == 1: |
|
return keys, values, mask.expand(-1, keys.size()[1], -1, -1) |
|
|
|
queries = queries.unsqueeze(-3) |
|
mask = self.chunk(mask.transpose(-1, -2), self.sparsity_factor).transpose(-1, -2) |
|
keys = self.chunk(keys, self.sparsity_factor) |
|
values = self.chunk(values, self.sparsity_factor) |
|
|
|
n, h, b, t, d = keys.size() |
|
scores = (queries[..., :1, :] @ keys.transpose(-1, -2)) / math.sqrt(d) |
|
if mask is not None: |
|
scores = scores + mask |
|
|
|
scores = torch.softmax(scores, dim=-1) |
|
keys = scores @ keys |
|
values = scores @ values |
|
mask = mask.mean(dim=-1) |
|
mask[mask != torch.finfo(mask.dtype).min] = 0 |
|
|
|
return keys.reshape(n, h, -1, d), values.reshape(n, h, -1, d), mask.expand(-1, h, -1, -1).transpose(-1, -2) |
|
|
|
def forward( |
|
self, |
|
hidden_states, |
|
attention_mask=None, |
|
head_mask=None, |
|
output_attentions=False, |
|
): |
|
|
|
query_layer, key_layer, value_layer = self.project_QKV(hidden_states) |
|
outputs = self.not_causal_forward( |
|
query_layer, |
|
key_layer, |
|
value_layer, |
|
attention_mask=attention_mask, |
|
output_attentions=output_attentions |
|
) |
|
|
|
context = outputs[0] |
|
context = self.dense(context) |
|
context = self.output_dropout(context) |
|
context = self.LayerNorm(context + hidden_states) |
|
|
|
return (context, ) + outputs[1:] |
|
|
|
def not_causal_forward( |
|
self, |
|
query_layer, |
|
key_layer, |
|
value_layer, |
|
attention_mask=None, |
|
output_attentions=False, |
|
): |
|
|
|
n, h, t, d = query_layer.size() |
|
|
|
|
|
attention_mask = torch.nn.functional.pad(attention_mask, (self.num_global_tokens, 0), value=0) |
|
|
|
|
|
if t <= 2 * self.block_size + self.num_global_tokens: |
|
context_layer = self.full_attention( |
|
query_layer=query_layer, |
|
key_layer=key_layer, |
|
value_layer=value_layer, |
|
attention_mask=attention_mask |
|
) |
|
return (self.reshape_output(context_layer), ) |
|
|
|
|
|
split = (self.num_global_tokens, t - self.num_global_tokens) |
|
global_query, query_layer = query_layer.split(split, dim=-2) |
|
|
|
|
|
bos = self.full_attention( |
|
query_layer=global_query, |
|
key_layer=key_layer, |
|
value_layer=value_layer, |
|
attention_mask=attention_mask |
|
) |
|
|
|
|
|
global_key, key_layer = key_layer.split(split, dim=-2) |
|
global_value, value_layer = value_layer.split(split, dim=-2) |
|
global_mask, attention_mask = attention_mask.split(split, dim=-1) |
|
|
|
n, h, t, d = key_layer.size() |
|
|
|
|
|
sparse_key, sparse_value, sparse_mask = (None, None, None) |
|
|
|
if self.sparse_block_size and self.sparsity_factor > 0: |
|
sparse_key, sparse_value, sparse_mask = self.get_sparse_elements(query_layer, key_layer, value_layer, attention_mask) |
|
|
|
|
|
attention_mask = attention_mask.expand(-1, h, -1, -1) |
|
global_mask = global_mask.expand(-1, h, -1, -1) |
|
|
|
|
|
context_layer = self.attention( |
|
query_layer, |
|
key_layer, |
|
value_layer, |
|
attention_mask, |
|
sparse_key=sparse_key, |
|
sparse_value=sparse_value, |
|
sparse_mask=sparse_mask, |
|
global_key=global_key, |
|
global_value=global_value, |
|
global_mask=global_mask |
|
) |
|
|
|
|
|
context_layer = torch.cat([bos, context_layer], dim=-2) |
|
context_layer = self.reshape_output(context_layer) |
|
|
|
return (context_layer,) |
|
|
|
def chunk(self, x, chunk_size): |
|
|
|
n, h, t, d = x.size() |
|
return x.reshape(n, h, -1, chunk_size, d) |
|
|
|
|
|
class LSGAlbertLayer(AlbertLayer): |
|
|
|
def __init__(self, config): |
|
super().__init__(config) |
|
|
|
self.attention = LSGSelfAttention(config) |
|
|
|
|
|
class LSGAlbertLayerGroup(AlbertLayerGroup): |
|
|
|
def __init__(self, config): |
|
|
|
nn.Module.__init__(self) |
|
|
|
self.albert_layers = nn.ModuleList([LSGAlbertLayer(config) for _ in range(config.inner_group_num)]) |
|
|
|
|
|
class LSGAlbertTransformer(AlbertTransformer): |
|
|
|
def __init__(self, config): |
|
|
|
super().__init__(config) |
|
|
|
self.albert_layer_groups = nn.ModuleList([LSGAlbertLayerGroup(config) for _ in range(config.num_hidden_groups)]) |
|
|
|
assert hasattr(config, "num_global_tokens") |
|
self.num_global_tokens = config.num_global_tokens |
|
self.pad_idx = config.pad_token_id |
|
|
|
assert hasattr(config, "block_size") and hasattr(config, "adaptive") |
|
self.block_size = config.block_size |
|
self.adaptive = config.adaptive |
|
self.mask_first_token = config.mask_first_token |
|
self.pool_with_global = config.pool_with_global |
|
|
|
def forward( |
|
self, |
|
hidden_states: torch.Tensor, |
|
attention_mask: Optional[torch.FloatTensor] = None, |
|
head_mask: Optional[torch.FloatTensor] = None, |
|
output_attentions: bool = False, |
|
output_hidden_states: bool = False, |
|
return_dict: bool = True, |
|
) -> Union[BaseModelOutput, Tuple]: |
|
|
|
mask_value = torch.finfo(attention_mask.dtype).min |
|
n, _, __, t = attention_mask.size() |
|
|
|
if self.mask_first_token: |
|
attention_mask[..., 0] = mask_value |
|
|
|
b = self.block_size * 2 |
|
pad = t % self.block_size |
|
|
|
|
|
if self.adaptive and t > b and pad > 0: |
|
pad_length = self.block_size - pad |
|
hidden_states = torch.nn.functional.pad(hidden_states.transpose(-1, -2), (0, pad_length), value=0.).transpose(-1, -2) |
|
attention_mask = torch.nn.functional.pad(attention_mask, (0, pad_length), value=mask_value) |
|
|
|
encoder_outputs = super().forward( |
|
hidden_states=hidden_states, |
|
attention_mask=attention_mask, |
|
head_mask=head_mask, |
|
output_attentions=output_attentions, |
|
output_hidden_states=output_hidden_states, |
|
return_dict=return_dict |
|
) |
|
|
|
sequence_output = encoder_outputs[0] |
|
if self.pool_with_global: |
|
sequence_output[:, self.num_global_tokens] = sequence_output[:, 0] |
|
|
|
|
|
sequence_output = sequence_output[..., self.num_global_tokens: t + self.num_global_tokens, :] |
|
|
|
if not return_dict: |
|
return (sequence_output, ) + encoder_outputs[1:] |
|
|
|
encoder_outputs.last_hidden_state = sequence_output |
|
return encoder_outputs |
|
|
|
|
|
class LSGAlbertPreTrainedModel(PreTrainedModel): |
|
""" |
|
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained |
|
models. |
|
""" |
|
|
|
config_class = LSGAlbertConfig |
|
load_tf_weights = load_tf_weights_in_albert |
|
base_model_prefix = "albert" |
|
|
|
def _init_weights(self, module): |
|
"""Initialize the weights.""" |
|
if isinstance(module, nn.Linear): |
|
|
|
|
|
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) |
|
if module.bias is not None: |
|
module.bias.data.zero_() |
|
elif isinstance(module, nn.Embedding): |
|
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) |
|
if module.padding_idx is not None: |
|
module.weight.data[module.padding_idx].zero_() |
|
elif isinstance(module, nn.LayerNorm): |
|
module.bias.data.zero_() |
|
module.weight.data.fill_(1.0) |
|
|
|
|
|
class LSGAlbertModel(LSGAlbertPreTrainedModel, AlbertModel): |
|
|
|
config_class = LSGAlbertConfig |
|
base_model_prefix = "albert" |
|
|
|
def __init__(self, config, add_pooling_layer=True): |
|
AlbertPreTrainedModel.__init__(self, config) |
|
|
|
self.config = config |
|
self.embeddings = LSGAlbertEmbeddings(config) |
|
self.encoder = LSGAlbertTransformer(config) |
|
if add_pooling_layer: |
|
self.pooler = nn.Linear(config.hidden_size, config.hidden_size) |
|
self.pooler_activation = nn.Tanh() |
|
else: |
|
self.pooler = None |
|
self.pooler_activation = None |
|
|
|
self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2" |
|
if self._use_flash_attention_2: |
|
logger.warning( |
|
"[WARNING flash-attention]: LSG doesnt support flash-attention currently" |
|
) |
|
|
|
self.post_init() |
|
|
|
|
|
class LSGAlbertForPreTraining(LSGAlbertPreTrainedModel, AlbertForPreTraining): |
|
|
|
_tied_weights_keys = ["predictions.decoder.bias", "predictions.decoder.weight"] |
|
|
|
def __init__(self, config): |
|
|
|
LSGAlbertPreTrainedModel.__init__(self, config) |
|
|
|
self.albert = LSGAlbertModel(config) |
|
self.predictions = AlbertMLMHead(config) |
|
self.sop_classifier = AlbertSOPHead(config) |
|
|
|
|
|
self.post_init() |
|
|
|
|
|
class LSGAlbertForMaskedLM(LSGAlbertPreTrainedModel, AlbertForMaskedLM): |
|
|
|
_tied_weights_keys = ["predictions.decoder.bias", "predictions.decoder.weight"] |
|
|
|
def __init__(self, config): |
|
LSGAlbertPreTrainedModel.__init__(self, config) |
|
|
|
self.albert = LSGAlbertModel(config, add_pooling_layer=False) |
|
self.predictions = AlbertMLMHead(config) |
|
|
|
|
|
self.post_init() |
|
|
|
|
|
class LSGAlbertForSequenceClassification(LSGAlbertPreTrainedModel, AlbertForSequenceClassification): |
|
|
|
def __init__(self, config): |
|
|
|
LSGAlbertPreTrainedModel.__init__(self, config) |
|
self.num_labels = config.num_labels |
|
self.config = config |
|
|
|
self.albert = LSGAlbertModel(config) |
|
self.dropout = nn.Dropout(config.classifier_dropout_prob) |
|
self.classifier = nn.Linear(config.hidden_size, self.config.num_labels) |
|
|
|
|
|
self.post_init() |
|
|
|
|
|
class LSGAlbertForTokenClassification(LSGAlbertPreTrainedModel, AlbertForTokenClassification): |
|
|
|
def __init__(self, config): |
|
|
|
LSGAlbertPreTrainedModel.__init__(self, config) |
|
self.num_labels = config.num_labels |
|
|
|
self.albert = LSGAlbertModel(config, add_pooling_layer=False) |
|
classifier_dropout_prob = ( |
|
config.classifier_dropout_prob |
|
if config.classifier_dropout_prob is not None |
|
else config.hidden_dropout_prob |
|
) |
|
self.dropout = nn.Dropout(classifier_dropout_prob) |
|
self.classifier = nn.Linear(config.hidden_size, self.config.num_labels) |
|
|
|
|
|
self.post_init() |
|
|
|
|
|
class LSGAlbertForQuestionAnswering(LSGAlbertPreTrainedModel, AlbertForQuestionAnswering): |
|
|
|
def __init__(self, config): |
|
|
|
LSGAlbertPreTrainedModel.__init__(self, config) |
|
self.num_labels = config.num_labels |
|
|
|
self.albert = LSGAlbertModel(config, add_pooling_layer=False) |
|
self.qa_outputs = nn.Linear(config.hidden_size, config.num_labels) |
|
|
|
|
|
self.post_init() |
|
|
|
|
|
class LSGAlbertForMultipleChoice(LSGAlbertPreTrainedModel, AlbertForMultipleChoice): |
|
|
|
def __init__(self, config): |
|
|
|
LSGAlbertPreTrainedModel.__init__(self, config) |
|
|
|
self.albert = LSGAlbertModel(config) |
|
self.dropout = nn.Dropout(config.classifier_dropout_prob) |
|
self.classifier = nn.Linear(config.hidden_size, 1) |
|
|
|
|
|
self.post_init() |
|
|
|
|
|
def str_to_class(classname): |
|
return getattr(sys.modules[__name__], classname) |
|
|
|
|
|
try: |
|
LSGAlbertConfig.register_for_auto_class() |
|
for key, value in AUTO_MAP.items(): |
|
str_to_class(value.split(".")[-1]).register_for_auto_class(key) |
|
except: |
|
warn("AutoRegister isn't available, you'll have to manually copy modeling.py after .save_pretrained(...).") |
|
warn("Update to transformers >= 4.36.1 to fix.") |