mcding
published version
ad552d8
import numpy as np
import torch
import torchvision
from PIL import Image
import zipfile
from .resize import build_resizer
class ResizeDataset(torch.utils.data.Dataset):
"""
A placeholder Dataset that enables parallelizing the resize operation
using multiple CPU cores
files: list of all files in the folder
fn_resize: function that takes an np_array as input [0,255]
"""
def __init__(self, files, mode, size=(299, 299), fdir=None):
self.files = files
self.fdir = fdir
self.transforms = torchvision.transforms.ToTensor()
self.size = size
self.fn_resize = build_resizer(mode)
self.custom_image_tranform = lambda x: x
self._zipfile = None
def _get_zipfile(self):
assert self.fdir is not None and ".zip" in self.fdir
if self._zipfile is None:
self._zipfile = zipfile.ZipFile(self.fdir)
return self._zipfile
def __len__(self):
return len(self.files)
def __getitem__(self, i):
path = str(self.files[i])
if self.fdir is not None and ".zip" in self.fdir:
with self._get_zipfile().open(path, "r") as f:
img_np = np.array(Image.open(f).convert("RGB"))
elif ".npy" in path:
img_np = np.load(path)
else:
img_pil = Image.open(path).convert("RGB")
img_np = np.array(img_pil)
# apply a custom image transform before resizing the image to 299x299
img_np = self.custom_image_tranform(img_np)
# fn_resize expects a np array and returns a np array
img_resized = self.fn_resize(img_np)
# ToTensor() converts to [0,1] only if input in uint8
if img_resized.dtype == "uint8":
img_t = self.transforms(np.array(img_resized)) * 255
elif img_resized.dtype == "float32":
img_t = self.transforms(img_resized)
return img_t
EXTENSIONS = {
"bmp",
"jpg",
"jpeg",
"pgm",
"png",
"ppm",
"tif",
"tiff",
"webp",
"npy",
"JPEG",
"JPG",
"PNG",
}