|
|
|
r""" |
|
This package adds support for CUDA tensor types. |
|
|
|
It implements the same function as CPU tensors, but they utilize |
|
GPUs for computation. |
|
|
|
It is lazily initialized, so you can always import it, and use |
|
:func:`is_available()` to determine if your system supports CUDA. |
|
|
|
:ref:`cuda-semantics` has more details about working with CUDA. |
|
""" |
|
|
|
|
|
import contextlib |
|
import importlib |
|
import os |
|
import sys |
|
import threading |
|
import traceback |
|
import warnings |
|
from functools import lru_cache |
|
from typing import Any, Callable, cast, List, Optional, Tuple, Union |
|
|
|
import torch |
|
import torch._C |
|
from torch.types import Device |
|
from .. import device as _device |
|
from .._utils import _dummy_type, _LazySeedTracker, classproperty |
|
from ._utils import _get_device_index |
|
from .graphs import ( |
|
CUDAGraph, |
|
graph, |
|
graph_pool_handle, |
|
is_current_stream_capturing, |
|
make_graphed_callables, |
|
) |
|
from .streams import Event, ExternalStream, Stream |
|
|
|
try: |
|
from torch._C import _cudart |
|
except ImportError: |
|
_cudart = None |
|
|
|
_initialized = False |
|
_tls = threading.local() |
|
_initialization_lock = threading.Lock() |
|
_queued_calls: List[ |
|
Tuple[Callable[[], None], List[str]] |
|
] = [] |
|
_is_in_bad_fork = getattr(torch._C, "_cuda_isInBadFork", lambda: False) |
|
_device_t = Union[_device, str, int, None] |
|
|
|
_HAS_PYNVML = False |
|
_PYNVML_ERR = None |
|
try: |
|
try: |
|
import pynvml |
|
|
|
_HAS_PYNVML = True |
|
except ModuleNotFoundError: |
|
pass |
|
try: |
|
import amdsmi |
|
|
|
_HAS_PYNVML = True |
|
except ModuleNotFoundError: |
|
pass |
|
except ImportError as err: |
|
_PYNVML_ERR = err |
|
|
|
_lazy_seed_tracker = _LazySeedTracker() |
|
|
|
|
|
if hasattr(torch._C, "_CudaDeviceProperties"): |
|
_CudaDeviceProperties = torch._C._CudaDeviceProperties |
|
else: |
|
_CudaDeviceProperties = _dummy_type("_CudaDeviceProperties") |
|
|
|
if hasattr(torch._C, "_cuda_exchangeDevice"): |
|
_exchange_device = torch._C._cuda_exchangeDevice |
|
else: |
|
|
|
def _exchange_device(device: int) -> int: |
|
if device < 0: |
|
return -1 |
|
raise RuntimeError("PyTorch was compiled without CUDA support") |
|
|
|
|
|
if hasattr(torch._C, "_cuda_maybeExchangeDevice"): |
|
_maybe_exchange_device = torch._C._cuda_maybeExchangeDevice |
|
else: |
|
|
|
def _maybe_exchange_device(device: int) -> int: |
|
if device < 0: |
|
return -1 |
|
raise RuntimeError("PyTorch was compiled without CUDA support") |
|
|
|
|
|
has_half: bool = True |
|
has_magma: bool = torch._C._has_magma |
|
|
|
default_generators: Tuple[torch._C.Generator] = () |
|
|
|
|
|
def _is_compiled() -> bool: |
|
r"""Return true if compile with CUDA support.""" |
|
return hasattr(torch._C, "_cuda_getDeviceCount") |
|
|
|
|
|
def _nvml_based_avail() -> bool: |
|
return os.getenv("PYTORCH_NVML_BASED_CUDA_CHECK") == "1" |
|
|
|
|
|
def is_available() -> bool: |
|
r"""Return a bool indicating if CUDA is currently available.""" |
|
if not _is_compiled(): |
|
return False |
|
if _nvml_based_avail(): |
|
|
|
|
|
|
|
return device_count() > 0 |
|
else: |
|
|
|
|
|
|
|
return torch._C._cuda_getDeviceCount() > 0 |
|
|
|
|
|
def is_bf16_supported(including_emulation: bool = True): |
|
r"""Return a bool indicating if the current CUDA/ROCm device supports dtype bfloat16.""" |
|
|
|
|
|
if torch.version.hip: |
|
return True |
|
|
|
device = torch.cuda.current_device() |
|
|
|
|
|
|
|
cuda_version = torch.version.cuda |
|
if ( |
|
cuda_version is not None |
|
and int(cuda_version.split(".")[0]) >= 11 |
|
and torch.cuda.get_device_properties(device).major >= 8 |
|
): |
|
return True |
|
|
|
if not including_emulation: |
|
return False |
|
|
|
|
|
return _check_bf16_tensor_supported(device) |
|
|
|
|
|
@lru_cache(maxsize=16) |
|
def _check_bf16_tensor_supported(device: _device_t): |
|
try: |
|
torch.tensor([1.0], dtype=torch.bfloat16, device=device) |
|
return True |
|
except Exception: |
|
return False |
|
|
|
|
|
def _sleep(cycles): |
|
torch._C._cuda_sleep(cycles) |
|
|
|
|
|
def _extract_arch_version(arch_string: str): |
|
"""Extracts the architecture string from a CUDA version""" |
|
base = arch_string.split("_")[1] |
|
if base.endswith("a"): |
|
base = base[:-1] |
|
return int(base) |
|
|
|
|
|
def _check_capability(): |
|
incorrect_binary_warn = """ |
|
Found GPU%d %s which requires CUDA_VERSION >= %d to |
|
work properly, but your PyTorch was compiled |
|
with CUDA_VERSION %d. Please install the correct PyTorch binary |
|
using instructions from https://pytorch.org |
|
""" |
|
|
|
old_gpu_warn = """ |
|
Found GPU%d %s which is of cuda capability %d.%d. |
|
PyTorch no longer supports this GPU because it is too old. |
|
The minimum cuda capability supported by this library is %d.%d. |
|
""" |
|
|
|
if torch.version.cuda is not None: |
|
CUDA_VERSION = torch._C._cuda_getCompiledVersion() |
|
for d in range(device_count()): |
|
capability = get_device_capability(d) |
|
major = capability[0] |
|
minor = capability[1] |
|
name = get_device_name(d) |
|
current_arch = major * 10 + minor |
|
min_arch = min( |
|
(_extract_arch_version(arch) for arch in torch.cuda.get_arch_list()), |
|
default=35, |
|
) |
|
if current_arch < min_arch: |
|
warnings.warn( |
|
old_gpu_warn |
|
% (d, name, major, minor, min_arch // 10, min_arch % 10) |
|
) |
|
|
|
|
|
def _check_cubins(): |
|
incompatible_device_warn = """ |
|
{} with CUDA capability sm_{} is not compatible with the current PyTorch installation. |
|
The current PyTorch install supports CUDA capabilities {}. |
|
If you want to use the {} GPU with PyTorch, please check the instructions at https://pytorch.org/get-started/locally/ |
|
""" |
|
if torch.version.cuda is None: |
|
return |
|
arch_list = get_arch_list() |
|
if len(arch_list) == 0: |
|
return |
|
supported_sm = [_extract_arch_version(arch) for arch in arch_list if "sm_" in arch] |
|
for idx in range(device_count()): |
|
cap_major, cap_minor = get_device_capability(idx) |
|
|
|
supported = any(sm // 10 == cap_major for sm in supported_sm) |
|
if not supported: |
|
device_name = get_device_name(idx) |
|
capability = cap_major * 10 + cap_minor |
|
warnings.warn( |
|
incompatible_device_warn.format( |
|
device_name, capability, " ".join(arch_list), device_name |
|
) |
|
) |
|
|
|
|
|
def is_initialized(): |
|
r"""Return whether PyTorch's CUDA state has been initialized.""" |
|
return _initialized and not _is_in_bad_fork() |
|
|
|
|
|
def _lazy_call(callable, **kwargs): |
|
if is_initialized(): |
|
callable() |
|
else: |
|
|
|
|
|
|
|
global _lazy_seed_tracker |
|
if kwargs.get("seed_all", False): |
|
_lazy_seed_tracker.queue_seed_all(callable, traceback.format_stack()) |
|
elif kwargs.get("seed", False): |
|
_lazy_seed_tracker.queue_seed(callable, traceback.format_stack()) |
|
else: |
|
|
|
_queued_calls.append((callable, traceback.format_stack())) |
|
|
|
|
|
_lazy_call(_check_capability) |
|
_lazy_call(_check_cubins) |
|
|
|
|
|
class DeferredCudaCallError(Exception): |
|
pass |
|
|
|
|
|
OutOfMemoryError = torch._C.OutOfMemoryError |
|
|
|
|
|
def init(): |
|
r"""Initialize PyTorch's CUDA state. |
|
|
|
You may need to call this explicitly if you are interacting with |
|
PyTorch via its C API, as Python bindings for CUDA functionality |
|
will not be available until this initialization takes place. |
|
Ordinary users should not need this, as all of PyTorch's CUDA methods |
|
automatically initialize CUDA state on-demand. |
|
|
|
Does nothing if the CUDA state is already initialized. |
|
""" |
|
_lazy_init() |
|
|
|
|
|
def _lazy_init(): |
|
global _initialized, _queued_calls |
|
if is_initialized() or hasattr(_tls, "is_initializing"): |
|
return |
|
with _initialization_lock: |
|
|
|
|
|
|
|
|
|
|
|
if is_initialized(): |
|
return |
|
|
|
|
|
|
|
if _is_in_bad_fork(): |
|
raise RuntimeError( |
|
"Cannot re-initialize CUDA in forked subprocess. To use CUDA with " |
|
"multiprocessing, you must use the 'spawn' start method" |
|
) |
|
if not hasattr(torch._C, "_cuda_getDeviceCount"): |
|
raise AssertionError("Torch not compiled with CUDA enabled") |
|
if _cudart is None: |
|
raise AssertionError( |
|
"libcudart functions unavailable. It looks like you have a broken build?" |
|
) |
|
|
|
|
|
if "CUDA_MODULE_LOADING" not in os.environ: |
|
os.environ["CUDA_MODULE_LOADING"] = "LAZY" |
|
torch._C._cuda_init() |
|
|
|
|
|
|
|
_tls.is_initializing = True |
|
|
|
for calls in _lazy_seed_tracker.get_calls(): |
|
if calls: |
|
_queued_calls.append(calls) |
|
|
|
try: |
|
for queued_call, orig_traceback in _queued_calls: |
|
try: |
|
queued_call() |
|
except Exception as e: |
|
msg = ( |
|
f"CUDA call failed lazily at initialization with error: {str(e)}\n\n" |
|
f"CUDA call was originally invoked at:\n\n{''.join(orig_traceback)}" |
|
) |
|
raise DeferredCudaCallError(msg) from e |
|
finally: |
|
delattr(_tls, "is_initializing") |
|
_initialized = True |
|
|
|
|
|
def cudart(): |
|
_lazy_init() |
|
return _cudart |
|
|
|
|
|
class cudaStatus: |
|
SUCCESS: int = 0 |
|
ERROR_NOT_READY: int = 34 |
|
|
|
|
|
class CudaError(RuntimeError): |
|
def __init__(self, code: int) -> None: |
|
msg = _cudart.cudaGetErrorString(_cudart.cudaError(code)) |
|
super().__init__(f"{msg} ({code})") |
|
|
|
|
|
def check_error(res: int) -> None: |
|
if res != _cudart.cudaError.success: |
|
raise CudaError(res) |
|
|
|
|
|
class _DeviceGuard: |
|
def __init__(self, index: int): |
|
self.idx = index |
|
self.prev_idx = -1 |
|
|
|
def __enter__(self): |
|
self.prev_idx = torch.cuda._exchange_device(self.idx) |
|
|
|
def __exit__(self, type: Any, value: Any, traceback: Any): |
|
self.idx = torch.cuda._maybe_exchange_device(self.prev_idx) |
|
return False |
|
|
|
|
|
class device: |
|
r"""Context-manager that changes the selected device. |
|
|
|
Args: |
|
device (torch.device or int): device index to select. It's a no-op if |
|
this argument is a negative integer or ``None``. |
|
""" |
|
|
|
def __init__(self, device: Any): |
|
self.idx = _get_device_index(device, optional=True) |
|
self.prev_idx = -1 |
|
|
|
def __enter__(self): |
|
self.prev_idx = torch.cuda._exchange_device(self.idx) |
|
|
|
def __exit__(self, type: Any, value: Any, traceback: Any): |
|
self.idx = torch.cuda._maybe_exchange_device(self.prev_idx) |
|
return False |
|
|
|
|
|
class device_of(device): |
|
r"""Context-manager that changes the current device to that of given object. |
|
|
|
You can use both tensors and storages as arguments. If a given object is |
|
not allocated on a GPU, this is a no-op. |
|
|
|
Args: |
|
obj (Tensor or Storage): object allocated on the selected device. |
|
""" |
|
|
|
def __init__(self, obj): |
|
idx = obj.get_device() if obj.is_cuda else -1 |
|
super().__init__(idx) |
|
|
|
|
|
def set_device(device: _device_t) -> None: |
|
r"""Set the current device. |
|
|
|
Usage of this function is discouraged in favor of :any:`device`. In most |
|
cases it's better to use ``CUDA_VISIBLE_DEVICES`` environmental variable. |
|
|
|
Args: |
|
device (torch.device or int): selected device. This function is a no-op |
|
if this argument is negative. |
|
""" |
|
device = _get_device_index(device) |
|
if device >= 0: |
|
torch._C._cuda_setDevice(device) |
|
|
|
|
|
def get_device_name(device: Optional[_device_t] = None) -> str: |
|
r"""Get the name of a device. |
|
|
|
Args: |
|
device (torch.device or int or str, optional): device for which to return the |
|
name. This function is a no-op if this argument is a negative |
|
integer. It uses the current device, given by :func:`~torch.cuda.current_device`, |
|
if :attr:`device` is ``None`` (default). |
|
|
|
Returns: |
|
str: the name of the device |
|
""" |
|
return get_device_properties(device).name |
|
|
|
|
|
def get_device_capability(device: Optional[_device_t] = None) -> Tuple[int, int]: |
|
r"""Get the cuda capability of a device. |
|
|
|
Args: |
|
device (torch.device or int or str, optional): device for which to return the |
|
device capability. This function is a no-op if this argument is |
|
a negative integer. It uses the current device, given by |
|
:func:`~torch.cuda.current_device`, if :attr:`device` is ``None`` |
|
(default). |
|
|
|
Returns: |
|
tuple(int, int): the major and minor cuda capability of the device |
|
""" |
|
prop = get_device_properties(device) |
|
return prop.major, prop.minor |
|
|
|
|
|
def get_device_properties(device: _device_t) -> _CudaDeviceProperties: |
|
r"""Get the properties of a device. |
|
|
|
Args: |
|
device (torch.device or int or str): device for which to return the |
|
properties of the device. |
|
|
|
Returns: |
|
_CudaDeviceProperties: the properties of the device |
|
""" |
|
_lazy_init() |
|
device = _get_device_index(device, optional=True) |
|
if device < 0 or device >= device_count(): |
|
raise AssertionError("Invalid device id") |
|
return _get_device_properties(device) |
|
|
|
|
|
def can_device_access_peer(device: _device_t, peer_device: _device_t) -> bool: |
|
r"""Check if peer access between two devices is possible.""" |
|
_lazy_init() |
|
device = _get_device_index(device, optional=True) |
|
peer_device = _get_device_index(peer_device) |
|
if device < 0 or device >= device_count(): |
|
raise AssertionError("Invalid device id") |
|
if peer_device < 0 or peer_device >= device_count(): |
|
raise AssertionError("Invalid peer device id") |
|
return torch._C._cuda_canDeviceAccessPeer(device, peer_device) |
|
|
|
|
|
class StreamContext: |
|
r"""Context-manager that selects a given stream. |
|
|
|
All CUDA kernels queued within its context will be enqueued on a selected |
|
stream. |
|
|
|
Args: |
|
Stream (Stream): selected stream. This manager is a no-op if it's |
|
``None``. |
|
.. note:: Streams are per-device. |
|
""" |
|
cur_stream: Optional["torch.cuda.Stream"] |
|
|
|
def __init__(self, stream: Optional["torch.cuda.Stream"]): |
|
self.stream = stream |
|
self.idx = _get_device_index(None, True) |
|
if not torch.jit.is_scripting(): |
|
if self.idx is None: |
|
self.idx = -1 |
|
|
|
self.src_prev_stream = ( |
|
None if not torch.jit.is_scripting() else torch.cuda.default_stream(None) |
|
) |
|
self.dst_prev_stream = ( |
|
None if not torch.jit.is_scripting() else torch.cuda.default_stream(None) |
|
) |
|
|
|
def __enter__(self): |
|
|
|
cur_stream = self.stream |
|
|
|
if cur_stream is None or self.idx == -1: |
|
return |
|
self.src_prev_stream = torch.cuda.current_stream(None) |
|
|
|
|
|
|
|
if self.src_prev_stream.device != cur_stream.device: |
|
with device(cur_stream.device): |
|
self.dst_prev_stream = torch.cuda.current_stream(cur_stream.device) |
|
torch.cuda.set_stream(cur_stream) |
|
|
|
def __exit__(self, type: Any, value: Any, traceback: Any): |
|
|
|
cur_stream = self.stream |
|
|
|
if cur_stream is None or self.idx == -1: |
|
return |
|
|
|
|
|
|
|
if self.src_prev_stream.device != cur_stream.device: |
|
torch.cuda.set_stream(self.dst_prev_stream) |
|
torch.cuda.set_stream(self.src_prev_stream) |
|
|
|
|
|
def stream(stream: Optional["torch.cuda.Stream"]) -> StreamContext: |
|
r"""Wrap around the Context-manager StreamContext that selects a given stream. |
|
|
|
Arguments: |
|
stream (Stream): selected stream. This manager is a no-op if it's |
|
``None``. |
|
..Note:: In eager mode stream is of type Stream class while in JIT it is |
|
an object of the custom class ``torch.classes.cuda.Stream``. |
|
""" |
|
return StreamContext(stream) |
|
|
|
|
|
def _set_stream_by_id(stream_id, device_index, device_type): |
|
r"""set stream specified by the stream id, device index and |
|
device type |
|
|
|
Args: stream_id (int): stream id in stream pool |
|
device_index (int): device index in topo |
|
device_type (int): enum device type |
|
""" |
|
torch._C._cuda_setStream( |
|
stream_id=stream_id, |
|
device_index=device_index, |
|
device_type=device_type, |
|
) |
|
|
|
|
|
def set_stream(stream: Stream): |
|
r"""Set the current stream.This is a wrapper API to set the stream. |
|
Usage of this function is discouraged in favor of the ``stream`` |
|
context manager. |
|
|
|
Args: |
|
stream (Stream): selected stream. This function is a no-op |
|
if this argument is ``None``. |
|
""" |
|
if stream is None: |
|
return |
|
_set_stream_by_id( |
|
stream_id=stream.stream_id, |
|
device_index=stream.device_index, |
|
device_type=stream.device_type, |
|
) |
|
|
|
|
|
def _parse_visible_devices() -> Union[List[int], List[str]]: |
|
r"""Parse CUDA_VISIBLE_DEVICES environment variable.""" |
|
var = os.getenv( |
|
"CUDA_VISIBLE_DEVICES" if not torch.version.hip else "HIP_VISIBLE_DEVICES" |
|
) |
|
if var is None: |
|
return list(range(64)) |
|
|
|
def _strtoul(s: str) -> int: |
|
"""Return -1 or positive integer sequence string starts with.""" |
|
if not s: |
|
return -1 |
|
for idx, c in enumerate(s): |
|
if not (c.isdigit() or (idx == 0 and c in "+-")): |
|
break |
|
if idx + 1 == len(s): |
|
idx += 1 |
|
return int(s[:idx]) if idx > 0 else -1 |
|
|
|
def parse_list_with_prefix(lst: str, prefix: str) -> List[str]: |
|
rcs: List[str] = [] |
|
for elem in lst.split(","): |
|
|
|
if elem in rcs: |
|
return cast(List[str], []) |
|
|
|
if not elem.startswith(prefix): |
|
break |
|
rcs.append(elem) |
|
return rcs |
|
|
|
if var.startswith("GPU-"): |
|
return parse_list_with_prefix(var, "GPU-") |
|
if var.startswith("MIG-"): |
|
return parse_list_with_prefix(var, "MIG-") |
|
|
|
|
|
rc: List[int] = [] |
|
for elem in var.split(","): |
|
x = _strtoul(elem.strip()) |
|
|
|
if x in rc: |
|
return cast(List[int], []) |
|
|
|
if x < 0: |
|
break |
|
rc.append(x) |
|
return rc |
|
|
|
|
|
def _raw_device_count_amdsmi() -> int: |
|
if not _HAS_PYNVML: |
|
return -1 |
|
try: |
|
amdsmi.amdsmi_init() |
|
except amdsmi.AmdSmiException as e: |
|
warnings.warn(f"Can't initialize amdsmi - Error code: {e.err_code}") |
|
return -1 |
|
socket_handles = amdsmi.amdsmi_get_processor_handles() |
|
return len(socket_handles) |
|
|
|
|
|
def _raw_device_count_nvml() -> int: |
|
r"""Return number of devices as reported by NVML or negative value if NVML discovery/initialization failed.""" |
|
from ctypes import byref, c_int, CDLL |
|
|
|
nvml_h = CDLL("libnvidia-ml.so.1") |
|
rc = nvml_h.nvmlInit() |
|
if rc != 0: |
|
warnings.warn("Can't initialize NVML") |
|
return -1 |
|
dev_count = c_int(-1) |
|
rc = nvml_h.nvmlDeviceGetCount_v2(byref(dev_count)) |
|
if rc != 0: |
|
warnings.warn("Can't get nvml device count") |
|
return -1 |
|
del nvml_h |
|
return dev_count.value |
|
|
|
|
|
def _raw_device_uuid_amdsmi() -> Optional[List[str]]: |
|
from ctypes import byref, c_int, c_void_p, CDLL, create_string_buffer |
|
|
|
if not _HAS_PYNVML: |
|
return None |
|
try: |
|
amdsmi.amdsmi_init() |
|
except amdsmi.AmdSmiException: |
|
warnings.warn("Can't initialize amdsmi") |
|
return None |
|
try: |
|
socket_handles = amdsmi.amdsmi_get_processor_handles() |
|
dev_count = len(socket_handles) |
|
except amdsmi.AmdSmiException: |
|
warnings.warn("Can't get amdsmi device count") |
|
return None |
|
uuids: List[str] = [] |
|
for idx in range(dev_count): |
|
try: |
|
handler = amdsmi.amdsmi_get_processor_handles()[idx] |
|
except amdsmi.AmdSmiException: |
|
warnings.warn("Cannot get amd device handler") |
|
return None |
|
try: |
|
uuid = amdsmi.amdsmi_get_gpu_device_uuid(handler) |
|
except amdsmi.AmdSmiException: |
|
warnings.warn("Cannot get uuid for amd device") |
|
return None |
|
uuids.append(str(uuid)) |
|
return uuids |
|
|
|
|
|
def _raw_device_uuid_nvml() -> Optional[List[str]]: |
|
r"""Return list of device UUID as reported by NVML or None if NVM discovery/initialization failed.""" |
|
from ctypes import byref, c_int, c_void_p, CDLL, create_string_buffer |
|
|
|
nvml_h = CDLL("libnvidia-ml.so.1") |
|
rc = nvml_h.nvmlInit() |
|
if rc != 0: |
|
warnings.warn("Can't initialize NVML") |
|
return None |
|
dev_count = c_int(-1) |
|
rc = nvml_h.nvmlDeviceGetCount_v2(byref(dev_count)) |
|
if rc != 0: |
|
warnings.warn("Can't get nvml device count") |
|
return None |
|
uuids: List[str] = [] |
|
for idx in range(dev_count.value): |
|
dev_id = c_void_p() |
|
rc = nvml_h.nvmlDeviceGetHandleByIndex_v2(idx, byref(dev_id)) |
|
if rc != 0: |
|
warnings.warn("Can't get device handle") |
|
return None |
|
buf_len = 96 |
|
buf = create_string_buffer(buf_len) |
|
rc = nvml_h.nvmlDeviceGetUUID(dev_id, buf, buf_len) |
|
if rc != 0: |
|
warnings.warn("Can't get device UUID") |
|
return None |
|
uuids.append(buf.raw.decode("ascii").strip("\0")) |
|
del nvml_h |
|
return uuids |
|
|
|
|
|
def _transform_uuid_to_ordinals(candidates: List[str], uuids: List[str]) -> List[int]: |
|
r"""Given the set of partial uuids and list of known uuids builds a set of ordinals excluding ambiguous partials IDs.""" |
|
|
|
def uuid_to_orinal(candidate: str, uuids: List[str]) -> int: |
|
best_match = -1 |
|
for idx, uuid in enumerate(uuids): |
|
if not uuid.startswith(candidate): |
|
continue |
|
|
|
if best_match != -1: |
|
return -1 |
|
best_match = idx |
|
return best_match |
|
|
|
rc: List[int] = [] |
|
for candidate in candidates: |
|
idx = uuid_to_orinal(candidate, uuids) |
|
|
|
if idx < 0: |
|
break |
|
|
|
if idx in rc: |
|
return cast(List[int], []) |
|
rc.append(idx) |
|
return rc |
|
|
|
|
|
def _device_count_amdsmi() -> int: |
|
visible_devices = _parse_visible_devices() |
|
if not visible_devices: |
|
return 0 |
|
try: |
|
if type(visible_devices[0]) is str: |
|
return -1 |
|
else: |
|
raw_cnt = _raw_device_count_amdsmi() |
|
if raw_cnt <= 0: |
|
return raw_cnt |
|
|
|
for idx, val in enumerate(visible_devices): |
|
if cast(int, val) >= raw_cnt: |
|
return idx |
|
except OSError: |
|
return -1 |
|
except AttributeError: |
|
return -1 |
|
return len(visible_devices) |
|
|
|
|
|
def _device_count_nvml() -> int: |
|
r"""Return number of devices as reported by NVML taking CUDA_VISIBLE_DEVICES into account. |
|
|
|
Negative value is returned if NVML discovery or initialization has failed. |
|
""" |
|
visible_devices = _parse_visible_devices() |
|
if not visible_devices: |
|
return 0 |
|
try: |
|
if type(visible_devices[0]) is str: |
|
|
|
if visible_devices[0].startswith("MIG-"): |
|
return -1 |
|
uuids = _raw_device_uuid_nvml() |
|
if uuids is None: |
|
return -1 |
|
visible_devices = _transform_uuid_to_ordinals( |
|
cast(List[str], visible_devices), uuids |
|
) |
|
else: |
|
raw_cnt = _raw_device_count_nvml() |
|
if raw_cnt <= 0: |
|
return raw_cnt |
|
|
|
for idx, val in enumerate(visible_devices): |
|
if cast(int, val) >= raw_cnt: |
|
return idx |
|
except OSError: |
|
return -1 |
|
except AttributeError: |
|
return -1 |
|
return len(visible_devices) |
|
|
|
|
|
def _get_nvml_device_index(device: Optional[Union[int, Device]]) -> int: |
|
r"""Return the NVML index of the device, taking CUDA_VISIBLE_DEVICES into account.""" |
|
idx = _get_device_index(device, optional=True) |
|
visible_devices = _parse_visible_devices() |
|
if type(visible_devices[0]) is str: |
|
uuids = _raw_device_uuid_nvml() |
|
if uuids is None: |
|
raise RuntimeError("Can't get device UUIDs") |
|
visible_devices = _transform_uuid_to_ordinals( |
|
cast(List[str], visible_devices), uuids |
|
) |
|
visible_devices = cast(List[int], visible_devices) |
|
if idx < 0 or idx >= len(visible_devices): |
|
raise RuntimeError( |
|
f"device {idx} is not visible (CUDA_VISIBLE_DEVICES={visible_devices})" |
|
) |
|
return visible_devices[idx] |
|
|
|
|
|
_cached_device_count: Optional[int] = None |
|
|
|
|
|
def device_count() -> int: |
|
r"""Return the number of GPUs available.""" |
|
global _cached_device_count |
|
if not _is_compiled(): |
|
return 0 |
|
if _cached_device_count is not None: |
|
return _cached_device_count |
|
|
|
nvml_count = _device_count_amdsmi() if torch.version.hip else _device_count_nvml() |
|
r = torch._C._cuda_getDeviceCount() if nvml_count < 0 else nvml_count |
|
|
|
|
|
|
|
if _initialized: |
|
_cached_device_count = r |
|
return r |
|
|
|
|
|
def get_arch_list() -> List[str]: |
|
r"""Return list CUDA architectures this library was compiled for.""" |
|
if not is_available(): |
|
return [] |
|
arch_flags = torch._C._cuda_getArchFlags() |
|
if arch_flags is None: |
|
return [] |
|
return arch_flags.split() |
|
|
|
|
|
def get_gencode_flags() -> str: |
|
r"""Return NVCC gencode flags this library was compiled with.""" |
|
arch_list = get_arch_list() |
|
if len(arch_list) == 0: |
|
return "" |
|
arch_list_ = [arch.split("_") for arch in arch_list] |
|
return " ".join( |
|
[ |
|
f"-gencode compute=compute_{arch},code={kind}_{arch}" |
|
for (kind, arch) in arch_list_ |
|
] |
|
) |
|
|
|
|
|
def current_device() -> int: |
|
r"""Return the index of a currently selected device.""" |
|
_lazy_init() |
|
return torch._C._cuda_getDevice() |
|
|
|
|
|
def synchronize(device: _device_t = None) -> None: |
|
r"""Wait for all kernels in all streams on a CUDA device to complete. |
|
|
|
Args: |
|
device (torch.device or int, optional): device for which to synchronize. |
|
It uses the current device, given by :func:`~torch.cuda.current_device`, |
|
if :attr:`device` is ``None`` (default). |
|
""" |
|
_lazy_init() |
|
with torch.cuda.device(device): |
|
return torch._C._cuda_synchronize() |
|
|
|
|
|
def ipc_collect(): |
|
r"""Force collects GPU memory after it has been released by CUDA IPC. |
|
|
|
.. note:: |
|
Checks if any sent CUDA tensors could be cleaned from the memory. Force |
|
closes shared memory file used for reference counting if there is no |
|
active counters. Useful when the producer process stopped actively sending |
|
tensors and want to release unused memory. |
|
""" |
|
_lazy_init() |
|
return torch._C._cuda_ipc_collect() |
|
|
|
|
|
def current_stream(device: Optional[_device_t] = None) -> Stream: |
|
r"""Return the currently selected :class:`Stream` for a given device. |
|
|
|
Args: |
|
device (torch.device or int, optional): selected device. Returns |
|
the currently selected :class:`Stream` for the current device, given |
|
by :func:`~torch.cuda.current_device`, if :attr:`device` is ``None`` |
|
(default). |
|
""" |
|
_lazy_init() |
|
streamdata = torch._C._cuda_getCurrentStream( |
|
_get_device_index(device, optional=True) |
|
) |
|
return Stream( |
|
stream_id=streamdata[0], device_index=streamdata[1], device_type=streamdata[2] |
|
) |
|
|
|
|
|
def default_stream(device: Optional[_device_t] = None) -> Stream: |
|
r"""Return the default :class:`Stream` for a given device. |
|
|
|
Args: |
|
device (torch.device or int, optional): selected device. Returns |
|
the default :class:`Stream` for the current device, given by |
|
:func:`~torch.cuda.current_device`, if :attr:`device` is ``None`` |
|
(default). |
|
""" |
|
_lazy_init() |
|
streamdata = torch._C._cuda_getDefaultStream( |
|
_get_device_index(device, optional=True) |
|
) |
|
return Stream( |
|
stream_id=streamdata[0], device_index=streamdata[1], device_type=streamdata[2] |
|
) |
|
|
|
|
|
def current_blas_handle(): |
|
r"""Return cublasHandle_t pointer to current cuBLAS handle""" |
|
_lazy_init() |
|
return torch._C._cuda_getCurrentBlasHandle() |
|
|
|
|
|
def set_sync_debug_mode(debug_mode: Union[int, str]) -> None: |
|
r"""Set the debug mode for cuda synchronizing operations. |
|
|
|
Args: |
|
debug_mode(str or int): if "default" or 0, don't error or warn on synchronizing operations, |
|
if "warn" or 1, warn on synchronizing operations, if "error" or 2, error out synchronizing operations. |
|
|
|
Warning: |
|
This is an experimental feature, and not all synchronizing operations will trigger warning or error. In |
|
particular, operations in torch.distributed and torch.sparse namespaces are not covered yet. |
|
""" |
|
_lazy_init() |
|
if isinstance(debug_mode, str): |
|
if debug_mode == "default": |
|
debug_mode = 0 |
|
elif debug_mode == "warn": |
|
debug_mode = 1 |
|
elif debug_mode == "error": |
|
debug_mode = 2 |
|
else: |
|
raise RuntimeError( |
|
"invalid value of debug_mode, expected one of `default`, `warn`, `error`" |
|
) |
|
|
|
torch._C._cuda_set_sync_debug_mode(debug_mode) |
|
|
|
|
|
def get_sync_debug_mode() -> int: |
|
r"""Return current value of debug mode for cuda synchronizing operations.""" |
|
_lazy_init() |
|
return torch._C._cuda_get_sync_debug_mode() |
|
|
|
|
|
def _get_pynvml_handler(device: Optional[Union[Device, int]] = None): |
|
if not _HAS_PYNVML: |
|
raise ModuleNotFoundError( |
|
"pynvml does not seem to be installed or it can't be imported." |
|
) from _PYNVML_ERR |
|
from pynvml import NVMLError_DriverNotLoaded |
|
|
|
try: |
|
pynvml.nvmlInit() |
|
except NVMLError_DriverNotLoaded as e: |
|
raise RuntimeError("cuda driver can't be loaded, is cuda enabled?") from e |
|
|
|
device = _get_nvml_device_index(device) |
|
handle = pynvml.nvmlDeviceGetHandleByIndex(device) |
|
return handle |
|
|
|
|
|
def _get_amdsmi_handler(device: Optional[Union[Device, int]] = None): |
|
if not _HAS_PYNVML: |
|
raise ModuleNotFoundError( |
|
"amdsmi does not seem to be installed or it can't be imported." |
|
) from _PYNVML_ERR |
|
try: |
|
amdsmi.amdsmi_init() |
|
except amdsmi.AmdSmiException as e: |
|
raise RuntimeError( |
|
"amdsmi driver can't be loaded, requires >=ROCm5.6 installation" |
|
) from e |
|
device = _get_amdsmi_device_index(device) |
|
handle = amdsmi.amdsmi_get_processor_handles()[device] |
|
return handle |
|
|
|
|
|
def _get_amdsmi_device_index(device: Optional[Union[int, Device]]) -> int: |
|
r"""Return the amdsmi index of the device, taking HIP_VISIBLE_DEVICES into account.""" |
|
idx = _get_device_index(device, optional=True) |
|
visible_devices = _parse_visible_devices() |
|
if type(visible_devices[0]) is str: |
|
raise RuntimeError("HIP_VISIBLE_DEVICES should be indices and not strings") |
|
idx_map = dict(enumerate(cast(List[int], visible_devices))) |
|
if idx not in idx_map: |
|
raise RuntimeError( |
|
f"device {idx} is not visible (HIP_VISIBLE_DEVICES={visible_devices})" |
|
) |
|
return idx_map[idx] |
|
|
|
|
|
def _get_amdsmi_memory_usage(device: Optional[Union[Device, int]] = None) -> int: |
|
handle = _get_amdsmi_handler() |
|
device = _get_amdsmi_device_index(device) |
|
return amdsmi.amdsmi_get_gpu_vram_usage(handle)["vram_used"] |
|
|
|
|
|
def _get_amdsmi_utilization(device: Optional[Union[Device, int]] = None) -> int: |
|
handle = _get_amdsmi_handler() |
|
device = _get_amdsmi_device_index(device) |
|
handle = amdsmi.amdsmi_get_processor_handles()[device] |
|
return amdsmi.amdsmi_get_gpu_activity(handle)["gfx_activity"] |
|
|
|
|
|
def _get_amdsmi_temperature(device: Optional[Union[Device, int]] = None) -> int: |
|
handle = _get_amdsmi_handler(device) |
|
return amdsmi.amdsmi_get_temp_metric( |
|
handle, |
|
amdsmi.AmdSmiTemperatureType.JUNCTION, |
|
amdsmi.AmdSmiTemperatureMetric.CURRENT, |
|
) |
|
|
|
|
|
def _get_amdsmi_power_draw(device: Optional[Union[Device, int]] = None) -> int: |
|
handle = _get_amdsmi_handler(device) |
|
return amdsmi.amdsmi_get_power_info(handle)["current_socket_power"] |
|
|
|
|
|
def _get_amdsmi_clock_rate(device: Optional[Union[Device, int]] = None) -> int: |
|
handle = _get_amdsmi_handler(device) |
|
return amdsmi.amdsmi_get_clock_info(handle, amdsmi.AmdSmiClkType.GFX)["cur_clk"] |
|
|
|
|
|
def memory_usage(device: Optional[Union[Device, int]] = None) -> int: |
|
r"""Return the percent of time over the past sample period during which global (device) |
|
memory was being read or written as given by `nvidia-smi`. |
|
|
|
Args: |
|
device (torch.device or int, optional): selected device. Returns |
|
statistic for the current device, given by :func:`~torch.cuda.current_device`, |
|
if :attr:`device` is ``None`` (default). |
|
|
|
Warning: Each sample period may be between 1 second and 1/6 second, |
|
depending on the product being queried. |
|
""" |
|
if not torch.version.hip: |
|
handle = _get_pynvml_handler() |
|
device = _get_nvml_device_index(device) |
|
handle = pynvml.nvmlDeviceGetHandleByIndex(device) |
|
return pynvml.nvmlDeviceGetUtilizationRates(handle).memory |
|
else: |
|
return _get_amdsmi_memory_usage(device) |
|
|
|
|
|
def utilization(device: Optional[Union[Device, int]] = None) -> int: |
|
r"""Return the percent of time over the past sample period during which one or |
|
more kernels was executing on the GPU as given by `nvidia-smi`. |
|
|
|
Args: |
|
device (torch.device or int, optional): selected device. Returns |
|
statistic for the current device, given by :func:`~torch.cuda.current_device`, |
|
if :attr:`device` is ``None`` (default). |
|
|
|
Warning: Each sample period may be between 1 second and 1/6 second, |
|
depending on the product being queried. |
|
""" |
|
if not torch.version.hip: |
|
handle = _get_pynvml_handler(device) |
|
device = _get_nvml_device_index(device) |
|
handle = pynvml.nvmlDeviceGetHandleByIndex(device) |
|
return pynvml.nvmlDeviceGetUtilizationRates(handle).gpu |
|
else: |
|
return _get_amdsmi_utilization(device) |
|
|
|
|
|
def temperature(device: Optional[Union[Device, int]] = None) -> int: |
|
r"""Return the average temperature of the GPU sensor in Degrees C (Centigrades). |
|
|
|
The average temperature is computed based on past sample period as given by `nvidia-smi`. |
|
|
|
Args: |
|
device (torch.device or int, optional): selected device. Returns |
|
statistic for the current device, given by :func:`~torch.cuda.current_device`, |
|
if :attr:`device` is ``None`` (default). |
|
|
|
Warning: Each sample period may be between 1 second and 1/6 second, |
|
depending on the product being queried. |
|
""" |
|
if not torch.version.hip: |
|
handle = _get_pynvml_handler(device) |
|
|
|
return pynvml.nvmlDeviceGetTemperature(handle, 0) |
|
else: |
|
return _get_amdsmi_temperature(device) |
|
|
|
|
|
def power_draw(device: Optional[Union[Device, int]] = None) -> int: |
|
r"""Return the average power draw of the GPU sensor in mW (MilliWatts) |
|
over the past sample period as given by `nvidia-smi` for Fermi or newer fully supported devices. |
|
|
|
Args: |
|
device (torch.device or int, optional): selected device. Returns |
|
statistic for the current device, given by :func:`~torch.cuda.current_device`, |
|
if :attr:`device` is ``None`` (default). |
|
|
|
Warning: Each sample period may be between 1 second and 1/6 second, |
|
depending on the product being queried. |
|
""" |
|
if not torch.version.hip: |
|
handle = _get_pynvml_handler(device) |
|
return pynvml.nvmlDeviceGetPowerUsage(handle) |
|
else: |
|
return _get_amdsmi_power_draw(device) |
|
|
|
|
|
def clock_rate(device: Optional[Union[Device, int]] = None) -> int: |
|
r"""Return the clock speed of the GPU SM in Hz Hertz over the past sample period as given by `nvidia-smi`. |
|
|
|
Args: |
|
device (torch.device or int, optional): selected device. Returns |
|
statistic for the current device, given by :func:`~torch.cuda.current_device`, |
|
if :attr:`device` is ``None`` (default). |
|
|
|
Warning: Each sample period may be between 1 second and 1/6 second, |
|
depending on the product being queried. |
|
""" |
|
if not torch.version.hip: |
|
handle = _get_pynvml_handler(device) |
|
return pynvml.nvmlDeviceGetClockInfo(handle, 1) |
|
else: |
|
return _get_amdsmi_clock_rate(device) |
|
|
|
|
|
def _get_device(device: Union[int, str, torch.device]) -> torch.device: |
|
r"""Return the torch.device type object from the passed in device. |
|
|
|
Args: |
|
device (torch.device or int): selected device. |
|
""" |
|
if isinstance(device, str): |
|
device = torch.device(device) |
|
elif isinstance(device, int): |
|
device = torch.device("cuda", device) |
|
return device |
|
|
|
|
|
def _get_generator(device: torch.device) -> torch._C.Generator: |
|
r"""Return the CUDA Generator object for the given device. |
|
|
|
Args: |
|
device (torch.device): selected device. |
|
""" |
|
idx = device.index |
|
if idx is None: |
|
idx = current_device() |
|
return torch.cuda.default_generators[idx] |
|
|
|
|
|
def _set_rng_state_offset( |
|
offset: int, device: Union[int, str, torch.device] = "cuda" |
|
) -> None: |
|
r"""Set the random number generator state offset of the specified GPU. |
|
|
|
Args: |
|
offset (int): The desired offset |
|
device (torch.device or int, optional): The device to set the RNG state. |
|
Default: ``'cuda'`` (i.e., ``torch.device('cuda')``, the current CUDA device). |
|
""" |
|
final_device = _get_device(device) |
|
|
|
def cb(): |
|
default_generator = _get_generator(final_device) |
|
default_generator.set_offset(offset) |
|
|
|
_lazy_call(cb) |
|
|
|
|
|
def _get_rng_state_offset(device: Union[int, str, torch.device] = "cuda") -> int: |
|
r"""Return the random number generator state offset of the specified GPU. |
|
|
|
Args: |
|
device (torch.device or int, optional): The device to return the RNG state offset of. |
|
Default: ``'cuda'`` (i.e., ``torch.device('cuda')``, the current CUDA device). |
|
|
|
.. warning:: |
|
This function eagerly initializes CUDA. |
|
""" |
|
_lazy_init() |
|
final_device = _get_device(device) |
|
default_generator = _get_generator(final_device) |
|
return default_generator.get_offset() |
|
|
|
|
|
from .memory import * |
|
|
|
|
|
from .random import * |
|
|
|
|
|
|
|
|
|
|
|
|
|
@staticmethod |
|
def _lazy_new(cls, *args, **kwargs): |
|
_lazy_init() |
|
|
|
|
|
return super(_CudaBase, cls).__new__(cls, *args, **kwargs) |
|
|
|
|
|
class _CudaBase: |
|
is_cuda = True |
|
is_sparse = False |
|
|
|
def type(self, *args, **kwargs): |
|
|
|
|
|
|
|
with device(self.get_device()): |
|
return super().type(*args, **kwargs) |
|
|
|
__new__ = _lazy_new |
|
|
|
|
|
from torch.storage import _LegacyStorage, _warn_typed_storage_removal |
|
|
|
|
|
class _CudaLegacyStorage(_LegacyStorage): |
|
@classmethod |
|
def from_buffer(cls, *args, **kwargs): |
|
_warn_typed_storage_removal() |
|
raise RuntimeError("from_buffer: Not available for CUDA storage") |
|
|
|
@classmethod |
|
def _new_with_weak_ptr(cls, *args, **kwargs): |
|
raise RuntimeError("_new_with_weak_ptr: Not available for CUDA storage") |
|
|
|
@classmethod |
|
def _new_shared_filename(cls, manager, obj, size, *, device=None, dtype=None): |
|
raise RuntimeError("_new_shared_filename: Not available for CUDA storage") |
|
|
|
|
|
class ByteStorage(_CudaLegacyStorage): |
|
@classproperty |
|
def dtype(self): |
|
_warn_typed_storage_removal() |
|
return self._dtype |
|
|
|
@classproperty |
|
def _dtype(self): |
|
return torch.uint8 |
|
|
|
|
|
class DoubleStorage(_CudaLegacyStorage): |
|
@classproperty |
|
def dtype(self): |
|
_warn_typed_storage_removal() |
|
return self._dtype |
|
|
|
@classproperty |
|
def _dtype(self): |
|
return torch.double |
|
|
|
|
|
class FloatStorage(_CudaLegacyStorage): |
|
@classproperty |
|
def dtype(self): |
|
_warn_typed_storage_removal() |
|
return self._dtype |
|
|
|
@classproperty |
|
def _dtype(self): |
|
return torch.float |
|
|
|
|
|
class HalfStorage(_CudaLegacyStorage): |
|
@classproperty |
|
def dtype(self): |
|
_warn_typed_storage_removal() |
|
return self._dtype |
|
|
|
@classproperty |
|
def _dtype(self): |
|
return torch.half |
|
|
|
|
|
class LongStorage(_CudaLegacyStorage): |
|
@classproperty |
|
def dtype(self): |
|
_warn_typed_storage_removal() |
|
return self._dtype |
|
|
|
@classproperty |
|
def _dtype(self): |
|
return torch.long |
|
|
|
|
|
class IntStorage(_CudaLegacyStorage): |
|
@classproperty |
|
def dtype(self): |
|
_warn_typed_storage_removal() |
|
return self._dtype |
|
|
|
@classproperty |
|
def _dtype(self): |
|
return torch.int |
|
|
|
|
|
class ShortStorage(_CudaLegacyStorage): |
|
@classproperty |
|
def dtype(self): |
|
_warn_typed_storage_removal() |
|
return self._dtype |
|
|
|
@classproperty |
|
def _dtype(self): |
|
return torch.short |
|
|
|
|
|
class CharStorage(_CudaLegacyStorage): |
|
@classproperty |
|
def dtype(self): |
|
_warn_typed_storage_removal() |
|
return self._dtype |
|
|
|
@classproperty |
|
def _dtype(self): |
|
return torch.int8 |
|
|
|
|
|
class BoolStorage(_CudaLegacyStorage): |
|
@classproperty |
|
def dtype(self): |
|
_warn_typed_storage_removal() |
|
return self._dtype |
|
|
|
@classproperty |
|
def _dtype(self): |
|
return torch.bool |
|
|
|
|
|
class BFloat16Storage(_CudaLegacyStorage): |
|
@classproperty |
|
def dtype(self): |
|
_warn_typed_storage_removal() |
|
return self._dtype |
|
|
|
@classproperty |
|
def _dtype(self): |
|
return torch.bfloat16 |
|
|
|
|
|
class ComplexDoubleStorage(_CudaLegacyStorage): |
|
@classproperty |
|
def dtype(self): |
|
_warn_typed_storage_removal() |
|
return self._dtype |
|
|
|
@classproperty |
|
def _dtype(self): |
|
return torch.cdouble |
|
|
|
|
|
class ComplexFloatStorage(_CudaLegacyStorage): |
|
@classproperty |
|
def dtype(self): |
|
_warn_typed_storage_removal() |
|
return self._dtype |
|
|
|
@classproperty |
|
def _dtype(self): |
|
return torch.cfloat |
|
|
|
|
|
del _LegacyStorage |
|
del _CudaLegacyStorage |
|
|
|
torch._storage_classes.add(DoubleStorage) |
|
torch._storage_classes.add(FloatStorage) |
|
torch._storage_classes.add(LongStorage) |
|
torch._storage_classes.add(IntStorage) |
|
torch._storage_classes.add(ShortStorage) |
|
torch._storage_classes.add(CharStorage) |
|
torch._storage_classes.add(ByteStorage) |
|
torch._storage_classes.add(HalfStorage) |
|
torch._storage_classes.add(BoolStorage) |
|
torch._storage_classes.add(BFloat16Storage) |
|
torch._storage_classes.add(ComplexDoubleStorage) |
|
torch._storage_classes.add(ComplexFloatStorage) |
|
|
|
|
|
class _WrappedTritonKernel: |
|
"""Just a simple wrapper to store some metadata for testing purposes.""" |
|
|
|
def __init__(self, kernel): |
|
self.kernel = kernel |
|
self.kernel_invoked = False |
|
|
|
def __call__(self, *args, **kwargs): |
|
res = self.kernel(*args, **kwargs) |
|
self.kernel_invoked = True |
|
return res |
|
|
|
|
|
def _register_triton_kernels(): |
|
if torch._running_with_deploy(): |
|
return |
|
|
|
@_WrappedTritonKernel |
|
def kernel_impl(*args, **kwargs): |
|
from torch.sparse._triton_ops import bsr_dense_mm |
|
|
|
return bsr_dense_mm(*args, skip_checks=True, **kwargs) |
|
|
|
@_WrappedTritonKernel |
|
def addmm_kernel_impl(*args, **kwargs): |
|
from torch.sparse._triton_ops import bsr_dense_addmm |
|
|
|
return bsr_dense_addmm(*args, skip_checks=True, **kwargs) |
|
|
|
has_triton = importlib.util.find_spec("triton") is not None |
|
if has_triton: |
|
torch._TritonLibrary.registerOp( |
|
"_triton_bsr_dense_mm_out", |
|
"_triton_bsr_dense_mm_out(Tensor bsr, Tensor dense, *, Tensor(a!) out) -> Tensor(a!)", |
|
kernel_impl, |
|
"SparseCsrCUDA", |
|
) |
|
|
|
torch._TritonLibrary.registerOp( |
|
"_triton_bsr_dense_addmm_out", |
|
( |
|
"_triton_bsr_dense_addmm_out(Tensor input, Tensor bsr, Tensor dense," |
|
" *, Scalar beta, Scalar alpha, Tensor(a!) out) -> Tensor(a!)" |
|
), |
|
addmm_kernel_impl, |
|
"SparseCsrCUDA", |
|
) |
|
|
|
|
|
_lazy_call(_register_triton_kernels) |
|
|
|
|
|
from . import amp, jiterator, nvtx, profiler, sparse, tunable |
|
|
|
__all__ = [ |
|
|
|
"BFloat16Storage", |
|
"BFloat16Tensor", |
|
"BoolStorage", |
|
"BoolTensor", |
|
"ByteStorage", |
|
"ByteTensor", |
|
"CharStorage", |
|
"CharTensor", |
|
"ComplexDoubleStorage", |
|
"ComplexFloatStorage", |
|
"DoubleStorage", |
|
"DoubleTensor", |
|
"FloatStorage", |
|
"FloatTensor", |
|
"HalfStorage", |
|
"HalfTensor", |
|
"IntStorage", |
|
"IntTensor", |
|
"LongStorage", |
|
"LongTensor", |
|
"ShortStorage", |
|
"ShortTensor", |
|
"CUDAGraph", |
|
"CudaError", |
|
"DeferredCudaCallError", |
|
"Event", |
|
"ExternalStream", |
|
"Stream", |
|
"StreamContext", |
|
"amp", |
|
"caching_allocator_alloc", |
|
"caching_allocator_delete", |
|
"can_device_access_peer", |
|
"check_error", |
|
"cudaStatus", |
|
"cudart", |
|
"current_blas_handle", |
|
"current_device", |
|
"current_stream", |
|
"default_generators", |
|
"default_stream", |
|
"device", |
|
"device_count", |
|
"device_of", |
|
"empty_cache", |
|
"get_allocator_backend", |
|
"CUDAPluggableAllocator", |
|
"change_current_allocator", |
|
"get_arch_list", |
|
"get_device_capability", |
|
"get_device_name", |
|
"get_device_properties", |
|
"get_gencode_flags", |
|
"get_rng_state", |
|
"get_rng_state_all", |
|
"get_sync_debug_mode", |
|
"graph", |
|
"graph_pool_handle", |
|
"graphs", |
|
"has_half", |
|
"has_magma", |
|
"init", |
|
"initial_seed", |
|
"ipc_collect", |
|
"is_available", |
|
"is_bf16_supported", |
|
"is_current_stream_capturing", |
|
"is_initialized", |
|
"jiterator", |
|
"list_gpu_processes", |
|
"make_graphed_callables", |
|
"manual_seed", |
|
"manual_seed_all", |
|
"max_memory_allocated", |
|
"max_memory_cached", |
|
"max_memory_reserved", |
|
"mem_get_info", |
|
"memory", |
|
"memory_allocated", |
|
"memory_cached", |
|
"memory_reserved", |
|
"memory_snapshot", |
|
"memory_stats", |
|
"memory_stats_as_nested_dict", |
|
"memory_summary", |
|
"memory_usage", |
|
"temperature", |
|
"power_draw", |
|
"clock_rate", |
|
"nccl", |
|
"nvtx", |
|
"profiler", |
|
"random", |
|
"reset_accumulated_memory_stats", |
|
"reset_max_memory_allocated", |
|
"reset_max_memory_cached", |
|
"reset_peak_memory_stats", |
|
"seed", |
|
"seed_all", |
|
"set_device", |
|
"set_per_process_memory_fraction", |
|
"set_rng_state", |
|
"set_rng_state_all", |
|
"set_stream", |
|
"set_sync_debug_mode", |
|
"sparse", |
|
"stream", |
|
"streams", |
|
"synchronize", |
|
"tunable", |
|
"utilization", |
|
] |
|
|