Spaces:
Sleeping
Sleeping
import numpy as np | |
from PIL import Image | |
from torch.utils.data import Dataset | |
from torchvision import transforms | |
class ImageRetrievalDataset(Dataset): | |
def __init__(self, data: np.ndarray, transform: transforms.Compose) -> None: | |
self.data = data | |
self.transform = transform | |
def __len__(self) -> int: | |
return len(self.data[0]) | |
def __getitem__(self, idx: int) -> tuple: | |
input_path, label_path = self.data.T[idx] | |
input_image = Image.open(input_path).convert("RGB") | |
label_image = Image.open(label_path).convert("RGB") | |
# if self.transform: | |
input_image = self.transform(input_image) | |
label_image = self.transform(label_image) | |
return input_image, label_image, input_path, label_path | |