Camil Ziane
init space
74b17e0
raw
history blame
3.61 kB
import os
import torch
import torch.nn as nn
from transformers import CLIPVisionModel, CLIPImageProcessor, CLIPVisionConfig, Dinov2Model, AutoConfig
from . import register_vision_tower
from .base import VisionTower
class MoF(nn.Module):
def __init__(self, cfg):
super().__init__()
self.clip = CLIPVisionModel(cfg)
cfg_dinov2 = AutoConfig.from_pretrained(cfg.model_name_or_path2)
self.dinov2 = Dinov2Model(cfg_dinov2)
# def enable_input_require_grads(self):
# def make_inputs_require_grad(module, input, output):
# output.requires_grads()
# if hasattr(self.clip, 'enable_input_require_grads'):
# self.clip.enable_input_require_grads()
# else:
# self.clip.get_input_embeddings(make_inputs_require_grad)
# if hasattr(self.dinov2, 'enable_input_require_grads'):
# self.dinov2.enable_input_require_grads()
# else:
# self.dinov2.get_input_embeddings(make_inputs_require_grad)
def forward(self, x, **kwargs):
image_features_clip = self.clip(x, output_hidden_states=True)
image_features_clip = image_features_clip.hidden_states[kwargs.get('vision_feature_layer', -2)]
image_features_dinov2 = self.dinov2(x, output_hidden_states=True)
image_features_dinov2 = image_features_dinov2.hidden_states[kwargs.get('vision_feature_layer', -2)]
if kwargs.get('vision_feature_select_strategy', 'patch') == 'patch':
image_features_clip = image_features_clip[:, 1:]
image_features_dinov2 = image_features_dinov2[:, 1:]
elif kwargs.get('vision_feature_select_strategy', 'patch') == 'cls_patch':
image_features_clip = image_features_clip
image_features_dinov2 = image_features_dinov2
else:
raise ValueError(f"Unexpected select feature: {kwargs.get('vision_feature_select_strategy')}")
image_features = image_features_clip, image_features_dinov2
return image_features
@register_vision_tower('mof')
class MoFVisionTower(VisionTower):
def __init__(self, cfg):
super().__init__(cfg)
self._vision_tower = MoF(cfg)
self._image_processor = CLIPImageProcessor.from_pretrained(cfg.model_name_or_path)
def _load_model(self, vision_tower_name, **kwargs):
pretrained_vision_tower_path = kwargs.pop('pretrained_vision_tower_path', None)
if pretrained_vision_tower_path is None:
model_name_or_path_dinov2 = kwargs.pop('model_name_or_path2')
self._vision_tower.clip = self._vision_tower.clip.from_pretrained(vision_tower_name, **kwargs)
self._vision_tower.dinov2 = self._vision_tower.dinov2.from_pretrained(model_name_or_path_dinov2, **kwargs)
print("Loading vision tower1 from ", vision_tower_name)
print("Loading vision tower2 from ", model_name_or_path_dinov2)
else: # nn.Module
if pretrained_vision_tower_path is not None:
vision_tower_weights = torch.load(os.path.join(pretrained_vision_tower_path, 'pytorch_model.bin'), map_location='cpu')
def get_w(weights, keyword):
return {k.split(keyword + '.')[1]: v for k, v in weights.items() if keyword in k}
self._vision_tower.load_state_dict(vision_tower_weights)
print("Loading vision tower from ", pretrained_vision_tower_path)
def forward(self, x, **kwargs):
device = x.data.device
self.to(device)
return self._vision_tower(x, **kwargs)