|
|
|
r""" |
|
This package enables an interface for accessing MPS (Metal Performance Shaders) backend in Python. |
|
Metal is Apple's API for programming metal GPU (graphics processor unit). Using MPS means that increased |
|
performance can be achieved, by running work on the metal GPU(s). |
|
See https://developer.apple.com/documentation/metalperformanceshaders for more details. |
|
""" |
|
from typing import Union |
|
|
|
import torch |
|
from .. import Tensor |
|
|
|
_is_in_bad_fork = getattr(torch._C, "_mps_is_in_bad_fork", lambda: False) |
|
_default_mps_generator: torch._C.Generator = None |
|
|
|
|
|
|
|
def _get_default_mps_generator() -> torch._C.Generator: |
|
global _default_mps_generator |
|
if _default_mps_generator is None: |
|
_default_mps_generator = torch._C._mps_get_default_generator() |
|
return _default_mps_generator |
|
|
|
|
|
def device_count() -> int: |
|
r"""Returns the number of available MPS devices.""" |
|
return int(torch._C._has_mps and torch._C._mps_is_available()) |
|
|
|
|
|
def synchronize() -> None: |
|
r"""Waits for all kernels in all streams on a MPS device to complete.""" |
|
return torch._C._mps_deviceSynchronize() |
|
|
|
|
|
def get_rng_state(device: Union[int, str, torch.device] = "mps") -> Tensor: |
|
r"""Returns the random number generator state as a ByteTensor. |
|
|
|
Args: |
|
device (torch.device or int, optional): The device to return the RNG state of. |
|
Default: ``'mps'`` (i.e., ``torch.device('mps')``, the current MPS device). |
|
""" |
|
return _get_default_mps_generator().get_state() |
|
|
|
|
|
def set_rng_state( |
|
new_state: Tensor, device: Union[int, str, torch.device] = "mps" |
|
) -> None: |
|
r"""Sets the random number generator state. |
|
|
|
Args: |
|
new_state (torch.ByteTensor): The desired state |
|
device (torch.device or int, optional): The device to set the RNG state. |
|
Default: ``'mps'`` (i.e., ``torch.device('mps')``, the current MPS device). |
|
""" |
|
new_state_copy = new_state.clone(memory_format=torch.contiguous_format) |
|
_get_default_mps_generator().set_state(new_state_copy) |
|
|
|
|
|
def manual_seed(seed: int) -> None: |
|
r"""Sets the seed for generating random numbers. |
|
|
|
Args: |
|
seed (int): The desired seed. |
|
""" |
|
|
|
|
|
|
|
|
|
if not torch._C._has_mps: |
|
return |
|
seed = int(seed) |
|
_get_default_mps_generator().manual_seed(seed) |
|
|
|
|
|
def seed() -> None: |
|
r"""Sets the seed for generating random numbers to a random number.""" |
|
_get_default_mps_generator().seed() |
|
|
|
|
|
def empty_cache() -> None: |
|
r"""Releases all unoccupied cached memory currently held by the caching |
|
allocator so that those can be used in other GPU applications. |
|
""" |
|
torch._C._mps_emptyCache() |
|
|
|
|
|
def set_per_process_memory_fraction(fraction) -> None: |
|
r"""Set memory fraction for limiting process's memory allocation on MPS device. |
|
The allowed value equals the fraction multiplied by recommended maximum device memory |
|
(obtained from Metal API device.recommendedMaxWorkingSetSize). |
|
If trying to allocate more than the allowed value in a process, it will raise an out of |
|
memory error in allocator. |
|
|
|
Args: |
|
fraction(float): Range: 0~2. Allowed memory equals total_memory * fraction. |
|
|
|
.. note:: |
|
Passing 0 to fraction means unlimited allocations |
|
(may cause system failure if out of memory). |
|
Passing fraction greater than 1.0 allows limits beyond the value |
|
returned from device.recommendedMaxWorkingSetSize. |
|
""" |
|
|
|
if not isinstance(fraction, float): |
|
raise TypeError("Invalid type for fraction argument, must be `float`") |
|
if fraction < 0 or fraction > 2: |
|
raise ValueError(f"Invalid fraction value: {fraction}. Allowed range: 0~2") |
|
|
|
torch._C._mps_setMemoryFraction(fraction) |
|
|
|
|
|
def current_allocated_memory() -> int: |
|
r"""Returns the current GPU memory occupied by tensors in bytes. |
|
|
|
.. note:: |
|
The returned size does not include cached allocations in |
|
memory pools of MPSAllocator. |
|
""" |
|
return torch._C._mps_currentAllocatedMemory() |
|
|
|
|
|
def driver_allocated_memory() -> int: |
|
r"""Returns total GPU memory allocated by Metal driver for the process in bytes. |
|
|
|
.. note:: |
|
The returned size includes cached allocations in MPSAllocator pools |
|
as well as allocations from MPS/MPSGraph frameworks. |
|
""" |
|
return torch._C._mps_driverAllocatedMemory() |
|
|
|
|
|
from . import profiler |
|
from .event import Event |
|
|
|
__all__ = [ |
|
"device_count", |
|
"get_rng_state", |
|
"manual_seed", |
|
"seed", |
|
"set_rng_state", |
|
"synchronize", |
|
"empty_cache", |
|
"set_per_process_memory_fraction", |
|
"current_allocated_memory", |
|
"driver_allocated_memory", |
|
"Event", |
|
"profiler", |
|
] |
|
|