Added optional input embeddings to bypass NeoBERT.encoder

#7
by Lolalb - opened
Files changed (1) hide show
  1. model.py +20 -9
model.py CHANGED
@@ -6,7 +6,7 @@ from torch import nn
6
  from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
7
  from torch.nn.functional import scaled_dot_product_attention
8
 
9
- from typing import Optional
10
  import numpy as np
11
 
12
  from xformers.ops import SwiGLU
@@ -190,7 +190,7 @@ class EncoderBlock(nn.Module):
190
  query=xq.transpose(1, 2),
191
  key=xk.transpose(1, 2),
192
  value=xv.transpose(1, 2),
193
- attn_mask=attention_mask.bool(),
194
  dropout_p=0,
195
  ).transpose(1, 2)
196
 
@@ -199,7 +199,6 @@ class EncoderBlock(nn.Module):
199
 
200
  class NeoBERTPreTrainedModel(PreTrainedModel):
201
  config_class = NeoBERTConfig
202
- base_model_prefix = "model"
203
  _supports_cache_class = True
204
 
205
  def _init_weights(self, module):
@@ -234,11 +233,12 @@ class NeoBERT(NeoBERTPreTrainedModel):
234
 
235
  def forward(
236
  self,
237
- input_ids: torch.Tensor,
238
  position_ids: torch.Tensor = None,
239
  max_seqlen: int = None,
240
  cu_seqlens: torch.Tensor = None,
241
  attention_mask: torch.Tensor = None,
 
242
  output_hidden_states: bool = False,
243
  output_attentions: bool = False,
244
  **kwargs,
@@ -246,6 +246,9 @@ class NeoBERT(NeoBERTPreTrainedModel):
246
  # Initialize
247
  hidden_states, attentions = [], []
248
 
 
 
 
249
  # Expand and repeat: (Batch, Length) -> (Batch, Heads, Length, Length)
250
  if attention_mask is not None:
251
  attention_mask = attention_mask.unsqueeze(1).unsqueeze(1).repeat(1, self.config.num_attention_heads, attention_mask.size(-1), 1)
@@ -257,14 +260,22 @@ class NeoBERT(NeoBERTPreTrainedModel):
257
  ), "Flash-attention is not available. Please ''pip install flash_attn'', or provide un-packed sequences."
258
  assert not output_attentions, "Output attentions is not supported when sequences are packed."
259
  assert max_seqlen is not None, "Missing max_seqlen. It must be provided when cu_seqlens are not None."
260
- assert input_ids.shape[0] == 1, "Cumulative sequence lengths are provided but input_ids are not packed."
261
- assert input_ids.is_cuda, "Packing uses an implementation of flash-attention and is only supported on GPU."
 
 
 
 
262
 
263
  # RoPE
264
- freqs_cis = self.freqs_cis[position_ids] if position_ids is not None else self.freqs_cis[: input_ids.shape[1]].unsqueeze(0)
 
 
 
 
265
 
266
  # Embedding
267
- x = self.encoder(input_ids)
268
 
269
  # Transformer encoder
270
  for layer in self.transformer_encoder:
@@ -356,7 +367,7 @@ class NeoBERTForSequenceClassification(NeoBERTPreTrainedModel):
356
 
357
  def forward(
358
  self,
359
- input_ids: torch.Tensor,
360
  position_ids: torch.Tensor = None,
361
  max_seqlen: int = None,
362
  cu_seqlens: torch.Tensor = None,
 
6
  from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
7
  from torch.nn.functional import scaled_dot_product_attention
8
 
9
+ from typing import Optional, Tuple
10
  import numpy as np
11
 
12
  from xformers.ops import SwiGLU
 
190
  query=xq.transpose(1, 2),
191
  key=xk.transpose(1, 2),
192
  value=xv.transpose(1, 2),
193
+ attn_mask=attention_mask,
194
  dropout_p=0,
195
  ).transpose(1, 2)
196
 
 
199
 
200
  class NeoBERTPreTrainedModel(PreTrainedModel):
201
  config_class = NeoBERTConfig
 
202
  _supports_cache_class = True
203
 
204
  def _init_weights(self, module):
 
233
 
234
  def forward(
235
  self,
236
+ input_ids: Optional[torch.Tensor] = None,
237
  position_ids: torch.Tensor = None,
238
  max_seqlen: int = None,
239
  cu_seqlens: torch.Tensor = None,
240
  attention_mask: torch.Tensor = None,
241
+ inputs_embeds: Optional[torch.Tensor] = None,
242
  output_hidden_states: bool = False,
243
  output_attentions: bool = False,
244
  **kwargs,
 
246
  # Initialize
247
  hidden_states, attentions = [], []
248
 
249
+ if (input_ids is None) ^ (inputs_embeds is not None):
250
+ raise ValueError("You must specify exactly one of input_ids or inputs_embeds")
251
+
252
  # Expand and repeat: (Batch, Length) -> (Batch, Heads, Length, Length)
253
  if attention_mask is not None:
254
  attention_mask = attention_mask.unsqueeze(1).unsqueeze(1).repeat(1, self.config.num_attention_heads, attention_mask.size(-1), 1)
 
260
  ), "Flash-attention is not available. Please ''pip install flash_attn'', or provide un-packed sequences."
261
  assert not output_attentions, "Output attentions is not supported when sequences are packed."
262
  assert max_seqlen is not None, "Missing max_seqlen. It must be provided when cu_seqlens are not None."
263
+ assert (input_ids if input_ids is not None else inputs_embeds).shape[
264
+ 0
265
+ ] == 1, "Cumulative sequence lengths are provided but inputs are not packed."
266
+ assert (
267
+ input_ids if input_ids is not None else inputs_embeds
268
+ ).is_cuda, "Packing uses an implementation of flash-attention and is only supported on GPU."
269
 
270
  # RoPE
271
+ freqs_cis = (
272
+ self.freqs_cis[position_ids]
273
+ if position_ids is not None
274
+ else self.freqs_cis[: (input_ids if input_ids is not None else inputs_embeds).shape[1]].unsqueeze(0)
275
+ )
276
 
277
  # Embedding
278
+ x = self.encoder(input_ids) if input_ids is not None else inputs_embeds
279
 
280
  # Transformer encoder
281
  for layer in self.transformer_encoder:
 
367
 
368
  def forward(
369
  self,
370
+ input_ids: Optional[torch.Tensor] = None,
371
  position_ids: torch.Tensor = None,
372
  max_seqlen: int = None,
373
  cu_seqlens: torch.Tensor = None,