File size: 2,383 Bytes
d1ceb73 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 |
from typing import Callable
from torch._utils import CallbackRegistry
EventCreationCallbacks: "CallbackRegistry[int]" = CallbackRegistry(
"CUDA event creation"
)
EventDeletionCallbacks: "CallbackRegistry[int]" = CallbackRegistry(
"CUDA event deletion"
)
EventRecordCallbacks: "CallbackRegistry[int, int]" = CallbackRegistry(
"CUDA event record"
)
EventWaitCallbacks: "CallbackRegistry[int, int]" = CallbackRegistry(
"CUDA event wait"
)
MemoryAllocationCallbacks: "CallbackRegistry[int]" = CallbackRegistry(
"CUDA memory allocation"
)
MemoryDeallocationCallbacks: "CallbackRegistry[int]" = CallbackRegistry(
"CUDA memory deallocation"
)
StreamCreationCallbacks: "CallbackRegistry[int]" = CallbackRegistry(
"CUDA stream creation"
)
DeviceSynchronizationCallbacks: "CallbackRegistry[[]]" = CallbackRegistry(
"CUDA device synchronization"
)
StreamSynchronizationCallbacks: "CallbackRegistry[int]" = CallbackRegistry(
"CUDA stream synchronization"
)
EventSynchronizationCallbacks: "CallbackRegistry[int]" = CallbackRegistry(
"CUDA event synchronization"
)
def register_callback_for_event_creation(cb: Callable[[int], None]) -> None:
EventCreationCallbacks.add_callback(cb)
def register_callback_for_event_deletion(cb: Callable[[int], None]) -> None:
EventDeletionCallbacks.add_callback(cb)
def register_callback_for_event_record(cb: Callable[[int, int], None]) -> None:
EventRecordCallbacks.add_callback(cb)
def register_callback_for_event_wait(cb: Callable[[int, int], None]) -> None:
EventWaitCallbacks.add_callback(cb)
def register_callback_for_memory_allocation(cb: Callable[[int], None]) -> None:
MemoryAllocationCallbacks.add_callback(cb)
def register_callback_for_memory_deallocation(cb: Callable[[int], None]) -> None:
MemoryDeallocationCallbacks.add_callback(cb)
def register_callback_for_stream_creation(cb: Callable[[int], None]) -> None:
StreamCreationCallbacks.add_callback(cb)
def register_callback_for_device_synchronization(cb: Callable[[], None]) -> None:
DeviceSynchronizationCallbacks.add_callback(cb)
def register_callback_for_stream_synchronization(cb: Callable[[int], None]) -> None:
StreamSynchronizationCallbacks.add_callback(cb)
def register_callback_for_event_synchronization(cb: Callable[[int], None]) -> None:
EventSynchronizationCallbacks.add_callback(cb)
|