from transformers import CLIPImageProcessor from transformers.image_processing_utils import BatchFeature, get_size_dict from transformers.image_transforms import get_resize_output_image_size import torch import torch.nn.functional as F import numpy as np class VideoFramesProcessor(CLIPImageProcessor): def __init__(self, **kwargs): super().__init__(**kwargs) def preprocess(self, images, **kwargs): if not isinstance(images, np.ndarray): return super().preprocess(images=images, **kwargs) do_resize = kwargs.get('do_resize', self.do_resize) size = kwargs.get('size', self.size) size = get_size_dict(size, param_name="size", default_to_square=False) do_center_crop = kwargs.get('do_center_crop', self.do_center_crop) crop_size = kwargs.get('crop_size', self.crop_size) crop_size = get_size_dict(crop_size, param_name="crop_size", default_to_square=True) do_rescale = kwargs.get('do_rescale', self.do_rescale) rescale_factor = kwargs.get('rescale_factor', self.rescale_factor) do_normalize = kwargs.get('do_normalize', self.do_normalize) image_mean = kwargs.get('image_mean', self.image_mean) image_std = kwargs.get('image_std', self.image_std) return_tensors = kwargs.get('return_tensors', None) def resize(images, output_size): images = images.permute((0, 3, 1, 2)) images = F.interpolate(images, size=output_size, mode='bicubic') images = images.permute((0, 2, 3, 1)) return images def center_crop(images, crop_size): crop_width, crop_height = crop_size["width"], crop_size["height"] img_width, img_height = images.shape[1:3] x = (img_width - crop_width) // 2 y = (img_height - crop_height) // 2 images = images[:, x:x+crop_width, y:y+crop_height] return images def rescale(images, rescale_factor): images = images * rescale_factor return images def normalize(images, mean, std): mean = torch.tensor(mean) std = torch.tensor(std) images = (images - mean) / std return images images = torch.from_numpy(images).float() if do_resize: output_size = get_resize_output_image_size(images[0], size=size["shortest_edge"], default_to_square=False) images = resize(images, output_size) if do_center_crop: images = center_crop(images, crop_size) if do_rescale: images = rescale(images, rescale_factor) if do_normalize: images = normalize(images, image_mean, image_std) images = images.permute((0, 3, 1, 2)) data = {"pixel_values": images} return BatchFeature(data=data, tensor_type=return_tensors)