andrecornman commited on
Commit
5a7d048
·
verified ·
1 Parent(s): 99d28e3

Upload gLM2ForEmbedding

Browse files
Files changed (5) hide show
  1. README.md +199 -0
  2. config.json +20 -0
  3. configuration_glm2.py +51 -0
  4. model.safetensors +3 -0
  5. modeling_glm2.py +619 -0
README.md ADDED
@@ -0,0 +1,199 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ library_name: transformers
3
+ tags: []
4
+ ---
5
+
6
+ # Model Card for Model ID
7
+
8
+ <!-- Provide a quick summary of what the model is/does. -->
9
+
10
+
11
+
12
+ ## Model Details
13
+
14
+ ### Model Description
15
+
16
+ <!-- Provide a longer summary of what this model is. -->
17
+
18
+ This is the model card of a 🤗 transformers model that has been pushed on the Hub. This model card has been automatically generated.
19
+
20
+ - **Developed by:** [More Information Needed]
21
+ - **Funded by [optional]:** [More Information Needed]
22
+ - **Shared by [optional]:** [More Information Needed]
23
+ - **Model type:** [More Information Needed]
24
+ - **Language(s) (NLP):** [More Information Needed]
25
+ - **License:** [More Information Needed]
26
+ - **Finetuned from model [optional]:** [More Information Needed]
27
+
28
+ ### Model Sources [optional]
29
+
30
+ <!-- Provide the basic links for the model. -->
31
+
32
+ - **Repository:** [More Information Needed]
33
+ - **Paper [optional]:** [More Information Needed]
34
+ - **Demo [optional]:** [More Information Needed]
35
+
36
+ ## Uses
37
+
38
+ <!-- Address questions around how the model is intended to be used, including the foreseeable users of the model and those affected by the model. -->
39
+
40
+ ### Direct Use
41
+
42
+ <!-- This section is for the model use without fine-tuning or plugging into a larger ecosystem/app. -->
43
+
44
+ [More Information Needed]
45
+
46
+ ### Downstream Use [optional]
47
+
48
+ <!-- This section is for the model use when fine-tuned for a task, or when plugged into a larger ecosystem/app -->
49
+
50
+ [More Information Needed]
51
+
52
+ ### Out-of-Scope Use
53
+
54
+ <!-- This section addresses misuse, malicious use, and uses that the model will not work well for. -->
55
+
56
+ [More Information Needed]
57
+
58
+ ## Bias, Risks, and Limitations
59
+
60
+ <!-- This section is meant to convey both technical and sociotechnical limitations. -->
61
+
62
+ [More Information Needed]
63
+
64
+ ### Recommendations
65
+
66
+ <!-- This section is meant to convey recommendations with respect to the bias, risk, and technical limitations. -->
67
+
68
+ Users (both direct and downstream) should be made aware of the risks, biases and limitations of the model. More information needed for further recommendations.
69
+
70
+ ## How to Get Started with the Model
71
+
72
+ Use the code below to get started with the model.
73
+
74
+ [More Information Needed]
75
+
76
+ ## Training Details
77
+
78
+ ### Training Data
79
+
80
+ <!-- This should link to a Dataset Card, perhaps with a short stub of information on what the training data is all about as well as documentation related to data pre-processing or additional filtering. -->
81
+
82
+ [More Information Needed]
83
+
84
+ ### Training Procedure
85
+
86
+ <!-- This relates heavily to the Technical Specifications. Content here should link to that section when it is relevant to the training procedure. -->
87
+
88
+ #### Preprocessing [optional]
89
+
90
+ [More Information Needed]
91
+
92
+
93
+ #### Training Hyperparameters
94
+
95
+ - **Training regime:** [More Information Needed] <!--fp32, fp16 mixed precision, bf16 mixed precision, bf16 non-mixed precision, fp16 non-mixed precision, fp8 mixed precision -->
96
+
97
+ #### Speeds, Sizes, Times [optional]
98
+
99
+ <!-- This section provides information about throughput, start/end time, checkpoint size if relevant, etc. -->
100
+
101
+ [More Information Needed]
102
+
103
+ ## Evaluation
104
+
105
+ <!-- This section describes the evaluation protocols and provides the results. -->
106
+
107
+ ### Testing Data, Factors & Metrics
108
+
109
+ #### Testing Data
110
+
111
+ <!-- This should link to a Dataset Card if possible. -->
112
+
113
+ [More Information Needed]
114
+
115
+ #### Factors
116
+
117
+ <!-- These are the things the evaluation is disaggregating by, e.g., subpopulations or domains. -->
118
+
119
+ [More Information Needed]
120
+
121
+ #### Metrics
122
+
123
+ <!-- These are the evaluation metrics being used, ideally with a description of why. -->
124
+
125
+ [More Information Needed]
126
+
127
+ ### Results
128
+
129
+ [More Information Needed]
130
+
131
+ #### Summary
132
+
133
+
134
+
135
+ ## Model Examination [optional]
136
+
137
+ <!-- Relevant interpretability work for the model goes here -->
138
+
139
+ [More Information Needed]
140
+
141
+ ## Environmental Impact
142
+
143
+ <!-- Total emissions (in grams of CO2eq) and additional considerations, such as electricity usage, go here. Edit the suggested text below accordingly -->
144
+
145
+ Carbon emissions can be estimated using the [Machine Learning Impact calculator](https://mlco2.github.io/impact#compute) presented in [Lacoste et al. (2019)](https://arxiv.org/abs/1910.09700).
146
+
147
+ - **Hardware Type:** [More Information Needed]
148
+ - **Hours used:** [More Information Needed]
149
+ - **Cloud Provider:** [More Information Needed]
150
+ - **Compute Region:** [More Information Needed]
151
+ - **Carbon Emitted:** [More Information Needed]
152
+
153
+ ## Technical Specifications [optional]
154
+
155
+ ### Model Architecture and Objective
156
+
157
+ [More Information Needed]
158
+
159
+ ### Compute Infrastructure
160
+
161
+ [More Information Needed]
162
+
163
+ #### Hardware
164
+
165
+ [More Information Needed]
166
+
167
+ #### Software
168
+
169
+ [More Information Needed]
170
+
171
+ ## Citation [optional]
172
+
173
+ <!-- If there is a paper or blog post introducing the model, the APA and Bibtex information for that should go in this section. -->
174
+
175
+ **BibTeX:**
176
+
177
+ [More Information Needed]
178
+
179
+ **APA:**
180
+
181
+ [More Information Needed]
182
+
183
+ ## Glossary [optional]
184
+
185
+ <!-- If relevant, include terms and calculations in this section that can help readers understand the model or model card. -->
186
+
187
+ [More Information Needed]
188
+
189
+ ## More Information [optional]
190
+
191
+ [More Information Needed]
192
+
193
+ ## Model Card Authors [optional]
194
+
195
+ [More Information Needed]
196
+
197
+ ## Model Card Contact
198
+
199
+ [More Information Needed]
config.json ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "gLM2ForEmbedding"
4
+ ],
5
+ "auto_map": {
6
+ "AutoConfig": "configuration_glm2.gLM2EmbedConfig",
7
+ "AutoModel": "modeling_glm2.gLM2ForEmbedding"
8
+ },
9
+ "depth": 16,
10
+ "dim": 1280,
11
+ "ffn_dim_multiplier": null,
12
+ "heads": 20,
13
+ "model_type": "gLM2Embed",
14
+ "norm_eps": 1e-05,
15
+ "projection_dim": 512,
16
+ "swiglu_multiple_of": 256,
17
+ "torch_dtype": "float32",
18
+ "transformers_version": "4.44.1",
19
+ "vocab_size": 37
20
+ }
configuration_glm2.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """gLM2 model configuration"""
2
+
3
+ from typing import Optional
4
+ from transformers import PretrainedConfig
5
+ from transformers.utils import logging
6
+
7
+ logger = logging.get_logger(__name__)
8
+
9
+
10
+ class gLM2Config(PretrainedConfig):
11
+ model_type = "gLM2"
12
+
13
+ def __init__(
14
+ self,
15
+ dim: int = 640,
16
+ depth: int = 30,
17
+ heads: int = 10,
18
+ vocab_size: int = 37,
19
+ swiglu_multiple_of: int = 256,
20
+ ffn_dim_multiplier: Optional[float] = None,
21
+ norm_eps: float = 1e-5,
22
+ **kwargs
23
+ ):
24
+ super().__init__(**kwargs)
25
+ self.dim = dim
26
+ self.depth = depth
27
+ self.heads = heads
28
+ self.vocab_size = vocab_size
29
+ self.swiglu_multiple_of = swiglu_multiple_of
30
+ self.ffn_dim_multiplier = ffn_dim_multiplier
31
+ self.norm_eps = norm_eps
32
+
33
+ self.auto_map = {
34
+ "AutoConfig": "configuration_glm2.gLM2Config",
35
+ "AutoModel": "modeling_glm2.gLM2Model",
36
+ "AutoModelForMaskedLM": "modeling_glm2.gLM2ForMaskedLM"
37
+ }
38
+
39
+
40
+ class gLM2EmbedConfig(gLM2Config):
41
+ model_type = "gLM2Embed"
42
+
43
+ def __init__(self, projection_dim: int = 512, **kwargs):
44
+ super().__init__(**kwargs)
45
+ self.projection_dim = projection_dim
46
+
47
+ self.auto_map = {
48
+ "AutoConfig": "configuration_glm2.gLM2EmbedConfig",
49
+ "AutoModel": "modeling_glm2.gLM2ForEmbedding",
50
+ }
51
+
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:677a38fd3bf753c213d5ec1de58b85e14013577396901397d27a56507fcaae21
3
+ size 1303222360
modeling_glm2.py ADDED
@@ -0,0 +1,619 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """PyTorch gLM2 model.
2
+
3
+ Requires flash attention.
4
+ Some modules adapted from:
5
+ https://github.com/meta-llama/llama/blob/main/llama/model.py
6
+ """
7
+ import math
8
+ import torch
9
+ from einops import rearrange
10
+ from typing import Optional, Tuple, Union
11
+ from torch import nn
12
+ from torch.nn import CrossEntropyLoss
13
+ from transformers.modeling_outputs import (
14
+ BaseModelOutput,
15
+ BaseModelOutputWithPooling,
16
+ MaskedLMOutput,
17
+ )
18
+ from transformers.modeling_utils import PreTrainedModel
19
+ from transformers.utils import logging
20
+
21
+ try:
22
+ from flash_attn.ops.activations import swiglu
23
+ from flash_attn.layers.rotary import apply_rotary_emb_func
24
+ from flash_attn import (
25
+ flash_attn_kvpacked_func,
26
+ flash_attn_varlen_kvpacked_func,
27
+ )
28
+ from flash_attn.bert_padding import pad_input, unpad_input
29
+ from flash_attn.ops.triton.layer_norm import RMSNorm
30
+ except ImportError:
31
+ raise ImportError(
32
+ "gLM2 requires flash attention: `pip install flash-attn --no-build-isolation`")
33
+
34
+ from .configuration_glm2 import gLM2Config, gLM2EmbedConfig
35
+
36
+
37
+ logger = logging.get_logger(__name__)
38
+
39
+
40
+ class RotaryEmbedding(torch.nn.Module):
41
+ """
42
+ Copied from https://github.com/Dao-AILab/flash-attention/blob/main/flash_attn/layers/rotary.py.
43
+ Changed to only support passing in q or k individually, so that we can use varlen rotary.
44
+ """
45
+
46
+ def __init__(
47
+ self,
48
+ dim: int,
49
+ base=10000.0,
50
+ interleaved=False,
51
+ scale_base=None,
52
+ pos_idx_in_fp32=True,
53
+ device=None,
54
+ ):
55
+ super().__init__()
56
+ self.dim = dim
57
+ self.base = float(base)
58
+ self.pos_idx_in_fp32 = pos_idx_in_fp32
59
+ # Generate and save the inverse frequency buffer (non trainable)
60
+ inv_freq = self._compute_inv_freq(device)
61
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
62
+ self.interleaved = interleaved
63
+ self.scale_base = scale_base
64
+ scale = (
65
+ (torch.arange(0, dim, 2, device=device, dtype=torch.float32) + 0.4 * dim)
66
+ / (1.4 * dim)
67
+ if scale_base is not None
68
+ else None
69
+ )
70
+ self.register_buffer("scale", scale, persistent=False)
71
+
72
+ self._seq_len_cached = 0
73
+ self._cos_cached = None
74
+ self._sin_cached = None
75
+ self._cos_k_cached = None
76
+ self._sin_k_cached = None
77
+
78
+ def _compute_inv_freq(self, device=None):
79
+ return 1.0 / (
80
+ self.base
81
+ ** (
82
+ torch.arange(0, self.dim, 2, device=device,
83
+ dtype=torch.float32)
84
+ / self.dim
85
+ )
86
+ )
87
+
88
+ def _update_cos_sin_cache(self, seqlen, device=None, dtype=None):
89
+ # Reset the tables if the sequence length has changed,
90
+ # if we're on a new device (possibly due to tracing for instance),
91
+ # or if we're switching from inference mode to training
92
+ if (
93
+ seqlen > self._seq_len_cached
94
+ or self._cos_cached is None
95
+ or self._cos_cached.device != device
96
+ or self._cos_cached.dtype != dtype
97
+ or (self.training and self._cos_cached.is_inference())
98
+ ):
99
+ self._seq_len_cached = seqlen
100
+ # We want fp32 here, not self.inv_freq.dtype, since the model could be loaded in bf16
101
+ # And the output of arange can be quite large, so bf16 would lose a lot of precision.
102
+ # However, for compatibility reason, we add an option to use the dtype of self.inv_freq.
103
+ if self.pos_idx_in_fp32:
104
+ t = torch.arange(seqlen, device=device, dtype=torch.float32)
105
+ # We want fp32 here as well since inv_freq will be multiplied with t, and the output
106
+ # will be large. Having it in bf16 will lose a lot of precision and cause the
107
+ # cos & sin output to change significantly.
108
+ # We want to recompute self.inv_freq if it was not loaded in fp32
109
+ if self.inv_freq.dtype != torch.float32:
110
+ inv_freq = self._compute_inv_freq(device=device)
111
+ else:
112
+ inv_freq = self.inv_freq
113
+ else:
114
+ t = torch.arange(seqlen, device=device,
115
+ dtype=self.inv_freq.dtype)
116
+ inv_freq = self.inv_freq
117
+ # Don't do einsum, it converts fp32 to fp16 under AMP
118
+ # freqs = torch.einsum("i,j->ij", t, self.inv_freq)
119
+ freqs = torch.outer(t, inv_freq)
120
+ if self.scale is None:
121
+ self._cos_cached = torch.cos(freqs).to(dtype)
122
+ self._sin_cached = torch.sin(freqs).to(dtype)
123
+ else:
124
+ power = (
125
+ torch.arange(
126
+ seqlen, dtype=self.scale.dtype, device=self.scale.device
127
+ )
128
+ - seqlen // 2
129
+ ) / self.scale_base
130
+ scale = self.scale.to(device=power.device) ** rearrange(
131
+ power, "s -> s 1"
132
+ )
133
+ # We want the multiplication by scale to happen in fp32
134
+ self._cos_cached = (torch.cos(freqs) * scale).to(dtype)
135
+ self._sin_cached = (torch.sin(freqs) * scale).to(dtype)
136
+ self._cos_k_cached = (torch.cos(freqs) / scale).to(dtype)
137
+ self._sin_k_cached = (torch.sin(freqs) / scale).to(dtype)
138
+
139
+ def forward(
140
+ self,
141
+ q: torch.Tensor,
142
+ k: torch.Tensor,
143
+ seqlen_offset: Union[int, torch.Tensor] = 0,
144
+ cu_seqlens: Optional[torch.Tensor] = None,
145
+ max_seqlen: Optional[int] = None,
146
+ ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
147
+ """
148
+ q: (batch, seqlen, nheads, headdim). If cu_seqlens is not None,
149
+ shape (total_seqlen, nheads, headdim).
150
+ k: (batch, seqlen, nheads, headdim). If cu_seqlens is not None,
151
+ shape (total_seqlen, nheads, headdim).
152
+ seqlen_offset: (batch_size,) or int. Each sequence in x is shifted by this amount.
153
+ Most commonly used in inference when we have KV cache.
154
+ If it's a tensor of shape (batch_size,), then to update the cos / sin cache, one
155
+ should pass in max_seqlen, which will update the cos / sin cache up to that length.
156
+ Apply rotary embedding *inplace* to qkv and / or kv.
157
+ """
158
+ if cu_seqlens is not None:
159
+ assert max_seqlen is not None
160
+ seqlen = q.shape[1] if max_seqlen is None else max_seqlen
161
+ if max_seqlen is not None:
162
+ self._update_cos_sin_cache(
163
+ max_seqlen, device=q.device, dtype=q.dtype)
164
+ elif isinstance(seqlen_offset, int):
165
+ self._update_cos_sin_cache(
166
+ seqlen + seqlen_offset, device=q.device, dtype=q.dtype
167
+ )
168
+ q = apply_rotary_emb_func(
169
+ q,
170
+ self._cos_cached,
171
+ self._sin_cached,
172
+ interleaved=self.interleaved,
173
+ inplace=True,
174
+ seqlen_offsets=seqlen_offset,
175
+ cu_seqlens=cu_seqlens,
176
+ max_seqlen=max_seqlen,
177
+ )
178
+ if self.scale is None:
179
+ k = apply_rotary_emb_func(
180
+ k,
181
+ self._cos_cached,
182
+ self._sin_cached,
183
+ interleaved=self.interleaved,
184
+ inplace=True,
185
+ seqlen_offsets=seqlen_offset,
186
+ cu_seqlens=cu_seqlens,
187
+ max_seqlen=max_seqlen,
188
+ )
189
+ else:
190
+ k = apply_rotary_emb_func(
191
+ k,
192
+ self._cos_k_cached,
193
+ self._sin_k_cached,
194
+ interleaved=self.interleaved,
195
+ inplace=True,
196
+ seqlen_offsets=seqlen_offset,
197
+ cu_seqlens=cu_seqlens,
198
+ max_seqlen=max_seqlen,
199
+ )
200
+ return q, k
201
+
202
+
203
+ # @torch.jit.script
204
+ # def rmsnorm_func(hidden_states, weight, variance_epsilon):
205
+ # """Apply the root mean square normalization."""
206
+ # input_dtype = hidden_states.dtype
207
+ # hidden_states = hidden_states.to(torch.float32)
208
+ # variance = hidden_states.pow(2).mean(-1, keepdim=True)
209
+ # hidden_states = hidden_states * torch.rsqrt(variance + variance_epsilon)
210
+ # return (weight * hidden_states).to(input_dtype)
211
+
212
+
213
+ # class RMSNorm(nn.Module):
214
+ # """Root mean square normalization."""
215
+
216
+ # def __init__(self, dim, eps=1e-6):
217
+ # super().__init__()
218
+ # self.weight = nn.Parameter(torch.ones(dim))
219
+ # self.register_buffer(
220
+ # "variance_epsilon",
221
+ # torch.tensor(eps),
222
+ # persistent=False,
223
+ # )
224
+
225
+ # def forward(self, hidden_states):
226
+ # return rmsnorm_func(hidden_states, self.weight, self.variance_epsilon)
227
+
228
+
229
+ class Attention(nn.Module):
230
+ """Multi-head attention module."""
231
+
232
+ def __init__(self, config: gLM2Config):
233
+ super().__init__()
234
+ self.n_heads = config.heads
235
+ self.head_dim = config.dim // config.heads
236
+
237
+ self.wqkv = nn.Linear(config.dim, self.n_heads *
238
+ self.head_dim * 3, bias=False)
239
+ self.wo = nn.Linear(config.heads * self.head_dim,
240
+ config.dim, bias=False)
241
+
242
+ self.rotary_emb = RotaryEmbedding(self.head_dim)
243
+
244
+ def _forward_varlen(
245
+ self,
246
+ x: torch.Tensor,
247
+ cu_seqlens: Optional[torch.Tensor] = None,
248
+ max_seq_len: Optional[torch.Tensor] = None,
249
+ ) -> torch.Tensor:
250
+ total_seqlen, h_size = x.shape
251
+ qkv = self.wqkv(x)
252
+ q, k, v = torch.split(qkv, self.n_heads * self.head_dim, dim=-1)
253
+
254
+ q = q.view(total_seqlen, self.n_heads, self.head_dim)
255
+ k = k.view(total_seqlen, self.n_heads, self.head_dim)
256
+ v = v.view(total_seqlen, self.n_heads, self.head_dim)
257
+
258
+ q, k = self.rotary_emb(
259
+ q, k, cu_seqlens=cu_seqlens, max_seqlen=max_seq_len)
260
+
261
+ # (seqlen, 2, n_heads, head_dim)
262
+ kv = torch.stack([k, v], 1)
263
+
264
+ # (seqlen, n_heads, head_dim)
265
+ output = flash_attn_varlen_kvpacked_func(
266
+ q,
267
+ kv,
268
+ cu_seqlens_q=cu_seqlens,
269
+ cu_seqlens_k=cu_seqlens,
270
+ max_seqlen_q=max_seq_len,
271
+ max_seqlen_k=max_seq_len,
272
+ dropout_p=0.0,
273
+ causal=False,
274
+ )
275
+ output = output.view(total_seqlen, h_size)
276
+ return self.wo(output)
277
+
278
+ def forward(
279
+ self,
280
+ x: torch.Tensor,
281
+ cu_seqlens: Optional[torch.Tensor] = None,
282
+ max_seq_len: Optional[torch.Tensor] = None,
283
+ ) -> torch.Tensor:
284
+ if cu_seqlens is not None:
285
+ assert max_seq_len is not None
286
+ return self._forward_varlen(x, cu_seqlens, max_seq_len)
287
+
288
+ bsz, seqlen, h_size = x.shape
289
+ qkv = self.wqkv(x)
290
+ q, k, v = torch.split(qkv, self.n_heads * self.head_dim, dim=-1)
291
+ q = q.view(bsz, seqlen, self.n_heads, self.head_dim)
292
+ k = k.view(bsz, seqlen, self.n_heads, self.head_dim)
293
+ v = v.view(bsz, seqlen, self.n_heads, self.head_dim)
294
+
295
+ q, k = self.rotary_emb(q, k)
296
+ # (bs, seqlen, 2, n_heads, head_dim)
297
+ kv = torch.stack([k, v], 2)
298
+
299
+ output = flash_attn_kvpacked_func(
300
+ q,
301
+ kv,
302
+ dropout_p=0.0,
303
+ causal=False,
304
+ )
305
+ output = output.view(bsz, seqlen, h_size)
306
+ return self.wo(output)
307
+
308
+
309
+ class FeedForward(nn.Module):
310
+ def __init__(
311
+ self,
312
+ dim: int,
313
+ hidden_dim: int,
314
+ multiple_of: int,
315
+ ffn_dim_multiplier: Optional[float],
316
+ ):
317
+ """
318
+ SwiGLU FeedForward module.
319
+
320
+ Args:
321
+ dim (int): Input dimension.
322
+ hidden_dim (int): Hidden dimension of the feedforward layer.
323
+ multiple_of (int): Value to ensure hidden dimension is a multiple of this value.
324
+ ffn_dim_multiplier (float, optional): Custom multiplier for hidden dimension. Defaults to None.
325
+ """
326
+ super().__init__()
327
+ hidden_dim = int(2 * hidden_dim / 3)
328
+ # custom dim factor multiplier
329
+ if ffn_dim_multiplier is not None:
330
+ hidden_dim = int(ffn_dim_multiplier * hidden_dim)
331
+ hidden_dim = multiple_of * \
332
+ ((hidden_dim + multiple_of - 1) // multiple_of)
333
+
334
+ self.w1 = nn.Linear(dim, hidden_dim, bias=False)
335
+ self.w2 = nn.Linear(hidden_dim, dim, bias=False)
336
+ self.w3 = nn.Linear(dim, hidden_dim, bias=False)
337
+
338
+ def forward(self, x):
339
+ return self.w2(swiglu(self.w1(x), self.w3(x)))
340
+
341
+
342
+ class TransformerBlock(nn.Module):
343
+ def __init__(self, config: gLM2Config):
344
+ super().__init__()
345
+ self.n_heads = config.heads
346
+ self.dim = config.dim
347
+ self.head_dim = config.dim // config.heads
348
+ self.attention = Attention(config)
349
+ self.feed_forward = FeedForward(
350
+ dim=config.dim,
351
+ hidden_dim=4 * config.dim,
352
+ multiple_of=config.swiglu_multiple_of,
353
+ ffn_dim_multiplier=config.ffn_dim_multiplier,
354
+ )
355
+ self.attention_norm = RMSNorm(config.dim, eps=config.norm_eps)
356
+ self.ffn_norm = RMSNorm(config.dim, eps=config.norm_eps)
357
+
358
+ def forward(
359
+ self,
360
+ x: torch.Tensor,
361
+ cu_seqlens: Optional[torch.Tensor] = None,
362
+ max_seq_len: Optional[torch.Tensor] = None,
363
+ ) -> torch.Tensor:
364
+ r = self.attention(
365
+ self.attention_norm(x), cu_seqlens, max_seq_len
366
+ )
367
+ h = x + r
368
+ r = self.feed_forward(self.ffn_norm(h))
369
+ out = h + r
370
+ return out
371
+
372
+
373
+ class TransformerLayers(nn.Module):
374
+ def __init__(self, config: gLM2Config):
375
+ super().__init__()
376
+ self.config = config
377
+ self.layers = torch.nn.ModuleList(
378
+ [TransformerBlock(config=config) for _ in range(config.depth)]
379
+ )
380
+ self.apply(self._init_weights)
381
+ # Apply special scaled init to the residual projections, per GPT-2 paper.
382
+ # Weight w2 is output of FeedForward. Weight wo is output of Attention.
383
+ for pn, p in self.named_parameters():
384
+ if pn.endswith('w2.weight') or pn.endswith('wo.weight'):
385
+ torch.nn.init.normal_(
386
+ p, mean=0.0, std=0.02/math.sqrt(2 * self.config.depth))
387
+
388
+ def _init_weights(self, module):
389
+ if isinstance(module, nn.Linear):
390
+ torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
391
+ if module.bias is not None:
392
+ torch.nn.init.zeros_(module.bias)
393
+
394
+ def forward(
395
+ self,
396
+ x: torch.FloatTensor,
397
+ attention_mask: Optional[torch.BoolTensor] = None,
398
+ return_all_hiddens: bool = False,
399
+ ):
400
+ if x.shape[-1] != self.config.dim:
401
+ raise ValueError(
402
+ f"Input feature dim should be {self.config.dim}, but input has shape {x.shape}"
403
+ )
404
+ batch_size, seq_len = x.shape[:2]
405
+ should_unpad = attention_mask is not None and not attention_mask.all()
406
+ if should_unpad:
407
+ x, indices, cu_seqlens, max_seq_len_in_batch = unpad_input(
408
+ x, attention_mask
409
+ )
410
+ else:
411
+ indices, cu_seqlens, max_seq_len_in_batch = None, None, None
412
+ hiddens = []
413
+ for layer in self.layers:
414
+ x = layer(x, cu_seqlens, max_seq_len_in_batch)
415
+ if return_all_hiddens:
416
+ hiddens.append(x)
417
+
418
+ if should_unpad:
419
+ x = pad_input(x, indices, batch_size, seq_len)
420
+ if return_all_hiddens:
421
+ hiddens = [pad_input(h, indices, batch_size, seq_len)
422
+ for h in hiddens]
423
+
424
+ if return_all_hiddens:
425
+ return x, hiddens
426
+ return x
427
+
428
+
429
+ class gLM2PreTrainedModel(PreTrainedModel):
430
+ """
431
+ An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained
432
+ models.
433
+ """
434
+ config_class = gLM2Config
435
+ base_model_prefix = "glm2"
436
+ supports_gradient_checkpointing = False
437
+
438
+ # https://github.com/huggingface/transformers/blob/7032e0203262ebb2ebf55da8d2e01f873973e835/src/transformers/models/bert/modeling_bert.py#L748
439
+ def _init_weights(module, initializer_range=0.02):
440
+ if isinstance(module, nn.Linear):
441
+ nn.init.normal_(module.weight, std=initializer_range)
442
+ if module.bias is not None:
443
+ nn.init.zeros_(module.bias)
444
+ elif isinstance(module, nn.Embedding):
445
+ nn.init.normal_(module.weight, std=initializer_range)
446
+ if module.padding_idx is not None:
447
+ nn.init.zeros_(module.weight[module.padding_idx])
448
+
449
+
450
+ class gLM2Model(gLM2PreTrainedModel):
451
+ """gLM2 Model."""
452
+
453
+ def __init__(self, config: gLM2Config):
454
+ super().__init__(config)
455
+ self.config = config
456
+
457
+ self.tok_embeddings = nn.Embedding(config.vocab_size, config.dim)
458
+ self._init_weights(self.tok_embeddings)
459
+ self.encoder = TransformerLayers(config)
460
+
461
+ def _init_weights(self, module):
462
+ if isinstance(module, nn.Linear):
463
+ torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
464
+ if module.bias is not None:
465
+ torch.nn.init.zeros_(module.bias)
466
+ elif isinstance(module, nn.Embedding):
467
+ torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
468
+
469
+ def forward(
470
+ self,
471
+ input_ids: torch.Tensor,
472
+ attention_mask: Optional[torch.Tensor] = None,
473
+ output_hidden_states: Optional[bool] = None,
474
+ return_dict: Optional[bool] = None,
475
+ ) -> Union[Tuple[torch.Tensor], BaseModelOutput]:
476
+ output_hidden_states = (
477
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
478
+ )
479
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
480
+
481
+ h = self.tok_embeddings(input_ids)
482
+ if output_hidden_states:
483
+ sequence_output, all_hidden_states = self.encoder(
484
+ h, attention_mask, return_all_hiddens=True)
485
+ else:
486
+ sequence_output = self.encoder(h, attention_mask)
487
+ all_hidden_states = None
488
+
489
+ if not return_dict:
490
+ return (sequence_output, all_hidden_states)
491
+
492
+ return BaseModelOutput(
493
+ last_hidden_state=sequence_output,
494
+ hidden_states=all_hidden_states,
495
+
496
+ )
497
+
498
+
499
+ class MeanPooling(nn.Module):
500
+ def __init__(self):
501
+ super().__init__()
502
+
503
+ def forward(self, embeds: torch.Tensor, attention_mask: Optional[torch.Tensor] = None):
504
+ """Applies mean pooling.
505
+
506
+ Args:
507
+ embeds: [..., seq_len, hidden_dim].
508
+ attention_mask: [..., seq_len].
509
+
510
+ Returns:
511
+ Outputs of shape [..., hidden_dim].
512
+ """
513
+ if attention_mask is None:
514
+ return torch.mean(embeds, dim=-2)
515
+ mask = attention_mask.bool().unsqueeze(-1)
516
+ embeds = torch.where(mask, embeds, 0.0)
517
+ embeds = torch.sum(embeds, -2)
518
+ embeds /= torch.clamp(torch.sum(mask, dim=-2, dtype=embeds.dtype), min=1.0)
519
+ return embeds
520
+
521
+
522
+ class gLM2ForEmbedding(gLM2PreTrainedModel):
523
+ """gLM2 Embedding Model."""
524
+ config_class = gLM2EmbedConfig
525
+
526
+ def __init__(self, config: gLM2EmbedConfig):
527
+ super().__init__(config)
528
+ self.glm2 = gLM2Model(config)
529
+ self.pool = MeanPooling()
530
+ self.projection = nn.Linear(config.dim, config.projection_dim, bias=False)
531
+
532
+ def forward(
533
+ self,
534
+ input_ids: torch.Tensor,
535
+ attention_mask: Optional[torch.Tensor] = None,
536
+ ) -> Union[Tuple[torch.Tensor], BaseModelOutputWithPooling]:
537
+
538
+ hidden_states = self.glm2(
539
+ input_ids,
540
+ attention_mask=attention_mask,
541
+ output_hidden_states=False,
542
+ return_dict=True,
543
+ ).last_hidden_state
544
+
545
+ embeds = self.pool(hidden_states, attention_mask)
546
+ embeds = self.projection(embeds)
547
+ return BaseModelOutputWithPooling(
548
+ pooler_output=embeds,
549
+ )
550
+
551
+
552
+ class gLM2ForMaskedLM(gLM2PreTrainedModel):
553
+
554
+ def __init__(self, config: gLM2Config):
555
+ super().__init__(config)
556
+
557
+ self.glm2 = gLM2Model(config)
558
+ self.lm_head = gLM2LMHead(config)
559
+ self._init_weights(self.lm_head)
560
+
561
+ def _init_weights(self, module):
562
+ if isinstance(module, nn.Linear):
563
+ torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
564
+ if module.bias is not None:
565
+ torch.nn.init.zeros_(module.bias)
566
+ elif isinstance(module, nn.Embedding):
567
+ torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
568
+
569
+ def forward(
570
+ self,
571
+ input_ids: torch.Tensor,
572
+ attention_mask: Optional[torch.Tensor] = None,
573
+ labels: Optional[torch.LongTensor] = None,
574
+ output_hidden_states: Optional[bool] = None,
575
+ return_dict: Optional[bool] = None,
576
+ ) -> Union[Tuple, MaskedLMOutput]:
577
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
578
+
579
+ outputs = self.glm2(
580
+ input_ids,
581
+ attention_mask=attention_mask,
582
+ output_hidden_states=output_hidden_states,
583
+ return_dict=return_dict,
584
+ )
585
+ sequence_output = outputs[0]
586
+ prediction_scores = self.lm_head(sequence_output)
587
+
588
+ masked_lm_loss = None
589
+ if labels is not None:
590
+ loss_fct = CrossEntropyLoss()
591
+
592
+ labels = labels.to(prediction_scores.device)
593
+ masked_lm_loss = loss_fct(
594
+ prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
595
+
596
+ if not return_dict:
597
+ output = (prediction_scores,) + outputs[2:]
598
+ return ((masked_lm_loss,) + output) if masked_lm_loss is not None else output
599
+
600
+ return MaskedLMOutput(
601
+ loss=masked_lm_loss,
602
+ logits=prediction_scores,
603
+ hidden_states=outputs.hidden_states,
604
+ attentions=outputs.attentions,
605
+ )
606
+
607
+
608
+ class gLM2LMHead(nn.Module):
609
+ """gLM2 head for masked language modeling."""
610
+
611
+ def __init__(self, config):
612
+ super().__init__()
613
+
614
+ self.norm = RMSNorm(config.dim, eps=config.norm_eps)
615
+ self.proj_output = nn.Linear(
616
+ config.dim, config.vocab_size, bias=False)
617
+
618
+ def forward(self, features):
619
+ return self.proj_output(self.norm(features))