Spaces:
Runtime error
Runtime error
import os | |
import torch | |
import torch.nn as nn | |
from transformers import PreTrainedModel | |
# from tinyllava.utils.data_utils import get_value_from_kwargs | |
def get_value_from_kwargs(kwargs, name): | |
if name in kwargs: | |
return kwargs.pop(name) | |
else: | |
return None | |
class VisionTower(nn.Module): | |
def __init__(self, cfg): | |
super().__init__() | |
self._vision_tower = None | |
self._image_processor = None | |
self.config = cfg | |
def load_model(self, vision_tower_name, **kwargs): | |
self._load_model(vision_tower_name, **kwargs) | |
self._vision_tower.requires_grad_(False) | |
def _load_model(self, vision_tower_name, **kwargs): | |
pretrained_vision_tower_path = get_value_from_kwargs(kwargs, 'pretrained_vision_tower_path') | |
if isinstance(self._vision_tower, PreTrainedModel): # hf model | |
if pretrained_vision_tower_path is not None: | |
vision_tower_name = pretrained_vision_tower_path | |
self._vision_tower = self._vision_tower.from_pretrained(vision_tower_name, **kwargs) | |
else: # nn.Module | |
if pretrained_vision_tower_path is not None: | |
vision_tower_weights = torch.load(os.path.join(pretrained_vision_tower_path, 'pytorch_model.bin'), map_location='cpu') | |
def get_w(weights, keyword): | |
return {k.split(keyword + '.')[1]: v for k, v in weights.items() if keyword in k} | |
self._vision_tower.load_state_dict(vision_tower_weights) | |
print("Loading vision tower from ", vision_tower_name) | |
def forward(self, x, **kwargs): | |
x = x.to(torch.float32) | |
image_features = self._vision_tower(x, output_hidden_states=True) | |
image_features = image_features.hidden_states[kwargs.get('vision_feature_layer', -2)] | |
if kwargs.get('vision_feature_select_strategy', 'patch') == 'patch': | |
image_features = image_features[:, 1:] | |
elif kwargs.get('vision_feature_select_strategy', 'patch') == 'cls_patch': | |
image_features = image_features | |
else: | |
raise ValueError(f"Unexpected select feature: {kwargs.get('vision_feature_select_strategy')}") | |
return image_features | |
def vision_tower(self): | |
return self._vision_tower | |
def vision_tower(self, vision_tower): | |
self._vision_tower = vision_tower | |