|
import torch |
|
|
|
from .log import logger |
|
|
|
|
|
def select_device(min_memory=2047, experimental=False): |
|
if torch.cuda.is_available(): |
|
selected_gpu = 0 |
|
max_free_memory = -1 |
|
for i in range(torch.cuda.device_count()): |
|
props = torch.cuda.get_device_properties(i) |
|
free_memory = props.total_memory - torch.cuda.memory_reserved(i) |
|
if max_free_memory < free_memory: |
|
selected_gpu = i |
|
max_free_memory = free_memory |
|
free_memory_mb = max_free_memory / (1024 * 1024) |
|
if free_memory_mb < min_memory: |
|
logger.get_logger().warning( |
|
f"GPU {selected_gpu} has {round(free_memory_mb, 2)} MB memory left. Switching to CPU." |
|
) |
|
device = torch.device("cpu") |
|
else: |
|
device = torch.device(f"cuda:{selected_gpu}") |
|
elif torch.backends.mps.is_available(): |
|
""" |
|
Currently MPS is slower than CPU while needs more memory and core utility, |
|
so only enable this for experimental use. |
|
""" |
|
if experimental: |
|
|
|
logger.get_logger().warning("experimantal: found apple GPU, using MPS.") |
|
device = torch.device("mps") |
|
else: |
|
logger.get_logger().info("found Apple GPU, but use CPU.") |
|
device = torch.device("cpu") |
|
else: |
|
logger.get_logger().warning("no GPU found, use CPU instead") |
|
device = torch.device("cpu") |
|
|
|
return device |
|
|