File size: 14,021 Bytes
d1ceb73 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 |
# mypy: ignore-errors
import weakref
from typing import Dict, List, TYPE_CHECKING
import torch
from torch.utils._pytree import tree_map_only
from ..guards import GuardBuilder, install_guard
from ..source import (
AttrSource,
ConstDictKeySource,
GetItemSource,
GlobalWeakRefSource,
GradSource,
)
from ..utils import GLOBAL_KEY_PREFIX
from .constant import ConstantVariable
from .dicts import ConstDictVariable
from .lists import ListVariable
from .misc import GetAttrVariable
from .user_defined import UserDefinedObjectVariable
if TYPE_CHECKING:
from .base import VariableTracker
class ArgMappingException(Exception):
pass
class GuardInstallException(Exception):
pass
class OptimizerVariable(UserDefinedObjectVariable):
_nonvar_fields = {
"grad_to_source",
"tensor_to_source",
"static_tensor_names",
*UserDefinedObjectVariable._nonvar_fields,
}
def __init__(
self,
value,
grad_to_source=None,
static_tensor_names=None,
tensor_to_source=None,
**kwargs,
):
super().__init__(value, **kwargs)
self.grad_to_source = grad_to_source or {}
self.tensor_to_source = tensor_to_source or {}
self.static_tensor_names = static_tensor_names or set()
def call_method(
self,
tx,
name,
args: "List[VariableTracker]",
kwargs: "Dict[str, VariableTracker]",
) -> "VariableTracker":
"""This is an optimization to avoid tracing the very slow initialization of the optimizer"""
if name == "_init_group":
try:
self.graph_break_if_pending_mutation(tx)
self.move_step_if_cpu()
py_args, py_kwargs = self.get_python_args(*args, **kwargs)
ret_val = self.value._init_group(*py_args, **py_kwargs)
self.map_sources_and_install_guards(tx)
self.update_list_args(tx, args, kwargs, py_args, py_kwargs)
# stash a weak_ptr to optimizer to invalidate code
# if the optimizer object dies
mangled_name = f"__optimizer_{id(self.value)}"
tx.store_global_weakref_by_id(mangled_name, self.value)
self.create_finalizer(tx)
# This is currently safe only because the only actual `ret_val`s returned
# by the `_init_group` of existing optimizers are properties that are invariant
# to the input tensors (e.g. dtype, layout). Changing these would trigger a
# recompilation and hence never result in the wrong specialization of `ret_val`.
return ConstantVariable.create(ret_val)
except (ArgMappingException, GuardInstallException) as _:
# trace normally if we can't map args or install guards correctly
pass
return super().call_method(tx, name, args, kwargs)
def var_getattr(self, tx, name):
# Note: this allows us to intercept the call in call_method
# in the typical case, we return a UserMethodVariable
# which will directly inline
if name in ("_init_group", "step"):
return GetAttrVariable(self, name, source=AttrSource(self.source, name))
if name == "param_groups":
from ..decorators import mark_static_address
for group in self.value.param_groups:
for p in group["params"]:
mark_static_address(p)
self._set_capturable(tx)
return super().var_getattr(tx, name)
def graph_break_if_pending_mutation(self, tx):
# If there are pending mutations on a parameter (due to using closure)
# then we need to graph break to allow the python version of the parameter
# to update, so that running _init_group will initialize the states with
# the correct values
for g in self.value.param_groups:
for p in g["params"]:
side_effects = tx.output.side_effects
variable = side_effects.id_to_variable.get(id(p), None)
if variable and side_effects.has_pending_mutation(variable):
from ..exc import Unsupported
raise Unsupported("Pending mutation on parameter")
def _set_capturable(self, tx):
from . import LazyVariableTracker
from .builder import VariableBuilder
# We only set capturable if params are on cuda
# and the state is not initialized
def safe_to_set_capturable(group):
all_uninitialized = True
all_cuda = True
for p in group.get("params", list()):
all_cuda &= p.is_cuda
all_uninitialized &= p not in self.value.state
return "capturable" in group and all_uninitialized and all_cuda
# track indices to not set so we don't need to
# in the variable tracker realize the whole state
# we handle guarding the state specially
for ind, group in enumerate(self.value.param_groups):
if safe_to_set_capturable(group):
group["capturable"] = True
param_groups_vt = LazyVariableTracker.realize_all(
VariableBuilder(tx, AttrSource(self.source, "param_groups"))(
self.value.param_groups
)
)
for ind, param_group_vt in enumerate(param_groups_vt.items):
key = ConstDictVariable._HashableTracker(
ConstantVariable.create("capturable")
)
param_group_vt.items[key] = ConstantVariable.create(True)
def get_python_args(self, *args, **kwargs):
"""Get python values equivalent to the variable tracker args"""
def map_arg(arg):
if isinstance(arg, ConstantVariable):
return arg.as_python_constant()
elif isinstance(arg, ListVariable) and not arg.items:
return []
elif (
isinstance(arg, ConstDictVariable)
and isinstance(arg.source, GetItemSource)
and isinstance(arg.source.base, AttrSource)
and arg.source.base.member == "param_groups"
):
return self.value.param_groups[arg.source.index]
raise ArgMappingException
new_args = [map_arg(arg) for arg in args]
new_kwargs = {k: map_arg(v) for k, v in kwargs.items()}
return new_args, new_kwargs
# If users load an old state dictionary,
# it's possible that step could be on the cpu
# if this is the case, move it to the GPU
# corresponding to the parameter
# in most cases this is a no-op because the state is empty
def move_step_if_cpu(self):
for p, state in self.value.state.items():
if "step" in state and state["step"].is_cpu:
state["step"] = state["step"].to(p.device)
def map_sources_and_install_guards(self, tx):
from ..decorators import mark_static_address
from .builder import VariableBuilder
from .lazy import LazyVariableTracker
self.grad_to_source = {}
self.tensor_to_source = {}
# Tracing the _init_group is expensive. But we still have to insert the
# necessary guards for _init_group. So, we manually handle insertion of
# guards. We also want to mark all the tensors inside the state dict to
# be static address.
# Mark all the tensors in the state dict to be static address. This has
# to be done first because the variable builder relies on the static
# address annotation.
def mark_static(x):
mark_static_address(x)
tree_map_only(torch.Tensor, mark_static, self.value.state)
# Recursively realize the variable trackers for optim.state and
# optim.param_groups, which recursively install the necessary guards.
param_groups_vt = LazyVariableTracker.realize_all(
VariableBuilder(tx, AttrSource(self.source, "param_groups"))(
self.value.param_groups
)
)
state_vt = VariableBuilder(tx, AttrSource(self.source, "state"))(
self.value.state
)
# We need to realize the top level state dict to populate
# the guard locals
state_vt.realize()
# Populate self.grad_to_source and self.tensor_to_source so that we can
# manually update_list_args
for g_ind, (group, group_vt) in enumerate(
zip(self.value.param_groups, param_groups_vt.items)
):
# we assume here that all params within a param group
# are initialized similarly
if len(group["params"]) > 0:
for param in group["params"]:
if param.grad is not None:
key_index = None
for i, k in enumerate(self.value.state.keys()):
if k is param:
key_index = i
break
if key_index:
state_source = AttrSource(self.source, "state")
LazyVariableTracker.realize_all(
VariableBuilder(
tx,
GetItemSource(
state_source,
ConstDictKeySource(state_source, key_index),
),
)(self.value.state[param])
)
break
group_source = group_vt.source
params_vt = group_vt.getitem_const(ConstantVariable.create("params"))
for p_ind, (p, p_vt) in enumerate(
zip(group["params"], params_vt.unpack_var_sequence(tx))
):
param_source = p_vt.source
self.tensor_to_source[p] = param_source
grad_source = GradSource(
param_source,
"grad",
)
if p.grad is not None:
self.grad_to_source[p.grad] = grad_source
else:
install_guard(grad_source.make_guard(GuardBuilder.CONSTANT_MATCH))
# We have to again iterate over the state dict to collect the
# tensor_to_source dict. This is used for the finalizer.
state_source = AttrSource(self.source, "state")
for idx, (p, value) in enumerate(self.value.state.items()):
p_state_source = GetItemSource(
state_source, ConstDictKeySource(state_source, idx)
)
for k, v in value.items():
if (
isinstance(v, torch.Tensor)
and v not in self.grad_to_source
and v not in self.tensor_to_source
):
self.tensor_to_source[v] = GetItemSource(p_state_source, k)
def wrap_tensor(self, tx, tensor_value):
"""Wrap state tensor in a TensorVariable"""
from ..decorators import mark_static_address
from .builder import VariableBuilder
# If we have a source for a tensor already use it,
# if we have not seen a tensor before, stash and use a
# global weak ref source, since it must be an optimizer tensor
# that we have missed
if tensor_value in self.tensor_to_source:
# mark these tensors as static for cudagraphs
mark_static_address(tensor_value)
builder = VariableBuilder(tx, self.tensor_to_source[tensor_value])
self.static_tensor_names.add(tx.output.module_key_name(builder.name))
elif tensor_value in self.grad_to_source:
builder = VariableBuilder(tx, self.grad_to_source[tensor_value])
else:
# mark these tensors as static for cudagraphs
mark_static_address(tensor_value)
global_name = tx.store_global_weakref_by_id(GLOBAL_KEY_PREFIX, tensor_value)
builder = VariableBuilder(tx, GlobalWeakRefSource(global_name))
self.static_tensor_names.add(tx.output.module_key_name(builder.name))
result = builder(tensor_value)
return result
def update_list_args(self, tx, args, kwargs, py_args, py_kwargs):
"""Update the args and kwargs to the traced optimizer call"""
for arg, py_arg in zip(args, py_args):
if isinstance(arg, ListVariable):
assert isinstance(
py_arg, list
), "py_arg should be a list in optimizer variable"
for i, val in enumerate(py_arg):
tx.output.side_effects.mutation(arg)
if isinstance(val, torch.Tensor):
arg.items.append(self.wrap_tensor(tx, val))
else:
from .builder import SourcelessBuilder, VariableBuilder
if arg.source:
arg.items.append(
VariableBuilder(tx, GetItemSource(arg.source, i))(val)
)
else:
arg.items.append(SourcelessBuilder.create(tx, val))
def create_finalizer(self, tx):
names_to_delete = self.static_tensor_names
value = self.value
tc = tx.output.tracing_context
def init_finalizer(gm):
def clear_static_tensor_refs():
for name in names_to_delete:
gm._buffers.pop(name, None)
gm._parameters.pop(name, None)
if tc.params_flat:
tc.params_flat.clear()
weakref.finalize(value, clear_static_tensor_refs)
tx.output.add_graph_finalizer(init_finalizer)
|