small fix with torch.finfo
Browse files- modeling_lsg_albert.py +63 -94
modeling_lsg_albert.py
CHANGED
@@ -198,7 +198,7 @@ class CausalAttentionProduct(nn.Module):
|
|
198 |
diagonal=-1
|
199 |
)
|
200 |
causal_mask = causal_mask.T * torch.finfo(attention_scores.dtype).min
|
201 |
-
attention_scores[..., -causal_shape[0]:, -causal_shape[1]:] = causal_mask
|
202 |
|
203 |
del attention_mask
|
204 |
|
@@ -551,7 +551,8 @@ class LSGAttention(BaseSelfAttention):
|
|
551 |
keys = keys.sum(dim=-2) / (mask + 1e-6)
|
552 |
values = values.sum(dim=-2) / (mask + 1e-6)
|
553 |
|
554 |
-
mask = (1. - mask.clamp(0, 1))
|
|
|
555 |
return keys.reshape(n, h, -1, d), values.reshape(n, h, -1, d), mask.expand(-1, h, -1, -1).transpose(-1, -2)
|
556 |
|
557 |
def get_sparse_tokens_with_stride(self, keys, values, mask):
|
@@ -616,7 +617,8 @@ class LSGAttention(BaseSelfAttention):
|
|
616 |
keys /= mask + 1e-8
|
617 |
values /= mask + 1e-8
|
618 |
|
619 |
-
mask = (1. - mask.clamp(0, 1))
|
|
|
620 |
|
621 |
return keys.reshape(n, h, -1, d), values.reshape(n, h, -1, d), mask.transpose(-1, -2).reshape(n, h, 1, -1)
|
622 |
|
@@ -769,6 +771,63 @@ class LSGAlbertTransformer(AlbertTransformer):
|
|
769 |
|
770 |
self.albert_layer_groups = nn.ModuleList([LSGAlbertLayerGroup(config) for _ in range(config.num_hidden_groups)])
|
771 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
772 |
|
773 |
class LSGAlbertPreTrainedModel(PreTrainedModel):
|
774 |
"""
|
@@ -806,16 +865,6 @@ class LSGAlbertModel(LSGAlbertPreTrainedModel, AlbertModel):
|
|
806 |
def __init__(self, config, add_pooling_layer=True):
|
807 |
AlbertPreTrainedModel.__init__(self, config)
|
808 |
|
809 |
-
assert hasattr(config, "num_global_tokens")
|
810 |
-
self.num_global_tokens = config.num_global_tokens
|
811 |
-
self.pad_idx = config.pad_token_id
|
812 |
-
|
813 |
-
assert hasattr(config, "block_size") and hasattr(config, "adaptive")
|
814 |
-
self.block_size = config.block_size
|
815 |
-
self.adaptive = config.adaptive
|
816 |
-
self.mask_first_token = config.mask_first_token
|
817 |
-
self.pool_with_global = config.pool_with_global
|
818 |
-
|
819 |
self.config = config
|
820 |
self.embeddings = LSGAlbertEmbeddings(config)
|
821 |
self.encoder = LSGAlbertTransformer(config)
|
@@ -829,87 +878,7 @@ class LSGAlbertModel(LSGAlbertPreTrainedModel, AlbertModel):
|
|
829 |
# Initialize weights and apply final processing
|
830 |
self.post_init()
|
831 |
|
832 |
-
|
833 |
-
self,
|
834 |
-
input_ids=None,
|
835 |
-
attention_mask=None,
|
836 |
-
token_type_ids=None,
|
837 |
-
position_ids=None,
|
838 |
-
head_mask=None,
|
839 |
-
inputs_embeds=None,
|
840 |
-
output_attentions=None,
|
841 |
-
output_hidden_states=None,
|
842 |
-
return_dict=None,
|
843 |
-
):
|
844 |
-
|
845 |
-
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
846 |
-
output_hidden_states = (
|
847 |
-
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
848 |
-
)
|
849 |
-
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
850 |
-
|
851 |
-
inputs_ = input_ids if input_ids is not None else inputs_embeds
|
852 |
-
n, t = inputs_.size()[:2]
|
853 |
-
|
854 |
-
if attention_mask is None:
|
855 |
-
attention_mask = torch.ones(n, t, device=inputs_.device, dtype=inputs_.dtype)
|
856 |
-
if self.mask_first_token:
|
857 |
-
attention_mask[:,0] = 0
|
858 |
-
|
859 |
-
b = self.block_size * 2
|
860 |
-
pad = t % self.block_size
|
861 |
-
|
862 |
-
# Check if t is multiple of block_size and pad
|
863 |
-
if self.adaptive and t > b and pad > 0:
|
864 |
-
pad_length = self.block_size - pad
|
865 |
-
if input_ids is not None:
|
866 |
-
input_ids = torch.nn.functional.pad(input_ids, (0, pad_length), value=self.pad_idx)
|
867 |
-
else:
|
868 |
-
inputs_embeds = torch.nn.functional.pad(inputs_embeds.transpose(-1, -2), (0, pad_length), value=0.).transpose(-1, -2)
|
869 |
-
|
870 |
-
attention_mask = torch.nn.functional.pad(attention_mask, (0, pad_length), value=0)
|
871 |
-
|
872 |
-
if token_type_ids is not None:
|
873 |
-
token_type_ids = torch.nn.functional.pad(token_type_ids, (0, pad_length), value=0)
|
874 |
-
if position_ids is not None:
|
875 |
-
position_ids = torch.nn.functional.pad(position_ids, (0, pad_length), value=0)
|
876 |
-
|
877 |
-
n, t_ = attention_mask.size()
|
878 |
-
|
879 |
-
encoder_outputs = super().forward(
|
880 |
-
input_ids=input_ids,
|
881 |
-
attention_mask=attention_mask,
|
882 |
-
token_type_ids=token_type_ids,
|
883 |
-
position_ids=position_ids,
|
884 |
-
head_mask=head_mask,
|
885 |
-
inputs_embeds=inputs_embeds,
|
886 |
-
output_attentions=output_attentions,
|
887 |
-
output_hidden_states=output_hidden_states,
|
888 |
-
return_dict=return_dict
|
889 |
-
)
|
890 |
-
|
891 |
-
sequence_output = encoder_outputs[0]
|
892 |
-
if self.pool_with_global:
|
893 |
-
sequence_output[:, self.num_global_tokens] = sequence_output[:, 0]
|
894 |
-
|
895 |
-
diff = t - t_
|
896 |
-
n, _, d = sequence_output.size()
|
897 |
-
sequence_output = sequence_output[..., self.num_global_tokens:, :]
|
898 |
-
|
899 |
-
# Adapt sequence to initial shape
|
900 |
-
if diff < 0:
|
901 |
-
sequence_output = sequence_output[:, :t]
|
902 |
-
|
903 |
-
pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
|
904 |
-
|
905 |
-
if not return_dict:
|
906 |
-
return (sequence_output, pooled_output) + encoder_outputs[1:]
|
907 |
-
|
908 |
-
encoder_outputs.last_hidden_state = sequence_output
|
909 |
-
encoder_outputs.pooler_output = pooled_output
|
910 |
-
return encoder_outputs
|
911 |
-
|
912 |
-
|
913 |
class LSGAlbertForPreTraining(LSGAlbertPreTrainedModel, AlbertForPreTraining):
|
914 |
|
915 |
def __init__(self, config):
|
|
|
198 |
diagonal=-1
|
199 |
)
|
200 |
causal_mask = causal_mask.T * torch.finfo(attention_scores.dtype).min
|
201 |
+
attention_scores[..., -causal_shape[0]:, -causal_shape[1] + 1:] = causal_mask[:, 1:]
|
202 |
|
203 |
del attention_mask
|
204 |
|
|
|
551 |
keys = keys.sum(dim=-2) / (mask + 1e-6)
|
552 |
values = values.sum(dim=-2) / (mask + 1e-6)
|
553 |
|
554 |
+
mask = (1. - mask.clamp(0, 1))
|
555 |
+
mask *= torch.finfo(mask.dtype).min
|
556 |
return keys.reshape(n, h, -1, d), values.reshape(n, h, -1, d), mask.expand(-1, h, -1, -1).transpose(-1, -2)
|
557 |
|
558 |
def get_sparse_tokens_with_stride(self, keys, values, mask):
|
|
|
617 |
keys /= mask + 1e-8
|
618 |
values /= mask + 1e-8
|
619 |
|
620 |
+
mask = (1. - mask.clamp(0, 1))
|
621 |
+
mask *= torch.finfo(mask.dtype).min
|
622 |
|
623 |
return keys.reshape(n, h, -1, d), values.reshape(n, h, -1, d), mask.transpose(-1, -2).reshape(n, h, 1, -1)
|
624 |
|
|
|
771 |
|
772 |
self.albert_layer_groups = nn.ModuleList([LSGAlbertLayerGroup(config) for _ in range(config.num_hidden_groups)])
|
773 |
|
774 |
+
assert hasattr(config, "num_global_tokens")
|
775 |
+
self.num_global_tokens = config.num_global_tokens
|
776 |
+
self.pad_idx = config.pad_token_id
|
777 |
+
|
778 |
+
assert hasattr(config, "block_size") and hasattr(config, "adaptive")
|
779 |
+
self.block_size = config.block_size
|
780 |
+
self.adaptive = config.adaptive
|
781 |
+
self.mask_first_token = config.mask_first_token
|
782 |
+
self.pool_with_global = config.pool_with_global
|
783 |
+
|
784 |
+
def forward(
|
785 |
+
self,
|
786 |
+
hidden_states: torch.Tensor,
|
787 |
+
attention_mask: Optional[torch.FloatTensor] = None,
|
788 |
+
head_mask: Optional[torch.FloatTensor] = None,
|
789 |
+
output_attentions: bool = False,
|
790 |
+
output_hidden_states: bool = False,
|
791 |
+
return_dict: bool = True,
|
792 |
+
) -> Union[BaseModelOutput, Tuple]:
|
793 |
+
|
794 |
+
mask_value = torch.finfo(attention_mask.dtype).min
|
795 |
+
n, _, __, t = attention_mask.size()
|
796 |
+
|
797 |
+
if self.mask_first_token:
|
798 |
+
attention_mask[..., 0] = mask_value
|
799 |
+
|
800 |
+
b = self.block_size * 2
|
801 |
+
pad = t % self.block_size
|
802 |
+
|
803 |
+
# Check if t is multiple of block_size and pad
|
804 |
+
if self.adaptive and t > b and pad > 0:
|
805 |
+
pad_length = self.block_size - pad
|
806 |
+
hidden_states = torch.nn.functional.pad(hidden_states.transpose(-1, -2), (0, pad_length), value=0.).transpose(-1, -2)
|
807 |
+
attention_mask = torch.nn.functional.pad(attention_mask, (0, pad_length), value=mask_value)
|
808 |
+
|
809 |
+
encoder_outputs = super().forward(
|
810 |
+
hidden_states=hidden_states,
|
811 |
+
attention_mask=attention_mask,
|
812 |
+
head_mask=head_mask,
|
813 |
+
output_attentions=output_attentions,
|
814 |
+
output_hidden_states=output_hidden_states,
|
815 |
+
return_dict=return_dict
|
816 |
+
)
|
817 |
+
|
818 |
+
sequence_output = encoder_outputs[0]
|
819 |
+
if self.pool_with_global:
|
820 |
+
sequence_output[:, self.num_global_tokens] = sequence_output[:, 0]
|
821 |
+
|
822 |
+
# Adapt sequence to initial shape
|
823 |
+
sequence_output = sequence_output[..., self.num_global_tokens: t + self.num_global_tokens, :]
|
824 |
+
|
825 |
+
if not return_dict:
|
826 |
+
return (sequence_output, ) + encoder_outputs[1:]
|
827 |
+
|
828 |
+
encoder_outputs.last_hidden_state = sequence_output
|
829 |
+
return encoder_outputs
|
830 |
+
|
831 |
|
832 |
class LSGAlbertPreTrainedModel(PreTrainedModel):
|
833 |
"""
|
|
|
865 |
def __init__(self, config, add_pooling_layer=True):
|
866 |
AlbertPreTrainedModel.__init__(self, config)
|
867 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
868 |
self.config = config
|
869 |
self.embeddings = LSGAlbertEmbeddings(config)
|
870 |
self.encoder = LSGAlbertTransformer(config)
|
|
|
878 |
# Initialize weights and apply final processing
|
879 |
self.post_init()
|
880 |
|
881 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
882 |
class LSGAlbertForPreTraining(LSGAlbertPreTrainedModel, AlbertForPreTraining):
|
883 |
|
884 |
def __init__(self, config):
|