Upload ModularStarEncoder
Browse files- README.md +199 -0
- config.json +44 -0
- config.py +81 -0
- model.safetensors +3 -0
- modularStarEncoder.py +356 -0
README.md
ADDED
@@ -0,0 +1,199 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
library_name: transformers
|
3 |
+
tags: []
|
4 |
+
---
|
5 |
+
|
6 |
+
# Model Card for Model ID
|
7 |
+
|
8 |
+
<!-- Provide a quick summary of what the model is/does. -->
|
9 |
+
|
10 |
+
|
11 |
+
|
12 |
+
## Model Details
|
13 |
+
|
14 |
+
### Model Description
|
15 |
+
|
16 |
+
<!-- Provide a longer summary of what this model is. -->
|
17 |
+
|
18 |
+
This is the model card of a 🤗 transformers model that has been pushed on the Hub. This model card has been automatically generated.
|
19 |
+
|
20 |
+
- **Developed by:** [More Information Needed]
|
21 |
+
- **Funded by [optional]:** [More Information Needed]
|
22 |
+
- **Shared by [optional]:** [More Information Needed]
|
23 |
+
- **Model type:** [More Information Needed]
|
24 |
+
- **Language(s) (NLP):** [More Information Needed]
|
25 |
+
- **License:** [More Information Needed]
|
26 |
+
- **Finetuned from model [optional]:** [More Information Needed]
|
27 |
+
|
28 |
+
### Model Sources [optional]
|
29 |
+
|
30 |
+
<!-- Provide the basic links for the model. -->
|
31 |
+
|
32 |
+
- **Repository:** [More Information Needed]
|
33 |
+
- **Paper [optional]:** [More Information Needed]
|
34 |
+
- **Demo [optional]:** [More Information Needed]
|
35 |
+
|
36 |
+
## Uses
|
37 |
+
|
38 |
+
<!-- Address questions around how the model is intended to be used, including the foreseeable users of the model and those affected by the model. -->
|
39 |
+
|
40 |
+
### Direct Use
|
41 |
+
|
42 |
+
<!-- This section is for the model use without fine-tuning or plugging into a larger ecosystem/app. -->
|
43 |
+
|
44 |
+
[More Information Needed]
|
45 |
+
|
46 |
+
### Downstream Use [optional]
|
47 |
+
|
48 |
+
<!-- This section is for the model use when fine-tuned for a task, or when plugged into a larger ecosystem/app -->
|
49 |
+
|
50 |
+
[More Information Needed]
|
51 |
+
|
52 |
+
### Out-of-Scope Use
|
53 |
+
|
54 |
+
<!-- This section addresses misuse, malicious use, and uses that the model will not work well for. -->
|
55 |
+
|
56 |
+
[More Information Needed]
|
57 |
+
|
58 |
+
## Bias, Risks, and Limitations
|
59 |
+
|
60 |
+
<!-- This section is meant to convey both technical and sociotechnical limitations. -->
|
61 |
+
|
62 |
+
[More Information Needed]
|
63 |
+
|
64 |
+
### Recommendations
|
65 |
+
|
66 |
+
<!-- This section is meant to convey recommendations with respect to the bias, risk, and technical limitations. -->
|
67 |
+
|
68 |
+
Users (both direct and downstream) should be made aware of the risks, biases and limitations of the model. More information needed for further recommendations.
|
69 |
+
|
70 |
+
## How to Get Started with the Model
|
71 |
+
|
72 |
+
Use the code below to get started with the model.
|
73 |
+
|
74 |
+
[More Information Needed]
|
75 |
+
|
76 |
+
## Training Details
|
77 |
+
|
78 |
+
### Training Data
|
79 |
+
|
80 |
+
<!-- This should link to a Dataset Card, perhaps with a short stub of information on what the training data is all about as well as documentation related to data pre-processing or additional filtering. -->
|
81 |
+
|
82 |
+
[More Information Needed]
|
83 |
+
|
84 |
+
### Training Procedure
|
85 |
+
|
86 |
+
<!-- This relates heavily to the Technical Specifications. Content here should link to that section when it is relevant to the training procedure. -->
|
87 |
+
|
88 |
+
#### Preprocessing [optional]
|
89 |
+
|
90 |
+
[More Information Needed]
|
91 |
+
|
92 |
+
|
93 |
+
#### Training Hyperparameters
|
94 |
+
|
95 |
+
- **Training regime:** [More Information Needed] <!--fp32, fp16 mixed precision, bf16 mixed precision, bf16 non-mixed precision, fp16 non-mixed precision, fp8 mixed precision -->
|
96 |
+
|
97 |
+
#### Speeds, Sizes, Times [optional]
|
98 |
+
|
99 |
+
<!-- This section provides information about throughput, start/end time, checkpoint size if relevant, etc. -->
|
100 |
+
|
101 |
+
[More Information Needed]
|
102 |
+
|
103 |
+
## Evaluation
|
104 |
+
|
105 |
+
<!-- This section describes the evaluation protocols and provides the results. -->
|
106 |
+
|
107 |
+
### Testing Data, Factors & Metrics
|
108 |
+
|
109 |
+
#### Testing Data
|
110 |
+
|
111 |
+
<!-- This should link to a Dataset Card if possible. -->
|
112 |
+
|
113 |
+
[More Information Needed]
|
114 |
+
|
115 |
+
#### Factors
|
116 |
+
|
117 |
+
<!-- These are the things the evaluation is disaggregating by, e.g., subpopulations or domains. -->
|
118 |
+
|
119 |
+
[More Information Needed]
|
120 |
+
|
121 |
+
#### Metrics
|
122 |
+
|
123 |
+
<!-- These are the evaluation metrics being used, ideally with a description of why. -->
|
124 |
+
|
125 |
+
[More Information Needed]
|
126 |
+
|
127 |
+
### Results
|
128 |
+
|
129 |
+
[More Information Needed]
|
130 |
+
|
131 |
+
#### Summary
|
132 |
+
|
133 |
+
|
134 |
+
|
135 |
+
## Model Examination [optional]
|
136 |
+
|
137 |
+
<!-- Relevant interpretability work for the model goes here -->
|
138 |
+
|
139 |
+
[More Information Needed]
|
140 |
+
|
141 |
+
## Environmental Impact
|
142 |
+
|
143 |
+
<!-- Total emissions (in grams of CO2eq) and additional considerations, such as electricity usage, go here. Edit the suggested text below accordingly -->
|
144 |
+
|
145 |
+
Carbon emissions can be estimated using the [Machine Learning Impact calculator](https://mlco2.github.io/impact#compute) presented in [Lacoste et al. (2019)](https://arxiv.org/abs/1910.09700).
|
146 |
+
|
147 |
+
- **Hardware Type:** [More Information Needed]
|
148 |
+
- **Hours used:** [More Information Needed]
|
149 |
+
- **Cloud Provider:** [More Information Needed]
|
150 |
+
- **Compute Region:** [More Information Needed]
|
151 |
+
- **Carbon Emitted:** [More Information Needed]
|
152 |
+
|
153 |
+
## Technical Specifications [optional]
|
154 |
+
|
155 |
+
### Model Architecture and Objective
|
156 |
+
|
157 |
+
[More Information Needed]
|
158 |
+
|
159 |
+
### Compute Infrastructure
|
160 |
+
|
161 |
+
[More Information Needed]
|
162 |
+
|
163 |
+
#### Hardware
|
164 |
+
|
165 |
+
[More Information Needed]
|
166 |
+
|
167 |
+
#### Software
|
168 |
+
|
169 |
+
[More Information Needed]
|
170 |
+
|
171 |
+
## Citation [optional]
|
172 |
+
|
173 |
+
<!-- If there is a paper or blog post introducing the model, the APA and Bibtex information for that should go in this section. -->
|
174 |
+
|
175 |
+
**BibTeX:**
|
176 |
+
|
177 |
+
[More Information Needed]
|
178 |
+
|
179 |
+
**APA:**
|
180 |
+
|
181 |
+
[More Information Needed]
|
182 |
+
|
183 |
+
## Glossary [optional]
|
184 |
+
|
185 |
+
<!-- If relevant, include terms and calculations in this section that can help readers understand the model or model card. -->
|
186 |
+
|
187 |
+
[More Information Needed]
|
188 |
+
|
189 |
+
## More Information [optional]
|
190 |
+
|
191 |
+
[More Information Needed]
|
192 |
+
|
193 |
+
## Model Card Authors [optional]
|
194 |
+
|
195 |
+
[More Information Needed]
|
196 |
+
|
197 |
+
## Model Card Contact
|
198 |
+
|
199 |
+
[More Information Needed]
|
config.json
ADDED
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"architectures": [
|
3 |
+
"ModularStarEncoder"
|
4 |
+
],
|
5 |
+
"attention_dropout": 0.1,
|
6 |
+
"auto_map": {
|
7 |
+
"AutoConfig": "config.ModularStarEncoderConfig",
|
8 |
+
"AutoModel": "modularStarEncoder.ModularStarEncoder"
|
9 |
+
},
|
10 |
+
"bos_token_id": 0,
|
11 |
+
"conditional_size": 4,
|
12 |
+
"embedding_dropout": 0.1,
|
13 |
+
"eos_token_id": 0,
|
14 |
+
"hidden_act": "gelu_pytorch_tanh",
|
15 |
+
"hidden_size": 1024,
|
16 |
+
"initializer_range": 0.018042,
|
17 |
+
"intermediate_size": 12288,
|
18 |
+
"keys_to_ignore_at_inference": "past_key_values",
|
19 |
+
"layer_matryoshka_loss": true,
|
20 |
+
"layer_norm_eps": 1e-05,
|
21 |
+
"matryoshka_layers": [
|
22 |
+
4,
|
23 |
+
9,
|
24 |
+
18,
|
25 |
+
27,
|
26 |
+
36
|
27 |
+
],
|
28 |
+
"max_position_embeddings": 2048,
|
29 |
+
"mlp_type": "default",
|
30 |
+
"model_type": "ModularStarEncoder",
|
31 |
+
"norm_epsilon": 1e-05,
|
32 |
+
"norm_type": "layer_norm",
|
33 |
+
"num_attention_heads": 16,
|
34 |
+
"num_hidden_layers": 36,
|
35 |
+
"num_key_value_heads": 4,
|
36 |
+
"residual_dropout": 0.1,
|
37 |
+
"rope_theta": 999999.4420358813,
|
38 |
+
"sliding_window": null,
|
39 |
+
"torch_dtype": "bfloat16",
|
40 |
+
"transformers_version": "4.39.3",
|
41 |
+
"use_bias": true,
|
42 |
+
"use_cache": false,
|
43 |
+
"vocab_size": 49156
|
44 |
+
}
|
config.py
ADDED
@@ -0,0 +1,81 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from transformers import PretrainedConfig
|
2 |
+
from typing import List
|
3 |
+
|
4 |
+
|
5 |
+
#STARCODER2_PRETRAINED_CONFIG_ARCHIVE_MAP = {}
|
6 |
+
|
7 |
+
class ModularStarEncoderConfig(PretrainedConfig):
|
8 |
+
model_type = "ModularStarEncoder"
|
9 |
+
keys_to_ignore_at_inference = ["past_key_values"]
|
10 |
+
|
11 |
+
def __init__(
|
12 |
+
self,
|
13 |
+
attention_dropout= 0.1,
|
14 |
+
residual_dropout= 0.1,
|
15 |
+
embedding_dropout= 0.1,
|
16 |
+
bos_token_id= 0,
|
17 |
+
eos_token_id= 0,
|
18 |
+
hidden_act= "gelu_pytorch_tanh",
|
19 |
+
_attn_implementation="flash_attention_2",
|
20 |
+
hidden_size= 1024,
|
21 |
+
conditional_size= 4,
|
22 |
+
initializer_range= 0.018042,
|
23 |
+
intermediate_size= 12288,
|
24 |
+
max_position_embeddings= 2048,
|
25 |
+
mlp_type= "default",
|
26 |
+
model_type= "starcoder2",
|
27 |
+
torch_dtype= "bfloat16",
|
28 |
+
layer_matryoshka_loss= True,
|
29 |
+
matryoshka_layers= [4,9,18,27,36],
|
30 |
+
norm_epsilon= 1e-05,
|
31 |
+
layer_norm_eps=1e-05,
|
32 |
+
norm_type= "layer_norm",
|
33 |
+
num_attention_heads= 16,
|
34 |
+
num_hidden_layers= 36,
|
35 |
+
num_key_value_heads= 4,
|
36 |
+
rope_theta= 999999.4420358813,
|
37 |
+
sliding_window= None,
|
38 |
+
transformers_version= "4.39.3",
|
39 |
+
use_bias= True,
|
40 |
+
use_cache= False,
|
41 |
+
vocab_size= 49156,
|
42 |
+
pad_token_id=0,
|
43 |
+
**kwargs,
|
44 |
+
):
|
45 |
+
if _attn_implementation not in ["flash_attention_2", "sdpa"]:
|
46 |
+
raise ValueError(f"`_attn_implementation` must be 'flash_attention_2', 'sdpa', got {_attn_implementation}.")
|
47 |
+
|
48 |
+
self.attention_dropout=attention_dropout ,
|
49 |
+
self.residual_dropout= residual_dropout,
|
50 |
+
self.embedding_dropout= embedding_dropout,
|
51 |
+
self.bos_token_id= bos_token_id,
|
52 |
+
self.eos_token_id= eos_token_id,
|
53 |
+
self.hidden_act= hidden_act,
|
54 |
+
self._attn_implementation=_attn_implementation,
|
55 |
+
self.hidden_size= hidden_size,
|
56 |
+
self.conditional_size= conditional_size,
|
57 |
+
self.initializer_range= initializer_range,
|
58 |
+
self.intermediate_size= intermediate_size,
|
59 |
+
self.max_position_embeddings= max_position_embeddings,
|
60 |
+
self.mlp_type= mlp_type,
|
61 |
+
self.model_type= model_type,
|
62 |
+
self.torch_dtype= torch_dtype,
|
63 |
+
self.layer_matryoshka_loss= layer_matryoshka_loss,
|
64 |
+
self.matryoshka_layers= matryoshka_layers,
|
65 |
+
self.norm_epsilon= norm_epsilon,
|
66 |
+
self.layer_norm_eps=layer_norm_eps,
|
67 |
+
self.norm_type= norm_type,
|
68 |
+
self.num_attention_heads= num_attention_heads,
|
69 |
+
self.num_hidden_layers= num_hidden_layers,
|
70 |
+
self.num_key_value_heads= num_key_value_heads,
|
71 |
+
self.rope_theta= rope_theta,
|
72 |
+
self.sliding_window= sliding_window,
|
73 |
+
self.transformers_version= transformers_version,
|
74 |
+
self.use_bias= use_bias,
|
75 |
+
self.use_cache= use_cache,
|
76 |
+
self.vocab_size= vocab_size,
|
77 |
+
self.pad_token_id=pad_token_id,
|
78 |
+
super().__init__(
|
79 |
+
bos_token_id=bos_token_id,
|
80 |
+
eos_token_id=eos_token_id,
|
81 |
+
**kwargs)
|
model.safetensors
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:cb424de4f8c7ea7b7bc437b4c1aed15e71176e57ba2be1fd5558cd6887fbb866
|
3 |
+
size 2123859442
|
modularStarEncoder.py
ADDED
@@ -0,0 +1,356 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from transformers import Starcoder2Model
|
2 |
+
import sys
|
3 |
+
from config import ModularStarEncoderConfig
|
4 |
+
import os
|
5 |
+
from dataclasses import dataclass
|
6 |
+
from typing import Optional, Tuple, Union, List
|
7 |
+
import sys
|
8 |
+
import torch
|
9 |
+
import torch.utils.checkpoint
|
10 |
+
from torch import nn
|
11 |
+
from torch.nn import CrossEntropyLoss
|
12 |
+
from transformers.activations import ACT2FN
|
13 |
+
from transformers.modeling_utils import PreTrainedModel
|
14 |
+
from transformers.utils import (
|
15 |
+
ModelOutput,
|
16 |
+
logging,
|
17 |
+
|
18 |
+
)
|
19 |
+
|
20 |
+
logger = logging.get_logger(__name__)
|
21 |
+
|
22 |
+
class StarEncoder2PreTrainedModel(PreTrainedModel):
|
23 |
+
"""
|
24 |
+
An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
|
25 |
+
models.
|
26 |
+
"""
|
27 |
+
|
28 |
+
config_class = ModularStarEncoderConfig
|
29 |
+
base_model_prefix = "ModularStarEncoder"
|
30 |
+
model_type = "ModularStarEncoder"
|
31 |
+
supports_gradient_checkpointing = True
|
32 |
+
_supports_flash_attn_2 = True
|
33 |
+
_supports_sdpa = True
|
34 |
+
_supports_cache_class = True
|
35 |
+
|
36 |
+
|
37 |
+
|
38 |
+
def _init_weights(self, module):
|
39 |
+
"""Initialize the weights"""
|
40 |
+
if isinstance(module, nn.Linear):
|
41 |
+
# Slightly different from the TF version which uses truncated_normal for initialization
|
42 |
+
# cf https://github.com/pytorch/pytorch/pull/5617
|
43 |
+
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
44 |
+
if module.bias is not None:
|
45 |
+
module.bias.data.zero_()
|
46 |
+
elif isinstance(module, nn.Embedding):
|
47 |
+
module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
|
48 |
+
if module.padding_idx is not None:
|
49 |
+
module.weight.data[module.padding_idx].zero_()
|
50 |
+
elif isinstance(module, nn.LayerNorm):
|
51 |
+
module.bias.data.zero_()
|
52 |
+
module.weight.data.fill_(1.0)
|
53 |
+
|
54 |
+
class StarEncoder2Pooler(nn.Module):
|
55 |
+
def __init__(self, config):
|
56 |
+
super().__init__()
|
57 |
+
self.dense = nn.Linear(config.hidden_size, config.hidden_size)
|
58 |
+
self.activation = nn.Tanh()
|
59 |
+
|
60 |
+
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
61 |
+
# We "pool" the model by simply taking the hidden state corresponding
|
62 |
+
# to the last token.
|
63 |
+
last_token_tensor = hidden_states[:, -1]
|
64 |
+
pooled_output = self.dense(last_token_tensor)
|
65 |
+
pooled_output = self.activation(pooled_output)
|
66 |
+
return pooled_output
|
67 |
+
|
68 |
+
@dataclass
|
69 |
+
class ModularStarEncoderOutput(ModelOutput):
|
70 |
+
"""
|
71 |
+
Output type of [`BertForPreTraining`].
|
72 |
+
|
73 |
+
Args:
|
74 |
+
loss (*optional*, returned when `labels` is provided, `torch.FloatTensor` of shape `(1,)`):
|
75 |
+
Total loss as the sum of the masked language modeling loss and the next sequence prediction
|
76 |
+
(classification) loss.
|
77 |
+
prediction_logits (`torch.FloatTensor` of shape `(batch_size, sequence_length, config.vocab_size)`):
|
78 |
+
Prediction scores of the language modeling head (scores for each vocabulary token before SoftMax).
|
79 |
+
seq_relationship_logits (`torch.FloatTensor` of shape `(batch_size, 2)`):
|
80 |
+
Prediction scores of the in context classification (classification) head (scores of True/False continuation
|
81 |
+
before SoftMax).
|
82 |
+
hidden_states (`tuple(torch.FloatTensor)`, *optional*, returned when `output_hidden_states=True` is passed or when `config.output_hidden_states=True`):
|
83 |
+
Tuple of `torch.FloatTensor` (one for the output of the embeddings + one for the output of each layer) of
|
84 |
+
shape `(batch_size, sequence_length, hidden_size)`.
|
85 |
+
|
86 |
+
Hidden-states of the model at the output of each layer plus the initial embedding outputs.
|
87 |
+
attentions (`tuple(torch.FloatTensor)`, *optional*, returned when `output_attentions=True` is passed or when `config.output_attentions=True`):
|
88 |
+
Tuple of `torch.FloatTensor` (one for each layer) of shape `(batch_size, num_heads, sequence_length,
|
89 |
+
sequence_length)`.
|
90 |
+
|
91 |
+
Attentions weights after the attention softmax, used to compute the weighted average in the self-attention
|
92 |
+
heads.
|
93 |
+
"""
|
94 |
+
|
95 |
+
projected_pooled_normalized: Optional[List[torch.FloatTensor]] = None
|
96 |
+
raw_hidden_states: Optional[Tuple[torch.FloatTensor]] = None
|
97 |
+
attentions: Optional[Tuple[torch.FloatTensor]] = None
|
98 |
+
|
99 |
+
|
100 |
+
|
101 |
+
|
102 |
+
|
103 |
+
|
104 |
+
|
105 |
+
|
106 |
+
|
107 |
+
|
108 |
+
|
109 |
+
def forward(self, sequence_output, pooled_output,idx_layer: Optional[torch.Tensor] = None):
|
110 |
+
if self.is_matryoshka:
|
111 |
+
device_sequence = sequence_output.get_device()
|
112 |
+
if device_sequence<0:
|
113 |
+
device_sequence = "cpu"
|
114 |
+
prediction_scores = self.predictions(torch.cat([sequence_output , self.conditional_embeddings(torch.tensor(idx_layer,device=device_sequence).int()).expand(sequence_output.size()[0],sequence_output.size()[1],-1)],dim=-1))
|
115 |
+
seq_relationship_score = self.seq_relationship(torch.cat([pooled_output , self.conditional_embeddings(torch.tensor(idx_layer,device=device_sequence).int()).expand(pooled_output.size()[0],-1)],dim=-1))
|
116 |
+
else:
|
117 |
+
prediction_scores = self.predictions(sequence_output)
|
118 |
+
seq_relationship_score = self.seq_relationship(pooled_output)
|
119 |
+
return prediction_scores, seq_relationship_score
|
120 |
+
|
121 |
+
|
122 |
+
def normalize(my_tensor):
|
123 |
+
embedding_norms = my_tensor.norm(dim=0)
|
124 |
+
|
125 |
+
normalizing_factor = torch.where( # Only normalize embeddings with norm > 1.0.
|
126 |
+
embedding_norms > 1.0, embedding_norms, torch.tensor(1)
|
127 |
+
)
|
128 |
+
|
129 |
+
normalized_tensor = my_tensor / normalizing_factor
|
130 |
+
return normalized_tensor
|
131 |
+
def pooling(x: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
|
132 |
+
"""Pools a batch of vector sequences into a batch of vector global representations.
|
133 |
+
It does so by taking the average representation of the sequence, as indicated by the mask.
|
134 |
+
|
135 |
+
Args:
|
136 |
+
x (torch.Tensor): Batch of vector sequences with shape [B, T, F].
|
137 |
+
mask (torch.Tensor): Batch of masks with shape [B, T].
|
138 |
+
|
139 |
+
Returns:
|
140 |
+
torch.Tensor: Pooled version of the input batch with shape [B, F].
|
141 |
+
"""
|
142 |
+
|
143 |
+
# Expand the mask to match the feature dimensions for proper masking
|
144 |
+
mask_expanded = mask.unsqueeze(-1) # Shape [B, T, 1]
|
145 |
+
|
146 |
+
# Apply the mask to the input tensor
|
147 |
+
masked_x = x * mask_expanded # Shape [B, T, F]
|
148 |
+
# Sum along the time dimension
|
149 |
+
sum_x = masked_x.sum(dim=1) # Shape [B, F]
|
150 |
+
# Calculate the length of valid (non-padded) elements
|
151 |
+
valid_lengths = mask.sum(dim=1).clamp(min=1).unsqueeze(-1) # Shape [B, 1]
|
152 |
+
# Calculate the average pooling, avoiding division by zero
|
153 |
+
pooled_x = sum_x / valid_lengths # Shape [B, F]
|
154 |
+
|
155 |
+
return pooled_x
|
156 |
+
|
157 |
+
def pool_and_normalize(
|
158 |
+
features_sequence: torch.Tensor,
|
159 |
+
attention_masks: torch.Tensor,
|
160 |
+
return_norms: bool = False,
|
161 |
+
) -> Union[torch.Tensor, List[torch.Tensor]]:
|
162 |
+
"""Temporal ooling of sequences of vectors and projection onto the unit sphere.
|
163 |
+
|
164 |
+
Args:
|
165 |
+
features_sequence (torch.Tensor): Inpute features with shape [B, T, F].
|
166 |
+
attention_masks (torch.Tensor): Pooling masks with shape [B, T, F].
|
167 |
+
return_norms (bool, optional): Whether to additionally return the norms. Defaults to False.
|
168 |
+
|
169 |
+
Returns:
|
170 |
+
Union[torch.Tensor, List[torch.Tensor]]: Pooled and normalized vectors with shape [B, F].
|
171 |
+
"""
|
172 |
+
|
173 |
+
pooled_embeddings = pooling(features_sequence, attention_masks)
|
174 |
+
embedding_norms = pooled_embeddings.norm(dim=1)
|
175 |
+
|
176 |
+
normalizing_factor = torch.where( # Only normalize embeddings with norm > 1.0.
|
177 |
+
embedding_norms > 1.0, embedding_norms, torch.ones_like(embedding_norms)
|
178 |
+
)
|
179 |
+
|
180 |
+
pooled_normalized_embeddings = pooled_embeddings / normalizing_factor[:, None]
|
181 |
+
|
182 |
+
if return_norms:
|
183 |
+
return pooled_normalized_embeddings, embedding_norms
|
184 |
+
else:
|
185 |
+
return pooled_normalized_embeddings
|
186 |
+
|
187 |
+
def get_pooling_mask(
|
188 |
+
input_ids: torch.Tensor, sep_token_id: Union[int, float]
|
189 |
+
) -> torch.Tensor:
|
190 |
+
"""Gets pooling masks. For a sequence of input tokens, the mask will be
|
191 |
+
a sequence of zeros up until the first [SEP] occurrence, and 1 after that.
|
192 |
+
|
193 |
+
Args:
|
194 |
+
input_ids (torch.Tensor): Batch of input ids with shape [B, T].
|
195 |
+
sep_token_id (Union[int, float]): Id for [SEP] token.
|
196 |
+
|
197 |
+
Returns:
|
198 |
+
torch.Tensor: Batch of pooling masks with shape [B, T]
|
199 |
+
"""
|
200 |
+
# idx indicates the first occurrence of sep_token_id per along dim 0 of input_ids
|
201 |
+
idx = (input_ids == sep_token_id).float().flip(1).argmax(1)
|
202 |
+
|
203 |
+
idx = input_ids.size(-1)-idx-1
|
204 |
+
|
205 |
+
repeated_idx = idx.unsqueeze(1).repeat(1, input_ids.size(1))
|
206 |
+
|
207 |
+
ranges = torch.arange(input_ids.size(1)).repeat(input_ids.size(0), 1)
|
208 |
+
|
209 |
+
pooling_mask = (repeated_idx <= ranges).long()
|
210 |
+
|
211 |
+
return pooling_mask
|
212 |
+
|
213 |
+
def adapt_model(model,config,till_layer:int):
|
214 |
+
model = model.starEncoder2
|
215 |
+
|
216 |
+
encoder_config = config
|
217 |
+
layers = encoder_config.matryoshka_layers
|
218 |
+
feature_dim = encoder_config.hidden_size
|
219 |
+
|
220 |
+
model.projection_heads = torch.nn.ModuleList()
|
221 |
+
if till_layer:
|
222 |
+
print(f"ATTENTION: till layer is on, you are pruning the model keeping just the first {till_layer} layers")
|
223 |
+
model.layers = model.layers[:till_layer]
|
224 |
+
model.projection_heads.append(torch.nn.Sequential(
|
225 |
+
torch.nn.Linear(feature_dim, feature_dim),
|
226 |
+
torch.nn.LeakyReLU(),
|
227 |
+
torch.nn.Linear(feature_dim, feature_dim),
|
228 |
+
))
|
229 |
+
else:
|
230 |
+
for layer in layers:
|
231 |
+
model.projection_heads.append(torch.nn.Sequential(
|
232 |
+
torch.nn.Linear(feature_dim, feature_dim),
|
233 |
+
torch.nn.LeakyReLU(),
|
234 |
+
torch.nn.Linear(feature_dim, feature_dim),
|
235 |
+
))
|
236 |
+
#setting off causal masking
|
237 |
+
for layer in model.layers:
|
238 |
+
layer.self_attn.is_causal=False
|
239 |
+
|
240 |
+
model.temperature_coef = torch.nn.Parameter(torch.Tensor([10.0]),requires_grad=False)
|
241 |
+
|
242 |
+
return model
|
243 |
+
|
244 |
+
class ModularStarEncoder(StarEncoder2PreTrainedModel):
|
245 |
+
_tied_weights_keys = ["predictions.decoder.bias", "cls.predictions.decoder.weight"]
|
246 |
+
config_class = ModularStarEncoderConfig
|
247 |
+
def __init__(self, config):
|
248 |
+
super().__init__(config)
|
249 |
+
self.model_type = "ModularStarEncoder"
|
250 |
+
for element in dir(config):
|
251 |
+
value = getattr(config, element) # Get the attribute value
|
252 |
+
if (isinstance(value, tuple) or isinstance(value, list)) and len(value)>0:
|
253 |
+
setattr(config, element, value[0])
|
254 |
+
self.layer_matryoshka_loss = config.layer_matryoshka_loss
|
255 |
+
self.matryoshka_layers = config.matryoshka_layers
|
256 |
+
|
257 |
+
|
258 |
+
self.starEncoder2 = Starcoder2Model(config)
|
259 |
+
|
260 |
+
|
261 |
+
#setting off causal masking
|
262 |
+
for layer in self.starEncoder2.layers:
|
263 |
+
layer.self_attn.is_causal=False
|
264 |
+
# Initialize weights and apply final processing
|
265 |
+
self.post_init()
|
266 |
+
|
267 |
+
self.starEncoder2 = adapt_model(self ,config=config,till_layer=False)
|
268 |
+
|
269 |
+
|
270 |
+
|
271 |
+
|
272 |
+
|
273 |
+
def forward(
|
274 |
+
self,
|
275 |
+
input_ids: Optional[torch.Tensor] = None,
|
276 |
+
attention_mask: Optional[torch.Tensor] = None,
|
277 |
+
#token_type_ids: Optional[torch.Tensor] = None,
|
278 |
+
position_ids: Optional[torch.Tensor] = None,
|
279 |
+
head_mask: Optional[torch.Tensor] = None,
|
280 |
+
inputs_embeds: Optional[torch.Tensor] = None,
|
281 |
+
labels: Optional[torch.Tensor] = None,
|
282 |
+
next_sentence_label: Optional[torch.Tensor] = None,
|
283 |
+
output_attentions: Optional[bool] = None,
|
284 |
+
output_hidden_states: Optional[bool] = None,
|
285 |
+
sep_token_id:Optional[int] = 49152,
|
286 |
+
) -> Union[Tuple[torch.Tensor], ModularStarEncoderOutput]:
|
287 |
+
r"""
|
288 |
+
labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
|
289 |
+
Labels for computing the masked language modeling loss. Indices should be in `[-100, 0, ...,
|
290 |
+
config.vocab_size]` (see `input_ids` docstring) Tokens with indices set to `-100` are ignored (masked),
|
291 |
+
the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`
|
292 |
+
next_sentence_label (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
|
293 |
+
This label is assigned to the in context loss:
|
294 |
+
- 0 indicates sequence B belongs to the same repository of A,
|
295 |
+
- 1 indicates sequence B is a random repository.
|
296 |
+
kwargs (`Dict[str, any]`, optional, defaults to *{}*):
|
297 |
+
Used to hide legacy arguments that have been deprecated.
|
298 |
+
|
299 |
+
|
300 |
+
"""
|
301 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
302 |
+
|
303 |
+
source_embedding = self.starEncoder2(
|
304 |
+
input_ids,
|
305 |
+
attention_mask=attention_mask,
|
306 |
+
position_ids=position_ids,
|
307 |
+
inputs_embeds=inputs_embeds,
|
308 |
+
output_attentions=output_attentions,
|
309 |
+
output_hidden_states=True,
|
310 |
+
return_dict=return_dict,
|
311 |
+
).hidden_states
|
312 |
+
|
313 |
+
|
314 |
+
DEVICE = source_embedding[-1].get_device()
|
315 |
+
|
316 |
+
try:
|
317 |
+
projection_fn = self.starEncoder2.module.projection_heads
|
318 |
+
temp_coef = self.starEncoder2.module.temperature_coef
|
319 |
+
except AttributeError:
|
320 |
+
projection_fn = self.starEncoder2.projection_heads
|
321 |
+
temp_coef = self.starEncoder2.temperature_coef
|
322 |
+
|
323 |
+
for head in projection_fn:
|
324 |
+
head.to(DEVICE)
|
325 |
+
temp_coef.to(DEVICE)
|
326 |
+
|
327 |
+
|
328 |
+
|
329 |
+
|
330 |
+
pooling_mask_source_targtes = get_pooling_mask(
|
331 |
+
input_ids, sep_token_id
|
332 |
+
) # Pooling masks indicate the second [SEP] occurrence, 0 till SEP, then all ones.
|
333 |
+
|
334 |
+
pooled_and_normalized = []
|
335 |
+
for idx,matr_layer in enumerate(self.matryoshka_layers):
|
336 |
+
source_embedding_proj = projection_fn[idx](source_embedding[matr_layer])
|
337 |
+
|
338 |
+
normalized_source_embedding, embedding_norms = pool_and_normalize(
|
339 |
+
source_embedding_proj,
|
340 |
+
pooling_mask_source_targtes,
|
341 |
+
return_norms=True,
|
342 |
+
)
|
343 |
+
|
344 |
+
pooled_and_normalized.append(normalized_source_embedding)
|
345 |
+
|
346 |
+
|
347 |
+
return ModularStarEncoderOutput(
|
348 |
+
projected_pooled_normalized = pooled_and_normalized,
|
349 |
+
raw_hidden_states=source_embedding.hidden_states,
|
350 |
+
attentions=source_embedding.attentions,
|
351 |
+
)
|
352 |
+
|
353 |
+
|
354 |
+
|
355 |
+
|
356 |
+
|