|
from typing import Callable |
|
|
|
from torch._utils import CallbackRegistry |
|
|
|
|
|
EventCreationCallbacks: "CallbackRegistry[int]" = CallbackRegistry( |
|
"CUDA event creation" |
|
) |
|
EventDeletionCallbacks: "CallbackRegistry[int]" = CallbackRegistry( |
|
"CUDA event deletion" |
|
) |
|
EventRecordCallbacks: "CallbackRegistry[int, int]" = CallbackRegistry( |
|
"CUDA event record" |
|
) |
|
EventWaitCallbacks: "CallbackRegistry[int, int]" = CallbackRegistry( |
|
"CUDA event wait" |
|
) |
|
MemoryAllocationCallbacks: "CallbackRegistry[int]" = CallbackRegistry( |
|
"CUDA memory allocation" |
|
) |
|
MemoryDeallocationCallbacks: "CallbackRegistry[int]" = CallbackRegistry( |
|
"CUDA memory deallocation" |
|
) |
|
StreamCreationCallbacks: "CallbackRegistry[int]" = CallbackRegistry( |
|
"CUDA stream creation" |
|
) |
|
DeviceSynchronizationCallbacks: "CallbackRegistry[[]]" = CallbackRegistry( |
|
"CUDA device synchronization" |
|
) |
|
StreamSynchronizationCallbacks: "CallbackRegistry[int]" = CallbackRegistry( |
|
"CUDA stream synchronization" |
|
) |
|
EventSynchronizationCallbacks: "CallbackRegistry[int]" = CallbackRegistry( |
|
"CUDA event synchronization" |
|
) |
|
|
|
|
|
def register_callback_for_event_creation(cb: Callable[[int], None]) -> None: |
|
EventCreationCallbacks.add_callback(cb) |
|
|
|
|
|
def register_callback_for_event_deletion(cb: Callable[[int], None]) -> None: |
|
EventDeletionCallbacks.add_callback(cb) |
|
|
|
|
|
def register_callback_for_event_record(cb: Callable[[int, int], None]) -> None: |
|
EventRecordCallbacks.add_callback(cb) |
|
|
|
|
|
def register_callback_for_event_wait(cb: Callable[[int, int], None]) -> None: |
|
EventWaitCallbacks.add_callback(cb) |
|
|
|
|
|
def register_callback_for_memory_allocation(cb: Callable[[int], None]) -> None: |
|
MemoryAllocationCallbacks.add_callback(cb) |
|
|
|
|
|
def register_callback_for_memory_deallocation(cb: Callable[[int], None]) -> None: |
|
MemoryDeallocationCallbacks.add_callback(cb) |
|
|
|
|
|
def register_callback_for_stream_creation(cb: Callable[[int], None]) -> None: |
|
StreamCreationCallbacks.add_callback(cb) |
|
|
|
|
|
def register_callback_for_device_synchronization(cb: Callable[[], None]) -> None: |
|
DeviceSynchronizationCallbacks.add_callback(cb) |
|
|
|
|
|
def register_callback_for_stream_synchronization(cb: Callable[[int], None]) -> None: |
|
StreamSynchronizationCallbacks.add_callback(cb) |
|
|
|
|
|
def register_callback_for_event_synchronization(cb: Callable[[int], None]) -> None: |
|
EventSynchronizationCallbacks.add_callback(cb) |
|
|