chicelli's picture
Upload 21 files
9f68e7c verified
raw
history blame contribute delete
770 Bytes
import numpy as np
from PIL import Image
from torch.utils.data import Dataset
from torchvision import transforms
class ImageRetrievalDataset(Dataset):
def __init__(self, data: np.ndarray, transform: transforms.Compose) -> None:
self.data = data
self.transform = transform
def __len__(self) -> int:
return len(self.data[0])
def __getitem__(self, idx: int) -> tuple:
input_path, label_path = self.data.T[idx]
input_image = Image.open(input_path).convert("RGB")
label_image = Image.open(label_path).convert("RGB")
# if self.transform:
input_image = self.transform(input_image)
label_image = self.transform(label_image)
return input_image, label_image, input_path, label_path