|
|
|
|
|
|
|
import inspect |
|
import warnings |
|
from collections import defaultdict |
|
from typing import Any, List, Dict |
|
from pathlib import Path |
|
|
|
|
|
def import_(target: str, allow_none: bool = False) -> Any: |
|
if target is None: |
|
return None |
|
path, identifier = target.rsplit('.', 1) |
|
module = __import__(path, globals(), locals(), [identifier]) |
|
return getattr(module, identifier) |
|
|
|
|
|
def version_larger_equal(a: str, b: str) -> bool: |
|
|
|
a = a.split('+')[0] |
|
b = b.split('+')[0] |
|
return tuple(map(int, a.split('.'))) >= tuple(map(int, b.split('.'))) |
|
|
|
|
|
_last_uid = defaultdict(int) |
|
|
|
|
|
def uid(namespace: str = 'default') -> int: |
|
_last_uid[namespace] += 1 |
|
return _last_uid[namespace] |
|
|
|
|
|
def reset_uid(namespace: str = 'default') -> None: |
|
_last_uid[namespace] = 0 |
|
|
|
|
|
def get_module_name(cls_or_func): |
|
module_name = cls_or_func.__module__ |
|
if module_name == '__main__': |
|
|
|
for frm in inspect.stack(): |
|
if inspect.getmodule(frm[0]).__name__ == '__main__': |
|
|
|
main_file_path = Path(inspect.getsourcefile(frm[0])) |
|
if main_file_path.parents[0] != Path('.'): |
|
raise RuntimeError(f'You are using "{main_file_path}" to launch your experiment, ' |
|
f'please launch the experiment under the directory where "{main_file_path.name}" is located.') |
|
module_name = main_file_path.stem |
|
break |
|
if module_name == '__main__': |
|
warnings.warn('Callstack exhausted but main module still not found. This will probably cause issues that the ' |
|
'function/class cannot be imported.') |
|
|
|
|
|
|
|
|
|
|
|
if f'{cls_or_func.__module__}.{cls_or_func.__name__}' == 'torch.nn.modules.rnn.LSTM': |
|
module_name = cls_or_func.__module__ |
|
|
|
return module_name |
|
|
|
|
|
def get_importable_name(cls, relocate_module=False): |
|
module_name = get_module_name(cls) if relocate_module else cls.__module__ |
|
return module_name + '.' + cls.__name__ |
|
|
|
|
|
class ContextStack: |
|
""" |
|
This is to maintain a globally-accessible context envinronment that is visible to everywhere. |
|
|
|
Use ``with ContextStack(namespace, value):`` to initiate, and use ``get_current_context(namespace)`` to |
|
get the corresponding value in the namespace. |
|
""" |
|
|
|
_stack: Dict[str, List[Any]] = defaultdict(list) |
|
|
|
def __init__(self, key: str, value: Any): |
|
self.key = key |
|
self.value = value |
|
|
|
def __enter__(self): |
|
self.push(self.key, self.value) |
|
return self |
|
|
|
def __exit__(self, *args, **kwargs): |
|
self.pop(self.key) |
|
|
|
@classmethod |
|
def push(cls, key: str, value: Any): |
|
cls._stack[key].append(value) |
|
|
|
@classmethod |
|
def pop(cls, key: str) -> None: |
|
cls._stack[key].pop() |
|
|
|
@classmethod |
|
def top(cls, key: str) -> Any: |
|
assert cls._stack[key], 'Context is empty.' |
|
return cls._stack[key][-1] |
|
|
|
|
|
def get_current_context(key: str) -> Any: |
|
return ContextStack.top(key) |
|
|