Spaces:
Running
Running
import traceback | |
from contextlib import contextmanager | |
from typing import List, Any, Dict | |
from ._compatibility import compatibility | |
__all__ = ['preserve_node_meta', 'has_preserved_node_meta', | |
'set_stack_trace', 'set_grad_fn_seq_nr', 'reset_grad_fn_seq_nr', | |
'format_stack', 'set_current_meta', 'get_current_meta'] | |
current_meta: Dict[str, Any] = {} | |
should_preserve_node_meta = False | |
def preserve_node_meta(): | |
global should_preserve_node_meta | |
saved_should_preserve_node_meta = should_preserve_node_meta | |
try: | |
should_preserve_node_meta = True | |
yield | |
finally: | |
should_preserve_node_meta = saved_should_preserve_node_meta | |
def set_stack_trace(stack : List[str]): | |
global current_meta | |
if should_preserve_node_meta and stack: | |
current_meta["stack_trace"] = "".join(stack) | |
def set_grad_fn_seq_nr(seq_nr): | |
global current_meta | |
if should_preserve_node_meta: | |
# The seq_nr is captured by eager mode in the grad_fn during forward | |
current_meta["grad_fn_seq_nr"] = current_meta.get("grad_fn_seq_nr", []) + [seq_nr] | |
current_meta["in_grad_fn"] = current_meta.get("in_grad_fn", 0) + 1 | |
def reset_grad_fn_seq_nr(): | |
# NB: reset state properly, this would be helpful towards supporting | |
# reentrant autograd if we actually wanted to do that. | |
global current_meta | |
if should_preserve_node_meta: | |
current_level = current_meta.get("in_grad_fn", 0) | |
assert current_level > 0 | |
if current_level == 1: | |
del current_meta["in_grad_fn"] | |
del current_meta["grad_fn_seq_nr"] | |
else: | |
current_meta["in_grad_fn"] = current_level - 1 | |
current_meta["grad_fn_seq_nr"].pop() | |
def format_stack() -> List[str]: | |
if should_preserve_node_meta: | |
return [current_meta.get("stack_trace", "")] | |
else: | |
# fallback to traceback.format_stack() | |
return traceback.format_list(traceback.extract_stack()[:-1]) | |
def has_preserved_node_meta() -> bool: | |
return should_preserve_node_meta | |
def set_current_meta(node): | |
global current_meta | |
if should_preserve_node_meta and node.meta: | |
saved_meta = current_meta | |
try: | |
current_meta = node.meta.copy() | |
# Append (node.name, node.target) onto "from_node" for provenance tracking | |
if "from_node" not in current_meta: | |
current_meta["from_node"] = [(node.name, node.target)] | |
elif current_meta["from_node"][-1][0] != node.name: | |
current_meta["from_node"].append((node.name, node.target)) | |
yield | |
finally: | |
current_meta = saved_meta | |
else: | |
yield | |
def get_current_meta() -> Dict[str, Any]: | |
return current_meta | |