ccdv commited on
Commit
db91e1c
·
1 Parent(s): 2d1da30
Files changed (1) hide show
  1. modeling_lsg_albert.py +33 -28
modeling_lsg_albert.py CHANGED
@@ -17,7 +17,7 @@ AUTO_MAP = {
17
 
18
  class LSGAlbertConfig(AlbertConfig):
19
  """
20
- This class overrides :class:`~transformers.LSGAlbertConfig`. Please check the superclass for the appropriate
21
  documentation alongside usage examples.
22
  """
23
 
@@ -55,7 +55,8 @@ class LSGAlbertConfig(AlbertConfig):
55
 
56
  if sparsity_type not in [None, "none", "norm", "lsh", "pooling", "stride", "block_stride"]:
57
  logger.warning(
58
- "[WARNING CONFIG]: sparsity_mode not in [None, 'none', 'norm', 'lsh', 'pooling', 'stride', 'block_stride'], setting sparsity_type=None, computation will skip sparse attention")
 
59
  self.sparsity_type = None
60
 
61
  if self.sparsity_type in ["stride", "block_stride"]:
@@ -71,7 +72,7 @@ class LSGAlbertConfig(AlbertConfig):
71
  self.num_global_tokens = 1
72
  elif self.num_global_tokens > 512:
73
  logger.warning(
74
- "[WARNING CONFIG]: num_global_tokens > 512 is not compatible, setting num_global_tokens=512"
75
  )
76
  self.num_global_tokens = 512
77
 
@@ -79,7 +80,17 @@ class LSGAlbertConfig(AlbertConfig):
79
  assert self.block_size % self.sparsity_factor == 0, "[ERROR CONFIG]: block_size must be divisible by sparsity_factor"
80
  assert self.block_size//self.sparsity_factor >= 1, "[ERROR CONFIG]: make sure block_size >= sparsity_factor"
81
 
 
 
 
 
82
 
 
 
 
 
 
 
83
  class BaseSelfAttention(nn.Module):
84
 
85
  def init_modules(self, config):
@@ -635,9 +646,6 @@ class LSGAttention(BaseSelfAttention):
635
  hidden_states,
636
  attention_mask=None,
637
  head_mask=None,
638
- encoder_hidden_states=None,
639
- encoder_attention_mask=None,
640
- past_key_value=None,
641
  output_attentions=False,
642
  ):
643
 
@@ -655,11 +663,7 @@ class LSGAttention(BaseSelfAttention):
655
  context = self.output_dropout(context)
656
  context = self.LayerNorm(context + hidden_states)
657
 
658
- outputs = (context, ) + outputs[1:]
659
-
660
- #if head_mask is not None:
661
- # outputs = (outputs[0] * head_mask[:, :, :1, :1], ) + outputs[1:]
662
- return outputs
663
 
664
  def not_causal_forward(
665
  self,
@@ -751,6 +755,7 @@ class LSGAlbertLayer(AlbertLayer):
751
  class LSGAlbertLayerGroup(AlbertLayerGroup):
752
 
753
  def __init__(self, config):
 
754
  nn.Module.__init__(self)
755
 
756
  self.albert_layers = nn.ModuleList([LSGAlbertLayer(config) for _ in range(config.inner_group_num)])
@@ -759,10 +764,9 @@ class LSGAlbertLayerGroup(AlbertLayerGroup):
759
  class LSGAlbertTransformer(AlbertTransformer):
760
 
761
  def __init__(self, config):
762
- nn.Module.__init__(self)
763
 
764
- self.config = config
765
- self.embedding_hidden_mapping_in = nn.Linear(config.embedding_size, config.hidden_size)
766
  self.albert_layer_groups = nn.ModuleList([LSGAlbertLayerGroup(config) for _ in range(config.num_hidden_groups)])
767
 
768
 
@@ -838,6 +842,12 @@ class LSGAlbertModel(LSGAlbertPreTrainedModel, AlbertModel):
838
  return_dict=None,
839
  ):
840
 
 
 
 
 
 
 
841
  inputs_ = input_ids if input_ids is not None else inputs_embeds
842
  n, t = inputs_.size()[:2]
843
 
@@ -878,31 +888,26 @@ class LSGAlbertModel(LSGAlbertPreTrainedModel, AlbertModel):
878
  return_dict=return_dict
879
  )
880
 
881
- context = encoder_outputs[0]
882
  if self.pool_with_global:
883
- context[:, self.num_global_tokens] = context[:, 0]
884
 
885
  diff = t - t_
886
- n, _, d = context.size()
887
- context = context[..., self.num_global_tokens:, :]
888
 
889
  # Adapt sequence to initial shape
890
  if diff < 0:
891
- context = context[:, :t]
892
 
893
- encoder_outputs.last_hidden_state = context
894
- sequence_output = encoder_outputs[0]
895
  pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
896
 
897
  if not return_dict:
898
  return (sequence_output, pooled_output) + encoder_outputs[1:]
899
-
900
- return BaseModelOutputWithPooling(
901
- last_hidden_state=sequence_output,
902
- pooler_output=pooled_output,
903
- hidden_states=encoder_outputs.hidden_states,
904
- attentions=encoder_outputs.attentions,
905
- )
906
 
907
 
908
  class LSGAlbertForPreTraining(LSGAlbertPreTrainedModel, AlbertForPreTraining):
 
17
 
18
  class LSGAlbertConfig(AlbertConfig):
19
  """
20
+ This class overrides :class:`~transformers.AlbertConfig`. Please check the superclass for the appropriate
21
  documentation alongside usage examples.
22
  """
23
 
 
55
 
56
  if sparsity_type not in [None, "none", "norm", "lsh", "pooling", "stride", "block_stride"]:
57
  logger.warning(
58
+ "[WARNING CONFIG]: sparsity_mode not in [None, 'none', 'norm', 'lsh', 'pooling', 'stride', 'block_stride'], \
59
+ setting sparsity_type=None, computation will skip sparse attention")
60
  self.sparsity_type = None
61
 
62
  if self.sparsity_type in ["stride", "block_stride"]:
 
72
  self.num_global_tokens = 1
73
  elif self.num_global_tokens > 512:
74
  logger.warning(
75
+ "[WARNING CONFIG]: num_global_tokens > 512 is not allowed, setting num_global_tokens=512"
76
  )
77
  self.num_global_tokens = 512
78
 
 
80
  assert self.block_size % self.sparsity_factor == 0, "[ERROR CONFIG]: block_size must be divisible by sparsity_factor"
81
  assert self.block_size//self.sparsity_factor >= 1, "[ERROR CONFIG]: make sure block_size >= sparsity_factor"
82
 
83
+ if self.mask_first_token and not pool_with_global:
84
+ logger.warning(
85
+ "[WARNING CONFIG]: pool_with_global==False is not compatible with mask_first_token==True. Setting pool_with_global to True.")
86
+ self.pool_with_global = True
87
 
88
+ if hasattr(self, "position_embedding_type"):
89
+ if self.position_embedding_type != "absolute":
90
+ logger.warning(
91
+ "[WARNING CONFIG]: LSG Attention is not compatible with relative positional embedding and will skip its computation. Set position_embedding_type='absolute' to remove this warning.")
92
+
93
+
94
  class BaseSelfAttention(nn.Module):
95
 
96
  def init_modules(self, config):
 
646
  hidden_states,
647
  attention_mask=None,
648
  head_mask=None,
 
 
 
649
  output_attentions=False,
650
  ):
651
 
 
663
  context = self.output_dropout(context)
664
  context = self.LayerNorm(context + hidden_states)
665
 
666
+ return (context, ) + outputs[1:]
 
 
 
 
667
 
668
  def not_causal_forward(
669
  self,
 
755
  class LSGAlbertLayerGroup(AlbertLayerGroup):
756
 
757
  def __init__(self, config):
758
+
759
  nn.Module.__init__(self)
760
 
761
  self.albert_layers = nn.ModuleList([LSGAlbertLayer(config) for _ in range(config.inner_group_num)])
 
764
  class LSGAlbertTransformer(AlbertTransformer):
765
 
766
  def __init__(self, config):
 
767
 
768
+ super().__init__(config)
769
+
770
  self.albert_layer_groups = nn.ModuleList([LSGAlbertLayerGroup(config) for _ in range(config.num_hidden_groups)])
771
 
772
 
 
842
  return_dict=None,
843
  ):
844
 
845
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
846
+ output_hidden_states = (
847
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
848
+ )
849
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
850
+
851
  inputs_ = input_ids if input_ids is not None else inputs_embeds
852
  n, t = inputs_.size()[:2]
853
 
 
888
  return_dict=return_dict
889
  )
890
 
891
+ sequence_output = encoder_outputs[0]
892
  if self.pool_with_global:
893
+ sequence_output[:, self.num_global_tokens] = sequence_output[:, 0]
894
 
895
  diff = t - t_
896
+ n, _, d = sequence_output.size()
897
+ sequence_output = sequence_output[..., self.num_global_tokens:, :]
898
 
899
  # Adapt sequence to initial shape
900
  if diff < 0:
901
+ sequence_output = sequence_output[:, :t]
902
 
 
 
903
  pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
904
 
905
  if not return_dict:
906
  return (sequence_output, pooled_output) + encoder_outputs[1:]
907
+
908
+ encoder_outputs.last_hidden_state = sequence_output
909
+ encoder_outputs.pooler_output = pooled_output
910
+ return encoder_outputs
 
 
 
911
 
912
 
913
  class LSGAlbertForPreTraining(LSGAlbertPreTrainedModel, AlbertForPreTraining):