|
|
|
|
|
import torch |
|
import transformers |
|
from packaging import version |
|
from transformers import AutoConfig, AutoModel, PretrainedConfig |
|
|
|
from llava.model.multimodal_encoder.vision_encoder import VisionTower, VisionTowerS2 |
|
|
|
if version.parse(transformers.__version__) > version.parse("4.36.2"): |
|
from transformers import SiglipImageProcessor, SiglipVisionConfig, SiglipVisionModel |
|
else: |
|
from .siglip import SiglipImageProcessor, SiglipVisionConfig, SiglipVisionModel |
|
|
|
|
|
class SiglipVisionTower(VisionTower): |
|
def __init__(self, model_name_or_path: str, config: PretrainedConfig, state_dict=None): |
|
super().__init__(model_name_or_path, config) |
|
self.image_processor = SiglipImageProcessor.from_pretrained(model_name_or_path) |
|
self.vision_tower = SiglipVisionModel.from_pretrained( |
|
|
|
model_name_or_path, torch_dtype=eval(config.model_dtype), state_dict=state_dict |
|
) |
|
self.is_loaded = True |
|
|
|
|
|
class SiglipVisionTowerS2(VisionTowerS2): |
|
def __init__(self, model_name_or_path: str, config: PretrainedConfig): |
|
super().__init__(model_name_or_path, config) |
|
self.image_processor = SiglipImageProcessor.from_pretrained(model_name_or_path) |
|
self.vision_tower = SiglipVisionModel.from_pretrained( |
|
model_name_or_path, torch_dtype=eval(config.model_dtype) |
|
) |
|
|
|
|
|
self.image_processor.size['height'] = self.image_processor.size['width'] = self.scales[-1] |
|
|
|
self.is_loaded = True |
|
|
|
if version.parse(transformers.__version__) <= version.parse("4.36.2"): |
|
AutoConfig.register("siglip_vision_model", SiglipVisionConfig) |
|
AutoModel.register(SiglipVisionConfig, SiglipVisionModel) |
|
|