Apollo-LMMs-Apollo-3B-t32 / vision_tower.py
hawky-ai-labs's picture
Upload folder using huggingface_hub
7441f42 verified
import torch, os, PIL, numbers
from PIL import Image
import cv2
from transformers.modeling_utils import PreTrainedModel
from transformers.models.siglip.modeling_siglip import SiglipVisionModel
from transformers import AutoConfig, AutoModel, SiglipImageProcessor, SiglipVisionConfig, PretrainedConfig
from typing import Union
import torch.nn.functional as F
import numpy as np
def crop_clip(clip, min_h, min_w, h, w):
if isinstance(clip[0], np.ndarray):
cropped = [img[min_h:min_h + h, min_w:min_w + w, :] for img in clip]
elif isinstance(clip[0], PIL.Image.Image):
cropped = [
img.crop((min_w, min_h, min_w + w, min_h + h)) for img in clip
]
else:
raise TypeError('Expected numpy.ndarray or PIL.Image' +
'but got list of {0}'.format(type(clip[0])))
return cropped
class Normalize(object):
"""Normalize a clip with mean and standard deviation.
Given mean: ``(M1,...,Mn)`` and std: ``(S1,..,Sn)`` for ``n`` channels, this transform
will normalize each channel of the input ``torch.*Tensor`` i.e.
``input[channel] = (input[channel] - mean[channel]) / std[channel]``
.. note::
This transform acts out of place, i.e., it does not mutates the input tensor.
Args:
mean (sequence): Sequence of means for each channel.
std (sequence): Sequence of standard deviations for each channel.
"""
def __init__(self, mean, std):
self.mean = mean
self.std = std
def __call__(self, clip):
"""
Args:
clip (Tensor): Tensor clip of size (T, C, H, W) to be normalized.
Returns:
Tensor: Normalized Tensor clip.
"""
return normalize(clip, self.mean, self.std)
def __repr__(self):
return self.__class__.__name__ + '(mean={0}, std={1})'.format(self.mean, self.std)
class CenterCrop(object):
"""Extract center crop at the same location for a list of images
Args:
size (sequence or int): Desired output size for the
crop in format (h, w)
"""
def __init__(self, size):
if isinstance(size, numbers.Number):
size = (size, size)
self.size = size
def __call__(self, clip):
"""
Args:
img (PIL.Image or numpy.ndarray): List of images to be cropped
in format (h, w, c) in numpy.ndarray
Returns:
PIL.Image or numpy.ndarray: Cropped list of images
"""
h, w = self.size
if isinstance(clip[0], np.ndarray):
im_h, im_w, im_c = clip[0].shape
elif isinstance(clip[0], PIL.Image.Image):
im_w, im_h = clip[0].size
else:
raise TypeError('Expected numpy.ndarray or PIL.Image' +
'but got list of {0}'.format(type(clip[0])))
if w > im_w or h > im_h:
error_msg = (
'Initial image size should be larger then '
'cropped size but got cropped sizes : ({w}, {h}) while '
'initial image is ({im_w}, {im_h})'.format(
im_w=im_w, im_h=im_h, w=w, h=h))
raise ValueError(error_msg)
x1 = int(round((im_w - w) / 2.))
y1 = int(round((im_h - h) / 2.))
cropped = crop_clip(clip, y1, x1, h, w)
return cropped
def resize_clip(clip, size, interpolation='bilinear'):
if isinstance(clip[0], np.ndarray):
if isinstance(size, numbers.Number):
im_h, im_w, im_c = clip[0].shape
# Min spatial dim already matches minimal size
if (im_w <= im_h and im_w == size) or (im_h <= im_w
and im_h == size):
return clip
new_h, new_w = get_resize_sizes(im_h, im_w, size)
size = (new_w, new_h)
else:
size = size[0], size[1]
if interpolation == 'bilinear':
np_inter = cv2.INTER_LINEAR
else:
np_inter = cv2.INTER_NEAREST
scaled = [
cv2.resize(img, size, interpolation=np_inter) for img in clip
]
elif isinstance(clip[0], PIL.Image.Image):
if isinstance(size, numbers.Number):
im_w, im_h = clip[0].size
# Min spatial dim already matches minimal size
if (im_w <= im_h and im_w == size) or (im_h <= im_w
and im_h == size):
return clip
new_h, new_w = get_resize_sizes(im_h, im_w, size)
size = (new_w, new_h)
else:
size = size[1], size[0]
if interpolation == 'bilinear':
pil_inter = PIL.Image.BILINEAR
else:
pil_inter = PIL.Image.NEAREST
scaled = [img.resize(size, pil_inter) for img in clip]
else:
raise TypeError('Expected numpy.ndarray or PIL.Image' +
'but got list of {0}'.format(type(clip[0])))
return scaled
def _is_tensor_clip(clip):
return torch.is_tensor(clip) and clip.ndimension() == 4
def get_resize_sizes(im_h, im_w, size):
if im_w < im_h:
ow = size
oh = int(size * im_h / im_w)
else:
oh = size
ow = int(size * im_w / im_h)
return oh, ow
def normalize(clip, mean, std, inplace=False):
if not _is_tensor_clip(clip):
raise TypeError('tensor is not a torch clip.')
if not inplace:
clip = clip.clone()
dtype = clip.dtype
mean = torch.as_tensor(mean, dtype=dtype, device=clip.device)
std = torch.as_tensor(std, dtype=dtype, device=clip.device)
clip.sub_(mean[:, None, None, None]).div_(std[:, None, None, None])
return clip
class Resize(object):
"""Resizes a list of (H x W x C) numpy.ndarray to the final size
The larger the original image is, the more times it takes to
interpolate
Args:
interpolation (str): Can be one of 'nearest', 'bilinear'
defaults to nearest
size (tuple): (widht, height)
"""
def __init__(self, size, interpolation='nearest'):
self.size = size
self.interpolation = interpolation
def __call__(self, clip):
resized = resize_clip(
clip, self.size, interpolation=self.interpolation)
return resized
class Compose(object):
"""Composes several transforms
Args:
transforms (list of ``Transform`` objects): list of transforms
to compose
"""
def __init__(self, transforms):
self.transforms = transforms
def __call__(self, clip):
for t in self.transforms:
clip = t(clip)
return clip
def convert_img(img):
"""Converts (H, W, C) numpy.ndarray to (C, W, H) format"""
if len(img.shape) == 3:
img = img.transpose(2, 0, 1)
if len(img.shape) == 2:
img = np.expand_dims(img, 0)
return img
class ClipToTensor(object):
"""Convert a list of m (H x W x C) numpy.ndarrays in the range [0, 255]
to a torch.FloatTensor of shape (C x m x H x W) in the range [0, 1.0]
"""
def __init__(self, channel_nb=3, div_255=True, numpy=False):
self.channel_nb = channel_nb
self.div_255 = div_255
self.numpy = numpy
def __call__(self, clip):
"""
Args: clip (list of numpy.ndarray): clip (list of images)
to be converted to tensor.
"""
# Retrieve shape
if isinstance(clip[0], np.ndarray):
h, w, ch = clip[0].shape
assert ch == self.channel_nb, "Got {0} instead of 3 channels".format(ch)
elif isinstance(clip[0], Image.Image):
w, h = clip[0].size
else:
raise TypeError(
"Expected numpy.ndarray or PIL.Image\
but got list of {0}".format(
type(clip[0])
)
)
np_clip = np.zeros([self.channel_nb, len(clip), int(h), int(w)])
# Convert
for img_idx, img in enumerate(clip):
if isinstance(img, np.ndarray):
pass
elif isinstance(img, Image.Image):
img = np.array(img, copy=False)
else:
raise TypeError(
"Expected numpy.ndarray or PIL.Image\
but got list of {0}".format(
type(clip[0])
)
)
img = convert_img(img)
np_clip[:, img_idx, :, :] = img
if self.numpy:
if self.div_255:
np_clip = np_clip / 255.0
return np_clip
else:
tensor_clip = torch.from_numpy(np_clip)
if not isinstance(tensor_clip, torch.FloatTensor):
tensor_clip = tensor_clip.float()
if self.div_255:
tensor_clip = torch.div(tensor_clip, 255)
return tensor_clip
class VisionTowerConfig(PretrainedConfig):
model_type = "vision_tower"
def __init__(self, vision_tower_name: str = None, **kwargs):
super().__init__()
self.vision_tower_name = vision_tower_name
class ProcessorWrapper:
def __init__(self, transform=None, processor=None, height=378, width=378, frames_per_clip=1,
image_mean=[0.48145466, 0.4578275, 0.40821073]):
assert transform is not None or processor is not None, "ERROR: you did not define both `transform` and `processor`! You must define either transform or processor"
assert transform is None or processor is None, "ERROR: you did defined both `transform` and `processor`! You must define only one of: transform or processor"
self._size = {
"height": height,
"width": width,
"frames_per_clip": frames_per_clip
}
self._transforms = transform
self._processor = processor
self.image_mean = image_mean
@property
def size(self):
return self._size
def preprocess(self, image, return_tensors='pt'):
# Ensure image is a PIL Image
output = {}
if self._transforms is not None:
output['pixel_values'] = [self._transforms(image)]
else:
output = self._processor(image, return_tensors='pt')
return output
def save_pretrained(self, save_path):
if self._transforms is not None:
transform_dict = transform_to_dict(self._transforms)
transform_dict["image_processor_type"] = "transforms"
with open(os.path.join(save_path, 'preprocessor_config.json'), 'w') as f:
json.dump(transform_dict, f, indent=4)
else:
self._processor.save_pretrained(save_path)
return
class VisionTower(PreTrainedModel):
config_class = VisionTowerConfig
def __init__(self, model_name_or_path: str, config: PretrainedConfig, vision_config: VisionTowerConfig = None):
super().__init__(vision_config)
self.vision_tower_name = model_name_or_path
self.vision_config = vision_config
self.select_layer = getattr(config, "mm_vision_select_layer", -2)
self.select_feature = getattr(config, "mm_vision_select_feature", "patch")
self.encode_batch_size = getattr(config, "encode_batch_size", 0) // 2
self.num_encode_batch = getattr(config, "num_encode_batch", 0) // 2
self.temporal_tubelet_size = getattr(vision_config, "tubelet_size", 1)
def feature_select(self, image_features):
if self.select_layer is not None:
image_features = image_features.hidden_states[self.select_layer]
if self.select_feature == "patch":
image_features = image_features[:, 1:]
elif self.select_feature == "cls_patch":
image_features = image_features
else:
raise ValueError(f"Unexpected select feature: {self.select_feature}")
return image_features
def vision_tower_forward(self, image):
image_feature = self.vision_tower(image, output_hidden_states=True)
return image_feature
def _forward(self, images, out_T=1):
if type(images) is list:
image_features = []
for image in images:
image_feature = self.vision_tower_forward(image.to(device=self.device, dtype=self.dtype).unsqueeze(0))
image_feature = self.feature_select(image_feature).to(image.dtype)
image_feature = image_features.reshape(image_feature.shape[0], self.W, self.H, self.D)
image_features.append(image_feature)
else:
original_shape = images.shape
if len(original_shape) == 5 and self.T == 1:
# downsample temporally if needed, and reshape from (B, T, C, W, H) to (B*T, C, W, H).
images = images[:, ::original_shape[1] // out_T, ...]
original_shape = images.shape
images = images.view(-1, *original_shape[2:])
image_features = self.vision_tower_forward(images.to(device=self.device, dtype=self.dtype))
image_features = self.feature_select(image_features).to(images.dtype)
# Reshape back to (B, T, ...) if necessary
if len(original_shape) == 5 and self.T == 1:
# Assuming the feature dimension does not change, adapt the following line if it does
new_shape = list(image_features.shape[:-2]) + [self.W, self.H, self.hidden_size]
image_features = image_features.reshape(new_shape)
feature_size = image_features.shape[1:]
image_features = image_features.view(original_shape[0], original_shape[1], *feature_size)
else:
image_features = image_features.reshape(image_features.shape[0], self.T, self.W, self.H, self.hidden_size)
return image_features
def forward(self, images):
return self._forward(images)
@property
def dummy_feature(self):
return torch.zeros(1, self.hidden_size, device=self.device, dtype=self.dtype)
@property
def dtype(self):
return self.vision_tower.dtype
@property
def device(self):
return self.vision_tower.device
@property
def num_patches(self):
return (self.config.image_size // self.config.patch_size) ** 2
class InternVideoTower(VisionTower):
def __init__(self, model_name_or_path: str, config: PretrainedConfig, vision_config: PretrainedConfig = None):
if vision_config is None:
vision_config = AutoConfig.from_pretrained(model_name_or_path, trust_remote_code=True)
super().__init__(model_name_or_path, config, vision_config)
self.vision_config = vision_config
normalize = ((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
print('loading: ', model_name_or_path)
model = AutoModel.from_pretrained(model_name_or_path, trust_remote_code=True)
self.vision_tower = model.to(dtype=eval(config.model_dtype))
transform = Compose([
Resize(self.vision_config.img_size, interpolation='bilinear'),
CenterCrop(size=(self.vision_config.img_size, self.vision_config.img_size)),
ClipToTensor(),
Normalize(mean=normalize[0], std=normalize[1])
])
self.vision_processor = ProcessorWrapper(transform=transform,
height=self.vision_config.img_size,
width=self.vision_config.img_size,
frames_per_clip=self.vision_config.num_frames,
image_mean=normalize[0])
self.W = self.H = vision_config.img_size // vision_config.patch_size
self.T = self.vision_config.num_frames // self.vision_config.tubelet_size
self.num_frames = self.vision_config.num_frames
self.hidden_size = vision_config.d_model
self.vision_select_layer=self.select_layer
self.select_layer=None
def vision_tower_forward(self, video):
if video.shape[-3] < self.num_frames:
video = video.repeat_interleave(self.num_frames, dim=-3)
elif video.shape[-3] > self.num_frames:
video = video[:, :, ::video.shape[-3] // self.num_frames, ...]
video_feature = self.vision_tower(video.to(device=self.device, dtype=self.dtype),
x_vis_return_idx=self.vision_select_layer, x_vis_only=True)
return video_feature
@property
def device(self):
return self.vision_tower.pos_embed.device
class SiglipVisionTower(VisionTower):
def __init__(self, model_name_or_path: str, config: PretrainedConfig, vision_config: PretrainedConfig = None):
if vision_config is None:
vision_config = SiglipVisionConfig.from_pretrained(model_name_or_path)
super().__init__(model_name_or_path, config, vision_config)
self.vision_config = vision_config
self.vision_tower_name = model_name_or_path
self.vision_processor = SiglipImageProcessor.from_pretrained(self.vision_tower_name)
print('loading: ', model_name_or_path)
self.vision_tower = SiglipVisionModel.from_pretrained(self.vision_tower_name)
self.hidden_size = self.vision_config.hidden_size
self.W = self.H = self.vision_config.image_size // self.vision_config.patch_size
self.T = 1
self.select_feature = "cls_patch"
class ApolloVisionTower(PreTrainedModel):
def __init__(self, config, vision_tower_cfg):
super(ApolloVisionTower, self).__init__(config, vision_tower_cfg)
self.model_name_or_path = vision_tower_cfg._name_or_path
self.vision_towers = vision_tower_cfg.vision_towers
self._config = vision_tower_cfg
for vision_tower_name in self.vision_towers:
if 'internvideo' in vision_tower_name.lower():
vision_tower = InternVideoTower(os.path.join(vision_tower_cfg._name_or_path, vision_tower_name), config)
elif 'siglip' in vision_tower_name.lower():
vision_tower = SiglipVisionTower(os.path.join(vision_tower_cfg._name_or_path, vision_tower_name),
config)
setattr(self, vision_tower_name, vision_tower)
self.vision_processor = [getattr(self, vt).vision_processor for vt in self.vision_towers]
self.num_vision_encoders = len(self.vision_towers)
self.W = self.H = max([getattr(self, vt).W for vt in self.vision_towers])
self.T = max([getattr(self, vt).T for vt in self.vision_towers])
self.max_tubelet_size = max(
[getattr(getattr(self, vt).vision_config, 'tubelet_size', 1) for vt in self.vision_towers])
self._hidden_size = sum([getattr(self, vt).hidden_size for vt in self.vision_towers])
self.token_output_shape = (self.T, self.W, self.H)
self.config.num_vision_encoders = self.num_vision_encoders
self.config.vision_towers = self.vision_towers
self.config.token_output_shape = self.token_output_shape
def forward(self, x):
output_features = []
for x_s, vision_tower_name in zip(x, self.vision_towers):
vision_tower = getattr(self, vision_tower_name)
features = vision_tower._forward(x_s, out_T=self.T)
if len(features.shape) != len(self.token_output_shape) + 2:
features = features.unsqueeze(1)
if features.shape[-len(self.token_output_shape) - 1:-1] != self.token_output_shape:
features = features.permute(0, 4, 1, 2, 3).contiguous() # shape [B, D, T, W, H]
features = F.interpolate(features.to(torch.float32), size=self.token_output_shape, mode='trilinear',
align_corners=False).to(features.dtype)
features = features.permute(0, 2, 3, 4, 1).contiguous()
output_features.append(features)
output_features = torch.cat(output_features, dim=-1)
output_features = torch.flatten(output_features, start_dim=1, end_dim=-2)
return output_features
def save_pretrained(
self,
save_directory: Union[str, os.PathLike],
state_dict=None,
**kwargs,
):
if state_dict is None:
state_dict = self.state_dict()
for vision_tower_name in self.vision_towers:
vision_tower = getattr(self, vision_tower_name)
vision_tower_state_dict = OrderedDict(
{k.split(f"vision_tower.{vision_tower_name}.vision_tower.")[-1]: v for k, v in state_dict.items() if
vision_tower_name in k}
)
vision_tower.vision_tower.save_pretrained(os.path.join(save_directory, vision_tower_name),
state_dict=vision_tower_state_dict, **kwargs)
vision_tower.vision_processor.save_pretrained(os.path.join(save_directory, vision_tower_name))
config = self.config
config.configs = {}
config.save_pretrained(save_directory)
@property
def patch_size(self):
return self._patch_size
@property
def image_size(self):
return self._image_size
@property
def hidden_size(self):
return self._hidden_size