Feature Extraction
Transformers
Safetensors
ModularStarEncoder
custom_code
andreagurioli1995 commited on
Commit
8ec5243
·
verified ·
1 Parent(s): 0849be3

Upload ModularStarEncoder

Browse files
Files changed (5) hide show
  1. README.md +199 -0
  2. config.json +44 -0
  3. config.py +81 -0
  4. model.safetensors +3 -0
  5. modularStarEncoder.py +366 -0
README.md ADDED
@@ -0,0 +1,199 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ library_name: transformers
3
+ tags: []
4
+ ---
5
+
6
+ # Model Card for Model ID
7
+
8
+ <!-- Provide a quick summary of what the model is/does. -->
9
+
10
+
11
+
12
+ ## Model Details
13
+
14
+ ### Model Description
15
+
16
+ <!-- Provide a longer summary of what this model is. -->
17
+
18
+ This is the model card of a 🤗 transformers model that has been pushed on the Hub. This model card has been automatically generated.
19
+
20
+ - **Developed by:** [More Information Needed]
21
+ - **Funded by [optional]:** [More Information Needed]
22
+ - **Shared by [optional]:** [More Information Needed]
23
+ - **Model type:** [More Information Needed]
24
+ - **Language(s) (NLP):** [More Information Needed]
25
+ - **License:** [More Information Needed]
26
+ - **Finetuned from model [optional]:** [More Information Needed]
27
+
28
+ ### Model Sources [optional]
29
+
30
+ <!-- Provide the basic links for the model. -->
31
+
32
+ - **Repository:** [More Information Needed]
33
+ - **Paper [optional]:** [More Information Needed]
34
+ - **Demo [optional]:** [More Information Needed]
35
+
36
+ ## Uses
37
+
38
+ <!-- Address questions around how the model is intended to be used, including the foreseeable users of the model and those affected by the model. -->
39
+
40
+ ### Direct Use
41
+
42
+ <!-- This section is for the model use without fine-tuning or plugging into a larger ecosystem/app. -->
43
+
44
+ [More Information Needed]
45
+
46
+ ### Downstream Use [optional]
47
+
48
+ <!-- This section is for the model use when fine-tuned for a task, or when plugged into a larger ecosystem/app -->
49
+
50
+ [More Information Needed]
51
+
52
+ ### Out-of-Scope Use
53
+
54
+ <!-- This section addresses misuse, malicious use, and uses that the model will not work well for. -->
55
+
56
+ [More Information Needed]
57
+
58
+ ## Bias, Risks, and Limitations
59
+
60
+ <!-- This section is meant to convey both technical and sociotechnical limitations. -->
61
+
62
+ [More Information Needed]
63
+
64
+ ### Recommendations
65
+
66
+ <!-- This section is meant to convey recommendations with respect to the bias, risk, and technical limitations. -->
67
+
68
+ Users (both direct and downstream) should be made aware of the risks, biases and limitations of the model. More information needed for further recommendations.
69
+
70
+ ## How to Get Started with the Model
71
+
72
+ Use the code below to get started with the model.
73
+
74
+ [More Information Needed]
75
+
76
+ ## Training Details
77
+
78
+ ### Training Data
79
+
80
+ <!-- This should link to a Dataset Card, perhaps with a short stub of information on what the training data is all about as well as documentation related to data pre-processing or additional filtering. -->
81
+
82
+ [More Information Needed]
83
+
84
+ ### Training Procedure
85
+
86
+ <!-- This relates heavily to the Technical Specifications. Content here should link to that section when it is relevant to the training procedure. -->
87
+
88
+ #### Preprocessing [optional]
89
+
90
+ [More Information Needed]
91
+
92
+
93
+ #### Training Hyperparameters
94
+
95
+ - **Training regime:** [More Information Needed] <!--fp32, fp16 mixed precision, bf16 mixed precision, bf16 non-mixed precision, fp16 non-mixed precision, fp8 mixed precision -->
96
+
97
+ #### Speeds, Sizes, Times [optional]
98
+
99
+ <!-- This section provides information about throughput, start/end time, checkpoint size if relevant, etc. -->
100
+
101
+ [More Information Needed]
102
+
103
+ ## Evaluation
104
+
105
+ <!-- This section describes the evaluation protocols and provides the results. -->
106
+
107
+ ### Testing Data, Factors & Metrics
108
+
109
+ #### Testing Data
110
+
111
+ <!-- This should link to a Dataset Card if possible. -->
112
+
113
+ [More Information Needed]
114
+
115
+ #### Factors
116
+
117
+ <!-- These are the things the evaluation is disaggregating by, e.g., subpopulations or domains. -->
118
+
119
+ [More Information Needed]
120
+
121
+ #### Metrics
122
+
123
+ <!-- These are the evaluation metrics being used, ideally with a description of why. -->
124
+
125
+ [More Information Needed]
126
+
127
+ ### Results
128
+
129
+ [More Information Needed]
130
+
131
+ #### Summary
132
+
133
+
134
+
135
+ ## Model Examination [optional]
136
+
137
+ <!-- Relevant interpretability work for the model goes here -->
138
+
139
+ [More Information Needed]
140
+
141
+ ## Environmental Impact
142
+
143
+ <!-- Total emissions (in grams of CO2eq) and additional considerations, such as electricity usage, go here. Edit the suggested text below accordingly -->
144
+
145
+ Carbon emissions can be estimated using the [Machine Learning Impact calculator](https://mlco2.github.io/impact#compute) presented in [Lacoste et al. (2019)](https://arxiv.org/abs/1910.09700).
146
+
147
+ - **Hardware Type:** [More Information Needed]
148
+ - **Hours used:** [More Information Needed]
149
+ - **Cloud Provider:** [More Information Needed]
150
+ - **Compute Region:** [More Information Needed]
151
+ - **Carbon Emitted:** [More Information Needed]
152
+
153
+ ## Technical Specifications [optional]
154
+
155
+ ### Model Architecture and Objective
156
+
157
+ [More Information Needed]
158
+
159
+ ### Compute Infrastructure
160
+
161
+ [More Information Needed]
162
+
163
+ #### Hardware
164
+
165
+ [More Information Needed]
166
+
167
+ #### Software
168
+
169
+ [More Information Needed]
170
+
171
+ ## Citation [optional]
172
+
173
+ <!-- If there is a paper or blog post introducing the model, the APA and Bibtex information for that should go in this section. -->
174
+
175
+ **BibTeX:**
176
+
177
+ [More Information Needed]
178
+
179
+ **APA:**
180
+
181
+ [More Information Needed]
182
+
183
+ ## Glossary [optional]
184
+
185
+ <!-- If relevant, include terms and calculations in this section that can help readers understand the model or model card. -->
186
+
187
+ [More Information Needed]
188
+
189
+ ## More Information [optional]
190
+
191
+ [More Information Needed]
192
+
193
+ ## Model Card Authors [optional]
194
+
195
+ [More Information Needed]
196
+
197
+ ## Model Card Contact
198
+
199
+ [More Information Needed]
config.json ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "ModularStarEncoder"
4
+ ],
5
+ "attention_dropout": 0.1,
6
+ "auto_map": {
7
+ "AutoConfig": "config.ModularStarEncoderConfig",
8
+ "AutoModel": "modularStarEncoder.ModularStarEncoder"
9
+ },
10
+ "bos_token_id": 0,
11
+ "conditional_size": 4,
12
+ "embedding_dropout": 0.1,
13
+ "eos_token_id": 0,
14
+ "hidden_act": "gelu_pytorch_tanh",
15
+ "hidden_size": 1024,
16
+ "initializer_range": 0.018042,
17
+ "intermediate_size": 12288,
18
+ "keys_to_ignore_at_inference": "past_key_values",
19
+ "layer_matryoshka_loss": true,
20
+ "layer_norm_eps": 1e-05,
21
+ "matryoshka_layers": [
22
+ 4,
23
+ 9,
24
+ 18,
25
+ 27,
26
+ 36
27
+ ],
28
+ "max_position_embeddings": 2048,
29
+ "mlp_type": "default",
30
+ "model_type": "ModularStarEncoder",
31
+ "norm_epsilon": 1e-05,
32
+ "norm_type": "layer_norm",
33
+ "num_attention_heads": 16,
34
+ "num_hidden_layers": 36,
35
+ "num_key_value_heads": 4,
36
+ "residual_dropout": 0.1,
37
+ "rope_theta": 999999.4420358813,
38
+ "sliding_window": null,
39
+ "torch_dtype": "bfloat16",
40
+ "transformers_version": "4.39.3",
41
+ "use_bias": true,
42
+ "use_cache": false,
43
+ "vocab_size": 49156
44
+ }
config.py ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import PretrainedConfig
2
+ from typing import List
3
+
4
+
5
+ #STARCODER2_PRETRAINED_CONFIG_ARCHIVE_MAP = {}
6
+
7
+ class ModularStarEncoderConfig(PretrainedConfig):
8
+ model_type = "ModularStarEncoder"
9
+ keys_to_ignore_at_inference = ["past_key_values"]
10
+
11
+ def __init__(
12
+ self,
13
+ attention_dropout= 0.1,
14
+ residual_dropout= 0.1,
15
+ embedding_dropout= 0.1,
16
+ bos_token_id= 0,
17
+ eos_token_id= 0,
18
+ hidden_act= "gelu_pytorch_tanh",
19
+ _attn_implementation="flash_attention_2",
20
+ hidden_size= 1024,
21
+ conditional_size= 4,
22
+ initializer_range= 0.018042,
23
+ intermediate_size= 12288,
24
+ max_position_embeddings= 2048,
25
+ mlp_type= "default",
26
+ model_type= "starcoder2",
27
+ torch_dtype= "bfloat16",
28
+ layer_matryoshka_loss= True,
29
+ matryoshka_layers= [4,9,18,27,36],
30
+ norm_epsilon= 1e-05,
31
+ layer_norm_eps=1e-05,
32
+ norm_type= "layer_norm",
33
+ num_attention_heads= 16,
34
+ num_hidden_layers= 36,
35
+ num_key_value_heads= 4,
36
+ rope_theta= 999999.4420358813,
37
+ sliding_window= None,
38
+ transformers_version= "4.39.3",
39
+ use_bias= True,
40
+ use_cache= False,
41
+ vocab_size= 49156,
42
+ pad_token_id=0,
43
+ **kwargs,
44
+ ):
45
+ if _attn_implementation not in ["flash_attention_2", "sdpa"]:
46
+ raise ValueError(f"`_attn_implementation` must be 'flash_attention_2', 'sdpa', got {_attn_implementation}.")
47
+
48
+ self.attention_dropout=attention_dropout ,
49
+ self.residual_dropout= residual_dropout,
50
+ self.embedding_dropout= embedding_dropout,
51
+ self.bos_token_id= bos_token_id,
52
+ self.eos_token_id= eos_token_id,
53
+ self.hidden_act= hidden_act,
54
+ self._attn_implementation=_attn_implementation,
55
+ self.hidden_size= hidden_size,
56
+ self.conditional_size= conditional_size,
57
+ self.initializer_range= initializer_range,
58
+ self.intermediate_size= intermediate_size,
59
+ self.max_position_embeddings= max_position_embeddings,
60
+ self.mlp_type= mlp_type,
61
+ self.model_type= model_type,
62
+ self.torch_dtype= torch_dtype,
63
+ self.layer_matryoshka_loss= layer_matryoshka_loss,
64
+ self.matryoshka_layers= matryoshka_layers,
65
+ self.norm_epsilon= norm_epsilon,
66
+ self.layer_norm_eps=layer_norm_eps,
67
+ self.norm_type= norm_type,
68
+ self.num_attention_heads= num_attention_heads,
69
+ self.num_hidden_layers= num_hidden_layers,
70
+ self.num_key_value_heads= num_key_value_heads,
71
+ self.rope_theta= rope_theta,
72
+ self.sliding_window= sliding_window,
73
+ self.transformers_version= transformers_version,
74
+ self.use_bias= use_bias,
75
+ self.use_cache= use_cache,
76
+ self.vocab_size= vocab_size,
77
+ self.pad_token_id=pad_token_id,
78
+ super().__init__(
79
+ bos_token_id=bos_token_id,
80
+ eos_token_id=eos_token_id,
81
+ **kwargs)
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0390c498cc4365f211adc1dbab66d6fd2aa1ddd8dadd8ac99f924ddd32760cf2
3
+ size 327340210
modularStarEncoder.py ADDED
@@ -0,0 +1,366 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import Starcoder2Model
2
+ import sys
3
+ from config import ModularStarEncoderConfig
4
+ import os
5
+ from dataclasses import dataclass
6
+ from typing import Optional, Tuple, Union, List
7
+ import sys
8
+ import torch
9
+ import torch.utils.checkpoint
10
+ from torch import nn
11
+ from torch.nn import CrossEntropyLoss
12
+ from transformers.activations import ACT2FN
13
+ from transformers.modeling_utils import PreTrainedModel
14
+ from transformers.utils import (
15
+ ModelOutput,
16
+ logging,
17
+
18
+ )
19
+
20
+ logger = logging.get_logger(__name__)
21
+
22
+ class StarEncoder2PreTrainedModel(PreTrainedModel):
23
+ """
24
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
25
+ models.
26
+ """
27
+
28
+ config_class = ModularStarEncoderConfig
29
+ base_model_prefix = "ModularStarEncoder"
30
+ model_type = "ModularStarEncoder"
31
+ supports_gradient_checkpointing = True
32
+ _supports_flash_attn_2 = True
33
+ _supports_sdpa = True
34
+ _supports_cache_class = True
35
+
36
+
37
+
38
+ def _init_weights(self, module):
39
+ """Initialize the weights"""
40
+ if isinstance(module, nn.Linear):
41
+ # Slightly different from the TF version which uses truncated_normal for initialization
42
+ # cf https://github.com/pytorch/pytorch/pull/5617
43
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
44
+ if module.bias is not None:
45
+ module.bias.data.zero_()
46
+ elif isinstance(module, nn.Embedding):
47
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
48
+ if module.padding_idx is not None:
49
+ module.weight.data[module.padding_idx].zero_()
50
+ elif isinstance(module, nn.LayerNorm):
51
+ module.bias.data.zero_()
52
+ module.weight.data.fill_(1.0)
53
+
54
+ class StarEncoder2Pooler(nn.Module):
55
+ def __init__(self, config):
56
+ super().__init__()
57
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
58
+ self.activation = nn.Tanh()
59
+
60
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
61
+ # We "pool" the model by simply taking the hidden state corresponding
62
+ # to the last token.
63
+ last_token_tensor = hidden_states[:, -1]
64
+ pooled_output = self.dense(last_token_tensor)
65
+ pooled_output = self.activation(pooled_output)
66
+ return pooled_output
67
+
68
+ @dataclass
69
+ class ModularStarEncoderOutput(ModelOutput):
70
+ """
71
+ Output type of [`BertForPreTraining`].
72
+
73
+ Args:
74
+ loss (*optional*, returned when `labels` is provided, `torch.FloatTensor` of shape `(1,)`):
75
+ Total loss as the sum of the masked language modeling loss and the next sequence prediction
76
+ (classification) loss.
77
+ prediction_logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
78
+ Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
79
+ seq_relationship_logits (`torch.FloatTensor` of shape `(batch_size, 2)`):
80
+ Prediction scores of the in context classification (classification) head (scores of True/False continuation
81
+ before SoftMax).
82
+ hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
83
+ Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
84
+ shape `(batch_size, sequence_length, hidden_size)`.
85
+
86
+ Hidden-states of the model at the output of each layer plus the initial embedding outputs.
87
+ attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
88
+ Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
89
+ sequence_length)`.
90
+
91
+ Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
92
+ heads.
93
+ """
94
+
95
+ projected_pooled_normalized: Optional[List[torch.FloatTensor]] = None
96
+ raw_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
97
+ attentions: Optional[Tuple[torch.FloatTensor]] = None
98
+
99
+
100
+
101
+
102
+
103
+
104
+
105
+
106
+
107
+
108
+
109
+ def forward(self, sequence_output, pooled_output,idx_layer: Optional[torch.Tensor] = None):
110
+ if self.is_matryoshka:
111
+ device_sequence = sequence_output.get_device()
112
+ if device_sequence<0:
113
+ device_sequence = "cpu"
114
+ prediction_scores = self.predictions(torch.cat([sequence_output , self.conditional_embeddings(torch.tensor(idx_layer,device=device_sequence).int()).expand(sequence_output.size()[0],sequence_output.size()[1],-1)],dim=-1))
115
+ seq_relationship_score = self.seq_relationship(torch.cat([pooled_output , self.conditional_embeddings(torch.tensor(idx_layer,device=device_sequence).int()).expand(pooled_output.size()[0],-1)],dim=-1))
116
+ else:
117
+ prediction_scores = self.predictions(sequence_output)
118
+ seq_relationship_score = self.seq_relationship(pooled_output)
119
+ return prediction_scores, seq_relationship_score
120
+
121
+
122
+ def normalize(my_tensor):
123
+ embedding_norms = my_tensor.norm(dim=0)
124
+
125
+ normalizing_factor = torch.where( # Only normalize embeddings with norm > 1.0.
126
+ embedding_norms > 1.0, embedding_norms, torch.tensor(1)
127
+ )
128
+
129
+ normalized_tensor = my_tensor / normalizing_factor
130
+ return normalized_tensor
131
+ def pooling(x: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
132
+ """Pools a batch of vector sequences into a batch of vector global representations.
133
+ It does so by taking the average representation of the sequence, as indicated by the mask.
134
+
135
+ Args:
136
+ x (torch.Tensor): Batch of vector sequences with shape [B, T, F].
137
+ mask (torch.Tensor): Batch of masks with shape [B, T].
138
+
139
+ Returns:
140
+ torch.Tensor: Pooled version of the input batch with shape [B, F].
141
+ """
142
+
143
+ # Expand the mask to match the feature dimensions for proper masking
144
+ mask_expanded = mask.unsqueeze(-1) # Shape [B, T, 1]
145
+
146
+ # Apply the mask to the input tensor
147
+ masked_x = x * mask_expanded # Shape [B, T, F]
148
+ # Sum along the time dimension
149
+ sum_x = masked_x.sum(dim=1) # Shape [B, F]
150
+ # Calculate the length of valid (non-padded) elements
151
+ valid_lengths = mask.sum(dim=1).clamp(min=1).unsqueeze(-1) # Shape [B, 1]
152
+ # Calculate the average pooling, avoiding division by zero
153
+ pooled_x = sum_x / valid_lengths # Shape [B, F]
154
+
155
+ return pooled_x
156
+
157
+ def pool_and_normalize(
158
+ features_sequence: torch.Tensor,
159
+ attention_masks: torch.Tensor,
160
+ return_norms: bool = False,
161
+ ) -> Union[torch.Tensor, List[torch.Tensor]]:
162
+ """Temporal ooling of sequences of vectors and projection onto the unit sphere.
163
+
164
+ Args:
165
+ features_sequence (torch.Tensor): Inpute features with shape [B, T, F].
166
+ attention_masks (torch.Tensor): Pooling masks with shape [B, T, F].
167
+ return_norms (bool, optional): Whether to additionally return the norms. Defaults to False.
168
+
169
+ Returns:
170
+ Union[torch.Tensor, List[torch.Tensor]]: Pooled and normalized vectors with shape [B, F].
171
+ """
172
+
173
+ pooled_embeddings = pooling(features_sequence, attention_masks)
174
+ embedding_norms = pooled_embeddings.norm(dim=1)
175
+
176
+ normalizing_factor = torch.where( # Only normalize embeddings with norm > 1.0.
177
+ embedding_norms > 1.0, embedding_norms, torch.ones_like(embedding_norms)
178
+ )
179
+
180
+ pooled_normalized_embeddings = pooled_embeddings / normalizing_factor[:, None]
181
+
182
+ if return_norms:
183
+ return pooled_normalized_embeddings, embedding_norms
184
+ else:
185
+ return pooled_normalized_embeddings
186
+
187
+ def get_pooling_mask(
188
+ input_ids: torch.Tensor, sep_token_id: Union[int, float]
189
+ ) -> torch.Tensor:
190
+ """Gets pooling masks. For a sequence of input tokens, the mask will be
191
+ a sequence of zeros up until the first [SEP] occurrence, and 1 after that.
192
+
193
+ Args:
194
+ input_ids (torch.Tensor): Batch of input ids with shape [B, T].
195
+ sep_token_id (Union[int, float]): Id for [SEP] token.
196
+
197
+ Returns:
198
+ torch.Tensor: Batch of pooling masks with shape [B, T]
199
+ """
200
+ # idx indicates the first occurrence of sep_token_id per along dim 0 of input_ids
201
+ idx = (input_ids == sep_token_id).float().flip(1).argmax(1)
202
+
203
+ idx = input_ids.size(-1)-idx-1
204
+
205
+ repeated_idx = idx.unsqueeze(1).repeat(1, input_ids.size(1))
206
+
207
+ ranges = torch.arange(input_ids.size(1)).repeat(input_ids.size(0), 1)
208
+
209
+ pooling_mask = (repeated_idx <= ranges).long()
210
+
211
+ return pooling_mask
212
+
213
+ def adapt_model(model,config,till_layer:27):
214
+ model = model.starEncoder2
215
+
216
+ encoder_config = config
217
+ layers = encoder_config.matryoshka_layers
218
+ feature_dim = encoder_config.hidden_size
219
+
220
+ model.projection_heads = torch.nn.ModuleList()
221
+ if till_layer:
222
+ print(f"ATTENTION: till layer is on, you are pruning the model keeping just the first {till_layer} layers")
223
+ model.layers = model.layers[:till_layer]
224
+ model.projection_heads.append(torch.nn.Sequential(
225
+ torch.nn.Linear(feature_dim, feature_dim),
226
+ torch.nn.LeakyReLU(),
227
+ torch.nn.Linear(feature_dim, feature_dim),
228
+ ))
229
+ else:
230
+ for layer in layers:
231
+ model.projection_heads.append(torch.nn.Sequential(
232
+ torch.nn.Linear(feature_dim, feature_dim),
233
+ torch.nn.LeakyReLU(),
234
+ torch.nn.Linear(feature_dim, feature_dim),
235
+ ))
236
+ #setting off causal masking
237
+ for layer in model.layers:
238
+ layer.self_attn.is_causal=False
239
+
240
+ model.temperature_coef = torch.nn.Parameter(torch.Tensor([10.0]),requires_grad=False)
241
+
242
+ return model
243
+
244
+ class ModularStarEncoder(StarEncoder2PreTrainedModel):
245
+ _tied_weights_keys = ["predictions.decoder.bias", "cls.predictions.decoder.weight"]
246
+ config_class = ModularStarEncoderConfig
247
+ def __init__(self, config):
248
+ super().__init__(config)
249
+ self.model_type = "ModularStarEncoder"
250
+ for element in dir(config):
251
+ value = getattr(config, element) # Get the attribute value
252
+ if (isinstance(value, tuple) or isinstance(value, list)) and len(value)>0:
253
+ setattr(config, element, value[0])
254
+ self.layer_matryoshka_loss = config.layer_matryoshka_loss
255
+ self.matryoshka_layers = config.matryoshka_layers
256
+
257
+
258
+ self.starEncoder2 = Starcoder2Model(config)
259
+
260
+
261
+ #setting off causal masking
262
+ for layer in self.starEncoder2.layers:
263
+ layer.self_attn.is_causal=False
264
+ # Initialize weights and apply final processing
265
+ self.post_init()
266
+ self.till_layer= 4
267
+ self.starEncoder2 = adapt_model(self ,config=config,till_layer=self.till_layer)
268
+
269
+
270
+
271
+
272
+
273
+ def forward(
274
+ self,
275
+ input_ids: Optional[torch.Tensor] = None,
276
+ attention_mask: Optional[torch.Tensor] = None,
277
+ #token_type_ids: Optional[torch.Tensor] = None,
278
+ position_ids: Optional[torch.Tensor] = None,
279
+ head_mask: Optional[torch.Tensor] = None,
280
+ inputs_embeds: Optional[torch.Tensor] = None,
281
+ labels: Optional[torch.Tensor] = None,
282
+ next_sentence_label: Optional[torch.Tensor] = None,
283
+ output_attentions: Optional[bool] = None,
284
+ output_hidden_states: Optional[bool] = None,
285
+ sep_token_id:Optional[int] = 49152,
286
+ ) -> Union[Tuple[torch.Tensor], ModularStarEncoderOutput]:
287
+ r"""
288
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
289
+ Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ...,
290
+ config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked),
291
+ the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`
292
+ next_sentence_label (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
293
+ This label is assigned to the in context loss:
294
+ - 0 indicates sequence B belongs to the same repository of A,
295
+ - 1 indicates sequence B is a random repository.
296
+ kwargs (`Dict[str, any]`, optional, defaults to *{}*):
297
+ Used to hide legacy arguments that have been deprecated.
298
+
299
+
300
+ """
301
+
302
+ source_embedding = self.starEncoder2(
303
+ input_ids,
304
+ attention_mask=attention_mask,
305
+ position_ids=position_ids,
306
+ inputs_embeds=inputs_embeds,
307
+ output_attentions=output_attentions,
308
+ output_hidden_states=True,
309
+ return_dict=True,
310
+ )
311
+
312
+
313
+ DEVICE = source_embedding.hidden_states[-1].get_device()
314
+ if DEVICE<0:
315
+ DEVICE = "cpu"
316
+
317
+ try:
318
+ projection_fn = self.starEncoder2.module.projection_heads
319
+ temp_coef = self.starEncoder2.module.temperature_coef
320
+ except AttributeError:
321
+ projection_fn = self.starEncoder2.projection_heads
322
+ temp_coef = self.starEncoder2.temperature_coef
323
+
324
+ for head in projection_fn:
325
+ head.to(DEVICE)
326
+ temp_coef.to(DEVICE)
327
+
328
+
329
+
330
+
331
+ pooling_mask_source_targtes = get_pooling_mask(
332
+ input_ids, sep_token_id
333
+ ) # Pooling masks indicate the second [SEP] occurrence, 0 till SEP, then all ones.
334
+
335
+ if self.till_layer:
336
+ self.matryoshka_layers=[self.till_layer]
337
+
338
+ pooled_and_normalized = []
339
+ for idx,matr_layer in enumerate(self.matryoshka_layers):
340
+ source_embedding_proj = projection_fn[idx](source_embedding.hidden_states[matr_layer])
341
+
342
+ normalized_source_embedding, embedding_norms = pool_and_normalize(
343
+ source_embedding_proj,
344
+ pooling_mask_source_targtes,
345
+ return_norms=True,
346
+ )
347
+
348
+ pooled_and_normalized.append(normalized_source_embedding)
349
+
350
+ if not self.till_layer:
351
+ return ModularStarEncoderOutput(
352
+ projected_pooled_normalized = pooled_and_normalized,
353
+ raw_hidden_states=source_embedding.hidden_states,
354
+ attentions=source_embedding.attentions,
355
+ )
356
+ else:
357
+ return ModularStarEncoderOutput(
358
+ projected_pooled_normalized = pooled_and_normalized[0],
359
+ raw_hidden_states=source_embedding.hidden_states,
360
+ attentions=source_embedding.attentions,
361
+ )
362
+
363
+
364
+
365
+
366
+