manaestras commited on
Commit
e45f5c8
·
verified ·
1 Parent(s): ec139cf

Delete vit_model.py

Browse files
Files changed (1) hide show
  1. vit_model.py +0 -1083
vit_model.py DELETED
@@ -1,1083 +0,0 @@
1
- import json
2
- import types
3
- import math
4
- import torch
5
- from torch import Tensor, nn
6
- import torch.nn.functional as F
7
- from typing import List, Tuple, Optional, Union
8
- from contextlib import contextmanager
9
- from transformers.modeling_attn_mask_utils import (
10
- _prepare_4d_causal_attention_mask_for_sdpa,
11
- _prepare_4d_causal_attention_mask_for_sdpa,
12
- _prepare_4d_causal_attention_mask,
13
- )
14
- from transformers.models.clip.configuration_clip import CLIPVisionConfig
15
- from transformers.modeling_outputs import BaseModelOutputWithPooling
16
- from .modeling_hunyuan import HunYuanDecoderLayer, HunYuanRMSNorm
17
- from .configuration_hunyuan import HunYuanConfig
18
-
19
-
20
- def NaVitForward(input_ids, encoder_input, vit, image_tensors, images_pos, vit_input_resolution, im_start_id, im_end_id, image_token_id, anyres_vit_two_views, dtype):
21
- # input_ids: (B, L)
22
- # encoder_input: (L, B, E)
23
- # image_tensors [[Tensor],...,[Tensor]]
24
- # image_pos [[Tensor],...,[Tensor]]
25
- # tokenizer = get_tokenizer()
26
- b = len(input_ids)
27
- img_embs = None
28
- all_nums = sum([len(tensors) for tensors in image_tensors]) if image_tensors else 0
29
- if all_nums != 0:
30
- img_embs, img_batch_pos = vit(image_tensors)
31
- else:
32
- # when no input image, initialize a fake tensor
33
- pad_nums = 1
34
- image_tensors = [[torch.rand(3, vit_input_resolution, vit_input_resolution, dtype=dtype, device=torch.cuda.current_device()) for _ in range(pad_nums)]]
35
- img_embs, img_batch_pos = vit(image_tensors)
36
-
37
- encoder_input = encoder_input.clone()
38
- if all_nums > 0:
39
- assert len(images_pos) == len(img_batch_pos), \
40
- (len(images_pos), len(img_batch_pos))
41
- start_token_id = im_start_id
42
- end_token_id = im_end_id
43
- placeholder_id = image_token_id
44
- for idx in range(len(images_pos)):
45
- assert len(images_pos[idx]) == len(img_batch_pos[idx]), \
46
- (len(images_pos[idx]), len(img_batch_pos[idx]))
47
- for p_img_pos_in_batch, p_batch_img_pos in zip(img_batch_pos[idx], images_pos[idx]):
48
- # the positions to be filled [s_start, s_end)
49
- s_idx, s_start, s_end = p_img_pos_in_batch
50
- current_embs = img_embs[s_idx, s_start:s_end]
51
- im_s, im_e = p_batch_img_pos
52
- assert len(current_embs) == im_e - im_s, \
53
- (img_embs.shape, (s_start, s_end, s_idx), current_embs.shape, (im_s, im_e, idx))
54
- if not anyres_vit_two_views:
55
- assert input_ids[idx, im_s - 1] == start_token_id, \
56
- input_ids[idx, im_s - 1]
57
- assert input_ids[idx, im_e] == end_token_id, \
58
- input_ids[idx, im_e]
59
- assert (input_ids[idx, im_s:im_e] == placeholder_id).all(), \
60
- f'The tokens to be filled are not the placeholder_id {placeholder_id}: {(input_ids[idx, im_s:im_e] == placeholder_id).sum()} vs {im_e - im_s}'
61
- encoder_input[idx, im_s:im_e] = current_embs
62
- else:
63
- # when no input image, to mask vit value
64
- vit_mask = torch.zeros([1, img_embs.shape[0]], device=torch.cuda.current_device())
65
- current_embs = img_embs[0, :]
66
- encoder_input[0, 1:img_embs.shape[0] + 1] = encoder_input[0, 1:img_embs.shape[0] + 1] * (1 - vit_mask) + current_embs * vit_mask
67
- return encoder_input, input_ids
68
-
69
-
70
- def VitForward(input_ids, encoder_input, vit, vit_linear_encoder, image_tensors, images_pos, vit_input_resolution, vit_mapping_type, vit_patch, vit_token):
71
- vit_patch_mlp = (vit_patch > 1 and vit_mapping_type == 'mlp') or vit_patch == 0
72
-
73
- b = len(input_ids)
74
- if images_pos is None:
75
- images_pos = torch.ones([len(input_ids), 1, 3])
76
- images_pos[:, :, 1] = images_pos[:, :, 1]*(vit_token + 1)
77
- images_pos = images_pos.long()
78
-
79
- real_image_nums = []
80
- image_tensors = image_tensors.view(b, -1, 3, vit_input_resolution, vit_input_resolution)
81
- real_images = []
82
-
83
- all_nums = 0
84
- img_index = []
85
- for s in range(len(images_pos)):
86
- real_image_num = 0
87
- for (im_s, im_e,index) in images_pos[s]:
88
- if im_s == -1:
89
- break
90
- real_image_num += 1
91
- all_nums += 1
92
- img_index.append(index)
93
-
94
- real_image_nums.append(real_image_num)
95
- real_images.append(image_tensors[s][:real_image_num])
96
-
97
- if vit_patch == 1:
98
- img_index = None
99
-
100
- if all_nums == 0:
101
- # when no input image, initialize a fake tensor
102
- img_input = torch.rand(b, 3, vit_input_resolution, vit_input_resolution).cuda().type(image_tensors.dtype)
103
- img_embs = vit(img_input)
104
- img_embs = vit_linear_encoder(img_embs)
105
- else:
106
- img_input = torch.cat(real_images)
107
- img_embs = vit(img_input, img_index = img_index)
108
- img_embs = vit_linear_encoder(img_embs)
109
-
110
- encoder_input = encoder_input.clone()
111
- start = 0
112
- if all_nums > 0:
113
- for s, real_image_len in enumerate(real_image_nums):
114
- current_embs = img_embs[start:start + real_image_len, :] #[30, 256, 4096]
115
- for ss in range(current_embs.shape[0]):
116
- im_s, im_e, index = images_pos[s, ss]
117
- # 子图特征更少
118
- if index > 0 and vit_patch_mlp:
119
- encoder_input[s, im_s:im_e,] = current_embs[ss, :(im_e-im_s)]
120
- else:
121
- encoder_input[s, im_s:im_e] = current_embs[ss, :]
122
- start = start + real_image_len
123
- else:
124
- # when no input image, to mask vit value
125
- for s in range(b):
126
- vit_mask = torch.zeros([vit_token, 1]).cuda()
127
- current_embs = img_embs[:, start:start + 1]
128
- encoder_input[1:vit_token + 1, s] = encoder_input[1:vit_token + 1, s] * (1 - vit_mask) + current_embs[:, 0, :] * vit_mask
129
- start = start + 1
130
- return encoder_input, input_ids
131
-
132
-
133
- def group_images_by_max_seq_len(
134
- images: List[List[Tensor]], patch_size: int,
135
- max_seq_len: int, adaptor_patch_size: int,
136
- add_cls_token: bool = False) -> List[List[Tensor]]:
137
-
138
- groups = []
139
- group = []
140
- pos_groups = []
141
- seq_len = 0
142
- num_images = 0
143
- for image_list in images:
144
- pos_group = []
145
- for image in image_list:
146
- num_images += 1
147
- assert isinstance(image, Tensor)
148
-
149
- image_dims = image.shape[-2:]
150
- ph, pw = map(lambda t: t // patch_size, image_dims)
151
-
152
- image_seq_len = (ph * pw)
153
- new_image_seq_len = image_seq_len
154
- grouped_len = seq_len + image_seq_len
155
- if add_cls_token:
156
- new_image_seq_len += 1
157
- grouped_len += num_images
158
-
159
- assert new_image_seq_len <= max_seq_len, f'image with dimensions {image_dims} exceeds maximum sequence length'
160
-
161
- if grouped_len > max_seq_len:
162
- groups.append(group)
163
- group = []
164
- seq_len = 0
165
- num_images = 1
166
-
167
- group.append(image)
168
- start = seq_len // (adaptor_patch_size * adaptor_patch_size)
169
- end = start + image_seq_len//(adaptor_patch_size * adaptor_patch_size)
170
- batch_idx = len(groups)
171
- pos_group.append([batch_idx, start, end])
172
- seq_len += image_seq_len
173
- pos_groups.append(pos_group)
174
-
175
- if len(group) > 0:
176
- groups.append(group)
177
-
178
- return groups, pos_groups
179
-
180
-
181
- class AnyResCLIPVisionEmbeddings(nn.Module):
182
- def __init__(self, config: CLIPVisionConfig):
183
- super().__init__()
184
-
185
- self.config = config
186
- # self.sparse_attn_mask = args.sparse_attn_mask
187
- # self.use_flash_attn = args.use_flash_attn
188
- self.embed_dim = config.hidden_size
189
- self.image_size = config.max_image_size
190
- self.patch_size = config.patch_size
191
- self.max_seq_len = config.max_vit_seq_len
192
- self.adaptor_patch_size = config.adaptor_patch_size
193
- self.anyres_vit_two_views = config.anyres_vit_two_views
194
- self.vit_add_patchemb_bias = config.vit_add_patchemb_bias
195
- self.vit_remove_prenorm = config.vit_remove_prenorm
196
-
197
- self.patch_embedding = nn.Conv2d(
198
- in_channels=config.num_channels,
199
- out_channels=self.embed_dim,
200
- kernel_size=self.patch_size,
201
- stride=self.patch_size,
202
- bias=self.vit_add_patchemb_bias,
203
- )
204
-
205
- self.num_patches = (self.image_size // self.patch_size) ** 2
206
- self.skip_cls_token = True
207
-
208
- # add interpolate_pos_encoding
209
- if self.anyres_vit_two_views:
210
- self.num_positions = self.num_patches
211
- self.position_embedding = nn.Parameter(torch.randn(1, self.num_positions, self.embed_dim) * 0.02)
212
- else:
213
- self.num_positions = self.num_patches + 1
214
- self.register_buffer("position_ids", torch.arange(self.num_positions).expand((1, -1)))
215
- # self.position_ids = torch.arange(self.num_positions).expand((1, -1))
216
- self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim)
217
-
218
- if not self.vit_remove_prenorm:
219
- self.pre_layernorm = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
220
-
221
- def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor:
222
- """
223
- This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher
224
- resolution images.
225
-
226
- Source:
227
- https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174
228
- """
229
- num_patches = embeddings.shape[1]
230
- position_embeddings = self.position_embedding(self.position_ids)
231
- patch_pos_embed = position_embeddings[:, 1:]
232
- num_positions = position_embeddings.shape[1] - 1
233
- if num_patches == num_positions and height == width:
234
- return patch_pos_embed
235
- # class_pos_embed = position_embeddings[:, 0]
236
- dim = embeddings.shape[-1]
237
- h0 = height // self.patch_size
238
- w0 = width // self.patch_size
239
- # we add a small number to avoid floating point error in the interpolation
240
- # see discussion at https://github.com/facebookresearch/dino/issues/8
241
- h0, w0 = h0 + 0.1, w0 + 0.1
242
- patch_pos_embed = patch_pos_embed.reshape(1, int(math.sqrt(num_positions)), int(math.sqrt(num_positions)), dim)
243
- patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2)
244
- raw_type = patch_pos_embed.dtype
245
- patch_pos_embed = nn.functional.interpolate(
246
- patch_pos_embed.to(torch.float32, non_blocking=True),
247
- scale_factor=(h0 / math.sqrt(num_positions), w0 / math.sqrt(num_positions)),
248
- mode="bilinear",
249
- align_corners=False,
250
- )
251
- patch_pos_embed = patch_pos_embed.to(raw_type, non_blocking=True)
252
- assert int(h0) == patch_pos_embed.shape[-2] and int(w0) == patch_pos_embed.shape[-1]
253
- patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
254
- return patch_pos_embed
255
-
256
- def rescale_positional_embedding(self, out_size):
257
- h, w = out_size
258
- pos_embed_shape = int((self.position_embedding.shape[1]) ** 0.5)
259
- if (h, w) == (pos_embed_shape, pos_embed_shape):
260
- return self.position_embedding
261
- rescaled_positional_embedding = \
262
- self.position_embedding.new_zeros(1, h*w, self.position_embedding.shape[2])
263
- pe_2d = self.position_embedding[0].T.contiguous().view(1, -1, pos_embed_shape, pos_embed_shape)
264
- pe_2d = F.interpolate(pe_2d, out_size, mode='bilinear', align_corners=False).view(-1, h*w)
265
- rescaled_positional_embedding[0] = pe_2d.T.contiguous()
266
- return rescaled_positional_embedding
267
-
268
- def forward_single(self, pixel_values: torch.FloatTensor) -> torch.Tensor:
269
- if pixel_values.ndim == 3:
270
- pixel_values = pixel_values[None]
271
- batch_size, num_channels, height, width = pixel_values.shape
272
-
273
- if self.anyres_vit_two_views:
274
- # padding
275
- pad_h = (self.patch_size - height % self.patch_size) % self.patch_size
276
- pad_w = (self.patch_size - width % self.patch_size) % self.patch_size
277
- pixel_values = F.pad(pixel_values, (0, pad_w, 0, pad_h))
278
-
279
- patch_embeds = self.patch_embedding(pixel_values) # shape = [*, width, grid, grid]
280
- b, c, h, w = patch_embeds.shape
281
-
282
- # (b, hw, c)
283
- patch_embeds = patch_embeds.flatten(2).transpose(1, 2)
284
- if self.anyres_vit_two_views:
285
- embeddings = patch_embeds + self.rescale_positional_embedding(out_size=(h, w))
286
- else:
287
- embeddings = patch_embeds + self.interpolate_pos_encoding(patch_embeds, height, width)
288
- if not self.vit_remove_prenorm:
289
- embeddings = self.pre_layernorm(embeddings)
290
- return embeddings, (h, w)
291
-
292
- def forward(self, images: List[List[Tensor]]):
293
- '''
294
- Input:
295
- images: List[List[Tensor]]
296
-
297
- Return:
298
- embeddings: Tensor (B, L, E)
299
- attn_mask: Tensor (B, L, 2)
300
- pos_groups: List[List[(batch_idx, start, end)]]
301
- '''
302
- batched_images, pos_groups = group_images_by_max_seq_len(
303
- images, self.patch_size, self.max_seq_len, self.adaptor_patch_size, add_cls_token=not self.skip_cls_token)
304
- max_seq_len = self.max_seq_len
305
-
306
- # batched_images is a list of a list
307
- B = len(batched_images)
308
- L = max_seq_len
309
- E = self.embed_dim
310
-
311
- embeddings = torch.zeros(B, L, E, dtype=self.config.torch_dtype, requires_grad=True).cuda(non_blocking=True)
312
- attn_mask = embeddings.new_full((B, 1, L, L), False, dtype=torch.bool) # True presents compute
313
- assert len(images) == len(pos_groups), (len(images), len(pos_groups))
314
-
315
- batch_images = []
316
- batch_pos = []
317
- for images_i, pos_group in zip(images, pos_groups):
318
- assert len(images_i) == len(pos_group), (len(images_i), len(pos_group))
319
- for image, pos in zip(images_i, pos_group):
320
- batch_idx, start, end = pos
321
- a2 = self.adaptor_patch_size ** 2
322
- # recover the real number of the input image tokens
323
- start *= a2
324
- end *= a2
325
- emb, _ = self.forward_single(image)
326
- assert emb.ndim == 3, '(B, L, E)'
327
- embeddings[batch_idx, start:end] = emb
328
- attn_mask[batch_idx, :, start:end, start:end] = True
329
- return embeddings, attn_mask, pos_groups
330
-
331
-
332
- class CLIPVisionEmbeddings(nn.Module):
333
- def __init__(self, config: CLIPVisionConfig, add_pre_layernorm=False, skip_cls_token=True, vit_patch=1):
334
- super().__init__()
335
- self.config = config
336
- self.embed_dim = config.hidden_size
337
- self.image_size = config.image_size
338
- self.image_size = config.vit_input_resolution
339
- self.patch_size = config.patch_size
340
-
341
- self.class_embedding = nn.Parameter(torch.randn(self.embed_dim))
342
-
343
- self.patch_embedding = nn.Conv2d(
344
- in_channels=config.num_channels,
345
- out_channels=self.embed_dim,
346
- kernel_size=self.patch_size,
347
- stride=self.patch_size,
348
- bias=False,
349
- )
350
-
351
- self.num_patches = (self.image_size // self.patch_size) ** 2
352
-
353
- self.skip_cls_token = skip_cls_token
354
-
355
- self.num_positions = self.num_patches + 1
356
-
357
- self.register_buffer("position_ids", torch.arange(self.num_positions).expand((1, -1)))
358
- if vit_patch > 1:
359
- self.position_embedding = nn.Embedding(self.num_patches * (vit_patch ** 2 + 1) + 1, self.embed_dim)
360
- # 0 支持最大16张图,目前写死了,如需其他的需要额外定义参数
361
- elif vit_patch == 0:
362
- self.position_embedding = nn.Embedding(self.num_patches * (16 ** 2 + 1) + 1, self.embed_dim)
363
- else:
364
- self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim)
365
-
366
- if add_pre_layernorm:
367
- self.pre_layernorm = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
368
- else:
369
- self.pre_layernorm = None
370
-
371
- def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor:
372
- """
373
- This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher
374
- resolution images.
375
-
376
- Source:
377
- https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174
378
- """
379
- num_patches = embeddings.shape[1] - 1
380
- position_embeddings = self.position_embedding(self.position_ids)
381
- num_positions = position_embeddings.shape[1] - 1
382
- if num_patches == num_positions and height == width:
383
- return position_embeddings
384
- class_pos_embed = position_embeddings[:, 0]
385
- patch_pos_embed = position_embeddings[:, 1:]
386
- dim = embeddings.shape[-1]
387
- h0 = height // self.config.patch_size
388
- w0 = width // self.config.patch_size
389
- # we add a small number to avoid floating point error in the interpolation
390
- # see discussion at https://github.com/facebookresearch/dino/issues/8
391
- h0, w0 = h0 + 0.1, w0 + 0.1
392
- patch_pos_embed = patch_pos_embed.reshape(1, int(math.sqrt(num_positions)), int(math.sqrt(num_positions)), dim)
393
- patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2)
394
- raw_type = patch_pos_embed.dtype
395
- patch_pos_embed = nn.functional.interpolate(
396
- patch_pos_embed.float(),
397
- scale_factor=(h0 / math.sqrt(num_positions), w0 / math.sqrt(num_positions)),
398
- mode="bicubic",
399
- align_corners=False,
400
- )
401
- # print(patch_pos_embed.shape)
402
- patch_pos_embed = patch_pos_embed.to(raw_type)
403
- assert int(h0) == patch_pos_embed.shape[-2] and int(w0) == patch_pos_embed.shape[-1]
404
- patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
405
- return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1)
406
-
407
-
408
- def forward(self, pixel_values: torch.FloatTensor, interpolate_pos_encoding: bool = False, img_index=None) -> torch.Tensor:
409
- batch_size, num_channels, height, width = pixel_values.shape
410
- patch_embeds = self.patch_embedding(pixel_values) # shape = [*, width, grid, grid]
411
- patch_embeds = patch_embeds.flatten(2).transpose(1, 2)
412
- if self.skip_cls_token:
413
- embeddings = patch_embeds
414
- if img_index is None:
415
- position_ids = self.position_ids[:,1:]
416
- embeddings = embeddings + self.position_embedding(position_ids)
417
- else:
418
- position_ids = (torch.tensor(img_index).cuda() * (self.num_positions - 1)).unsqueeze(1).repeat(1, self.num_positions - 1) \
419
- + self.position_ids.expand(batch_size, -1)[:, 1:]
420
- embeddings = embeddings + self.position_embedding(position_ids)
421
- else:
422
- class_embeds = self.class_embedding.expand(batch_size, 1, -1)
423
- embeddings = torch.cat([class_embeds, patch_embeds], dim=1)
424
- if interpolate_pos_encoding:
425
- embeddings = embeddings + self.interpolate_pos_encoding(embeddings, height, width)
426
- else:
427
- if img_index is None:
428
- embeddings = embeddings + self.position_embedding(self.position_ids)
429
- else:
430
- position_ids = self.position_ids.expand(batch_size,-1)[:,0].unsqueeze(1)
431
- new_position = (torch.tensor(img_index).cuda() * (self.num_positions -1)).unsqueeze(1).repeat(1,self.num_positions-1) + self.position_ids.expand(batch_size,-1)[:,1:]
432
- position_ids = torch.cat([position_ids,new_position],dim=1)
433
- embeddings = embeddings + self.position_embedding(position_ids)
434
- if self.pre_layernorm is not None:
435
- embeddings = self.pre_layernorm(embeddings)
436
- return embeddings
437
-
438
-
439
- class NaVitTransformer(nn.Module):
440
- def __init__(self, config: HunYuanConfig, vit_config: CLIPVisionConfig):
441
- super().__init__()
442
- self.config = config
443
- self.vit_config = vit_config
444
- with self.prepare_args(config, vit_config):
445
- self._use_sdpa = config._attn_implementation == "sdpa"
446
- self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
447
- self.layers = nn.ModuleList(
448
- [HunYuanDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
449
- )
450
-
451
- @contextmanager
452
- def prepare_args(self, config, vit_config):
453
- hidden_act = config.hidden_act
454
- hidden_size = config.hidden_size
455
- ffn_hidden_size = config.intermediate_size
456
- num_attention_heads = config.num_attention_heads
457
- num_key_value_heads = config.num_key_value_heads
458
- attention_head_dim = config.attention_head_dim
459
- use_qk_norm = config.use_qk_norm
460
- use_rotary_pos_emb = config.use_rotary_pos_emb
461
- num_hidden_layers = config.num_hidden_layers
462
- rms_norm_eps = config.rms_norm_eps
463
- attention_dropout = config.attention_dropout
464
- # hidden_dropout = config.hidden_dropout
465
- norm_type = config.norm_type
466
- attention_bias = config.attention_bias
467
- mlp_bias = config.mlp_bias
468
- use_mla = config.use_mla
469
- num_experts = config.num_experts
470
- _attn_implementation = config._attn_implementation
471
-
472
- config.hidden_act = vit_config.hidden_act
473
- config.hidden_size = vit_config.hidden_size
474
- config.intermediate_size = vit_config.intermediate_size
475
- config.num_attention_heads = vit_config.num_attention_heads
476
- config.num_key_value_heads = None
477
- config.attention_head_dim = vit_config.hidden_size // vit_config.num_attention_heads
478
- config.use_qk_norm = False
479
- config.use_rotary_pos_emb = False
480
- config.num_hidden_layers = vit_config.num_hidden_layers
481
- config.rms_norm_eps = vit_config.layer_norm_eps
482
- config.attention_dropout = vit_config.attention_dropout
483
- # config.hidden_dropout = vit_config.hidden_dropout
484
- config.norm_type = config.vit_norm_type
485
- config.attention_bias = True
486
- config.mlp_bias = True
487
- config.use_mla = False
488
- config.num_experts = 1
489
- config._attn_implementation = "eager"
490
-
491
- yield
492
- config.hidden_act = hidden_act
493
- config.hidden_size = hidden_size
494
- config.intermediate_size = ffn_hidden_size
495
- config.num_attention_heads = num_attention_heads
496
- config.num_key_value_heads = num_key_value_heads
497
- config.attention_head_dim = attention_head_dim
498
- config.use_qk_norm = use_qk_norm
499
- config.use_rotary_pos_emb = use_rotary_pos_emb
500
- config.num_hidden_layers = num_hidden_layers
501
- config.rms_norm_eps = rms_norm_eps
502
- config.attention_dropout = attention_dropout
503
- # config.hidden_dropout = hidden_dropout
504
- config.attention_bias = attention_bias
505
- config.mlp_bias = mlp_bias
506
- config.norm_type = norm_type
507
- config.use_mla = use_mla
508
- config.num_experts = num_experts
509
- config._attn_implementation = _attn_implementation
510
-
511
- def forward(
512
- self,
513
- pixel_values: Optional[torch.FloatTensor] = None,
514
- ) -> Union[Tuple, BaseModelOutputWithPooling]:
515
-
516
- hidden_states, attention_mask, img_pos = self.embeddings(pixel_values)
517
- attention_mask = attention_mask.int()
518
- batch_size, seq_length, _ = hidden_states.shape
519
- past_key_values_length = 0
520
-
521
- if self._use_flash_attention_2:
522
- # 2d mask is passed through the layers
523
- attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
524
- elif self._use_sdpa:
525
- # output_attentions=True can not be supported when using SDPA, and we fall back on
526
- # the manual implementation that requires a 4D causal mask in all cases.
527
- attention_mask = _prepare_4d_causal_attention_mask_for_sdpa(
528
- attention_mask,
529
- (batch_size, seq_length),
530
- hidden_states,
531
- past_key_values_length,
532
- )
533
- else:
534
- attention_mask = _prepare_4d_causal_attention_mask(
535
- attention_mask,
536
- (batch_size, seq_length),
537
- hidden_states,
538
- past_key_values_length,
539
- )
540
-
541
- for layer_idx, decoder_layer in enumerate(self.layers):
542
- layer_outputs = decoder_layer(
543
- hidden_states,
544
- attention_mask=attention_mask
545
- )
546
- hidden_states = layer_outputs[0]
547
-
548
- return hidden_states, img_pos
549
-
550
-
551
- class AnyResVitTransformer(NaVitTransformer):
552
- def __init__(self, config: HunYuanConfig, vit_config: CLIPVisionConfig, anyres_vit_max_image_size):
553
- super().__init__(config, vit_config)
554
- old_anyres_vit_max_image_size = vit_config.max_image_size
555
- anyres_vit_max_image_size = anyres_vit_max_image_size or old_anyres_vit_max_image_size
556
- vit_config.max_image_size = anyres_vit_max_image_size
557
- vit_config.torch_dtype = config.torch_dtype
558
- vit_config.anyres_vit_two_views = config.anyres_vit_two_views
559
- vit_config.vit_remove_prenorm = config.vit_remove_prenorm
560
- vit_config.vit_add_patchemb_bias = config.vit_add_patchemb_bias
561
- self.embeddings = AnyResCLIPVisionEmbeddings(vit_config)
562
- vit_config.max_image_size = old_anyres_vit_max_image_size
563
-
564
- def fix_embeddings_fn(self, pixel_values):
565
- # (B, L, E)
566
- embeddings, hw = self.embeddings.forward_single(pixel_values)
567
- embeddings = self.embeddings.pre_layernorm(embeddings)
568
- return embeddings
569
-
570
-
571
- class CLIPVisionTransformer(nn.Module):
572
- def __init__(self, config: HunYuanConfig, vit_config: CLIPVisionConfig):
573
- super().__init__()
574
- embed_dim = vit_config.hidden_size
575
-
576
- self.pre_layrnorm = nn.LayerNorm(embed_dim, eps=vit_config.layer_norm_eps)
577
- self.embeddings = CLIPVisionEmbeddings(vit_config, skip_cls_token=config.skip_cls_token, vit_patch=config.vit_patch)
578
-
579
- with self.prepare_args(config, vit_config):
580
- self.layers = nn.ModuleList(
581
- [HunYuanDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
582
- )
583
-
584
- @contextmanager
585
- def prepare_args(self, config, vit_config):
586
- hidden_act = config.hidden_act
587
- hidden_size = config.hidden_size
588
- ffn_hidden_size = config.intermediate_size
589
- num_attention_heads = config.num_attention_heads
590
- num_key_value_heads = config.num_key_value_heads
591
- attention_head_dim = config.attention_head_dim
592
- use_qk_norm = config.use_qk_norm
593
- use_rotary_pos_emb = config.use_rotary_pos_emb
594
- num_hidden_layers = config.num_hidden_layers
595
- rms_norm_eps = config.rms_norm_eps
596
- attention_dropout = config.attention_dropout
597
- # hidden_dropout = config.hidden_dropout
598
- norm_type = config.norm_type
599
- attention_bias = config.attention_bias
600
- mlp_bias = config.mlp_bias
601
- use_mla = config.use_mla
602
- num_experts = config.num_experts
603
- _attn_implementation = config._attn_implementation
604
-
605
- config.hidden_act = vit_config.hidden_act
606
- config.hidden_size = vit_config.hidden_size
607
- config.intermediate_size = vit_config.intermediate_size
608
- config.num_attention_heads = vit_config.num_attention_heads
609
- config.num_key_value_heads = None
610
- config.attention_head_dim = vit_config.hidden_size // vit_config.num_attention_heads
611
- config.use_qk_norm = False
612
- config.use_rotary_pos_emb = False
613
- config.num_hidden_layers = vit_config.num_hidden_layers
614
- config.rms_norm_eps = vit_config.layer_norm_eps
615
- config.attention_dropout = vit_config.attention_dropout
616
- # config.hidden_dropout = 0.0
617
- config.norm_type = "fused"
618
- config.attention_bias = True
619
- config.mlp_bias = True
620
- config.use_mla = False
621
- config.num_experts = 1
622
- config._attn_implementation = "eager"
623
-
624
- yield
625
-
626
- config.hidden_act = hidden_act
627
- config.hidden_size = hidden_size
628
- config.intermediate_size = ffn_hidden_size
629
- config.num_attention_heads = num_attention_heads
630
- config.num_key_value_heads = num_key_value_heads
631
- config.attention_head_dim = attention_head_dim
632
- config.use_qk_norm = use_qk_norm
633
- config.use_rotary_pos_emb = use_rotary_pos_emb
634
- config.num_hidden_layers = num_hidden_layers
635
- config.rms_norm_eps = rms_norm_eps
636
- config.attention_dropout = attention_dropout
637
- # config.hidden_dropout = hidden_dropout
638
- config.norm_type = norm_type
639
- config.attention_bias = attention_bias
640
- config.mlp_bias = mlp_bias
641
- config.use_mla = use_mla
642
- config.num_experts = num_experts
643
- config._attn_implementation = _attn_implementation
644
-
645
- def forward(
646
- self,
647
- pixel_values: Optional[torch.FloatTensor] = None,
648
- interpolate_pos_encoding: Optional[bool] = None,
649
- img_index=None
650
- ) -> Union[Tuple, BaseModelOutputWithPooling]:
651
- r"""
652
- Returns:
653
-
654
- """
655
- hidden_states = self.embeddings(pixel_values, interpolate_pos_encoding=interpolate_pos_encoding, img_index=img_index)
656
- hidden_states = self.pre_layrnorm(hidden_states)
657
- batch = hidden_states.shape[0]
658
- seq_len = hidden_states.shape[1]
659
- device = hidden_states.device
660
- attention_mask = torch.ones(batch, 1, seq_len, seq_len, dtype=torch.float32, device=device)
661
-
662
- for layer_idx, decoder_layer in enumerate(self.layers):
663
- layer_outputs = decoder_layer(
664
- hidden_states,
665
- attention_mask=attention_mask
666
- )
667
- hidden_states = layer_outputs[0]
668
-
669
- return hidden_states
670
-
671
-
672
- class Vit(torch.nn.Module):
673
- def __init__(self, config, resampler_token=64, pool_rate=2):
674
- super().__init__()
675
- self.config = config
676
- self.vit_mapping_type = config.vit_mapping_type
677
- self.anyres_vit_max_image_size = config.anyres_vit_max_image_size
678
- self.skip_cls_token = config.skip_cls_token
679
- self.pool_rate = pool_rate
680
- self.vit_type = self.config.vit_type
681
- self.anyres_vit_two_views = self.config.anyres_vit_two_views
682
- if self.vit_type in ['Vit-g', 'Vit-bigG', 'NaVit', 'EvaVit', 'AnyResVit']:
683
- self.img_init(resampler_token, config.vit_input_resolution, config.vit_mapping_type, pool_rate)
684
- else:
685
- raise NotImplementedError(f"unsupported vit type: {self.vit_type}")
686
-
687
- def img_init(self, resampler_token=64, vit_input_resolution=224, vit_mapping_type='resampler', pool_rate=2):
688
- if self.vit_type == 'AnyResVit':
689
- vit_config = json.load(open(f"{self.config.vit_path}/config.json"))
690
- self.vit_config = types.SimpleNamespace(**vit_config["vision_config"])
691
- self.vit_config.image_size = vit_input_resolution
692
- self.vit = AnyResVitTransformer(self.config, self.vit_config, self.anyres_vit_max_image_size)
693
- elif self.vit_type == 'Vit-g':
694
- vit_config = json.load(open(f"{self.config.vit_path}/config.json"))
695
- self.vit_config = types.SimpleNamespace(**{**vit_config["vision_config_dict"],**vit_config["vision_config"]})
696
- self.vit_config.vit_input_resolution = vit_input_resolution
697
- self.vit = CLIPVisionTransformer(self.config, self.vit_config)
698
- else:
699
- assert False, "other vit_types are not supported"
700
-
701
- if self.vit_mapping_type == 'simple_conv_mlp':
702
- self.perceive = SimpleConvMlp(self.vit_config.hidden_size, self.config.hidden_size, self.config.anyres_pooling_size, \
703
- self.config.vit_used_rms_norm, self.config.rms_norm_eps, poolmlp=False, twoview=True)
704
- elif self.vit_mapping_type == 'oryx_mlp':
705
- self.perceive = OryxMLPv2(self.vit_config.hidden_size, self.config.hidden_size, twoview=True, use_pe=False)
706
- elif self.vit_mapping_type == 'mlp':
707
- self.mlp_depth = 2
708
- # one mlp layer already in gpt_model.py
709
- mlp_hidden_size = self.vit_config.hidden_size
710
- if self.vit_type in ['NaVit', 'EvaVit']:
711
- mlp_hidden_size *= self.vit_config.adaptor_patch_size **2
712
- if self.mlp_depth > 1:
713
- mlp_modules = [torch.nn.Linear(mlp_hidden_size, self.config.hidden_size), torch.nn.GELU()]
714
- if self.vit_type in ['NaVit', 'EvaVit']:
715
- for _ in range(1, self.mlp_depth):
716
- mlp_modules.append(torch.nn.Linear(self.config.hidden_size, self.config.hidden_size))
717
- mlp_modules.append(torch.nn.GELU())
718
- self.perceive = torch.nn.Sequential(*mlp_modules)
719
- else:
720
- assert False, "other vit_mapping_types are not supported"
721
-
722
- self.vit_patch_mlp = (self.config.vit_patch > 1 and self.vit_mapping_type == 'mlp') or self.config.vit_patch == 0
723
- for name, param in self.named_parameters():
724
- setattr(param, "is_vit_param", True)
725
-
726
- def forward(self, images, img_index=None):
727
- if self.vit_type in ['AnyResVit']:
728
- dtype = self.config.torch_dtype
729
- device = torch.cuda.current_device()
730
-
731
- images_size = []
732
- for i in range(len(images)):
733
- images_size.append([])
734
- for j in range(len(images[i])):
735
- images_size[i].append((images[i][j].size()[1] // self.vit_config.patch_size, images[i][j].size()[2] // self.vit_config.patch_size))
736
-
737
- images_feats, img_batch_pos = self.vit(pixel_values=images)
738
- a2 = self.vit_config.adaptor_patch_size ** 2
739
-
740
- if self.anyres_vit_two_views:
741
- step = 2
742
- else:
743
- step = 1
744
- perceive_fn = lambda x, img_size, is_video: self.perceive(x, img_size, is_video=is_video)
745
- images_list = []
746
- images_fix_i = 0
747
- num_img_batch_pos = len(img_batch_pos)
748
- for i in range(num_img_batch_pos): # batch_id
749
- for j in range(0, len(img_batch_pos[i]), step):
750
- if self.anyres_vit_two_views:
751
- lower_idx, lower_begin, lower_end = img_batch_pos[i][j]
752
- lower_begin = lower_begin * a2
753
- lower_end = lower_end * a2
754
- higher_idx, higher_begin, higher_end = img_batch_pos[i][j + 1]
755
- higher_begin = higher_begin * a2
756
- higher_end = higher_end * a2
757
- lower_res_feat = images_feats[lower_idx, lower_begin:lower_end].unsqueeze(0)
758
- higher_res_feat = images_feats[higher_idx, higher_begin:higher_end].unsqueeze(0)
759
- lower_images_size = images_size[i][j]
760
- higher_images_size = images_size[i][j + 1]
761
- images_list.append(self.perceive(lower_res_feat, lower_images_size, higher_res_feat, higher_images_size))
762
- else:
763
- idx, begin, end = img_batch_pos[i][j]
764
- begin = begin * a2
765
- end = end * a2
766
- is_video = hasattr(images[i][j],'_is_video') and images[i][j]._is_video
767
- images_list.append(perceive_fn(images_feats[idx, begin:end].unsqueeze(0), images_size[i][j], is_video=is_video))
768
-
769
- images = torch.cat(images_list, dim=1)
770
-
771
- new_batch_pos = []
772
- k = 0; cur_len = 0
773
- for i in range(len(images_size)):
774
- new_batch_pos.append([])
775
- for j in range(0, len(images_size[i]), step):
776
- new_pos = [0, cur_len, cur_len + images_list[k].size(1)]
777
- cur_len += images_list[k].size(1)
778
- k += 1
779
- new_batch_pos[i].append(new_pos)
780
- return images, new_batch_pos
781
- elif self.vit_type == 'Vit-g':
782
- images = self.vit(pixel_values=images, interpolate_pos_encoding=False, img_index=img_index)
783
- else:
784
- assert False, "other vit_types are not supported"
785
-
786
- if self.vit_mapping_type == 'mlp':
787
- if self.vit_type in ['Vit-g'] and not self.skip_cls_token:
788
- images = images[:,1:,:]
789
- b, v, d = images.shape
790
- s = int(math.sqrt(v))
791
- images = images.reshape(b, s, s, d)
792
-
793
-
794
- if self.vit_patch_mlp and img_index is not None:
795
- L_tensor = torch.tensor(img_index)
796
- device = images.device
797
- # 获取子图位置
798
- nonzero_indices = torch.nonzero(L_tensor).squeeze().to(device)
799
- # 获取主图位置
800
- zero_indices = torch.nonzero(L_tensor == 0).squeeze().to(device)
801
-
802
-
803
- images_nonzero = torch.index_select(images,0, nonzero_indices).to(device)
804
- images_zero = torch.index_select(images, 0, zero_indices).to(device)
805
-
806
- # 子图额外多pool一次
807
- pool_rate = self.pool_rate * 2
808
- images_nonzero = images_nonzero.reshape(-1, s // pool_rate, pool_rate, s // pool_rate, pool_rate, d)
809
- images_nonzero = images_nonzero.permute(0, 1, 3, 5, 2, 4).reshape(-1, (s // pool_rate) * (s // pool_rate), d,
810
- pool_rate*pool_rate).mean(-1)
811
-
812
- # 为了组batch折衷方案
813
- images_nonzero = F.pad(images_nonzero, (0, 0, 0, (s // self.pool_rate) * (s // self.pool_rate)- (s // pool_rate) * (s // pool_rate)))
814
- images_zero = images_zero.reshape(-1, s // self.pool_rate, self.pool_rate, s // self.pool_rate, self.pool_rate, d)
815
- images_zero = images_zero.permute(0, 1, 3, 5, 2, 4).reshape(-1, (s // self.pool_rate) * (s // self.pool_rate), d,
816
- self.pool_rate*self.pool_rate).mean(-1)
817
- # 组batch
818
- images = torch.zeros(b, (s // self.pool_rate) * (s // self.pool_rate), d).to(device).to(images.dtype)
819
- images.index_copy_(0, nonzero_indices, images_nonzero)
820
- images.index_copy_(0, zero_indices, images_zero)
821
-
822
- if self.mlp_depth >= 2:
823
- images = self.perceive(images)
824
- else:
825
- if s % self.pool_rate == 0:
826
- images = images.reshape(b, s//self.pool_rate, self.pool_rate, s//self.pool_rate, self.pool_rate, d)
827
- images = images.permute(0, 1, 3, 5, 2, 4).reshape(b, (s//self.pool_rate) * (s//self.pool_rate), d, -1).mean(-1)
828
- if self.mlp_depth >= 2:
829
- images = self.perceive(images)
830
- else:
831
- raise ValueError
832
- return images
833
-
834
-
835
- class SimpleConvMlp(nn.Module):
836
- def __init__(self, in_channels, out_channels, anyres_pooling_size, vit_used_rms_norm, rms_norm_eps, twoview=False, poolmlp=True, cat_extra_token=True):
837
- super().__init__()
838
-
839
- embed_std = 1 / math.sqrt(out_channels)
840
- if poolmlp:
841
- # if args.learnable_mlp_pooling_size is not None:
842
- # in_channels *= args.learnable_mlp_pooling_size ** 2
843
- self.proj = nn.Sequential(
844
- nn.Linear(in_channels, out_channels),
845
- nn.GELU()
846
- )
847
- self.vit_linear_encoder = nn.Linear(out_channels, out_channels)
848
- self.image_newline = nn.Parameter(
849
- torch.randn(out_channels) * embed_std
850
- )
851
- else:
852
- self.proj = nn.Sequential(
853
- nn.Conv2d(in_channels, in_channels * 2, kernel_size=anyres_pooling_size, stride=anyres_pooling_size),
854
- nn.GELU(),
855
- nn.Conv2d(in_channels * 2, in_channels * 4, kernel_size=1),
856
- )
857
- self.mlp = nn.Linear(in_channels * 4, out_channels)
858
- self.image_newline = nn.Parameter(
859
- torch.randn(in_channels * 4) * embed_std
860
- )
861
- self.poolmlp = poolmlp
862
-
863
- self.image_begin = nn.Parameter(
864
- torch.randn(out_channels) * embed_std
865
- )
866
- self.image_end = nn.Parameter(
867
- torch.randn(out_channels) * embed_std
868
- )
869
-
870
- if twoview:
871
- self.image_sep = nn.Parameter(
872
- torch.randn(out_channels) * embed_std
873
- )
874
-
875
- self.cat_extra_token = cat_extra_token
876
- self.use_rms_norm = vit_used_rms_norm
877
- if self.use_rms_norm:
878
- self.before_rms = HunYuanRMSNorm(in_channels, eps=rms_norm_eps)
879
- self.after_rms = HunYuanRMSNorm(out_channels, eps=rms_norm_eps)
880
-
881
- def forward(self, x, size=(16,16), x2=None, size2=(16, 16), is_video=False):
882
- return self.single_forward(x=x, size=size, x2=x2, size2=size2, is_video=is_video)
883
-
884
- def single_forward(self, x, size=(16,16), x2=None, size2=(16, 16), is_video=False):
885
- remove_vit_special_tokens = False
886
- learnable_mlp_pooling_size = None
887
- if self.use_rms_norm:
888
- x = self.before_rms(x)
889
- h, w = size
890
- dtype = x.dtype
891
- x = x.permute(0, 2, 1).reshape(x.shape[0], -1, h, w)
892
- if self.poolmlp:
893
- if learnable_mlp_pooling_size is None:
894
- x = F.avg_pool2d(x, anyres_pooling_size)
895
- x = self.proj(x.permute(0, 2, 3, 1)) # b, h, w, c
896
- else:
897
- x = x.permute(0, 2, 3, 1) # b, h, w, c
898
- x = x.reshape(x.shape[0], h // learnable_mlp_pooling_size, learnable_mlp_pooling_size,
899
- w // learnable_mlp_pooling_size, learnable_mlp_pooling_size, -1)
900
- x = x.permute(0, 1, 3, 2, 4, 5).reshape(x.shape[0], h // learnable_mlp_pooling_size, w // learnable_mlp_pooling_size, -1)
901
- x = self.proj(x)
902
- x = self.vit_linear_encoder(x)
903
- b, h, w, c = x.shape
904
- if not remove_vit_special_tokens:
905
- x = torch.cat([
906
- x,
907
- self.image_newline.reshape(1, 1, 1, c).expand(b, h, 1, c).to(dtype, non_blocking=True)
908
- ], dim=2)
909
- x = x.reshape(b, -1, c)
910
- else:
911
- x = self.proj(x) #b,c,h,w
912
- if is_video:
913
- video_avgpool_size = 2
914
- stride = 2
915
- x = F.avg_pool2d(x, kernel_size = video_avgpool_size, stride = stride)
916
- b, c, h, w = x.shape
917
- if not remove_vit_special_tokens:
918
- x = torch.cat([
919
- x,
920
- self.image_newline.reshape(1, c, 1, 1).expand(b, c, h, 1).to(dtype, non_blocking=True)
921
- ], dim=-1)
922
- x = x.reshape(b, c, -1).permute(0, 2, 1)
923
- x = self.mlp(x)
924
-
925
-
926
- if x2 is not None:
927
- h2, w2 = size2
928
- x2 = x2.permute(0, 2, 1).reshape(x2.shape[0], -1, h2, w2)
929
- if self.poolmlp:
930
- x2 = F.avg_pool2d(x2, 2)
931
- x2 = self.proj(x2.permute(0, 2, 3, 1)) # b, h, w, c
932
- x2 = self.vit_linear_encoder(x2)
933
- b2, h2, w2, c2 = x2.shape
934
- if not remove_vit_special_tokens:
935
- x2 = torch.cat([
936
- x2,
937
- self.image_newline.reshape(1, 1, 1, c2).expand(b2, h2, 1, c2).to(dtype, non_blocking=True)
938
- ], dim=2)
939
- x2 = x2.reshape(b2, -1, c2)
940
- else:
941
- x2 = self.proj(x2)
942
- b2, c2, h2, w2 = x2.shape
943
- if not remove_vit_special_tokens:
944
- x2 = torch.cat([
945
- x2,
946
- self.image_newline.reshape(1, c2, 1, 1).expand(b2, c2, h2, 1).to(dtype, non_blocking=True)
947
- ], dim=-1)
948
- x2 = x2.reshape(b2, c2, -1).permute(0, 2, 1) #b,n,c
949
- x2 = self.mlp(x2)
950
-
951
- sep = self.image_sep.reshape(1, 1, -1).expand(b2, 1, x2.shape[-1]).to(dtype, non_blocking=True)
952
-
953
- x = torch.cat([x, sep, x2], dim=1)
954
-
955
- if self.cat_extra_token:
956
- begin = self.image_begin.reshape(1, 1, -1).expand(b, 1, x.shape[-1]).to(dtype, non_blocking=True)
957
- end = self.image_end.reshape(1, 1, -1).expand(b, 1, x.shape[-1]).to(dtype, non_blocking=True)
958
- x = torch.cat([begin, x, end], dim=1)
959
-
960
- if self.use_rms_norm:
961
- return self.after_rms(x)
962
- else:
963
- return x
964
-
965
-
966
- class NormalizedDwPooler(nn.Module):
967
- def __init__(self, dim):
968
- super().__init__()
969
- self.dim = dim
970
- self.predictor = nn.Sequential(
971
- nn.Linear(dim*2, dim),
972
- nn.GELU(),
973
- nn.Linear(dim, dim),
974
- )
975
-
976
- def forward(self, x, forward_type='2x'):
977
- B, H, W, C = x.shape
978
-
979
- if forward_type == '2x':
980
- new_x = x.reshape(B, H//2, 2, W//2, 2, C).permute(0, 1, 3, 2, 4, 5).reshape(B, H//2, W//2, 4, C)
981
- pooled_x = new_x.mean(-2, keepdim=True).expand(-1, -1, -1, 4, -1)
982
- fused_x = torch.cat([new_x, pooled_x], dim=-1)
983
- elif forward_type == '1x':
984
- new_x = x.reshape(B, H, W, 1, C)
985
- fused_x = torch.cat([new_x, new_x], dim=-1)
986
- elif forward_type == '4x':
987
- new_x = x.reshape(B, H//4, 4, W//4, 4, C).permute(0, 1, 3, 2, 4, 5).reshape(B, H//4, W//4, 16, C)
988
- pooled_x = new_x.mean(-2, keepdim=True).expand(-1, -1, -1, 16, -1)
989
- fused_x = torch.cat([new_x, pooled_x], dim=-1)
990
-
991
- score = self.predictor(fused_x)
992
- normalized_score = F.softmax(score, dim=-2)
993
- new_x = (new_x * normalized_score).sum(dim=-2)
994
- return new_x
995
-
996
-
997
- class OryxMLPv2(nn.Module):
998
- def __init__(self, in_channels, out_channels, twoview=False, use_pe=False):
999
- super().__init__()
1000
-
1001
- self.proj1 = nn.Linear(in_channels, out_channels)
1002
- self.proj2 = nn.Linear(out_channels, out_channels)
1003
- self.act = nn.GELU()
1004
- self.pooler = NormalizedDwPooler(out_channels)
1005
- embed_std = 1 / math.sqrt(out_channels)
1006
-
1007
- self.use_pe = use_pe
1008
- if not use_pe:
1009
- self.image_newline = nn.Parameter(
1010
- torch.randn(out_channels) * embed_std
1011
- )
1012
- self.image_begin = nn.Parameter(
1013
- torch.randn(out_channels) * embed_std
1014
- )
1015
- self.image_end = nn.Parameter(
1016
- torch.randn(out_channels) * embed_std
1017
- )
1018
-
1019
- if twoview:
1020
- self.image_sep = nn.Parameter(
1021
- torch.randn(out_channels) * embed_std
1022
- )
1023
-
1024
- def forward(self, x, size=(16,16), x2=None, size2=(16, 16), is_video=False):
1025
- h, w = size
1026
- dtype = x.dtype
1027
- x = x.reshape(x.shape[0], h, w, -1)
1028
- # x = self.pooler(x, forward_type=REGIONAL_POOL)
1029
- # x = self.proj(x) #b,h,w, c
1030
- x = self.proj1(x)
1031
- x = self.pooler(x, forward_type='2x')
1032
- x = self.act(x)
1033
- x = self.proj2(x)
1034
-
1035
-
1036
- b, h, w, c = x.shape
1037
- if not self.use_pe:
1038
- x = torch.cat([
1039
- x,
1040
- self.image_newline.reshape(1, 1, 1, c).expand(b, h, 1, c).to(dtype)
1041
- ], dim=2)
1042
- else:
1043
- pe_h = torch.arange(h, dtype=torch.long, device=x.device).reshape(1, h, 1, 1).expand(b, h, w, 1).reshape(b, h*w, 1)
1044
- pe_w = torch.arange(w, dtype=torch.long, device=x.device).reshape(1, 1, w, 1).expand(b, h, w, 1).reshape(b, h*w, 1)
1045
- pe = torch.cat([pe_h, pe_w], dim=-1)
1046
-
1047
- x = x.reshape(b, -1, c)
1048
-
1049
- if x2 is not None:
1050
- h2, w2 = size2
1051
- x2 = x2.reshape(x2.shape[0], h2, w2, -1)
1052
- # x2 = self.pooler(x2, forward_type=REGIONAL_POOL)
1053
- ## x2 = self.proj(x2) #b,h,w, c
1054
- x2 = self.proj1(x2)
1055
- x2 = self.pooler(x2, forward_type='2x')
1056
- x2 = self.act(x2)
1057
- x2 = self.proj2(x2)
1058
-
1059
- b2, h2, w2, c2 = x2.shape
1060
- if not self.use_pe:
1061
- x2 = torch.cat([
1062
- x2,
1063
- self.image_newline.reshape(1, 1, 1, c).expand(b, h2, 1, c).to(dtype)
1064
- ], dim=2)
1065
- x2 = x2.reshape(b, -1, c)
1066
- sep = self.image_sep.reshape(1, 1, -1).expand(b, 1, c2).to(dtype)
1067
- x = torch.cat([x, sep, x2], dim=1)
1068
-
1069
- begin = self.image_begin.reshape(1, 1, -1).expand(b, 1, c).to(dtype)
1070
- end = self.image_end.reshape(1, 1, -1).expand(b, 1, c).to(dtype)
1071
- x = torch.cat([begin, x, end], dim=1)
1072
- # print(x.shape, x2.shape, h, w, h2, w2)
1073
- # print("vit rank = " + str(torch.distributed.get_rank()) +" x = " + str(x))
1074
- if self.use_pe:
1075
- zero_pad = torch.zeros(b, 1, 2, device=x.device, dtype=torch.long)
1076
- pe = torch.cat([zero_pad, pe, zero_pad], dim=1)
1077
- assert pe.shape[1] == x.shape[1]
1078
- return x, pe
1079
- else:
1080
- nseq = x.shape[1]
1081
- fake_pe = torch.zeros(b, nseq, 2, device=x.device, dtype=torch.long)
1082
- return x #, fake_pe
1083
-