manaestras commited on
Commit
f055b72
·
verified ·
1 Parent(s): 7a537be

Update LICENSE.txt

Browse files
Files changed (1) hide show
  1. LICENSE.txt +77 -1083
LICENSE.txt CHANGED
@@ -1,1083 +1,77 @@
1
- import json
2
- import types
3
- import math
4
- import torch
5
- from torch import Tensor, nn
6
- import torch.nn.functional as F
7
- from typing import List, Tuple, Optional, Union
8
- from contextlib import contextmanager
9
- from transformers.modeling_attn_mask_utils import (
10
- _prepare_4d_causal_attention_mask_for_sdpa,
11
- _prepare_4d_causal_attention_mask_for_sdpa,
12
- _prepare_4d_causal_attention_mask,
13
- )
14
- from transformers.models.clip.configuration_clip import CLIPVisionConfig
15
- from transformers.modeling_outputs import BaseModelOutputWithPooling
16
- from .modeling_hunyuan import HunYuanDecoderLayer, HunYuanRMSNorm
17
- from .configuration_hunyuan import HunYuanConfig
18
-
19
-
20
- def NaVitForward(input_ids, encoder_input, vit, image_tensors, images_pos, vit_input_resolution, im_start_id, im_end_id, image_token_id, anyres_vit_two_views, dtype):
21
- # input_ids: (B, L)
22
- # encoder_input: (L, B, E)
23
- # image_tensors [[Tensor],...,[Tensor]]
24
- # image_pos [[Tensor],...,[Tensor]]
25
- # tokenizer = get_tokenizer()
26
- b = len(input_ids)
27
- img_embs = None
28
- all_nums = sum([len(tensors) for tensors in image_tensors]) if image_tensors else 0
29
- if all_nums != 0:
30
- img_embs, img_batch_pos = vit(image_tensors)
31
- else:
32
- # when no input image, initialize a fake tensor
33
- pad_nums = 1
34
- image_tensors = [[torch.rand(3, vit_input_resolution, vit_input_resolution, dtype=dtype, device=torch.cuda.current_device()) for _ in range(pad_nums)]]
35
- img_embs, img_batch_pos = vit(image_tensors)
36
-
37
- encoder_input = encoder_input.clone()
38
- if all_nums > 0:
39
- assert len(images_pos) == len(img_batch_pos), \
40
- (len(images_pos), len(img_batch_pos))
41
- start_token_id = im_start_id
42
- end_token_id = im_end_id
43
- placeholder_id = image_token_id
44
- for idx in range(len(images_pos)):
45
- assert len(images_pos[idx]) == len(img_batch_pos[idx]), \
46
- (len(images_pos[idx]), len(img_batch_pos[idx]))
47
- for p_img_pos_in_batch, p_batch_img_pos in zip(img_batch_pos[idx], images_pos[idx]):
48
- # the positions to be filled [s_start, s_end)
49
- s_idx, s_start, s_end = p_img_pos_in_batch
50
- current_embs = img_embs[s_idx, s_start:s_end]
51
- im_s, im_e = p_batch_img_pos
52
- assert len(current_embs) == im_e - im_s, \
53
- (img_embs.shape, (s_start, s_end, s_idx), current_embs.shape, (im_s, im_e, idx))
54
- if not anyres_vit_two_views:
55
- assert input_ids[idx, im_s - 1] == start_token_id, \
56
- input_ids[idx, im_s - 1]
57
- assert input_ids[idx, im_e] == end_token_id, \
58
- input_ids[idx, im_e]
59
- assert (input_ids[idx, im_s:im_e] == placeholder_id).all(), \
60
- f'The tokens to be filled are not the placeholder_id {placeholder_id}: {(input_ids[idx, im_s:im_e] == placeholder_id).sum()} vs {im_e - im_s}'
61
- encoder_input[idx, im_s:im_e] = current_embs
62
- else:
63
- # when no input image, to mask vit value
64
- vit_mask = torch.zeros([1, img_embs.shape[0]], device=torch.cuda.current_device())
65
- current_embs = img_embs[0, :]
66
- encoder_input[0, 1:img_embs.shape[0] + 1] = encoder_input[0, 1:img_embs.shape[0] + 1] * (1 - vit_mask) + current_embs * vit_mask
67
- return encoder_input, input_ids
68
-
69
-
70
- def VitForward(input_ids, encoder_input, vit, vit_linear_encoder, image_tensors, images_pos, vit_input_resolution, vit_mapping_type, vit_patch, vit_token):
71
- vit_patch_mlp = (vit_patch > 1 and vit_mapping_type == 'mlp') or vit_patch == 0
72
-
73
- b = len(input_ids)
74
- if images_pos is None:
75
- images_pos = torch.ones([len(input_ids), 1, 3])
76
- images_pos[:, :, 1] = images_pos[:, :, 1]*(vit_token + 1)
77
- images_pos = images_pos.long()
78
-
79
- real_image_nums = []
80
- image_tensors = image_tensors.view(b, -1, 3, vit_input_resolution, vit_input_resolution)
81
- real_images = []
82
-
83
- all_nums = 0
84
- img_index = []
85
- for s in range(len(images_pos)):
86
- real_image_num = 0
87
- for (im_s, im_e,index) in images_pos[s]:
88
- if im_s == -1:
89
- break
90
- real_image_num += 1
91
- all_nums += 1
92
- img_index.append(index)
93
-
94
- real_image_nums.append(real_image_num)
95
- real_images.append(image_tensors[s][:real_image_num])
96
-
97
- if vit_patch == 1:
98
- img_index = None
99
-
100
- if all_nums == 0:
101
- # when no input image, initialize a fake tensor
102
- img_input = torch.rand(b, 3, vit_input_resolution, vit_input_resolution).cuda().type(image_tensors.dtype)
103
- img_embs = vit(img_input)
104
- img_embs = vit_linear_encoder(img_embs)
105
- else:
106
- img_input = torch.cat(real_images)
107
- img_embs = vit(img_input, img_index = img_index)
108
- img_embs = vit_linear_encoder(img_embs)
109
-
110
- encoder_input = encoder_input.clone()
111
- start = 0
112
- if all_nums > 0:
113
- for s, real_image_len in enumerate(real_image_nums):
114
- current_embs = img_embs[start:start + real_image_len, :] #[30, 256, 4096]
115
- for ss in range(current_embs.shape[0]):
116
- im_s, im_e, index = images_pos[s, ss]
117
- # 子图特征更少
118
- if index > 0 and vit_patch_mlp:
119
- encoder_input[s, im_s:im_e,] = current_embs[ss, :(im_e-im_s)]
120
- else:
121
- encoder_input[s, im_s:im_e] = current_embs[ss, :]
122
- start = start + real_image_len
123
- else:
124
- # when no input image, to mask vit value
125
- for s in range(b):
126
- vit_mask = torch.zeros([vit_token, 1]).cuda()
127
- current_embs = img_embs[:, start:start + 1]
128
- encoder_input[1:vit_token + 1, s] = encoder_input[1:vit_token + 1, s] * (1 - vit_mask) + current_embs[:, 0, :] * vit_mask
129
- start = start + 1
130
- return encoder_input, input_ids
131
-
132
-
133
- def group_images_by_max_seq_len(
134
- images: List[List[Tensor]], patch_size: int,
135
- max_seq_len: int, adaptor_patch_size: int,
136
- add_cls_token: bool = False) -> List[List[Tensor]]:
137
-
138
- groups = []
139
- group = []
140
- pos_groups = []
141
- seq_len = 0
142
- num_images = 0
143
- for image_list in images:
144
- pos_group = []
145
- for image in image_list:
146
- num_images += 1
147
- assert isinstance(image, Tensor)
148
-
149
- image_dims = image.shape[-2:]
150
- ph, pw = map(lambda t: t // patch_size, image_dims)
151
-
152
- image_seq_len = (ph * pw)
153
- new_image_seq_len = image_seq_len
154
- grouped_len = seq_len + image_seq_len
155
- if add_cls_token:
156
- new_image_seq_len += 1
157
- grouped_len += num_images
158
-
159
- assert new_image_seq_len <= max_seq_len, f'image with dimensions {image_dims} exceeds maximum sequence length'
160
-
161
- if grouped_len > max_seq_len:
162
- groups.append(group)
163
- group = []
164
- seq_len = 0
165
- num_images = 1
166
-
167
- group.append(image)
168
- start = seq_len // (adaptor_patch_size * adaptor_patch_size)
169
- end = start + image_seq_len//(adaptor_patch_size * adaptor_patch_size)
170
- batch_idx = len(groups)
171
- pos_group.append([batch_idx, start, end])
172
- seq_len += image_seq_len
173
- pos_groups.append(pos_group)
174
-
175
- if len(group) > 0:
176
- groups.append(group)
177
-
178
- return groups, pos_groups
179
-
180
-
181
- class AnyResCLIPVisionEmbeddings(nn.Module):
182
- def __init__(self, config: CLIPVisionConfig):
183
- super().__init__()
184
-
185
- self.config = config
186
- # self.sparse_attn_mask = args.sparse_attn_mask
187
- # self.use_flash_attn = args.use_flash_attn
188
- self.embed_dim = config.hidden_size
189
- self.image_size = config.max_image_size
190
- self.patch_size = config.patch_size
191
- self.max_seq_len = config.max_vit_seq_len
192
- self.adaptor_patch_size = config.adaptor_patch_size
193
- self.anyres_vit_two_views = config.anyres_vit_two_views
194
- self.vit_add_patchemb_bias = config.vit_add_patchemb_bias
195
- self.vit_remove_prenorm = config.vit_remove_prenorm
196
-
197
- self.patch_embedding = nn.Conv2d(
198
- in_channels=config.num_channels,
199
- out_channels=self.embed_dim,
200
- kernel_size=self.patch_size,
201
- stride=self.patch_size,
202
- bias=self.vit_add_patchemb_bias,
203
- )
204
-
205
- self.num_patches = (self.image_size // self.patch_size) ** 2
206
- self.skip_cls_token = True
207
-
208
- # add interpolate_pos_encoding
209
- if self.anyres_vit_two_views:
210
- self.num_positions = self.num_patches
211
- self.position_embedding = nn.Parameter(torch.randn(1, self.num_positions, self.embed_dim) * 0.02)
212
- else:
213
- self.num_positions = self.num_patches + 1
214
- self.register_buffer("position_ids", torch.arange(self.num_positions).expand((1, -1)))
215
- # self.position_ids = torch.arange(self.num_positions).expand((1, -1))
216
- self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim)
217
-
218
- if not self.vit_remove_prenorm:
219
- self.pre_layernorm = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
220
-
221
- def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor:
222
- """
223
- This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher
224
- resolution images.
225
-
226
- Source:
227
- https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174
228
- """
229
- num_patches = embeddings.shape[1]
230
- position_embeddings = self.position_embedding(self.position_ids)
231
- patch_pos_embed = position_embeddings[:, 1:]
232
- num_positions = position_embeddings.shape[1] - 1
233
- if num_patches == num_positions and height == width:
234
- return patch_pos_embed
235
- # class_pos_embed = position_embeddings[:, 0]
236
- dim = embeddings.shape[-1]
237
- h0 = height // self.patch_size
238
- w0 = width // self.patch_size
239
- # we add a small number to avoid floating point error in the interpolation
240
- # see discussion at https://github.com/facebookresearch/dino/issues/8
241
- h0, w0 = h0 + 0.1, w0 + 0.1
242
- patch_pos_embed = patch_pos_embed.reshape(1, int(math.sqrt(num_positions)), int(math.sqrt(num_positions)), dim)
243
- patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2)
244
- raw_type = patch_pos_embed.dtype
245
- patch_pos_embed = nn.functional.interpolate(
246
- patch_pos_embed.to(torch.float32, non_blocking=True),
247
- scale_factor=(h0 / math.sqrt(num_positions), w0 / math.sqrt(num_positions)),
248
- mode="bilinear",
249
- align_corners=False,
250
- )
251
- patch_pos_embed = patch_pos_embed.to(raw_type, non_blocking=True)
252
- assert int(h0) == patch_pos_embed.shape[-2] and int(w0) == patch_pos_embed.shape[-1]
253
- patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
254
- return patch_pos_embed
255
-
256
- def rescale_positional_embedding(self, out_size):
257
- h, w = out_size
258
- pos_embed_shape = int((self.position_embedding.shape[1]) ** 0.5)
259
- if (h, w) == (pos_embed_shape, pos_embed_shape):
260
- return self.position_embedding
261
- rescaled_positional_embedding = \
262
- self.position_embedding.new_zeros(1, h*w, self.position_embedding.shape[2])
263
- pe_2d = self.position_embedding[0].T.contiguous().view(1, -1, pos_embed_shape, pos_embed_shape)
264
- pe_2d = F.interpolate(pe_2d, out_size, mode='bilinear', align_corners=False).view(-1, h*w)
265
- rescaled_positional_embedding[0] = pe_2d.T.contiguous()
266
- return rescaled_positional_embedding
267
-
268
- def forward_single(self, pixel_values: torch.FloatTensor) -> torch.Tensor:
269
- if pixel_values.ndim == 3:
270
- pixel_values = pixel_values[None]
271
- batch_size, num_channels, height, width = pixel_values.shape
272
-
273
- if self.anyres_vit_two_views:
274
- # padding
275
- pad_h = (self.patch_size - height % self.patch_size) % self.patch_size
276
- pad_w = (self.patch_size - width % self.patch_size) % self.patch_size
277
- pixel_values = F.pad(pixel_values, (0, pad_w, 0, pad_h))
278
-
279
- patch_embeds = self.patch_embedding(pixel_values) # shape = [*, width, grid, grid]
280
- b, c, h, w = patch_embeds.shape
281
-
282
- # (b, hw, c)
283
- patch_embeds = patch_embeds.flatten(2).transpose(1, 2)
284
- if self.anyres_vit_two_views:
285
- embeddings = patch_embeds + self.rescale_positional_embedding(out_size=(h, w))
286
- else:
287
- embeddings = patch_embeds + self.interpolate_pos_encoding(patch_embeds, height, width)
288
- if not self.vit_remove_prenorm:
289
- embeddings = self.pre_layernorm(embeddings)
290
- return embeddings, (h, w)
291
-
292
- def forward(self, images: List[List[Tensor]]):
293
- '''
294
- Input:
295
- images: List[List[Tensor]]
296
-
297
- Return:
298
- embeddings: Tensor (B, L, E)
299
- attn_mask: Tensor (B, L, 2)
300
- pos_groups: List[List[(batch_idx, start, end)]]
301
- '''
302
- batched_images, pos_groups = group_images_by_max_seq_len(
303
- images, self.patch_size, self.max_seq_len, self.adaptor_patch_size, add_cls_token=not self.skip_cls_token)
304
- max_seq_len = self.max_seq_len
305
-
306
- # batched_images is a list of a list
307
- B = len(batched_images)
308
- L = max_seq_len
309
- E = self.embed_dim
310
-
311
- embeddings = torch.zeros(B, L, E, dtype=self.config.torch_dtype, requires_grad=True).cuda(non_blocking=True)
312
- attn_mask = embeddings.new_full((B, 1, L, L), False, dtype=torch.bool) # True presents compute
313
- assert len(images) == len(pos_groups), (len(images), len(pos_groups))
314
-
315
- batch_images = []
316
- batch_pos = []
317
- for images_i, pos_group in zip(images, pos_groups):
318
- assert len(images_i) == len(pos_group), (len(images_i), len(pos_group))
319
- for image, pos in zip(images_i, pos_group):
320
- batch_idx, start, end = pos
321
- a2 = self.adaptor_patch_size ** 2
322
- # recover the real number of the input image tokens
323
- start *= a2
324
- end *= a2
325
- emb, _ = self.forward_single(image)
326
- assert emb.ndim == 3, '(B, L, E)'
327
- embeddings[batch_idx, start:end] = emb
328
- attn_mask[batch_idx, :, start:end, start:end] = True
329
- return embeddings, attn_mask, pos_groups
330
-
331
-
332
- class CLIPVisionEmbeddings(nn.Module):
333
- def __init__(self, config: CLIPVisionConfig, add_pre_layernorm=False, skip_cls_token=True, vit_patch=1):
334
- super().__init__()
335
- self.config = config
336
- self.embed_dim = config.hidden_size
337
- self.image_size = config.image_size
338
- self.image_size = config.vit_input_resolution
339
- self.patch_size = config.patch_size
340
-
341
- self.class_embedding = nn.Parameter(torch.randn(self.embed_dim))
342
-
343
- self.patch_embedding = nn.Conv2d(
344
- in_channels=config.num_channels,
345
- out_channels=self.embed_dim,
346
- kernel_size=self.patch_size,
347
- stride=self.patch_size,
348
- bias=False,
349
- )
350
-
351
- self.num_patches = (self.image_size // self.patch_size) ** 2
352
-
353
- self.skip_cls_token = skip_cls_token
354
-
355
- self.num_positions = self.num_patches + 1
356
-
357
- self.register_buffer("position_ids", torch.arange(self.num_positions).expand((1, -1)))
358
- if vit_patch > 1:
359
- self.position_embedding = nn.Embedding(self.num_patches * (vit_patch ** 2 + 1) + 1, self.embed_dim)
360
- # 0 支持最大16张图,目前写死了,如需其他的需要额外定义参数
361
- elif vit_patch == 0:
362
- self.position_embedding = nn.Embedding(self.num_patches * (16 ** 2 + 1) + 1, self.embed_dim)
363
- else:
364
- self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim)
365
-
366
- if add_pre_layernorm:
367
- self.pre_layernorm = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
368
- else:
369
- self.pre_layernorm = None
370
-
371
- def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor:
372
- """
373
- This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher
374
- resolution images.
375
-
376
- Source:
377
- https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174
378
- """
379
- num_patches = embeddings.shape[1] - 1
380
- position_embeddings = self.position_embedding(self.position_ids)
381
- num_positions = position_embeddings.shape[1] - 1
382
- if num_patches == num_positions and height == width:
383
- return position_embeddings
384
- class_pos_embed = position_embeddings[:, 0]
385
- patch_pos_embed = position_embeddings[:, 1:]
386
- dim = embeddings.shape[-1]
387
- h0 = height // self.config.patch_size
388
- w0 = width // self.config.patch_size
389
- # we add a small number to avoid floating point error in the interpolation
390
- # see discussion at https://github.com/facebookresearch/dino/issues/8
391
- h0, w0 = h0 + 0.1, w0 + 0.1
392
- patch_pos_embed = patch_pos_embed.reshape(1, int(math.sqrt(num_positions)), int(math.sqrt(num_positions)), dim)
393
- patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2)
394
- raw_type = patch_pos_embed.dtype
395
- patch_pos_embed = nn.functional.interpolate(
396
- patch_pos_embed.float(),
397
- scale_factor=(h0 / math.sqrt(num_positions), w0 / math.sqrt(num_positions)),
398
- mode="bicubic",
399
- align_corners=False,
400
- )
401
- # print(patch_pos_embed.shape)
402
- patch_pos_embed = patch_pos_embed.to(raw_type)
403
- assert int(h0) == patch_pos_embed.shape[-2] and int(w0) == patch_pos_embed.shape[-1]
404
- patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
405
- return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1)
406
-
407
-
408
- def forward(self, pixel_values: torch.FloatTensor, interpolate_pos_encoding: bool = False, img_index=None) -> torch.Tensor:
409
- batch_size, num_channels, height, width = pixel_values.shape
410
- patch_embeds = self.patch_embedding(pixel_values) # shape = [*, width, grid, grid]
411
- patch_embeds = patch_embeds.flatten(2).transpose(1, 2)
412
- if self.skip_cls_token:
413
- embeddings = patch_embeds
414
- if img_index is None:
415
- position_ids = self.position_ids[:,1:]
416
- embeddings = embeddings + self.position_embedding(position_ids)
417
- else:
418
- position_ids = (torch.tensor(img_index).cuda() * (self.num_positions - 1)).unsqueeze(1).repeat(1, self.num_positions - 1) \
419
- + self.position_ids.expand(batch_size, -1)[:, 1:]
420
- embeddings = embeddings + self.position_embedding(position_ids)
421
- else:
422
- class_embeds = self.class_embedding.expand(batch_size, 1, -1)
423
- embeddings = torch.cat([class_embeds, patch_embeds], dim=1)
424
- if interpolate_pos_encoding:
425
- embeddings = embeddings + self.interpolate_pos_encoding(embeddings, height, width)
426
- else:
427
- if img_index is None:
428
- embeddings = embeddings + self.position_embedding(self.position_ids)
429
- else:
430
- position_ids = self.position_ids.expand(batch_size,-1)[:,0].unsqueeze(1)
431
- new_position = (torch.tensor(img_index).cuda() * (self.num_positions -1)).unsqueeze(1).repeat(1,self.num_positions-1) + self.position_ids.expand(batch_size,-1)[:,1:]
432
- position_ids = torch.cat([position_ids,new_position],dim=1)
433
- embeddings = embeddings + self.position_embedding(position_ids)
434
- if self.pre_layernorm is not None:
435
- embeddings = self.pre_layernorm(embeddings)
436
- return embeddings
437
-
438
-
439
- class NaVitTransformer(nn.Module):
440
- def __init__(self, config: HunYuanConfig, vit_config: CLIPVisionConfig):
441
- super().__init__()
442
- self.config = config
443
- self.vit_config = vit_config
444
- with self.prepare_args(config, vit_config):
445
- self._use_sdpa = config._attn_implementation == "sdpa"
446
- self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
447
- self.layers = nn.ModuleList(
448
- [HunYuanDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
449
- )
450
-
451
- @contextmanager
452
- def prepare_args(self, config, vit_config):
453
- hidden_act = config.hidden_act
454
- hidden_size = config.hidden_size
455
- ffn_hidden_size = config.intermediate_size
456
- num_attention_heads = config.num_attention_heads
457
- num_key_value_heads = config.num_key_value_heads
458
- attention_head_dim = config.attention_head_dim
459
- use_qk_norm = config.use_qk_norm
460
- use_rotary_pos_emb = config.use_rotary_pos_emb
461
- num_hidden_layers = config.num_hidden_layers
462
- rms_norm_eps = config.rms_norm_eps
463
- attention_dropout = config.attention_dropout
464
- # hidden_dropout = config.hidden_dropout
465
- norm_type = config.norm_type
466
- attention_bias = config.attention_bias
467
- mlp_bias = config.mlp_bias
468
- use_mla = config.use_mla
469
- num_experts = config.num_experts
470
- _attn_implementation = config._attn_implementation
471
-
472
- config.hidden_act = vit_config.hidden_act
473
- config.hidden_size = vit_config.hidden_size
474
- config.intermediate_size = vit_config.intermediate_size
475
- config.num_attention_heads = vit_config.num_attention_heads
476
- config.num_key_value_heads = None
477
- config.attention_head_dim = vit_config.hidden_size // vit_config.num_attention_heads
478
- config.use_qk_norm = False
479
- config.use_rotary_pos_emb = False
480
- config.num_hidden_layers = vit_config.num_hidden_layers
481
- config.rms_norm_eps = vit_config.layer_norm_eps
482
- config.attention_dropout = vit_config.attention_dropout
483
- # config.hidden_dropout = vit_config.hidden_dropout
484
- config.norm_type = config.vit_norm_type
485
- config.attention_bias = True
486
- config.mlp_bias = True
487
- config.use_mla = False
488
- config.num_experts = 1
489
- config._attn_implementation = "eager"
490
-
491
- yield
492
- config.hidden_act = hidden_act
493
- config.hidden_size = hidden_size
494
- config.intermediate_size = ffn_hidden_size
495
- config.num_attention_heads = num_attention_heads
496
- config.num_key_value_heads = num_key_value_heads
497
- config.attention_head_dim = attention_head_dim
498
- config.use_qk_norm = use_qk_norm
499
- config.use_rotary_pos_emb = use_rotary_pos_emb
500
- config.num_hidden_layers = num_hidden_layers
501
- config.rms_norm_eps = rms_norm_eps
502
- config.attention_dropout = attention_dropout
503
- # config.hidden_dropout = hidden_dropout
504
- config.attention_bias = attention_bias
505
- config.mlp_bias = mlp_bias
506
- config.norm_type = norm_type
507
- config.use_mla = use_mla
508
- config.num_experts = num_experts
509
- config._attn_implementation = _attn_implementation
510
-
511
- def forward(
512
- self,
513
- pixel_values: Optional[torch.FloatTensor] = None,
514
- ) -> Union[Tuple, BaseModelOutputWithPooling]:
515
-
516
- hidden_states, attention_mask, img_pos = self.embeddings(pixel_values)
517
- attention_mask = attention_mask.int()
518
- batch_size, seq_length, _ = hidden_states.shape
519
- past_key_values_length = 0
520
-
521
- if self._use_flash_attention_2:
522
- # 2d mask is passed through the layers
523
- attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
524
- elif self._use_sdpa:
525
- # output_attentions=True can not be supported when using SDPA, and we fall back on
526
- # the manual implementation that requires a 4D causal mask in all cases.
527
- attention_mask = _prepare_4d_causal_attention_mask_for_sdpa(
528
- attention_mask,
529
- (batch_size, seq_length),
530
- hidden_states,
531
- past_key_values_length,
532
- )
533
- else:
534
- attention_mask = _prepare_4d_causal_attention_mask(
535
- attention_mask,
536
- (batch_size, seq_length),
537
- hidden_states,
538
- past_key_values_length,
539
- )
540
-
541
- for layer_idx, decoder_layer in enumerate(self.layers):
542
- layer_outputs = decoder_layer(
543
- hidden_states,
544
- attention_mask=attention_mask
545
- )
546
- hidden_states = layer_outputs[0]
547
-
548
- return hidden_states, img_pos
549
-
550
-
551
- class AnyResVitTransformer(NaVitTransformer):
552
- def __init__(self, config: HunYuanConfig, vit_config: CLIPVisionConfig, anyres_vit_max_image_size):
553
- super().__init__(config, vit_config)
554
- old_anyres_vit_max_image_size = vit_config.max_image_size
555
- anyres_vit_max_image_size = anyres_vit_max_image_size or old_anyres_vit_max_image_size
556
- vit_config.max_image_size = anyres_vit_max_image_size
557
- vit_config.torch_dtype = config.torch_dtype
558
- vit_config.anyres_vit_two_views = config.anyres_vit_two_views
559
- vit_config.vit_remove_prenorm = config.vit_remove_prenorm
560
- vit_config.vit_add_patchemb_bias = config.vit_add_patchemb_bias
561
- self.embeddings = AnyResCLIPVisionEmbeddings(vit_config)
562
- vit_config.max_image_size = old_anyres_vit_max_image_size
563
-
564
- def fix_embeddings_fn(self, pixel_values):
565
- # (B, L, E)
566
- embeddings, hw = self.embeddings.forward_single(pixel_values)
567
- embeddings = self.embeddings.pre_layernorm(embeddings)
568
- return embeddings
569
-
570
-
571
- class CLIPVisionTransformer(nn.Module):
572
- def __init__(self, config: HunYuanConfig, vit_config: CLIPVisionConfig):
573
- super().__init__()
574
- embed_dim = vit_config.hidden_size
575
-
576
- self.pre_layrnorm = nn.LayerNorm(embed_dim, eps=vit_config.layer_norm_eps)
577
- self.embeddings = CLIPVisionEmbeddings(vit_config, skip_cls_token=config.skip_cls_token, vit_patch=config.vit_patch)
578
-
579
- with self.prepare_args(config, vit_config):
580
- self.layers = nn.ModuleList(
581
- [HunYuanDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
582
- )
583
-
584
- @contextmanager
585
- def prepare_args(self, config, vit_config):
586
- hidden_act = config.hidden_act
587
- hidden_size = config.hidden_size
588
- ffn_hidden_size = config.intermediate_size
589
- num_attention_heads = config.num_attention_heads
590
- num_key_value_heads = config.num_key_value_heads
591
- attention_head_dim = config.attention_head_dim
592
- use_qk_norm = config.use_qk_norm
593
- use_rotary_pos_emb = config.use_rotary_pos_emb
594
- num_hidden_layers = config.num_hidden_layers
595
- rms_norm_eps = config.rms_norm_eps
596
- attention_dropout = config.attention_dropout
597
- # hidden_dropout = config.hidden_dropout
598
- norm_type = config.norm_type
599
- attention_bias = config.attention_bias
600
- mlp_bias = config.mlp_bias
601
- use_mla = config.use_mla
602
- num_experts = config.num_experts
603
- _attn_implementation = config._attn_implementation
604
-
605
- config.hidden_act = vit_config.hidden_act
606
- config.hidden_size = vit_config.hidden_size
607
- config.intermediate_size = vit_config.intermediate_size
608
- config.num_attention_heads = vit_config.num_attention_heads
609
- config.num_key_value_heads = None
610
- config.attention_head_dim = vit_config.hidden_size // vit_config.num_attention_heads
611
- config.use_qk_norm = False
612
- config.use_rotary_pos_emb = False
613
- config.num_hidden_layers = vit_config.num_hidden_layers
614
- config.rms_norm_eps = vit_config.layer_norm_eps
615
- config.attention_dropout = vit_config.attention_dropout
616
- # config.hidden_dropout = 0.0
617
- config.norm_type = "fused"
618
- config.attention_bias = True
619
- config.mlp_bias = True
620
- config.use_mla = False
621
- config.num_experts = 1
622
- config._attn_implementation = "eager"
623
-
624
- yield
625
-
626
- config.hidden_act = hidden_act
627
- config.hidden_size = hidden_size
628
- config.intermediate_size = ffn_hidden_size
629
- config.num_attention_heads = num_attention_heads
630
- config.num_key_value_heads = num_key_value_heads
631
- config.attention_head_dim = attention_head_dim
632
- config.use_qk_norm = use_qk_norm
633
- config.use_rotary_pos_emb = use_rotary_pos_emb
634
- config.num_hidden_layers = num_hidden_layers
635
- config.rms_norm_eps = rms_norm_eps
636
- config.attention_dropout = attention_dropout
637
- # config.hidden_dropout = hidden_dropout
638
- config.norm_type = norm_type
639
- config.attention_bias = attention_bias
640
- config.mlp_bias = mlp_bias
641
- config.use_mla = use_mla
642
- config.num_experts = num_experts
643
- config._attn_implementation = _attn_implementation
644
-
645
- def forward(
646
- self,
647
- pixel_values: Optional[torch.FloatTensor] = None,
648
- interpolate_pos_encoding: Optional[bool] = None,
649
- img_index=None
650
- ) -> Union[Tuple, BaseModelOutputWithPooling]:
651
- r"""
652
- Returns:
653
-
654
- """
655
- hidden_states = self.embeddings(pixel_values, interpolate_pos_encoding=interpolate_pos_encoding, img_index=img_index)
656
- hidden_states = self.pre_layrnorm(hidden_states)
657
- batch = hidden_states.shape[0]
658
- seq_len = hidden_states.shape[1]
659
- device = hidden_states.device
660
- attention_mask = torch.ones(batch, 1, seq_len, seq_len, dtype=torch.float32, device=device)
661
-
662
- for layer_idx, decoder_layer in enumerate(self.layers):
663
- layer_outputs = decoder_layer(
664
- hidden_states,
665
- attention_mask=attention_mask
666
- )
667
- hidden_states = layer_outputs[0]
668
-
669
- return hidden_states
670
-
671
-
672
- class Vit(torch.nn.Module):
673
- def __init__(self, config, resampler_token=64, pool_rate=2):
674
- super().__init__()
675
- self.config = config
676
- self.vit_mapping_type = config.vit_mapping_type
677
- self.anyres_vit_max_image_size = config.anyres_vit_max_image_size
678
- self.skip_cls_token = config.skip_cls_token
679
- self.pool_rate = pool_rate
680
- self.vit_type = self.config.vit_type
681
- self.anyres_vit_two_views = self.config.anyres_vit_two_views
682
- if self.vit_type in ['Vit-g', 'Vit-bigG', 'NaVit', 'EvaVit', 'AnyResVit']:
683
- self.img_init(resampler_token, config.vit_input_resolution, config.vit_mapping_type, pool_rate)
684
- else:
685
- raise NotImplementedError(f"unsupported vit type: {self.vit_type}")
686
-
687
- def img_init(self, resampler_token=64, vit_input_resolution=224, vit_mapping_type='resampler', pool_rate=2):
688
- if self.vit_type == 'AnyResVit':
689
- vit_config = json.load(open(f"{self.config.vit_path}/config.json"))
690
- self.vit_config = types.SimpleNamespace(**vit_config["vision_config"])
691
- self.vit_config.image_size = vit_input_resolution
692
- self.vit = AnyResVitTransformer(self.config, self.vit_config, self.anyres_vit_max_image_size)
693
- elif self.vit_type == 'Vit-g':
694
- vit_config = json.load(open(f"{self.config.vit_path}/config.json"))
695
- self.vit_config = types.SimpleNamespace(**{**vit_config["vision_config_dict"],**vit_config["vision_config"]})
696
- self.vit_config.vit_input_resolution = vit_input_resolution
697
- self.vit = CLIPVisionTransformer(self.config, self.vit_config)
698
- else:
699
- assert False, "other vit_types are not supported"
700
-
701
- if self.vit_mapping_type == 'simple_conv_mlp':
702
- self.perceive = SimpleConvMlp(self.vit_config.hidden_size, self.config.hidden_size, self.config.anyres_pooling_size, \
703
- self.config.vit_used_rms_norm, self.config.rms_norm_eps, poolmlp=False, twoview=True)
704
- elif self.vit_mapping_type == 'oryx_mlp':
705
- self.perceive = OryxMLPv2(self.vit_config.hidden_size, self.config.hidden_size, twoview=True, use_pe=False)
706
- elif self.vit_mapping_type == 'mlp':
707
- self.mlp_depth = 2
708
- # one mlp layer already in gpt_model.py
709
- mlp_hidden_size = self.vit_config.hidden_size
710
- if self.vit_type in ['NaVit', 'EvaVit']:
711
- mlp_hidden_size *= self.vit_config.adaptor_patch_size **2
712
- if self.mlp_depth > 1:
713
- mlp_modules = [torch.nn.Linear(mlp_hidden_size, self.config.hidden_size), torch.nn.GELU()]
714
- if self.vit_type in ['NaVit', 'EvaVit']:
715
- for _ in range(1, self.mlp_depth):
716
- mlp_modules.append(torch.nn.Linear(self.config.hidden_size, self.config.hidden_size))
717
- mlp_modules.append(torch.nn.GELU())
718
- self.perceive = torch.nn.Sequential(*mlp_modules)
719
- else:
720
- assert False, "other vit_mapping_types are not supported"
721
-
722
- self.vit_patch_mlp = (self.config.vit_patch > 1 and self.vit_mapping_type == 'mlp') or self.config.vit_patch == 0
723
- for name, param in self.named_parameters():
724
- setattr(param, "is_vit_param", True)
725
-
726
- def forward(self, images, img_index=None):
727
- if self.vit_type in ['AnyResVit']:
728
- dtype = self.config.torch_dtype
729
- device = torch.cuda.current_device()
730
-
731
- images_size = []
732
- for i in range(len(images)):
733
- images_size.append([])
734
- for j in range(len(images[i])):
735
- images_size[i].append((images[i][j].size()[1] // self.vit_config.patch_size, images[i][j].size()[2] // self.vit_config.patch_size))
736
-
737
- images_feats, img_batch_pos = self.vit(pixel_values=images)
738
- a2 = self.vit_config.adaptor_patch_size ** 2
739
-
740
- if self.anyres_vit_two_views:
741
- step = 2
742
- else:
743
- step = 1
744
- perceive_fn = lambda x, img_size, is_video: self.perceive(x, img_size, is_video=is_video)
745
- images_list = []
746
- images_fix_i = 0
747
- num_img_batch_pos = len(img_batch_pos)
748
- for i in range(num_img_batch_pos): # batch_id
749
- for j in range(0, len(img_batch_pos[i]), step):
750
- if self.anyres_vit_two_views:
751
- lower_idx, lower_begin, lower_end = img_batch_pos[i][j]
752
- lower_begin = lower_begin * a2
753
- lower_end = lower_end * a2
754
- higher_idx, higher_begin, higher_end = img_batch_pos[i][j + 1]
755
- higher_begin = higher_begin * a2
756
- higher_end = higher_end * a2
757
- lower_res_feat = images_feats[lower_idx, lower_begin:lower_end].unsqueeze(0)
758
- higher_res_feat = images_feats[higher_idx, higher_begin:higher_end].unsqueeze(0)
759
- lower_images_size = images_size[i][j]
760
- higher_images_size = images_size[i][j + 1]
761
- images_list.append(self.perceive(lower_res_feat, lower_images_size, higher_res_feat, higher_images_size))
762
- else:
763
- idx, begin, end = img_batch_pos[i][j]
764
- begin = begin * a2
765
- end = end * a2
766
- is_video = hasattr(images[i][j],'_is_video') and images[i][j]._is_video
767
- images_list.append(perceive_fn(images_feats[idx, begin:end].unsqueeze(0), images_size[i][j], is_video=is_video))
768
-
769
- images = torch.cat(images_list, dim=1)
770
-
771
- new_batch_pos = []
772
- k = 0; cur_len = 0
773
- for i in range(len(images_size)):
774
- new_batch_pos.append([])
775
- for j in range(0, len(images_size[i]), step):
776
- new_pos = [0, cur_len, cur_len + images_list[k].size(1)]
777
- cur_len += images_list[k].size(1)
778
- k += 1
779
- new_batch_pos[i].append(new_pos)
780
- return images, new_batch_pos
781
- elif self.vit_type == 'Vit-g':
782
- images = self.vit(pixel_values=images, interpolate_pos_encoding=False, img_index=img_index)
783
- else:
784
- assert False, "other vit_types are not supported"
785
-
786
- if self.vit_mapping_type == 'mlp':
787
- if self.vit_type in ['Vit-g'] and not self.skip_cls_token:
788
- images = images[:,1:,:]
789
- b, v, d = images.shape
790
- s = int(math.sqrt(v))
791
- images = images.reshape(b, s, s, d)
792
-
793
-
794
- if self.vit_patch_mlp and img_index is not None:
795
- L_tensor = torch.tensor(img_index)
796
- device = images.device
797
- # 获取子图位置
798
- nonzero_indices = torch.nonzero(L_tensor).squeeze().to(device)
799
- # 获取主图位置
800
- zero_indices = torch.nonzero(L_tensor == 0).squeeze().to(device)
801
-
802
-
803
- images_nonzero = torch.index_select(images,0, nonzero_indices).to(device)
804
- images_zero = torch.index_select(images, 0, zero_indices).to(device)
805
-
806
- # 子图额外多pool一次
807
- pool_rate = self.pool_rate * 2
808
- images_nonzero = images_nonzero.reshape(-1, s // pool_rate, pool_rate, s // pool_rate, pool_rate, d)
809
- images_nonzero = images_nonzero.permute(0, 1, 3, 5, 2, 4).reshape(-1, (s // pool_rate) * (s // pool_rate), d,
810
- pool_rate*pool_rate).mean(-1)
811
-
812
- # 为了组batch折衷方案
813
- images_nonzero = F.pad(images_nonzero, (0, 0, 0, (s // self.pool_rate) * (s // self.pool_rate)- (s // pool_rate) * (s // pool_rate)))
814
- images_zero = images_zero.reshape(-1, s // self.pool_rate, self.pool_rate, s // self.pool_rate, self.pool_rate, d)
815
- images_zero = images_zero.permute(0, 1, 3, 5, 2, 4).reshape(-1, (s // self.pool_rate) * (s // self.pool_rate), d,
816
- self.pool_rate*self.pool_rate).mean(-1)
817
- # 组batch
818
- images = torch.zeros(b, (s // self.pool_rate) * (s // self.pool_rate), d).to(device).to(images.dtype)
819
- images.index_copy_(0, nonzero_indices, images_nonzero)
820
- images.index_copy_(0, zero_indices, images_zero)
821
-
822
- if self.mlp_depth >= 2:
823
- images = self.perceive(images)
824
- else:
825
- if s % self.pool_rate == 0:
826
- images = images.reshape(b, s//self.pool_rate, self.pool_rate, s//self.pool_rate, self.pool_rate, d)
827
- images = images.permute(0, 1, 3, 5, 2, 4).reshape(b, (s//self.pool_rate) * (s//self.pool_rate), d, -1).mean(-1)
828
- if self.mlp_depth >= 2:
829
- images = self.perceive(images)
830
- else:
831
- raise ValueError
832
- return images
833
-
834
-
835
- class SimpleConvMlp(nn.Module):
836
- def __init__(self, in_channels, out_channels, anyres_pooling_size, vit_used_rms_norm, rms_norm_eps, twoview=False, poolmlp=True, cat_extra_token=True):
837
- super().__init__()
838
-
839
- embed_std = 1 / math.sqrt(out_channels)
840
- if poolmlp:
841
- # if args.learnable_mlp_pooling_size is not None:
842
- # in_channels *= args.learnable_mlp_pooling_size ** 2
843
- self.proj = nn.Sequential(
844
- nn.Linear(in_channels, out_channels),
845
- nn.GELU()
846
- )
847
- self.vit_linear_encoder = nn.Linear(out_channels, out_channels)
848
- self.image_newline = nn.Parameter(
849
- torch.randn(out_channels) * embed_std
850
- )
851
- else:
852
- self.proj = nn.Sequential(
853
- nn.Conv2d(in_channels, in_channels * 2, kernel_size=anyres_pooling_size, stride=anyres_pooling_size),
854
- nn.GELU(),
855
- nn.Conv2d(in_channels * 2, in_channels * 4, kernel_size=1),
856
- )
857
- self.mlp = nn.Linear(in_channels * 4, out_channels)
858
- self.image_newline = nn.Parameter(
859
- torch.randn(in_channels * 4) * embed_std
860
- )
861
- self.poolmlp = poolmlp
862
-
863
- self.image_begin = nn.Parameter(
864
- torch.randn(out_channels) * embed_std
865
- )
866
- self.image_end = nn.Parameter(
867
- torch.randn(out_channels) * embed_std
868
- )
869
-
870
- if twoview:
871
- self.image_sep = nn.Parameter(
872
- torch.randn(out_channels) * embed_std
873
- )
874
-
875
- self.cat_extra_token = cat_extra_token
876
- self.use_rms_norm = vit_used_rms_norm
877
- if self.use_rms_norm:
878
- self.before_rms = HunYuanRMSNorm(in_channels, eps=rms_norm_eps)
879
- self.after_rms = HunYuanRMSNorm(out_channels, eps=rms_norm_eps)
880
-
881
- def forward(self, x, size=(16,16), x2=None, size2=(16, 16), is_video=False):
882
- return self.single_forward(x=x, size=size, x2=x2, size2=size2, is_video=is_video)
883
-
884
- def single_forward(self, x, size=(16,16), x2=None, size2=(16, 16), is_video=False):
885
- remove_vit_special_tokens = False
886
- learnable_mlp_pooling_size = None
887
- if self.use_rms_norm:
888
- x = self.before_rms(x)
889
- h, w = size
890
- dtype = x.dtype
891
- x = x.permute(0, 2, 1).reshape(x.shape[0], -1, h, w)
892
- if self.poolmlp:
893
- if learnable_mlp_pooling_size is None:
894
- x = F.avg_pool2d(x, anyres_pooling_size)
895
- x = self.proj(x.permute(0, 2, 3, 1)) # b, h, w, c
896
- else:
897
- x = x.permute(0, 2, 3, 1) # b, h, w, c
898
- x = x.reshape(x.shape[0], h // learnable_mlp_pooling_size, learnable_mlp_pooling_size,
899
- w // learnable_mlp_pooling_size, learnable_mlp_pooling_size, -1)
900
- x = x.permute(0, 1, 3, 2, 4, 5).reshape(x.shape[0], h // learnable_mlp_pooling_size, w // learnable_mlp_pooling_size, -1)
901
- x = self.proj(x)
902
- x = self.vit_linear_encoder(x)
903
- b, h, w, c = x.shape
904
- if not remove_vit_special_tokens:
905
- x = torch.cat([
906
- x,
907
- self.image_newline.reshape(1, 1, 1, c).expand(b, h, 1, c).to(dtype, non_blocking=True)
908
- ], dim=2)
909
- x = x.reshape(b, -1, c)
910
- else:
911
- x = self.proj(x) #b,c,h,w
912
- if is_video:
913
- video_avgpool_size = 2
914
- stride = 2
915
- x = F.avg_pool2d(x, kernel_size = video_avgpool_size, stride = stride)
916
- b, c, h, w = x.shape
917
- if not remove_vit_special_tokens:
918
- x = torch.cat([
919
- x,
920
- self.image_newline.reshape(1, c, 1, 1).expand(b, c, h, 1).to(dtype, non_blocking=True)
921
- ], dim=-1)
922
- x = x.reshape(b, c, -1).permute(0, 2, 1)
923
- x = self.mlp(x)
924
-
925
-
926
- if x2 is not None:
927
- h2, w2 = size2
928
- x2 = x2.permute(0, 2, 1).reshape(x2.shape[0], -1, h2, w2)
929
- if self.poolmlp:
930
- x2 = F.avg_pool2d(x2, 2)
931
- x2 = self.proj(x2.permute(0, 2, 3, 1)) # b, h, w, c
932
- x2 = self.vit_linear_encoder(x2)
933
- b2, h2, w2, c2 = x2.shape
934
- if not remove_vit_special_tokens:
935
- x2 = torch.cat([
936
- x2,
937
- self.image_newline.reshape(1, 1, 1, c2).expand(b2, h2, 1, c2).to(dtype, non_blocking=True)
938
- ], dim=2)
939
- x2 = x2.reshape(b2, -1, c2)
940
- else:
941
- x2 = self.proj(x2)
942
- b2, c2, h2, w2 = x2.shape
943
- if not remove_vit_special_tokens:
944
- x2 = torch.cat([
945
- x2,
946
- self.image_newline.reshape(1, c2, 1, 1).expand(b2, c2, h2, 1).to(dtype, non_blocking=True)
947
- ], dim=-1)
948
- x2 = x2.reshape(b2, c2, -1).permute(0, 2, 1) #b,n,c
949
- x2 = self.mlp(x2)
950
-
951
- sep = self.image_sep.reshape(1, 1, -1).expand(b2, 1, x2.shape[-1]).to(dtype, non_blocking=True)
952
-
953
- x = torch.cat([x, sep, x2], dim=1)
954
-
955
- if self.cat_extra_token:
956
- begin = self.image_begin.reshape(1, 1, -1).expand(b, 1, x.shape[-1]).to(dtype, non_blocking=True)
957
- end = self.image_end.reshape(1, 1, -1).expand(b, 1, x.shape[-1]).to(dtype, non_blocking=True)
958
- x = torch.cat([begin, x, end], dim=1)
959
-
960
- if self.use_rms_norm:
961
- return self.after_rms(x)
962
- else:
963
- return x
964
-
965
-
966
- class NormalizedDwPooler(nn.Module):
967
- def __init__(self, dim):
968
- super().__init__()
969
- self.dim = dim
970
- self.predictor = nn.Sequential(
971
- nn.Linear(dim*2, dim),
972
- nn.GELU(),
973
- nn.Linear(dim, dim),
974
- )
975
-
976
- def forward(self, x, forward_type='2x'):
977
- B, H, W, C = x.shape
978
-
979
- if forward_type == '2x':
980
- new_x = x.reshape(B, H//2, 2, W//2, 2, C).permute(0, 1, 3, 2, 4, 5).reshape(B, H//2, W//2, 4, C)
981
- pooled_x = new_x.mean(-2, keepdim=True).expand(-1, -1, -1, 4, -1)
982
- fused_x = torch.cat([new_x, pooled_x], dim=-1)
983
- elif forward_type == '1x':
984
- new_x = x.reshape(B, H, W, 1, C)
985
- fused_x = torch.cat([new_x, new_x], dim=-1)
986
- elif forward_type == '4x':
987
- new_x = x.reshape(B, H//4, 4, W//4, 4, C).permute(0, 1, 3, 2, 4, 5).reshape(B, H//4, W//4, 16, C)
988
- pooled_x = new_x.mean(-2, keepdim=True).expand(-1, -1, -1, 16, -1)
989
- fused_x = torch.cat([new_x, pooled_x], dim=-1)
990
-
991
- score = self.predictor(fused_x)
992
- normalized_score = F.softmax(score, dim=-2)
993
- new_x = (new_x * normalized_score).sum(dim=-2)
994
- return new_x
995
-
996
-
997
- class OryxMLPv2(nn.Module):
998
- def __init__(self, in_channels, out_channels, twoview=False, use_pe=False):
999
- super().__init__()
1000
-
1001
- self.proj1 = nn.Linear(in_channels, out_channels)
1002
- self.proj2 = nn.Linear(out_channels, out_channels)
1003
- self.act = nn.GELU()
1004
- self.pooler = NormalizedDwPooler(out_channels)
1005
- embed_std = 1 / math.sqrt(out_channels)
1006
-
1007
- self.use_pe = use_pe
1008
- if not use_pe:
1009
- self.image_newline = nn.Parameter(
1010
- torch.randn(out_channels) * embed_std
1011
- )
1012
- self.image_begin = nn.Parameter(
1013
- torch.randn(out_channels) * embed_std
1014
- )
1015
- self.image_end = nn.Parameter(
1016
- torch.randn(out_channels) * embed_std
1017
- )
1018
-
1019
- if twoview:
1020
- self.image_sep = nn.Parameter(
1021
- torch.randn(out_channels) * embed_std
1022
- )
1023
-
1024
- def forward(self, x, size=(16,16), x2=None, size2=(16, 16), is_video=False):
1025
- h, w = size
1026
- dtype = x.dtype
1027
- x = x.reshape(x.shape[0], h, w, -1)
1028
- # x = self.pooler(x, forward_type=REGIONAL_POOL)
1029
- # x = self.proj(x) #b,h,w, c
1030
- x = self.proj1(x)
1031
- x = self.pooler(x, forward_type='2x')
1032
- x = self.act(x)
1033
- x = self.proj2(x)
1034
-
1035
-
1036
- b, h, w, c = x.shape
1037
- if not self.use_pe:
1038
- x = torch.cat([
1039
- x,
1040
- self.image_newline.reshape(1, 1, 1, c).expand(b, h, 1, c).to(dtype)
1041
- ], dim=2)
1042
- else:
1043
- pe_h = torch.arange(h, dtype=torch.long, device=x.device).reshape(1, h, 1, 1).expand(b, h, w, 1).reshape(b, h*w, 1)
1044
- pe_w = torch.arange(w, dtype=torch.long, device=x.device).reshape(1, 1, w, 1).expand(b, h, w, 1).reshape(b, h*w, 1)
1045
- pe = torch.cat([pe_h, pe_w], dim=-1)
1046
-
1047
- x = x.reshape(b, -1, c)
1048
-
1049
- if x2 is not None:
1050
- h2, w2 = size2
1051
- x2 = x2.reshape(x2.shape[0], h2, w2, -1)
1052
- # x2 = self.pooler(x2, forward_type=REGIONAL_POOL)
1053
- ## x2 = self.proj(x2) #b,h,w, c
1054
- x2 = self.proj1(x2)
1055
- x2 = self.pooler(x2, forward_type='2x')
1056
- x2 = self.act(x2)
1057
- x2 = self.proj2(x2)
1058
-
1059
- b2, h2, w2, c2 = x2.shape
1060
- if not self.use_pe:
1061
- x2 = torch.cat([
1062
- x2,
1063
- self.image_newline.reshape(1, 1, 1, c).expand(b, h2, 1, c).to(dtype)
1064
- ], dim=2)
1065
- x2 = x2.reshape(b, -1, c)
1066
- sep = self.image_sep.reshape(1, 1, -1).expand(b, 1, c2).to(dtype)
1067
- x = torch.cat([x, sep, x2], dim=1)
1068
-
1069
- begin = self.image_begin.reshape(1, 1, -1).expand(b, 1, c).to(dtype)
1070
- end = self.image_end.reshape(1, 1, -1).expand(b, 1, c).to(dtype)
1071
- x = torch.cat([begin, x, end], dim=1)
1072
- # print(x.shape, x2.shape, h, w, h2, w2)
1073
- # print("vit rank = " + str(torch.distributed.get_rank()) +" x = " + str(x))
1074
- if self.use_pe:
1075
- zero_pad = torch.zeros(b, 1, 2, device=x.device, dtype=torch.long)
1076
- pe = torch.cat([zero_pad, pe, zero_pad], dim=1)
1077
- assert pe.shape[1] == x.shape[1]
1078
- return x, pe
1079
- else:
1080
- nseq = x.shape[1]
1081
- fake_pe = torch.zeros(b, nseq, 2, device=x.device, dtype=torch.long)
1082
- return x #, fake_pe
1083
-
 
1
+ TENCENT HUNYUAN COMMUNITY LICENSE AGREEMENT
2
+ Tencent Hunyuan A13B Release Date: June 27, 2025
3
+ THIS LICENSE AGREEMENT DOES NOT APPLY IN THE EUROPEAN UNION, UNITED KINGDOM AND SOUTH KOREA AND IS EXPRESSLY LIMITED TO THE TERRITORY, AS DEFINED BELOW.
4
+ By clicking to agree or by using, reproducing, modifying, distributing, performing or displaying any portion or element of the Tencent Hunyuan Works, including via any Hosted Service, You will be deemed to have recognized and accepted the content of this Agreement, which is effective immediately.
5
+ 1. DEFINITIONS.
6
+ a. “Acceptable Use Policy” shall mean the policy made available by Tencent as set forth in the Exhibit A.
7
+ b. “Agreement” shall mean the terms and conditions for use, reproduction, distribution, modification, performance and displaying of Tencent Hunyuan Works or any portion or element thereof set forth herein.
8
+ c. “Documentation” shall mean the specifications, manuals and documentation for Tencent Hunyuan made publicly available by Tencent.
9
+ d. “Hosted Service” shall mean a hosted service offered via an application programming interface (API), web access, or any other electronic or remote means.
10
+ e. “Licensee,” “You” or “Your” shall mean a natural person or legal entity exercising the rights granted by this Agreement and/or using the Tencent Hunyuan Works for any purpose and in any field of use.
11
+ f. “Materials” shall mean, collectively, Tencent’s proprietary Tencent Hunyuan and Documentation (and any portion thereof) as made available by Tencent under this Agreement.
12
+ g. “Model Derivatives” shall mean all: (i) modifications to Tencent Hunyuan or any Model Derivative of Tencent Hunyuan; (ii) works based on Tencent Hunyuan or any Model Derivative of Tencent Hunyuan; or (iii) any other machine learning model which is created by transfer of patterns of the weights, parameters, operations, or Output of Tencent Hunyuan or any Model Derivative of Tencent Hunyuan, to that model in order to cause that model to perform similarly to Tencent Hunyuan or a Model Derivative of Tencent Hunyuan, including distillation methods, methods that use intermediate data representations, or methods based on the generation of synthetic data Outputs by Tencent Hunyuan or a Model Derivative of Tencent Hunyuan for training that model. For clarity, Outputs by themselves are not deemed Model Derivatives.
13
+ h. “Output” shall mean the information and/or content output of Tencent Hunyuan or a Model Derivative that results from operating or otherwise using Tencent Hunyuan or a Model Derivative, including via a Hosted Service.
14
+ i. “Tencent,” “We” or “Us” shall mean the applicable entity or entities in the Tencent corporate family that own(s) intellectual property or other rights embodied in or utilized by the Materials.
15
+ j. “Tencent Hunyuan” shall mean the large language models, text/image/video/audio/3D generation models, and multimodal large language models and their software and algorithms, including trained model weights, parameters (including optimizer states), machine-learning model code, inference-enabling code, training-enabling code, fine-tuning enabling code and other elements of the foregoing made publicly available by Us, including, without limitation to, Tencent Hunyuan A13B released at [https://github.com/Tencent-Hunyuan/Hunyuan-A13B].
16
+ k. “Tencent Hunyuan Works” shall mean: (i) the Materials; (ii) Model Derivatives; and (iii) all derivative works thereof.
17
+ l. “Territory” shall mean the worldwide territory, excluding the territory of the European Union, United Kingdom and South Korea.
18
+ m. “Third Party” or “Third Parties” shall mean individuals or legal entities that are not under common control with Us or You.
19
+ n. “including” shall mean including but not limited to.
20
+ 2. GRANT OF RIGHTS.
21
+ We grant You, for the Territory only, a non-exclusive, non-transferable and royalty-free limited license under Tencent’s intellectual property or other rights owned by Us embodied in or utilized by the Materials to use, reproduce, distribute, create derivative works of (including Model Derivatives), and make modifications to the Materials, only in accordance with the terms of this Agreement and the Acceptable Use Policy, and You must not violate (or encourage or permit anyone else to violate) any term of this Agreement or the Acceptable Use Policy.
22
+ 3. DISTRIBUTION.
23
+ You may, subject to Your compliance with this Agreement, distribute or make available to Third Parties the Tencent Hunyuan Works, exclusively in the Territory, provided that You meet all of the following conditions:
24
+ a. You must provide all such Third Party recipients of the Tencent Hunyuan Works or products or services using them a copy of this Agreement;
25
+ b. You must cause any modified files to carry prominent notices stating that You changed the files;
26
+ c. You are encouraged to: (i) publish at least one technology introduction blogpost or one public statement expressing Your experience of using the Tencent Hunyuan Works; and (ii) mark the products or services developed by using the Tencent Hunyuan Works to indicate that the product/service is “Powered by Tencent Hunyuan”; and
27
+ d. All distributions to Third Parties (other than through a Hosted Service) must be accompanied by a “Notice” text file that contains the following notice: “Tencent Hunyuan is licensed under the Tencent Hunyuan Community License Agreement, Copyright © 2025 Tencent. All Rights Reserved. The trademark rights of “Tencent Hunyuan” are owned by Tencent or its affiliate.”
28
+ You may add Your own copyright statement to Your modifications and, except as set forth in this Section and in Section 5, may provide additional or different license terms and conditions for use, reproduction, or distribution of Your modifications, or for any such Model Derivatives as a whole, provided Your use, reproduction, modification, distribution, performance and display of the work otherwise complies with the terms and conditions of this Agreement (including as regards the Territory). If You receive Tencent Hunyuan Works from a Licensee as part of an integrated end user product, then this Section 3 of this Agreement will not apply to You.
29
+ 4. ADDITIONAL COMMERCIAL TERMS.
30
+ If, on the Tencent Hunyuan version release date, the monthly active users of all products or services made available by or for Licensee is greater than 100 million monthly active users in the preceding calendar month, You must request a license from Tencent, which Tencent may grant to You in its sole discretion, and You are not authorized to exercise any of the rights under this Agreement unless or until Tencent otherwise expressly grants You such rights.
31
+ 5. RULES OF USE.
32
+ a. Your use of the Tencent Hunyuan Works must comply with applicable laws and regulations (including trade compliance laws and regulations) and adhere to the Acceptable Use Policy for the Tencent Hunyuan Works, which is hereby incorporated by reference into this Agreement. You must include the use restrictions referenced in these Sections 5(a) and 5(b) as an enforceable provision in any agreement (e.g., license agreement, terms of use, etc.) governing the use and/or distribution of Tencent Hunyuan Works and You must provide notice to subsequent users to whom You distribute that Tencent Hunyuan Works are subject to the use restrictions in these Sections 5(a) and 5(b).
33
+ b. You must not use the Tencent Hunyuan Works or any Output or results of the Tencent Hunyuan Works to improve any other AI model (other than Tencent Hunyuan or Model Derivatives thereof).
34
+ c. You must not use, reproduce, modify, distribute, or display the Tencent Hunyuan Works, Output or results of the Tencent Hunyuan Works outside the Territory. Any such use outside the Territory is unlicensed and unauthorized under this Agreement.
35
+ 6. INTELLECTUAL PROPERTY.
36
+ a. Subject to Tencent’s ownership of Tencent Hunyuan Works made by or for Tencent and intellectual property rights therein, conditioned upon Your compliance with the terms and conditions of this Agreement, as between You and Tencent, You will be the owner of any derivative works and modifications of the Materials and any Model Derivatives that are made by or for You.
37
+ b. No trademark licenses are granted under this Agreement, and in connection with the Tencent Hunyuan Works, Licensee may not use any name or mark owned by or associated with Tencent or any of its affiliates, except as required for reasonable and customary use in describing and distributing the Tencent Hunyuan Works. Tencent hereby grants You a license to use “Tencent Hunyuan” (the “Mark”) in the Territory solely as required to comply with the provisions of Section 3(c), provided that You comply with any applicable laws related to trademark protection. All goodwill arising out of Your use of the Mark will inure to the benefit of Tencent.
38
+ c. If You commence a lawsuit or other proceedings (including a cross-claim or counterclaim in a lawsuit) against Us or any person or entity alleging that the Materials or any Output, or any portion of any of the foregoing, infringe any intellectual property or other right owned or licensable by You, then all licenses granted to You under this Agreement shall terminate as of the date such lawsuit or other proceeding is filed. You will defend, indemnify and hold harmless Us from and against any claim by any Third Party arising out of or related to Your or the Third Party’s use or distribution of the Tencent Hunyuan Works.
39
+ d. Tencent claims no rights in Outputs You generate. You and Your users are solely responsible for Outputs and their subsequent uses.
40
+ 7. DISCLAIMERS OF WARRANTY AND LIMITATIONS OF LIABILITY.
41
+ a. We are not obligated to support, update, provide training for, or develop any further version of the Tencent Hunyuan Works or to grant any license thereto.
42
+ b. UNLESS AND ONLY TO THE EXTENT REQUIRED BY APPLICABLE LAW, THE TENCENT HUNYUAN WORKS AND ANY OUTPUT AND RESULTS THEREFROM ARE PROVIDED “AS IS” WITHOUT ANY EXPRESS OR IMPLIED WARRANTIES OF ANY KIND INCLUDING ANY WARRANTIES OF TITLE, MERCHANTABILITY, NONINFRINGEMENT, COURSE OF DEALING, USAGE OF TRADE, OR FITNESS FOR A PARTICULAR PURPOSE. YOU ARE SOLELY RESPONSIBLE FOR DETERMINING THE APPROPRIATENESS OF USING, REPRODUCING, MODIFYING, PERFORMING, DISPLAYING OR DISTRIBUTING ANY OF THE TENCENT HUNYUAN WORKS OR OUTPUTS AND ASSUME ANY AND ALL RISKS ASSOCIATED WITH YOUR OR A THIRD PARTY’S USE OR DISTRIBUTION OF ANY OF THE TENCENT HUNYUAN WORKS OR OUTPUTS AND YOUR EXERCISE OF RIGHTS AND PERMISSIONS UNDER THIS AGREEMENT.
43
+ c. TO THE FULLEST EXTENT PERMITTED BY APPLICABLE LAW, IN NO EVENT SHALL TENCENT OR ITS AFFILIATES BE LIABLE UNDER ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, TORT, NEGLIGENCE, PRODUCTS LIABILITY, OR OTHERWISE, FOR ANY DAMAGES, INCLUDING ANY DIRECT, INDIRECT, SPECIAL, INCIDENTAL, EXEMPLARY, CONSEQUENTIAL OR PUNITIVE DAMAGES, OR LOST PROFITS OF ANY KIND ARISING FROM THIS AGREEMENT OR RELATED TO ANY OF THE TENCENT HUNYUAN WORKS OR OUTPUTS, EVEN IF TENCENT OR ITS AFFILIATES HAVE BEEN ADVISED OF THE POSSIBILITY OF ANY OF THE FOREGOING.
44
+ 8. SURVIVAL AND TERMINATION.
45
+ a. The term of this Agreement shall commence upon Your acceptance of this Agreement or access to the Materials and will continue in full force and effect until terminated in accordance with the terms and conditions herein.
46
+ b. We may terminate this Agreement if You breach any of the terms or conditions of this Agreement. Upon termination of this Agreement, You must promptly delete and cease use of the Tencent Hunyuan Works. Sections 6(a), 6(c), 7 and 9 shall survive the termination of this Agreement.
47
+ 9. GOVERNING LAW AND JURISDICTION.
48
+ a. This Agreement and any dispute arising out of or relating to it will be governed by the laws of the Hong Kong Special Administrative Region of the People’s Republic of China, without regard to conflict of law principles, and the UN Convention on Contracts for the International Sale of Goods does not apply to this Agreement.
49
+ b. Exclusive jurisdiction and venue for any dispute arising out of or relating to this Agreement will be a court of competent jurisdiction in the Hong Kong Special Administrative Region of the People’s Republic of China, and Tencent and Licensee consent to the exclusive jurisdiction of such court with respect to any such dispute.
50
+
51
+ EXHIBIT A
52
+ ACCEPTABLE USE POLICY
53
+
54
+ Tencent reserves the right to update this Acceptable Use Policy from time to time.
55
+ Last modified: November 5, 2024
56
+
57
+ Tencent endeavors to promote safe and fair use of its tools and features, including Tencent Hunyuan. You agree not to use Tencent Hunyuan or Model Derivatives:
58
+ 1. Outside the Territory;
59
+ 2. In any way that violates any applicable national, federal, state, local, international or any other law or regulation;
60
+ 3. To harm Yourself or others;
61
+ 4. To repurpose or distribute output from Tencent Hunyuan or any Model Derivatives to harm Yourself or others;
62
+ 5. To override or circumvent the safety guardrails and safeguards We have put in place;
63
+ 6. For the purpose of exploiting, harming or attempting to exploit or harm minors in any way;
64
+ 7. To generate or disseminate verifiably false information and/or content with the purpose of harming others or influencing elections;
65
+ 8. To generate or facilitate false online engagement, including fake reviews and other means of fake online engagement;
66
+ 9. To intentionally defame, disparage or otherwise harass others;
67
+ 10. To generate and/or disseminate malware (including ransomware) or any other content to be used for the purpose of harming electronic systems;
68
+ 11. To generate or disseminate personal identifiable information with the purpose of harming others;
69
+ 12. To generate or disseminate information (including images, code, posts, articles), and place the information in any public context (including –through the use of bot generated tweets), without expressly and conspicuously identifying that the information and/or content is machine generated;
70
+ 13. To impersonate another individual without consent, authorization, or legal right;
71
+ 14. To make high-stakes automated decisions in domains that affect an individual’s safety, rights or wellbeing (e.g., law enforcement, migration, medicine/health, management of critical infrastructure, safety components of products, essential services, credit, employment, housing, education, social scoring, or insurance);
72
+ 15. In a manner that violates or disrespects the social ethics and moral standards of other countries or regions;
73
+ 16. To perform, facilitate, threaten, incite, plan, promote or encourage violent extremism or terrorism;
74
+ 17. For any use intended to discriminate against or harm individuals or groups based on protected characteristics or categories, online or offline social behavior or known or predicted personal or personality characteristics;
75
+ 18. To intentionally exploit any of the vulnerabilities of a specific group of persons based on their age, social, physical or mental characteristics, in order to materially distort the behavior of a person pertaining to that group in a manner that causes or is likely to cause that person or another person physical or psychological harm;
76
+ 19. For military purposes;
77
+ 20. To engage in the unauthorized or unlicensed practice of any profession including, but not limited to, financial, legal, medical/health, or other professional practices.