|
|
|
from .abstract_impl import AbstractImplHolder |
|
|
|
__all__ = ["SimpleLibraryRegistry", "SimpleOperatorEntry", "singleton"] |
|
|
|
|
|
class SimpleLibraryRegistry: |
|
"""Registry for the "simple" torch.library APIs |
|
|
|
The "simple" torch.library APIs are a higher-level API on top of the |
|
raw PyTorch DispatchKey registration APIs that includes: |
|
- fake impl |
|
|
|
Registrations for these APIs do not go into the PyTorch dispatcher's |
|
table because they may not directly involve a DispatchKey. For example, |
|
the fake impl is a Python function that gets invoked by FakeTensor. |
|
Instead, we manage them here. |
|
|
|
SimpleLibraryRegistry is a mapping from a fully qualified operator name |
|
(including the overload) to SimpleOperatorEntry. |
|
""" |
|
|
|
def __init__(self): |
|
self._data = {} |
|
|
|
def find(self, qualname: str) -> "SimpleOperatorEntry": |
|
if qualname not in self._data: |
|
self._data[qualname] = SimpleOperatorEntry(qualname) |
|
return self._data[qualname] |
|
|
|
|
|
singleton: SimpleLibraryRegistry = SimpleLibraryRegistry() |
|
|
|
|
|
class SimpleOperatorEntry: |
|
"""This is 1:1 to an operator overload. |
|
|
|
The fields of SimpleOperatorEntry are Holders where kernels can be |
|
registered to. |
|
""" |
|
|
|
def __init__(self, qualname: str): |
|
self.qualname: str = qualname |
|
self.abstract_impl: AbstractImplHolder = AbstractImplHolder(qualname) |
|
|