File size: 829 Bytes
72268ee
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
import os
import torch
import gc
import logging


def auto_parallel(args):
    model_size = args.model_path.split("-")[-1]
    if model_size.endswith("m"):
        model_gb = 1
    else:
        model_gb = float(model_size[:-1])
    if model_gb < 20:
        n_gpu = 1
    elif model_gb < 50:
        n_gpu = 4
    else:
        n_gpu = 8
    args.parallel = n_gpu > 1
    cuda_visible_devices = os.environ.get("CUDA_VISIBLE_DEVICES", None)
    if isinstance(cuda_visible_devices, str):
        cuda_visible_devices = cuda_visible_devices.split(",")
    else:
        cuda_visible_devices = list(range(8))
    os.environ["CUDA_VISIBLE_DEVICES"] = ",".join(
        [str(dev) for dev in cuda_visible_devices[:n_gpu]])
    logging.debug("CUDA_VISIBLE_DEVICES: ", os.environ["CUDA_VISIBLE_DEVICES"])
    return cuda_visible_devices