SwinTExCo / src /data /functional.py
duongttr's picture
Upload folder using huggingface_hub
62ef5f4
raw
history blame
2.5 kB
from __future__ import division
import torch
import numbers
import collections
import numpy as np
from PIL import Image, ImageOps
def _is_pil_image(img):
return isinstance(img, Image.Image)
def _is_tensor_image(img):
return torch.is_tensor(img) and img.ndimension() == 3
def _is_numpy_image(img):
return isinstance(img, np.ndarray) and (img.ndim in {2, 3})
def to_mytensor(pic):
pic_arr = np.array(pic)
if pic_arr.ndim == 2:
pic_arr = pic_arr[..., np.newaxis]
img = torch.from_numpy(pic_arr.transpose((2, 0, 1)))
if not isinstance(img, torch.FloatTensor):
return img.float() # no normalize .div(255)
else:
return img
def normalize(tensor, mean, std):
if not _is_tensor_image(tensor):
raise TypeError("tensor is not a torch image.")
if tensor.size(0) == 1:
tensor.sub_(mean).div_(std)
else:
for t, m, s in zip(tensor, mean, std):
t.sub_(m).div_(s)
return tensor
def resize(img, size, interpolation=Image.BILINEAR):
if not _is_pil_image(img):
raise TypeError("img should be PIL Image. Got {}".format(type(img)))
if not isinstance(size, int) and (not isinstance(size, collections.Iterable) or len(size) != 2):
raise TypeError("Got inappropriate size arg: {}".format(size))
if not isinstance(size, int):
return img.resize(size[::-1], interpolation)
w, h = img.size
if (w <= h and w == size) or (h <= w and h == size):
return img
if w < h:
ow = size
oh = int(round(size * h / w))
else:
oh = size
ow = int(round(size * w / h))
return img.resize((ow, oh), interpolation)
def pad(img, padding, fill=0):
if not _is_pil_image(img):
raise TypeError("img should be PIL Image. Got {}".format(type(img)))
if not isinstance(padding, (numbers.Number, tuple)):
raise TypeError("Got inappropriate padding arg")
if not isinstance(fill, (numbers.Number, str, tuple)):
raise TypeError("Got inappropriate fill arg")
if isinstance(padding, collections.Sequence) and len(padding) not in [2, 4]:
raise ValueError("Padding must be an int or a 2, or 4 element tuple, not a " + "{} element tuple".format(len(padding)))
return ImageOps.expand(img, border=padding, fill=fill)
def crop(img, i, j, h, w):
if not _is_pil_image(img):
raise TypeError("img should be PIL Image. Got {}".format(type(img)))
return img.crop((j, i, j + w, i + h))