Kano001's picture
Upload 5252 files
c61ccee verified
raw
history blame
9.32 kB
# mypy: ignore-errors
import weakref
from typing import Dict, List
import torch
from ..decorators import mark_static_address
from ..guards import GuardBuilder, install_guard
from ..source import AttrSource, ConstDictKeySource, GetItemSource, GlobalWeakRefSource
from ..utils import GLOBAL_KEY_PREFIX
from .base import VariableTracker
from .constant import ConstantVariable
from .dicts import ConstDictVariable
from .lists import ListVariable
from .misc import GetAttrVariable
from .user_defined import UserDefinedObjectVariable
class ArgMappingException(Exception):
pass
class GuardInstallException(Exception):
pass
class OptimizerVariable(UserDefinedObjectVariable):
def __init__(
self,
value,
grad_to_source=None,
static_tensor_names=None,
tensor_to_source=None,
**kwargs,
):
super().__init__(value, **kwargs)
for group in self.value.param_groups:
if "capturable" in group:
group["capturable"] = True
for p in group["params"]:
mark_static_address(p, guard=False)
self.grad_to_source = grad_to_source or {}
self.tensor_to_source = tensor_to_source or {}
self.static_tensor_names = static_tensor_names or set()
def call_method(
self,
tx,
name,
args: "List[VariableTracker]",
kwargs: "Dict[str, VariableTracker]",
) -> "VariableTracker":
"""This is an optimization to avoid tracing the very slow initialization of the optimizer"""
if name == "_init_group":
try:
py_args, py_kwargs = self.get_python_args(*args, **kwargs)
ret_val = self.value._init_group(*py_args, **py_kwargs)
self.map_sources_and_install_guards(tx)
self.update_list_args(tx, args, kwargs, py_args, py_kwargs)
# stash a weak_ptr to optimizer to invalidate code
# if the optimizer object dies
mangled_name = f"__optimizer_{id(self.value)}"
tx.store_global_weakref_by_id(mangled_name, self.value)
self.create_finalizer(tx)
# This is currently safe only because the only actual `ret_val`s returned
# by the `_init_group` of existing optimizers are properties that are invariant
# to the input tensors (e.g. dtype, layout). Changing these would trigger a
# recompilation and hence never result in the wrong specialization of `ret_val`.
return ConstantVariable.create(ret_val)
except (ArgMappingException, GuardInstallException) as _:
# trace normally if we can't map args or install guards correctly
pass
return super().call_method(tx, name, args, kwargs)
def var_getattr(self, tx, name):
if name == "_init_group":
return GetAttrVariable(self, name)
return super().var_getattr(tx, name)
def get_python_args(self, *args, **kwargs):
"""Get python values equivalent to the variable tracker args"""
def map_arg(arg):
if isinstance(arg, ConstantVariable):
return arg.as_python_constant()
elif isinstance(arg, ListVariable) and not arg.items:
return []
elif (
isinstance(arg, ConstDictVariable)
and isinstance(arg.source, GetItemSource)
and isinstance(arg.source.base, AttrSource)
and arg.source.base.member == "param_groups"
):
return self.value.param_groups[arg.source.index]
raise ArgMappingException()
new_args = [map_arg(arg) for arg in args]
new_kwargs = {k: map_arg(v) for k, v in kwargs.items()}
return new_args, new_kwargs
def map_sources_and_install_guards(self, tx):
self.grad_to_source = {}
self.tensor_to_source = {}
from .builder import VariableBuilder
param_groups_vt = VariableBuilder(tx, AttrSource(self.source, "param_groups"))(
self.value.param_groups
).recursive_realize()
for g_ind, (group, group_vt) in enumerate(
zip(self.value.param_groups, param_groups_vt.items)
):
group_source = group_vt.source
params_vt = group_vt.getitem_const(ConstantVariable.create("params"))
for p_ind, (p, p_vt) in enumerate(
zip(group["params"], params_vt.unpack_var_sequence(tx))
):
param_source = p_vt.source
self.tensor_to_source[p] = param_source
grad_source = AttrSource(
param_source,
"grad",
)
if p.grad is not None:
self.grad_to_source[p.grad] = grad_source
else:
install_guard(grad_source.make_guard(GuardBuilder.CONSTANT_MATCH))
# state guards take a long time to generate
# so we manually generate them here
state_source = AttrSource(self.source, "state")
install_guard(state_source.make_guard(GuardBuilder.DICT_KEYS))
for idx, (p, value) in enumerate(self.value.state.items()):
tx.store_global_weakref_by_id(GLOBAL_KEY_PREFIX, p)
p_state_source = GetItemSource(
state_source, ConstDictKeySource(state_source, idx)
)
install_guard(p_state_source.make_guard(GuardBuilder.DICT_KEYS))
for k, v in value.items():
if (
isinstance(v, torch.Tensor)
and v not in self.grad_to_source
and v not in self.tensor_to_source
):
self.tensor_to_source[v] = GetItemSource(p_state_source, k)
elif v is None or isinstance(v, (bool, int, float, str)):
install_guard(
GetItemSource(p_state_source, k).make_guard(
GuardBuilder.CONSTANT_MATCH
)
)
else:
raise GuardInstallException()
def wrap_tensor(self, tx, tensor_value):
"""Wrap state tensor in a TensorVariable"""
from .builder import VariableBuilder
# If we have a source for a tensor already use it,
# if we have not seen a tensor before, stash and use a
# global weak ref source, since it must be an optimizer tensor
# that we have missed
if tensor_value in self.tensor_to_source:
# mark these tensors as static for cudagraphs
mark_static_address(tensor_value, guard=False)
builder = VariableBuilder(tx, self.tensor_to_source[tensor_value])
self.static_tensor_names.add(tx.output.module_key_name(builder.name))
elif tensor_value in self.grad_to_source:
builder = VariableBuilder(tx, self.grad_to_source[tensor_value])
else:
# mark these tensors as static for cudagraphs
mark_static_address(tensor_value, guard=False)
global_name = tx.store_global_weakref_by_id(GLOBAL_KEY_PREFIX, tensor_value)
builder = VariableBuilder(tx, GlobalWeakRefSource(global_name))
self.static_tensor_names.add(tx.output.module_key_name(builder.name))
result = builder(tensor_value)
return result
def update_list_args(self, tx, args, kwargs, py_args, py_kwargs):
"""Update the args and kwargs to the traced optimizer call"""
for arg, py_arg in zip(args, py_args):
if isinstance(arg, ListVariable):
assert isinstance(
py_arg, list
), "py_arg should be a list in optimizer variable"
for i, val in enumerate(py_arg):
tx.output.side_effects.mutation(arg)
if isinstance(val, torch.Tensor):
arg.items.append(self.wrap_tensor(tx, val))
else:
from .builder import SourcelessBuilder, VariableBuilder
if arg.source:
arg.items.append(
VariableBuilder(tx, GetItemSource(arg.source, i))(val)
)
else:
arg.items.append(SourcelessBuilder()(tx, val))
def create_finalizer(self, tx):
names_to_delete = self.static_tensor_names
value = self.value
tc = tx.output.tracing_context
def init_finalizer(gm):
def clear_static_tensor_refs():
for name in names_to_delete:
gm._buffers.pop(name, None)
gm._parameters.pop(name, None)
if tc.params_flat:
tc.params_flat.clear()
weakref.finalize(value, clear_static_tensor_refs)
tx.output.add_graph_finalizer(init_finalizer)