ccdv commited on
Commit
71321ab
·
1 Parent(s): 0cb6ffb

small fix with torch.finfo

Browse files
Files changed (1) hide show
  1. modeling_lsg_albert.py +63 -94
modeling_lsg_albert.py CHANGED
@@ -198,7 +198,7 @@ class CausalAttentionProduct(nn.Module):
198
  diagonal=-1
199
  )
200
  causal_mask = causal_mask.T * torch.finfo(attention_scores.dtype).min
201
- attention_scores[..., -causal_shape[0]:, -causal_shape[1]:] = causal_mask
202
 
203
  del attention_mask
204
 
@@ -551,7 +551,8 @@ class LSGAttention(BaseSelfAttention):
551
  keys = keys.sum(dim=-2) / (mask + 1e-6)
552
  values = values.sum(dim=-2) / (mask + 1e-6)
553
 
554
- mask = (1. - mask.clamp(0, 1)) * torch.finfo(mask.dtype).min
 
555
  return keys.reshape(n, h, -1, d), values.reshape(n, h, -1, d), mask.expand(-1, h, -1, -1).transpose(-1, -2)
556
 
557
  def get_sparse_tokens_with_stride(self, keys, values, mask):
@@ -616,7 +617,8 @@ class LSGAttention(BaseSelfAttention):
616
  keys /= mask + 1e-8
617
  values /= mask + 1e-8
618
 
619
- mask = (1. - mask.clamp(0, 1)) * torch.finfo(mask.dtype).min
 
620
 
621
  return keys.reshape(n, h, -1, d), values.reshape(n, h, -1, d), mask.transpose(-1, -2).reshape(n, h, 1, -1)
622
 
@@ -769,6 +771,63 @@ class LSGAlbertTransformer(AlbertTransformer):
769
 
770
  self.albert_layer_groups = nn.ModuleList([LSGAlbertLayerGroup(config) for _ in range(config.num_hidden_groups)])
771
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
772
 
773
  class LSGAlbertPreTrainedModel(PreTrainedModel):
774
  """
@@ -806,16 +865,6 @@ class LSGAlbertModel(LSGAlbertPreTrainedModel, AlbertModel):
806
  def __init__(self, config, add_pooling_layer=True):
807
  AlbertPreTrainedModel.__init__(self, config)
808
 
809
- assert hasattr(config, "num_global_tokens")
810
- self.num_global_tokens = config.num_global_tokens
811
- self.pad_idx = config.pad_token_id
812
-
813
- assert hasattr(config, "block_size") and hasattr(config, "adaptive")
814
- self.block_size = config.block_size
815
- self.adaptive = config.adaptive
816
- self.mask_first_token = config.mask_first_token
817
- self.pool_with_global = config.pool_with_global
818
-
819
  self.config = config
820
  self.embeddings = LSGAlbertEmbeddings(config)
821
  self.encoder = LSGAlbertTransformer(config)
@@ -829,87 +878,7 @@ class LSGAlbertModel(LSGAlbertPreTrainedModel, AlbertModel):
829
  # Initialize weights and apply final processing
830
  self.post_init()
831
 
832
- def forward(
833
- self,
834
- input_ids=None,
835
- attention_mask=None,
836
- token_type_ids=None,
837
- position_ids=None,
838
- head_mask=None,
839
- inputs_embeds=None,
840
- output_attentions=None,
841
- output_hidden_states=None,
842
- return_dict=None,
843
- ):
844
-
845
- output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
846
- output_hidden_states = (
847
- output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
848
- )
849
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
850
-
851
- inputs_ = input_ids if input_ids is not None else inputs_embeds
852
- n, t = inputs_.size()[:2]
853
-
854
- if attention_mask is None:
855
- attention_mask = torch.ones(n, t, device=inputs_.device, dtype=inputs_.dtype)
856
- if self.mask_first_token:
857
- attention_mask[:,0] = 0
858
-
859
- b = self.block_size * 2
860
- pad = t % self.block_size
861
-
862
- # Check if t is multiple of block_size and pad
863
- if self.adaptive and t > b and pad > 0:
864
- pad_length = self.block_size - pad
865
- if input_ids is not None:
866
- input_ids = torch.nn.functional.pad(input_ids, (0, pad_length), value=self.pad_idx)
867
- else:
868
- inputs_embeds = torch.nn.functional.pad(inputs_embeds.transpose(-1, -2), (0, pad_length), value=0.).transpose(-1, -2)
869
-
870
- attention_mask = torch.nn.functional.pad(attention_mask, (0, pad_length), value=0)
871
-
872
- if token_type_ids is not None:
873
- token_type_ids = torch.nn.functional.pad(token_type_ids, (0, pad_length), value=0)
874
- if position_ids is not None:
875
- position_ids = torch.nn.functional.pad(position_ids, (0, pad_length), value=0)
876
-
877
- n, t_ = attention_mask.size()
878
-
879
- encoder_outputs = super().forward(
880
- input_ids=input_ids,
881
- attention_mask=attention_mask,
882
- token_type_ids=token_type_ids,
883
- position_ids=position_ids,
884
- head_mask=head_mask,
885
- inputs_embeds=inputs_embeds,
886
- output_attentions=output_attentions,
887
- output_hidden_states=output_hidden_states,
888
- return_dict=return_dict
889
- )
890
-
891
- sequence_output = encoder_outputs[0]
892
- if self.pool_with_global:
893
- sequence_output[:, self.num_global_tokens] = sequence_output[:, 0]
894
-
895
- diff = t - t_
896
- n, _, d = sequence_output.size()
897
- sequence_output = sequence_output[..., self.num_global_tokens:, :]
898
-
899
- # Adapt sequence to initial shape
900
- if diff < 0:
901
- sequence_output = sequence_output[:, :t]
902
-
903
- pooled_output = self.pooler(sequence_output) if self.pooler is not None else None
904
-
905
- if not return_dict:
906
- return (sequence_output, pooled_output) + encoder_outputs[1:]
907
-
908
- encoder_outputs.last_hidden_state = sequence_output
909
- encoder_outputs.pooler_output = pooled_output
910
- return encoder_outputs
911
-
912
-
913
  class LSGAlbertForPreTraining(LSGAlbertPreTrainedModel, AlbertForPreTraining):
914
 
915
  def __init__(self, config):
 
198
  diagonal=-1
199
  )
200
  causal_mask = causal_mask.T * torch.finfo(attention_scores.dtype).min
201
+ attention_scores[..., -causal_shape[0]:, -causal_shape[1] + 1:] = causal_mask[:, 1:]
202
 
203
  del attention_mask
204
 
 
551
  keys = keys.sum(dim=-2) / (mask + 1e-6)
552
  values = values.sum(dim=-2) / (mask + 1e-6)
553
 
554
+ mask = (1. - mask.clamp(0, 1))
555
+ mask *= torch.finfo(mask.dtype).min
556
  return keys.reshape(n, h, -1, d), values.reshape(n, h, -1, d), mask.expand(-1, h, -1, -1).transpose(-1, -2)
557
 
558
  def get_sparse_tokens_with_stride(self, keys, values, mask):
 
617
  keys /= mask + 1e-8
618
  values /= mask + 1e-8
619
 
620
+ mask = (1. - mask.clamp(0, 1))
621
+ mask *= torch.finfo(mask.dtype).min
622
 
623
  return keys.reshape(n, h, -1, d), values.reshape(n, h, -1, d), mask.transpose(-1, -2).reshape(n, h, 1, -1)
624
 
 
771
 
772
  self.albert_layer_groups = nn.ModuleList([LSGAlbertLayerGroup(config) for _ in range(config.num_hidden_groups)])
773
 
774
+ assert hasattr(config, "num_global_tokens")
775
+ self.num_global_tokens = config.num_global_tokens
776
+ self.pad_idx = config.pad_token_id
777
+
778
+ assert hasattr(config, "block_size") and hasattr(config, "adaptive")
779
+ self.block_size = config.block_size
780
+ self.adaptive = config.adaptive
781
+ self.mask_first_token = config.mask_first_token
782
+ self.pool_with_global = config.pool_with_global
783
+
784
+ def forward(
785
+ self,
786
+ hidden_states: torch.Tensor,
787
+ attention_mask: Optional[torch.FloatTensor] = None,
788
+ head_mask: Optional[torch.FloatTensor] = None,
789
+ output_attentions: bool = False,
790
+ output_hidden_states: bool = False,
791
+ return_dict: bool = True,
792
+ ) -> Union[BaseModelOutput, Tuple]:
793
+
794
+ mask_value = torch.finfo(attention_mask.dtype).min
795
+ n, _, __, t = attention_mask.size()
796
+
797
+ if self.mask_first_token:
798
+ attention_mask[..., 0] = mask_value
799
+
800
+ b = self.block_size * 2
801
+ pad = t % self.block_size
802
+
803
+ # Check if t is multiple of block_size and pad
804
+ if self.adaptive and t > b and pad > 0:
805
+ pad_length = self.block_size - pad
806
+ hidden_states = torch.nn.functional.pad(hidden_states.transpose(-1, -2), (0, pad_length), value=0.).transpose(-1, -2)
807
+ attention_mask = torch.nn.functional.pad(attention_mask, (0, pad_length), value=mask_value)
808
+
809
+ encoder_outputs = super().forward(
810
+ hidden_states=hidden_states,
811
+ attention_mask=attention_mask,
812
+ head_mask=head_mask,
813
+ output_attentions=output_attentions,
814
+ output_hidden_states=output_hidden_states,
815
+ return_dict=return_dict
816
+ )
817
+
818
+ sequence_output = encoder_outputs[0]
819
+ if self.pool_with_global:
820
+ sequence_output[:, self.num_global_tokens] = sequence_output[:, 0]
821
+
822
+ # Adapt sequence to initial shape
823
+ sequence_output = sequence_output[..., self.num_global_tokens: t + self.num_global_tokens, :]
824
+
825
+ if not return_dict:
826
+ return (sequence_output, ) + encoder_outputs[1:]
827
+
828
+ encoder_outputs.last_hidden_state = sequence_output
829
+ return encoder_outputs
830
+
831
 
832
  class LSGAlbertPreTrainedModel(PreTrainedModel):
833
  """
 
865
  def __init__(self, config, add_pooling_layer=True):
866
  AlbertPreTrainedModel.__init__(self, config)
867
 
 
 
 
 
 
 
 
 
 
 
868
  self.config = config
869
  self.embeddings = LSGAlbertEmbeddings(config)
870
  self.encoder = LSGAlbertTransformer(config)
 
878
  # Initialize weights and apply final processing
879
  self.post_init()
880
 
881
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
882
  class LSGAlbertForPreTraining(LSGAlbertPreTrainedModel, AlbertForPreTraining):
883
 
884
  def __init__(self, config):