Added optional input embeddings to bypass NeoBERT.encoder
Browse files
model.py
CHANGED
@@ -6,7 +6,7 @@ from torch import nn
|
|
6 |
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
7 |
from torch.nn.functional import scaled_dot_product_attention
|
8 |
|
9 |
-
from typing import Optional
|
10 |
import numpy as np
|
11 |
|
12 |
from xformers.ops import SwiGLU
|
@@ -234,11 +234,12 @@ class NeoBERT(NeoBERTPreTrainedModel):
|
|
234 |
|
235 |
def forward(
|
236 |
self,
|
237 |
-
input_ids: torch.Tensor,
|
238 |
position_ids: torch.Tensor = None,
|
239 |
max_seqlen: int = None,
|
240 |
cu_seqlens: torch.Tensor = None,
|
241 |
attention_mask: torch.Tensor = None,
|
|
|
242 |
output_hidden_states: bool = False,
|
243 |
output_attentions: bool = False,
|
244 |
**kwargs,
|
@@ -246,6 +247,9 @@ class NeoBERT(NeoBERTPreTrainedModel):
|
|
246 |
# Initialize
|
247 |
hidden_states, attentions = [], []
|
248 |
|
|
|
|
|
|
|
249 |
# Expand and repeat: (Batch, Length) -> (Batch, Heads, Length, Length)
|
250 |
if attention_mask is not None:
|
251 |
attention_mask = attention_mask.unsqueeze(1).unsqueeze(1).repeat(1, self.config.num_attention_heads, attention_mask.size(-1), 1)
|
@@ -257,14 +261,22 @@ class NeoBERT(NeoBERTPreTrainedModel):
|
|
257 |
), "Flash-attention is not available. Please ''pip install flash_attn'', or provide un-packed sequences."
|
258 |
assert not output_attentions, "Output attentions is not supported when sequences are packed."
|
259 |
assert max_seqlen is not None, "Missing max_seqlen. It must be provided when cu_seqlens are not None."
|
260 |
-
assert input_ids
|
261 |
-
|
|
|
|
|
|
|
|
|
262 |
|
263 |
# RoPE
|
264 |
-
freqs_cis =
|
|
|
|
|
|
|
|
|
265 |
|
266 |
# Embedding
|
267 |
-
x = self.encoder(input_ids)
|
268 |
|
269 |
# Transformer encoder
|
270 |
for layer in self.transformer_encoder:
|
@@ -356,7 +368,7 @@ class NeoBERTForSequenceClassification(NeoBERTPreTrainedModel):
|
|
356 |
|
357 |
def forward(
|
358 |
self,
|
359 |
-
input_ids: torch.Tensor,
|
360 |
position_ids: torch.Tensor = None,
|
361 |
max_seqlen: int = None,
|
362 |
cu_seqlens: torch.Tensor = None,
|
|
|
6 |
from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
|
7 |
from torch.nn.functional import scaled_dot_product_attention
|
8 |
|
9 |
+
from typing import Optional, Tuple
|
10 |
import numpy as np
|
11 |
|
12 |
from xformers.ops import SwiGLU
|
|
|
234 |
|
235 |
def forward(
|
236 |
self,
|
237 |
+
input_ids: Optional[torch.Tensor] = None,
|
238 |
position_ids: torch.Tensor = None,
|
239 |
max_seqlen: int = None,
|
240 |
cu_seqlens: torch.Tensor = None,
|
241 |
attention_mask: torch.Tensor = None,
|
242 |
+
inputs_embeds: Optional[torch.Tensor] = None,
|
243 |
output_hidden_states: bool = False,
|
244 |
output_attentions: bool = False,
|
245 |
**kwargs,
|
|
|
247 |
# Initialize
|
248 |
hidden_states, attentions = [], []
|
249 |
|
250 |
+
if (input_ids is None) ^ (inputs_embeds is not None):
|
251 |
+
raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
|
252 |
+
|
253 |
# Expand and repeat: (Batch, Length) -> (Batch, Heads, Length, Length)
|
254 |
if attention_mask is not None:
|
255 |
attention_mask = attention_mask.unsqueeze(1).unsqueeze(1).repeat(1, self.config.num_attention_heads, attention_mask.size(-1), 1)
|
|
|
261 |
), "Flash-attention is not available. Please ''pip install flash_attn'', or provide un-packed sequences."
|
262 |
assert not output_attentions, "Output attentions is not supported when sequences are packed."
|
263 |
assert max_seqlen is not None, "Missing max_seqlen. It must be provided when cu_seqlens are not None."
|
264 |
+
assert (input_ids if input_ids is not None else inputs_embeds).shape[
|
265 |
+
0
|
266 |
+
] == 1, "Cumulative sequence lengths are provided but inputs are not packed."
|
267 |
+
assert (
|
268 |
+
input_ids if input_ids is not None else inputs_embeds
|
269 |
+
).is_cuda, "Packing uses an implementation of flash-attention and is only supported on GPU."
|
270 |
|
271 |
# RoPE
|
272 |
+
freqs_cis = (
|
273 |
+
self.freqs_cis[position_ids]
|
274 |
+
if position_ids is not None
|
275 |
+
else self.freqs_cis[: (input_ids if input_ids is not None else inputs_embeds).shape[1]].unsqueeze(0)
|
276 |
+
)
|
277 |
|
278 |
# Embedding
|
279 |
+
x = self.encoder(input_ids) if input_ids is not None else inputs_embeds
|
280 |
|
281 |
# Transformer encoder
|
282 |
for layer in self.transformer_encoder:
|
|
|
368 |
|
369 |
def forward(
|
370 |
self,
|
371 |
+
input_ids: Optional[torch.Tensor] = None,
|
372 |
position_ids: torch.Tensor = None,
|
373 |
max_seqlen: int = None,
|
374 |
cu_seqlens: torch.Tensor = None,
|