akhaliq3
New message for the combined commit
833ef7e
import os, tarfile, glob, shutil
import yaml
import numpy as np
from tqdm import tqdm
from PIL import Image
import albumentations
from omegaconf import OmegaConf
from torch.utils.data import Dataset
from taming.data.base import ImagePaths
from taming.util import download, retrieve
import taming.data.utils as bdu
def give_synsets_from_indices(indices, path_to_yaml="data/imagenet_idx_to_synset.yaml"):
synsets = []
with open(path_to_yaml) as f:
di2s = yaml.load(f)
for idx in indices:
synsets.append(str(di2s[idx]))
print("Using {} different synsets for construction of Restriced Imagenet.".format(len(synsets)))
return synsets
def str_to_indices(string):
"""Expects a string in the format '32-123, 256, 280-321'"""
assert not string.endswith(","), "provided string '{}' ends with a comma, pls remove it".format(string)
subs = string.split(",")
indices = []
for sub in subs:
subsubs = sub.split("-")
assert len(subsubs) > 0
if len(subsubs) == 1:
indices.append(int(subsubs[0]))
else:
rang = [j for j in range(int(subsubs[0]), int(subsubs[1]))]
indices.extend(rang)
return sorted(indices)
class ImageNetBase(Dataset):
def __init__(self, config=None):
self.config = config or OmegaConf.create()
if not type(self.config)==dict:
self.config = OmegaConf.to_container(self.config)
self._prepare()
self._prepare_synset_to_human()
self._prepare_idx_to_synset()
self._load()
def __len__(self):
return len(self.data)
def __getitem__(self, i):
return self.data[i]
def _prepare(self):
raise NotImplementedError()
def _filter_relpaths(self, relpaths):
ignore = set([
"n06596364_9591.JPEG",
])
relpaths = [rpath for rpath in relpaths if not rpath.split("/")[-1] in ignore]
if "sub_indices" in self.config:
indices = str_to_indices(self.config["sub_indices"])
synsets = give_synsets_from_indices(indices, path_to_yaml=self.idx2syn) # returns a list of strings
files = []
for rpath in relpaths:
syn = rpath.split("/")[0]
if syn in synsets:
files.append(rpath)
return files
else:
return relpaths
def _prepare_synset_to_human(self):
SIZE = 2655750
URL = "https://heibox.uni-heidelberg.de/f/9f28e956cd304264bb82/?dl=1"
self.human_dict = os.path.join(self.root, "synset_human.txt")
if (not os.path.exists(self.human_dict) or
not os.path.getsize(self.human_dict)==SIZE):
download(URL, self.human_dict)
def _prepare_idx_to_synset(self):
URL = "https://heibox.uni-heidelberg.de/f/d835d5b6ceda4d3aa910/?dl=1"
self.idx2syn = os.path.join(self.root, "index_synset.yaml")
if (not os.path.exists(self.idx2syn)):
download(URL, self.idx2syn)
def _load(self):
with open(self.txt_filelist, "r") as f:
self.relpaths = f.read().splitlines()
l1 = len(self.relpaths)
self.relpaths = self._filter_relpaths(self.relpaths)
print("Removed {} files from filelist during filtering.".format(l1 - len(self.relpaths)))
self.synsets = [p.split("/")[0] for p in self.relpaths]
self.abspaths = [os.path.join(self.datadir, p) for p in self.relpaths]
unique_synsets = np.unique(self.synsets)
class_dict = dict((synset, i) for i, synset in enumerate(unique_synsets))
self.class_labels = [class_dict[s] for s in self.synsets]
with open(self.human_dict, "r") as f:
human_dict = f.read().splitlines()
human_dict = dict(line.split(maxsplit=1) for line in human_dict)
self.human_labels = [human_dict[s] for s in self.synsets]
labels = {
"relpath": np.array(self.relpaths),
"synsets": np.array(self.synsets),
"class_label": np.array(self.class_labels),
"human_label": np.array(self.human_labels),
}
self.data = ImagePaths(self.abspaths,
labels=labels,
size=retrieve(self.config, "size", default=0),
random_crop=self.random_crop)
class ImageNetTrain(ImageNetBase):
NAME = "ILSVRC2012_train"
URL = "http://www.image-net.org/challenges/LSVRC/2012/"
AT_HASH = "a306397ccf9c2ead27155983c254227c0fd938e2"
FILES = [
"ILSVRC2012_img_train.tar",
]
SIZES = [
147897477120,
]
def _prepare(self):
self.random_crop = retrieve(self.config, "ImageNetTrain/random_crop",
default=True)
cachedir = os.environ.get("XDG_CACHE_HOME", os.path.expanduser("~/.cache"))
self.root = os.path.join(cachedir, "autoencoders/data", self.NAME)
self.datadir = os.path.join(self.root, "data")
self.txt_filelist = os.path.join(self.root, "filelist.txt")
self.expected_length = 1281167
if not bdu.is_prepared(self.root):
# prep
print("Preparing dataset {} in {}".format(self.NAME, self.root))
datadir = self.datadir
if not os.path.exists(datadir):
path = os.path.join(self.root, self.FILES[0])
if not os.path.exists(path) or not os.path.getsize(path)==self.SIZES[0]:
import academictorrents as at
atpath = at.get(self.AT_HASH, datastore=self.root)
assert atpath == path
print("Extracting {} to {}".format(path, datadir))
os.makedirs(datadir, exist_ok=True)
with tarfile.open(path, "r:") as tar:
tar.extractall(path=datadir)
print("Extracting sub-tars.")
subpaths = sorted(glob.glob(os.path.join(datadir, "*.tar")))
for subpath in tqdm(subpaths):
subdir = subpath[:-len(".tar")]
os.makedirs(subdir, exist_ok=True)
with tarfile.open(subpath, "r:") as tar:
tar.extractall(path=subdir)
filelist = glob.glob(os.path.join(datadir, "**", "*.JPEG"))
filelist = [os.path.relpath(p, start=datadir) for p in filelist]
filelist = sorted(filelist)
filelist = "\n".join(filelist)+"\n"
with open(self.txt_filelist, "w") as f:
f.write(filelist)
bdu.mark_prepared(self.root)
class ImageNetValidation(ImageNetBase):
NAME = "ILSVRC2012_validation"
URL = "http://www.image-net.org/challenges/LSVRC/2012/"
AT_HASH = "5d6d0df7ed81efd49ca99ea4737e0ae5e3a5f2e5"
VS_URL = "https://heibox.uni-heidelberg.de/f/3e0f6e9c624e45f2bd73/?dl=1"
FILES = [
"ILSVRC2012_img_val.tar",
"validation_synset.txt",
]
SIZES = [
6744924160,
1950000,
]
def _prepare(self):
self.random_crop = retrieve(self.config, "ImageNetValidation/random_crop",
default=False)
cachedir = os.environ.get("XDG_CACHE_HOME", os.path.expanduser("~/.cache"))
self.root = os.path.join(cachedir, "autoencoders/data", self.NAME)
self.datadir = os.path.join(self.root, "data")
self.txt_filelist = os.path.join(self.root, "filelist.txt")
self.expected_length = 50000
if not bdu.is_prepared(self.root):
# prep
print("Preparing dataset {} in {}".format(self.NAME, self.root))
datadir = self.datadir
if not os.path.exists(datadir):
path = os.path.join(self.root, self.FILES[0])
if not os.path.exists(path) or not os.path.getsize(path)==self.SIZES[0]:
import academictorrents as at
atpath = at.get(self.AT_HASH, datastore=self.root)
assert atpath == path
print("Extracting {} to {}".format(path, datadir))
os.makedirs(datadir, exist_ok=True)
with tarfile.open(path, "r:") as tar:
tar.extractall(path=datadir)
vspath = os.path.join(self.root, self.FILES[1])
if not os.path.exists(vspath) or not os.path.getsize(vspath)==self.SIZES[1]:
download(self.VS_URL, vspath)
with open(vspath, "r") as f:
synset_dict = f.read().splitlines()
synset_dict = dict(line.split() for line in synset_dict)
print("Reorganizing into synset folders")
synsets = np.unique(list(synset_dict.values()))
for s in synsets:
os.makedirs(os.path.join(datadir, s), exist_ok=True)
for k, v in synset_dict.items():
src = os.path.join(datadir, k)
dst = os.path.join(datadir, v)
shutil.move(src, dst)
filelist = glob.glob(os.path.join(datadir, "**", "*.JPEG"))
filelist = [os.path.relpath(p, start=datadir) for p in filelist]
filelist = sorted(filelist)
filelist = "\n".join(filelist)+"\n"
with open(self.txt_filelist, "w") as f:
f.write(filelist)
bdu.mark_prepared(self.root)
def get_preprocessor(size=None, random_crop=False, additional_targets=None,
crop_size=None):
if size is not None and size > 0:
transforms = list()
rescaler = albumentations.SmallestMaxSize(max_size = size)
transforms.append(rescaler)
if not random_crop:
cropper = albumentations.CenterCrop(height=size,width=size)
transforms.append(cropper)
else:
cropper = albumentations.RandomCrop(height=size,width=size)
transforms.append(cropper)
flipper = albumentations.HorizontalFlip()
transforms.append(flipper)
preprocessor = albumentations.Compose(transforms,
additional_targets=additional_targets)
elif crop_size is not None and crop_size > 0:
if not random_crop:
cropper = albumentations.CenterCrop(height=crop_size,width=crop_size)
else:
cropper = albumentations.RandomCrop(height=crop_size,width=crop_size)
transforms = [cropper]
preprocessor = albumentations.Compose(transforms,
additional_targets=additional_targets)
else:
preprocessor = lambda **kwargs: kwargs
return preprocessor
def rgba_to_depth(x):
assert x.dtype == np.uint8
assert len(x.shape) == 3 and x.shape[2] == 4
y = x.copy()
y.dtype = np.float32
y = y.reshape(x.shape[:2])
return np.ascontiguousarray(y)
class BaseWithDepth(Dataset):
DEFAULT_DEPTH_ROOT="data/imagenet_depth"
def __init__(self, config=None, size=None, random_crop=False,
crop_size=None, root=None):
self.config = config
self.base_dset = self.get_base_dset()
self.preprocessor = get_preprocessor(
size=size,
crop_size=crop_size,
random_crop=random_crop,
additional_targets={"depth": "image"})
self.crop_size = crop_size
if self.crop_size is not None:
self.rescaler = albumentations.Compose(
[albumentations.SmallestMaxSize(max_size = self.crop_size)],
additional_targets={"depth": "image"})
if root is not None:
self.DEFAULT_DEPTH_ROOT = root
def __len__(self):
return len(self.base_dset)
def preprocess_depth(self, path):
rgba = np.array(Image.open(path))
depth = rgba_to_depth(rgba)
depth = (depth - depth.min())/max(1e-8, depth.max()-depth.min())
depth = 2.0*depth-1.0
return depth
def __getitem__(self, i):
e = self.base_dset[i]
e["depth"] = self.preprocess_depth(self.get_depth_path(e))
# up if necessary
h,w,c = e["image"].shape
if self.crop_size and min(h,w) < self.crop_size:
# have to upscale to be able to crop - this just uses bilinear
out = self.rescaler(image=e["image"], depth=e["depth"])
e["image"] = out["image"]
e["depth"] = out["depth"]
transformed = self.preprocessor(image=e["image"], depth=e["depth"])
e["image"] = transformed["image"]
e["depth"] = transformed["depth"]
return e
class ImageNetTrainWithDepth(BaseWithDepth):
# default to random_crop=True
def __init__(self, random_crop=True, sub_indices=None, **kwargs):
self.sub_indices = sub_indices
super().__init__(random_crop=random_crop, **kwargs)
def get_base_dset(self):
if self.sub_indices is None:
return ImageNetTrain()
else:
return ImageNetTrain({"sub_indices": self.sub_indices})
def get_depth_path(self, e):
fid = os.path.splitext(e["relpath"])[0]+".png"
fid = os.path.join(self.DEFAULT_DEPTH_ROOT, "train", fid)
return fid
class ImageNetValidationWithDepth(BaseWithDepth):
def __init__(self, sub_indices=None, **kwargs):
self.sub_indices = sub_indices
super().__init__(**kwargs)
def get_base_dset(self):
if self.sub_indices is None:
return ImageNetValidation()
else:
return ImageNetValidation({"sub_indices": self.sub_indices})
def get_depth_path(self, e):
fid = os.path.splitext(e["relpath"])[0]+".png"
fid = os.path.join(self.DEFAULT_DEPTH_ROOT, "val", fid)
return fid
class RINTrainWithDepth(ImageNetTrainWithDepth):
def __init__(self, config=None, size=None, random_crop=True, crop_size=None):
sub_indices = "30-32, 33-37, 151-268, 281-285, 80-100, 365-382, 389-397, 118-121, 300-319"
super().__init__(config=config, size=size, random_crop=random_crop,
sub_indices=sub_indices, crop_size=crop_size)
class RINValidationWithDepth(ImageNetValidationWithDepth):
def __init__(self, config=None, size=None, random_crop=False, crop_size=None):
sub_indices = "30-32, 33-37, 151-268, 281-285, 80-100, 365-382, 389-397, 118-121, 300-319"
super().__init__(config=config, size=size, random_crop=random_crop,
sub_indices=sub_indices, crop_size=crop_size)
class DRINExamples(Dataset):
def __init__(self):
self.preprocessor = get_preprocessor(size=256, additional_targets={"depth": "image"})
with open("data/drin_examples.txt", "r") as f:
relpaths = f.read().splitlines()
self.image_paths = [os.path.join("data/drin_images",
relpath) for relpath in relpaths]
self.depth_paths = [os.path.join("data/drin_depth",
relpath.replace(".JPEG", ".png")) for relpath in relpaths]
def __len__(self):
return len(self.image_paths)
def preprocess_image(self, image_path):
image = Image.open(image_path)
if not image.mode == "RGB":
image = image.convert("RGB")
image = np.array(image).astype(np.uint8)
image = self.preprocessor(image=image)["image"]
image = (image/127.5 - 1.0).astype(np.float32)
return image
def preprocess_depth(self, path):
rgba = np.array(Image.open(path))
depth = rgba_to_depth(rgba)
depth = (depth - depth.min())/max(1e-8, depth.max()-depth.min())
depth = 2.0*depth-1.0
return depth
def __getitem__(self, i):
e = dict()
e["image"] = self.preprocess_image(self.image_paths[i])
e["depth"] = self.preprocess_depth(self.depth_paths[i])
transformed = self.preprocessor(image=e["image"], depth=e["depth"])
e["image"] = transformed["image"]
e["depth"] = transformed["depth"]
return e
def imscale(x, factor, keepshapes=False, keepmode="bicubic"):
if factor is None or factor==1:
return x
dtype = x.dtype
assert dtype in [np.float32, np.float64]
assert x.min() >= -1
assert x.max() <= 1
keepmode = {"nearest": Image.NEAREST, "bilinear": Image.BILINEAR,
"bicubic": Image.BICUBIC}[keepmode]
lr = (x+1.0)*127.5
lr = lr.clip(0,255).astype(np.uint8)
lr = Image.fromarray(lr)
h, w, _ = x.shape
nh = h//factor
nw = w//factor
assert nh > 0 and nw > 0, (nh, nw)
lr = lr.resize((nw,nh), Image.BICUBIC)
if keepshapes:
lr = lr.resize((w,h), keepmode)
lr = np.array(lr)/127.5-1.0
lr = lr.astype(dtype)
return lr
class ImageNetScale(Dataset):
def __init__(self, size=None, crop_size=None, random_crop=False,
up_factor=None, hr_factor=None, keep_mode="bicubic"):
self.base = self.get_base()
self.size = size
self.crop_size = crop_size if crop_size is not None else self.size
self.random_crop = random_crop
self.up_factor = up_factor
self.hr_factor = hr_factor
self.keep_mode = keep_mode
transforms = list()
if self.size is not None and self.size > 0:
rescaler = albumentations.SmallestMaxSize(max_size = self.size)
self.rescaler = rescaler
transforms.append(rescaler)
if self.crop_size is not None and self.crop_size > 0:
if len(transforms) == 0:
self.rescaler = albumentations.SmallestMaxSize(max_size = self.crop_size)
if not self.random_crop:
cropper = albumentations.CenterCrop(height=self.crop_size,width=self.crop_size)
else:
cropper = albumentations.RandomCrop(height=self.crop_size,width=self.crop_size)
transforms.append(cropper)
if len(transforms) > 0:
if self.up_factor is not None:
additional_targets = {"lr": "image"}
else:
additional_targets = None
self.preprocessor = albumentations.Compose(transforms,
additional_targets=additional_targets)
else:
self.preprocessor = lambda **kwargs: kwargs
def __len__(self):
return len(self.base)
def __getitem__(self, i):
example = self.base[i]
image = example["image"]
# adjust resolution
image = imscale(image, self.hr_factor, keepshapes=False)
h,w,c = image.shape
if self.crop_size and min(h,w) < self.crop_size:
# have to upscale to be able to crop - this just uses bilinear
image = self.rescaler(image=image)["image"]
if self.up_factor is None:
image = self.preprocessor(image=image)["image"]
example["image"] = image
else:
lr = imscale(image, self.up_factor, keepshapes=True,
keepmode=self.keep_mode)
out = self.preprocessor(image=image, lr=lr)
example["image"] = out["image"]
example["lr"] = out["lr"]
return example
class ImageNetScaleTrain(ImageNetScale):
def __init__(self, random_crop=True, **kwargs):
super().__init__(random_crop=random_crop, **kwargs)
def get_base(self):
return ImageNetTrain()
class ImageNetScaleValidation(ImageNetScale):
def get_base(self):
return ImageNetValidation()
from skimage.feature import canny
from skimage.color import rgb2gray
class ImageNetEdges(ImageNetScale):
def __init__(self, up_factor=1, **kwargs):
super().__init__(up_factor=1, **kwargs)
def __getitem__(self, i):
example = self.base[i]
image = example["image"]
h,w,c = image.shape
if self.crop_size and min(h,w) < self.crop_size:
# have to upscale to be able to crop - this just uses bilinear
image = self.rescaler(image=image)["image"]
lr = canny(rgb2gray(image), sigma=2)
lr = lr.astype(np.float32)
lr = lr[:,:,None][:,:,[0,0,0]]
out = self.preprocessor(image=image, lr=lr)
example["image"] = out["image"]
example["lr"] = out["lr"]
return example
class ImageNetEdgesTrain(ImageNetEdges):
def __init__(self, random_crop=True, **kwargs):
super().__init__(random_crop=random_crop, **kwargs)
def get_base(self):
return ImageNetTrain()
class ImageNetEdgesValidation(ImageNetEdges):
def get_base(self):
return ImageNetValidation()