File size: 2,138 Bytes
e94100d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
from collections import defaultdict
from typing import TypeVar, Type, Dict, List
import importlib
import logging

logger = logging.getLogger("toolbox")

T = TypeVar("T")


class Registrable(object):
    _registry: Dict[Type, Dict[str, Type]] = defaultdict(dict)
    default_implementation: str = None
    register_name: str = "unknown"

    @classmethod
    def register(cls: Type[T], name: str, exist_ok=False):
        registry = Registrable._registry[cls]
        def add_subclass_to_registry(subclass: Type[T]):
            # set a name on the subclass
            setattr(subclass, "register_name", name)
            if name in registry:
                if exist_ok:
                    message = (f"{name} has already been registered as {registry[name].__name__}, but "
                               f"exist_ok=True, so overwriting with {cls.__name__}")
                    # logger.info(message)
                else:
                    message = (f"Cannot register {name} as {cls.__name__}; "
                               f"name already in use for {registry[name].__name__}")
                    raise ValueError(message)
            registry[name] = subclass
            return subclass
        return add_subclass_to_registry

    @classmethod
    def by_name(cls: Type[T], name: str) -> Type[T]:
        # logger.info(f"instantiating registered subclass {name} of {cls}")
        if name in Registrable._registry[cls]:
            return Registrable._registry[cls].get(name)
        else:
            raise ValueError(
                f"{name} is not a registered name for {cls.__name__}. "
                f"the available is: [{Registrable._registry[cls].keys()}]"
            )


    @classmethod
    def list_available(cls) -> List[str]:
        keys = list(Registrable._registry[cls].keys())
        default = cls.default_implementation

        if default is None:
            return keys
        elif default not in keys:
            message = "Default implementation %s is not registered" % default
            raise ValueError(message)
        else:
            return [default] + [k for k in keys if k != default]