Llama-3.1-8B-DALv0.1
/
venv
/lib
/python3.12
/site-packages
/torchgen
/packaged
/autograd
/context.py
import functools | |
from typing import Callable | |
from torchgen.api.autograd import NativeFunctionWithDifferentiabilityInfo as NFWDI | |
from torchgen.context import native_function_manager | |
from torchgen.utils import T | |
# Like tools.api.context.with_native_function, but for | |
# NativeFunctionWithDifferentiabilityInfo. | |
def with_native_function_with_differentiability_info( | |
func: Callable[[NFWDI], T] | |
) -> Callable[[NFWDI], T]: | |
def wrapper(f: NFWDI) -> T: | |
with native_function_manager(f.func): | |
return func(f) | |
return wrapper | |
# Like the above but with an additional dispatch key string argument | |
def with_native_function_with_differentiability_info_and_key( | |
func: Callable[[NFWDI, str], T] | |
) -> Callable[[NFWDI, str], T]: | |
def wrapper(f: NFWDI, key: str) -> T: | |
with native_function_manager(f.func): | |
return func(f, key) | |
return wrapper | |