Spaces:
Runtime error
Runtime error
import numpy as np | |
import torch | |
import torchvision | |
from PIL import Image | |
import zipfile | |
from .resize import build_resizer | |
class ResizeDataset(torch.utils.data.Dataset): | |
""" | |
A placeholder Dataset that enables parallelizing the resize operation | |
using multiple CPU cores | |
files: list of all files in the folder | |
fn_resize: function that takes an np_array as input [0,255] | |
""" | |
def __init__(self, files, mode, size=(299, 299), fdir=None): | |
self.files = files | |
self.fdir = fdir | |
self.transforms = torchvision.transforms.ToTensor() | |
self.size = size | |
self.fn_resize = build_resizer(mode) | |
self.custom_image_tranform = lambda x: x | |
self._zipfile = None | |
def _get_zipfile(self): | |
assert self.fdir is not None and ".zip" in self.fdir | |
if self._zipfile is None: | |
self._zipfile = zipfile.ZipFile(self.fdir) | |
return self._zipfile | |
def __len__(self): | |
return len(self.files) | |
def __getitem__(self, i): | |
path = str(self.files[i]) | |
if self.fdir is not None and ".zip" in self.fdir: | |
with self._get_zipfile().open(path, "r") as f: | |
img_np = np.array(Image.open(f).convert("RGB")) | |
elif ".npy" in path: | |
img_np = np.load(path) | |
else: | |
img_pil = Image.open(path).convert("RGB") | |
img_np = np.array(img_pil) | |
# apply a custom image transform before resizing the image to 299x299 | |
img_np = self.custom_image_tranform(img_np) | |
# fn_resize expects a np array and returns a np array | |
img_resized = self.fn_resize(img_np) | |
# ToTensor() converts to [0,1] only if input in uint8 | |
if img_resized.dtype == "uint8": | |
img_t = self.transforms(np.array(img_resized)) * 255 | |
elif img_resized.dtype == "float32": | |
img_t = self.transforms(img_resized) | |
return img_t | |
EXTENSIONS = { | |
"bmp", | |
"jpg", | |
"jpeg", | |
"pgm", | |
"png", | |
"ppm", | |
"tif", | |
"tiff", | |
"webp", | |
"npy", | |
"JPEG", | |
"JPG", | |
"PNG", | |
} | |