Feature Extraction
Transformers
Safetensors
ModularStarEncoder
custom_code
andreagurioli1995 commited on
Commit
79c1c23
·
verified ·
1 Parent(s): 8eb0731

Upload ModularStarEncoder

Browse files
Files changed (1) hide show
  1. modularStarEncoder.py +5 -11
modularStarEncoder.py CHANGED
@@ -1,4 +1,4 @@
1
- from transformers import AutoConfig, Starcoder2Model, Starcoder2Config
2
  import sys
3
  from config import ModularStarEncoderConfig
4
  import os
@@ -13,7 +13,6 @@ from transformers.activations import ACT2FN
13
  from transformers.modeling_utils import PreTrainedModel
14
  from transformers.utils import (
15
  ModelOutput,
16
-
17
  logging,
18
 
19
  )
@@ -34,9 +33,6 @@ class StarEncoder2PreTrainedModel(PreTrainedModel):
34
  _supports_sdpa = True
35
  _supports_cache_class = True
36
 
37
- # def __init__(self):
38
- # self._supports_flash_attn_2 = True
39
- # super().__init__()
40
 
41
 
42
  def _init_weights(self, module):
@@ -81,7 +77,7 @@ class ModularStarEncoderOutput(ModelOutput):
81
  prediction_logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
82
  Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
83
  seq_relationship_logits (`torch.FloatTensor` of shape `(batch_size, 2)`):
84
- Prediction scores of the next sequence prediction (classification) head (scores of True/False continuation
85
  before SoftMax).
86
  hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
87
  Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
@@ -249,11 +245,9 @@ class ModularStarEncoder(StarEncoder2PreTrainedModel):
249
  config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked),
250
  the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`
251
  next_sentence_label (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
252
- Labels for computing the next sequence prediction (classification) loss. Input should be a sequence
253
- pair (see `input_ids` docstring) Indices should be in `[0, 1]`:
254
-
255
- - 0 indicates sequence B is a continuation of sequence A,
256
- - 1 indicates sequence B is a random sequence.
257
  kwargs (`Dict[str, any]`, optional, defaults to *{}*):
258
  Used to hide legacy arguments that have been deprecated.
259
 
 
1
+ from transformers import Starcoder2Model
2
  import sys
3
  from config import ModularStarEncoderConfig
4
  import os
 
13
  from transformers.modeling_utils import PreTrainedModel
14
  from transformers.utils import (
15
  ModelOutput,
 
16
  logging,
17
 
18
  )
 
33
  _supports_sdpa = True
34
  _supports_cache_class = True
35
 
 
 
 
36
 
37
 
38
  def _init_weights(self, module):
 
77
  prediction_logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
78
  Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
79
  seq_relationship_logits (`torch.FloatTensor` of shape `(batch_size, 2)`):
80
+ Prediction scores of the in context classification (classification) head (scores of True/False continuation
81
  before SoftMax).
82
  hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
83
  Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
 
245
  config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked),
246
  the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`
247
  next_sentence_label (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
248
+ This label is assigned to the in context loss:
249
+ - 0 indicates sequence B belongs to the same repository of A,
250
+ - 1 indicates sequence B is a random repository.
 
 
251
  kwargs (`Dict[str, any]`, optional, defaults to *{}*):
252
  Used to hide legacy arguments that have been deprecated.
253