File size: 2,391 Bytes
786f6a6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 |
from abc import ABC, abstractmethod
from typing import Dict, List, Optional, Union
class DatasetInfo(ABC):
def __init__(self):
pass
@abstractmethod
def num_classes(self):
pass
@abstractmethod
def label_names(self):
pass
@abstractmethod
def label_descriptions(self, detailed: bool = False, as_dict: bool = False) -> Union[List[str], Dict[str, str]]:
pass
@abstractmethod
def index_to_label_name(self, index) -> str:
pass
@abstractmethod
def index_to_description(self, index: int, detailed: bool = False) -> str:
pass
@abstractmethod
def label_name_to_description(self, label: str, detailed: bool = False) -> str:
pass
class CustomDatasetInfo(DatasetInfo):
""" DatasetInfo that wraps passed values for custom datasets."""
def __init__(
self,
label_names: Union[List[str], Dict[int, str]],
label_descriptions: Optional[Dict[str, str]] = None
):
super().__init__()
assert len(label_names) > 0
self._label_names = label_names # label index => label name mapping
self._label_descriptions = label_descriptions # label name => label description mapping
if self._label_descriptions is not None:
# validate descriptions (label names required)
assert isinstance(self._label_descriptions, dict)
for n in self._label_names:
assert n in self._label_descriptions
def num_classes(self):
return len(self._label_names)
def label_names(self):
return self._label_names
def label_descriptions(self, detailed: bool = False, as_dict: bool = False) -> Union[List[str], Dict[str, str]]:
return self._label_descriptions
def label_name_to_description(self, label: str, detailed: bool = False) -> str:
if self._label_descriptions:
return self._label_descriptions[label]
return label # return label name itself if a descriptions is not present
def index_to_label_name(self, index) -> str:
assert 0 <= index < len(self._label_names)
return self._label_names[index]
def index_to_description(self, index: int, detailed: bool = False) -> str:
label = self.index_to_label_name(index)
return self.label_name_to_description(label, detailed=detailed)
|