Eagle2-9B / multi_backbone_channel_concatenation_encoder.py
Zhiding's picture
init
288b99c
raw
history blame
11.1 kB
import torch, os
import torch.nn as nn
from torch.utils.checkpoint import checkpoint
from .siglip_vision_tower import SiglipVisionTower
# from .hr_clip_encoder import HRCLIPVisionTower
# from .eva_vit import EVAVITVisionTower
# from .SAM.modeling_sam import SAMVisionTower
# from .pix2struct_large import Pix2StructLargeVisionTower
import torch.nn.functional as F
from torch.nn.init import trunc_normal_
from copy import deepcopy
import random
import math
class MultiBackboneChannelConcatenationVisionTower(nn.Module):
def __init__(self,
vision_tower,
args,
grid_size=32,
convnext_img_size=1024,
normalize_type=None, raw_config=None):
super().__init__()
self.is_loaded = False
self.grid_size = grid_size
self.num_tokens = self.grid_size ** 2
self.normalize_type = args.normalize_type
self.moe_version_type = args.moe_version_type
self.raw_config = raw_config
print("moe_version_type: ", self.moe_version_type)
assert self.moe_version_type in [None, 'all_tiling', 'seq_concat', 'feat_concat', 'convnext_512_siglip_448'], f"Unknown self.moe_version_type: {self.moe_version_type}"
vision_tower_name_list = vision_tower.split(";")
self.input_image_size = 1024
self.convnext_img_size = convnext_img_size
self.load_vision_towers(vision_tower_name_list, args)
def load_vision_towers(self, vision_tower_name_list, args):
self.vision_towers = nn.ModuleList()
freeze_backbone_list = args.freeze_backbones # note this is a str
if freeze_backbone_list is not None and len(freeze_backbone_list) > 0:
print("The frozen backbones: ", freeze_backbone_list)
else:
# make it a blank str
freeze_backbone_list = ""
for name in vision_tower_name_list:
## ConvNeXt
if name == 'convnext-1024':
convnext_args = deepcopy(args)
convnext_args.freeze_vision = False
if 'convnext-1024' in freeze_backbone_list:
convnext_args.freeze_vision = True
from .convnext_encoder import ConvNextVisionTower
convnext_args.input_image_size = self.convnext_img_size
convnext_vision_tower = args.vision_tower_convnext_path
convnext_vision_tower = ConvNextVisionTower(convnext_vision_tower,
convnext_args, delay_load=args.delay_load, normalize_type=self.normalize_type)
convnext_vision_tower.load_model()
self.vision_towers.append(convnext_vision_tower)
## PaliSigLIP
elif name == 'palisiglip':
palisiglip_args = deepcopy(args)
palisiglip_args.input_image_size = 448
palisiglip_args.freeze_vision = False
if 'palisiglip' in freeze_backbone_list:
palisiglip_args.freeze_vision = True
palisiglip_vision_tower = SiglipVisionTower(args.vision_tower_siglip_path, palisiglip_args, delay_load=args.delay_load, raw_config=self.raw_config)
palisiglip_vision_tower.load_model()
self.vision_towers.append(palisiglip_vision_tower)
# Set the image processor
self.image_processor = None
self.is_loaded = True
def load_model(self):
assert self.is_loaded, "All the vision encoders should be loaded during initialization!"
def forward(self, x):
# x is a Tensor if moe_version_type is None or 'all_tiling'
# else is a tuple(Tensor, Tensor)
if self.moe_version_type in [None, 'all_tiling']:
# The default pipeline
features = []
image_input_size = x.shape[2]
assert x.shape[2] == x.shape[3], f"Image should be a square but size ({x.shape[2]} x {x.shape[3]})"
for vision_tower in self.vision_towers:
if vision_tower.input_image_size != image_input_size:
resized_x = F.interpolate(x.float(),
size=(vision_tower.input_image_size, vision_tower.input_image_size),
mode='bilinear',
align_corners=True).to(dtype=x.dtype)
else:
resized_x = x
feature = vision_tower(resized_x)
if len(feature.shape) == 3: # b, n, c
b, n, c = feature.shape
if n == self.num_tokens:
features.append(feature)
continue
w = h = int(n**0.5)
feature = feature.transpose(1,2).reshape(b, c, h, w)
else:
b, c, h, w = feature.shape
if w != self.grid_size:
feature = F.interpolate(feature.float(), size=(self.grid_size, self.grid_size), mode='bilinear', align_corners=True).to(dtype=x.dtype)
features.append(feature.flatten(2,3).transpose(1,2))
features = torch.cat(features, dim=-1)
elif self.moe_version_type == 'convnext_512_siglip_448':
features = {}
image_input_size = x.shape[2]
assert x.shape[2] == x.shape[3], f"Image should be a square but size ({x.shape[2]} x {x.shape[3]})"
for vision_tower in self.vision_towers:
if vision_tower.input_image_size != image_input_size:
resized_x = F.interpolate(x.float(),
size=(vision_tower.input_image_size, vision_tower.input_image_size),
mode='bilinear',
align_corners=True).to(dtype=x.dtype)
else:
resized_x = x
feature = vision_tower(resized_x)
# if len(feature.shape) == 3: # b, n, c
# b, n, c = feature.shape
# if n == self.num_tokens:
# features.append(feature)
# continue
# w = h = int(n**0.5)
# feature = feature.transpose(1,2).reshape(b, c, h, w)
# else:
# b, c, h, w = feature.shape
features[vision_tower.name] = feature
else:
assert isinstance(x, dict), "x is expected to be a dict but {}".format(type(x))
pixel_values = x['pixel_values']
num_patches = x['num_patches'] # num patch of paddings token in texts
# calculated the real image patches
if self.moe_version_type == 'seq_concat':
image_in_num_patches = [i-1 for i in num_patches]
else:
image_in_num_patches = [i for i in num_patches]
assert sum(image_in_num_patches) == pixel_values.size(0), "sum(image_in_num_patches) ({}) != pixel_values.size(0) ({})".format(sum(image_in_num_patches), pixel_values.size(0))
# find the thubnail image id
thumbnail_image_id = torch.cumsum(torch.tensor(image_in_num_patches).to(pixel_values.device), 0) - 1
image_no_tiling = pixel_values[thumbnail_image_id]
# By default, we use the 1st vision_tower for x, others for x_nt
features = []
for layer_id, vision_tower in enumerate(self.vision_towers):
if layer_id == 0:
x = pixel_values
else:
x = image_no_tiling
if vision_tower.input_image_size != self.input_image_size:
resized_x = F.interpolate(x.float(),
size=(vision_tower.input_image_size, vision_tower.input_image_size),
mode='bilinear',
align_corners=True).to(dtype=x.dtype)
else:
resized_x = x
feature = vision_tower(resized_x)
if len(feature.shape) == 3: # b, n, c
b, n, c = feature.shape
if n == self.num_tokens:
features.append(feature)
continue
w = h = int(n**0.5)
feature = feature.transpose(1,2).reshape(b, c, h, w)
else:
b, c, h, w = feature.shape
if w != self.grid_size:
feature = F.interpolate(feature.float(), size=(self.grid_size, self.grid_size), mode='bilinear', align_corners=True).to(dtype=x.dtype)
features.append(feature.flatten(2,3).transpose(1,2))
clip_embeds = features[0]
if len(features) <= 1:
no_tiling_embeds = None
else:
no_tiling_embeds = torch.cat(features[1:], dim=-1)
if self.moe_version_type == 'feat_concat':
# concat thumbnail images features together
clip_thumbnail_embeds = clip_embeds[thumbnail_image_id]
if no_tiling_embeds is not None:
no_tiling_embeds = torch.cat([clip_thumbnail_embeds, no_tiling_embeds], dim=-1)
else:
no_tiling_embeds = clip_thumbnail_embeds
# extra patch featureas
clip_embeds_mask = ~torch.isin(torch.arange(clip_embeds.shape[0]).to(clip_embeds.device), thumbnail_image_id)
clip_embeds = clip_embeds[clip_embeds_mask]
features = {
'clip_embeds': clip_embeds,
'no_tiling_embeds': no_tiling_embeds,
'num_patches': num_patches
}
# features is a Tensor if not clip_tiling_only
return features
@property
def dummy_feature(self):
return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype)
@property
def dtype(self):
return next(self.clip_vision_tower.parameters()).dtype
@property
def device(self):
return next(self.clip_vision_tower.parameters()).device
@property
def config(self):
assert NotImplementedError
pass
@property
def hidden_size(self):
if self.moe_version_type == 'convnext_512_siglip_448':
res = {}
for vision_tower in self.vision_towers:
res[vision_tower.name] = vision_tower.hidden_size
return res
else:
return sum([_.hidden_size for _ in self.vision_towers])
@property
def num_patches(self):
return self.num_tokens