File size: 2,488 Bytes
2f044c1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 |
import os
from pathlib import Path
from typing import Any, Dict, Iterator, List, Optional, Tuple, Union
import torch
from torch.utils.data import Dataset, IterableDataset
from relik.common.log import get_logger
logger = get_logger(__name__)
class BaseDataset(Dataset):
def __init__(
self,
name: str,
path: Optional[Union[str, os.PathLike, List[str], List[os.PathLike]]] = None,
data: Any = None,
**kwargs,
):
super().__init__()
self.name = name
if path is None and data is None:
raise ValueError("Either `path` or `data` must be provided")
self.path = path
self.project_folder = Path(__file__).parent.parent.parent
self.data = data
def __len__(self) -> int:
return len(self.data)
def __getitem__(
self, index
) -> Union[Dict[str, torch.Tensor], Tuple[torch.Tensor, torch.Tensor]]:
return self.data[index]
def __repr__(self) -> str:
return f"Dataset({self.name=}, {self.path=})"
def load(
self,
paths: Union[str, os.PathLike, List[str], List[os.PathLike]],
*args,
**kwargs,
) -> Any:
# load data from single or multiple paths in one single dataset
raise NotImplementedError
@staticmethod
def collate_fn(batch: Any, *args, **kwargs) -> Any:
raise NotImplementedError
class IterableBaseDataset(IterableDataset):
def __init__(
self,
name: str,
path: Optional[Union[str, Path, List[str], List[Path]]] = None,
data: Any = None,
*args,
**kwargs,
):
super().__init__()
self.name = name
if path is None and data is None:
raise ValueError("Either `path` or `data` must be provided")
self.path = path
self.project_folder = Path(__file__).parent.parent.parent
self.data = data
def __iter__(self) -> Iterator[Dict[str, torch.Tensor]]:
for sample in self.data:
yield sample
def __repr__(self) -> str:
return f"Dataset({self.name=}, {self.path=})"
def load(
self,
paths: Union[str, os.PathLike, List[str], List[os.PathLike]],
*args,
**kwargs,
) -> Any:
# load data from single or multiple paths in one single dataset
raise NotImplementedError
@staticmethod
def collate_fn(batch: Any, *args, **kwargs) -> Any:
raise NotImplementedError
|