File size: 693 Bytes
74b17e0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 |
from transformers import SiglipVisionModel, SiglipVisionConfig, SiglipImageProcessor
from . import register_vision_tower
from .base import VisionTower
@register_vision_tower('siglip')
class SIGLIPVisionTower(VisionTower):
def __init__(self, cfg):
super().__init__(cfg)
self._vision_tower = SiglipVisionModel(cfg)
self._image_processor = SiglipImageProcessor.from_pretrained(cfg.model_name_or_path)
# def forward(self, x, **kwargs):
# image_features = self._vision_tower(x, output_hidden_states=True)
# image_features = image_features.hidden_states[kwargs.get('vision_feature_layer', -2)]
# return image_features
|