Update modeling_hf_nomic_bert.py
Browse files- modeling_hf_nomic_bert.py +22 -6
modeling_hf_nomic_bert.py
CHANGED
@@ -16,7 +16,7 @@ from einops import rearrange, repeat
|
|
16 |
from transformers import GPT2Config, PreTrainedModel
|
17 |
from transformers.models.bert.modeling_bert import (
|
18 |
BaseModelOutputWithPoolingAndCrossAttentions,
|
19 |
-
|
20 |
SequenceClassifierOutput
|
21 |
)
|
22 |
|
@@ -321,7 +321,10 @@ class NomicBertPreTrainedModel(PreTrainedModel):
|
|
321 |
ignore_mismatched_shapes = kwargs.pop("ignore_mismatched_sizes", False)
|
322 |
num_labels = kwargs.pop("num_labels", None)
|
323 |
rotary_scaling_factor = kwargs.pop("rotary_scaling_factor", None)
|
324 |
-
|
|
|
|
|
|
|
325 |
if config.n_positions <= 0 and config.rotary_emb_fraction > 0:
|
326 |
config.n_positions = 2048
|
327 |
if num_labels:
|
@@ -330,7 +333,10 @@ class NomicBertPreTrainedModel(PreTrainedModel):
|
|
330 |
if "add_pooling_layer" in kwargs:
|
331 |
model = cls(config, *inputs, add_pooling_layer=kwargs.pop("add_pooling_layer"))
|
332 |
else:
|
333 |
-
|
|
|
|
|
|
|
334 |
# TODO: fix this
|
335 |
# Assuming we know what we're doing when loading from disk
|
336 |
# Prob a bad assumption but i'm tired and want to train this asap
|
@@ -551,6 +557,12 @@ class NomicBertRotaryEmbedding(nn.Module):
|
|
551 |
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
552 |
self.interleaved = interleaved
|
553 |
self.scale_base = scale_base
|
|
|
|
|
|
|
|
|
|
|
|
|
554 |
|
555 |
self._seq_len_cached = 0
|
556 |
self._cos_cached = None
|
@@ -616,7 +628,9 @@ class NomicBertRotaryEmbedding(nn.Module):
|
|
616 |
Apply rotary embedding *inplace* to qkv and / or kv.
|
617 |
"""
|
618 |
seqlen = qkv.shape[1]
|
619 |
-
if
|
|
|
|
|
620 |
self._update_cos_sin_cache(max_seqlen, device=qkv.device, dtype=qkv.dtype)
|
621 |
elif isinstance(seqlen_offset, int):
|
622 |
self._update_cos_sin_cache(seqlen + seqlen_offset, device=qkv.device, dtype=qkv.dtype)
|
@@ -1133,9 +1147,11 @@ class NomicBertForPreTraining(NomicBertPreTrainedModel):
|
|
1133 |
)
|
1134 |
total_loss = masked_lm_loss.float()
|
1135 |
|
1136 |
-
return
|
1137 |
loss=total_loss,
|
1138 |
-
|
|
|
|
|
1139 |
)
|
1140 |
|
1141 |
|
|
|
16 |
from transformers import GPT2Config, PreTrainedModel
|
17 |
from transformers.models.bert.modeling_bert import (
|
18 |
BaseModelOutputWithPoolingAndCrossAttentions,
|
19 |
+
MaskedLMOutput,
|
20 |
SequenceClassifierOutput
|
21 |
)
|
22 |
|
|
|
321 |
ignore_mismatched_shapes = kwargs.pop("ignore_mismatched_sizes", False)
|
322 |
num_labels = kwargs.pop("num_labels", None)
|
323 |
rotary_scaling_factor = kwargs.pop("rotary_scaling_factor", None)
|
324 |
+
if rotary_scaling_factor:
|
325 |
+
config.rotary_scaling_factor = rotary_scaling_factor
|
326 |
+
else:
|
327 |
+
config.rotary_scaling_factor = None
|
328 |
if config.n_positions <= 0 and config.rotary_emb_fraction > 0:
|
329 |
config.n_positions = 2048
|
330 |
if num_labels:
|
|
|
333 |
if "add_pooling_layer" in kwargs:
|
334 |
model = cls(config, *inputs, add_pooling_layer=kwargs.pop("add_pooling_layer"))
|
335 |
else:
|
336 |
+
if cls == NomicBertModel:
|
337 |
+
model = cls(config, *inputs, add_pooling_layer=False)
|
338 |
+
else:
|
339 |
+
model = cls(config, *inputs)
|
340 |
# TODO: fix this
|
341 |
# Assuming we know what we're doing when loading from disk
|
342 |
# Prob a bad assumption but i'm tired and want to train this asap
|
|
|
557 |
self.register_buffer("inv_freq", inv_freq, persistent=False)
|
558 |
self.interleaved = interleaved
|
559 |
self.scale_base = scale_base
|
560 |
+
scale = (
|
561 |
+
(torch.arange(0, dim, 2, device=device, dtype=torch.float32) + 0.4 * dim) / (1.4 * dim)
|
562 |
+
if scale_base is not None
|
563 |
+
else None
|
564 |
+
)
|
565 |
+
self.register_buffer("scale", scale, persistent=False)
|
566 |
|
567 |
self._seq_len_cached = 0
|
568 |
self._cos_cached = None
|
|
|
628 |
Apply rotary embedding *inplace* to qkv and / or kv.
|
629 |
"""
|
630 |
seqlen = qkv.shape[1]
|
631 |
+
if seqlen > self._seq_len_cached:
|
632 |
+
self._update_cos_sin_cache(seqlen, device=qkv.device, dtype=qkv.dtype)
|
633 |
+
elif max_seqlen is not None:
|
634 |
self._update_cos_sin_cache(max_seqlen, device=qkv.device, dtype=qkv.dtype)
|
635 |
elif isinstance(seqlen_offset, int):
|
636 |
self._update_cos_sin_cache(seqlen + seqlen_offset, device=qkv.device, dtype=qkv.dtype)
|
|
|
1147 |
)
|
1148 |
total_loss = masked_lm_loss.float()
|
1149 |
|
1150 |
+
return MaskedLMOutput(
|
1151 |
loss=total_loss,
|
1152 |
+
logits=prediction_scores,
|
1153 |
+
hidden_states=outputs.hidden_states,
|
1154 |
+
attentions=None,
|
1155 |
)
|
1156 |
|
1157 |
|