Spaces:
Sleeping
Sleeping
from torch.utils.data import Dataset | |
from torchvision import transforms as tran | |
from datasets import load_dataset | |
from .utils import log | |
from .config import Config | |
class TrainData(Dataset): | |
def __init__(self, config: Config, gray_scale=True): | |
super().__init__() | |
self.transforms = tran.Compose( | |
[ | |
tran.Resize((config.image_size, config.image_size)), | |
tran.ToTensor(), | |
tran.Normalize([0.5], [0.5]), | |
tran.Grayscale() if gray_scale else None, | |
] | |
) | |
self.images = None | |
self.labels = None | |
self.download_dataset() | |
def download_dataset(self): | |
dataset = load_dataset("zaibutcooler/archi-vault") | |
self.images = dataset["train"]["images"] | |
self.labels = dataset["train"]["labels"] | |
log("Successfully loaded the dataset") | |
def __getitem__(self, index): | |
image = self.transforms(self.images[index]) | |
label = self.labels[index] | |
return image, label | |
def __len__(self): | |
return len(self.images) |