LINC-BIT's picture
Upload 1912 files
b84549f verified
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.
import inspect
import warnings
from collections import defaultdict
from typing import Any, List, Dict
from pathlib import Path
def import_(target: str, allow_none: bool = False) -> Any:
if target is None:
return None
path, identifier = target.rsplit('.', 1)
module = __import__(path, globals(), locals(), [identifier])
return getattr(module, identifier)
def version_larger_equal(a: str, b: str) -> bool:
# TODO: refactor later
a = a.split('+')[0]
b = b.split('+')[0]
return tuple(map(int, a.split('.'))) >= tuple(map(int, b.split('.')))
_last_uid = defaultdict(int)
def uid(namespace: str = 'default') -> int:
_last_uid[namespace] += 1
return _last_uid[namespace]
def reset_uid(namespace: str = 'default') -> None:
_last_uid[namespace] = 0
def get_module_name(cls_or_func):
module_name = cls_or_func.__module__
if module_name == '__main__':
# infer the module name with inspect
for frm in inspect.stack():
if inspect.getmodule(frm[0]).__name__ == '__main__':
# main module found
main_file_path = Path(inspect.getsourcefile(frm[0]))
if main_file_path.parents[0] != Path('.'):
raise RuntimeError(f'You are using "{main_file_path}" to launch your experiment, '
f'please launch the experiment under the directory where "{main_file_path.name}" is located.')
module_name = main_file_path.stem
break
if module_name == '__main__':
warnings.warn('Callstack exhausted but main module still not found. This will probably cause issues that the '
'function/class cannot be imported.')
# NOTE: this is hacky. As torchscript retrieves LSTM's source code to do something.
# to make LSTM's source code can be found, we should assign original LSTM's __module__ to
# the wrapped LSTM's __module__
# TODO: find out all the modules that have the same requirement as LSTM
if f'{cls_or_func.__module__}.{cls_or_func.__name__}' == 'torch.nn.modules.rnn.LSTM':
module_name = cls_or_func.__module__
return module_name
def get_importable_name(cls, relocate_module=False):
module_name = get_module_name(cls) if relocate_module else cls.__module__
return module_name + '.' + cls.__name__
class ContextStack:
"""
This is to maintain a globally-accessible context envinronment that is visible to everywhere.
Use ``with ContextStack(namespace, value):`` to initiate, and use ``get_current_context(namespace)`` to
get the corresponding value in the namespace.
"""
_stack: Dict[str, List[Any]] = defaultdict(list)
def __init__(self, key: str, value: Any):
self.key = key
self.value = value
def __enter__(self):
self.push(self.key, self.value)
return self
def __exit__(self, *args, **kwargs):
self.pop(self.key)
@classmethod
def push(cls, key: str, value: Any):
cls._stack[key].append(value)
@classmethod
def pop(cls, key: str) -> None:
cls._stack[key].pop()
@classmethod
def top(cls, key: str) -> Any:
assert cls._stack[key], 'Context is empty.'
return cls._stack[key][-1]
def get_current_context(key: str) -> Any:
return ContextStack.top(key)