|
import torch, os |
|
import torch.nn as nn |
|
from torch.utils.checkpoint import checkpoint |
|
|
|
from .siglip_vision_tower import SiglipVisionTower |
|
|
|
|
|
|
|
|
|
|
|
import torch.nn.functional as F |
|
from torch.nn.init import trunc_normal_ |
|
from copy import deepcopy |
|
import random |
|
import math |
|
|
|
class MultiBackboneChannelConcatenationVisionTower(nn.Module): |
|
def __init__(self, |
|
vision_tower, |
|
args, |
|
grid_size=32, |
|
convnext_img_size=1024, |
|
normalize_type=None, raw_config=None): |
|
|
|
super().__init__() |
|
|
|
self.is_loaded = False |
|
self.grid_size = grid_size |
|
self.num_tokens = self.grid_size ** 2 |
|
self.normalize_type = args.normalize_type |
|
self.moe_version_type = args.moe_version_type |
|
self.raw_config = raw_config |
|
print("moe_version_type: ", self.moe_version_type) |
|
assert self.moe_version_type in [None, 'all_tiling', 'seq_concat', 'feat_concat', 'convnext_512_siglip_448'], f"Unknown self.moe_version_type: {self.moe_version_type}" |
|
|
|
vision_tower_name_list = vision_tower.split(";") |
|
self.input_image_size = 1024 |
|
self.convnext_img_size = convnext_img_size |
|
self.load_vision_towers(vision_tower_name_list, args) |
|
|
|
|
|
def load_vision_towers(self, vision_tower_name_list, args): |
|
self.vision_towers = nn.ModuleList() |
|
|
|
freeze_backbone_list = args.freeze_backbones |
|
if freeze_backbone_list is not None and len(freeze_backbone_list) > 0: |
|
print("The frozen backbones: ", freeze_backbone_list) |
|
else: |
|
|
|
freeze_backbone_list = "" |
|
|
|
for name in vision_tower_name_list: |
|
|
|
|
|
if name == 'convnext-1024': |
|
convnext_args = deepcopy(args) |
|
|
|
convnext_args.freeze_vision = False |
|
if 'convnext-1024' in freeze_backbone_list: |
|
convnext_args.freeze_vision = True |
|
|
|
from .convnext_encoder import ConvNextVisionTower |
|
convnext_args.input_image_size = self.convnext_img_size |
|
convnext_vision_tower = args.vision_tower_convnext_path |
|
convnext_vision_tower = ConvNextVisionTower(convnext_vision_tower, |
|
convnext_args, delay_load=args.delay_load, normalize_type=self.normalize_type) |
|
convnext_vision_tower.load_model() |
|
self.vision_towers.append(convnext_vision_tower) |
|
|
|
|
|
elif name == 'palisiglip': |
|
palisiglip_args = deepcopy(args) |
|
palisiglip_args.input_image_size = 448 |
|
|
|
palisiglip_args.freeze_vision = False |
|
if 'palisiglip' in freeze_backbone_list: |
|
palisiglip_args.freeze_vision = True |
|
|
|
palisiglip_vision_tower = SiglipVisionTower(args.vision_tower_siglip_path, palisiglip_args, delay_load=args.delay_load, raw_config=self.raw_config) |
|
|
|
palisiglip_vision_tower.load_model() |
|
self.vision_towers.append(palisiglip_vision_tower) |
|
|
|
|
|
self.image_processor = None |
|
self.is_loaded = True |
|
|
|
def load_model(self): |
|
assert self.is_loaded, "All the vision encoders should be loaded during initialization!" |
|
|
|
def forward(self, x): |
|
|
|
|
|
if self.moe_version_type in [None, 'all_tiling']: |
|
|
|
features = [] |
|
image_input_size = x.shape[2] |
|
assert x.shape[2] == x.shape[3], f"Image should be a square but size ({x.shape[2]} x {x.shape[3]})" |
|
for vision_tower in self.vision_towers: |
|
|
|
if vision_tower.input_image_size != image_input_size: |
|
resized_x = F.interpolate(x.float(), |
|
size=(vision_tower.input_image_size, vision_tower.input_image_size), |
|
mode='bilinear', |
|
align_corners=True).to(dtype=x.dtype) |
|
else: |
|
resized_x = x |
|
|
|
feature = vision_tower(resized_x) |
|
|
|
if len(feature.shape) == 3: |
|
b, n, c = feature.shape |
|
if n == self.num_tokens: |
|
features.append(feature) |
|
continue |
|
w = h = int(n**0.5) |
|
feature = feature.transpose(1,2).reshape(b, c, h, w) |
|
else: |
|
b, c, h, w = feature.shape |
|
|
|
if w != self.grid_size: |
|
feature = F.interpolate(feature.float(), size=(self.grid_size, self.grid_size), mode='bilinear', align_corners=True).to(dtype=x.dtype) |
|
features.append(feature.flatten(2,3).transpose(1,2)) |
|
|
|
features = torch.cat(features, dim=-1) |
|
elif self.moe_version_type == 'convnext_512_siglip_448': |
|
features = {} |
|
image_input_size = x.shape[2] |
|
assert x.shape[2] == x.shape[3], f"Image should be a square but size ({x.shape[2]} x {x.shape[3]})" |
|
for vision_tower in self.vision_towers: |
|
|
|
if vision_tower.input_image_size != image_input_size: |
|
resized_x = F.interpolate(x.float(), |
|
size=(vision_tower.input_image_size, vision_tower.input_image_size), |
|
mode='bilinear', |
|
align_corners=True).to(dtype=x.dtype) |
|
else: |
|
resized_x = x |
|
|
|
feature = vision_tower(resized_x) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
features[vision_tower.name] = feature |
|
|
|
else: |
|
assert isinstance(x, dict), "x is expected to be a dict but {}".format(type(x)) |
|
pixel_values = x['pixel_values'] |
|
num_patches = x['num_patches'] |
|
|
|
|
|
if self.moe_version_type == 'seq_concat': |
|
image_in_num_patches = [i-1 for i in num_patches] |
|
else: |
|
image_in_num_patches = [i for i in num_patches] |
|
|
|
|
|
assert sum(image_in_num_patches) == pixel_values.size(0), "sum(image_in_num_patches) ({}) != pixel_values.size(0) ({})".format(sum(image_in_num_patches), pixel_values.size(0)) |
|
|
|
|
|
thumbnail_image_id = torch.cumsum(torch.tensor(image_in_num_patches).to(pixel_values.device), 0) - 1 |
|
image_no_tiling = pixel_values[thumbnail_image_id] |
|
|
|
|
|
features = [] |
|
for layer_id, vision_tower in enumerate(self.vision_towers): |
|
if layer_id == 0: |
|
x = pixel_values |
|
else: |
|
x = image_no_tiling |
|
|
|
if vision_tower.input_image_size != self.input_image_size: |
|
resized_x = F.interpolate(x.float(), |
|
size=(vision_tower.input_image_size, vision_tower.input_image_size), |
|
mode='bilinear', |
|
align_corners=True).to(dtype=x.dtype) |
|
else: |
|
resized_x = x |
|
|
|
feature = vision_tower(resized_x) |
|
if len(feature.shape) == 3: |
|
b, n, c = feature.shape |
|
if n == self.num_tokens: |
|
features.append(feature) |
|
continue |
|
|
|
w = h = int(n**0.5) |
|
feature = feature.transpose(1,2).reshape(b, c, h, w) |
|
else: |
|
b, c, h, w = feature.shape |
|
|
|
if w != self.grid_size: |
|
feature = F.interpolate(feature.float(), size=(self.grid_size, self.grid_size), mode='bilinear', align_corners=True).to(dtype=x.dtype) |
|
features.append(feature.flatten(2,3).transpose(1,2)) |
|
|
|
clip_embeds = features[0] |
|
if len(features) <= 1: |
|
no_tiling_embeds = None |
|
else: |
|
no_tiling_embeds = torch.cat(features[1:], dim=-1) |
|
|
|
if self.moe_version_type == 'feat_concat': |
|
|
|
clip_thumbnail_embeds = clip_embeds[thumbnail_image_id] |
|
if no_tiling_embeds is not None: |
|
no_tiling_embeds = torch.cat([clip_thumbnail_embeds, no_tiling_embeds], dim=-1) |
|
else: |
|
no_tiling_embeds = clip_thumbnail_embeds |
|
|
|
|
|
clip_embeds_mask = ~torch.isin(torch.arange(clip_embeds.shape[0]).to(clip_embeds.device), thumbnail_image_id) |
|
clip_embeds = clip_embeds[clip_embeds_mask] |
|
|
|
|
|
features = { |
|
'clip_embeds': clip_embeds, |
|
'no_tiling_embeds': no_tiling_embeds, |
|
'num_patches': num_patches |
|
} |
|
|
|
|
|
|
|
return features |
|
|
|
@property |
|
def dummy_feature(self): |
|
return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype) |
|
|
|
@property |
|
def dtype(self): |
|
return next(self.clip_vision_tower.parameters()).dtype |
|
|
|
@property |
|
def device(self): |
|
return next(self.clip_vision_tower.parameters()).device |
|
|
|
@property |
|
def config(self): |
|
assert NotImplementedError |
|
pass |
|
|
|
@property |
|
def hidden_size(self): |
|
if self.moe_version_type == 'convnext_512_siglip_448': |
|
res = {} |
|
for vision_tower in self.vision_towers: |
|
res[vision_tower.name] = vision_tower.hidden_size |
|
return res |
|
else: |
|
return sum([_.hidden_size for _ in self.vision_towers]) |
|
|
|
@property |
|
def num_patches(self): |
|
return self.num_tokens |
|
|
|
|
|
|
|
|