Spaces:
Sleeping
Sleeping
File size: 5,670 Bytes
04f8e39 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 |
# Copyright 2024 Zhenwei Shao and MILVLG team.
# Licensed under the Apache License, Version 2.0.
# Adopted from https://github.com/haotian-liu/LLaVA.
import torch
import torch.nn as nn
from typing import Dict, Optional, Union
import numpy as np
from transformers import CLIPVisionModel, CLIPImageProcessor, CLIPVisionConfig
from .siglip.image_processing_flashsloth import ImpImageProcessor
from .siglip.modeling_siglip import SiglipVisionModel
from .siglip.configuration_siglip import SiglipVisionConfig
class CLIPVisionTower(nn.Module):
def __init__(self, vision_tower, args, delay_load=False):
super().__init__()
self.is_loaded = False
self.vision_tower_name = vision_tower
self.select_layer = args.mm_vision_select_layer
self.select_feature = getattr(args, 'mm_vision_select_feature', 'patch')
if not delay_load:
self.load_model()
else:
self.cfg_only = SiglipVisionConfig.from_pretrained(self.vision_tower_name)
def load_model(self):
if self.is_loaded:
return
# It's a hacky way to check if model is initialized under meta device
# context, which will be enabled when loading trained model by huggingface
# `from_pretrained` api. In the case that a full model with vision tower is
# loaded, there will be a warning if vision tower is loaded to cpu here. So we
# set `device_map` to `auto` in order to avoid the warning.
# [Edited by zhenwei - 2024-02-02 13:03]
is_meta = getattr(nn.Linear(1, 1, bias=False).weight, 'is_meta', False)
if 'siglip' in self.vision_tower_name:
# "google/siglip-so400m-patch14-384"
self.image_processor = ImpImageProcessor()
if is_meta:
# cfg = SiglipVisionConfig.from_pretrained(self.vision_tower_name)
# self.vision_tower = SiglipVisionModel(cfg)
self.vision_tower = SiglipVisionModel.from_pretrained(self.vision_tower_name, device_map='auto')
else:
self.vision_tower = SiglipVisionModel.from_pretrained(self.vision_tower_name)
del self.vision_tower.vision_model.encoder.layers[(self.select_layer + 1):]
self.vision_tower.vision_model.post_layernorm = nn.Identity()
self.vision_tower.vision_model.head = nn.Identity()
else:
self.image_processor = CLIPImageProcessor.from_pretrained(self.vision_tower_name)
if is_meta:
# cfg = CLIPVisionConfig.from_pretrained(self.vision_tower_name)
# self.vision_tower = CLIPVisionModel(cfg)
self.vision_tower = CLIPVisionModel.from_pretrained(self.vision_tower_name, device_map='auto')
else:
self.vision_tower = CLIPVisionModel.from_pretrained(self.vision_tower_name)
del self.vision_tower.vision_model.encoder.layers[(self.select_layer + 1):]
self.vision_tower.requires_grad_(False)
self.vision_tower.eval()
self.is_loaded = True
def feature_select(self, image_forward_outs):
# image_features = image_forward_outs.hidden_states[self.select_layer]
image_features = image_forward_outs.hidden_states[-1]
if self.select_feature == 'patch':
image_features = image_features[:, -self.num_patches:]
assert image_features.shape[-2] == self.num_patches, f'select_feature=patch, image_features.shape[-2]={image_features.shape[-2]} != num_patches={self.num_patches}'
elif self.select_feature == 'cls_patch':
image_features = image_features
assert image_features.shape[-2] == self.num_patches + 1, f'select_feature=cls_patch, image_features.shape[-2]={image_features.shape[-2]} != num_patches+1={self.num_patches+1}'
else:
raise ValueError(f'Unexpected select feature: {self.select_feature}')
return image_features
@torch.no_grad()
def forward(self, images):
# assert self.num_patches == 729
if type(images) is list:
image_features = []
for image in images:
image_forward_out = self.vision_tower(image.to(device=self.device, dtype=self.dtype).unsqueeze(0), output_hidden_states=True)
image_feature = self.feature_select(image_forward_out).to(image.dtype)
# image_feature = image_forward_out.last_hidden_state.to(image.dtype)
image_features.append(image_feature)
else:
image_forward_outs = self.vision_tower(images.to(device=self.device, dtype=self.dtype), output_hidden_states=True)
image_features = self.feature_select(image_forward_outs).to(images.dtype)
# image_features = image_forward_outs.last_hidden_state.to(images.dtype)
return image_features
@property
def dummy_feature(self):
return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype)
@property
def dtype(self):
for p in self.vision_tower.parameters():
return p.dtype
@property
def device(self):
for p in self.vision_tower.parameters():
return p.device
@property
def is_meta(self):
return self.device.type == 'meta'
@property
def config(self):
if self.is_loaded:
return self.vision_tower.config
else:
return self.cfg_only
@property
def hidden_size(self):
return self.config.hidden_size
@property
def num_patches(self):
return (self.config.image_size // self.config.patch_size) ** 2
|