import os | |
import torch | |
import gc | |
import logging | |
def auto_parallel(args): | |
model_size = args.model_path.split("-")[-1] | |
if model_size.endswith("m"): | |
model_gb = 1 | |
else: | |
model_gb = float(model_size[:-1]) | |
if model_gb < 20: | |
n_gpu = 1 | |
elif model_gb < 50: | |
n_gpu = 4 | |
else: | |
n_gpu = 8 | |
args.parallel = n_gpu > 1 | |
cuda_visible_devices = os.environ.get("CUDA_VISIBLE_DEVICES", None) | |
if isinstance(cuda_visible_devices, str): | |
cuda_visible_devices = cuda_visible_devices.split(",") | |
else: | |
cuda_visible_devices = list(range(8)) | |
os.environ["CUDA_VISIBLE_DEVICES"] = ",".join( | |
[str(dev) for dev in cuda_visible_devices[:n_gpu]]) | |
logging.debug("CUDA_VISIBLE_DEVICES: ", os.environ["CUDA_VISIBLE_DEVICES"]) | |
return cuda_visible_devices | |