ve-gans / vegans /dataset.py
Zai
added readme and cleanup files
f2b0abe
from torch.utils.data import Dataset
from torchvision import transforms as tran
from datasets import load_dataset
from .utils import log
from .config import Config
class TrainData(Dataset):
def __init__(self, config: Config, gray_scale=True):
super().__init__()
self.transforms = tran.Compose(
[
tran.Resize((config.image_size, config.image_size)),
tran.ToTensor(),
tran.Normalize([0.5], [0.5]),
tran.Grayscale() if gray_scale else None,
]
)
self.images = None
self.labels = None
self.download_dataset()
def download_dataset(self):
dataset = load_dataset("zaibutcooler/archi-vault")
self.images = dataset["train"]["images"]
self.labels = dataset["train"]["labels"]
log("Successfully loaded the dataset")
def __getitem__(self, index):
image = self.transforms(self.images[index])
label = self.labels[index]
return image, label
def __len__(self):
return len(self.images)