|
|
|
|
|
|
|
|
|
|
|
|
|
import logging |
|
|
|
import numpy as np |
|
import os |
|
from typing import Optional, Callable, Set |
|
|
|
import torch |
|
|
|
from torchvision.datasets.vision import VisionDataset |
|
from torchvision.transforms import ToTensor |
|
|
|
from fairseq.data import FairseqDataset |
|
|
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
class ImageDataset(FairseqDataset, VisionDataset): |
|
def __init__( |
|
self, |
|
root: str, |
|
extensions: Set[str], |
|
load_classes: bool, |
|
transform: Optional[Callable] = None, |
|
shuffle=True, |
|
): |
|
FairseqDataset.__init__(self) |
|
VisionDataset.__init__(self, root=root, transform=transform) |
|
|
|
self.shuffle = shuffle |
|
self.tensor_transform = ToTensor() |
|
|
|
self.classes = None |
|
self.labels = None |
|
if load_classes: |
|
classes = [d.name for d in os.scandir(root) if d.is_dir()] |
|
classes.sort() |
|
self.classes = {cls_name: i for i, cls_name in enumerate(classes)} |
|
logger.info(f"loaded {len(self.classes)} classes") |
|
self.labels = [] |
|
|
|
def walk_path(root_path): |
|
for root, _, fnames in sorted(os.walk(root_path, followlinks=True)): |
|
for fname in sorted(fnames): |
|
fname_ext = os.path.splitext(fname) |
|
if fname_ext[-1].lower() not in extensions: |
|
continue |
|
|
|
path = os.path.join(root, fname) |
|
yield path |
|
|
|
logger.info(f"finding images in {root}") |
|
if self.classes is not None: |
|
self.files = [] |
|
self.labels = [] |
|
for c, i in self.classes.items(): |
|
for f in walk_path(os.path.join(root, c)): |
|
self.files.append(f) |
|
self.labels.append(i) |
|
else: |
|
self.files = [f for f in walk_path(root)] |
|
|
|
logger.info(f"loaded {len(self.files)} examples") |
|
|
|
def __getitem__(self, index): |
|
from PIL import Image |
|
|
|
fpath = self.files[index] |
|
|
|
with open(fpath, "rb") as f: |
|
img = Image.open(f).convert("RGB") |
|
|
|
if self.transform is None: |
|
img = self.tensor_transform(img) |
|
else: |
|
img = self.transform(img) |
|
assert torch.is_tensor(img) |
|
|
|
res = {"id": index, "img": img} |
|
|
|
if self.labels is not None: |
|
res["label"] = self.labels[index] |
|
|
|
return res |
|
|
|
def __len__(self): |
|
return len(self.files) |
|
|
|
def collater(self, samples): |
|
if len(samples) == 0: |
|
return {} |
|
|
|
collated_img = torch.stack([s["img"] for s in samples], dim=0) |
|
|
|
res = { |
|
"id": torch.LongTensor([s["id"] for s in samples]), |
|
"net_input": { |
|
"img": collated_img, |
|
}, |
|
} |
|
|
|
if "label" in samples[0]: |
|
res["net_input"]["label"] = torch.LongTensor([s["label"] for s in samples]) |
|
|
|
return res |
|
|
|
def num_tokens(self, index): |
|
return 1 |
|
|
|
def size(self, index): |
|
return 1 |
|
|
|
def ordered_indices(self): |
|
"""Return an ordered list of indices. Batches will be constructed based |
|
on this order.""" |
|
if self.shuffle: |
|
order = [np.random.permutation(len(self))] |
|
else: |
|
order = [np.arange(len(self))] |
|
|
|
return order[0] |
|
|