Spaces:
Runtime error
Runtime error
import functools | |
import gc | |
import torch | |
try: | |
HAS_CUDA = torch.cuda.is_available() | |
except Exception: | |
HAS_CUDA = False | |
try: | |
HAS_MPS = torch.backends.mps.is_available() | |
except Exception: | |
HAS_MPS = False | |
try: | |
import intel_extension_for_pytorch as ipex # noqa | |
HAS_XPU = torch.xpu.is_available() | |
except Exception: | |
HAS_XPU = False | |
def clean_memory(): | |
gc.collect() | |
if HAS_CUDA: | |
torch.cuda.empty_cache() | |
if HAS_XPU: | |
torch.xpu.empty_cache() | |
if HAS_MPS: | |
torch.mps.empty_cache() | |
def clean_memory_on_device(device: torch.device): | |
r""" | |
Clean memory on the specified device, will be called from training scripts. | |
""" | |
gc.collect() | |
# device may "cuda" or "cuda:0", so we need to check the type of device | |
if device.type == "cuda": | |
torch.cuda.empty_cache() | |
if device.type == "xpu": | |
torch.xpu.empty_cache() | |
if device.type == "mps": | |
torch.mps.empty_cache() | |
def get_preferred_device() -> torch.device: | |
r""" | |
Do not call this function from training scripts. Use accelerator.device instead. | |
""" | |
if HAS_CUDA: | |
device = torch.device("cuda") | |
elif HAS_XPU: | |
device = torch.device("xpu") | |
elif HAS_MPS: | |
device = torch.device("mps") | |
else: | |
device = torch.device("cpu") | |
print(f"get_preferred_device() -> {device}") | |
return device | |
def init_ipex(): | |
""" | |
Apply IPEX to CUDA hijacks using `library.ipex.ipex_init`. | |
This function should run right after importing torch and before doing anything else. | |
If IPEX is not available, this function does nothing. | |
""" | |
try: | |
if HAS_XPU: | |
from library.ipex import ipex_init | |
is_initialized, error_message = ipex_init() | |
if not is_initialized: | |
print("failed to initialize ipex:", error_message) | |
else: | |
return | |
except Exception as e: | |
print("failed to initialize ipex:", e) | |