Feature Extraction
Transformers
Safetensors
ModularStarEncoder
custom_code
andreagurioli1995 commited on
Commit
1bdca4b
·
verified ·
1 Parent(s): 664c618

Upload ModularStarEncoder

Browse files
Files changed (1) hide show
  1. modularStarEncoder.py +5 -47
modularStarEncoder.py CHANGED
@@ -1,39 +1,21 @@
1
  from transformers import AutoConfig, Starcoder2Model, Starcoder2Config
2
  import sys
 
3
  import os
4
- from .config import ModularStarEncoderConfig
5
- import math
6
- import os
7
- import warnings
8
  from dataclasses import dataclass
9
- from typing import List, Optional, Tuple, Union
10
  import sys
11
  import torch
12
  import torch.utils.checkpoint
13
  from torch import nn
14
- from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
15
-
16
  from transformers.activations import ACT2FN
17
- from transformers.modeling_outputs import (
18
- BaseModelOutputWithPastAndCrossAttentions,
19
- BaseModelOutputWithPoolingAndCrossAttentions,
20
- CausalLMOutputWithCrossAttentions,
21
- MaskedLMOutput,
22
- MultipleChoiceModelOutput,
23
- NextSentencePredictorOutput,
24
- QuestionAnsweringModelOutput,
25
- SequenceClassifierOutput,
26
- TokenClassifierOutput,
27
- )
28
  from transformers.modeling_utils import PreTrainedModel
29
- from transformers.pytorch_utils import apply_chunking_to_forward, find_pruneable_heads_and_indices, prune_linear_layer
30
  from transformers.utils import (
31
  ModelOutput,
32
- add_code_sample_docstrings,
33
- add_start_docstrings,
34
- add_start_docstrings_to_model_forward,
35
  logging,
36
- replace_return_docstrings,
37
  )
38
 
39
  logger = logging.get_logger(__name__)
@@ -243,11 +225,7 @@ class ModularStarEncoder(StarEncoder2PreTrainedModel):
243
  # Initialize weights and apply final processing
244
  self.post_init()
245
 
246
- # def get_output_embeddings(self):
247
- # return self.cls.predictions.decoder
248
 
249
- # def set_output_embeddings(self, new_embeddings):
250
- # self.cls.predictions.decoder = new_embeddings
251
 
252
 
253
 
@@ -279,40 +257,20 @@ class ModularStarEncoder(StarEncoder2PreTrainedModel):
279
  kwargs (`Dict[str, any]`, optional, defaults to *{}*):
280
  Used to hide legacy arguments that have been deprecated.
281
 
282
- Returns:
283
 
284
- Example:
285
-
286
- ```python
287
- >>> from transformers import AutoTokenizer, BertForPreTraining
288
- >>> import torch
289
-
290
- >>> tokenizer = AutoTokenizer.from_pretrained("google-bert/bert-base-uncased")
291
- >>> model = BertForPreTraining.from_pretrained("google-bert/bert-base-uncased")
292
-
293
- >>> inputs = tokenizer("Hello, my dog is cute", return_tensors="pt")
294
- >>> outputs = model(**inputs)
295
-
296
- >>> prediction_logits = outputs.prediction_logits
297
- >>> seq_relationship_logits = outputs.seq_relationship_logits
298
- ```
299
  """
300
  return_dict = return_dict if return_dict is not None else self.config.use_return_dict
301
 
302
  outputs = self.starEncoder2(
303
  input_ids,
304
  attention_mask=attention_mask,
305
- # token_type_ids=token_type_ids,
306
  position_ids=position_ids,
307
- # head_mask=head_mask,
308
  inputs_embeds=inputs_embeds,
309
  output_attentions=output_attentions,
310
  output_hidden_states=True,
311
  return_dict=return_dict,
312
  )
313
 
314
-
315
- #TODO FIX FOR EFFICIENCY, COMPUTE FORWARD PASS JUST ON MATRYOSKA LAYERS
316
  #if layer matryoshka on, compute the scores for all the heads
317
  if self.layer_matryoshka_loss:
318
  prediction_scores = []
 
1
  from transformers import AutoConfig, Starcoder2Model, Starcoder2Config
2
  import sys
3
+ from config import ModularStarEncoderConfig
4
  import os
 
 
 
 
5
  from dataclasses import dataclass
6
+ from typing import Optional, Tuple, Union
7
  import sys
8
  import torch
9
  import torch.utils.checkpoint
10
  from torch import nn
11
+ from torch.nn import CrossEntropyLoss
 
12
  from transformers.activations import ACT2FN
 
 
 
 
 
 
 
 
 
 
 
13
  from transformers.modeling_utils import PreTrainedModel
 
14
  from transformers.utils import (
15
  ModelOutput,
16
+
 
 
17
  logging,
18
+
19
  )
20
 
21
  logger = logging.get_logger(__name__)
 
225
  # Initialize weights and apply final processing
226
  self.post_init()
227
 
 
 
228
 
 
 
229
 
230
 
231
 
 
257
  kwargs (`Dict[str, any]`, optional, defaults to *{}*):
258
  Used to hide legacy arguments that have been deprecated.
259
 
 
260
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
261
  """
262
  return_dict = return_dict if return_dict is not None else self.config.use_return_dict
263
 
264
  outputs = self.starEncoder2(
265
  input_ids,
266
  attention_mask=attention_mask,
 
267
  position_ids=position_ids,
 
268
  inputs_embeds=inputs_embeds,
269
  output_attentions=output_attentions,
270
  output_hidden_states=True,
271
  return_dict=return_dict,
272
  )
273
 
 
 
274
  #if layer matryoshka on, compute the scores for all the heads
275
  if self.layer_matryoshka_loss:
276
  prediction_scores = []