Llama-3.1-8B-DALv0.1
/
venv
/lib
/python3.12
/site-packages
/torch
/_dynamo
/variables
/optimizer.py
# mypy: ignore-errors | |
import weakref | |
from typing import Dict, List, TYPE_CHECKING | |
import torch | |
from torch.utils._pytree import tree_map_only | |
from ..guards import GuardBuilder, install_guard | |
from ..source import ( | |
AttrSource, | |
ConstDictKeySource, | |
GetItemSource, | |
GlobalWeakRefSource, | |
GradSource, | |
) | |
from ..utils import GLOBAL_KEY_PREFIX | |
from .constant import ConstantVariable | |
from .dicts import ConstDictVariable | |
from .lists import ListVariable | |
from .misc import GetAttrVariable | |
from .user_defined import UserDefinedObjectVariable | |
if TYPE_CHECKING: | |
from .base import VariableTracker | |
class ArgMappingException(Exception): | |
pass | |
class GuardInstallException(Exception): | |
pass | |
class OptimizerVariable(UserDefinedObjectVariable): | |
_nonvar_fields = { | |
"grad_to_source", | |
"tensor_to_source", | |
"static_tensor_names", | |
*UserDefinedObjectVariable._nonvar_fields, | |
} | |
def __init__( | |
self, | |
value, | |
grad_to_source=None, | |
static_tensor_names=None, | |
tensor_to_source=None, | |
**kwargs, | |
): | |
super().__init__(value, **kwargs) | |
self.grad_to_source = grad_to_source or {} | |
self.tensor_to_source = tensor_to_source or {} | |
self.static_tensor_names = static_tensor_names or set() | |
def call_method( | |
self, | |
tx, | |
name, | |
args: "List[VariableTracker]", | |
kwargs: "Dict[str, VariableTracker]", | |
) -> "VariableTracker": | |
"""This is an optimization to avoid tracing the very slow initialization of the optimizer""" | |
if name == "_init_group": | |
try: | |
self.graph_break_if_pending_mutation(tx) | |
self.move_step_if_cpu() | |
py_args, py_kwargs = self.get_python_args(*args, **kwargs) | |
ret_val = self.value._init_group(*py_args, **py_kwargs) | |
self.map_sources_and_install_guards(tx) | |
self.update_list_args(tx, args, kwargs, py_args, py_kwargs) | |
# stash a weak_ptr to optimizer to invalidate code | |
# if the optimizer object dies | |
mangled_name = f"__optimizer_{id(self.value)}" | |
tx.store_global_weakref_by_id(mangled_name, self.value) | |
self.create_finalizer(tx) | |
# This is currently safe only because the only actual `ret_val`s returned | |
# by the `_init_group` of existing optimizers are properties that are invariant | |
# to the input tensors (e.g. dtype, layout). Changing these would trigger a | |
# recompilation and hence never result in the wrong specialization of `ret_val`. | |
return ConstantVariable.create(ret_val) | |
except (ArgMappingException, GuardInstallException) as _: | |
# trace normally if we can't map args or install guards correctly | |
pass | |
return super().call_method(tx, name, args, kwargs) | |
def var_getattr(self, tx, name): | |
# Note: this allows us to intercept the call in call_method | |
# in the typical case, we return a UserMethodVariable | |
# which will directly inline | |
if name in ("_init_group", "step"): | |
return GetAttrVariable(self, name, source=AttrSource(self.source, name)) | |
if name == "param_groups": | |
from ..decorators import mark_static_address | |
for group in self.value.param_groups: | |
for p in group["params"]: | |
mark_static_address(p) | |
self._set_capturable(tx) | |
return super().var_getattr(tx, name) | |
def graph_break_if_pending_mutation(self, tx): | |
# If there are pending mutations on a parameter (due to using closure) | |
# then we need to graph break to allow the python version of the parameter | |
# to update, so that running _init_group will initialize the states with | |
# the correct values | |
for g in self.value.param_groups: | |
for p in g["params"]: | |
side_effects = tx.output.side_effects | |
variable = side_effects.id_to_variable.get(id(p), None) | |
if variable and side_effects.has_pending_mutation(variable): | |
from ..exc import Unsupported | |
raise Unsupported("Pending mutation on parameter") | |
def _set_capturable(self, tx): | |
from . import LazyVariableTracker | |
from .builder import VariableBuilder | |
# We only set capturable if params are on cuda | |
# and the state is not initialized | |
def safe_to_set_capturable(group): | |
all_uninitialized = True | |
all_cuda = True | |
for p in group.get("params", list()): | |
all_cuda &= p.is_cuda | |
all_uninitialized &= p not in self.value.state | |
return "capturable" in group and all_uninitialized and all_cuda | |
# track indices to not set so we don't need to | |
# in the variable tracker realize the whole state | |
# we handle guarding the state specially | |
for ind, group in enumerate(self.value.param_groups): | |
if safe_to_set_capturable(group): | |
group["capturable"] = True | |
param_groups_vt = LazyVariableTracker.realize_all( | |
VariableBuilder(tx, AttrSource(self.source, "param_groups"))( | |
self.value.param_groups | |
) | |
) | |
for ind, param_group_vt in enumerate(param_groups_vt.items): | |
key = ConstDictVariable._HashableTracker( | |
ConstantVariable.create("capturable") | |
) | |
param_group_vt.items[key] = ConstantVariable.create(True) | |
def get_python_args(self, *args, **kwargs): | |
"""Get python values equivalent to the variable tracker args""" | |
def map_arg(arg): | |
if isinstance(arg, ConstantVariable): | |
return arg.as_python_constant() | |
elif isinstance(arg, ListVariable) and not arg.items: | |
return [] | |
elif ( | |
isinstance(arg, ConstDictVariable) | |
and isinstance(arg.source, GetItemSource) | |
and isinstance(arg.source.base, AttrSource) | |
and arg.source.base.member == "param_groups" | |
): | |
return self.value.param_groups[arg.source.index] | |
raise ArgMappingException | |
new_args = [map_arg(arg) for arg in args] | |
new_kwargs = {k: map_arg(v) for k, v in kwargs.items()} | |
return new_args, new_kwargs | |
# If users load an old state dictionary, | |
# it's possible that step could be on the cpu | |
# if this is the case, move it to the GPU | |
# corresponding to the parameter | |
# in most cases this is a no-op because the state is empty | |
def move_step_if_cpu(self): | |
for p, state in self.value.state.items(): | |
if "step" in state and state["step"].is_cpu: | |
state["step"] = state["step"].to(p.device) | |
def map_sources_and_install_guards(self, tx): | |
from ..decorators import mark_static_address | |
from .builder import VariableBuilder | |
from .lazy import LazyVariableTracker | |
self.grad_to_source = {} | |
self.tensor_to_source = {} | |
# Tracing the _init_group is expensive. But we still have to insert the | |
# necessary guards for _init_group. So, we manually handle insertion of | |
# guards. We also want to mark all the tensors inside the state dict to | |
# be static address. | |
# Mark all the tensors in the state dict to be static address. This has | |
# to be done first because the variable builder relies on the static | |
# address annotation. | |
def mark_static(x): | |
mark_static_address(x) | |
tree_map_only(torch.Tensor, mark_static, self.value.state) | |
# Recursively realize the variable trackers for optim.state and | |
# optim.param_groups, which recursively install the necessary guards. | |
param_groups_vt = LazyVariableTracker.realize_all( | |
VariableBuilder(tx, AttrSource(self.source, "param_groups"))( | |
self.value.param_groups | |
) | |
) | |
state_vt = VariableBuilder(tx, AttrSource(self.source, "state"))( | |
self.value.state | |
) | |
# We need to realize the top level state dict to populate | |
# the guard locals | |
state_vt.realize() | |
# Populate self.grad_to_source and self.tensor_to_source so that we can | |
# manually update_list_args | |
for g_ind, (group, group_vt) in enumerate( | |
zip(self.value.param_groups, param_groups_vt.items) | |
): | |
# we assume here that all params within a param group | |
# are initialized similarly | |
if len(group["params"]) > 0: | |
for param in group["params"]: | |
if param.grad is not None: | |
key_index = None | |
for i, k in enumerate(self.value.state.keys()): | |
if k is param: | |
key_index = i | |
break | |
if key_index: | |
state_source = AttrSource(self.source, "state") | |
LazyVariableTracker.realize_all( | |
VariableBuilder( | |
tx, | |
GetItemSource( | |
state_source, | |
ConstDictKeySource(state_source, key_index), | |
), | |
)(self.value.state[param]) | |
) | |
break | |
group_source = group_vt.source | |
params_vt = group_vt.getitem_const(ConstantVariable.create("params")) | |
for p_ind, (p, p_vt) in enumerate( | |
zip(group["params"], params_vt.unpack_var_sequence(tx)) | |
): | |
param_source = p_vt.source | |
self.tensor_to_source[p] = param_source | |
grad_source = GradSource( | |
param_source, | |
"grad", | |
) | |
if p.grad is not None: | |
self.grad_to_source[p.grad] = grad_source | |
else: | |
install_guard(grad_source.make_guard(GuardBuilder.CONSTANT_MATCH)) | |
# We have to again iterate over the state dict to collect the | |
# tensor_to_source dict. This is used for the finalizer. | |
state_source = AttrSource(self.source, "state") | |
for idx, (p, value) in enumerate(self.value.state.items()): | |
p_state_source = GetItemSource( | |
state_source, ConstDictKeySource(state_source, idx) | |
) | |
for k, v in value.items(): | |
if ( | |
isinstance(v, torch.Tensor) | |
and v not in self.grad_to_source | |
and v not in self.tensor_to_source | |
): | |
self.tensor_to_source[v] = GetItemSource(p_state_source, k) | |
def wrap_tensor(self, tx, tensor_value): | |
"""Wrap state tensor in a TensorVariable""" | |
from ..decorators import mark_static_address | |
from .builder import VariableBuilder | |
# If we have a source for a tensor already use it, | |
# if we have not seen a tensor before, stash and use a | |
# global weak ref source, since it must be an optimizer tensor | |
# that we have missed | |
if tensor_value in self.tensor_to_source: | |
# mark these tensors as static for cudagraphs | |
mark_static_address(tensor_value) | |
builder = VariableBuilder(tx, self.tensor_to_source[tensor_value]) | |
self.static_tensor_names.add(tx.output.module_key_name(builder.name)) | |
elif tensor_value in self.grad_to_source: | |
builder = VariableBuilder(tx, self.grad_to_source[tensor_value]) | |
else: | |
# mark these tensors as static for cudagraphs | |
mark_static_address(tensor_value) | |
global_name = tx.store_global_weakref_by_id(GLOBAL_KEY_PREFIX, tensor_value) | |
builder = VariableBuilder(tx, GlobalWeakRefSource(global_name)) | |
self.static_tensor_names.add(tx.output.module_key_name(builder.name)) | |
result = builder(tensor_value) | |
return result | |
def update_list_args(self, tx, args, kwargs, py_args, py_kwargs): | |
"""Update the args and kwargs to the traced optimizer call""" | |
for arg, py_arg in zip(args, py_args): | |
if isinstance(arg, ListVariable): | |
assert isinstance( | |
py_arg, list | |
), "py_arg should be a list in optimizer variable" | |
for i, val in enumerate(py_arg): | |
tx.output.side_effects.mutation(arg) | |
if isinstance(val, torch.Tensor): | |
arg.items.append(self.wrap_tensor(tx, val)) | |
else: | |
from .builder import SourcelessBuilder, VariableBuilder | |
if arg.source: | |
arg.items.append( | |
VariableBuilder(tx, GetItemSource(arg.source, i))(val) | |
) | |
else: | |
arg.items.append(SourcelessBuilder.create(tx, val)) | |
def create_finalizer(self, tx): | |
names_to_delete = self.static_tensor_names | |
value = self.value | |
tc = tx.output.tracing_context | |
def init_finalizer(gm): | |
def clear_static_tensor_refs(): | |
for name in names_to_delete: | |
gm._buffers.pop(name, None) | |
gm._parameters.pop(name, None) | |
if tc.params_flat: | |
tc.params_flat.clear() | |
weakref.finalize(value, clear_static_tensor_refs) | |
tx.output.add_graph_finalizer(init_finalizer) | |