small fix
Browse files- modeling_lsg_albert.py +33 -28
modeling_lsg_albert.py
CHANGED
@@ -17,7 +17,7 @@ AUTO_MAP = {
|
|
17 |
|
18 |
class LSGAlbertConfig(AlbertConfig):
|
19 |
"""
|
20 |
-
This class overrides :class:`~transformers.
|
21 |
documentation alongside usage examples.
|
22 |
"""
|
23 |
|
@@ -55,7 +55,8 @@ class LSGAlbertConfig(AlbertConfig):
|
|
55 |
|
56 |
if sparsity_type not in [None, "none", "norm", "lsh", "pooling", "stride", "block_stride"]:
|
57 |
logger.warning(
|
58 |
-
"[WARNING CONFIG]: sparsity_mode not in [None, 'none', 'norm', 'lsh', 'pooling', 'stride', 'block_stride'],
|
|
|
59 |
self.sparsity_type = None
|
60 |
|
61 |
if self.sparsity_type in ["stride", "block_stride"]:
|
@@ -71,7 +72,7 @@ class LSGAlbertConfig(AlbertConfig):
|
|
71 |
self.num_global_tokens = 1
|
72 |
elif self.num_global_tokens > 512:
|
73 |
logger.warning(
|
74 |
-
"[WARNING CONFIG]: num_global_tokens > 512 is not
|
75 |
)
|
76 |
self.num_global_tokens = 512
|
77 |
|
@@ -79,7 +80,17 @@ class LSGAlbertConfig(AlbertConfig):
|
|
79 |
assert self.block_size % self.sparsity_factor == 0, "[ERROR CONFIG]: block_size must be divisible by sparsity_factor"
|
80 |
assert self.block_size//self.sparsity_factor >= 1, "[ERROR CONFIG]: make sure block_size >= sparsity_factor"
|
81 |
|
|
|
|
|
|
|
|
|
82 |
|
|
|
|
|
|
|
|
|
|
|
|
|
83 |
class BaseSelfAttention(nn.Module):
|
84 |
|
85 |
def init_modules(self, config):
|
@@ -635,9 +646,6 @@ class LSGAttention(BaseSelfAttention):
|
|
635 |
hidden_states,
|
636 |
attention_mask=None,
|
637 |
head_mask=None,
|
638 |
-
encoder_hidden_states=None,
|
639 |
-
encoder_attention_mask=None,
|
640 |
-
past_key_value=None,
|
641 |
output_attentions=False,
|
642 |
):
|
643 |
|
@@ -655,11 +663,7 @@ class LSGAttention(BaseSelfAttention):
|
|
655 |
context = self.output_dropout(context)
|
656 |
context = self.LayerNorm(context + hidden_states)
|
657 |
|
658 |
-
|
659 |
-
|
660 |
-
#if head_mask is not None:
|
661 |
-
# outputs = (outputs[0] * head_mask[:, :, :1, :1], ) + outputs[1:]
|
662 |
-
return outputs
|
663 |
|
664 |
def not_causal_forward(
|
665 |
self,
|
@@ -751,6 +755,7 @@ class LSGAlbertLayer(AlbertLayer):
|
|
751 |
class LSGAlbertLayerGroup(AlbertLayerGroup):
|
752 |
|
753 |
def __init__(self, config):
|
|
|
754 |
nn.Module.__init__(self)
|
755 |
|
756 |
self.albert_layers = nn.ModuleList([LSGAlbertLayer(config) for _ in range(config.inner_group_num)])
|
@@ -759,10 +764,9 @@ class LSGAlbertLayerGroup(AlbertLayerGroup):
|
|
759 |
class LSGAlbertTransformer(AlbertTransformer):
|
760 |
|
761 |
def __init__(self, config):
|
762 |
-
nn.Module.__init__(self)
|
763 |
|
764 |
-
|
765 |
-
|
766 |
self.albert_layer_groups = nn.ModuleList([LSGAlbertLayerGroup(config) for _ in range(config.num_hidden_groups)])
|
767 |
|
768 |
|
@@ -838,6 +842,12 @@ class LSGAlbertModel(LSGAlbertPreTrainedModel, AlbertModel):
|
|
838 |
return_dict=None,
|
839 |
):
|
840 |
|
|
|
|
|
|
|
|
|
|
|
|
|
841 |
inputs_ = input_ids if input_ids is not None else inputs_embeds
|
842 |
n, t = inputs_.size()[:2]
|
843 |
|
@@ -878,31 +888,26 @@ class LSGAlbertModel(LSGAlbertPreTrainedModel, AlbertModel):
|
|
878 |
return_dict=return_dict
|
879 |
)
|
880 |
|
881 |
-
|
882 |
if self.pool_with_global:
|
883 |
-
|
884 |
|
885 |
diff = t - t_
|
886 |
-
n, _, d =
|
887 |
-
|
888 |
|
889 |
# Adapt sequence to initial shape
|
890 |
if diff < 0:
|
891 |
-
|
892 |
|
893 |
-
encoder_outputs.last_hidden_state = context
|
894 |
-
sequence_output = encoder_outputs[0]
|
895 |
pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
|
896 |
|
897 |
if not return_dict:
|
898 |
return (sequence_output, pooled_output) + encoder_outputs[1:]
|
899 |
-
|
900 |
-
|
901 |
-
|
902 |
-
|
903 |
-
hidden_states=encoder_outputs.hidden_states,
|
904 |
-
attentions=encoder_outputs.attentions,
|
905 |
-
)
|
906 |
|
907 |
|
908 |
class LSGAlbertForPreTraining(LSGAlbertPreTrainedModel, AlbertForPreTraining):
|
|
|
17 |
|
18 |
class LSGAlbertConfig(AlbertConfig):
|
19 |
"""
|
20 |
+
This class overrides :class:`~transformers.AlbertConfig`. Please check the superclass for the appropriate
|
21 |
documentation alongside usage examples.
|
22 |
"""
|
23 |
|
|
|
55 |
|
56 |
if sparsity_type not in [None, "none", "norm", "lsh", "pooling", "stride", "block_stride"]:
|
57 |
logger.warning(
|
58 |
+
"[WARNING CONFIG]: sparsity_mode not in [None, 'none', 'norm', 'lsh', 'pooling', 'stride', 'block_stride'], \
|
59 |
+
setting sparsity_type=None, computation will skip sparse attention")
|
60 |
self.sparsity_type = None
|
61 |
|
62 |
if self.sparsity_type in ["stride", "block_stride"]:
|
|
|
72 |
self.num_global_tokens = 1
|
73 |
elif self.num_global_tokens > 512:
|
74 |
logger.warning(
|
75 |
+
"[WARNING CONFIG]: num_global_tokens > 512 is not allowed, setting num_global_tokens=512"
|
76 |
)
|
77 |
self.num_global_tokens = 512
|
78 |
|
|
|
80 |
assert self.block_size % self.sparsity_factor == 0, "[ERROR CONFIG]: block_size must be divisible by sparsity_factor"
|
81 |
assert self.block_size//self.sparsity_factor >= 1, "[ERROR CONFIG]: make sure block_size >= sparsity_factor"
|
82 |
|
83 |
+
if self.mask_first_token and not pool_with_global:
|
84 |
+
logger.warning(
|
85 |
+
"[WARNING CONFIG]: pool_with_global==False is not compatible with mask_first_token==True. Setting pool_with_global to True.")
|
86 |
+
self.pool_with_global = True
|
87 |
|
88 |
+
if hasattr(self, "position_embedding_type"):
|
89 |
+
if self.position_embedding_type != "absolute":
|
90 |
+
logger.warning(
|
91 |
+
"[WARNING CONFIG]: LSG Attention is not compatible with relative positional embedding and will skip its computation. Set position_embedding_type='absolute' to remove this warning.")
|
92 |
+
|
93 |
+
|
94 |
class BaseSelfAttention(nn.Module):
|
95 |
|
96 |
def init_modules(self, config):
|
|
|
646 |
hidden_states,
|
647 |
attention_mask=None,
|
648 |
head_mask=None,
|
|
|
|
|
|
|
649 |
output_attentions=False,
|
650 |
):
|
651 |
|
|
|
663 |
context = self.output_dropout(context)
|
664 |
context = self.LayerNorm(context + hidden_states)
|
665 |
|
666 |
+
return (context, ) + outputs[1:]
|
|
|
|
|
|
|
|
|
667 |
|
668 |
def not_causal_forward(
|
669 |
self,
|
|
|
755 |
class LSGAlbertLayerGroup(AlbertLayerGroup):
|
756 |
|
757 |
def __init__(self, config):
|
758 |
+
|
759 |
nn.Module.__init__(self)
|
760 |
|
761 |
self.albert_layers = nn.ModuleList([LSGAlbertLayer(config) for _ in range(config.inner_group_num)])
|
|
|
764 |
class LSGAlbertTransformer(AlbertTransformer):
|
765 |
|
766 |
def __init__(self, config):
|
|
|
767 |
|
768 |
+
super().__init__(config)
|
769 |
+
|
770 |
self.albert_layer_groups = nn.ModuleList([LSGAlbertLayerGroup(config) for _ in range(config.num_hidden_groups)])
|
771 |
|
772 |
|
|
|
842 |
return_dict=None,
|
843 |
):
|
844 |
|
845 |
+
output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
|
846 |
+
output_hidden_states = (
|
847 |
+
output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
|
848 |
+
)
|
849 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
850 |
+
|
851 |
inputs_ = input_ids if input_ids is not None else inputs_embeds
|
852 |
n, t = inputs_.size()[:2]
|
853 |
|
|
|
888 |
return_dict=return_dict
|
889 |
)
|
890 |
|
891 |
+
sequence_output = encoder_outputs[0]
|
892 |
if self.pool_with_global:
|
893 |
+
sequence_output[:, self.num_global_tokens] = sequence_output[:, 0]
|
894 |
|
895 |
diff = t - t_
|
896 |
+
n, _, d = sequence_output.size()
|
897 |
+
sequence_output = sequence_output[..., self.num_global_tokens:, :]
|
898 |
|
899 |
# Adapt sequence to initial shape
|
900 |
if diff < 0:
|
901 |
+
sequence_output = sequence_output[:, :t]
|
902 |
|
|
|
|
|
903 |
pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
|
904 |
|
905 |
if not return_dict:
|
906 |
return (sequence_output, pooled_output) + encoder_outputs[1:]
|
907 |
+
|
908 |
+
encoder_outputs.last_hidden_state = sequence_output
|
909 |
+
encoder_outputs.pooler_output = pooled_output
|
910 |
+
return encoder_outputs
|
|
|
|
|
|
|
911 |
|
912 |
|
913 |
class LSGAlbertForPreTraining(LSGAlbertPreTrainedModel, AlbertForPreTraining):
|