Spaces:
Sleeping
Sleeping
import os | |
from pathlib import Path | |
from typing import Any, Dict, Iterator, List, Optional, Tuple, Union | |
import torch | |
from torch.utils.data import Dataset, IterableDataset | |
from relik.common.log import get_logger | |
logger = get_logger() | |
class BaseDataset(Dataset): | |
def __init__( | |
self, | |
name: str, | |
path: Optional[Union[str, os.PathLike, List[str], List[os.PathLike]]] = None, | |
data: Any = None, | |
**kwargs, | |
): | |
super().__init__() | |
self.name = name | |
if path is None and data is None: | |
raise ValueError("Either `path` or `data` must be provided") | |
self.path = path | |
self.project_folder = Path(__file__).parent.parent.parent | |
self.data = data | |
def __len__(self) -> int: | |
return len(self.data) | |
def __getitem__( | |
self, index | |
) -> Union[Dict[str, torch.Tensor], Tuple[torch.Tensor, torch.Tensor]]: | |
return self.data[index] | |
def __repr__(self) -> str: | |
return f"Dataset({self.name=}, {self.path=})" | |
def load( | |
self, | |
paths: Union[str, os.PathLike, List[str], List[os.PathLike]], | |
*args, | |
**kwargs, | |
) -> Any: | |
# load data from single or multiple paths in one single dataset | |
raise NotImplementedError | |
def collate_fn(batch: Any, *args, **kwargs) -> Any: | |
raise NotImplementedError | |
class IterableBaseDataset(IterableDataset): | |
def __init__( | |
self, | |
name: str, | |
path: Optional[Union[str, Path, List[str], List[Path]]] = None, | |
data: Any = None, | |
*args, | |
**kwargs, | |
): | |
super().__init__() | |
self.name = name | |
if path is None and data is None: | |
raise ValueError("Either `path` or `data` must be provided") | |
self.path = path | |
self.project_folder = Path(__file__).parent.parent.parent | |
self.data = data | |
def __iter__(self) -> Iterator[Dict[str, torch.Tensor]]: | |
for sample in self.data: | |
yield sample | |
def __repr__(self) -> str: | |
return f"Dataset({self.name=}, {self.path=})" | |
def load( | |
self, | |
paths: Union[str, os.PathLike, List[str], List[os.PathLike]], | |
*args, | |
**kwargs, | |
) -> Any: | |
# load data from single or multiple paths in one single dataset | |
raise NotImplementedError | |
def collate_fn(batch: Any, *args, **kwargs) -> Any: | |
raise NotImplementedError | |