|
import os |
|
import torch |
|
import torch.nn as nn |
|
from transformers import CLIPVisionModel, CLIPImageProcessor, CLIPVisionConfig, Dinov2Model, AutoConfig |
|
|
|
from . import register_vision_tower |
|
from .base import VisionTower |
|
|
|
|
|
|
|
|
|
|
|
class MoF(nn.Module): |
|
def __init__(self, cfg): |
|
super().__init__() |
|
self.clip = CLIPVisionModel(cfg) |
|
|
|
cfg_dinov2 = AutoConfig.from_pretrained(cfg.model_name_or_path2) |
|
self.dinov2 = Dinov2Model(cfg_dinov2) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def forward(self, x, **kwargs): |
|
|
|
image_features_clip = self.clip(x, output_hidden_states=True) |
|
image_features_clip = image_features_clip.hidden_states[kwargs.get('vision_feature_layer', -2)] |
|
|
|
image_features_dinov2 = self.dinov2(x, output_hidden_states=True) |
|
image_features_dinov2 = image_features_dinov2.hidden_states[kwargs.get('vision_feature_layer', -2)] |
|
|
|
if kwargs.get('vision_feature_select_strategy', 'patch') == 'patch': |
|
image_features_clip = image_features_clip[:, 1:] |
|
image_features_dinov2 = image_features_dinov2[:, 1:] |
|
elif kwargs.get('vision_feature_select_strategy', 'patch') == 'cls_patch': |
|
image_features_clip = image_features_clip |
|
image_features_dinov2 = image_features_dinov2 |
|
else: |
|
raise ValueError(f"Unexpected select feature: {kwargs.get('vision_feature_select_strategy')}") |
|
|
|
|
|
image_features = image_features_clip, image_features_dinov2 |
|
|
|
return image_features |
|
|
|
|
|
|
|
|
|
|
|
|
|
@register_vision_tower('mof') |
|
class MoFVisionTower(VisionTower): |
|
def __init__(self, cfg): |
|
super().__init__(cfg) |
|
|
|
self._vision_tower = MoF(cfg) |
|
|
|
self._image_processor = CLIPImageProcessor.from_pretrained(cfg.model_name_or_path) |
|
|
|
|
|
def _load_model(self, vision_tower_name, **kwargs): |
|
pretrained_vision_tower_path = kwargs.pop('pretrained_vision_tower_path', None) |
|
if pretrained_vision_tower_path is None: |
|
model_name_or_path_dinov2 = kwargs.pop('model_name_or_path2') |
|
self._vision_tower.clip = self._vision_tower.clip.from_pretrained(vision_tower_name, **kwargs) |
|
self._vision_tower.dinov2 = self._vision_tower.dinov2.from_pretrained(model_name_or_path_dinov2, **kwargs) |
|
print("Loading vision tower1 from ", vision_tower_name) |
|
print("Loading vision tower2 from ", model_name_or_path_dinov2) |
|
else: |
|
if pretrained_vision_tower_path is not None: |
|
vision_tower_weights = torch.load(os.path.join(pretrained_vision_tower_path, 'pytorch_model.bin'), map_location='cpu') |
|
def get_w(weights, keyword): |
|
return {k.split(keyword + '.')[1]: v for k, v in weights.items() if keyword in k} |
|
self._vision_tower.load_state_dict(vision_tower_weights) |
|
print("Loading vision tower from ", pretrained_vision_tower_path) |
|
|
|
|
|
def forward(self, x, **kwargs): |
|
device = x.data.device |
|
self.to(device) |
|
return self._vision_tower(x, **kwargs) |
|
|
|
|
|
|