import torch | |
def get_device(device = None): | |
if device is None: | |
# get cuda -> mps -> cpu | |
if torch.cuda.is_available(): | |
device = "cuda" | |
elif torch.backends.mps.is_available(): | |
if torch.backends.mps.is_built(): | |
device = "mps" | |
else: | |
device = "cpu" | |
else: | |
device = "cpu" | |
return device |