Spaces:
Running
Running
# Copyright (c) Microsoft Corporation. | |
# Licensed under the MIT license. | |
import inspect | |
import warnings | |
from collections import defaultdict | |
from typing import Any, List, Dict | |
from pathlib import Path | |
def import_(target: str, allow_none: bool = False) -> Any: | |
if target is None: | |
return None | |
path, identifier = target.rsplit('.', 1) | |
module = __import__(path, globals(), locals(), [identifier]) | |
return getattr(module, identifier) | |
def version_larger_equal(a: str, b: str) -> bool: | |
# TODO: refactor later | |
a = a.split('+')[0] | |
b = b.split('+')[0] | |
return tuple(map(int, a.split('.'))) >= tuple(map(int, b.split('.'))) | |
_last_uid = defaultdict(int) | |
def uid(namespace: str = 'default') -> int: | |
_last_uid[namespace] += 1 | |
return _last_uid[namespace] | |
def reset_uid(namespace: str = 'default') -> None: | |
_last_uid[namespace] = 0 | |
def get_module_name(cls_or_func): | |
module_name = cls_or_func.__module__ | |
if module_name == '__main__': | |
# infer the module name with inspect | |
for frm in inspect.stack(): | |
if inspect.getmodule(frm[0]).__name__ == '__main__': | |
# main module found | |
main_file_path = Path(inspect.getsourcefile(frm[0])) | |
if main_file_path.parents[0] != Path('.'): | |
raise RuntimeError(f'You are using "{main_file_path}" to launch your experiment, ' | |
f'please launch the experiment under the directory where "{main_file_path.name}" is located.') | |
module_name = main_file_path.stem | |
break | |
if module_name == '__main__': | |
warnings.warn('Callstack exhausted but main module still not found. This will probably cause issues that the ' | |
'function/class cannot be imported.') | |
# NOTE: this is hacky. As torchscript retrieves LSTM's source code to do something. | |
# to make LSTM's source code can be found, we should assign original LSTM's __module__ to | |
# the wrapped LSTM's __module__ | |
# TODO: find out all the modules that have the same requirement as LSTM | |
if f'{cls_or_func.__module__}.{cls_or_func.__name__}' == 'torch.nn.modules.rnn.LSTM': | |
module_name = cls_or_func.__module__ | |
return module_name | |
def get_importable_name(cls, relocate_module=False): | |
module_name = get_module_name(cls) if relocate_module else cls.__module__ | |
return module_name + '.' + cls.__name__ | |
class ContextStack: | |
""" | |
This is to maintain a globally-accessible context envinronment that is visible to everywhere. | |
Use ``with ContextStack(namespace, value):`` to initiate, and use ``get_current_context(namespace)`` to | |
get the corresponding value in the namespace. | |
""" | |
_stack: Dict[str, List[Any]] = defaultdict(list) | |
def __init__(self, key: str, value: Any): | |
self.key = key | |
self.value = value | |
def __enter__(self): | |
self.push(self.key, self.value) | |
return self | |
def __exit__(self, *args, **kwargs): | |
self.pop(self.key) | |
def push(cls, key: str, value: Any): | |
cls._stack[key].append(value) | |
def pop(cls, key: str) -> None: | |
cls._stack[key].pop() | |
def top(cls, key: str) -> Any: | |
assert cls._stack[key], 'Context is empty.' | |
return cls._stack[key][-1] | |
def get_current_context(key: str) -> Any: | |
return ContextStack.top(key) | |