from transformers.image_processing_utils import ImageProcessingMixin, BatchFeature from torchvision.transforms import transforms as tf import torchvision.transforms.functional as F from PIL import Image import torch class CondViTProcessor(ImageProcessingMixin): def __init__( self, bkg_color=255, input_resolution=224, image_mean=(0.48145466, 0.4578275, 0.40821073), image_std=(0.26862954, 0.26130258, 0.27577711), categories=[ "Bags", "Feet", "Hands", "Head", "Lower Body", "Neck", "Outwear", "Upper Body", "Waist", "Whole Body", ], **kwargs, ): super().__init__(**kwargs) self.bkg_color = bkg_color self.input_resolution = input_resolution self.image_mean = image_mean self.image_std = image_std self.categories = categories def square_pad(self, image): max_wh = max(image.size) p_left, p_top = [(max_wh - s) // 2 for s in image.size] p_right, p_bottom = [ max_wh - (s + pad) for s, pad in zip(image.size, [p_left, p_top]) ] padding = (p_left, p_top, p_right, p_bottom) return F.pad(image, padding, self.bkg_color, "constant") def process_img(self, image): img = self.square_pad(image) img = F.resize(img, self.input_resolution) img = F.to_tensor(img) img = F.normalize(img, self.image_mean, self.image_std) return img def process_cat(self, cat): if cat is not None: cat = torch.tensor(self.categories.index(cat), dtype=int) return cat def __call__(self, images, categories=None): """ Parameters ---------- images : Union[Image.Image, List[Image.Image]] Image or list of images to process categories : Optional[Union[str, List[str]]] Category or list of categories to process Returns ------- BatchFeature pixel_values : torch.Tensor Processed image tensor (B C H W) category_indices : torch.Tensor Categories indices (B) """ use_cats = categories is not None # Single Image + Single category if isinstance(images, Image.Image): images = [images] if use_cats: categories = [categories] data = {} data["pixel_values"] = torch.stack([self.process_img(img) for img in images]) if use_cats: data["category_indices"] = torch.stack( [self.process_cat(c) for c in categories] ) return BatchFeature(data=data)