OpenNLPLab commited on
Commit
d73e1e2
·
1 Parent(s): 4b9f8cb

Publish code & model

Browse files
config.json ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "act_fun": "silu",
3
+ "add_bos_token": false,
4
+ "architectures": [
5
+ "HgrnForCausalLM"
6
+ ],
7
+ "auto_map": {
8
+ "AutoConfig": "configuration_hgrn.HgrnConfig",
9
+ "AutoModelForCausalLM": "modeling_hgrn.HgrnForCausalLM"
10
+ },
11
+ "bias": false,
12
+ "bos_token_id": 50260,
13
+ "causal": true,
14
+ "decoder_embed_dim": 768,
15
+ "decoder_layers": 14,
16
+ "eos_token_id": 50260,
17
+ "glu_act": "swish",
18
+ "glu_dim": 1536,
19
+ "init_std": 0.02,
20
+ "model_type": "hgrn",
21
+ "no_scale_embedding": false,
22
+ "norm_type": "layernorm",
23
+ "pad_token_id": null,
24
+ "torch_dtype": "float32",
25
+ "transformers_version": "4.31.0",
26
+ "use_cache": true,
27
+ "use_triton": false,
28
+ "vocab_size": 50272
29
+ }
configuration_hgrn.py ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ """ Hgrn configuration"""
3
+
4
+ from transformers.configuration_utils import PretrainedConfig
5
+ from transformers.utils import logging
6
+
7
+
8
+ logger = logging.get_logger(__name__)
9
+
10
+ LLAMA_PRETRAINED_CONFIG_ARCHIVE_MAP = {}
11
+
12
+
13
+ class HgrnConfig(PretrainedConfig):
14
+ model_type = "hgrn"
15
+ keys_to_ignore_at_inference = ["past_key_values"]
16
+
17
+ def __init__(
18
+ self,
19
+ pad_token_id=1,
20
+ bos_token_id=0,
21
+ eos_token_id=2,
22
+ vocab_size=50272,
23
+ use_cache=True,
24
+ init_std=0.02,
25
+ # model config
26
+ decoder_embed_dim=1024,
27
+ decoder_layers=24,
28
+ add_bos_token=False,
29
+ act_fun="swish",
30
+ causal=True,
31
+ use_triton=False,
32
+ glu_act="swish",
33
+ glu_dim=2816,
34
+ bias=False,
35
+ norm_type="layernorm",
36
+ no_scale_embedding=False,
37
+ **kwargs,
38
+ ):
39
+ super().__init__(
40
+ pad_token_id=pad_token_id,
41
+ bos_token_id=bos_token_id,
42
+ eos_token_id=eos_token_id,
43
+ **kwargs,
44
+ )
45
+ # hf origin
46
+ self.vocab_size = vocab_size
47
+ self.use_cache = use_cache
48
+ self.init_std = init_std
49
+ # add
50
+ self.decoder_embed_dim = decoder_embed_dim
51
+ self.decoder_layers = decoder_layers
52
+ self.add_bos_token = add_bos_token
53
+ self.act_fun = act_fun
54
+ self.causal = causal
55
+ self.use_triton = use_triton
56
+ self.glu_act = glu_act
57
+ self.glu_dim = glu_dim
58
+ self.bias = bias
59
+ self.norm_type = norm_type
60
+ self.no_scale_embedding = no_scale_embedding
generation_config.json ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ {
2
+ "_from_model_config": true,
3
+ "bos_token_id": 50260,
4
+ "eos_token_id": 50260,
5
+ "transformers_version": "4.31.0"
6
+ }
hgrn.png ADDED
merges.txt ADDED
The diff for this file is too large to render. See raw diff
 
modeling_hgrn.py ADDED
@@ -0,0 +1,563 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ """ PyTorch Hgrn model."""
3
+ import math
4
+ from typing import List, Optional, Tuple, Union
5
+
6
+ import torch
7
+ import torch.utils.checkpoint
8
+ from torch import nn
9
+ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
10
+ from dataclasses import dataclass
11
+ import torch.nn.functional as F
12
+
13
+ from transformers.activations import ACT2FN
14
+ from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast, SequenceClassifierOutputWithPast
15
+ from transformers.modeling_utils import PreTrainedModel
16
+ from transformers.utils import add_start_docstrings, add_start_docstrings_to_model_forward, logging, replace_return_docstrings
17
+ from transformers.utils import ModelOutput
18
+
19
+ from .configuration_hgrn import HgrnConfig
20
+ from .utils import print_module, get_activation_fn, get_norm_fn, print_params, logging_info
21
+ from .norm import SimpleRMSNorm
22
+ from hgru import Hgru1d
23
+
24
+ from einops import rearrange
25
+ import numpy as np
26
+
27
+ logger = logging.get_logger(__name__)
28
+
29
+ _CONFIG_FOR_DOC = "HgrnConfig"
30
+
31
+ class GLU(nn.Module):
32
+ def __init__(self, d1, d2, act_fun, bias=False):
33
+ super().__init__()
34
+ # get local varables
35
+ params = locals()
36
+ # print params
37
+ print_params(**params)
38
+
39
+ self.l1 = nn.Linear(d1, d2, bias=bias)
40
+ self.l2 = nn.Linear(d1, d2, bias=bias)
41
+ self.l3 = nn.Linear(d2, d1, bias=bias)
42
+ self.act_fun = get_activation_fn(act_fun)
43
+
44
+ def forward(self, x):
45
+ o1 = self.act_fun(self.l1(x))
46
+ o2 = self.l2(x)
47
+ output = o1 * o2
48
+ output = self.l3(output)
49
+
50
+ return output
51
+
52
+ class HgrnDecoderLayer(nn.Module):
53
+ def __init__(
54
+ self, config: HgrnConfig
55
+ ):
56
+ super().__init__()
57
+ self.embed_dim = config.decoder_embed_dim
58
+ ##### token mixer
59
+ self.token_mixer = Hgru1d(
60
+ self.embed_dim,
61
+ act_fun=config.act_fun,
62
+ causal=config.causal,
63
+ use_triton=config.use_triton,
64
+ bias=config.bias,
65
+ )
66
+ self.token_norm = get_norm_fn(config.norm_type)(self.embed_dim)
67
+
68
+ ##### channel mixer
69
+ self.glu_act = config.glu_act
70
+ self.glu_dim = config.glu_dim
71
+ self.channel_mixer = GLU(self.embed_dim, self.glu_dim, self.glu_act, bias=config.bias)
72
+ self.channel_norm = get_norm_fn(config.norm_type)(self.embed_dim)
73
+
74
+ def forward(
75
+ self,
76
+ x,
77
+ padding_mask: Optional[torch.Tensor] = None,
78
+ lower_bound: Optional[torch.Tensor] = None,
79
+ ):
80
+ # current does not support padding_mask!
81
+ x = self.token_mixer(self.token_norm(x), lower_bound) + x
82
+ x = self.channel_mixer(self.channel_norm(x)) + x
83
+
84
+ outputs = x
85
+
86
+ return outputs, None
87
+
88
+ HGRN_START_DOCSTRING = r"""
89
+ This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
90
+ library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
91
+ etc.)
92
+
93
+ This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
94
+ Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
95
+ and behavior.
96
+
97
+ Parameters:
98
+ config ([`HgrnConfig`]):
99
+ Model configuration class with all the parameters of the model. Initializing with a config file does not
100
+ load the weights associated with the model, only the configuration. Check out the
101
+ [`~PreTrainedModel.from_pretrained`] method to load the model weights.
102
+ """
103
+
104
+
105
+ @add_start_docstrings(
106
+ HGRN_START_DOCSTRING,
107
+ )
108
+ class HgrnPreTrainedModel(PreTrainedModel):
109
+ config_class = HgrnConfig
110
+ base_model_prefix = "model"
111
+ supports_gradient_checkpointing = True
112
+ _no_split_modules = ["HgrnDecoderLayer"]
113
+ _skip_keys_device_placement = "past_key_values"
114
+ _keys_to_ignore_on_load_unexpected = [r"decoder\.version"]
115
+
116
+ def _init_weights(self, module):
117
+ std = self.config.init_std
118
+ if isinstance(module, nn.Linear):
119
+ module.weight.data.normal_(mean=0.0, std=std)
120
+ if module.bias is not None:
121
+ module.bias.data.zero_()
122
+ elif isinstance(module, nn.Embedding):
123
+ module.weight.data.normal_(mean=0.0, std=std)
124
+ if module.padding_idx is not None:
125
+ module.weight.data[module.padding_idx].zero_()
126
+
127
+ def _set_gradient_checkpointing(self, module, value=False):
128
+ if isinstance(module, HgrnModel):
129
+ module.gradient_checkpointing = value
130
+
131
+ @dataclass
132
+ class HgrnModelOutputWithPast(ModelOutput):
133
+ last_hidden_state: torch.FloatTensor = None
134
+ cache_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None
135
+
136
+ HGRN_INPUTS_DOCSTRING = r"""
137
+ Args:
138
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
139
+ Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
140
+ it.
141
+
142
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
143
+ [`PreTrainedTokenizer.__call__`] for details.
144
+
145
+ [What are input IDs?](../glossary#input-ids)
146
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
147
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
148
+
149
+ - 1 for tokens that are **not masked**,
150
+ - 0 for tokens that are **masked**.
151
+
152
+ [What are attention masks?](../glossary#attention-mask)
153
+
154
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
155
+ [`PreTrainedTokenizer.__call__`] for details.
156
+
157
+ If `past_key_values` is used, optionally only the last `decoder_input_ids` have to be input (see
158
+ `past_key_values`).
159
+
160
+ If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
161
+ and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
162
+ information on the default strategy.
163
+
164
+ - 1 indicates the head is **not masked**,
165
+ - 0 indicates the head is **masked**.
166
+ position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
167
+ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
168
+ config.n_positions - 1]`.
169
+
170
+ [What are position IDs?](../glossary#position-ids)
171
+ past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*, returned when `use_cache=True` is passed or when `config.use_cache=True`):
172
+ Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of shape
173
+ `(batch_size, num_heads, sequence_length, embed_size_per_head)`) and 2 additional tensors of shape
174
+ `(batch_size, num_heads, encoder_sequence_length, embed_size_per_head)`.
175
+
176
+ Contains pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
177
+ blocks) that can be used (see `past_key_values` input) to speed up sequential decoding.
178
+
179
+ If `past_key_values` are used, the user can optionally input only the last `decoder_input_ids` (those that
180
+ don't have their past key value states given to this model) of shape `(batch_size, 1)` instead of all
181
+ `decoder_input_ids` of shape `(batch_size, sequence_length)`.
182
+ use_cache (`bool`, *optional*):
183
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
184
+ `past_key_values`).
185
+ output_attentions (`bool`, *optional*):
186
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
187
+ tensors for more detail.
188
+ output_hidden_states (`bool`, *optional*):
189
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
190
+ more detail.
191
+ return_dict (`bool`, *optional*):
192
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
193
+ """
194
+
195
+
196
+ @add_start_docstrings(
197
+ HGRN_START_DOCSTRING,
198
+ )
199
+ class HgrnModel(HgrnPreTrainedModel):
200
+ """
201
+ Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`HgrnDecoderLayer`]
202
+
203
+ Args:
204
+ config: HgrnConfig
205
+ """
206
+
207
+ def __init__(self, config: HgrnConfig):
208
+ super().__init__(config)
209
+ # hf origin
210
+ self.padding_idx = config.pad_token_id
211
+ self.vocab_size = config.vocab_size
212
+ self.gradient_checkpointing = False
213
+
214
+ # params
215
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.decoder_embed_dim, self.padding_idx)
216
+ self.layers = nn.ModuleList([HgrnDecoderLayer(config) for i in range(config.decoder_layers)])
217
+ self.final_norm = get_norm_fn(config.norm_type)(config.decoder_embed_dim)
218
+ self.embed_dim = config.decoder_embed_dim
219
+ self.embed_scale = 1.0 if config.no_scale_embedding else math.sqrt(self.embed_dim)
220
+ self.num_layers = config.decoder_layers
221
+ self.lower_bounds = nn.Parameter(torch.ones(self.num_layers, self.embed_dim), requires_grad=True)
222
+
223
+ # Initialize weights and apply final processing
224
+ self.post_init()
225
+
226
+ def extra_repr(self):
227
+ return print_module(self)
228
+
229
+ def get_input_embeddings(self):
230
+ return self.embed_tokens
231
+
232
+ def set_input_embeddings(self, value):
233
+ self.embed_tokens = value
234
+
235
+ @add_start_docstrings_to_model_forward(HGRN_INPUTS_DOCSTRING)
236
+ def forward(
237
+ self,
238
+ input_ids: torch.LongTensor = None,
239
+ padding_mask: Optional[torch.Tensor] = None,
240
+ inputs_embeds: Optional[torch.FloatTensor] = None,
241
+ return_dict: Optional[bool] = None,
242
+ ) -> Union[Tuple, BaseModelOutputWithPast]:
243
+ if not self.training and padding_mask != None and padding_mask.eq(self.padding_idx):
244
+ raise ValueError("During the inference stage, attn_padding_mask should be either None or should not include the pad token.")
245
+
246
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
247
+
248
+ # retrieve input_ids and inputs_embeds
249
+ if input_ids is not None and inputs_embeds is not None:
250
+ raise ValueError("You cannot specify both decoder_input_ids and decoder_inputs_embeds at the same time")
251
+ elif input_ids is not None:
252
+ batch_size, seq_length = input_ids.shape
253
+ elif inputs_embeds is not None:
254
+ batch_size, seq_length, _ = inputs_embeds.shape
255
+ else:
256
+ raise ValueError("You have to specify either decoder_input_ids or decoder_inputs_embeds")
257
+
258
+ if inputs_embeds is None:
259
+ # !!! use embed_scale
260
+ inputs_embeds = self.embed_scale * self.embed_tokens(input_ids)
261
+
262
+ hidden_states = inputs_embeds
263
+
264
+ cache_values = ()
265
+
266
+ # lower bound
267
+ lower_bounds = self.lower_bounds
268
+ lower_bounds = F.softmax(lower_bounds, dim=0)
269
+ lower_bounds = torch.cumsum(lower_bounds, dim=0)
270
+ lower_bounds -= lower_bounds[0, ...].clone()
271
+
272
+ # b, n, d -> n, b, d
273
+ hidden_states = hidden_states.transpose(1, 0)
274
+
275
+ for idx, layer in enumerate(self.layers):
276
+ lower_bound = lower_bounds[idx]
277
+
278
+ if self.gradient_checkpointing and self.training:
279
+
280
+ def create_custom_forward(module):
281
+ def custom_forward(*inputs):
282
+ # None for past_key_value
283
+ return module(*inputs, None)
284
+
285
+ return custom_forward
286
+
287
+ layer_outputs = torch.utils.checkpoint.checkpoint(
288
+ create_custom_forward(layer),
289
+ hidden_states,
290
+ padding_mask,
291
+ lower_bound,
292
+ )
293
+ else:
294
+ layer_outputs = layer(
295
+ hidden_states,
296
+ padding_mask,
297
+ lower_bound,
298
+ )
299
+
300
+ hidden_states = layer_outputs[0]
301
+
302
+ # tbd
303
+ cache_values += (layer_outputs[1],)
304
+
305
+ hidden_states = self.final_norm(hidden_states)
306
+
307
+ # n, b, d -> b, n, d
308
+ hidden_states = hidden_states.transpose(1, 0)
309
+
310
+ if not return_dict:
311
+ return tuple(v for v in [hidden_states, cache_values] if v is not None)
312
+ return HgrnModelOutputWithPast(
313
+ last_hidden_state=hidden_states,
314
+ cache_values=cache_values
315
+ )
316
+
317
+
318
+ class HgrnForCausalLM(HgrnPreTrainedModel):
319
+ def __init__(self, config):
320
+ super().__init__(config)
321
+ self.model = HgrnModel(config)
322
+
323
+ # the lm_head weight is automatically tied to the embed tokens weight
324
+ self.lm_head = nn.Linear(config.decoder_embed_dim, config.vocab_size, bias=False)
325
+
326
+ # Initialize weights and apply final processing
327
+ self.post_init()
328
+
329
+ def get_input_embeddings(self):
330
+ return self.model.embed_tokens
331
+
332
+ def set_input_embeddings(self, value):
333
+ self.model.embed_tokens = value
334
+
335
+ def get_output_embeddings(self):
336
+ return self.lm_head
337
+
338
+ def set_output_embeddings(self, new_embeddings):
339
+ self.lm_head = new_embeddings
340
+
341
+ def set_decoder(self, decoder):
342
+ self.model = decoder
343
+
344
+ def get_decoder(self):
345
+ return self.model
346
+
347
+ @add_start_docstrings_to_model_forward(HGRN_INPUTS_DOCSTRING)
348
+ @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
349
+ def forward(
350
+ self,
351
+ input_ids: torch.LongTensor = None,
352
+ attention_mask: Optional[torch.Tensor] = None,
353
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
354
+ inputs_embeds: Optional[torch.FloatTensor] = None,
355
+ labels: Optional[torch.LongTensor] = None,
356
+ use_cache: Optional[bool] = None,
357
+ output_attentions: Optional[bool] = None,
358
+ output_hidden_states: Optional[bool] = None,
359
+ return_dict: Optional[bool] = None,
360
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
361
+ r"""
362
+ Args:
363
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
364
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
365
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
366
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
367
+
368
+ Returns:
369
+
370
+ Example:
371
+
372
+ ```python
373
+ >>> from transformers import AutoTokenizer, HgrnForCausalLM
374
+
375
+ >>> model = HgrnForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS)
376
+ >>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER)
377
+
378
+ >>> prompt = "Hey, are you consciours? Can you talk to me?"
379
+ >>> inputs = tokenizer(prompt, return_tensors="pt")
380
+
381
+ >>> # Generate
382
+ >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
383
+ >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
384
+ "Hey, are you consciours? Can you talk to me?\nI'm not consciours, but I can talk to you."
385
+ ```"""
386
+
387
+ output_hidden_states = (
388
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
389
+ )
390
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
391
+
392
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
393
+ outputs = self.model(
394
+ input_ids=input_ids,
395
+ padding_mask=attention_mask,
396
+ inputs_embeds=inputs_embeds,
397
+ return_dict=return_dict,
398
+ )
399
+
400
+ hidden_states = outputs[0]
401
+ logits = self.lm_head(hidden_states)
402
+
403
+ loss = None
404
+ if labels is not None:
405
+ # Shift so that tokens < n predict n
406
+ shift_logits = logits[..., :-1, :].contiguous()
407
+ shift_labels = labels[..., 1:].contiguous()
408
+ # Flatten the tokens
409
+ loss_fct = CrossEntropyLoss()
410
+ shift_logits = shift_logits.view(-1, self.config.vocab_size)
411
+ shift_labels = shift_labels.view(-1)
412
+ # Enable model parallelism
413
+ shift_labels = shift_labels.to(shift_logits.device)
414
+ loss = loss_fct(shift_logits, shift_labels)
415
+
416
+ if not return_dict:
417
+ output = (logits,) + outputs[1:]
418
+ return (loss,) + output if loss is not None else output
419
+
420
+ return CausalLMOutputWithPast(
421
+ loss=loss,
422
+ logits=logits,
423
+ past_key_values=outputs.cache_values,
424
+ )
425
+
426
+ def prepare_inputs_for_generation(
427
+ self, input_ids, past_key_values=None, attn_padding_mask=None, inputs_embeds=None, **kwargs
428
+ ):
429
+ if past_key_values:
430
+ input_ids = input_ids[:, -1:]
431
+
432
+ # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
433
+ if inputs_embeds is not None and past_key_values is None:
434
+ model_inputs = {"inputs_embeds": inputs_embeds}
435
+ else:
436
+ model_inputs = {"input_ids": input_ids}
437
+
438
+ model_inputs.update(
439
+ {
440
+ }
441
+ )
442
+ return model_inputs
443
+
444
+ @staticmethod
445
+ def _reorder_cache(past_key_values, beam_idx):
446
+ reordered_past = ()
447
+ for layer_past in past_key_values:
448
+ reordered_past += (tuple(past_state.index_select(0, beam_idx) for past_state in layer_past),)
449
+ return reordered_past
450
+
451
+
452
+ @add_start_docstrings(
453
+ """
454
+ The LLaMa Model transformer with a sequence classification head on top (linear layer).
455
+
456
+ [`HgrnForSequenceClassification`] uses the last token in order to do the classification, as other causal models
457
+ (e.g. GPT-2) do.
458
+
459
+ Since it does classification on the last token, it requires to know the position of the last token. If a
460
+ `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If
461
+ no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the
462
+ padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in
463
+ each row of the batch).
464
+ """,
465
+ HGRN_START_DOCSTRING,
466
+ )
467
+ class HgrnForSequenceClassification(HgrnPreTrainedModel):
468
+ _keys_to_ignore_on_load_missing = [r"lm_head.weight"]
469
+
470
+ def __init__(self, config):
471
+ super().__init__(config)
472
+ self.num_labels = config.num_labels
473
+ self.model = HgrnModel(config)
474
+ self.score = nn.Linear(config.decoder_embed_dim, self.num_labels, bias=False)
475
+
476
+ # Initialize weights and apply final processing
477
+ self.post_init()
478
+
479
+ def get_input_embeddings(self):
480
+ return self.model.embed_tokens
481
+
482
+ def set_input_embeddings(self, value):
483
+ self.model.embed_tokens = value
484
+
485
+ @add_start_docstrings_to_model_forward(HGRN_INPUTS_DOCSTRING)
486
+ def forward(
487
+ self,
488
+ input_ids: torch.LongTensor = None,
489
+ attention_mask: Optional[torch.Tensor] = None,
490
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
491
+ inputs_embeds: Optional[torch.FloatTensor] = None,
492
+ labels: Optional[torch.LongTensor] = None,
493
+ use_cache: Optional[bool] = None,
494
+ output_attentions: Optional[bool] = None,
495
+ output_hidden_states: Optional[bool] = None,
496
+ return_dict: Optional[bool] = None,
497
+ ) -> Union[Tuple, SequenceClassifierOutputWithPast]:
498
+ r"""
499
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
500
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
501
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
502
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
503
+ """
504
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
505
+
506
+ outputs = self.model(
507
+ input_ids=input_ids,
508
+ padding_mask=attention_mask,
509
+ inputs_embeds=inputs_embeds,
510
+ return_dict=return_dict,
511
+ )
512
+ hidden_states = outputs[0]
513
+ logits = self.score(hidden_states)
514
+
515
+ if input_ids is not None:
516
+ batch_size = input_ids.shape[0]
517
+ else:
518
+ batch_size = inputs_embeds.shape[0]
519
+
520
+ if self.config.pad_token_id is None and batch_size != 1:
521
+ raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")
522
+ if self.config.pad_token_id is None:
523
+ sequence_lengths = -1
524
+ else:
525
+ if input_ids is not None:
526
+ sequence_lengths = (torch.ne(input_ids, self.config.pad_token_id).sum(-1) - 1).to(logits.device)
527
+ else:
528
+ sequence_lengths = -1
529
+
530
+ pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths]
531
+
532
+ loss = None
533
+ if labels is not None:
534
+ labels = labels.to(logits.device)
535
+ if self.config.problem_type is None:
536
+ if self.num_labels == 1:
537
+ self.config.problem_type = "regression"
538
+ elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
539
+ self.config.problem_type = "single_label_classification"
540
+ else:
541
+ self.config.problem_type = "multi_label_classification"
542
+
543
+ if self.config.problem_type == "regression":
544
+ loss_fct = MSELoss()
545
+ if self.num_labels == 1:
546
+ loss = loss_fct(pooled_logits.squeeze(), labels.squeeze())
547
+ else:
548
+ loss = loss_fct(pooled_logits, labels)
549
+ elif self.config.problem_type == "single_label_classification":
550
+ loss_fct = CrossEntropyLoss()
551
+ loss = loss_fct(pooled_logits.view(-1, self.num_labels), labels.view(-1))
552
+ elif self.config.problem_type == "multi_label_classification":
553
+ loss_fct = BCEWithLogitsLoss()
554
+ loss = loss_fct(pooled_logits, labels)
555
+ if not return_dict:
556
+ output = (pooled_logits,) + outputs[1:]
557
+ return ((loss,) + output) if loss is not None else output
558
+
559
+ return SequenceClassifierOutputWithPast(
560
+ loss=loss,
561
+ logits=pooled_logits,
562
+ hidden_states=outputs.hidden_states,
563
+ )
norm.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+
4
+ class SimpleRMSNorm(nn.Module):
5
+ def __init__(self, dim: int, eps: float = 1e-6):
6
+ super().__init__()
7
+ self.eps = eps
8
+
9
+ def _norm(self, x):
10
+ return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
11
+
12
+ def forward(self, x):
13
+ output = self._norm(x.float()).type_as(x)
14
+
15
+ return output
pytorch_model.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f43d6715865308087c6e054982e409d8ed783a1a671233d1924bb56671d7fd78
3
+ size 584327661
special_tokens_map.json ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "bos_token": {
3
+ "content": "<|endoftext|>",
4
+ "lstrip": false,
5
+ "normalized": true,
6
+ "rstrip": false,
7
+ "single_word": false
8
+ },
9
+ "eos_token": {
10
+ "content": "<|endoftext|>",
11
+ "lstrip": false,
12
+ "normalized": true,
13
+ "rstrip": false,
14
+ "single_word": false
15
+ },
16
+ "unk_token": {
17
+ "content": "<|endoftext|>",
18
+ "lstrip": false,
19
+ "normalized": true,
20
+ "rstrip": false,
21
+ "single_word": false
22
+ }
23
+ }
tokenizer_config.json ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "add_bos_token": false,
3
+ "add_prefix_space": false,
4
+ "bos_token": {
5
+ "__type": "AddedToken",
6
+ "content": "<|endoftext|>",
7
+ "lstrip": false,
8
+ "normalized": true,
9
+ "rstrip": false,
10
+ "single_word": false
11
+ },
12
+ "clean_up_tokenization_spaces": true,
13
+ "eos_token": {
14
+ "__type": "AddedToken",
15
+ "content": "<|endoftext|>",
16
+ "lstrip": false,
17
+ "normalized": true,
18
+ "rstrip": false,
19
+ "single_word": false
20
+ },
21
+ "errors": "replace",
22
+ "model_max_length": 1000000000000000019884624838656,
23
+ "pad_token": null,
24
+ "tokenizer_class": "GPT2Tokenizer",
25
+ "unk_token": {
26
+ "__type": "AddedToken",
27
+ "content": "<|endoftext|>",
28
+ "lstrip": false,
29
+ "normalized": true,
30
+ "rstrip": false,
31
+ "single_word": false
32
+ }
33
+ }
utils.py ADDED
@@ -0,0 +1,122 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import os
3
+ import sys
4
+
5
+ import torch
6
+ import torch.distributed as dist
7
+ import torch.nn.functional as F
8
+ from torch import nn
9
+
10
+ from .norm import SimpleRMSNorm
11
+
12
+ logging.basicConfig(
13
+ format="%(asctime)s | %(levelname)s | %(name)s | %(message)s",
14
+ datefmt="%Y-%m-%d %H:%M:%S",
15
+ level=os.environ.get("LOGLEVEL", "INFO").upper(),
16
+ stream=sys.stdout,
17
+ )
18
+ logger = logging.getLogger("print_config")
19
+
20
+ BASE_DIM = 256
21
+
22
+ def is_dist_avail_and_initialized():
23
+ if not dist.is_available():
24
+ return False
25
+ if not dist.is_initialized():
26
+ return False
27
+ return True
28
+
29
+ def get_world_size():
30
+ if not is_dist_avail_and_initialized():
31
+ return 1
32
+ return dist.get_world_size()
33
+
34
+ def get_rank():
35
+ if not is_dist_avail_and_initialized():
36
+ return 0
37
+ return dist.get_rank()
38
+
39
+ def is_main_process():
40
+ return get_rank() == 0
41
+
42
+ def logging_info(string):
43
+ if is_main_process():
44
+ logger.info(string)
45
+
46
+ def print_params(**kwargs):
47
+ if is_main_process():
48
+ logger.info(f"start print config of {kwargs['__class__']}")
49
+ for key in kwargs:
50
+ if key in ["__class__", "self"]:
51
+ continue
52
+ logger.info(f"{key}: {kwargs[key]}")
53
+ logger.info(f"end print config of {kwargs['__class__']}")
54
+
55
+ def print_config(config):
56
+ if is_main_process():
57
+ logger.info(f"start print config of {config['__class__']}")
58
+ for key in config:
59
+ if key in ["__class__", "self"]:
60
+ continue
61
+ logger.info(f"{key}: {config[key]}")
62
+ logger.info(f"end print config of {config['__class__']}")
63
+
64
+ def print_module(module):
65
+ named_modules = set()
66
+ for p in module.named_modules():
67
+ named_modules.update([p[0]] )
68
+ named_modules = list(named_modules)
69
+
70
+ string_repr = ''
71
+ for p in module.named_parameters():
72
+ name = p[0].split('.')[0]
73
+ if name not in named_modules:
74
+ string_repr = string_repr + '('+ name +'): ' \
75
+ +'Tensor(' + str(tuple(p[1].shape))+ ', requires_grad='+ str(p[1].requires_grad) +')\n'
76
+
77
+ return string_repr.rstrip("\n")
78
+
79
+ def get_activation_fn(activation):
80
+ logger.info(f"activation: {activation}")
81
+ if activation == "gelu":
82
+ return F.gelu
83
+ elif activation == "relu":
84
+ return F.relu
85
+ elif activation == "elu":
86
+ return F.elu
87
+ elif activation == "sigmoid":
88
+ return F.sigmoid
89
+ elif activation == "exp":
90
+ def f(x):
91
+ with torch.no_grad():
92
+ x_max = torch.max(x, dim=-1, keepdims=True).values
93
+ y = torch.exp(x - x_max)
94
+
95
+ return y
96
+ return f
97
+ elif activation == "leak":
98
+ return F.leaky_relu
99
+ elif activation == "1+elu":
100
+ def f(x):
101
+ return 1 + F.elu(x)
102
+ return f
103
+ elif activation == "2+elu":
104
+ def f(x):
105
+ return 2 + F.elu(x)
106
+ return f
107
+ elif activation == "silu" or activation == "swish":
108
+ return F.silu
109
+ elif activation == "sine":
110
+ return torch.sin
111
+ else:
112
+ logger.info(f"activation: does not support {activation}, use Identity!!!")
113
+ return lambda x: x
114
+
115
+ def get_norm_fn(norm_type):
116
+ if norm_type == "simplermsnorm":
117
+ return SimpleRMSNorm
118
+ else:
119
+ return nn.LayerNorm
120
+
121
+ def convert_to_multiple_of_base(x):
122
+ return BASE_DIM * ((x + BASE_DIM - 1) // BASE_DIM)
vocab.json ADDED
The diff for this file is too large to render. See raw diff