japanese-clip-vit-h-14-bert-deeper / image_processing_custom_clip.py
bsyx001's picture
Upload processor
b63f335 verified
raw
history blame
3.42 kB
"""
Use torchvision instead of transformers to perform resize and center crop.
This is because transformers' version is sometimes 1-pixel off.
For example, if the image size is 640x480, both results are consistent.
(e.g., "http://images.cocodataset.org/val2017/000000039769.jpg")
However, if the image size is 500x334, the following happens:
(e.g., "http://images.cocodataset.org/val2014/COCO_val2014_000000324158.jpg")
>>> # Results' shape: (h, w)
>>> torch.allclose(torchvision_result[:, :-1], transformers_result[:, 1:])
... True
Note that if only resize is performed with torchvision,
the inconsistency remains.
Therefore, center crop must also be done with torchvision.
"""
import PIL
from torchvision.transforms import CenterCrop, InterpolationMode, Resize
from transformers import AutoImageProcessor, CLIPImageProcessor
from transformers.image_processing_utils import get_size_dict
from transformers.image_utils import ImageInput, PILImageResampling, make_list_of_images
def PILImageResampling_to_InterpolationMode(
resample: PILImageResampling,
) -> InterpolationMode:
return getattr(InterpolationMode, PILImageResampling(resample).name)
class CustomCLIPImageProcessor(CLIPImageProcessor):
def preprocess(
self,
images: ImageInput,
do_resize: bool = None,
size: dict[str, int] = None,
resample: PILImageResampling = None,
do_center_crop: bool = None,
crop_size: int = None,
**kwargs,
) -> PIL.Image.Image:
do_resize = do_resize if do_resize is not None else self.do_resize
size = size if size is not None else self.size
resample = resample if resample is not None else self.resample
do_center_crop = (
do_center_crop if do_center_crop is not None else self.do_center_crop
)
crop_size = crop_size if crop_size is not None else self.crop_size
images = make_list_of_images(images)
if do_resize:
# TODO input_data_format is ignored
_size = get_size_dict(
size,
param_name="size",
default_to_square=getattr(self, "use_square_size", False),
)
if set(_size) == {"shortest_edge"}:
# Corresponds to `image_transform.transforms[0]`
resize = Resize(
size=_size["shortest_edge"],
interpolation=PILImageResampling_to_InterpolationMode(resample),
)
images = [resize(image) for image in images]
do_resize = False
if do_center_crop:
# TODO input_data_format is ignored
_crop_size = get_size_dict(
crop_size, param_name="crop_size", default_to_square=True
)
# Corresponds to `image_transform.transforms[1]`
center_crop = CenterCrop(
size=tuple(map(_crop_size.get, ["height", "width"]))
)
images = [center_crop(image) for image in images]
do_center_crop = False
return super().preprocess(
images=images,
do_resize=do_resize,
size=size,
resample=resample,
do_center_crop=do_center_crop,
crop_size=crop_size,
**kwargs,
)
AutoImageProcessor.register("CustomCLIPImageProcessor", CustomCLIPImageProcessor)