import torch def get_device(device = None): if device is None: # get cuda -> mps -> cpu if torch.cuda.is_available(): device = "cuda" elif torch.backends.mps.is_available(): if torch.backends.mps.is_built(): device = "mps" else: device = "cpu" else: device = "cpu" return device