Doa-doa's picture
Upload folder using huggingface_hub
72268ee
raw
history blame contribute delete
829 Bytes
import os
import torch
import gc
import logging
def auto_parallel(args):
model_size = args.model_path.split("-")[-1]
if model_size.endswith("m"):
model_gb = 1
else:
model_gb = float(model_size[:-1])
if model_gb < 20:
n_gpu = 1
elif model_gb < 50:
n_gpu = 4
else:
n_gpu = 8
args.parallel = n_gpu > 1
cuda_visible_devices = os.environ.get("CUDA_VISIBLE_DEVICES", None)
if isinstance(cuda_visible_devices, str):
cuda_visible_devices = cuda_visible_devices.split(",")
else:
cuda_visible_devices = list(range(8))
os.environ["CUDA_VISIBLE_DEVICES"] = ",".join(
[str(dev) for dev in cuda_visible_devices[:n_gpu]])
logging.debug("CUDA_VISIBLE_DEVICES: ", os.environ["CUDA_VISIBLE_DEVICES"])
return cuda_visible_devices