Lolalb commited on
Commit
e7504ed
·
verified ·
1 Parent(s): a4fbc49

Added optional input embeddings to bypass NeoBERT.encoder

Browse files
Files changed (1) hide show
  1. model.py +19 -7
model.py CHANGED
@@ -6,7 +6,7 @@ from torch import nn
6
  from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
7
  from torch.nn.functional import scaled_dot_product_attention
8
 
9
- from typing import Optional
10
  import numpy as np
11
 
12
  from xformers.ops import SwiGLU
@@ -234,11 +234,12 @@ class NeoBERT(NeoBERTPreTrainedModel):
234
 
235
  def forward(
236
  self,
237
- input_ids: torch.Tensor,
238
  position_ids: torch.Tensor = None,
239
  max_seqlen: int = None,
240
  cu_seqlens: torch.Tensor = None,
241
  attention_mask: torch.Tensor = None,
 
242
  output_hidden_states: bool = False,
243
  output_attentions: bool = False,
244
  **kwargs,
@@ -246,6 +247,9 @@ class NeoBERT(NeoBERTPreTrainedModel):
246
  # Initialize
247
  hidden_states, attentions = [], []
248
 
 
 
 
249
  # Expand and repeat: (Batch, Length) -> (Batch, Heads, Length, Length)
250
  if attention_mask is not None:
251
  attention_mask = attention_mask.unsqueeze(1).unsqueeze(1).repeat(1, self.config.num_attention_heads, attention_mask.size(-1), 1)
@@ -257,14 +261,22 @@ class NeoBERT(NeoBERTPreTrainedModel):
257
  ), "Flash-attention is not available. Please ''pip install flash_attn'', or provide un-packed sequences."
258
  assert not output_attentions, "Output attentions is not supported when sequences are packed."
259
  assert max_seqlen is not None, "Missing max_seqlen. It must be provided when cu_seqlens are not None."
260
- assert input_ids.shape[0] == 1, "Cumulative sequence lengths are provided but input_ids are not packed."
261
- assert input_ids.is_cuda, "Packing uses an implementation of flash-attention and is only supported on GPU."
 
 
 
 
262
 
263
  # RoPE
264
- freqs_cis = self.freqs_cis[position_ids] if position_ids is not None else self.freqs_cis[: input_ids.shape[1]].unsqueeze(0)
 
 
 
 
265
 
266
  # Embedding
267
- x = self.encoder(input_ids)
268
 
269
  # Transformer encoder
270
  for layer in self.transformer_encoder:
@@ -356,7 +368,7 @@ class NeoBERTForSequenceClassification(NeoBERTPreTrainedModel):
356
 
357
  def forward(
358
  self,
359
- input_ids: torch.Tensor,
360
  position_ids: torch.Tensor = None,
361
  max_seqlen: int = None,
362
  cu_seqlens: torch.Tensor = None,
 
6
  from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
7
  from torch.nn.functional import scaled_dot_product_attention
8
 
9
+ from typing import Optional, Tuple
10
  import numpy as np
11
 
12
  from xformers.ops import SwiGLU
 
234
 
235
  def forward(
236
  self,
237
+ input_ids: Optional[torch.Tensor] = None,
238
  position_ids: torch.Tensor = None,
239
  max_seqlen: int = None,
240
  cu_seqlens: torch.Tensor = None,
241
  attention_mask: torch.Tensor = None,
242
+ inputs_embeds: Optional[torch.Tensor] = None,
243
  output_hidden_states: bool = False,
244
  output_attentions: bool = False,
245
  **kwargs,
 
247
  # Initialize
248
  hidden_states, attentions = [], []
249
 
250
+ if (input_ids is None) ^ (inputs_embeds is not None):
251
+ raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
252
+
253
  # Expand and repeat: (Batch, Length) -> (Batch, Heads, Length, Length)
254
  if attention_mask is not None:
255
  attention_mask = attention_mask.unsqueeze(1).unsqueeze(1).repeat(1, self.config.num_attention_heads, attention_mask.size(-1), 1)
 
261
  ), "Flash-attention is not available. Please ''pip install flash_attn'', or provide un-packed sequences."
262
  assert not output_attentions, "Output attentions is not supported when sequences are packed."
263
  assert max_seqlen is not None, "Missing max_seqlen. It must be provided when cu_seqlens are not None."
264
+ assert (input_ids if input_ids is not None else inputs_embeds).shape[
265
+ 0
266
+ ] == 1, "Cumulative sequence lengths are provided but inputs are not packed."
267
+ assert (
268
+ input_ids if input_ids is not None else inputs_embeds
269
+ ).is_cuda, "Packing uses an implementation of flash-attention and is only supported on GPU."
270
 
271
  # RoPE
272
+ freqs_cis = (
273
+ self.freqs_cis[position_ids]
274
+ if position_ids is not None
275
+ else self.freqs_cis[: (input_ids if input_ids is not None else inputs_embeds).shape[1]].unsqueeze(0)
276
+ )
277
 
278
  # Embedding
279
+ x = self.encoder(input_ids) if input_ids is not None else inputs_embeds
280
 
281
  # Transformer encoder
282
  for layer in self.transformer_encoder:
 
368
 
369
  def forward(
370
  self,
371
+ input_ids: Optional[torch.Tensor] = None,
372
  position_ids: torch.Tensor = None,
373
  max_seqlen: int = None,
374
  cu_seqlens: torch.Tensor = None,