ghlee94's picture
Init
2a13495
import os
import torch
import shutil
import numpy as np
from PIL import Image
from tqdm import tqdm
from urllib.request import urlretrieve
class OxfordPetDataset(torch.utils.data.Dataset):
def __init__(self, root, mode="train", transform=None):
assert mode in {"train", "valid", "test"}
self.root = root
self.mode = mode
self.transform = transform
self.images_directory = os.path.join(self.root, "images")
self.masks_directory = os.path.join(self.root, "annotations", "trimaps")
self.filenames = self._read_split() # read train/valid/test splits
def __len__(self):
return len(self.filenames)
def __getitem__(self, idx):
filename = self.filenames[idx]
image_path = os.path.join(self.images_directory, filename + ".jpg")
mask_path = os.path.join(self.masks_directory, filename + ".png")
image = np.array(Image.open(image_path).convert("RGB"))
trimap = np.array(Image.open(mask_path))
mask = self._preprocess_mask(trimap)
sample = dict(image=image, mask=mask, trimap=trimap)
if self.transform is not None:
sample = self.transform(**sample)
return sample
@staticmethod
def _preprocess_mask(mask):
mask = mask.astype(np.float32)
mask[mask == 2.0] = 0.0
mask[(mask == 1.0) | (mask == 3.0)] = 1.0
return mask
def _read_split(self):
split_filename = "test.txt" if self.mode == "test" else "trainval.txt"
split_filepath = os.path.join(self.root, "annotations", split_filename)
with open(split_filepath) as f:
split_data = f.read().strip("\n").split("\n")
filenames = [x.split(" ")[0] for x in split_data]
if self.mode == "train": # 90% for train
filenames = [x for i, x in enumerate(filenames) if i % 10 != 0]
elif self.mode == "valid": # 10% for validation
filenames = [x for i, x in enumerate(filenames) if i % 10 == 0]
return filenames
@staticmethod
def download(root):
# load images
filepath = os.path.join(root, "images.tar.gz")
download_url(
url="https://www.robots.ox.ac.uk/~vgg/data/pets/data/images.tar.gz",
filepath=filepath,
)
extract_archive(filepath)
# load annotations
filepath = os.path.join(root, "annotations.tar.gz")
download_url(
url="https://www.robots.ox.ac.uk/~vgg/data/pets/data/annotations.tar.gz",
filepath=filepath,
)
extract_archive(filepath)
class SimpleOxfordPetDataset(OxfordPetDataset):
def __getitem__(self, *args, **kwargs):
sample = super().__getitem__(*args, **kwargs)
# resize images
image = np.array(
Image.fromarray(sample["image"]).resize((256, 256), Image.LINEAR)
)
mask = np.array(
Image.fromarray(sample["mask"]).resize((256, 256), Image.NEAREST)
)
trimap = np.array(
Image.fromarray(sample["trimap"]).resize((256, 256), Image.NEAREST)
)
# convert to other format HWC -> CHW
sample["image"] = np.moveaxis(image, -1, 0)
sample["mask"] = np.expand_dims(mask, 0)
sample["trimap"] = np.expand_dims(trimap, 0)
return sample
class TqdmUpTo(tqdm):
def update_to(self, b=1, bsize=1, tsize=None):
if tsize is not None:
self.total = tsize
self.update(b * bsize - self.n)
def download_url(url, filepath):
directory = os.path.dirname(os.path.abspath(filepath))
os.makedirs(directory, exist_ok=True)
if os.path.exists(filepath):
return
with TqdmUpTo(
unit="B",
unit_scale=True,
unit_divisor=1024,
miniters=1,
desc=os.path.basename(filepath),
) as t:
urlretrieve(url, filename=filepath, reporthook=t.update_to, data=None)
t.total = t.n
def extract_archive(filepath):
extract_dir = os.path.dirname(os.path.abspath(filepath))
dst_dir = os.path.splitext(filepath)[0]
if not os.path.exists(dst_dir):
shutil.unpack_archive(filepath, extract_dir)