File size: 49,500 Bytes
350220c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 722 723 724 725 726 727 728 729 730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745 746 747 748 749 750 751 752 753 754 755 756 757 758 759 760 761 762 763 764 765 766 767 768 769 770 771 772 773 774 775 776 777 778 779 780 781 782 783 784 785 786 787 788 789 790 791 792 793 794 795 796 797 798 799 800 801 802 803 804 805 806 807 808 809 810 811 812 813 814 815 816 817 818 819 820 821 822 823 824 825 826 827 828 829 830 831 832 833 834 835 836 837 838 839 840 841 842 843 844 845 846 847 848 849 850 851 852 853 854 855 856 857 858 859 860 861 862 863 864 865 866 867 868 869 870 871 872 873 874 875 876 877 878 879 880 881 882 883 884 885 886 887 888 889 890 891 892 893 894 895 896 897 898 899 900 901 902 903 904 905 906 907 908 909 910 911 912 913 914 915 916 917 918 919 920 921 922 923 924 925 926 927 928 929 930 931 932 933 934 935 936 937 938 939 940 941 942 943 944 945 946 947 948 949 950 951 952 953 954 955 956 957 958 959 960 961 962 963 964 965 966 967 968 969 970 971 972 973 974 975 976 977 978 979 980 981 982 983 984 985 986 987 988 989 990 991 992 993 994 995 996 997 998 999 1000 1001 1002 1003 1004 1005 1006 1007 1008 1009 1010 1011 1012 1013 1014 1015 1016 1017 1018 1019 1020 1021 1022 1023 1024 1025 1026 1027 1028 1029 1030 1031 1032 1033 1034 1035 1036 1037 1038 1039 1040 1041 1042 1043 1044 1045 1046 1047 1048 1049 1050 1051 1052 1053 1054 1055 1056 1057 1058 1059 1060 1061 1062 1063 1064 1065 1066 1067 1068 1069 1070 1071 1072 1073 1074 1075 1076 1077 1078 1079 1080 1081 1082 1083 1084 |
import json
import types
import math
import torch
from torch import Tensor, nn
import torch.nn.functional as F
from typing import List, Tuple, Optional, Union
from contextlib import contextmanager
from transformers.modeling_attn_mask_utils import (
_prepare_4d_causal_attention_mask_for_sdpa,
_prepare_4d_causal_attention_mask_for_sdpa,
_prepare_4d_causal_attention_mask,
)
from transformers.models.clip.configuration_clip import CLIPVisionConfig
from transformers.modeling_outputs import BaseModelOutputWithPooling
from .modeling_hunyuan import HunYuanDecoderLayer, HunYuanRMSNorm
from .configuration_hunyuan import HunYuanConfig
def NaVitForward(input_ids, encoder_input, vit, image_tensors, images_pos, vit_input_resolution, im_start_id, im_end_id, image_token_id, anyres_vit_two_views, dtype):
# input_ids: (B, L)
# encoder_input: (L, B, E)
# image_tensors [[Tensor],...,[Tensor]]
# image_pos [[Tensor],...,[Tensor]]
# tokenizer = get_tokenizer()
b = len(input_ids)
img_embs = None
all_nums = sum([len(tensors) for tensors in image_tensors]) if image_tensors else 0
if all_nums != 0:
img_embs, img_batch_pos = vit(image_tensors)
else:
# when no input image, initialize a fake tensor
pad_nums = 1
image_tensors = [[torch.rand(3, vit_input_resolution, vit_input_resolution, dtype=dtype, device=torch.cuda.current_device()) for _ in range(pad_nums)]]
img_embs, img_batch_pos = vit(image_tensors)
encoder_input = encoder_input.clone()
if all_nums > 0:
assert len(images_pos) == len(img_batch_pos), \
(len(images_pos), len(img_batch_pos))
start_token_id = im_start_id
end_token_id = im_end_id
placeholder_id = image_token_id
for idx in range(len(images_pos)):
assert len(images_pos[idx]) == len(img_batch_pos[idx]), \
(len(images_pos[idx]), len(img_batch_pos[idx]))
for p_img_pos_in_batch, p_batch_img_pos in zip(img_batch_pos[idx], images_pos[idx]):
# the positions to be filled [s_start, s_end)
s_idx, s_start, s_end = p_img_pos_in_batch
current_embs = img_embs[s_idx, s_start:s_end]
im_s, im_e = p_batch_img_pos
assert len(current_embs) == im_e - im_s, \
(img_embs.shape, (s_start, s_end, s_idx), current_embs.shape, (im_s, im_e, idx))
if not anyres_vit_two_views:
assert input_ids[idx, im_s - 1] == start_token_id, \
input_ids[idx, im_s - 1]
assert input_ids[idx, im_e] == end_token_id, \
input_ids[idx, im_e]
assert (input_ids[idx, im_s:im_e] == placeholder_id).all(), \
f'The tokens to be filled are not the placeholder_id {placeholder_id}: {(input_ids[idx, im_s:im_e] == placeholder_id).sum()} vs {im_e - im_s}'
encoder_input[idx, im_s:im_e] = current_embs
else:
# when no input image, to mask vit value
vit_mask = torch.zeros([1, img_embs.shape[0]], device=torch.cuda.current_device())
current_embs = img_embs[0, :]
encoder_input[0, 1:img_embs.shape[0] + 1] = encoder_input[0, 1:img_embs.shape[0] + 1] * (1 - vit_mask) + current_embs * vit_mask
return encoder_input, input_ids
def VitForward(input_ids, encoder_input, vit, vit_linear_encoder, image_tensors, images_pos, vit_input_resolution, vit_mapping_type, vit_patch, vit_token):
vit_patch_mlp = (vit_patch > 1 and vit_mapping_type == 'mlp') or vit_patch == 0
b = len(input_ids)
if images_pos is None:
images_pos = torch.ones([len(input_ids), 1, 3])
images_pos[:, :, 1] = images_pos[:, :, 1]*(vit_token + 1)
images_pos = images_pos.long()
real_image_nums = []
image_tensors = image_tensors.view(b, -1, 3, vit_input_resolution, vit_input_resolution)
real_images = []
all_nums = 0
img_index = []
for s in range(len(images_pos)):
real_image_num = 0
for (im_s, im_e,index) in images_pos[s]:
if im_s == -1:
break
real_image_num += 1
all_nums += 1
img_index.append(index)
real_image_nums.append(real_image_num)
real_images.append(image_tensors[s][:real_image_num])
if vit_patch == 1:
img_index = None
if all_nums == 0:
# when no input image, initialize a fake tensor
img_input = torch.rand(b, 3, vit_input_resolution, vit_input_resolution).cuda().type(image_tensors.dtype)
img_embs = vit(img_input)
img_embs = vit_linear_encoder(img_embs)
else:
img_input = torch.cat(real_images)
img_embs = vit(img_input, img_index = img_index)
img_embs = vit_linear_encoder(img_embs)
encoder_input = encoder_input.clone()
start = 0
if all_nums > 0:
for s, real_image_len in enumerate(real_image_nums):
current_embs = img_embs[start:start + real_image_len, :] #[30, 256, 4096]
for ss in range(current_embs.shape[0]):
im_s, im_e, index = images_pos[s, ss]
# 子图特征更少
if index > 0 and vit_patch_mlp:
encoder_input[s, im_s:im_e,] = current_embs[ss, :(im_e-im_s)]
else:
encoder_input[s, im_s:im_e] = current_embs[ss, :]
start = start + real_image_len
else:
# when no input image, to mask vit value
for s in range(b):
vit_mask = torch.zeros([vit_token, 1]).cuda()
current_embs = img_embs[:, start:start + 1]
encoder_input[1:vit_token + 1, s] = encoder_input[1:vit_token + 1, s] * (1 - vit_mask) + current_embs[:, 0, :] * vit_mask
start = start + 1
return encoder_input, input_ids
def group_images_by_max_seq_len(
images: List[List[Tensor]], patch_size: int,
max_seq_len: int, adaptor_patch_size: int,
add_cls_token: bool = False) -> List[List[Tensor]]:
groups = []
group = []
pos_groups = []
seq_len = 0
num_images = 0
for image_list in images:
pos_group = []
for image in image_list:
num_images += 1
assert isinstance(image, Tensor)
image_dims = image.shape[-2:]
ph, pw = map(lambda t: t // patch_size, image_dims)
image_seq_len = (ph * pw)
new_image_seq_len = image_seq_len
grouped_len = seq_len + image_seq_len
if add_cls_token:
new_image_seq_len += 1
grouped_len += num_images
assert new_image_seq_len <= max_seq_len, f'image with dimensions {image_dims} exceeds maximum sequence length'
if grouped_len > max_seq_len:
groups.append(group)
group = []
seq_len = 0
num_images = 1
group.append(image)
start = seq_len // (adaptor_patch_size * adaptor_patch_size)
end = start + image_seq_len//(adaptor_patch_size * adaptor_patch_size)
batch_idx = len(groups)
pos_group.append([batch_idx, start, end])
seq_len += image_seq_len
pos_groups.append(pos_group)
if len(group) > 0:
groups.append(group)
return groups, pos_groups
class AnyResCLIPVisionEmbeddings(nn.Module):
def __init__(self, config: CLIPVisionConfig):
super().__init__()
self.config = config
# self.sparse_attn_mask = args.sparse_attn_mask
# self.use_flash_attn = args.use_flash_attn
self.embed_dim = config.hidden_size
self.image_size = config.max_image_size
self.patch_size = config.patch_size
self.max_seq_len = config.max_vit_seq_len
self.adaptor_patch_size = config.adaptor_patch_size
self.anyres_vit_two_views = config.anyres_vit_two_views
self.vit_add_patchemb_bias = config.vit_add_patchemb_bias
self.vit_remove_prenorm = config.vit_remove_prenorm
self.patch_embedding = nn.Conv2d(
in_channels=config.num_channels,
out_channels=self.embed_dim,
kernel_size=self.patch_size,
stride=self.patch_size,
bias=self.vit_add_patchemb_bias,
)
self.num_patches = (self.image_size // self.patch_size) ** 2
self.skip_cls_token = True
# add interpolate_pos_encoding
if self.anyres_vit_two_views:
self.num_positions = self.num_patches
self.position_embedding = nn.Parameter(torch.randn(1, self.num_positions, self.embed_dim) * 0.02)
else:
self.num_positions = self.num_patches + 1
self.register_buffer("position_ids", torch.arange(self.num_positions).expand((1, -1)))
# self.position_ids = torch.arange(self.num_positions).expand((1, -1))
self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim)
if not self.vit_remove_prenorm:
self.pre_layernorm = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor:
"""
This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher
resolution images.
Source:
https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174
"""
num_patches = embeddings.shape[1]
position_embeddings = self.position_embedding(self.position_ids)
patch_pos_embed = position_embeddings[:, 1:]
num_positions = position_embeddings.shape[1] - 1
if num_patches == num_positions and height == width:
return patch_pos_embed
# class_pos_embed = position_embeddings[:, 0]
dim = embeddings.shape[-1]
h0 = height // self.patch_size
w0 = width // self.patch_size
# we add a small number to avoid floating point error in the interpolation
# see discussion at https://github.com/facebookresearch/dino/issues/8
h0, w0 = h0 + 0.1, w0 + 0.1
patch_pos_embed = patch_pos_embed.reshape(1, int(math.sqrt(num_positions)), int(math.sqrt(num_positions)), dim)
patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2)
raw_type = patch_pos_embed.dtype
patch_pos_embed = nn.functional.interpolate(
patch_pos_embed.to(torch.float32, non_blocking=True),
scale_factor=(h0 / math.sqrt(num_positions), w0 / math.sqrt(num_positions)),
mode="bilinear",
align_corners=False,
)
patch_pos_embed = patch_pos_embed.to(raw_type, non_blocking=True)
assert int(h0) == patch_pos_embed.shape[-2] and int(w0) == patch_pos_embed.shape[-1]
patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
return patch_pos_embed
def rescale_positional_embedding(self, out_size):
h, w = out_size
pos_embed_shape = int((self.position_embedding.shape[1]) ** 0.5)
if (h, w) == (pos_embed_shape, pos_embed_shape):
return self.position_embedding
rescaled_positional_embedding = \
self.position_embedding.new_zeros(1, h*w, self.position_embedding.shape[2])
pe_2d = self.position_embedding[0].T.contiguous().view(1, -1, pos_embed_shape, pos_embed_shape)
pe_2d = F.interpolate(pe_2d, out_size, mode='bilinear', align_corners=False).view(-1, h*w)
rescaled_positional_embedding[0] = pe_2d.T.contiguous()
return rescaled_positional_embedding
def forward_single(self, pixel_values: torch.FloatTensor) -> torch.Tensor:
if pixel_values.ndim == 3:
pixel_values = pixel_values[None]
batch_size, num_channels, height, width = pixel_values.shape
if self.anyres_vit_two_views:
# padding
pad_h = (self.patch_size - height % self.patch_size) % self.patch_size
pad_w = (self.patch_size - width % self.patch_size) % self.patch_size
pixel_values = F.pad(pixel_values, (0, pad_w, 0, pad_h))
patch_embeds = self.patch_embedding(pixel_values) # shape = [*, width, grid, grid]
b, c, h, w = patch_embeds.shape
# (b, hw, c)
patch_embeds = patch_embeds.flatten(2).transpose(1, 2)
if self.anyres_vit_two_views:
embeddings = patch_embeds + self.rescale_positional_embedding(out_size=(h, w))
else:
embeddings = patch_embeds + self.interpolate_pos_encoding(patch_embeds, height, width)
if not self.vit_remove_prenorm:
embeddings = self.pre_layernorm(embeddings)
return embeddings, (h, w)
def forward(self, images: List[List[Tensor]]):
'''
Input:
images: List[List[Tensor]]
Return:
embeddings: Tensor (B, L, E)
attn_mask: Tensor (B, L, 2)
pos_groups: List[List[(batch_idx, start, end)]]
'''
batched_images, pos_groups = group_images_by_max_seq_len(
images, self.patch_size, self.max_seq_len, self.adaptor_patch_size, add_cls_token=not self.skip_cls_token)
max_seq_len = self.max_seq_len
# batched_images is a list of a list
B = len(batched_images)
L = max_seq_len
E = self.embed_dim
embeddings = torch.zeros(B, L, E, dtype=self.config.torch_dtype, requires_grad=True).cuda(non_blocking=True)
attn_mask = embeddings.new_full((B, 1, L, L), False, dtype=torch.bool) # True presents compute
assert len(images) == len(pos_groups), (len(images), len(pos_groups))
batch_images = []
batch_pos = []
for images_i, pos_group in zip(images, pos_groups):
assert len(images_i) == len(pos_group), (len(images_i), len(pos_group))
for image, pos in zip(images_i, pos_group):
batch_idx, start, end = pos
a2 = self.adaptor_patch_size ** 2
# recover the real number of the input image tokens
start *= a2
end *= a2
emb, _ = self.forward_single(image)
assert emb.ndim == 3, '(B, L, E)'
embeddings[batch_idx, start:end] = emb
attn_mask[batch_idx, :, start:end, start:end] = True
return embeddings, attn_mask, pos_groups
class CLIPVisionEmbeddings(nn.Module):
def __init__(self, config: CLIPVisionConfig, add_pre_layernorm=False, skip_cls_token=True, vit_patch=1):
super().__init__()
self.config = config
self.embed_dim = config.hidden_size
self.image_size = config.image_size
self.image_size = config.vit_input_resolution
self.patch_size = config.patch_size
self.class_embedding = nn.Parameter(torch.randn(self.embed_dim))
self.patch_embedding = nn.Conv2d(
in_channels=config.num_channels,
out_channels=self.embed_dim,
kernel_size=self.patch_size,
stride=self.patch_size,
bias=False,
)
self.num_patches = (self.image_size // self.patch_size) ** 2
self.skip_cls_token = skip_cls_token
self.num_positions = self.num_patches + 1
self.register_buffer("position_ids", torch.arange(self.num_positions).expand((1, -1)))
if vit_patch > 1:
self.position_embedding = nn.Embedding(self.num_patches * (vit_patch ** 2 + 1) + 1, self.embed_dim)
# 0 支持最大16张图,目前写死了,如需其他的需要额外定义参数
elif vit_patch == 0:
self.position_embedding = nn.Embedding(self.num_patches * (16 ** 2 + 1) + 1, self.embed_dim)
else:
self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim)
if add_pre_layernorm:
self.pre_layernorm = nn.LayerNorm(self.embed_dim, eps=config.layer_norm_eps)
else:
self.pre_layernorm = None
def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor:
"""
This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher
resolution images.
Source:
https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174
"""
num_patches = embeddings.shape[1] - 1
position_embeddings = self.position_embedding(self.position_ids)
num_positions = position_embeddings.shape[1] - 1
if num_patches == num_positions and height == width:
return position_embeddings
class_pos_embed = position_embeddings[:, 0]
patch_pos_embed = position_embeddings[:, 1:]
dim = embeddings.shape[-1]
h0 = height // self.config.patch_size
w0 = width // self.config.patch_size
# we add a small number to avoid floating point error in the interpolation
# see discussion at https://github.com/facebookresearch/dino/issues/8
h0, w0 = h0 + 0.1, w0 + 0.1
patch_pos_embed = patch_pos_embed.reshape(1, int(math.sqrt(num_positions)), int(math.sqrt(num_positions)), dim)
patch_pos_embed = patch_pos_embed.permute(0, 3, 1, 2)
raw_type = patch_pos_embed.dtype
patch_pos_embed = nn.functional.interpolate(
patch_pos_embed.float(),
scale_factor=(h0 / math.sqrt(num_positions), w0 / math.sqrt(num_positions)),
mode="bicubic",
align_corners=False,
)
# print(patch_pos_embed.shape)
patch_pos_embed = patch_pos_embed.to(raw_type)
assert int(h0) == patch_pos_embed.shape[-2] and int(w0) == patch_pos_embed.shape[-1]
patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1)
def forward(self, pixel_values: torch.FloatTensor, interpolate_pos_encoding: bool = False, img_index=None) -> torch.Tensor:
batch_size, num_channels, height, width = pixel_values.shape
patch_embeds = self.patch_embedding(pixel_values) # shape = [*, width, grid, grid]
patch_embeds = patch_embeds.flatten(2).transpose(1, 2)
if self.skip_cls_token:
embeddings = patch_embeds
if img_index is None:
position_ids = self.position_ids[:,1:]
embeddings = embeddings + self.position_embedding(position_ids)
else:
position_ids = (torch.tensor(img_index).cuda() * (self.num_positions - 1)).unsqueeze(1).repeat(1, self.num_positions - 1) \
+ self.position_ids.expand(batch_size, -1)[:, 1:]
embeddings = embeddings + self.position_embedding(position_ids)
else:
class_embeds = self.class_embedding.expand(batch_size, 1, -1)
embeddings = torch.cat([class_embeds, patch_embeds], dim=1)
if interpolate_pos_encoding:
embeddings = embeddings + self.interpolate_pos_encoding(embeddings, height, width)
else:
if img_index is None:
embeddings = embeddings + self.position_embedding(self.position_ids)
else:
position_ids = self.position_ids.expand(batch_size,-1)[:,0].unsqueeze(1)
new_position = (torch.tensor(img_index).cuda() * (self.num_positions -1)).unsqueeze(1).repeat(1,self.num_positions-1) + self.position_ids.expand(batch_size,-1)[:,1:]
position_ids = torch.cat([position_ids,new_position],dim=1)
embeddings = embeddings + self.position_embedding(position_ids)
if self.pre_layernorm is not None:
embeddings = self.pre_layernorm(embeddings)
return embeddings
class NaVitTransformer(nn.Module):
def __init__(self, config: HunYuanConfig, vit_config: CLIPVisionConfig):
super().__init__()
self.config = config
self.vit_config = vit_config
with self.prepare_args(config, vit_config):
self._use_sdpa = config._attn_implementation == "sdpa"
self._use_flash_attention_2 = config._attn_implementation == "flash_attention_2"
self.layers = nn.ModuleList(
[HunYuanDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
)
@contextmanager
def prepare_args(self, config, vit_config):
hidden_act = config.hidden_act
hidden_size = config.hidden_size
ffn_hidden_size = config.intermediate_size
num_attention_heads = config.num_attention_heads
num_key_value_heads = config.num_key_value_heads
attention_head_dim = config.attention_head_dim
use_qk_norm = config.use_qk_norm
use_rotary_pos_emb = config.use_rotary_pos_emb
num_hidden_layers = config.num_hidden_layers
rms_norm_eps = config.rms_norm_eps
attention_dropout = config.attention_dropout
# hidden_dropout = config.hidden_dropout
norm_type = config.norm_type
attention_bias = config.attention_bias
mlp_bias = config.mlp_bias
use_mla = config.use_mla
num_experts = config.num_experts
_attn_implementation = config._attn_implementation
config.hidden_act = vit_config.hidden_act
config.hidden_size = vit_config.hidden_size
config.intermediate_size = vit_config.intermediate_size
config.num_attention_heads = vit_config.num_attention_heads
config.num_key_value_heads = None
config.attention_head_dim = vit_config.hidden_size // vit_config.num_attention_heads
config.use_qk_norm = False
config.use_rotary_pos_emb = False
config.num_hidden_layers = vit_config.num_hidden_layers
config.rms_norm_eps = vit_config.layer_norm_eps
config.attention_dropout = vit_config.attention_dropout
# config.hidden_dropout = vit_config.hidden_dropout
config.norm_type = config.vit_norm_type
config.attention_bias = True
config.mlp_bias = True
config.use_mla = False
config.num_experts = 1
config._attn_implementation = "eager"
yield
config.hidden_act = hidden_act
config.hidden_size = hidden_size
config.intermediate_size = ffn_hidden_size
config.num_attention_heads = num_attention_heads
config.num_key_value_heads = num_key_value_heads
config.attention_head_dim = attention_head_dim
config.use_qk_norm = use_qk_norm
config.use_rotary_pos_emb = use_rotary_pos_emb
config.num_hidden_layers = num_hidden_layers
config.rms_norm_eps = rms_norm_eps
config.attention_dropout = attention_dropout
# config.hidden_dropout = hidden_dropout
config.attention_bias = attention_bias
config.mlp_bias = mlp_bias
config.norm_type = norm_type
config.use_mla = use_mla
config.num_experts = num_experts
config._attn_implementation = _attn_implementation
def forward(
self,
pixel_values: Optional[torch.FloatTensor] = None,
) -> Union[Tuple, BaseModelOutputWithPooling]:
hidden_states, attention_mask, img_pos = self.embeddings(pixel_values)
attention_mask = attention_mask.int()
batch_size, seq_length, _ = hidden_states.shape
past_key_values_length = 0
if self._use_flash_attention_2:
# 2d mask is passed through the layers
attention_mask = attention_mask if (attention_mask is not None and 0 in attention_mask) else None
elif self._use_sdpa:
# output_attentions=True can not be supported when using SDPA, and we fall back on
# the manual implementation that requires a 4D causal mask in all cases.
attention_mask = _prepare_4d_causal_attention_mask_for_sdpa(
attention_mask,
(batch_size, seq_length),
hidden_states,
past_key_values_length,
)
else:
attention_mask = _prepare_4d_causal_attention_mask(
attention_mask,
(batch_size, seq_length),
hidden_states,
past_key_values_length,
)
for layer_idx, decoder_layer in enumerate(self.layers):
layer_outputs = decoder_layer(
hidden_states,
attention_mask=attention_mask
)
hidden_states = layer_outputs[0]
return hidden_states, img_pos
class AnyResVitTransformer(NaVitTransformer):
def __init__(self, config: HunYuanConfig, vit_config: CLIPVisionConfig, anyres_vit_max_image_size):
super().__init__(config, vit_config)
old_anyres_vit_max_image_size = vit_config.max_image_size
anyres_vit_max_image_size = anyres_vit_max_image_size or old_anyres_vit_max_image_size
vit_config.max_image_size = anyres_vit_max_image_size
vit_config.torch_dtype = config.torch_dtype
vit_config.anyres_vit_two_views = config.anyres_vit_two_views
vit_config.vit_remove_prenorm = config.vit_remove_prenorm
vit_config.vit_add_patchemb_bias = config.vit_add_patchemb_bias
self.embeddings = AnyResCLIPVisionEmbeddings(vit_config)
vit_config.max_image_size = old_anyres_vit_max_image_size
def fix_embeddings_fn(self, pixel_values):
# (B, L, E)
embeddings, hw = self.embeddings.forward_single(pixel_values)
embeddings = self.embeddings.pre_layernorm(embeddings)
return embeddings
class CLIPVisionTransformer(nn.Module):
def __init__(self, config: HunYuanConfig, vit_config: CLIPVisionConfig):
super().__init__()
embed_dim = vit_config.hidden_size
self.pre_layrnorm = nn.LayerNorm(embed_dim, eps=vit_config.layer_norm_eps)
self.embeddings = CLIPVisionEmbeddings(vit_config, skip_cls_token=config.skip_cls_token, vit_patch=config.vit_patch)
with self.prepare_args(config, vit_config):
self.layers = nn.ModuleList(
[HunYuanDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]
)
@contextmanager
def prepare_args(self, config, vit_config):
hidden_act = config.hidden_act
hidden_size = config.hidden_size
ffn_hidden_size = config.intermediate_size
num_attention_heads = config.num_attention_heads
num_key_value_heads = config.num_key_value_heads
attention_head_dim = config.attention_head_dim
use_qk_norm = config.use_qk_norm
use_rotary_pos_emb = config.use_rotary_pos_emb
num_hidden_layers = config.num_hidden_layers
rms_norm_eps = config.rms_norm_eps
attention_dropout = config.attention_dropout
# hidden_dropout = config.hidden_dropout
norm_type = config.norm_type
attention_bias = config.attention_bias
mlp_bias = config.mlp_bias
use_mla = config.use_mla
num_experts = config.num_experts
_attn_implementation = config._attn_implementation
config.hidden_act = vit_config.hidden_act
config.hidden_size = vit_config.hidden_size
config.intermediate_size = vit_config.intermediate_size
config.num_attention_heads = vit_config.num_attention_heads
config.num_key_value_heads = None
config.attention_head_dim = vit_config.hidden_size // vit_config.num_attention_heads
config.use_qk_norm = False
config.use_rotary_pos_emb = False
config.num_hidden_layers = vit_config.num_hidden_layers
config.rms_norm_eps = vit_config.layer_norm_eps
config.attention_dropout = vit_config.attention_dropout
# config.hidden_dropout = 0.0
config.norm_type = "fused"
config.attention_bias = True
config.mlp_bias = True
config.use_mla = False
config.num_experts = 1
config._attn_implementation = "eager"
yield
config.hidden_act = hidden_act
config.hidden_size = hidden_size
config.intermediate_size = ffn_hidden_size
config.num_attention_heads = num_attention_heads
config.num_key_value_heads = num_key_value_heads
config.attention_head_dim = attention_head_dim
config.use_qk_norm = use_qk_norm
config.use_rotary_pos_emb = use_rotary_pos_emb
config.num_hidden_layers = num_hidden_layers
config.rms_norm_eps = rms_norm_eps
config.attention_dropout = attention_dropout
# config.hidden_dropout = hidden_dropout
config.norm_type = norm_type
config.attention_bias = attention_bias
config.mlp_bias = mlp_bias
config.use_mla = use_mla
config.num_experts = num_experts
config._attn_implementation = _attn_implementation
def forward(
self,
pixel_values: Optional[torch.FloatTensor] = None,
interpolate_pos_encoding: Optional[bool] = None,
img_index=None
) -> Union[Tuple, BaseModelOutputWithPooling]:
r"""
Returns:
"""
hidden_states = self.embeddings(pixel_values, interpolate_pos_encoding=interpolate_pos_encoding, img_index=img_index)
hidden_states = self.pre_layrnorm(hidden_states)
batch = hidden_states.shape[0]
seq_len = hidden_states.shape[1]
device = hidden_states.device
attention_mask = torch.ones(batch, 1, seq_len, seq_len, dtype=torch.float32, device=device)
for layer_idx, decoder_layer in enumerate(self.layers):
layer_outputs = decoder_layer(
hidden_states,
attention_mask=attention_mask
)
hidden_states = layer_outputs[0]
return hidden_states
class Vit(torch.nn.Module):
def __init__(self, config, resampler_token=64, pool_rate=2):
super().__init__()
self.config = config
self.vit_mapping_type = config.vit_mapping_type
self.anyres_vit_max_image_size = config.anyres_vit_max_image_size
self.skip_cls_token = config.skip_cls_token
self.pool_rate = pool_rate
self.vit_type = self.config.vit_type
self.anyres_vit_two_views = self.config.anyres_vit_two_views
if self.vit_type in ['Vit-g', 'Vit-bigG', 'NaVit', 'EvaVit', 'AnyResVit']:
self.img_init(resampler_token, config.vit_input_resolution, config.vit_mapping_type, pool_rate)
else:
raise NotImplementedError(f"unsupported vit type: {self.vit_type}")
def img_init(self, resampler_token=64, vit_input_resolution=224, vit_mapping_type='resampler', pool_rate=2):
if self.vit_type == 'AnyResVit':
vit_config = json.load(open(f"{self.config.vit_path}/config.json"))
self.vit_config = types.SimpleNamespace(**vit_config["vision_config"])
self.vit_config.image_size = vit_input_resolution
self.vit = AnyResVitTransformer(self.config, self.vit_config, self.anyres_vit_max_image_size)
elif self.vit_type == 'Vit-g':
vit_config = json.load(open(f"{self.config.vit_path}/config.json"))
self.vit_config = types.SimpleNamespace(**{**vit_config["vision_config_dict"],**vit_config["vision_config"]})
self.vit_config.vit_input_resolution = vit_input_resolution
self.vit = CLIPVisionTransformer(self.config, self.vit_config)
else:
assert False, "other vit_types are not supported"
if self.vit_mapping_type == 'simple_conv_mlp':
self.perceive = SimpleConvMlp(self.vit_config.hidden_size, self.config.hidden_size, self.config.anyres_pooling_size, \
self.config.vit_used_rms_norm, self.config.rms_norm_eps, poolmlp=False, twoview=True)
elif self.vit_mapping_type == 'oryx_mlp':
self.perceive = OryxMLPv2(self.vit_config.hidden_size, self.config.hidden_size, twoview=True, use_pe=False)
elif self.vit_mapping_type == 'mlp':
self.mlp_depth = 2
# one mlp layer already in gpt_model.py
mlp_hidden_size = self.vit_config.hidden_size
if self.vit_type in ['NaVit', 'EvaVit']:
mlp_hidden_size *= self.vit_config.adaptor_patch_size **2
if self.mlp_depth > 1:
mlp_modules = [torch.nn.Linear(mlp_hidden_size, self.config.hidden_size), torch.nn.GELU()]
if self.vit_type in ['NaVit', 'EvaVit']:
for _ in range(1, self.mlp_depth):
mlp_modules.append(torch.nn.Linear(self.config.hidden_size, self.config.hidden_size))
mlp_modules.append(torch.nn.GELU())
self.perceive = torch.nn.Sequential(*mlp_modules)
else:
assert False, "other vit_mapping_types are not supported"
self.vit_patch_mlp = (self.config.vit_patch > 1 and self.vit_mapping_type == 'mlp') or self.config.vit_patch == 0
for name, param in self.named_parameters():
setattr(param, "is_vit_param", True)
def forward(self, images, img_index=None):
if self.vit_type in ['AnyResVit']:
dtype = self.config.torch_dtype
device = torch.cuda.current_device()
images_size = []
for i in range(len(images)):
images_size.append([])
for j in range(len(images[i])):
images_size[i].append((images[i][j].size()[1] // self.vit_config.patch_size, images[i][j].size()[2] // self.vit_config.patch_size))
images_feats, img_batch_pos = self.vit(pixel_values=images)
a2 = self.vit_config.adaptor_patch_size ** 2
if self.anyres_vit_two_views:
step = 2
else:
step = 1
perceive_fn = lambda x, img_size, is_video: self.perceive(x, img_size, is_video=is_video)
images_list = []
images_fix_i = 0
num_img_batch_pos = len(img_batch_pos)
for i in range(num_img_batch_pos): # batch_id
for j in range(0, len(img_batch_pos[i]), step):
if self.anyres_vit_two_views:
lower_idx, lower_begin, lower_end = img_batch_pos[i][j]
lower_begin = lower_begin * a2
lower_end = lower_end * a2
higher_idx, higher_begin, higher_end = img_batch_pos[i][j + 1]
higher_begin = higher_begin * a2
higher_end = higher_end * a2
lower_res_feat = images_feats[lower_idx, lower_begin:lower_end].unsqueeze(0)
higher_res_feat = images_feats[higher_idx, higher_begin:higher_end].unsqueeze(0)
lower_images_size = images_size[i][j]
higher_images_size = images_size[i][j + 1]
images_list.append(self.perceive(lower_res_feat, lower_images_size, higher_res_feat, higher_images_size))
else:
idx, begin, end = img_batch_pos[i][j]
begin = begin * a2
end = end * a2
is_video = hasattr(images[i][j],'_is_video') and images[i][j]._is_video
images_list.append(perceive_fn(images_feats[idx, begin:end].unsqueeze(0), images_size[i][j], is_video=is_video))
images = torch.cat(images_list, dim=1)
new_batch_pos = []
k = 0; cur_len = 0
for i in range(len(images_size)):
new_batch_pos.append([])
for j in range(0, len(images_size[i]), step):
new_pos = [0, cur_len, cur_len + images_list[k].size(1)]
cur_len += images_list[k].size(1)
k += 1
new_batch_pos[i].append(new_pos)
return images, new_batch_pos
elif self.vit_type == 'Vit-g':
images = self.vit(pixel_values=images, interpolate_pos_encoding=False, img_index=img_index)
else:
assert False, "other vit_types are not supported"
if self.vit_mapping_type == 'mlp':
if self.vit_type in ['Vit-g'] and not self.skip_cls_token:
images = images[:,1:,:]
b, v, d = images.shape
s = int(math.sqrt(v))
images = images.reshape(b, s, s, d)
if self.vit_patch_mlp and img_index is not None:
L_tensor = torch.tensor(img_index)
device = images.device
# 获取子图位置
nonzero_indices = torch.nonzero(L_tensor).squeeze().to(device)
# 获取主图位置
zero_indices = torch.nonzero(L_tensor == 0).squeeze().to(device)
images_nonzero = torch.index_select(images,0, nonzero_indices).to(device)
images_zero = torch.index_select(images, 0, zero_indices).to(device)
# 子图额外多pool一次
pool_rate = self.pool_rate * 2
images_nonzero = images_nonzero.reshape(-1, s // pool_rate, pool_rate, s // pool_rate, pool_rate, d)
images_nonzero = images_nonzero.permute(0, 1, 3, 5, 2, 4).reshape(-1, (s // pool_rate) * (s // pool_rate), d,
pool_rate*pool_rate).mean(-1)
# 为了组batch折衷方案
images_nonzero = F.pad(images_nonzero, (0, 0, 0, (s // self.pool_rate) * (s // self.pool_rate)- (s // pool_rate) * (s // pool_rate)))
images_zero = images_zero.reshape(-1, s // self.pool_rate, self.pool_rate, s // self.pool_rate, self.pool_rate, d)
images_zero = images_zero.permute(0, 1, 3, 5, 2, 4).reshape(-1, (s // self.pool_rate) * (s // self.pool_rate), d,
self.pool_rate*self.pool_rate).mean(-1)
# 组batch
images = torch.zeros(b, (s // self.pool_rate) * (s // self.pool_rate), d).to(device).to(images.dtype)
images.index_copy_(0, nonzero_indices, images_nonzero)
images.index_copy_(0, zero_indices, images_zero)
if self.mlp_depth >= 2:
images = self.perceive(images)
else:
if s % self.pool_rate == 0:
images = images.reshape(b, s//self.pool_rate, self.pool_rate, s//self.pool_rate, self.pool_rate, d)
images = images.permute(0, 1, 3, 5, 2, 4).reshape(b, (s//self.pool_rate) * (s//self.pool_rate), d, -1).mean(-1)
if self.mlp_depth >= 2:
images = self.perceive(images)
else:
raise ValueError
return images
class SimpleConvMlp(nn.Module):
def __init__(self, in_channels, out_channels, anyres_pooling_size, vit_used_rms_norm, rms_norm_eps, twoview=False, poolmlp=True, cat_extra_token=True):
super().__init__()
embed_std = 1 / math.sqrt(out_channels)
if poolmlp:
# if args.learnable_mlp_pooling_size is not None:
# in_channels *= args.learnable_mlp_pooling_size ** 2
self.proj = nn.Sequential(
nn.Linear(in_channels, out_channels),
nn.GELU()
)
self.vit_linear_encoder = nn.Linear(out_channels, out_channels)
self.image_newline = nn.Parameter(
torch.randn(out_channels) * embed_std
)
else:
self.proj = nn.Sequential(
nn.Conv2d(in_channels, in_channels * 2, kernel_size=anyres_pooling_size, stride=anyres_pooling_size),
nn.GELU(),
nn.Conv2d(in_channels * 2, in_channels * 4, kernel_size=1),
)
self.mlp = nn.Linear(in_channels * 4, out_channels)
self.image_newline = nn.Parameter(
torch.randn(in_channels * 4) * embed_std
)
self.poolmlp = poolmlp
self.image_begin = nn.Parameter(
torch.randn(out_channels) * embed_std
)
self.image_end = nn.Parameter(
torch.randn(out_channels) * embed_std
)
if twoview:
self.image_sep = nn.Parameter(
torch.randn(out_channels) * embed_std
)
self.cat_extra_token = cat_extra_token
self.use_rms_norm = vit_used_rms_norm
if self.use_rms_norm:
self.before_rms = HunYuanRMSNorm(in_channels, eps=rms_norm_eps)
self.after_rms = HunYuanRMSNorm(out_channels, eps=rms_norm_eps)
def forward(self, x, size=(16,16), x2=None, size2=(16, 16), is_video=False):
return self.single_forward(x=x, size=size, x2=x2, size2=size2, is_video=is_video)
def single_forward(self, x, size=(16,16), x2=None, size2=(16, 16), is_video=False):
remove_vit_special_tokens = False
learnable_mlp_pooling_size = None
if self.use_rms_norm:
x = self.before_rms(x)
h, w = size
dtype = x.dtype
x = x.permute(0, 2, 1).reshape(x.shape[0], -1, h, w)
if self.poolmlp:
if learnable_mlp_pooling_size is None:
x = F.avg_pool2d(x, anyres_pooling_size)
x = self.proj(x.permute(0, 2, 3, 1)) # b, h, w, c
else:
x = x.permute(0, 2, 3, 1) # b, h, w, c
x = x.reshape(x.shape[0], h // learnable_mlp_pooling_size, learnable_mlp_pooling_size,
w // learnable_mlp_pooling_size, learnable_mlp_pooling_size, -1)
x = x.permute(0, 1, 3, 2, 4, 5).reshape(x.shape[0], h // learnable_mlp_pooling_size, w // learnable_mlp_pooling_size, -1)
x = self.proj(x)
x = self.vit_linear_encoder(x)
b, h, w, c = x.shape
if not remove_vit_special_tokens:
x = torch.cat([
x,
self.image_newline.reshape(1, 1, 1, c).expand(b, h, 1, c).to(dtype, non_blocking=True)
], dim=2)
x = x.reshape(b, -1, c)
else:
x = self.proj(x) #b,c,h,w
if is_video:
video_avgpool_size = 2
stride = 2
x = F.avg_pool2d(x, kernel_size = video_avgpool_size, stride = stride)
b, c, h, w = x.shape
if not remove_vit_special_tokens:
x = torch.cat([
x,
self.image_newline.reshape(1, c, 1, 1).expand(b, c, h, 1).to(dtype, non_blocking=True)
], dim=-1)
x = x.reshape(b, c, -1).permute(0, 2, 1)
x = self.mlp(x)
if x2 is not None:
h2, w2 = size2
x2 = x2.permute(0, 2, 1).reshape(x2.shape[0], -1, h2, w2)
if self.poolmlp:
x2 = F.avg_pool2d(x2, 2)
x2 = self.proj(x2.permute(0, 2, 3, 1)) # b, h, w, c
x2 = self.vit_linear_encoder(x2)
b2, h2, w2, c2 = x2.shape
if not remove_vit_special_tokens:
x2 = torch.cat([
x2,
self.image_newline.reshape(1, 1, 1, c2).expand(b2, h2, 1, c2).to(dtype, non_blocking=True)
], dim=2)
x2 = x2.reshape(b2, -1, c2)
else:
x2 = self.proj(x2)
b2, c2, h2, w2 = x2.shape
if not remove_vit_special_tokens:
x2 = torch.cat([
x2,
self.image_newline.reshape(1, c2, 1, 1).expand(b2, c2, h2, 1).to(dtype, non_blocking=True)
], dim=-1)
x2 = x2.reshape(b2, c2, -1).permute(0, 2, 1) #b,n,c
x2 = self.mlp(x2)
sep = self.image_sep.reshape(1, 1, -1).expand(b2, 1, x2.shape[-1]).to(dtype, non_blocking=True)
x = torch.cat([x, sep, x2], dim=1)
if self.cat_extra_token:
begin = self.image_begin.reshape(1, 1, -1).expand(b, 1, x.shape[-1]).to(dtype, non_blocking=True)
end = self.image_end.reshape(1, 1, -1).expand(b, 1, x.shape[-1]).to(dtype, non_blocking=True)
x = torch.cat([begin, x, end], dim=1)
if self.use_rms_norm:
return self.after_rms(x)
else:
return x
class NormalizedDwPooler(nn.Module):
def __init__(self, dim):
super().__init__()
self.dim = dim
self.predictor = nn.Sequential(
nn.Linear(dim*2, dim),
nn.GELU(),
nn.Linear(dim, dim),
)
def forward(self, x, forward_type='2x'):
B, H, W, C = x.shape
if forward_type == '2x':
new_x = x.reshape(B, H//2, 2, W//2, 2, C).permute(0, 1, 3, 2, 4, 5).reshape(B, H//2, W//2, 4, C)
pooled_x = new_x.mean(-2, keepdim=True).expand(-1, -1, -1, 4, -1)
fused_x = torch.cat([new_x, pooled_x], dim=-1)
elif forward_type == '1x':
new_x = x.reshape(B, H, W, 1, C)
fused_x = torch.cat([new_x, new_x], dim=-1)
elif forward_type == '4x':
new_x = x.reshape(B, H//4, 4, W//4, 4, C).permute(0, 1, 3, 2, 4, 5).reshape(B, H//4, W//4, 16, C)
pooled_x = new_x.mean(-2, keepdim=True).expand(-1, -1, -1, 16, -1)
fused_x = torch.cat([new_x, pooled_x], dim=-1)
score = self.predictor(fused_x)
normalized_score = F.softmax(score, dim=-2)
new_x = (new_x * normalized_score).sum(dim=-2)
return new_x
class OryxMLPv2(nn.Module):
def __init__(self, in_channels, out_channels, twoview=False, use_pe=False):
super().__init__()
self.proj1 = nn.Linear(in_channels, out_channels)
self.proj2 = nn.Linear(out_channels, out_channels)
self.act = nn.GELU()
self.pooler = NormalizedDwPooler(out_channels)
embed_std = 1 / math.sqrt(out_channels)
self.use_pe = use_pe
if not use_pe:
self.image_newline = nn.Parameter(
torch.randn(out_channels) * embed_std
)
self.image_begin = nn.Parameter(
torch.randn(out_channels) * embed_std
)
self.image_end = nn.Parameter(
torch.randn(out_channels) * embed_std
)
if twoview:
self.image_sep = nn.Parameter(
torch.randn(out_channels) * embed_std
)
def forward(self, x, size=(16,16), x2=None, size2=(16, 16), is_video=False):
h, w = size
dtype = x.dtype
x = x.reshape(x.shape[0], h, w, -1)
# x = self.pooler(x, forward_type=REGIONAL_POOL)
# x = self.proj(x) #b,h,w, c
x = self.proj1(x)
x = self.pooler(x, forward_type='2x')
x = self.act(x)
x = self.proj2(x)
b, h, w, c = x.shape
if not self.use_pe:
x = torch.cat([
x,
self.image_newline.reshape(1, 1, 1, c).expand(b, h, 1, c).to(dtype)
], dim=2)
else:
pe_h = torch.arange(h, dtype=torch.long, device=x.device).reshape(1, h, 1, 1).expand(b, h, w, 1).reshape(b, h*w, 1)
pe_w = torch.arange(w, dtype=torch.long, device=x.device).reshape(1, 1, w, 1).expand(b, h, w, 1).reshape(b, h*w, 1)
pe = torch.cat([pe_h, pe_w], dim=-1)
x = x.reshape(b, -1, c)
if x2 is not None:
h2, w2 = size2
x2 = x2.reshape(x2.shape[0], h2, w2, -1)
# x2 = self.pooler(x2, forward_type=REGIONAL_POOL)
## x2 = self.proj(x2) #b,h,w, c
x2 = self.proj1(x2)
x2 = self.pooler(x2, forward_type='2x')
x2 = self.act(x2)
x2 = self.proj2(x2)
b2, h2, w2, c2 = x2.shape
if not self.use_pe:
x2 = torch.cat([
x2,
self.image_newline.reshape(1, 1, 1, c).expand(b, h2, 1, c).to(dtype)
], dim=2)
x2 = x2.reshape(b, -1, c)
sep = self.image_sep.reshape(1, 1, -1).expand(b, 1, c2).to(dtype)
x = torch.cat([x, sep, x2], dim=1)
begin = self.image_begin.reshape(1, 1, -1).expand(b, 1, c).to(dtype)
end = self.image_end.reshape(1, 1, -1).expand(b, 1, c).to(dtype)
x = torch.cat([begin, x, end], dim=1)
# print(x.shape, x2.shape, h, w, h2, w2)
# print("vit rank = " + str(torch.distributed.get_rank()) +" x = " + str(x))
if self.use_pe:
zero_pad = torch.zeros(b, 1, 2, device=x.device, dtype=torch.long)
pe = torch.cat([zero_pad, pe, zero_pad], dim=1)
assert pe.shape[1] == x.shape[1]
return x, pe
else:
nseq = x.shape[1]
fake_pe = torch.zeros(b, nseq, 2, device=x.device, dtype=torch.long)
return x #, fake_pe
|