zpn commited on
Commit
94e044d
·
verified ·
1 Parent(s): 504ce99

Update modeling_hf_nomic_bert.py

Browse files
Files changed (1) hide show
  1. modeling_hf_nomic_bert.py +22 -6
modeling_hf_nomic_bert.py CHANGED
@@ -16,7 +16,7 @@ from einops import rearrange, repeat
16
  from transformers import GPT2Config, PreTrainedModel
17
  from transformers.models.bert.modeling_bert import (
18
  BaseModelOutputWithPoolingAndCrossAttentions,
19
- BertForPreTrainingOutput,
20
  SequenceClassifierOutput
21
  )
22
 
@@ -321,7 +321,10 @@ class NomicBertPreTrainedModel(PreTrainedModel):
321
  ignore_mismatched_shapes = kwargs.pop("ignore_mismatched_sizes", False)
322
  num_labels = kwargs.pop("num_labels", None)
323
  rotary_scaling_factor = kwargs.pop("rotary_scaling_factor", None)
324
- config.rotary_scaling_factor = rotary_scaling_factor
 
 
 
325
  if config.n_positions <= 0 and config.rotary_emb_fraction > 0:
326
  config.n_positions = 2048
327
  if num_labels:
@@ -330,7 +333,10 @@ class NomicBertPreTrainedModel(PreTrainedModel):
330
  if "add_pooling_layer" in kwargs:
331
  model = cls(config, *inputs, add_pooling_layer=kwargs.pop("add_pooling_layer"))
332
  else:
333
- model = cls(config, *inputs, add_pooling_layer=False)
 
 
 
334
  # TODO: fix this
335
  # Assuming we know what we're doing when loading from disk
336
  # Prob a bad assumption but i'm tired and want to train this asap
@@ -551,6 +557,12 @@ class NomicBertRotaryEmbedding(nn.Module):
551
  self.register_buffer("inv_freq", inv_freq, persistent=False)
552
  self.interleaved = interleaved
553
  self.scale_base = scale_base
 
 
 
 
 
 
554
 
555
  self._seq_len_cached = 0
556
  self._cos_cached = None
@@ -616,7 +628,9 @@ class NomicBertRotaryEmbedding(nn.Module):
616
  Apply rotary embedding *inplace* to qkv and / or kv.
617
  """
618
  seqlen = qkv.shape[1]
619
- if max_seqlen is not None:
 
 
620
  self._update_cos_sin_cache(max_seqlen, device=qkv.device, dtype=qkv.dtype)
621
  elif isinstance(seqlen_offset, int):
622
  self._update_cos_sin_cache(seqlen + seqlen_offset, device=qkv.device, dtype=qkv.dtype)
@@ -1133,9 +1147,11 @@ class NomicBertForPreTraining(NomicBertPreTrainedModel):
1133
  )
1134
  total_loss = masked_lm_loss.float()
1135
 
1136
- return BertForPreTrainingOutput(
1137
  loss=total_loss,
1138
- prediction_logits=prediction_scores,
 
 
1139
  )
1140
 
1141
 
 
16
  from transformers import GPT2Config, PreTrainedModel
17
  from transformers.models.bert.modeling_bert import (
18
  BaseModelOutputWithPoolingAndCrossAttentions,
19
+ MaskedLMOutput,
20
  SequenceClassifierOutput
21
  )
22
 
 
321
  ignore_mismatched_shapes = kwargs.pop("ignore_mismatched_sizes", False)
322
  num_labels = kwargs.pop("num_labels", None)
323
  rotary_scaling_factor = kwargs.pop("rotary_scaling_factor", None)
324
+ if rotary_scaling_factor:
325
+ config.rotary_scaling_factor = rotary_scaling_factor
326
+ else:
327
+ config.rotary_scaling_factor = None
328
  if config.n_positions <= 0 and config.rotary_emb_fraction > 0:
329
  config.n_positions = 2048
330
  if num_labels:
 
333
  if "add_pooling_layer" in kwargs:
334
  model = cls(config, *inputs, add_pooling_layer=kwargs.pop("add_pooling_layer"))
335
  else:
336
+ if cls == NomicBertModel:
337
+ model = cls(config, *inputs, add_pooling_layer=False)
338
+ else:
339
+ model = cls(config, *inputs)
340
  # TODO: fix this
341
  # Assuming we know what we're doing when loading from disk
342
  # Prob a bad assumption but i'm tired and want to train this asap
 
557
  self.register_buffer("inv_freq", inv_freq, persistent=False)
558
  self.interleaved = interleaved
559
  self.scale_base = scale_base
560
+ scale = (
561
+ (torch.arange(0, dim, 2, device=device, dtype=torch.float32) + 0.4 * dim) / (1.4 * dim)
562
+ if scale_base is not None
563
+ else None
564
+ )
565
+ self.register_buffer("scale", scale, persistent=False)
566
 
567
  self._seq_len_cached = 0
568
  self._cos_cached = None
 
628
  Apply rotary embedding *inplace* to qkv and / or kv.
629
  """
630
  seqlen = qkv.shape[1]
631
+ if seqlen > self._seq_len_cached:
632
+ self._update_cos_sin_cache(seqlen, device=qkv.device, dtype=qkv.dtype)
633
+ elif max_seqlen is not None:
634
  self._update_cos_sin_cache(max_seqlen, device=qkv.device, dtype=qkv.dtype)
635
  elif isinstance(seqlen_offset, int):
636
  self._update_cos_sin_cache(seqlen + seqlen_offset, device=qkv.device, dtype=qkv.dtype)
 
1147
  )
1148
  total_loss = masked_lm_loss.float()
1149
 
1150
+ return MaskedLMOutput(
1151
  loss=total_loss,
1152
+ logits=prediction_scores,
1153
+ hidden_states=outputs.hidden_states,
1154
+ attentions=None,
1155
  )
1156
 
1157