"""Module for handling symbolic function registration.""" import warnings from typing import ( Callable, Collection, Dict, Generic, Optional, Sequence, Set, TypeVar, Union, ) from torch.onnx import _constants, errors from torch.onnx._internal import _beartype OpsetVersion = int def _dispatch_opset_version( target: OpsetVersion, registered_opsets: Collection[OpsetVersion] ) -> Optional[OpsetVersion]: """Finds the registered opset given a target opset version and the available opsets. Args: target: The target opset version. registered_opsets: The available opsets. Returns: The registered opset version. """ if not registered_opsets: return None descending_registered_versions = sorted(registered_opsets, reverse=True) # Linear search for the opset version, which is fine since the number of opset # versions is small. if target >= _constants.ONNX_BASE_OPSET: # Always look down toward opset 1 when the target is >= ONNX_BASE_OPSET (opset 9). # When a custom op is register at opset 1, we want to be able to discover it as a # fallback for all opsets >= ONNX_BASE_OPSET. for version in descending_registered_versions: if version <= target: return version return None # target < opset 9. This is the legacy behavior to support opset 7 and opset 8. # for caffe2 support. We search up toward opset 9. for version in reversed(descending_registered_versions): # Count back up until _constants.ONNX_BASE_OPSET if target <= version <= _constants.ONNX_BASE_OPSET: return version return None _K = TypeVar("_K") _V = TypeVar("_V") class OverrideDict(Generic[_K, _V], Collection[_K]): """A dictionary that merges built-in and custom symbolic functions. It supports overriding and un-overriding built-in symbolic functions with custom ones. """ def __init__(self): self._base: Dict[_K, _V] = {} self._overrides: Dict[_K, _V] = {} self._merged: Dict[_K, _V] = {} def set_base(self, key: _K, value: _V) -> None: self._base[key] = value if key not in self._overrides: self._merged[key] = value def in_base(self, key: _K) -> bool: """Checks if a key is in the base dictionary.""" return key in self._base def override(self, key: _K, value: _V) -> None: """Overrides a base key-value with a new pair.""" self._overrides[key] = value self._merged[key] = value def remove_override(self, key: _K) -> None: """Un-overrides a key-value pair.""" self._overrides.pop(key, None) # type: ignore[arg-type] self._merged.pop(key, None) # type: ignore[arg-type] if key in self._base: self._merged[key] = self._base[key] def overridden(self, key: _K) -> bool: """Checks if a key-value pair is overridden.""" return key in self._overrides def __getitem__(self, key: _K) -> _V: return self._merged[key] def get(self, key: _K, default: Optional[_V] = None): return self._merged.get(key, default) def __contains__(self, key: object) -> bool: return key in self._merged def __iter__(self): return iter(self._merged) def __len__(self) -> int: return len(self._merged) def __repr__(self) -> str: return f"OverrideDict(base={self._base}, overrides={self._overrides})" def __bool__(self) -> bool: return bool(self._merged) class _SymbolicFunctionGroup: """Different versions of symbolic functions registered to the same name. O(number of registered versions of an op) search is performed to find the most recent version of the op. The registration is delayed until op is used to improve startup time. Function overloads with different arguments are not allowed. Custom op overrides are supported. """ def __init__(self, name: str) -> None: self._name = name # A dictionary of functions, keyed by the opset version. self._functions: OverrideDict[OpsetVersion, Callable] = OverrideDict() def __repr__(self) -> str: return f"_SymbolicFunctionGroup({self._name}, registered={self._functions})" def __getitem__(self, key: OpsetVersion) -> Callable: result = self.get(key) if result is None: raise KeyError(key) return result # TODO(justinchuby): Add @functools.lru_cache(maxsize=None) if lookup time becomes # a problem. def get(self, opset: OpsetVersion) -> Optional[Callable]: """Find the most recent version of the function.""" version = _dispatch_opset_version(opset, self._functions) if version is None: return None return self._functions[version] def add(self, func: Callable, opset: OpsetVersion) -> None: """Adds a symbolic function. Args: func: The function to add. opset: The opset version of the function to add. """ if self._functions.in_base(opset): warnings.warn( f"Symbolic function '{self._name}' already registered for opset {opset}. " f"Replacing the existing function with new function. This is unexpected. " f"Please report it on {_constants.PYTORCH_GITHUB_ISSUES_URL}.", errors.OnnxExporterWarning, ) self._functions.set_base(opset, func) def add_custom(self, func: Callable, opset: OpsetVersion) -> None: """Adds a custom symbolic function. Args: func: The symbolic function to register. opset: The corresponding opset version. """ self._functions.override(opset, func) def remove_custom(self, opset: OpsetVersion) -> None: """Removes a custom symbolic function. Args: opset: The opset version of the custom function to remove. """ if not self._functions.overridden(opset): warnings.warn( f"No custom function registered for '{self._name}' opset {opset}" ) return self._functions.remove_override(opset) def get_min_supported(self) -> OpsetVersion: """Returns the lowest built-in opset version supported by the function.""" return min(self._functions) class SymbolicRegistry: """Registry for symbolic functions. The registry maintains a mapping from qualified names to symbolic functions. It is used to register new symbolic functions and to dispatch calls to the appropriate function. """ def __init__(self) -> None: self._registry: Dict[str, _SymbolicFunctionGroup] = {} def register( self, name: str, opset: OpsetVersion, func: Callable, custom: bool = False ) -> None: """Registers a symbolic function. Args: name: The qualified name of the function to register. In the form of 'domain::op'. E.g. 'aten::add'. opset: The opset version of the function to register. func: The symbolic function to register. custom: Whether the function is a custom function that overrides existing ones. Raises: ValueError: If the separator '::' is not in the name. """ if "::" not in name: raise ValueError( f"The name must be in the form of 'domain::op', not '{name}'" ) symbolic_functions = self._registry.setdefault( name, _SymbolicFunctionGroup(name) ) if custom: symbolic_functions.add_custom(func, opset) else: symbolic_functions.add(func, opset) def unregister(self, name: str, opset: OpsetVersion) -> None: """Unregisters a symbolic function. Args: name: The qualified name of the function to unregister. opset: The opset version of the function to unregister. """ if name not in self._registry: return self._registry[name].remove_custom(opset) def get_function_group(self, name: str) -> Optional[_SymbolicFunctionGroup]: """Returns the function group for the given name.""" return self._registry.get(name) def is_registered_op(self, name: str, version: int) -> bool: """Returns whether the given op is registered for the given opset version.""" functions = self.get_function_group(name) if functions is None: return False return functions.get(version) is not None def all_functions(self) -> Set[str]: """Returns the set of all registered function names.""" return set(self._registry) @_beartype.beartype def onnx_symbolic( name: str, opset: Union[OpsetVersion, Sequence[OpsetVersion]], decorate: Optional[Sequence[Callable]] = None, custom: bool = False, ) -> Callable: """Registers a symbolic function. Usage:: ``` @onnx_symbolic("aten::symbolic_b", opset=10, decorate=[quantized_aten_handler(scale=1/128, zero_point=0)]) @symbolic_helper.parse_args("v", "v", "b") def symbolic_b(g: _C.Graph, x: _C.Value, y: _C.Value, arg1: bool) -> _C.Value: ... ``` Args: name: The qualified name of the function in the form of 'domain::op'. E.g. 'aten::add'. opset: The opset versions of the function to register at. decorate: A sequence of decorators to apply to the function. custom: Whether the function is a custom symbolic function. Raises: ValueError: If the separator '::' is not in the name. """ def wrapper(func: Callable) -> Callable: decorated = func if decorate is not None: for decorate_func in decorate: decorated = decorate_func(decorated) global registry nonlocal opset if isinstance(opset, OpsetVersion): opset = (opset,) for opset_version in opset: registry.register(name, opset_version, decorated, custom=custom) # Return the original function because the decorators in "decorate" are only # specific to the instance being registered. return func return wrapper @_beartype.beartype def custom_onnx_symbolic( name: str, opset: Union[OpsetVersion, Sequence[OpsetVersion]], decorate: Optional[Sequence[Callable]] = None, ) -> Callable: """Registers a custom symbolic function. Args: name: the qualified name of the function. opset: the opset version of the function. decorate: a sequence of decorators to apply to the function. Returns: The decorator. Raises: ValueError: If the separator '::' is not in the name. """ return onnx_symbolic(name, opset, decorate, custom=True) # The registry for all symbolic functions. registry = SymbolicRegistry()