manaestras commited on
Commit
e0f117d
·
verified ·
1 Parent(s): e1488e8

Upload ./hunyuan.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. hunyuan.py +879 -0
hunyuan.py ADDED
@@ -0,0 +1,879 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright (C) 2024 THL A29 Limited, a Tencent company. All rights reserved.
3
+ #
4
+ """ PyTorch HunYuan model."""
5
+
6
+ import math
7
+ import warnings
8
+ from typing import List, Optional, Tuple, Union
9
+
10
+ import torch
11
+ from torch import Tensor
12
+ import torch.nn.functional as F
13
+ import torch.utils.checkpoint
14
+ from torch import nn
15
+ from torch.nn import BCEWithLogitsLoss, CrossEntropyLoss, MSELoss
16
+
17
+ from transformers.activations import ACT2FN
18
+ from transformers.cache_utils import Cache, DynamicCache
19
+ from transformers.modeling_attn_mask_utils import (
20
+ AttentionMaskConverter,
21
+ _prepare_4d_attention_mask,
22
+ _prepare_4d_causal_attention_mask,
23
+ _prepare_4d_causal_attention_mask_for_sdpa,
24
+ )
25
+ from transformers.modeling_outputs import (
26
+ BaseModelOutputWithPast,
27
+ CausalLMOutputWithPast,
28
+ SequenceClassifierOutputWithPast
29
+ )
30
+ from transformers.modeling_utils import PreTrainedModel
31
+ from transformers.pytorch_utils import ALL_LAYERNORM_LAYERS, is_torch_greater_or_equal_than_1_13
32
+ from transformers.utils import (
33
+ add_start_docstrings,
34
+ add_start_docstrings_to_model_forward,
35
+ is_flash_attn_2_available,
36
+ is_flash_attn_greater_or_equal_2_10,
37
+ logging,
38
+ replace_return_docstrings,
39
+ )
40
+ from transformers.utils.import_utils import is_torch_fx_available
41
+ from transformers.generation.utils import GenerateOutput
42
+ from .configuration_hunyuan import HunYuanConfig
43
+ from .modeling_hunyuan import HunYuanDecoderLayer, HunYuanRMSNorm
44
+ from .vit_model import NaVitForward, VitForward, Vit
45
+
46
+
47
+ if is_flash_attn_2_available():
48
+ from flash_attn import flash_attn_func, flash_attn_varlen_func
49
+ from flash_attn.bert_padding import index_first_axis, pad_input, unpad_input # noqa
50
+
51
+
52
+ # This makes `_prepare_4d_causal_attention_mask` a leaf function in the FX graph.
53
+ # It means that the function will not be traced through and simply appear as a node in the graph.
54
+ if is_torch_fx_available():
55
+ if not is_torch_greater_or_equal_than_1_13:
56
+ import torch.fx
57
+
58
+ _prepare_4d_causal_attention_mask = torch.fx.wrap(_prepare_4d_causal_attention_mask)
59
+
60
+
61
+
62
+ _CONFIG_FOR_DOC = "HunYuanConfig"
63
+
64
+
65
+ HUNYUAN_START_DOCSTRING = r"""
66
+ This model inherits from [`PreTrainedModel`]. Check the superclass documentation for the generic methods the
67
+ library implements for all its model (such as downloading or saving, resizing the input embeddings, pruning heads
68
+ etc.)
69
+
70
+ This model is also a PyTorch [torch.nn.Module](https://pytorch.org/docs/stable/nn.html#torch.nn.Module) subclass.
71
+ Use it as a regular PyTorch Module and refer to the PyTorch documentation for all matter related to general usage
72
+ and behavior.
73
+
74
+ Parameters:
75
+ config ([`HunYuanConfig`]):
76
+ Model configuration class with all the parameters of the model. Initializing with a config file does not
77
+ load the weights associated with the model, only the configuration. Check out the
78
+ [`~PreTrainedModel.from_pretrained`] method to load the model weights.
79
+ """
80
+
81
+
82
+ @add_start_docstrings(
83
+ "The bare HunYuan Model outputting raw hidden-states without any specific head on top.",
84
+ HUNYUAN_START_DOCSTRING,
85
+ )
86
+ class HunYuanPreTrainedModel(PreTrainedModel):
87
+ config_class = HunYuanConfig
88
+ base_model_prefix = "model"
89
+ supports_gradient_checkpointing = True
90
+ _no_split_modules = ["HunYuanDecoderLayer"]
91
+ _skip_keys_device_placement = "past_key_values"
92
+ _supports_flash_attn_2 = True
93
+ _supports_sdpa = True
94
+ _supports_cache_class = True
95
+
96
+ def _init_weights(self, module):
97
+ std = self.config.initializer_range
98
+ if isinstance(module, nn.Linear):
99
+ module.weight.data.normal_(mean=0.0, std=std)
100
+ if module.bias is not None:
101
+ module.bias.data.zero_()
102
+ elif isinstance(module, nn.Embedding):
103
+ module.weight.data.normal_(mean=0.0, std=std)
104
+ if module.padding_idx is not None:
105
+ module.weight.data[module.padding_idx].zero_()
106
+
107
+
108
+ HUNYUAN_INPUTS_DOCSTRING = r"""
109
+ Args:
110
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
111
+ Indices of input sequence tokens in the vocabulary. Padding will be ignored by default should you provide
112
+ it.
113
+
114
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
115
+ [`PreTrainedTokenizer.__call__`] for details.
116
+
117
+ [What are input IDs?](../glossary#input-ids)
118
+ attention_mask (`torch.Tensor` of shape `(batch_size, sequence_length)`, *optional*):
119
+ Mask to avoid performing attention on padding token indices. Mask values selected in `[0, 1]`:
120
+
121
+ - 1 for tokens that are **not masked**,
122
+ - 0 for tokens that are **masked**.
123
+
124
+ [What are attention masks?](../glossary#attention-mask)
125
+
126
+ Indices can be obtained using [`AutoTokenizer`]. See [`PreTrainedTokenizer.encode`] and
127
+ [`PreTrainedTokenizer.__call__`] for details.
128
+
129
+ If `past_key_values` is used, optionally only the last `input_ids` have to be input (see
130
+ `past_key_values`).
131
+
132
+ If you want to change padding behavior, you should read [`modeling_opt._prepare_decoder_attention_mask`]
133
+ and modify to your needs. See diagram 1 in [the paper](https://arxiv.org/abs/1910.13461) for more
134
+ information on the default strategy.
135
+
136
+ - 1 indicates the head is **not masked**,
137
+ - 0 indicates the head is **masked**.
138
+ position_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
139
+ Indices of positions of each input sequence tokens in the position embeddings. Selected in the range `[0,
140
+ config.n_positions - 1]`.
141
+
142
+ [What are position IDs?](../glossary#position-ids)
143
+ past_key_values (`Cache` or `tuple(tuple(torch.FloatTensor))`, *optional*):
144
+ Pre-computed hidden-states (key and values in the self-attention blocks and in the cross-attention
145
+ blocks) that can be used to speed up sequential decoding. This typically consists in the `past_key_values`
146
+ returned by the model at a previous stage of decoding, when `use_cache=True` or `config.use_cache=True`.
147
+
148
+ Two formats are allowed:
149
+ - a [`~cache_utils.Cache`] instance;
150
+ - Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`, with each tuple having 2 tensors of
151
+ shape `(batch_size, num_heads, sequence_length, embed_size_per_head)`). This is also known as the legacy
152
+ cache format.
153
+
154
+ The model will output the same cache format that is fed as input. If no `past_key_values` are passed, the
155
+ legacy cache format will be returned.
156
+
157
+ If `past_key_values` are used, the user can optionally input only the last `input_ids` (those that don't
158
+ have their past key value states given to this model) of shape `(batch_size, 1)` instead of all `input_ids`
159
+ of shape `(batch_size, sequence_length)`.
160
+ inputs_embeds (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
161
+ Optionally, instead of passing `input_ids` you can choose to directly pass an embedded representation. This
162
+ is useful if you want more control over how to convert `input_ids` indices into associated vectors than the
163
+ model's internal embedding lookup matrix.
164
+ use_cache (`bool`, *optional*):
165
+ If set to `True`, `past_key_values` key value states are returned and can be used to speed up decoding (see
166
+ `past_key_values`).
167
+ output_attentions (`bool`, *optional*):
168
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under returned
169
+ tensors for more detail.
170
+ output_hidden_states (`bool`, *optional*):
171
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors for
172
+ more detail.
173
+ return_dict (`bool`, *optional*):
174
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
175
+ """
176
+
177
+
178
+ @add_start_docstrings(
179
+ "The bare HunYuan Model outputting raw hidden-states without any specific head on top.",
180
+ HUNYUAN_START_DOCSTRING,
181
+ )
182
+ class HunYuanModel(HunYuanPreTrainedModel):
183
+ """
184
+ Transformer decoder consisting of *config.num_hidden_layers* layers. Each layer is a [`HunYuanDecoderLayer`]
185
+
186
+ Args:
187
+ config: HunYuanConfig
188
+ """
189
+
190
+ def __init__(self, config: HunYuanConfig):
191
+ super().__init__(config)
192
+ self.padding_idx = config.pad_token_id
193
+ self.vocab_size = config.vocab_size
194
+ self.add_classification_head = config.add_classification_head
195
+ self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
196
+ self.layers = nn.ModuleList(
197
+ [HunYuanDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
198
+ )
199
+ self._use_sdpa = config._attn_implementation == "sdpa"
200
+ self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
201
+ if not config.add_classification_head:
202
+ self.norm = HunYuanRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
203
+
204
+ self.cla = config.use_cla
205
+ self.cla_share_factor = config.cla_share_factor
206
+
207
+ self.gradient_checkpointing = False
208
+ # Initialize weights and apply final processing
209
+ self.post_init()
210
+
211
+ def get_input_embeddings(self):
212
+ return self.embed_tokens
213
+
214
+ def set_input_embeddings(self, value):
215
+ self.embed_tokens = value
216
+
217
+ @add_start_docstrings_to_model_forward(HUNYUAN_INPUTS_DOCSTRING)
218
+ def forward(
219
+ self,
220
+ input_ids: torch.LongTensor = None,
221
+ attention_mask: Optional[torch.Tensor] = None,
222
+ position_ids: Optional[torch.LongTensor] = None,
223
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
224
+ inputs_embeds: Optional[torch.FloatTensor] = None,
225
+ use_cache: Optional[bool] = None,
226
+ output_attentions: Optional[bool] = None,
227
+ output_hidden_states: Optional[bool] = None,
228
+ return_dict: Optional[bool] = None,
229
+ ) -> Union[Tuple, BaseModelOutputWithPast]:
230
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
231
+ output_hidden_states = (
232
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
233
+ )
234
+ use_cache = use_cache if use_cache is not None else self.config.use_cache
235
+
236
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
237
+
238
+ # retrieve input_ids and inputs_embeds
239
+ # if input_ids is not None and inputs_embeds is not None:
240
+ # raise ValueError("You cannot specify both input_ids and inputs_embeds at the same time")
241
+ if input_ids is not None:
242
+ batch_size, seq_length = input_ids.shape[:2]
243
+ elif inputs_embeds is not None:
244
+ batch_size, seq_length = inputs_embeds.shape[:2]
245
+ else:
246
+ raise ValueError("You have to specify either input_ids or inputs_embeds")
247
+
248
+ if self.gradient_checkpointing and self.training:
249
+ if use_cache:
250
+ logger.warning_once(
251
+ "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
252
+ )
253
+ use_cache = False
254
+
255
+ past_key_values_length = 0
256
+ if use_cache:
257
+ use_legacy_cache = not isinstance(past_key_values, Cache)
258
+ if use_legacy_cache:
259
+ past_key_values = DynamicCache.from_legacy_cache(past_key_values)
260
+ past_key_values_length = past_key_values.get_usable_length(seq_length)
261
+
262
+ if position_ids is None:
263
+ device = input_ids.device if input_ids is not None else inputs_embeds.device
264
+ position_ids = torch.arange(
265
+ past_key_values_length, seq_length + past_key_values_length, dtype=torch.long, device=device
266
+ )
267
+ position_ids = position_ids.unsqueeze(0)
268
+
269
+ if inputs_embeds is None:
270
+ inputs_embeds = self.embed_tokens(input_ids)
271
+
272
+ # Fix lora with gradient checkpointing training
273
+ if self.training and inputs_embeds.is_leaf:
274
+ inputs_embeds.requires_grad = True
275
+
276
+ if self._use_flash_attention_2:
277
+ # 2d mask is passed through the layers
278
+ attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
279
+ elif self._use_sdpa and not output_attentions:
280
+ # output_attentions=True can not be supported when using SDPA, and we fall back on
281
+ # the manual implementation that requires a 4D causal mask in all cases.
282
+ attention_mask = _prepare_4d_causal_attention_mask_for_sdpa(
283
+ attention_mask,
284
+ (batch_size, seq_length),
285
+ inputs_embeds,
286
+ past_key_values_length,
287
+ )
288
+ else:
289
+ # 4d mask is passed through the layers
290
+ attention_mask = _prepare_4d_causal_attention_mask(
291
+ attention_mask, (batch_size, seq_length), inputs_embeds, past_key_values_length
292
+ )
293
+
294
+ # embed positions
295
+ hidden_states = inputs_embeds
296
+
297
+ # decoder layers
298
+ all_hidden_states = () if output_hidden_states else None
299
+ all_self_attns = () if output_attentions else None
300
+ next_decoder_cache = None
301
+
302
+ prev_kv_states = None
303
+ for layer_idx, decoder_layer in enumerate(self.layers):
304
+ if output_hidden_states:
305
+ all_hidden_states += (hidden_states,)
306
+
307
+ if self.gradient_checkpointing and self.training:
308
+ layer_outputs = self._gradient_checkpointing_func(
309
+ decoder_layer.__call__,
310
+ hidden_states,
311
+ attention_mask,
312
+ position_ids,
313
+ past_key_values,
314
+ output_attentions,
315
+ use_cache,
316
+ prev_kv_states,
317
+ )
318
+ else:
319
+ layer_outputs = decoder_layer(
320
+ hidden_states,
321
+ attention_mask=attention_mask,
322
+ position_ids=position_ids,
323
+ past_key_value=past_key_values,
324
+ output_attentions=output_attentions,
325
+ use_cache=use_cache,
326
+ kv_states=prev_kv_states
327
+ )
328
+
329
+ hidden_states = layer_outputs[0]
330
+
331
+ if use_cache:
332
+ next_decoder_cache = layer_outputs[2 if output_attentions else 1]
333
+
334
+ if output_attentions:
335
+ all_self_attns += (layer_outputs[1],)
336
+
337
+ kv_states = layer_outputs[-1]
338
+
339
+ if self.cla and layer_idx % self.cla_share_factor == 0:
340
+ prev_kv_states = kv_states
341
+ if not self.add_classification_head:
342
+ hidden_states = self.norm(hidden_states)
343
+
344
+ # add hidden states from the last decoder layer
345
+ if output_hidden_states:
346
+ all_hidden_states += (hidden_states,)
347
+
348
+ next_cache = None
349
+ if use_cache:
350
+ next_cache = next_decoder_cache.to_legacy_cache() if use_legacy_cache else next_decoder_cache
351
+ if not return_dict:
352
+ return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
353
+ return BaseModelOutputWithPast(
354
+ last_hidden_state=hidden_states,
355
+ past_key_values=next_cache,
356
+ hidden_states=all_hidden_states,
357
+ attentions=all_self_attns,
358
+ )
359
+
360
+
361
+ class HunYuanMoEV1ForCausalLM(HunYuanPreTrainedModel):
362
+ _tied_weights_keys = ["lm_head.weight"]
363
+
364
+ def __init__(self, config: HunYuanConfig):
365
+ super().__init__(config)
366
+ if config.vit_path is not None:
367
+ if "-tp" in config.vit_type:
368
+ config.vit_type = config.vit_type.replace("-tp", "")
369
+ self.vit_type = config.vit_type
370
+ if self.vit_type not in ['NaVit', 'EvaVit']:
371
+ if config.vit_mapping_type == 'mlp':
372
+ self.vit_linear_encoder = torch.nn.Linear(config.hidden_size, config.hidden_size)
373
+ self.vit = Vit(config)
374
+ else:
375
+ self.vit = None
376
+ self.config = config
377
+ self.model = HunYuanModel(config)
378
+ self.add_classification_head = config.add_classification_head
379
+ self.pad_id = config.pad_id
380
+ self.vocab_size = config.vocab_size
381
+ self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
382
+ if config.add_classification_head:
383
+ self.pool_head = nn.Linear(config.hidden_size, config.hidden_size, bias=False)
384
+ self.pool_head2 = nn.Linear(config.hidden_size, config.class_num, bias=False)
385
+ # Initialize weights and apply final processing
386
+ self.post_init()
387
+
388
+ def get_input_embeddings(self):
389
+ return self.model.embed_tokens
390
+
391
+ def set_input_embeddings(self, value):
392
+ self.model.embed_tokens = value
393
+
394
+ def get_output_embeddings(self):
395
+ return self.lm_head
396
+
397
+ def set_output_embeddings(self, new_embeddings):
398
+ self.lm_head = new_embeddings
399
+
400
+ def set_decoder(self, decoder):
401
+ self.model = decoder
402
+
403
+ def get_decoder(self):
404
+ return self.model
405
+
406
+ @add_start_docstrings_to_model_forward(HUNYUAN_INPUTS_DOCSTRING)
407
+ @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
408
+ def forward(
409
+ self,
410
+ input_ids: torch.LongTensor = None,
411
+ attention_mask: Optional[torch.Tensor] = None,
412
+ position_ids: Optional[torch.LongTensor] = None,
413
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
414
+ inputs_embeds: Optional[torch.FloatTensor] = None,
415
+ labels: Optional[torch.LongTensor] = None,
416
+ use_cache: Optional[bool] = None,
417
+ output_attentions: Optional[bool] = None,
418
+ output_hidden_states: Optional[bool] = None,
419
+ return_dict: Optional[bool] = None,
420
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
421
+ r"""
422
+ Args:
423
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
424
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
425
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
426
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
427
+
428
+ Returns:
429
+
430
+ Example:
431
+
432
+ ```python
433
+ >>> from transformers import AutoTokenizer, AutoModelForCausalLM
434
+
435
+ >>> model = AutoModelForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS)
436
+ >>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER)
437
+
438
+ >>> prompt = "Hey, are you conscious? Can you talk to me?"
439
+ >>> inputs = tokenizer(prompt, return_tensors="pt")
440
+
441
+ >>> # Generate
442
+ >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
443
+ >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
444
+ "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
445
+ ```"""
446
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
447
+ output_hidden_states = (
448
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
449
+ )
450
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
451
+
452
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
453
+ outputs = self.model(
454
+ input_ids=input_ids,
455
+ attention_mask=attention_mask,
456
+ position_ids=position_ids,
457
+ past_key_values=past_key_values,
458
+ inputs_embeds=inputs_embeds,
459
+ use_cache=use_cache,
460
+ output_attentions=output_attentions,
461
+ output_hidden_states=output_hidden_states,
462
+ return_dict=return_dict,
463
+ )
464
+
465
+ hidden_states = outputs[0]
466
+
467
+ if not self.add_classification_head:
468
+ if self.config.pretraining_tp > 1:
469
+ lm_head_slices = self.lm_head.weight.split(self.vocab_size // self.config.pretraining_tp, dim=0)
470
+ logits = [F.linear(hidden_states, lm_head_slices[i]) for i in range(self.config.pretraining_tp)]
471
+ logits = torch.cat(logits, dim=-1)
472
+ else:
473
+ logits = self.lm_head(hidden_states)
474
+ logits = logits.float()
475
+ else:
476
+ logits = hidden_states
477
+ logits = logits.float()
478
+ pooled_output = self.pool_head(logits)
479
+ pooled_output = torch.tanh(pooled_output)
480
+ pooled_output = self.pool_head2(pooled_output).contiguous() # bs * class_num
481
+ if len(pooled_output.shape) < 2:
482
+ raise ValueError("pooled_output does not have enough dimensions for transpose")
483
+
484
+ if self.config.pool_type == "mean":
485
+ reward = pooled_output.mean(dim=1).squeeze(-1)
486
+ elif self.config.pool_type == "last":
487
+ # bs * hidden_size
488
+ seq_length = (input_ids != self.pad_id).long().sum(dim=1) - 1
489
+ batch_size = input_ids.size(0)
490
+ reward = pooled_output[torch.arange(batch_size, device=pooled_output.device), seq_length].squeeze(-1)
491
+ else:
492
+ reward = pooled_output[:, 0].squeeze(-1)
493
+
494
+ loss = None
495
+ if labels is not None:
496
+ # Shift so that tokens < n predict n
497
+ shift_logits = logits[..., :-1, :].contiguous()
498
+ shift_labels = labels[..., 1:].contiguous()
499
+ # Flatten the tokens
500
+ loss_fct = CrossEntropyLoss()
501
+ shift_logits = shift_logits.reshape(-1, self.config.vocab_size)
502
+ shift_labels = shift_labels.reshape(-1)
503
+ # Enable model parallelism
504
+ shift_labels = shift_labels.to(shift_logits.device)
505
+ loss = loss_fct(shift_logits, shift_labels)
506
+
507
+ if not return_dict:
508
+ output = (logits,) + outputs[1:]
509
+ return (loss,) + output if loss is not None else output
510
+
511
+ output = CausalLMOutputWithPast(
512
+ loss=loss,
513
+ logits=logits,
514
+ past_key_values=outputs.past_key_values,
515
+ hidden_states=outputs.hidden_states,
516
+ attentions=outputs.attentions,
517
+ )
518
+ if self.add_classification_head:
519
+ output['reward'] = reward
520
+
521
+ return output
522
+
523
+ def prepare_inputs_for_generation(
524
+ self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
525
+ ):
526
+ if past_key_values is not None:
527
+ if isinstance(past_key_values, Cache):
528
+ cache_length = past_key_values.get_seq_length()
529
+ past_length = past_key_values.seen_tokens
530
+ max_cache_length = past_key_values.get_max_length()
531
+ else:
532
+ cache_length = past_length = past_key_values[0][0].shape[2]
533
+ max_cache_length = None
534
+
535
+ # Keep only the unprocessed tokens:
536
+ # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
537
+ # some of the inputs are exclusivelly passed as part of the cache (e.g. when passing input_embeds as
538
+ # input)
539
+ if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]:
540
+ input_ids = input_ids[:, -(attention_mask.shape[1] - past_length):]
541
+ # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard
542
+ # input_ids based on the past_length.
543
+ elif past_length < input_ids.shape[1]:
544
+ input_ids = input_ids[:, past_length:]
545
+ # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens.
546
+
547
+ # If we are about to go beyond the maximum cache length, we need to crop the input attention mask.
548
+ if (
549
+ max_cache_length is not None
550
+ and attention_mask is not None
551
+ and cache_length + input_ids.shape[1] > max_cache_length
552
+ ):
553
+ attention_mask = attention_mask[:, -max_cache_length:]
554
+
555
+ position_ids = kwargs.get("position_ids", None)
556
+ if attention_mask is not None and position_ids is None:
557
+ # create position_ids on the fly for batch generation
558
+ position_ids = attention_mask.long().cumsum(-1) - 1
559
+ position_ids.masked_fill_(attention_mask == 0, 1)
560
+ if past_key_values:
561
+ position_ids = position_ids[:, -input_ids.shape[1]:]
562
+
563
+ # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
564
+ if inputs_embeds is not None and past_key_values is None:
565
+ model_inputs = {"inputs_embeds": inputs_embeds}
566
+ else:
567
+ model_inputs = {"input_ids": input_ids}
568
+
569
+ model_inputs.update(
570
+ {
571
+ "position_ids": position_ids,
572
+ "past_key_values": past_key_values,
573
+ "use_cache": kwargs.get("use_cache"),
574
+ "attention_mask": attention_mask,
575
+ }
576
+ )
577
+ return model_inputs
578
+
579
+ @staticmethod
580
+ def _reorder_cache(past_key_values, beam_idx):
581
+ reordered_past = ()
582
+ for layer_past in past_key_values:
583
+ reordered_past += (
584
+ tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past),
585
+ )
586
+ return reordered_past
587
+
588
+
589
+ class MultimodelHunYuanForCausalLM(HunYuanMoEV1ForCausalLM):
590
+ _tied_weights_keys = ["lm_head.weight"]
591
+
592
+ def __init__(self, config: HunYuanConfig):
593
+ super().__init__(config)
594
+
595
+ @add_start_docstrings_to_model_forward(HUNYUAN_INPUTS_DOCSTRING)
596
+ @replace_return_docstrings(output_type=CausalLMOutputWithPast, config_class=_CONFIG_FOR_DOC)
597
+ def forward(
598
+ self,
599
+ input_ids: torch.LongTensor = None,
600
+ attention_mask: Optional[torch.Tensor] = None,
601
+ position_ids: Optional[torch.LongTensor] = None,
602
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
603
+ inputs_embeds: Optional[torch.FloatTensor] = None,
604
+ labels: Optional[torch.LongTensor] = None,
605
+ imgs: Optional[List[torch.FloatTensor]] = None,
606
+ imgs_pos: Optional[List[int]] = None,
607
+ use_cache: Optional[bool] = None,
608
+ output_attentions: Optional[bool] = None,
609
+ output_hidden_states: Optional[bool] = None,
610
+ return_dict: Optional[bool] = None,
611
+ ) -> Union[Tuple, CausalLMOutputWithPast]:
612
+ r"""
613
+ Args:
614
+ labels (`torch.LongTensor` of shape `(batch_size, sequence_length)`, *optional*):
615
+ Labels for computing the masked language modeling loss. Indices should either be in `[0, ...,
616
+ config.vocab_size]` or -100 (see `input_ids` docstring). Tokens with indices set to `-100` are ignored
617
+ (masked), the loss is only computed for the tokens with labels in `[0, ..., config.vocab_size]`.
618
+
619
+ Returns:
620
+
621
+ Example:
622
+
623
+ ```python
624
+ >>> from transformers import AutoTokenizer, AutoModelForCausalLM
625
+
626
+ >>> model = AutoModelForCausalLM.from_pretrained(PATH_TO_CONVERTED_WEIGHTS)
627
+ >>> tokenizer = AutoTokenizer.from_pretrained(PATH_TO_CONVERTED_TOKENIZER)
628
+
629
+ >>> prompt = "Hey, are you conscious? Can you talk to me?"
630
+ >>> inputs = tokenizer(prompt, return_tensors="pt")
631
+
632
+ >>> # Generate
633
+ >>> generate_ids = model.generate(inputs.input_ids, max_length=30)
634
+ >>> tokenizer.batch_decode(generate_ids, skip_special_tokens=True, clean_up_tokenization_spaces=False)[0]
635
+ "Hey, are you conscious? Can you talk to me?\nI'm not conscious, but I can talk to you."
636
+ ```"""
637
+ mask_init_id = self.config.mask_init_id
638
+ pad_id = self.config.pad_token_id
639
+ eod_id = self.config.eod_token_id
640
+ image_token_id = self.config.image_token_id
641
+ im_start_id = self.config.im_start_id
642
+ im_end_id = self.config.im_end_id
643
+ video_start_id = self.config.video_start_id
644
+ video_end_id = self.config.video_end_id
645
+
646
+ if self.vit is not None and imgs is not None:
647
+ encoder_input = self.model.embed_tokens(input_ids)
648
+ if self.vit_type in ['NaVit', 'EvaVit', 'AnyResVit']:
649
+ inputs_embeds, input_ids = NaVitForward(input_ids, encoder_input, self.vit, imgs, imgs_pos, self.config.vit_input_resolution, \
650
+ im_start_id, im_end_id, image_token_id, self.config.anyres_vit_two_views, self.config.torch_dtype)
651
+ else:
652
+ inputs_embeds, input_ids = VitForward(input_ids, encoder_input, self.vit, self.vit_linear_encoder, imgs, imgs_pos, \
653
+ self.config.vit_input_resolution, self.config.vit_mapping_type, self.config.vit_patch, self.config.vit_token)
654
+
655
+ output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
656
+ output_hidden_states = (
657
+ output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
658
+ )
659
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
660
+ # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn)
661
+
662
+ outputs = self.model(
663
+ input_ids=input_ids,
664
+ attention_mask=attention_mask,
665
+ position_ids=position_ids,
666
+ past_key_values=past_key_values,
667
+ inputs_embeds=inputs_embeds,
668
+ use_cache=use_cache,
669
+ output_attentions=output_attentions,
670
+ output_hidden_states=output_hidden_states,
671
+ return_dict=return_dict,
672
+ )
673
+
674
+ hidden_states = outputs[0]
675
+ if self.config.pretraining_tp > 1:
676
+ lm_head_slices = self.lm_head.weight.split(self.vocab_size // self.config.pretraining_tp, dim=0)
677
+ logits = [F.linear(hidden_states, lm_head_slices[i]) for i in range(self.config.pretraining_tp)]
678
+ logits = torch.cat(logits, dim=-1)
679
+ else:
680
+ logits = self.lm_head(hidden_states)
681
+ logits = logits.float()
682
+
683
+ loss = None
684
+ if labels is not None:
685
+ labels = labels.to(logits.device)
686
+ # Shift so that tokens < n predict n
687
+ shift_logits = logits
688
+ shift_labels = labels
689
+ # Flatten the tokens
690
+ loss_fct = CrossEntropyLoss()
691
+ shift_logits = shift_logits.reshape(-1, self.config.vocab_size)
692
+ shift_labels = shift_labels.reshape(-1)
693
+ shift_tokens = input_ids.reshape(-1)
694
+ # compute loss
695
+ mask = (shift_labels < mask_init_id) & (shift_labels != pad_id) & (shift_labels != image_token_id) & (shift_labels != im_start_id) \
696
+ & (shift_labels != im_end_id) & (shift_labels != video_start_id) & (shift_labels != video_end_id) & (shift_tokens != pad_id) & (shift_tokens != eod_id)
697
+ shift_logits = shift_logits[mask, :]
698
+ shift_labels = shift_labels[mask]
699
+ loss = loss_fct(shift_logits, shift_labels)
700
+
701
+ if not return_dict:
702
+ output = (logits,) + outputs[1:]
703
+ return (loss,) + output if loss is not None else output
704
+
705
+ return CausalLMOutputWithPast(
706
+ loss=loss,
707
+ logits=logits,
708
+ past_key_values=outputs.past_key_values,
709
+ hidden_states=outputs.hidden_states,
710
+ attentions=outputs.attentions,
711
+ )
712
+
713
+ def prepare_inputs_for_generation(
714
+ self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
715
+ ):
716
+ imgs = kwargs.pop("imgs", None)
717
+ imgs_pos = kwargs.pop("imgs_pos", None)
718
+ inputs = super().prepare_inputs_for_generation(
719
+ input_ids, past_key_values=past_key_values, attention_mask=attention_mask, inputs_embeds=inputs_embeds, **kwargs
720
+ )
721
+
722
+ if imgs is not None:
723
+ inputs['imgs'] = imgs
724
+ if imgs_pos is not None:
725
+ inputs['imgs_pos'] = imgs_pos
726
+ return inputs
727
+
728
+ @torch.no_grad()
729
+ def generate(
730
+ self,
731
+ inputs: Optional[torch.Tensor] = None,
732
+ attention_mask: Optional[torch.Tensor] = None,
733
+ position_ids: Optional[torch.LongTensor] = None,
734
+ imgs: Optional[List[torch.FloatTensor]] = None,
735
+ imgs_pos: Optional[List[int]] = None,
736
+ **kwargs,
737
+ ) -> Union[GenerateOutput, torch.LongTensor]:
738
+ if "inputs_embeds" in kwargs:
739
+ raise NotImplementedError("`inputs_embeds` is not supported")
740
+
741
+ if self.vit is not None:
742
+ encoder_input = self.model.embed_tokens(inputs)
743
+ if self.vit_type in ['NaVit', 'EvaVit', 'AnyResVit']:
744
+ inputs_embeds, input_ids = NaVitForward(inputs, encoder_input, self.vit, imgs, imgs_pos, self.config.vit_input_resolution, \
745
+ self.config.im_start_id, self.config.im_end_id, self.config.image_token_id, self.config.anyres_vit_two_views, self.config.torch_dtype)
746
+ else:
747
+ inputs_embeds, input_ids = VitForward(inputs, encoder_input, self.vit, self.vit_linear_encoder, imgs, imgs_pos, \
748
+ self.config.vit_input_resolution, self.config.vit_mapping_type, self.config.vit_patch, self.config.vit_token)
749
+
750
+ return super().generate(
751
+ inputs=input_ids,
752
+ position_ids=position_ids,
753
+ attention_mask=attention_mask,
754
+ inputs_embeds=inputs_embeds,
755
+ eos_token_id=self.config.eod_token_id,
756
+ **kwargs
757
+ )
758
+
759
+
760
+ @add_start_docstrings(
761
+ """
762
+ The HunYuan Model transformer with a sequence classification head on top (linear layer).
763
+
764
+ [`HunYuanForSequenceClassification`] uses the last token in order to do the classification, as other causal models
765
+ (e.g. GPT-2) do.
766
+
767
+ Since it does classification on the last token, it requires to know the position of the last token. If a
768
+ `pad_token_id` is defined in the configuration, it finds the last token that is not a padding token in each row. If
769
+ no `pad_token_id` is defined, it simply takes the last value in each row of the batch. Since it cannot guess the
770
+ padding tokens when `inputs_embeds` are passed instead of `input_ids`, it does the same (take the last value in
771
+ each row of the batch).
772
+ """,
773
+ HUNYUAN_START_DOCSTRING,
774
+ )
775
+ class HunYuanForSequenceClassification(HunYuanPreTrainedModel):
776
+ def __init__(self, config):
777
+ super().__init__(config)
778
+ self.num_labels = config.num_labels
779
+ self.model = HunYuanModel(config)
780
+ self.score = nn.Linear(config.hidden_size, self.num_labels, bias=False)
781
+
782
+ # Initialize weights and apply final processing
783
+ self.post_init()
784
+
785
+ def get_input_embeddings(self):
786
+ return self.model.embed_tokens
787
+
788
+ def set_input_embeddings(self, value):
789
+ self.model.embed_tokens = value
790
+
791
+ @add_start_docstrings_to_model_forward(HUNYUAN_INPUTS_DOCSTRING)
792
+ def forward(
793
+ self,
794
+ input_ids: torch.LongTensor = None,
795
+ attention_mask: Optional[torch.Tensor] = None,
796
+ position_ids: Optional[torch.LongTensor] = None,
797
+ past_key_values: Optional[List[torch.FloatTensor]] = None,
798
+ inputs_embeds: Optional[torch.FloatTensor] = None,
799
+ labels: Optional[torch.LongTensor] = None,
800
+ use_cache: Optional[bool] = None,
801
+ output_attentions: Optional[bool] = None,
802
+ output_hidden_states: Optional[bool] = None,
803
+ return_dict: Optional[bool] = None,
804
+ ) -> Union[Tuple, SequenceClassifierOutputWithPast]:
805
+ r"""
806
+ labels (`torch.LongTensor` of shape `(batch_size,)`, *optional*):
807
+ Labels for computing the sequence classification/regression loss. Indices should be in `[0, ...,
808
+ config.num_labels - 1]`. If `config.num_labels == 1` a regression loss is computed (Mean-Square loss), If
809
+ `config.num_labels > 1` a classification loss is computed (Cross-Entropy).
810
+ """
811
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
812
+
813
+ transformer_outputs = self.model(
814
+ input_ids,
815
+ attention_mask=attention_mask,
816
+ position_ids=position_ids,
817
+ past_key_values=past_key_values,
818
+ inputs_embeds=inputs_embeds,
819
+ use_cache=use_cache,
820
+ output_attentions=output_attentions,
821
+ output_hidden_states=output_hidden_states,
822
+ return_dict=return_dict,
823
+ )
824
+ hidden_states = transformer_outputs[0]
825
+ logits = self.score(hidden_states)
826
+
827
+ if input_ids is not None:
828
+ batch_size = input_ids.shape[0]
829
+ else:
830
+ batch_size = inputs_embeds.shape[0]
831
+
832
+ if self.config.pad_token_id is None and batch_size != 1:
833
+ raise ValueError("Cannot handle batch sizes > 1 if no padding token is defined.")
834
+ if self.config.pad_token_id is None:
835
+ sequence_lengths = -1
836
+ else:
837
+ if input_ids is not None:
838
+ sequence_lengths = (torch.eq(input_ids, self.config.pad_token_id).int().argmax(-1) - 1).to(
839
+ logits.device
840
+ )
841
+ else:
842
+ sequence_lengths = -1
843
+
844
+ pooled_logits = logits[torch.arange(batch_size, device=logits.device), sequence_lengths]
845
+
846
+ loss = None
847
+ if labels is not None:
848
+ labels = labels.to(logits.device)
849
+ if self.config.problem_type is None:
850
+ if self.num_labels == 1:
851
+ self.config.problem_type = "regression"
852
+ elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
853
+ self.config.problem_type = "single_label_classification"
854
+ else:
855
+ self.config.problem_type = "multi_label_classification"
856
+
857
+ if self.config.problem_type == "regression":
858
+ loss_fct = MSELoss()
859
+ if self.num_labels == 1:
860
+ loss = loss_fct(pooled_logits.squeeze(), labels.squeeze())
861
+ else:
862
+ loss = loss_fct(pooled_logits, labels)
863
+ elif self.config.problem_type == "single_label_classification":
864
+ loss_fct = CrossEntropyLoss()
865
+ loss = loss_fct(pooled_logits.reshape(-1, self.num_labels), labels.reshape(-1))
866
+ elif self.config.problem_type == "multi_label_classification":
867
+ loss_fct = BCEWithLogitsLoss()
868
+ loss = loss_fct(pooled_logits, labels)
869
+ if not return_dict:
870
+ output = (pooled_logits,) + transformer_outputs[1:]
871
+ return ((loss,) + output) if loss is not None else output
872
+
873
+ return SequenceClassifierOutputWithPast(
874
+ loss=loss,
875
+ logits=pooled_logits,
876
+ past_key_values=transformer_outputs.past_key_values,
877
+ hidden_states=transformer_outputs.hidden_states,
878
+ attentions=transformer_outputs.attentions,
879
+ )