Spaces:
Running
Running
import inspect | |
from typing import Any, Callable, Dict, Iterable, Optional, Tuple, Type, Union | |
import torch | |
from torch._streambase import _EventBase, _StreamBase | |
get_cuda_stream: Optional[Callable[[int], int]] | |
if torch.cuda._is_compiled(): | |
from torch._C import _cuda_getCurrentRawStream as get_cuda_stream | |
else: | |
get_cuda_stream = None | |
_device_t = Union[torch.device, str, int, None] | |
# Recording the device properties in the main process but used in worker process. | |
caching_worker_device_properties: Dict[str, Any] = {} | |
caching_worker_current_devices: Dict[str, int] = {} | |
class DeviceInterfaceMeta(type): | |
def __new__(metacls, *args, **kwargs): | |
class_member = args[2] | |
if "Event" in class_member: | |
assert inspect.isclass(class_member["Event"]) and issubclass( | |
class_member["Event"], _EventBase | |
), "DeviceInterface member Event should be inherit from _EventBase" | |
if "Stream" in class_member: | |
assert inspect.isclass(class_member["Stream"]) and issubclass( | |
class_member["Stream"], _StreamBase | |
), "DeviceInterface member Stream should be inherit from _StreamBase" | |
return super().__new__(metacls, *args, **kwargs) | |
class DeviceInterface(metaclass=DeviceInterfaceMeta): | |
""" | |
This is a simple device runtime interface for Inductor. It enables custom | |
backends to be integrated with Inductor in a device-agnostic semantic. | |
""" | |
class device: | |
def __new__(cls, device: _device_t): | |
raise NotImplementedError() | |
class Worker: | |
""" | |
Worker API to query device properties that will work in multi processing | |
workers that cannot use the GPU APIs (due to processing fork() and | |
initialization time issues). Properties are recorded in the main process | |
before we fork the workers. | |
""" | |
def set_device(device: int): | |
raise NotImplementedError() | |
def current_device() -> int: | |
raise NotImplementedError() | |
def get_device_properties(device: _device_t = None): | |
raise NotImplementedError() | |
def current_device(): | |
raise NotImplementedError() | |
def set_device(device: _device_t): | |
raise NotImplementedError() | |
def device_count(): | |
raise NotImplementedError() | |
def is_available() -> bool: | |
raise NotImplementedError() | |
def stream(stream: torch.Stream): | |
raise NotImplementedError() | |
def current_stream(): | |
raise NotImplementedError() | |
def set_stream(stream: torch.Stream): | |
raise NotImplementedError() | |
def _set_stream_by_id(stream_id: int, device_index: int, device_type: int): | |
raise NotImplementedError() | |
def get_raw_stream(): | |
raise NotImplementedError() | |
def synchronize(device: _device_t = None): | |
raise NotImplementedError() | |
def get_device_properties(device: _device_t = None): | |
raise NotImplementedError() | |
def get_compute_capability(device: _device_t = None): | |
raise NotImplementedError() | |
class CudaInterface(DeviceInterface): | |
device = torch.cuda.device | |
# register Event and Stream class into the backend interface | |
# make sure Event and Stream are implemented and inherited from the _EventBase and _StreamBase | |
Event = torch.cuda.Event | |
Stream = torch.cuda.Stream | |
class Worker: | |
def set_device(device: int): | |
caching_worker_current_devices["cuda"] = device | |
def current_device() -> int: | |
if "cuda" in caching_worker_current_devices: | |
return caching_worker_current_devices["cuda"] | |
return torch.cuda.current_device() | |
def get_device_properties(device: _device_t = None): | |
if device is not None: | |
if isinstance(device, str): | |
device = torch.device(device) | |
assert device.type == "cuda" | |
if isinstance(device, torch.device): | |
device = device.index | |
if device is None: | |
device = CudaInterface.Worker.current_device() | |
if "cuda" not in caching_worker_device_properties: | |
device_prop = [ | |
torch.cuda.get_device_properties(i) | |
for i in range(torch.cuda.device_count()) | |
] | |
caching_worker_device_properties["cuda"] = device_prop | |
return caching_worker_device_properties["cuda"][device] | |
current_device = staticmethod(torch.cuda.current_device) | |
set_device = staticmethod(torch.cuda.set_device) | |
device_count = staticmethod(torch.cuda.device_count) | |
stream = staticmethod(torch.cuda.stream) # type: ignore[assignment] | |
current_stream = staticmethod(torch.cuda.current_stream) | |
set_stream = staticmethod(torch.cuda.set_stream) # type: ignore[assignment] | |
_set_stream_by_id = staticmethod(torch.cuda._set_stream_by_id) # type: ignore[assignment] | |
synchronize = staticmethod(torch.cuda.synchronize) | |
get_device_properties = staticmethod(torch.cuda.get_device_properties) # type: ignore[assignment] | |
get_raw_stream = staticmethod(get_cuda_stream) # type: ignore[arg-type] | |
# Can be mock patched by @patch decorator. | |
def is_available() -> bool: | |
return torch.cuda.is_available() | |
def get_compute_capability(device: _device_t = None): | |
major, min = torch.cuda.get_device_capability(device) | |
return major * 10 + min | |
device_interfaces: Dict[str, Type[DeviceInterface]] = {} | |
def register_interface_for_device( | |
device: Union[str, torch.device], device_interface: Type[DeviceInterface] | |
): | |
if isinstance(device, torch.device): | |
device = str(device) | |
device_interfaces[device] = device_interface | |
def get_interface_for_device(device: Union[str, torch.device]) -> Type[DeviceInterface]: | |
if isinstance(device, torch.device): | |
device = str(device) | |
if device in device_interfaces: | |
return device_interfaces[device] | |
raise NotImplementedError(f"No interface for device {device}") | |
def get_registered_device_interfaces() -> Iterable[Tuple[str, Type[DeviceInterface]]]: | |
return device_interfaces.items() | |
register_interface_for_device("cuda", CudaInterface) | |
for i in range(torch.cuda.device_count()): | |
register_interface_for_device(f"cuda:{i}", CudaInterface) | |