Update Notice.txt
Browse files- Notice.txt +159 -1082
Notice.txt
CHANGED
@@ -1,1083 +1,160 @@
|
|
1 |
-
|
2 |
-
import types
|
3 |
-
import math
|
4 |
-
import torch
|
5 |
-
from torch import Tensor, nn
|
6 |
-
import torch.nn.functional as F
|
7 |
-
from typing import List, Tuple, Optional, Union
|
8 |
-
from contextlib import contextmanager
|
9 |
-
from transformers.modeling_attn_mask_utils import (
|
10 |
-
_prepare_4d_causal_attention_mask_for_sdpa,
|
11 |
-
_prepare_4d_causal_attention_mask_for_sdpa,
|
12 |
-
_prepare_4d_causal_attention_mask,
|
13 |
-
)
|
14 |
-
from transformers.models.clip.configuration_clip import CLIPVisionConfig
|
15 |
-
from transformers.modeling_outputs import BaseModelOutputWithPooling
|
16 |
-
from .modeling_hunyuan import HunYuanDecoderLayer, HunYuanRMSNorm
|
17 |
-
from .configuration_hunyuan import HunYuanConfig
|
18 |
-
|
19 |
-
|
20 |
-
def NaVitForward(input_ids, encoder_input, vit, image_tensors, images_pos, vit_input_resolution, im_start_id, im_end_id, image_token_id, anyres_vit_two_views, dtype):
|
21 |
-
# input_ids: (B, L)
|
22 |
-
# encoder_input: (L, B, E)
|
23 |
-
# image_tensors [[Tensor],...,[Tensor]]
|
24 |
-
# image_pos [[Tensor],...,[Tensor]]
|
25 |
-
# tokenizer = get_tokenizer()
|
26 |
-
b = len(input_ids)
|
27 |
-
img_embs = None
|
28 |
-
all_nums = sum([len(tensors) for tensors in image_tensors]) if image_tensors else 0
|
29 |
-
if all_nums != 0:
|
30 |
-
img_embs, img_batch_pos = vit(image_tensors)
|
31 |
-
else:
|
32 |
-
# when no input image, initialize a fake tensor
|
33 |
-
pad_nums = 1
|
34 |
-
image_tensors = [[torch.rand(3, vit_input_resolution, vit_input_resolution, dtype=dtype, device=torch.cuda.current_device()) for _ in range(pad_nums)]]
|
35 |
-
img_embs, img_batch_pos = vit(image_tensors)
|
36 |
-
|
37 |
-
encoder_input = encoder_input.clone()
|
38 |
-
if all_nums > 0:
|
39 |
-
assert len(images_pos) == len(img_batch_pos), \
|
40 |
-
(len(images_pos), len(img_batch_pos))
|
41 |
-
start_token_id = im_start_id
|
42 |
-
end_token_id = im_end_id
|
43 |
-
placeholder_id = image_token_id
|
44 |
-
for idx in range(len(images_pos)):
|
45 |
-
assert len(images_pos[idx]) == len(img_batch_pos[idx]), \
|
46 |
-
(len(images_pos[idx]), len(img_batch_pos[idx]))
|
47 |
-
for p_img_pos_in_batch, p_batch_img_pos in zip(img_batch_pos[idx], images_pos[idx]):
|
48 |
-
# the positions to be filled [s_start, s_end)
|
49 |
-
s_idx, s_start, s_end = p_img_pos_in_batch
|
50 |
-
current_embs = img_embs[s_idx, s_start:s_end]
|
51 |
-
im_s, im_e = p_batch_img_pos
|
52 |
-
assert len(current_embs) == im_e - im_s, \
|
53 |
-
(img_embs.shape, (s_start, s_end, s_idx), current_embs.shape, (im_s, im_e, idx))
|
54 |
-
if not anyres_vit_two_views:
|
55 |
-
assert input_ids[idx, im_s - 1] == start_token_id, \
|
56 |
-
input_ids[idx, im_s - 1]
|
57 |
-
assert input_ids[idx, im_e] == end_token_id, \
|
58 |
-
input_ids[idx, im_e]
|
59 |
-
assert (input_ids[idx, im_s:im_e] == placeholder_id).all(), \
|
60 |
-
f'The tokens to be filled are not the placeholder_id {placeholder_id}: {(input_ids[idx, im_s:im_e] == placeholder_id).sum()} vs {im_e - im_s}'
|
61 |
-
encoder_input[idx, im_s:im_e] = current_embs
|
62 |
-
else:
|
63 |
-
# when no input image, to mask vit value
|
64 |
-
vit_mask = torch.zeros([1, img_embs.shape[0]], device=torch.cuda.current_device())
|
65 |
-
current_embs = img_embs[0, :]
|
66 |
-
encoder_input[0, 1:img_embs.shape[0] + 1] = encoder_input[0, 1:img_embs.shape[0] + 1] * (1 - vit_mask) + current_embs * vit_mask
|
67 |
-
return encoder_input, input_ids
|
68 |
-
|
69 |
-
|
70 |
-
def VitForward(input_ids, encoder_input, vit, vit_linear_encoder, image_tensors, images_pos, vit_input_resolution, vit_mapping_type, vit_patch, vit_token):
|
71 |
-
vit_patch_mlp = (vit_patch > 1 and vit_mapping_type == 'mlp') or vit_patch == 0
|
72 |
-
|
73 |
-
b = len(input_ids)
|
74 |
-
if images_pos is None:
|
75 |
-
images_pos = torch.ones([len(input_ids), 1, 3])
|
76 |
-
images_pos[:, :, 1] = images_pos[:, :, 1]*(vit_token + 1)
|
77 |
-
images_pos = images_pos.long()
|
78 |
-
|
79 |
-
real_image_nums = []
|
80 |
-
image_tensors = image_tensors.view(b, -1, 3, vit_input_resolution, vit_input_resolution)
|
81 |
-
real_images = []
|
82 |
-
|
83 |
-
all_nums = 0
|
84 |
-
img_index = []
|
85 |
-
for s in range(len(images_pos)):
|
86 |
-
real_image_num = 0
|
87 |
-
for (im_s, im_e,index) in images_pos[s]:
|
88 |
-
if im_s == -1:
|
89 |
-
break
|
90 |
-
real_image_num += 1
|
91 |
-
all_nums += 1
|
92 |
-
img_index.append(index)
|
93 |
-
|
94 |
-
real_image_nums.append(real_image_num)
|
95 |
-
real_images.append(image_tensors[s][:real_image_num])
|
96 |
-
|
97 |
-
if vit_patch == 1:
|
98 |
-
img_index = None
|
99 |
-
|
100 |
-
if all_nums == 0:
|
101 |
-
# when no input image, initialize a fake tensor
|
102 |
-
img_input = torch.rand(b, 3, vit_input_resolution, vit_input_resolution).cuda().type(image_tensors.dtype)
|
103 |
-
img_embs = vit(img_input)
|
104 |
-
img_embs = vit_linear_encoder(img_embs)
|
105 |
-
else:
|
106 |
-
img_input = torch.cat(real_images)
|
107 |
-
img_embs = vit(img_input, img_index = img_index)
|
108 |
-
img_embs = vit_linear_encoder(img_embs)
|
109 |
-
|
110 |
-
encoder_input = encoder_input.clone()
|
111 |
-
start = 0
|
112 |
-
if all_nums > 0:
|
113 |
-
for s, real_image_len in enumerate(real_image_nums):
|
114 |
-
current_embs = img_embs[start:start + real_image_len, :] #[30, 256, 4096]
|
115 |
-
for ss in range(current_embs.shape[0]):
|
116 |
-
im_s, im_e, index = images_pos[s, ss]
|
117 |
-
# 子图特征更少
|
118 |
-
if index > 0 and vit_patch_mlp:
|
119 |
-
encoder_input[s, im_s:im_e,] = current_embs[ss, :(im_e-im_s)]
|
120 |
-
else:
|
121 |
-
encoder_input[s, im_s:im_e] = current_embs[ss, :]
|
122 |
-
start = start + real_image_len
|
123 |
-
else:
|
124 |
-
# when no input image, to mask vit value
|
125 |
-
for s in range(b):
|
126 |
-
vit_mask = torch.zeros([vit_token, 1]).cuda()
|
127 |
-
current_embs = img_embs[:, start:start + 1]
|
128 |
-
encoder_input[1:vit_token + 1, s] = encoder_input[1:vit_token + 1, s] * (1 - vit_mask) + current_embs[:, 0, :] * vit_mask
|
129 |
-
start = start + 1
|
130 |
-
return encoder_input, input_ids
|
131 |
-
|
132 |
-
|
133 |
-
def group_images_by_max_seq_len(
|
134 |
-
images: List[List[Tensor]], patch_size: int,
|
135 |
-
max_seq_len: int, adaptor_patch_size: int,
|
136 |
-
add_cls_token: bool = False) -> List[List[Tensor]]:
|
137 |
-
|
138 |
-
groups = []
|
139 |
-
group = []
|
140 |
-
pos_groups = []
|
141 |
-
seq_len = 0
|
142 |
-
num_images = 0
|
143 |
-
for image_list in images:
|
144 |
-
pos_group = []
|
145 |
-
for image in image_list:
|
146 |
-
num_images += 1
|
147 |
-
assert isinstance(image, Tensor)
|
148 |
-
|
149 |
-
image_dims = image.shape[-2:]
|
150 |
-
ph, pw = map(lambda t: t // patch_size, image_dims)
|
151 |
-
|
152 |
-
image_seq_len = (ph * pw)
|
153 |
-
new_image_seq_len = image_seq_len
|
154 |
-
grouped_len = seq_len + image_seq_len
|
155 |
-
if add_cls_token:
|
156 |
-
new_image_seq_len += 1
|
157 |
-
grouped_len += num_images
|
158 |
-
|
159 |
-
assert new_image_seq_len <= max_seq_len, f'image with dimensions {image_dims} exceeds maximum sequence length'
|
160 |
-
|
161 |
-
if grouped_len > max_seq_len:
|
162 |
-
groups.append(group)
|
163 |
-
group = []
|
164 |
-
seq_len = 0
|
165 |
-
num_images = 1
|
166 |
-
|
167 |
-
group.append(image)
|
168 |
-
start = seq_len // (adaptor_patch_size * adaptor_patch_size)
|
169 |
-
end = start + image_seq_len//(adaptor_patch_size * adaptor_patch_size)
|
170 |
-
batch_idx = len(groups)
|
171 |
-
pos_group.append([batch_idx, start, end])
|
172 |
-
seq_len += image_seq_len
|
173 |
-
pos_groups.append(pos_group)
|
174 |
-
|
175 |
-
if len(group) > 0:
|
176 |
-
groups.append(group)
|
177 |
-
|
178 |
-
return groups, pos_groups
|
179 |
-
|
180 |
-
|
181 |
-
class AnyResCLIPVisionEmbeddings(nn.Module):
|
182 |
-
def __init__(self, config: CLIPVisionConfig):
|
183 |
-
super().__init__()
|
184 |
-
|
185 |
-
self.config = config
|
186 |
-
# self.sparse_attn_mask = args.sparse_attn_mask
|
187 |
-
# self.use_flash_attn = args.use_flash_attn
|
188 |
-
self.embed_dim = config.hidden_size
|
189 |
-
self.image_size = config.max_image_size
|
190 |
-
self.patch_size = config.patch_size
|
191 |
-
self.max_seq_len = config.max_vit_seq_len
|
192 |
-
self.adaptor_patch_size = config.adaptor_patch_size
|
193 |
-
self.anyres_vit_two_views = config.anyres_vit_two_views
|
194 |
-
self.vit_add_patchemb_bias = config.vit_add_patchemb_bias
|
195 |
-
self.vit_remove_prenorm = config.vit_remove_prenorm
|
196 |
-
|
197 |
-
self.patch_embedding = nn.Conv2d(
|
198 |
-
in_channels=config.num_channels,
|
199 |
-
out_channels=self.embed_dim,
|
200 |
-
kernel_size=self.patch_size,
|
201 |
-
stride=self.patch_size,
|
202 |
-
bias=self.vit_add_patchemb_bias,
|
203 |
-
)
|
204 |
-
|
205 |
-
self.num_patches = (self.image_size // self.patch_size) ** 2
|
206 |
-
self.skip_cls_token = True
|
207 |
-
|
208 |
-
# add interpolate_pos_encoding
|
209 |
-
if self.anyres_vit_two_views:
|
210 |
-
self.num_positions = self.num_patches
|
211 |
-
self.position_embedding = nn.Parameter(torch.randn(1, self.num_positions, self.embed_dim) * 0.02)
|
212 |
-
else:
|
213 |
-
self.num_positions = self.num_patches + 1
|
214 |
-
self.register_buffer("position_ids", torch.arange(self.num_positions).expand((1, -1)))
|
215 |
-
# self.position_ids = torch.arange(self.num_positions).expand((1, -1))
|
216 |
-
self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim)
|
217 |
-
|
218 |
-
if not self.vit_remove_prenorm:
|
219 |
-
self.pre_layernorm = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
|
220 |
-
|
221 |
-
def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor:
|
222 |
-
"""
|
223 |
-
This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher
|
224 |
-
resolution images.
|
225 |
-
|
226 |
-
Source:
|
227 |
-
https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174
|
228 |
-
"""
|
229 |
-
num_patches = embeddings.shape[1]
|
230 |
-
position_embeddings = self.position_embedding(self.position_ids)
|
231 |
-
patch_pos_embed = position_embeddings[:, 1:]
|
232 |
-
num_positions = position_embeddings.shape[1] - 1
|
233 |
-
if num_patches == num_positions and height == width:
|
234 |
-
return patch_pos_embed
|
235 |
-
# class_pos_embed = position_embeddings[:, 0]
|
236 |
-
dim = embeddings.shape[-1]
|
237 |
-
h0 = height // self.patch_size
|
238 |
-
w0 = width // self.patch_size
|
239 |
-
# we add a small number to avoid floating point error in the interpolation
|
240 |
-
# see discussion at https://github.com/facebookresearch/dino/issues/8
|
241 |
-
h0, w0 = h0 + 0.1, w0 + 0.1
|
242 |
-
patch_pos_embed = patch_pos_embed.reshape(1, int(math.sqrt(num_positions)), int(math.sqrt(num_positions)), dim)
|
243 |
-
patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2)
|
244 |
-
raw_type = patch_pos_embed.dtype
|
245 |
-
patch_pos_embed = nn.functional.interpolate(
|
246 |
-
patch_pos_embed.to(torch.float32, non_blocking=True),
|
247 |
-
scale_factor=(h0 / math.sqrt(num_positions), w0 / math.sqrt(num_positions)),
|
248 |
-
mode="bilinear",
|
249 |
-
align_corners=False,
|
250 |
-
)
|
251 |
-
patch_pos_embed = patch_pos_embed.to(raw_type, non_blocking=True)
|
252 |
-
assert int(h0) == patch_pos_embed.shape[-2] and int(w0) == patch_pos_embed.shape[-1]
|
253 |
-
patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
|
254 |
-
return patch_pos_embed
|
255 |
-
|
256 |
-
def rescale_positional_embedding(self, out_size):
|
257 |
-
h, w = out_size
|
258 |
-
pos_embed_shape = int((self.position_embedding.shape[1]) ** 0.5)
|
259 |
-
if (h, w) == (pos_embed_shape, pos_embed_shape):
|
260 |
-
return self.position_embedding
|
261 |
-
rescaled_positional_embedding = \
|
262 |
-
self.position_embedding.new_zeros(1, h*w, self.position_embedding.shape[2])
|
263 |
-
pe_2d = self.position_embedding[0].T.contiguous().view(1, -1, pos_embed_shape, pos_embed_shape)
|
264 |
-
pe_2d = F.interpolate(pe_2d, out_size, mode='bilinear', align_corners=False).view(-1, h*w)
|
265 |
-
rescaled_positional_embedding[0] = pe_2d.T.contiguous()
|
266 |
-
return rescaled_positional_embedding
|
267 |
-
|
268 |
-
def forward_single(self, pixel_values: torch.FloatTensor) -> torch.Tensor:
|
269 |
-
if pixel_values.ndim == 3:
|
270 |
-
pixel_values = pixel_values[None]
|
271 |
-
batch_size, num_channels, height, width = pixel_values.shape
|
272 |
-
|
273 |
-
if self.anyres_vit_two_views:
|
274 |
-
# padding
|
275 |
-
pad_h = (self.patch_size - height % self.patch_size) % self.patch_size
|
276 |
-
pad_w = (self.patch_size - width % self.patch_size) % self.patch_size
|
277 |
-
pixel_values = F.pad(pixel_values, (0, pad_w, 0, pad_h))
|
278 |
-
|
279 |
-
patch_embeds = self.patch_embedding(pixel_values) # shape = [*, width, grid, grid]
|
280 |
-
b, c, h, w = patch_embeds.shape
|
281 |
-
|
282 |
-
# (b, hw, c)
|
283 |
-
patch_embeds = patch_embeds.flatten(2).transpose(1, 2)
|
284 |
-
if self.anyres_vit_two_views:
|
285 |
-
embeddings = patch_embeds + self.rescale_positional_embedding(out_size=(h, w))
|
286 |
-
else:
|
287 |
-
embeddings = patch_embeds + self.interpolate_pos_encoding(patch_embeds, height, width)
|
288 |
-
if not self.vit_remove_prenorm:
|
289 |
-
embeddings = self.pre_layernorm(embeddings)
|
290 |
-
return embeddings, (h, w)
|
291 |
-
|
292 |
-
def forward(self, images: List[List[Tensor]]):
|
293 |
-
'''
|
294 |
-
Input:
|
295 |
-
images: List[List[Tensor]]
|
296 |
-
|
297 |
-
Return:
|
298 |
-
embeddings: Tensor (B, L, E)
|
299 |
-
attn_mask: Tensor (B, L, 2)
|
300 |
-
pos_groups: List[List[(batch_idx, start, end)]]
|
301 |
-
'''
|
302 |
-
batched_images, pos_groups = group_images_by_max_seq_len(
|
303 |
-
images, self.patch_size, self.max_seq_len, self.adaptor_patch_size, add_cls_token=not self.skip_cls_token)
|
304 |
-
max_seq_len = self.max_seq_len
|
305 |
-
|
306 |
-
# batched_images is a list of a list
|
307 |
-
B = len(batched_images)
|
308 |
-
L = max_seq_len
|
309 |
-
E = self.embed_dim
|
310 |
-
|
311 |
-
embeddings = torch.zeros(B, L, E, dtype=self.config.torch_dtype, requires_grad=True).cuda(non_blocking=True)
|
312 |
-
attn_mask = embeddings.new_full((B, 1, L, L), False, dtype=torch.bool) # True presents compute
|
313 |
-
assert len(images) == len(pos_groups), (len(images), len(pos_groups))
|
314 |
-
|
315 |
-
batch_images = []
|
316 |
-
batch_pos = []
|
317 |
-
for images_i, pos_group in zip(images, pos_groups):
|
318 |
-
assert len(images_i) == len(pos_group), (len(images_i), len(pos_group))
|
319 |
-
for image, pos in zip(images_i, pos_group):
|
320 |
-
batch_idx, start, end = pos
|
321 |
-
a2 = self.adaptor_patch_size ** 2
|
322 |
-
# recover the real number of the input image tokens
|
323 |
-
start *= a2
|
324 |
-
end *= a2
|
325 |
-
emb, _ = self.forward_single(image)
|
326 |
-
assert emb.ndim == 3, '(B, L, E)'
|
327 |
-
embeddings[batch_idx, start:end] = emb
|
328 |
-
attn_mask[batch_idx, :, start:end, start:end] = True
|
329 |
-
return embeddings, attn_mask, pos_groups
|
330 |
-
|
331 |
-
|
332 |
-
class CLIPVisionEmbeddings(nn.Module):
|
333 |
-
def __init__(self, config: CLIPVisionConfig, add_pre_layernorm=False, skip_cls_token=True, vit_patch=1):
|
334 |
-
super().__init__()
|
335 |
-
self.config = config
|
336 |
-
self.embed_dim = config.hidden_size
|
337 |
-
self.image_size = config.image_size
|
338 |
-
self.image_size = config.vit_input_resolution
|
339 |
-
self.patch_size = config.patch_size
|
340 |
-
|
341 |
-
self.class_embedding = nn.Parameter(torch.randn(self.embed_dim))
|
342 |
-
|
343 |
-
self.patch_embedding = nn.Conv2d(
|
344 |
-
in_channels=config.num_channels,
|
345 |
-
out_channels=self.embed_dim,
|
346 |
-
kernel_size=self.patch_size,
|
347 |
-
stride=self.patch_size,
|
348 |
-
bias=False,
|
349 |
-
)
|
350 |
-
|
351 |
-
self.num_patches = (self.image_size // self.patch_size) ** 2
|
352 |
-
|
353 |
-
self.skip_cls_token = skip_cls_token
|
354 |
-
|
355 |
-
self.num_positions = self.num_patches + 1
|
356 |
-
|
357 |
-
self.register_buffer("position_ids", torch.arange(self.num_positions).expand((1, -1)))
|
358 |
-
if vit_patch > 1:
|
359 |
-
self.position_embedding = nn.Embedding(self.num_patches * (vit_patch ** 2 + 1) + 1, self.embed_dim)
|
360 |
-
# 0 支持最大16张图,目前写死了,如需其他的需要额外定义参数
|
361 |
-
elif vit_patch == 0:
|
362 |
-
self.position_embedding = nn.Embedding(self.num_patches * (16 ** 2 + 1) + 1, self.embed_dim)
|
363 |
-
else:
|
364 |
-
self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim)
|
365 |
-
|
366 |
-
if add_pre_layernorm:
|
367 |
-
self.pre_layernorm = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
|
368 |
-
else:
|
369 |
-
self.pre_layernorm = None
|
370 |
-
|
371 |
-
def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor:
|
372 |
-
"""
|
373 |
-
This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher
|
374 |
-
resolution images.
|
375 |
-
|
376 |
-
Source:
|
377 |
-
https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174
|
378 |
-
"""
|
379 |
-
num_patches = embeddings.shape[1] - 1
|
380 |
-
position_embeddings = self.position_embedding(self.position_ids)
|
381 |
-
num_positions = position_embeddings.shape[1] - 1
|
382 |
-
if num_patches == num_positions and height == width:
|
383 |
-
return position_embeddings
|
384 |
-
class_pos_embed = position_embeddings[:, 0]
|
385 |
-
patch_pos_embed = position_embeddings[:, 1:]
|
386 |
-
dim = embeddings.shape[-1]
|
387 |
-
h0 = height // self.config.patch_size
|
388 |
-
w0 = width // self.config.patch_size
|
389 |
-
# we add a small number to avoid floating point error in the interpolation
|
390 |
-
# see discussion at https://github.com/facebookresearch/dino/issues/8
|
391 |
-
h0, w0 = h0 + 0.1, w0 + 0.1
|
392 |
-
patch_pos_embed = patch_pos_embed.reshape(1, int(math.sqrt(num_positions)), int(math.sqrt(num_positions)), dim)
|
393 |
-
patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2)
|
394 |
-
raw_type = patch_pos_embed.dtype
|
395 |
-
patch_pos_embed = nn.functional.interpolate(
|
396 |
-
patch_pos_embed.float(),
|
397 |
-
scale_factor=(h0 / math.sqrt(num_positions), w0 / math.sqrt(num_positions)),
|
398 |
-
mode="bicubic",
|
399 |
-
align_corners=False,
|
400 |
-
)
|
401 |
-
# print(patch_pos_embed.shape)
|
402 |
-
patch_pos_embed = patch_pos_embed.to(raw_type)
|
403 |
-
assert int(h0) == patch_pos_embed.shape[-2] and int(w0) == patch_pos_embed.shape[-1]
|
404 |
-
patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
|
405 |
-
return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1)
|
406 |
-
|
407 |
-
|
408 |
-
def forward(self, pixel_values: torch.FloatTensor, interpolate_pos_encoding: bool = False, img_index=None) -> torch.Tensor:
|
409 |
-
batch_size, num_channels, height, width = pixel_values.shape
|
410 |
-
patch_embeds = self.patch_embedding(pixel_values) # shape = [*, width, grid, grid]
|
411 |
-
patch_embeds = patch_embeds.flatten(2).transpose(1, 2)
|
412 |
-
if self.skip_cls_token:
|
413 |
-
embeddings = patch_embeds
|
414 |
-
if img_index is None:
|
415 |
-
position_ids = self.position_ids[:,1:]
|
416 |
-
embeddings = embeddings + self.position_embedding(position_ids)
|
417 |
-
else:
|
418 |
-
position_ids = (torch.tensor(img_index).cuda() * (self.num_positions - 1)).unsqueeze(1).repeat(1, self.num_positions - 1) \
|
419 |
-
+ self.position_ids.expand(batch_size, -1)[:, 1:]
|
420 |
-
embeddings = embeddings + self.position_embedding(position_ids)
|
421 |
-
else:
|
422 |
-
class_embeds = self.class_embedding.expand(batch_size, 1, -1)
|
423 |
-
embeddings = torch.cat([class_embeds, patch_embeds], dim=1)
|
424 |
-
if interpolate_pos_encoding:
|
425 |
-
embeddings = embeddings + self.interpolate_pos_encoding(embeddings, height, width)
|
426 |
-
else:
|
427 |
-
if img_index is None:
|
428 |
-
embeddings = embeddings + self.position_embedding(self.position_ids)
|
429 |
-
else:
|
430 |
-
position_ids = self.position_ids.expand(batch_size,-1)[:,0].unsqueeze(1)
|
431 |
-
new_position = (torch.tensor(img_index).cuda() * (self.num_positions -1)).unsqueeze(1).repeat(1,self.num_positions-1) + self.position_ids.expand(batch_size,-1)[:,1:]
|
432 |
-
position_ids = torch.cat([position_ids,new_position],dim=1)
|
433 |
-
embeddings = embeddings + self.position_embedding(position_ids)
|
434 |
-
if self.pre_layernorm is not None:
|
435 |
-
embeddings = self.pre_layernorm(embeddings)
|
436 |
-
return embeddings
|
437 |
-
|
438 |
-
|
439 |
-
class NaVitTransformer(nn.Module):
|
440 |
-
def __init__(self, config: HunYuanConfig, vit_config: CLIPVisionConfig):
|
441 |
-
super().__init__()
|
442 |
-
self.config = config
|
443 |
-
self.vit_config = vit_config
|
444 |
-
with self.prepare_args(config, vit_config):
|
445 |
-
self._use_sdpa = config._attn_implementation == "sdpa"
|
446 |
-
self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
|
447 |
-
self.layers = nn.ModuleList(
|
448 |
-
[HunYuanDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
|
449 |
-
)
|
450 |
-
|
451 |
-
@contextmanager
|
452 |
-
def prepare_args(self, config, vit_config):
|
453 |
-
hidden_act = config.hidden_act
|
454 |
-
hidden_size = config.hidden_size
|
455 |
-
ffn_hidden_size = config.intermediate_size
|
456 |
-
num_attention_heads = config.num_attention_heads
|
457 |
-
num_key_value_heads = config.num_key_value_heads
|
458 |
-
attention_head_dim = config.attention_head_dim
|
459 |
-
use_qk_norm = config.use_qk_norm
|
460 |
-
use_rotary_pos_emb = config.use_rotary_pos_emb
|
461 |
-
num_hidden_layers = config.num_hidden_layers
|
462 |
-
rms_norm_eps = config.rms_norm_eps
|
463 |
-
attention_dropout = config.attention_dropout
|
464 |
-
# hidden_dropout = config.hidden_dropout
|
465 |
-
norm_type = config.norm_type
|
466 |
-
attention_bias = config.attention_bias
|
467 |
-
mlp_bias = config.mlp_bias
|
468 |
-
use_mla = config.use_mla
|
469 |
-
num_experts = config.num_experts
|
470 |
-
_attn_implementation = config._attn_implementation
|
471 |
-
|
472 |
-
config.hidden_act = vit_config.hidden_act
|
473 |
-
config.hidden_size = vit_config.hidden_size
|
474 |
-
config.intermediate_size = vit_config.intermediate_size
|
475 |
-
config.num_attention_heads = vit_config.num_attention_heads
|
476 |
-
config.num_key_value_heads = None
|
477 |
-
config.attention_head_dim = vit_config.hidden_size // vit_config.num_attention_heads
|
478 |
-
config.use_qk_norm = False
|
479 |
-
config.use_rotary_pos_emb = False
|
480 |
-
config.num_hidden_layers = vit_config.num_hidden_layers
|
481 |
-
config.rms_norm_eps = vit_config.layer_norm_eps
|
482 |
-
config.attention_dropout = vit_config.attention_dropout
|
483 |
-
# config.hidden_dropout = vit_config.hidden_dropout
|
484 |
-
config.norm_type = config.vit_norm_type
|
485 |
-
config.attention_bias = True
|
486 |
-
config.mlp_bias = True
|
487 |
-
config.use_mla = False
|
488 |
-
config.num_experts = 1
|
489 |
-
config._attn_implementation = "eager"
|
490 |
-
|
491 |
-
yield
|
492 |
-
config.hidden_act = hidden_act
|
493 |
-
config.hidden_size = hidden_size
|
494 |
-
config.intermediate_size = ffn_hidden_size
|
495 |
-
config.num_attention_heads = num_attention_heads
|
496 |
-
config.num_key_value_heads = num_key_value_heads
|
497 |
-
config.attention_head_dim = attention_head_dim
|
498 |
-
config.use_qk_norm = use_qk_norm
|
499 |
-
config.use_rotary_pos_emb = use_rotary_pos_emb
|
500 |
-
config.num_hidden_layers = num_hidden_layers
|
501 |
-
config.rms_norm_eps = rms_norm_eps
|
502 |
-
config.attention_dropout = attention_dropout
|
503 |
-
# config.hidden_dropout = hidden_dropout
|
504 |
-
config.attention_bias = attention_bias
|
505 |
-
config.mlp_bias = mlp_bias
|
506 |
-
config.norm_type = norm_type
|
507 |
-
config.use_mla = use_mla
|
508 |
-
config.num_experts = num_experts
|
509 |
-
config._attn_implementation = _attn_implementation
|
510 |
-
|
511 |
-
def forward(
|
512 |
-
self,
|
513 |
-
pixel_values: Optional[torch.FloatTensor] = None,
|
514 |
-
) -> Union[Tuple, BaseModelOutputWithPooling]:
|
515 |
-
|
516 |
-
hidden_states, attention_mask, img_pos = self.embeddings(pixel_values)
|
517 |
-
attention_mask = attention_mask.int()
|
518 |
-
batch_size, seq_length, _ = hidden_states.shape
|
519 |
-
past_key_values_length = 0
|
520 |
-
|
521 |
-
if self._use_flash_attention_2:
|
522 |
-
# 2d mask is passed through the layers
|
523 |
-
attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
|
524 |
-
elif self._use_sdpa:
|
525 |
-
# output_attentions=True can not be supported when using SDPA, and we fall back on
|
526 |
-
# the manual implementation that requires a 4D causal mask in all cases.
|
527 |
-
attention_mask = _prepare_4d_causal_attention_mask_for_sdpa(
|
528 |
-
attention_mask,
|
529 |
-
(batch_size, seq_length),
|
530 |
-
hidden_states,
|
531 |
-
past_key_values_length,
|
532 |
-
)
|
533 |
-
else:
|
534 |
-
attention_mask = _prepare_4d_causal_attention_mask(
|
535 |
-
attention_mask,
|
536 |
-
(batch_size, seq_length),
|
537 |
-
hidden_states,
|
538 |
-
past_key_values_length,
|
539 |
-
)
|
540 |
-
|
541 |
-
for layer_idx, decoder_layer in enumerate(self.layers):
|
542 |
-
layer_outputs = decoder_layer(
|
543 |
-
hidden_states,
|
544 |
-
attention_mask=attention_mask
|
545 |
-
)
|
546 |
-
hidden_states = layer_outputs[0]
|
547 |
-
|
548 |
-
return hidden_states, img_pos
|
549 |
-
|
550 |
-
|
551 |
-
class AnyResVitTransformer(NaVitTransformer):
|
552 |
-
def __init__(self, config: HunYuanConfig, vit_config: CLIPVisionConfig, anyres_vit_max_image_size):
|
553 |
-
super().__init__(config, vit_config)
|
554 |
-
old_anyres_vit_max_image_size = vit_config.max_image_size
|
555 |
-
anyres_vit_max_image_size = anyres_vit_max_image_size or old_anyres_vit_max_image_size
|
556 |
-
vit_config.max_image_size = anyres_vit_max_image_size
|
557 |
-
vit_config.torch_dtype = config.torch_dtype
|
558 |
-
vit_config.anyres_vit_two_views = config.anyres_vit_two_views
|
559 |
-
vit_config.vit_remove_prenorm = config.vit_remove_prenorm
|
560 |
-
vit_config.vit_add_patchemb_bias = config.vit_add_patchemb_bias
|
561 |
-
self.embeddings = AnyResCLIPVisionEmbeddings(vit_config)
|
562 |
-
vit_config.max_image_size = old_anyres_vit_max_image_size
|
563 |
-
|
564 |
-
def fix_embeddings_fn(self, pixel_values):
|
565 |
-
# (B, L, E)
|
566 |
-
embeddings, hw = self.embeddings.forward_single(pixel_values)
|
567 |
-
embeddings = self.embeddings.pre_layernorm(embeddings)
|
568 |
-
return embeddings
|
569 |
-
|
570 |
-
|
571 |
-
class CLIPVisionTransformer(nn.Module):
|
572 |
-
def __init__(self, config: HunYuanConfig, vit_config: CLIPVisionConfig):
|
573 |
-
super().__init__()
|
574 |
-
embed_dim = vit_config.hidden_size
|
575 |
-
|
576 |
-
self.pre_layrnorm = nn.LayerNorm(embed_dim, eps=vit_config.layer_norm_eps)
|
577 |
-
self.embeddings = CLIPVisionEmbeddings(vit_config, skip_cls_token=config.skip_cls_token, vit_patch=config.vit_patch)
|
578 |
-
|
579 |
-
with self.prepare_args(config, vit_config):
|
580 |
-
self.layers = nn.ModuleList(
|
581 |
-
[HunYuanDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
|
582 |
-
)
|
583 |
-
|
584 |
-
@contextmanager
|
585 |
-
def prepare_args(self, config, vit_config):
|
586 |
-
hidden_act = config.hidden_act
|
587 |
-
hidden_size = config.hidden_size
|
588 |
-
ffn_hidden_size = config.intermediate_size
|
589 |
-
num_attention_heads = config.num_attention_heads
|
590 |
-
num_key_value_heads = config.num_key_value_heads
|
591 |
-
attention_head_dim = config.attention_head_dim
|
592 |
-
use_qk_norm = config.use_qk_norm
|
593 |
-
use_rotary_pos_emb = config.use_rotary_pos_emb
|
594 |
-
num_hidden_layers = config.num_hidden_layers
|
595 |
-
rms_norm_eps = config.rms_norm_eps
|
596 |
-
attention_dropout = config.attention_dropout
|
597 |
-
# hidden_dropout = config.hidden_dropout
|
598 |
-
norm_type = config.norm_type
|
599 |
-
attention_bias = config.attention_bias
|
600 |
-
mlp_bias = config.mlp_bias
|
601 |
-
use_mla = config.use_mla
|
602 |
-
num_experts = config.num_experts
|
603 |
-
_attn_implementation = config._attn_implementation
|
604 |
-
|
605 |
-
config.hidden_act = vit_config.hidden_act
|
606 |
-
config.hidden_size = vit_config.hidden_size
|
607 |
-
config.intermediate_size = vit_config.intermediate_size
|
608 |
-
config.num_attention_heads = vit_config.num_attention_heads
|
609 |
-
config.num_key_value_heads = None
|
610 |
-
config.attention_head_dim = vit_config.hidden_size // vit_config.num_attention_heads
|
611 |
-
config.use_qk_norm = False
|
612 |
-
config.use_rotary_pos_emb = False
|
613 |
-
config.num_hidden_layers = vit_config.num_hidden_layers
|
614 |
-
config.rms_norm_eps = vit_config.layer_norm_eps
|
615 |
-
config.attention_dropout = vit_config.attention_dropout
|
616 |
-
# config.hidden_dropout = 0.0
|
617 |
-
config.norm_type = "fused"
|
618 |
-
config.attention_bias = True
|
619 |
-
config.mlp_bias = True
|
620 |
-
config.use_mla = False
|
621 |
-
config.num_experts = 1
|
622 |
-
config._attn_implementation = "eager"
|
623 |
-
|
624 |
-
yield
|
625 |
-
|
626 |
-
config.hidden_act = hidden_act
|
627 |
-
config.hidden_size = hidden_size
|
628 |
-
config.intermediate_size = ffn_hidden_size
|
629 |
-
config.num_attention_heads = num_attention_heads
|
630 |
-
config.num_key_value_heads = num_key_value_heads
|
631 |
-
config.attention_head_dim = attention_head_dim
|
632 |
-
config.use_qk_norm = use_qk_norm
|
633 |
-
config.use_rotary_pos_emb = use_rotary_pos_emb
|
634 |
-
config.num_hidden_layers = num_hidden_layers
|
635 |
-
config.rms_norm_eps = rms_norm_eps
|
636 |
-
config.attention_dropout = attention_dropout
|
637 |
-
# config.hidden_dropout = hidden_dropout
|
638 |
-
config.norm_type = norm_type
|
639 |
-
config.attention_bias = attention_bias
|
640 |
-
config.mlp_bias = mlp_bias
|
641 |
-
config.use_mla = use_mla
|
642 |
-
config.num_experts = num_experts
|
643 |
-
config._attn_implementation = _attn_implementation
|
644 |
-
|
645 |
-
def forward(
|
646 |
-
self,
|
647 |
-
pixel_values: Optional[torch.FloatTensor] = None,
|
648 |
-
interpolate_pos_encoding: Optional[bool] = None,
|
649 |
-
img_index=None
|
650 |
-
) -> Union[Tuple, BaseModelOutputWithPooling]:
|
651 |
-
r"""
|
652 |
-
Returns:
|
653 |
-
|
654 |
-
"""
|
655 |
-
hidden_states = self.embeddings(pixel_values, interpolate_pos_encoding=interpolate_pos_encoding, img_index=img_index)
|
656 |
-
hidden_states = self.pre_layrnorm(hidden_states)
|
657 |
-
batch = hidden_states.shape[0]
|
658 |
-
seq_len = hidden_states.shape[1]
|
659 |
-
device = hidden_states.device
|
660 |
-
attention_mask = torch.ones(batch, 1, seq_len, seq_len, dtype=torch.float32, device=device)
|
661 |
-
|
662 |
-
for layer_idx, decoder_layer in enumerate(self.layers):
|
663 |
-
layer_outputs = decoder_layer(
|
664 |
-
hidden_states,
|
665 |
-
attention_mask=attention_mask
|
666 |
-
)
|
667 |
-
hidden_states = layer_outputs[0]
|
668 |
-
|
669 |
-
return hidden_states
|
670 |
-
|
671 |
-
|
672 |
-
class Vit(torch.nn.Module):
|
673 |
-
def __init__(self, config, resampler_token=64, pool_rate=2):
|
674 |
-
super().__init__()
|
675 |
-
self.config = config
|
676 |
-
self.vit_mapping_type = config.vit_mapping_type
|
677 |
-
self.anyres_vit_max_image_size = config.anyres_vit_max_image_size
|
678 |
-
self.skip_cls_token = config.skip_cls_token
|
679 |
-
self.pool_rate = pool_rate
|
680 |
-
self.vit_type = self.config.vit_type
|
681 |
-
self.anyres_vit_two_views = self.config.anyres_vit_two_views
|
682 |
-
if self.vit_type in ['Vit-g', 'Vit-bigG', 'NaVit', 'EvaVit', 'AnyResVit']:
|
683 |
-
self.img_init(resampler_token, config.vit_input_resolution, config.vit_mapping_type, pool_rate)
|
684 |
-
else:
|
685 |
-
raise NotImplementedError(f"unsupported vit type: {self.vit_type}")
|
686 |
-
|
687 |
-
def img_init(self, resampler_token=64, vit_input_resolution=224, vit_mapping_type='resampler', pool_rate=2):
|
688 |
-
if self.vit_type == 'AnyResVit':
|
689 |
-
vit_config = json.load(open(f"{self.config.vit_path}/config.json"))
|
690 |
-
self.vit_config = types.SimpleNamespace(**vit_config["vision_config"])
|
691 |
-
self.vit_config.image_size = vit_input_resolution
|
692 |
-
self.vit = AnyResVitTransformer(self.config, self.vit_config, self.anyres_vit_max_image_size)
|
693 |
-
elif self.vit_type == 'Vit-g':
|
694 |
-
vit_config = json.load(open(f"{self.config.vit_path}/config.json"))
|
695 |
-
self.vit_config = types.SimpleNamespace(**{**vit_config["vision_config_dict"],**vit_config["vision_config"]})
|
696 |
-
self.vit_config.vit_input_resolution = vit_input_resolution
|
697 |
-
self.vit = CLIPVisionTransformer(self.config, self.vit_config)
|
698 |
-
else:
|
699 |
-
assert False, "other vit_types are not supported"
|
700 |
-
|
701 |
-
if self.vit_mapping_type == 'simple_conv_mlp':
|
702 |
-
self.perceive = SimpleConvMlp(self.vit_config.hidden_size, self.config.hidden_size, self.config.anyres_pooling_size, \
|
703 |
-
self.config.vit_used_rms_norm, self.config.rms_norm_eps, poolmlp=False, twoview=True)
|
704 |
-
elif self.vit_mapping_type == 'oryx_mlp':
|
705 |
-
self.perceive = OryxMLPv2(self.vit_config.hidden_size, self.config.hidden_size, twoview=True, use_pe=False)
|
706 |
-
elif self.vit_mapping_type == 'mlp':
|
707 |
-
self.mlp_depth = 2
|
708 |
-
# one mlp layer already in gpt_model.py
|
709 |
-
mlp_hidden_size = self.vit_config.hidden_size
|
710 |
-
if self.vit_type in ['NaVit', 'EvaVit']:
|
711 |
-
mlp_hidden_size *= self.vit_config.adaptor_patch_size **2
|
712 |
-
if self.mlp_depth > 1:
|
713 |
-
mlp_modules = [torch.nn.Linear(mlp_hidden_size, self.config.hidden_size), torch.nn.GELU()]
|
714 |
-
if self.vit_type in ['NaVit', 'EvaVit']:
|
715 |
-
for _ in range(1, self.mlp_depth):
|
716 |
-
mlp_modules.append(torch.nn.Linear(self.config.hidden_size, self.config.hidden_size))
|
717 |
-
mlp_modules.append(torch.nn.GELU())
|
718 |
-
self.perceive = torch.nn.Sequential(*mlp_modules)
|
719 |
-
else:
|
720 |
-
assert False, "other vit_mapping_types are not supported"
|
721 |
-
|
722 |
-
self.vit_patch_mlp = (self.config.vit_patch > 1 and self.vit_mapping_type == 'mlp') or self.config.vit_patch == 0
|
723 |
-
for name, param in self.named_parameters():
|
724 |
-
setattr(param, "is_vit_param", True)
|
725 |
-
|
726 |
-
def forward(self, images, img_index=None):
|
727 |
-
if self.vit_type in ['AnyResVit']:
|
728 |
-
dtype = self.config.torch_dtype
|
729 |
-
device = torch.cuda.current_device()
|
730 |
-
|
731 |
-
images_size = []
|
732 |
-
for i in range(len(images)):
|
733 |
-
images_size.append([])
|
734 |
-
for j in range(len(images[i])):
|
735 |
-
images_size[i].append((images[i][j].size()[1] // self.vit_config.patch_size, images[i][j].size()[2] // self.vit_config.patch_size))
|
736 |
-
|
737 |
-
images_feats, img_batch_pos = self.vit(pixel_values=images)
|
738 |
-
a2 = self.vit_config.adaptor_patch_size ** 2
|
739 |
-
|
740 |
-
if self.anyres_vit_two_views:
|
741 |
-
step = 2
|
742 |
-
else:
|
743 |
-
step = 1
|
744 |
-
perceive_fn = lambda x, img_size, is_video: self.perceive(x, img_size, is_video=is_video)
|
745 |
-
images_list = []
|
746 |
-
images_fix_i = 0
|
747 |
-
num_img_batch_pos = len(img_batch_pos)
|
748 |
-
for i in range(num_img_batch_pos): # batch_id
|
749 |
-
for j in range(0, len(img_batch_pos[i]), step):
|
750 |
-
if self.anyres_vit_two_views:
|
751 |
-
lower_idx, lower_begin, lower_end = img_batch_pos[i][j]
|
752 |
-
lower_begin = lower_begin * a2
|
753 |
-
lower_end = lower_end * a2
|
754 |
-
higher_idx, higher_begin, higher_end = img_batch_pos[i][j + 1]
|
755 |
-
higher_begin = higher_begin * a2
|
756 |
-
higher_end = higher_end * a2
|
757 |
-
lower_res_feat = images_feats[lower_idx, lower_begin:lower_end].unsqueeze(0)
|
758 |
-
higher_res_feat = images_feats[higher_idx, higher_begin:higher_end].unsqueeze(0)
|
759 |
-
lower_images_size = images_size[i][j]
|
760 |
-
higher_images_size = images_size[i][j + 1]
|
761 |
-
images_list.append(self.perceive(lower_res_feat, lower_images_size, higher_res_feat, higher_images_size))
|
762 |
-
else:
|
763 |
-
idx, begin, end = img_batch_pos[i][j]
|
764 |
-
begin = begin * a2
|
765 |
-
end = end * a2
|
766 |
-
is_video = hasattr(images[i][j],'_is_video') and images[i][j]._is_video
|
767 |
-
images_list.append(perceive_fn(images_feats[idx, begin:end].unsqueeze(0), images_size[i][j], is_video=is_video))
|
768 |
-
|
769 |
-
images = torch.cat(images_list, dim=1)
|
770 |
-
|
771 |
-
new_batch_pos = []
|
772 |
-
k = 0; cur_len = 0
|
773 |
-
for i in range(len(images_size)):
|
774 |
-
new_batch_pos.append([])
|
775 |
-
for j in range(0, len(images_size[i]), step):
|
776 |
-
new_pos = [0, cur_len, cur_len + images_list[k].size(1)]
|
777 |
-
cur_len += images_list[k].size(1)
|
778 |
-
k += 1
|
779 |
-
new_batch_pos[i].append(new_pos)
|
780 |
-
return images, new_batch_pos
|
781 |
-
elif self.vit_type == 'Vit-g':
|
782 |
-
images = self.vit(pixel_values=images, interpolate_pos_encoding=False, img_index=img_index)
|
783 |
-
else:
|
784 |
-
assert False, "other vit_types are not supported"
|
785 |
-
|
786 |
-
if self.vit_mapping_type == 'mlp':
|
787 |
-
if self.vit_type in ['Vit-g'] and not self.skip_cls_token:
|
788 |
-
images = images[:,1:,:]
|
789 |
-
b, v, d = images.shape
|
790 |
-
s = int(math.sqrt(v))
|
791 |
-
images = images.reshape(b, s, s, d)
|
792 |
-
|
793 |
-
|
794 |
-
if self.vit_patch_mlp and img_index is not None:
|
795 |
-
L_tensor = torch.tensor(img_index)
|
796 |
-
device = images.device
|
797 |
-
# 获取子图位置
|
798 |
-
nonzero_indices = torch.nonzero(L_tensor).squeeze().to(device)
|
799 |
-
# 获取主图位置
|
800 |
-
zero_indices = torch.nonzero(L_tensor == 0).squeeze().to(device)
|
801 |
-
|
802 |
-
|
803 |
-
images_nonzero = torch.index_select(images,0, nonzero_indices).to(device)
|
804 |
-
images_zero = torch.index_select(images, 0, zero_indices).to(device)
|
805 |
-
|
806 |
-
# 子图额外多pool一次
|
807 |
-
pool_rate = self.pool_rate * 2
|
808 |
-
images_nonzero = images_nonzero.reshape(-1, s // pool_rate, pool_rate, s // pool_rate, pool_rate, d)
|
809 |
-
images_nonzero = images_nonzero.permute(0, 1, 3, 5, 2, 4).reshape(-1, (s // pool_rate) * (s // pool_rate), d,
|
810 |
-
pool_rate*pool_rate).mean(-1)
|
811 |
-
|
812 |
-
# 为了组batch折衷方案
|
813 |
-
images_nonzero = F.pad(images_nonzero, (0, 0, 0, (s // self.pool_rate) * (s // self.pool_rate)- (s // pool_rate) * (s // pool_rate)))
|
814 |
-
images_zero = images_zero.reshape(-1, s // self.pool_rate, self.pool_rate, s // self.pool_rate, self.pool_rate, d)
|
815 |
-
images_zero = images_zero.permute(0, 1, 3, 5, 2, 4).reshape(-1, (s // self.pool_rate) * (s // self.pool_rate), d,
|
816 |
-
self.pool_rate*self.pool_rate).mean(-1)
|
817 |
-
# 组batch
|
818 |
-
images = torch.zeros(b, (s // self.pool_rate) * (s // self.pool_rate), d).to(device).to(images.dtype)
|
819 |
-
images.index_copy_(0, nonzero_indices, images_nonzero)
|
820 |
-
images.index_copy_(0, zero_indices, images_zero)
|
821 |
-
|
822 |
-
if self.mlp_depth >= 2:
|
823 |
-
images = self.perceive(images)
|
824 |
-
else:
|
825 |
-
if s % self.pool_rate == 0:
|
826 |
-
images = images.reshape(b, s//self.pool_rate, self.pool_rate, s//self.pool_rate, self.pool_rate, d)
|
827 |
-
images = images.permute(0, 1, 3, 5, 2, 4).reshape(b, (s//self.pool_rate) * (s//self.pool_rate), d, -1).mean(-1)
|
828 |
-
if self.mlp_depth >= 2:
|
829 |
-
images = self.perceive(images)
|
830 |
-
else:
|
831 |
-
raise ValueError
|
832 |
-
return images
|
833 |
-
|
834 |
-
|
835 |
-
class SimpleConvMlp(nn.Module):
|
836 |
-
def __init__(self, in_channels, out_channels, anyres_pooling_size, vit_used_rms_norm, rms_norm_eps, twoview=False, poolmlp=True, cat_extra_token=True):
|
837 |
-
super().__init__()
|
838 |
-
|
839 |
-
embed_std = 1 / math.sqrt(out_channels)
|
840 |
-
if poolmlp:
|
841 |
-
# if args.learnable_mlp_pooling_size is not None:
|
842 |
-
# in_channels *= args.learnable_mlp_pooling_size ** 2
|
843 |
-
self.proj = nn.Sequential(
|
844 |
-
nn.Linear(in_channels, out_channels),
|
845 |
-
nn.GELU()
|
846 |
-
)
|
847 |
-
self.vit_linear_encoder = nn.Linear(out_channels, out_channels)
|
848 |
-
self.image_newline = nn.Parameter(
|
849 |
-
torch.randn(out_channels) * embed_std
|
850 |
-
)
|
851 |
-
else:
|
852 |
-
self.proj = nn.Sequential(
|
853 |
-
nn.Conv2d(in_channels, in_channels * 2, kernel_size=anyres_pooling_size, stride=anyres_pooling_size),
|
854 |
-
nn.GELU(),
|
855 |
-
nn.Conv2d(in_channels * 2, in_channels * 4, kernel_size=1),
|
856 |
-
)
|
857 |
-
self.mlp = nn.Linear(in_channels * 4, out_channels)
|
858 |
-
self.image_newline = nn.Parameter(
|
859 |
-
torch.randn(in_channels * 4) * embed_std
|
860 |
-
)
|
861 |
-
self.poolmlp = poolmlp
|
862 |
-
|
863 |
-
self.image_begin = nn.Parameter(
|
864 |
-
torch.randn(out_channels) * embed_std
|
865 |
-
)
|
866 |
-
self.image_end = nn.Parameter(
|
867 |
-
torch.randn(out_channels) * embed_std
|
868 |
-
)
|
869 |
-
|
870 |
-
if twoview:
|
871 |
-
self.image_sep = nn.Parameter(
|
872 |
-
torch.randn(out_channels) * embed_std
|
873 |
-
)
|
874 |
-
|
875 |
-
self.cat_extra_token = cat_extra_token
|
876 |
-
self.use_rms_norm = vit_used_rms_norm
|
877 |
-
if self.use_rms_norm:
|
878 |
-
self.before_rms = HunYuanRMSNorm(in_channels, eps=rms_norm_eps)
|
879 |
-
self.after_rms = HunYuanRMSNorm(out_channels, eps=rms_norm_eps)
|
880 |
-
|
881 |
-
def forward(self, x, size=(16,16), x2=None, size2=(16, 16), is_video=False):
|
882 |
-
return self.single_forward(x=x, size=size, x2=x2, size2=size2, is_video=is_video)
|
883 |
-
|
884 |
-
def single_forward(self, x, size=(16,16), x2=None, size2=(16, 16), is_video=False):
|
885 |
-
remove_vit_special_tokens = False
|
886 |
-
learnable_mlp_pooling_size = None
|
887 |
-
if self.use_rms_norm:
|
888 |
-
x = self.before_rms(x)
|
889 |
-
h, w = size
|
890 |
-
dtype = x.dtype
|
891 |
-
x = x.permute(0, 2, 1).reshape(x.shape[0], -1, h, w)
|
892 |
-
if self.poolmlp:
|
893 |
-
if learnable_mlp_pooling_size is None:
|
894 |
-
x = F.avg_pool2d(x, anyres_pooling_size)
|
895 |
-
x = self.proj(x.permute(0, 2, 3, 1)) # b, h, w, c
|
896 |
-
else:
|
897 |
-
x = x.permute(0, 2, 3, 1) # b, h, w, c
|
898 |
-
x = x.reshape(x.shape[0], h // learnable_mlp_pooling_size, learnable_mlp_pooling_size,
|
899 |
-
w // learnable_mlp_pooling_size, learnable_mlp_pooling_size, -1)
|
900 |
-
x = x.permute(0, 1, 3, 2, 4, 5).reshape(x.shape[0], h // learnable_mlp_pooling_size, w // learnable_mlp_pooling_size, -1)
|
901 |
-
x = self.proj(x)
|
902 |
-
x = self.vit_linear_encoder(x)
|
903 |
-
b, h, w, c = x.shape
|
904 |
-
if not remove_vit_special_tokens:
|
905 |
-
x = torch.cat([
|
906 |
-
x,
|
907 |
-
self.image_newline.reshape(1, 1, 1, c).expand(b, h, 1, c).to(dtype, non_blocking=True)
|
908 |
-
], dim=2)
|
909 |
-
x = x.reshape(b, -1, c)
|
910 |
-
else:
|
911 |
-
x = self.proj(x) #b,c,h,w
|
912 |
-
if is_video:
|
913 |
-
video_avgpool_size = 2
|
914 |
-
stride = 2
|
915 |
-
x = F.avg_pool2d(x, kernel_size = video_avgpool_size, stride = stride)
|
916 |
-
b, c, h, w = x.shape
|
917 |
-
if not remove_vit_special_tokens:
|
918 |
-
x = torch.cat([
|
919 |
-
x,
|
920 |
-
self.image_newline.reshape(1, c, 1, 1).expand(b, c, h, 1).to(dtype, non_blocking=True)
|
921 |
-
], dim=-1)
|
922 |
-
x = x.reshape(b, c, -1).permute(0, 2, 1)
|
923 |
-
x = self.mlp(x)
|
924 |
-
|
925 |
-
|
926 |
-
if x2 is not None:
|
927 |
-
h2, w2 = size2
|
928 |
-
x2 = x2.permute(0, 2, 1).reshape(x2.shape[0], -1, h2, w2)
|
929 |
-
if self.poolmlp:
|
930 |
-
x2 = F.avg_pool2d(x2, 2)
|
931 |
-
x2 = self.proj(x2.permute(0, 2, 3, 1)) # b, h, w, c
|
932 |
-
x2 = self.vit_linear_encoder(x2)
|
933 |
-
b2, h2, w2, c2 = x2.shape
|
934 |
-
if not remove_vit_special_tokens:
|
935 |
-
x2 = torch.cat([
|
936 |
-
x2,
|
937 |
-
self.image_newline.reshape(1, 1, 1, c2).expand(b2, h2, 1, c2).to(dtype, non_blocking=True)
|
938 |
-
], dim=2)
|
939 |
-
x2 = x2.reshape(b2, -1, c2)
|
940 |
-
else:
|
941 |
-
x2 = self.proj(x2)
|
942 |
-
b2, c2, h2, w2 = x2.shape
|
943 |
-
if not remove_vit_special_tokens:
|
944 |
-
x2 = torch.cat([
|
945 |
-
x2,
|
946 |
-
self.image_newline.reshape(1, c2, 1, 1).expand(b2, c2, h2, 1).to(dtype, non_blocking=True)
|
947 |
-
], dim=-1)
|
948 |
-
x2 = x2.reshape(b2, c2, -1).permute(0, 2, 1) #b,n,c
|
949 |
-
x2 = self.mlp(x2)
|
950 |
-
|
951 |
-
sep = self.image_sep.reshape(1, 1, -1).expand(b2, 1, x2.shape[-1]).to(dtype, non_blocking=True)
|
952 |
-
|
953 |
-
x = torch.cat([x, sep, x2], dim=1)
|
954 |
-
|
955 |
-
if self.cat_extra_token:
|
956 |
-
begin = self.image_begin.reshape(1, 1, -1).expand(b, 1, x.shape[-1]).to(dtype, non_blocking=True)
|
957 |
-
end = self.image_end.reshape(1, 1, -1).expand(b, 1, x.shape[-1]).to(dtype, non_blocking=True)
|
958 |
-
x = torch.cat([begin, x, end], dim=1)
|
959 |
-
|
960 |
-
if self.use_rms_norm:
|
961 |
-
return self.after_rms(x)
|
962 |
-
else:
|
963 |
-
return x
|
964 |
-
|
965 |
-
|
966 |
-
class NormalizedDwPooler(nn.Module):
|
967 |
-
def __init__(self, dim):
|
968 |
-
super().__init__()
|
969 |
-
self.dim = dim
|
970 |
-
self.predictor = nn.Sequential(
|
971 |
-
nn.Linear(dim*2, dim),
|
972 |
-
nn.GELU(),
|
973 |
-
nn.Linear(dim, dim),
|
974 |
-
)
|
975 |
-
|
976 |
-
def forward(self, x, forward_type='2x'):
|
977 |
-
B, H, W, C = x.shape
|
978 |
-
|
979 |
-
if forward_type == '2x':
|
980 |
-
new_x = x.reshape(B, H//2, 2, W//2, 2, C).permute(0, 1, 3, 2, 4, 5).reshape(B, H//2, W//2, 4, C)
|
981 |
-
pooled_x = new_x.mean(-2, keepdim=True).expand(-1, -1, -1, 4, -1)
|
982 |
-
fused_x = torch.cat([new_x, pooled_x], dim=-1)
|
983 |
-
elif forward_type == '1x':
|
984 |
-
new_x = x.reshape(B, H, W, 1, C)
|
985 |
-
fused_x = torch.cat([new_x, new_x], dim=-1)
|
986 |
-
elif forward_type == '4x':
|
987 |
-
new_x = x.reshape(B, H//4, 4, W//4, 4, C).permute(0, 1, 3, 2, 4, 5).reshape(B, H//4, W//4, 16, C)
|
988 |
-
pooled_x = new_x.mean(-2, keepdim=True).expand(-1, -1, -1, 16, -1)
|
989 |
-
fused_x = torch.cat([new_x, pooled_x], dim=-1)
|
990 |
-
|
991 |
-
score = self.predictor(fused_x)
|
992 |
-
normalized_score = F.softmax(score, dim=-2)
|
993 |
-
new_x = (new_x * normalized_score).sum(dim=-2)
|
994 |
-
return new_x
|
995 |
-
|
996 |
-
|
997 |
-
class OryxMLPv2(nn.Module):
|
998 |
-
def __init__(self, in_channels, out_channels, twoview=False, use_pe=False):
|
999 |
-
super().__init__()
|
1000 |
-
|
1001 |
-
self.proj1 = nn.Linear(in_channels, out_channels)
|
1002 |
-
self.proj2 = nn.Linear(out_channels, out_channels)
|
1003 |
-
self.act = nn.GELU()
|
1004 |
-
self.pooler = NormalizedDwPooler(out_channels)
|
1005 |
-
embed_std = 1 / math.sqrt(out_channels)
|
1006 |
-
|
1007 |
-
self.use_pe = use_pe
|
1008 |
-
if not use_pe:
|
1009 |
-
self.image_newline = nn.Parameter(
|
1010 |
-
torch.randn(out_channels) * embed_std
|
1011 |
-
)
|
1012 |
-
self.image_begin = nn.Parameter(
|
1013 |
-
torch.randn(out_channels) * embed_std
|
1014 |
-
)
|
1015 |
-
self.image_end = nn.Parameter(
|
1016 |
-
torch.randn(out_channels) * embed_std
|
1017 |
-
)
|
1018 |
-
|
1019 |
-
if twoview:
|
1020 |
-
self.image_sep = nn.Parameter(
|
1021 |
-
torch.randn(out_channels) * embed_std
|
1022 |
-
)
|
1023 |
-
|
1024 |
-
def forward(self, x, size=(16,16), x2=None, size2=(16, 16), is_video=False):
|
1025 |
-
h, w = size
|
1026 |
-
dtype = x.dtype
|
1027 |
-
x = x.reshape(x.shape[0], h, w, -1)
|
1028 |
-
# x = self.pooler(x, forward_type=REGIONAL_POOL)
|
1029 |
-
# x = self.proj(x) #b,h,w, c
|
1030 |
-
x = self.proj1(x)
|
1031 |
-
x = self.pooler(x, forward_type='2x')
|
1032 |
-
x = self.act(x)
|
1033 |
-
x = self.proj2(x)
|
1034 |
-
|
1035 |
-
|
1036 |
-
b, h, w, c = x.shape
|
1037 |
-
if not self.use_pe:
|
1038 |
-
x = torch.cat([
|
1039 |
-
x,
|
1040 |
-
self.image_newline.reshape(1, 1, 1, c).expand(b, h, 1, c).to(dtype)
|
1041 |
-
], dim=2)
|
1042 |
-
else:
|
1043 |
-
pe_h = torch.arange(h, dtype=torch.long, device=x.device).reshape(1, h, 1, 1).expand(b, h, w, 1).reshape(b, h*w, 1)
|
1044 |
-
pe_w = torch.arange(w, dtype=torch.long, device=x.device).reshape(1, 1, w, 1).expand(b, h, w, 1).reshape(b, h*w, 1)
|
1045 |
-
pe = torch.cat([pe_h, pe_w], dim=-1)
|
1046 |
-
|
1047 |
-
x = x.reshape(b, -1, c)
|
1048 |
-
|
1049 |
-
if x2 is not None:
|
1050 |
-
h2, w2 = size2
|
1051 |
-
x2 = x2.reshape(x2.shape[0], h2, w2, -1)
|
1052 |
-
# x2 = self.pooler(x2, forward_type=REGIONAL_POOL)
|
1053 |
-
## x2 = self.proj(x2) #b,h,w, c
|
1054 |
-
x2 = self.proj1(x2)
|
1055 |
-
x2 = self.pooler(x2, forward_type='2x')
|
1056 |
-
x2 = self.act(x2)
|
1057 |
-
x2 = self.proj2(x2)
|
1058 |
-
|
1059 |
-
b2, h2, w2, c2 = x2.shape
|
1060 |
-
if not self.use_pe:
|
1061 |
-
x2 = torch.cat([
|
1062 |
-
x2,
|
1063 |
-
self.image_newline.reshape(1, 1, 1, c).expand(b, h2, 1, c).to(dtype)
|
1064 |
-
], dim=2)
|
1065 |
-
x2 = x2.reshape(b, -1, c)
|
1066 |
-
sep = self.image_sep.reshape(1, 1, -1).expand(b, 1, c2).to(dtype)
|
1067 |
-
x = torch.cat([x, sep, x2], dim=1)
|
1068 |
-
|
1069 |
-
begin = self.image_begin.reshape(1, 1, -1).expand(b, 1, c).to(dtype)
|
1070 |
-
end = self.image_end.reshape(1, 1, -1).expand(b, 1, c).to(dtype)
|
1071 |
-
x = torch.cat([begin, x, end], dim=1)
|
1072 |
-
# print(x.shape, x2.shape, h, w, h2, w2)
|
1073 |
-
# print("vit rank = " + str(torch.distributed.get_rank()) +" x = " + str(x))
|
1074 |
-
if self.use_pe:
|
1075 |
-
zero_pad = torch.zeros(b, 1, 2, device=x.device, dtype=torch.long)
|
1076 |
-
pe = torch.cat([zero_pad, pe, zero_pad], dim=1)
|
1077 |
-
assert pe.shape[1] == x.shape[1]
|
1078 |
-
return x, pe
|
1079 |
-
else:
|
1080 |
-
nseq = x.shape[1]
|
1081 |
-
fake_pe = torch.zeros(b, nseq, 2, device=x.device, dtype=torch.long)
|
1082 |
-
return x #, fake_pe
|
1083 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Usage and Legal Notices:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
2 |
|
3 |
+
Tencent is pleased to support the open source community by making Tencent Hunyuan A13B available.
|
4 |
+
|
5 |
+
Copyright (C) Tencent. All rights reserved. The below software and/or models in this distribution may have been modified by Tencent ("Tencent Modifications"). All Tencent Modifications are Copyright (C) Tencent.
|
6 |
+
|
7 |
+
Tencent Hunyuan A13B is licensed under TENCENT HUNYUAN COMMUNITY LICENSE AGREEMENT, which can be found in this repository called "LICENSE", except for the third-party components listed below. Tencent Hunyuan A13B does not impose any additional limitations beyond what is outlined in the respective licenses of these third-party components. Users must comply with all terms and conditions of original licenses of these third-party components and must ensure that the usage of the third party components adheres to all relevant laws and regulations.
|
8 |
+
|
9 |
+
For avoidance of doubts, Tencent Hunyuan A13B refers to the inference code, training code, parameters and the weights of Tencent Hunyuan A13B only, which are made publicly available by Tencent in accordance with the TENCENT HUNYUAN COMMUNITY LICENSE AGREEMENT.
|
10 |
+
|
11 |
+
|
12 |
+
Other dependencies and licenses:
|
13 |
+
|
14 |
+
|
15 |
+
Open Source Software Licensed under the Apache License Version 2.0:
|
16 |
+
The below software in this distribution may have been modified by Tencent ("Tencent Modifications"). All Tencent Modifications are Copyright (C) 2025 Tencent.
|
17 |
+
--------------------------------------------------------------------
|
18 |
+
1. pytorch
|
19 |
+
Copyright 2016-2017 TorchAPI
|
20 |
+
Copyright 2016-2017 Contributors
|
21 |
+
|
22 |
+
2. VLLM
|
23 |
+
Copyright (c) vllm original author and authors
|
24 |
+
Please note this software has been modified by Tencent in this distribution.
|
25 |
+
|
26 |
+
3. transformers
|
27 |
+
Copyright 2018- The Hugging Face team. All rights reserved.
|
28 |
+
|
29 |
+
4. accelerate
|
30 |
+
Copyright (c) accelerate original author and authors
|
31 |
+
|
32 |
+
|
33 |
+
Terms of the Apache License Version 2.0:
|
34 |
+
--------------------------------------------------------------------
|
35 |
+
Apache License
|
36 |
+
|
37 |
+
Version 2.0, January 2004
|
38 |
+
|
39 |
+
http://www.apache.org/licenses/
|
40 |
+
|
41 |
+
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
42 |
+
1. Definitions.
|
43 |
+
|
44 |
+
"License" shall mean the terms and conditions for use, reproduction, and distribution as defined by Sections 1 through 9 of this document.
|
45 |
+
|
46 |
+
"Licensor" shall mean the copyright owner or entity authorized by the copyright owner that is granting the License.
|
47 |
+
|
48 |
+
"Legal Entity" shall mean the union of the acting entity and all other entities that control, are controlled by, or are under common control with that entity. For the purposes of this definition, "control" means (i) the power, direct or indirect, to cause the direction or management of such entity, whether by contract or otherwise, or (ii) ownership of fifty percent (50%) or more of the outstanding shares, or (iii) beneficial ownership of such entity.
|
49 |
+
|
50 |
+
"You" (or "Your") shall mean an individual or Legal Entity exercising permissions granted by this License.
|
51 |
+
|
52 |
+
"Source" form shall mean the preferred form for making modifications, including but not limited to software source code, documentation source, and configuration files.
|
53 |
+
|
54 |
+
"Object" form shall mean any form resulting from mechanical transformation or translation of a Source form, including but not limited to compiled object code, generated documentation, and conversions to other media types.
|
55 |
+
|
56 |
+
"Work" shall mean the work of authorship, whether in Source or Object form, made available under the License, as indicated by a copyright notice that is included in or attached to the work (an example is provided in the Appendix below).
|
57 |
+
|
58 |
+
"Derivative Works" shall mean any work, whether in Source or Object form, that is based on (or derived from) the Work and for which the editorial revisions, annotations, elaborations, or other modifications represent, as a whole, an original work of authorship. For the purposes of this License, Derivative Works shall not include works that remain separable from, or merely link (or bind by name) to the interfaces of, the Work and Derivative Works thereof.
|
59 |
+
|
60 |
+
"Contribution" shall mean any work of authorship, including the original version of the Work and any modifications or additions to that Work or Derivative Works thereof, that is intentionally submitted to Licensor for inclusion in the Work by the copyright owner or by an individual or Legal Entity authorized to submit on behalf of the copyright owner. For the purposes of this definition, "submitted" means any form of electronic, verbal, or written communication sent to the Licensor or its representatives, including but not limited to communication on electronic mailing lists, source code control systems, and issue tracking systems that are managed by, or on behalf of, the Licensor for the purpose of discussing and improving the Work, but excluding communication that is conspicuously marked or otherwise designated in writing by the copyright owner as "Not a Contribution."
|
61 |
+
|
62 |
+
"Contributor" shall mean Licensor and any individual or Legal Entity on behalf of whom a Contribution has been received by Licensor and subsequently incorporated within the Work.
|
63 |
+
|
64 |
+
2. Grant of Copyright License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable copyright license to reproduce, prepare Derivative Works of, publicly display, publicly perform, sublicense, and distribute the Work and such Derivative Works in Source or Object form.
|
65 |
+
|
66 |
+
3. Grant of Patent License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable (except as stated in this section) patent license to make, have made, use, offer to sell, sell, import, and otherwise transfer the Work, where such license applies only to those patent claims licensable by such Contributor that are necessarily infringed by their Contribution(s) alone or by combination of their Contribution(s) with the Work to which such Contribution(s) was submitted. If You institute patent litigation against any entity (including a cross-claim or counterclaim in a lawsuit) alleging that the Work or a Contribution incorporated within the Work constitutes direct or contributory patent infringement, then any patent licenses granted to You under this License for that Work shall terminate as of the date such litigation is filed.
|
67 |
+
|
68 |
+
4. Redistribution. You may reproduce and distribute copies of the Work or Derivative Works thereof in any medium, with or without modifications, and in Source or Object form, provided that You meet the following conditions:
|
69 |
+
|
70 |
+
You must give any other recipients of the Work or Derivative Works a copy of this License; and
|
71 |
+
|
72 |
+
You must cause any modified files to carry prominent notices stating that You changed the files; and
|
73 |
+
|
74 |
+
You must retain, in the Source form of any Derivative Works that You distribute, all copyright, patent, trademark, and attribution notices from the Source form of the Work, excluding those notices that do not pertain to any part of the Derivative Works; and
|
75 |
+
|
76 |
+
If the Work includes a "NOTICE" text file as part of its distribution, then any Derivative Works that You distribute must include a readable copy of the attribution notices contained within such NOTICE file, excluding those notices that do not pertain to any part of the Derivative Works, in at least one of the following places: within a NOTICE text file distributed as part of the Derivative Works; within the Source form or documentation, if provided along with the Derivative Works; or, within a display generated by the Derivative Works, if and wherever such third-party notices normally appear. The contents of the NOTICE file are for informational purposes only and do not modify the License. You may add Your own attribution notices within Derivative Works that You distribute, alongside or as an addendum to the NOTICE text from the Work, provided that such additional attribution notices cannot be construed as modifying the License.
|
77 |
+
|
78 |
+
You may add Your own copyright statement to Your modifications and may provide additional or different license terms and conditions for use, reproduction, or distribution of Your modifications, or for any such Derivative Works as a whole, provided Your use, reproduction, and distribution of the Work otherwise complies with the conditions stated in this License.
|
79 |
+
|
80 |
+
5. Submission of Contributions. Unless You explicitly state otherwise, any Contribution intentionally submitted for inclusion in the Work by You to the Licensor shall be under the terms and conditions of this License, without any additional terms or conditions. Notwithstanding the above, nothing herein shall supersede or modify the terms of any separate license agreement you may have executed with Licensor regarding such Contributions.
|
81 |
+
|
82 |
+
6. Trademarks. This License does not grant permission to use the trade names, trademarks, service marks, or product names of the Licensor, except as required for reasonable and customary use in describing the origin of the Work and reproducing the content of the NOTICE file.
|
83 |
+
|
84 |
+
7. Disclaimer of Warranty. Unless required by applicable law or agreed to in writing, Licensor provides the Work (and each Contributor provides its Contributions) on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied, including, without limitation, any warranties or conditions of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A PARTICULAR PURPOSE. You are solely responsible for determining the appropriateness of using or redistributing the Work and assume any risks associated with Your exercise of permissions under this License.
|
85 |
+
|
86 |
+
8. Limitation of Liability. In no event and under no legal theory, whether in tort (including negligence), contract, or otherwise, unless required by applicable law (such as deliberate and grossly negligent acts) or agreed to in writing, shall any Contributor be liable to You for damages, including any direct, indirect, special, incidental, or consequential damages of any character arising as a result of this License or out of the use or inability to use the Work (including but not limited to damages for loss of goodwill, work stoppage, computer failure or malfunction, or any and all other commercial damages or losses), even if such Contributor has been advised of the possibility of such damages.
|
87 |
+
|
88 |
+
9. Accepting Warranty or Additional Liability. While redistributing the Work or Derivative Works thereof, You may choose to offer, and charge a fee for, acceptance of support, warranty, indemnity, or other liability obligations and/or rights consistent with this License. However, in accepting such obligations, You may act only on Your own behalf and on Your sole responsibility, not on behalf of any other Contributor, and only if You agree to indemnify, defend, and hold each Contributor harmless for any liability incurred by, or claims asserted against, such Contributor by reason of your accepting any such warranty or additional liability.
|
89 |
+
|
90 |
+
END OF TERMS AND CONDITIONS
|
91 |
+
|
92 |
+
|
93 |
+
|
94 |
+
Open Source Software Licensed under the BSD 3-Clause License and Other Licenses of the Third-Party Components therein:
|
95 |
+
--------------------------------------------------------------------
|
96 |
+
1. pytorch
|
97 |
+
Copyright (c) 2016- Facebook, Inc (Adam Paszke)
|
98 |
+
Copyright (c) 2014- Facebook, Inc (Soumith Chintala)
|
99 |
+
Copyright (c) 2011-2014 Idiap Research Institute (Ronan Collobert)
|
100 |
+
Copyright (c) 2012-2014 Deepmind Technologies (Koray Kavukcuoglu)
|
101 |
+
Copyright (c) 2011-2012 NEC Laboratories America (Koray Kavukcuoglu)
|
102 |
+
Copyright (c) 2011-2013 NYU (Clement Farabet)
|
103 |
+
Copyright (c) 2006-2010 NEC Laboratories America (Ronan Collobert, Leon Bottou, Iain Melvin, Jason Weston)
|
104 |
+
Copyright (c) 2006 Idiap Research Institute (Samy Bengio)
|
105 |
+
Copyright (c) 2001-2004 Idiap Research Institute (Ronan Collobert, Samy Bengio, Johnny Mariethoz)
|
106 |
+
|
107 |
+
|
108 |
+
Terms of the BSD 3-Clause:
|
109 |
+
--------------------------------------------------------------------
|
110 |
+
Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met:
|
111 |
+
|
112 |
+
1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer.
|
113 |
+
|
114 |
+
2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution.
|
115 |
+
|
116 |
+
3. Neither the name of the copyright holder nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission.
|
117 |
+
|
118 |
+
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
119 |
+
|
120 |
+
For the license of other third party components, please refer to the following URL:
|
121 |
+
https://github.com/pytorch/pytorch/blob/v2.1.1/NOTICE
|
122 |
+
https://github.com/pytorch/pytorch/tree/v2.1.1/third_party
|
123 |
+
|
124 |
+
|
125 |
+
Open Source Software Licensed under the BSD 3-Clause License:
|
126 |
+
--------------------------------------------------------------------
|
127 |
+
1. flash_attn
|
128 |
+
Copyright (c) 2022, the respective contributors, as shown by the AUTHORS file.
|
129 |
+
All rights reserved.
|
130 |
+
|
131 |
+
|
132 |
+
A copy of the BSD 3-Clause is included in this file.
|
133 |
+
|
134 |
+
|
135 |
+
|
136 |
+
Open Source Software Licensed under the Apache License Version 2.0 and Other Licenses of the Third-Party Components therein:
|
137 |
+
The below software in this distribution is modified by Tencent ("Tencent Modifications"). All Tencent Modifications are Copyright (C) 2025 Tencent.
|
138 |
+
--------------------------------------------------------------------
|
139 |
+
1. sglang
|
140 |
+
Copyright 2023-2024 SGLang Team
|
141 |
+
|
142 |
+
|
143 |
+
A copy of the Apache 2.0 is included in this file.
|
144 |
+
|
145 |
+
For the license of other third party components, please refer to the following URL:
|
146 |
+
https://github.com/sgl-project/sglang/tree/v0.4.7/3rdparty/amd
|
147 |
+
|
148 |
+
|
149 |
+
|
150 |
+
Open Source Software Licensed under the Apache License Version 2.0 and Other Licenses of the Third-Party Components therein:
|
151 |
+
The below software in this distribution is modified by Tencent ("Tencent Modifications"). All Tencent Modifications are Copyright (C) 2025 Tencent.
|
152 |
+
--------------------------------------------------------------------
|
153 |
+
1. TensorRT-LLM
|
154 |
+
Copyright (c) TensorRT-LLM original author and authors
|
155 |
+
|
156 |
+
|
157 |
+
A copy of the Apache 2.0 is included in this file.
|
158 |
+
|
159 |
+
For the license of other third party components, please refer to the following URL:
|
160 |
+
https://github.com/NVIDIA/TensorRT-LLM/tree/v0.20.0/3rdparty
|