|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from typing import Dict, List, Literal, Optional, Tuple, Union |
|
|
|
import torch |
|
import torch.nn as nn |
|
import torchvision.transforms |
|
from einops import rearrange |
|
|
|
from .siglip_vit import create_siglip_vit |
|
|
|
|
|
class CLIPVisionTower(nn.Module): |
|
def __init__( |
|
self, |
|
model_name: str = "siglip_large_patch16_384", |
|
image_size: Union[Tuple[int, int], int] = 336, |
|
select_feature: str = "patch", |
|
select_layer: int = -2, |
|
select_layers: list = None, |
|
ckpt_path: str = "", |
|
pixel_mean: Optional[List[float]] = None, |
|
pixel_std: Optional[List[float]] = None, |
|
**kwargs, |
|
): |
|
super().__init__() |
|
|
|
self.model_name = model_name |
|
self.select_feature = select_feature |
|
self.select_layer = select_layer |
|
self.select_layers = select_layers |
|
|
|
vision_tower_params = { |
|
"model_name": model_name, |
|
"image_size": image_size, |
|
"ckpt_path": ckpt_path, |
|
"select_layer": select_layer, |
|
} |
|
vision_tower_params.update(kwargs) |
|
self.vision_tower, self.forward_kwargs = self.build_vision_tower( |
|
vision_tower_params |
|
) |
|
|
|
if pixel_mean is not None and pixel_std is not None: |
|
image_norm = torchvision.transforms.Normalize( |
|
mean=pixel_mean, std=pixel_std |
|
) |
|
else: |
|
image_norm = None |
|
|
|
self.image_norm = image_norm |
|
|
|
def build_vision_tower(self, vision_tower_params): |
|
if self.model_name.startswith("siglip"): |
|
self.select_feature = "same" |
|
vision_tower = create_siglip_vit(**vision_tower_params) |
|
forward_kwargs = dict() |
|
|
|
elif self.model_name.startswith("sam"): |
|
vision_tower = create_sam_vit(**vision_tower_params) |
|
forward_kwargs = dict() |
|
|
|
else: |
|
from transformers import CLIPVisionModel |
|
|
|
vision_tower = CLIPVisionModel.from_pretrained(**vision_tower_params) |
|
forward_kwargs = dict(output_hidden_states=True) |
|
|
|
return vision_tower, forward_kwargs |
|
|
|
def feature_select(self, image_forward_outs): |
|
if isinstance(image_forward_outs, torch.Tensor): |
|
|
|
image_features = image_forward_outs |
|
else: |
|
image_features = image_forward_outs.hidden_states[self.select_layer] |
|
|
|
if self.select_feature == "patch": |
|
|
|
image_features = image_features[:, 1:] |
|
elif self.select_feature == "cls_patch": |
|
image_features = image_features |
|
elif self.select_feature == "same": |
|
image_features = image_features |
|
|
|
else: |
|
raise ValueError(f"Unexpected select feature: {self.select_feature}") |
|
return image_features |
|
|
|
def forward(self, images): |
|
""" |
|
|
|
Args: |
|
images (torch.Tensor): [b, 3, H, W] |
|
|
|
Returns: |
|
image_features (torch.Tensor): [b, n_patch, d] |
|
""" |
|
|
|
if self.image_norm is not None: |
|
images = self.image_norm(images) |
|
|
|
image_forward_outs = self.vision_tower(images, **self.forward_kwargs) |
|
image_features = self.feature_select(image_forward_outs) |
|
return image_features |
|
|