File size: 1,005 Bytes
b84549f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 |
from typing import Any, Callable, Optional
from torchvision.datasets import ImageFolder
from torchvision.datasets.folder import default_loader
class MMImageFolder(ImageFolder):
def __init__(
self,
root: str,
classes_exclude_ignored,
transform=None,
target_transform=None,
loader=default_loader,
is_valid_file=None,
):
super().__init__(root, transform, target_transform, loader, is_valid_file)
self.classes_exclude_ignored = classes_exclude_ignored
def __getitem__(self, index: int):
image, label = super().__getitem__(index)
classes = self.classes_exclude_ignored
# text = f"a photo of a {classes[label].replace('-', ' ').replace('_', ' ').lower()}" # should be ground truth
text = f"a photo of a {classes[label]}" # should be ground truth
x = {'images': image, 'texts': text, 'for_training': False}
return x, label
|