from typing import Callable from torch._utils import CallbackRegistry EventCreationCallbacks: "CallbackRegistry[int]" = CallbackRegistry( "CUDA event creation" ) EventDeletionCallbacks: "CallbackRegistry[int]" = CallbackRegistry( "CUDA event deletion" ) EventRecordCallbacks: "CallbackRegistry[int, int]" = CallbackRegistry( "CUDA event record" ) EventWaitCallbacks: "CallbackRegistry[int, int]" = CallbackRegistry( "CUDA event wait" ) MemoryAllocationCallbacks: "CallbackRegistry[int]" = CallbackRegistry( "CUDA memory allocation" ) MemoryDeallocationCallbacks: "CallbackRegistry[int]" = CallbackRegistry( "CUDA memory deallocation" ) StreamCreationCallbacks: "CallbackRegistry[int]" = CallbackRegistry( "CUDA stream creation" ) DeviceSynchronizationCallbacks: "CallbackRegistry[[]]" = CallbackRegistry( "CUDA device synchronization" ) StreamSynchronizationCallbacks: "CallbackRegistry[int]" = CallbackRegistry( "CUDA stream synchronization" ) EventSynchronizationCallbacks: "CallbackRegistry[int]" = CallbackRegistry( "CUDA event synchronization" ) def register_callback_for_event_creation(cb: Callable[[int], None]) -> None: EventCreationCallbacks.add_callback(cb) def register_callback_for_event_deletion(cb: Callable[[int], None]) -> None: EventDeletionCallbacks.add_callback(cb) def register_callback_for_event_record(cb: Callable[[int, int], None]) -> None: EventRecordCallbacks.add_callback(cb) def register_callback_for_event_wait(cb: Callable[[int, int], None]) -> None: EventWaitCallbacks.add_callback(cb) def register_callback_for_memory_allocation(cb: Callable[[int], None]) -> None: MemoryAllocationCallbacks.add_callback(cb) def register_callback_for_memory_deallocation(cb: Callable[[int], None]) -> None: MemoryDeallocationCallbacks.add_callback(cb) def register_callback_for_stream_creation(cb: Callable[[int], None]) -> None: StreamCreationCallbacks.add_callback(cb) def register_callback_for_device_synchronization(cb: Callable[[], None]) -> None: DeviceSynchronizationCallbacks.add_callback(cb) def register_callback_for_stream_synchronization(cb: Callable[[int], None]) -> None: StreamSynchronizationCallbacks.add_callback(cb) def register_callback_for_event_synchronization(cb: Callable[[int], None]) -> None: EventSynchronizationCallbacks.add_callback(cb)