Spaces:
Running
Running
import dataclasses | |
from enum import auto, Enum | |
from typing import Collection, Dict, List, Mapping, Optional, Set, Tuple, Union | |
__all__ = [ | |
"ConstantArgument", | |
"CustomObjArgument", | |
"ExportBackwardSignature", | |
"ExportGraphSignature", | |
"InputKind", | |
"InputSpec", | |
"OutputKind", | |
"OutputSpec", | |
"SymIntArgument", | |
"TensorArgument", | |
] | |
class TensorArgument: | |
name: str | |
class SymIntArgument: | |
name: str | |
class CustomObjArgument: | |
name: str | |
class_fqn: str | |
class ConstantArgument: | |
value: Union[int, float, bool, None] | |
ArgumentSpec = Union[ | |
TensorArgument, SymIntArgument, ConstantArgument, CustomObjArgument | |
] | |
class InputKind(Enum): | |
USER_INPUT = auto() | |
PARAMETER = auto() | |
BUFFER = auto() | |
CONSTANT_TENSOR = auto() | |
CUSTOM_OBJ = auto() | |
TOKEN = auto() | |
class InputSpec: | |
kind: InputKind | |
arg: ArgumentSpec | |
target: Optional[str] | |
persistent: Optional[bool] = None | |
def __post_init__(self): | |
if self.kind == InputKind.BUFFER: | |
assert ( | |
self.persistent is not None | |
), "Failed to specify persistent flag on BUFFER." | |
assert isinstance( | |
self.arg, | |
(TensorArgument, SymIntArgument, ConstantArgument, CustomObjArgument), | |
), f"got {type(self.arg)}" | |
class OutputKind(Enum): | |
USER_OUTPUT = auto() | |
LOSS_OUTPUT = auto() | |
BUFFER_MUTATION = auto() | |
GRADIENT_TO_PARAMETER = auto() | |
GRADIENT_TO_USER_INPUT = auto() | |
USER_INPUT_MUTATION = auto() | |
TOKEN = auto() | |
class OutputSpec: | |
kind: OutputKind | |
arg: ArgumentSpec | |
target: Optional[str] | |
def __post_init__(self): | |
assert isinstance(self.arg, (TensorArgument, SymIntArgument, ConstantArgument)) | |
def _sig_to_specs( | |
*, | |
user_inputs: Set[str], | |
inputs_to_parameters: Mapping[str, str], | |
inputs_to_buffers: Mapping[str, str], | |
user_outputs: Set[str], | |
buffer_mutations: Mapping[str, str], | |
user_input_mutations: Mapping[str, str], | |
grad_params: Mapping[str, str], | |
grad_user_inputs: Mapping[str, str], | |
loss_output: Optional[str], | |
inputs: List[ArgumentSpec], | |
outputs: List[ArgumentSpec], | |
input_tokens: List[str], | |
output_tokens: List[str], | |
) -> Tuple[List[InputSpec], List[OutputSpec]]: | |
def to_input_spec(inp: ArgumentSpec) -> InputSpec: | |
if not isinstance(inp, TensorArgument): | |
return InputSpec(kind=InputKind.USER_INPUT, arg=inp, target=None) | |
name = inp.name | |
if name in user_inputs: | |
return InputSpec(kind=InputKind.USER_INPUT, arg=inp, target=None) | |
elif name in inputs_to_parameters: | |
return InputSpec( | |
kind=InputKind.PARAMETER, | |
arg=inp, | |
target=inputs_to_parameters[name], | |
) | |
elif name in inputs_to_buffers: | |
return InputSpec( | |
kind=InputKind.BUFFER, | |
arg=inp, | |
target=inputs_to_buffers[name], | |
# Mark as True for now; we will fix this up to distinguish | |
# persistent from non-persistent later in tracing. | |
# See: rewrite_non_persistent_buffers() | |
# TODO(suo): this is horrible. | |
persistent=True, | |
) | |
elif name in input_tokens: | |
return InputSpec(kind=InputKind.TOKEN, arg=inp, target=None) | |
else: | |
raise AssertionError(f"Unknown tensor input kind: {name}") | |
def to_output_spec(idx: int, o: ArgumentSpec) -> OutputSpec: | |
if not isinstance(o, TensorArgument): | |
return OutputSpec(kind=OutputKind.USER_OUTPUT, arg=o, target=None) | |
name = o.name | |
if idx < len(buffer_mutations) + len(user_input_mutations) + len(output_tokens): | |
if name in buffer_mutations: | |
return OutputSpec( | |
kind=OutputKind.BUFFER_MUTATION, | |
arg=o, | |
target=buffer_mutations[name], | |
) | |
elif name in user_input_mutations: | |
return OutputSpec( | |
kind=OutputKind.USER_INPUT_MUTATION, | |
arg=o, | |
target=user_input_mutations[name], | |
) | |
elif name in output_tokens: | |
return OutputSpec(kind=OutputKind.TOKEN, arg=o, target=None) | |
else: | |
raise AssertionError(f"Unknown tensor mutation kind: {name}") | |
else: | |
if name in user_outputs: | |
return OutputSpec(kind=OutputKind.USER_OUTPUT, arg=o, target=None) | |
elif name in grad_params: | |
return OutputSpec( | |
kind=OutputKind.GRADIENT_TO_PARAMETER, | |
arg=o, | |
target=grad_params[name], | |
) | |
elif name in grad_user_inputs: | |
return OutputSpec( | |
kind=OutputKind.GRADIENT_TO_USER_INPUT, | |
arg=o, | |
target=grad_user_inputs[name], | |
) | |
elif name == loss_output: | |
return OutputSpec(kind=OutputKind.LOSS_OUTPUT, arg=o, target=None) | |
else: | |
raise AssertionError(f"Unknown tensor output kind: {name}") | |
input_specs = [to_input_spec(inp) for inp in inputs] | |
output_specs = [to_output_spec(idx, o) for idx, o in enumerate(outputs)] | |
return input_specs, output_specs | |
class ExportBackwardSignature: | |
gradients_to_parameters: Dict[str, str] | |
gradients_to_user_inputs: Dict[str, str] | |
loss_output: str | |
class ExportGraphSignature: | |
""" | |
:class:`ExportGraphSignature` models the input/output signature of Export Graph, | |
which is a fx.Graph with stronger invariants gurantees. | |
Export Graph is functional and does not access "states" like parameters | |
or buffers within the graph via ``getattr`` nodes. Instead, :func:`export` | |
gurantees that parameters, buffers, and constant tensors are lifted out of | |
the graph as inputs. Similarly, any mutations to buffers are not included | |
in the graph either, instead the updated values of mutated buffers are | |
modeled as additional outputs of Export Graph. | |
The ordering of all inputs and outputs are:: | |
Inputs = [*parameters_buffers_constant_tensors, *flattened_user_inputs] | |
Outputs = [*mutated_inputs, *flattened_user_outputs] | |
e.g. If following module is exported:: | |
class CustomModule(nn.Module): | |
def __init__(self): | |
super(CustomModule, self).__init__() | |
# Define a parameter | |
self.my_parameter = nn.Parameter(torch.tensor(2.0)) | |
# Define two buffers | |
self.register_buffer('my_buffer1', torch.tensor(3.0)) | |
self.register_buffer('my_buffer2', torch.tensor(4.0)) | |
def forward(self, x1, x2): | |
# Use the parameter, buffers, and both inputs in the forward method | |
output = (x1 + self.my_parameter) * self.my_buffer1 + x2 * self.my_buffer2 | |
# Mutate one of the buffers (e.g., increment it by 1) | |
self.my_buffer2.add_(1.0) # In-place addition | |
return output | |
Resulting Graph would be:: | |
graph(): | |
%arg0_1 := placeholder[target=arg0_1] | |
%arg1_1 := placeholder[target=arg1_1] | |
%arg2_1 := placeholder[target=arg2_1] | |
%arg3_1 := placeholder[target=arg3_1] | |
%arg4_1 := placeholder[target=arg4_1] | |
%add_tensor := call_function[target=torch.ops.aten.add.Tensor](args = (%arg3_1, %arg0_1), kwargs = {}) | |
%mul_tensor := call_function[target=torch.ops.aten.mul.Tensor](args = (%add_tensor, %arg1_1), kwargs = {}) | |
%mul_tensor_1 := call_function[target=torch.ops.aten.mul.Tensor](args = (%arg4_1, %arg2_1), kwargs = {}) | |
%add_tensor_1 := call_function[target=torch.ops.aten.add.Tensor](args = (%mul_tensor, %mul_tensor_1), kwargs = {}) | |
%add_tensor_2 := call_function[target=torch.ops.aten.add.Tensor](args = (%arg2_1, 1.0), kwargs = {}) | |
return (add_tensor_2, add_tensor_1) | |
Resulting ExportGraphSignature would be:: | |
ExportGraphSignature( | |
input_specs=[ | |
InputSpec(kind=<InputKind.PARAMETER: 2>, arg=TensorArgument(name='arg0_1'), target='my_parameter'), | |
InputSpec(kind=<InputKind.BUFFER: 3>, arg=TensorArgument(name='arg1_1'), target='my_buffer1'), | |
InputSpec(kind=<InputKind.BUFFER: 3>, arg=TensorArgument(name='arg2_1'), target='my_buffer2'), | |
InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='arg3_1'), target=None), | |
InputSpec(kind=<InputKind.USER_INPUT: 1>, arg=TensorArgument(name='arg4_1'), target=None) | |
], | |
output_specs=[ | |
OutputSpec(kind=<OutputKind.BUFFER_MUTATION: 3>, arg=TensorArgument(name='add_2'), target='my_buffer2'), | |
OutputSpec(kind=<OutputKind.USER_OUTPUT: 1>, arg=TensorArgument(name='add_1'), target=None) | |
] | |
) | |
""" | |
input_specs: List[InputSpec] | |
output_specs: List[OutputSpec] | |
# A list of parameters uniquely identified by mangled fully qualified name | |
def parameters(self) -> Collection[str]: | |
# TODO Make this tuple. | |
return [ | |
s.target | |
for s in self.input_specs | |
if s.kind == InputKind.PARAMETER | |
if isinstance(s.target, str) | |
] | |
# A list of buffers uniquely identified by mangled fully qualified name | |
def buffers(self) -> Collection[str]: | |
# TODO Make this tuple. | |
return [ | |
s.target | |
for s in self.input_specs | |
if s.kind == InputKind.BUFFER | |
if isinstance(s.target, str) | |
] | |
def non_persistent_buffers(self) -> Collection[str]: | |
return [ | |
s.target | |
for s in self.input_specs | |
if s.kind == InputKind.BUFFER | |
if s.persistent is False | |
if isinstance(s.target, str) | |
] | |
# A list of lifted constant tensors | |
def lifted_tensor_constants(self) -> Collection[str]: | |
# TODO Make this tuple. | |
return [ | |
s.target | |
for s in self.input_specs | |
if s.kind == InputKind.CONSTANT_TENSOR | |
if isinstance(s.target, str) | |
] | |
def lifted_custom_objs(self) -> Collection[str]: | |
# TODO Make this tuple. | |
return [ | |
s.target | |
for s in self.input_specs | |
if s.kind == InputKind.CUSTOM_OBJ | |
if isinstance(s.target, str) | |
] | |
# Graph node names of pytree-flattened inputs of original program | |
def user_inputs(self) -> Collection[Union[int, float, bool, None, str]]: | |
user_inputs: List[Union[int, float, bool, None, str]] = [] | |
for s in self.input_specs: | |
if s.kind != InputKind.USER_INPUT: | |
continue | |
if isinstance(s.arg, (TensorArgument, SymIntArgument, CustomObjArgument)): | |
user_inputs.append(s.arg.name) | |
elif isinstance(s.arg, ConstantArgument): | |
user_inputs.append(s.arg.value) | |
else: | |
raise RuntimeError(f"{s.arg} is not a valid user inputs") | |
return tuple(user_inputs) | |
# Graph node names of pytree-flattened outputs of original program | |
def user_outputs(self) -> Collection[Union[int, float, bool, None, str]]: | |
user_outputs: List[Union[int, float, bool, None, str]] = [] | |
for s in self.output_specs: | |
if s.kind != OutputKind.USER_OUTPUT: | |
continue | |
if isinstance(s.arg, (TensorArgument, SymIntArgument)): | |
user_outputs.append(s.arg.name) | |
elif isinstance(s.arg, ConstantArgument): | |
user_outputs.append(s.arg.value) | |
else: | |
raise RuntimeError(f"{s.arg} is not a valid user output") | |
return tuple(user_outputs) | |
# A dictionary mapping graph input node names to parameters. If a graph input | |
# name is found in this dictionary, it is guranteed to be a lifted parameter. | |
def inputs_to_parameters(self) -> Mapping[str, str]: | |
return { | |
s.arg.name: s.target | |
for s in self.input_specs | |
if s.kind == InputKind.PARAMETER | |
and isinstance(s.arg, TensorArgument) | |
and isinstance(s.target, str) | |
} | |
# A dictionary mapping graph input node names to buffers. If a graph input | |
# name is found in this dictionary, it is guranteed to be a lifted buffer. | |
def inputs_to_buffers(self) -> Mapping[str, str]: | |
return { | |
s.arg.name: s.target # type: ignore[union-attr, misc] | |
for s in self.input_specs | |
if s.kind == InputKind.BUFFER | |
and isinstance(s.arg, TensorArgument) | |
and isinstance(s.target, str) | |
} | |
# A dictionary mapping graph output node names to buffers that are mutated in the | |
# original program. Buffers that are not mutated will not be found in this dictionary. | |
def buffers_to_mutate(self) -> Mapping[str, str]: | |
return { | |
s.arg.name: s.target | |
for s in self.output_specs | |
if s.kind == OutputKind.BUFFER_MUTATION | |
and isinstance(s.arg, TensorArgument) | |
and isinstance(s.target, str) | |
} | |
def user_inputs_to_mutate(self) -> Mapping[str, str]: | |
return { | |
s.arg.name: s.target | |
for s in self.output_specs | |
if s.kind == OutputKind.USER_INPUT_MUTATION | |
and isinstance(s.arg, TensorArgument) | |
and isinstance(s.target, str) | |
} | |
# A dictionary mapping graph input node names to lifted tensor constants. | |
def inputs_to_lifted_tensor_constants(self) -> Mapping[str, str]: | |
return { | |
s.arg.name: s.target | |
for s in self.input_specs | |
if s.kind == InputKind.CONSTANT_TENSOR | |
and isinstance(s.arg, TensorArgument) | |
and isinstance(s.target, str) | |
} | |
def inputs_to_lifted_custom_objs(self) -> Mapping[str, str]: | |
return { | |
s.arg.name: s.target | |
for s in self.input_specs | |
if s.kind == InputKind.CUSTOM_OBJ | |
and isinstance(s.arg, CustomObjArgument) | |
and isinstance(s.target, str) | |
} | |
def backward_signature(self) -> Optional[ExportBackwardSignature]: | |
loss_output = None | |
gradients_to_parameters: Dict[str, str] = {} | |
gradients_to_user_inputs: Dict[str, str] = {} | |
for spec in self.output_specs: | |
if spec.kind == OutputKind.LOSS_OUTPUT: | |
assert loss_output is None | |
assert isinstance(spec.arg, TensorArgument) | |
loss_output = spec.arg.name | |
elif spec.kind == OutputKind.GRADIENT_TO_PARAMETER: | |
assert isinstance(spec.target, str) | |
assert isinstance(spec.arg, TensorArgument) | |
gradients_to_parameters[spec.arg.name] = spec.target | |
elif spec.kind == OutputKind.GRADIENT_TO_USER_INPUT: | |
assert isinstance(spec.target, str) | |
assert isinstance(spec.arg, TensorArgument) | |
gradients_to_user_inputs[spec.arg.name] = spec.target | |
if loss_output is None: | |
return None | |
return ExportBackwardSignature( | |
loss_output=loss_output, | |
gradients_to_parameters=gradients_to_parameters, | |
gradients_to_user_inputs=gradients_to_user_inputs, | |
) | |
# Map from assertion dependency token index to assertion dep token output | |
# name in output. The shape of output after aot_autograd will be like: | |
# (updated_inputs, user_outputs, dep_token). | |
def assertion_dep_token(self) -> Optional[Mapping[int, str]]: | |
return None | |
def input_tokens(self) -> List[str]: | |
input_tokens = [] | |
for s in self.input_specs: | |
if s.kind == InputKind.TOKEN: | |
assert isinstance(s.arg, TensorArgument) | |
input_tokens.append(s.arg.name) | |
return input_tokens | |
def output_tokens(self) -> List[str]: | |
output_tokens = [] | |
for s in self.output_specs: | |
if s.kind == OutputKind.TOKEN: | |
assert isinstance(s.arg, TensorArgument) | |
output_tokens.append(s.arg.name) | |
return output_tokens | |
def __post_init__(self) -> None: | |
assertion_dep_token = self.assertion_dep_token | |
if assertion_dep_token is None: | |
return | |
assert len(assertion_dep_token) == 1 | |
assertion_dep_token_index = next(iter(assertion_dep_token.keys())) | |
assert ( | |
len(self.user_outputs) + len(self.buffers_to_mutate) | |
== assertion_dep_token_index | |
) | |
def replace_all_uses(self, old: str, new: str): | |
""" | |
Replace all uses of the old name with new name in the signature. | |
""" | |
assert isinstance(old, str) | |
assert isinstance(new, str) | |
arg_types = (TensorArgument, SymIntArgument, CustomObjArgument) | |
for o in self.output_specs: | |
if isinstance(o.arg, arg_types): | |
if o.arg.name == old: | |
o.arg.name = new | |
for i in self.input_specs: | |
if isinstance(i.arg, arg_types): | |
if i.arg.name == old: | |
i.arg.name = new | |
def get_replace_hook(self): | |
def _(old, new, user): | |
if user.op in ("output", "input"): | |
self.replace_all_uses(old.name, new) | |
return _ | |