Spaces:
Running
Running
from joblib import Parallel, delayed, cpu_count | |
from tqdm import tqdm | |
import logging | |
from datetime import datetime | |
import os | |
import yaml | |
import random | |
import torch | |
import numpy as np | |
# from gpustat import GPUStatCollection | |
import sys | |
import time | |
from easydict import EasyDict | |
from torch.nn.parallel import DistributedDataParallel as DDP | |
import torch.distributed as dist | |
import torch.multiprocessing as mp | |
from datetime import timedelta | |
def get_free_gpu(mode="memory", memory_need=10000) -> list: | |
r"""Get free gpu according to mode (process-free or memory-free). | |
Args: | |
mode (str, optional): memory-free or process-free. Defaults to "memory". | |
memory_need (int): The memory you need, used if mode=='memory'. Defaults to 10000. | |
Returns: | |
list: free gpu ids sorting by free memory | |
""" | |
assert mode in ["memory", "process"], "mode must be 'memory' or 'process'" | |
if mode == "memory": | |
assert memory_need is not None, \ | |
"'memory_need' if None, 'memory' mode must give the free memory you want to apply for" | |
memory_need = int(memory_need) | |
assert memory_need > 0, "'memory_need' you want must be positive" | |
gpu_stats = GPUStatCollection.new_query() | |
gpu_free_id_list = [] | |
for idx, gpu_stat in enumerate(gpu_stats): | |
if gpu_check_condition(gpu_stat, mode, memory_need): | |
gpu_free_id_list.append([idx, gpu_stat.memory_free]) | |
print("gpu[{}]: {}MB".format(idx, gpu_stat.memory_free)) | |
if gpu_free_id_list: | |
gpu_free_id_list = sorted(gpu_free_id_list, | |
key=lambda x: x[1], | |
reverse=True) | |
gpu_free_id_list = [i[0] for i in gpu_free_id_list] | |
return gpu_free_id_list | |
def gpu_check_condition(gpu_stat, mode, memory_need) -> bool: | |
r"""Check gpu is free or not. | |
Args: | |
gpu_stat (gpustat.core): gpustat to check | |
mode (str): memory-free or process-free. | |
memory_need (int): The memory you need, used if mode=='memory' | |
Returns: | |
bool: gpu is free or not | |
""" | |
if mode == "memory": | |
return gpu_stat.memory_free > memory_need | |
elif mode == "process": | |
for process in gpu_stat.processes: | |
if process["command"] == "python": | |
return False | |
return True | |
else: | |
return False | |
def get_device(gpu_check_list, memory_need=10000): | |
# check device | |
target_gpus = list(filter(lambda x: x is not None, gpu_check_list)) | |
# assert torch.cuda.device_count() >= len(target_gpus), 'do you set the gpus in config correctly?' | |
flag = None | |
while True: | |
# Get the gpu ids which have more than 10000MB memory | |
free_gpu_ids = get_free_gpu('memory', memory_need) | |
if len(free_gpu_ids) < 1: | |
if flag is None: | |
print("No GPU available now. Wait or Exit? y/n") | |
flag = input() | |
if flag.strip() == 'y': | |
continue | |
else: | |
device = torch.device('cpu') | |
print("Using device %s as main device" % device) | |
break | |
time.sleep(60) | |
else: | |
free_target_gpu = list(set(free_gpu_ids) & set(target_gpus)) | |
if len(free_target_gpu) == 0: | |
gpuid = free_gpu_ids[0] | |
print(f"no target GPU is not available") | |
else: | |
gpuid = free_target_gpu[0] | |
device = torch.device('cuda:'+str(gpuid)) | |
print("Using device %s as main device" % device) | |
break | |
return device | |
def get_config_easydict(config_path): | |
with open(config_path, 'r') as f: | |
config = yaml.safe_load(f) | |
return EasyDict(config) | |
def pmap_multi(pickleable_fn, data, n_jobs=None, verbose=1, desc=None, **kwargs): | |
""" | |
Parallel map using joblib. | |
Parameters | |
---------- | |
pickleable_fn : callable | |
Function to map over data. | |
data : iterable | |
Data over which we want to parallelize the function call. | |
n_jobs : int, optional | |
The maximum number of concurrently running jobs. By default, it is one less than | |
the number of CPUs. | |
verbose: int, optional | |
The verbosity level. If nonzero, the function prints the progress messages. | |
The frequency of the messages increases with the verbosity level. If above 10, | |
it reports all iterations. If above 50, it sends the output to stdout. | |
kwargs | |
Additional arguments for :attr:`pickleable_fn`. | |
Returns | |
------- | |
list | |
The i-th element of the list corresponds to the output of applying | |
:attr:`pickleable_fn` to :attr:`data[i]`. | |
""" | |
if n_jobs is None: | |
n_jobs = cpu_count() - 1 | |
results = Parallel(n_jobs=n_jobs, verbose=verbose, timeout=None)( | |
delayed(pickleable_fn)(*d, **kwargs) for i, d in tqdm(enumerate(data),desc=desc) | |
) | |
return results | |
def get_names_from_txt(txt_path): | |
with open(txt_path,'r') as f: | |
lines = f.read().strip().split('\n') | |
return lines | |
def get_logger(run_dir, rank=0): | |
""" | |
Set the logger | |
""" | |
if not os.path.exists(run_dir): | |
os.makedirs(run_dir) | |
logfile_name = os.path.join(run_dir, 'log.txt') | |
fmt = "%(asctime)-15s %(levelname)s %(message)s" | |
filedatefmt = "%a %d %b %Y %H:%M:%S" | |
fileformatter = logging.Formatter(fmt, filedatefmt) | |
sdatefmt = "%d %b %H:%M" | |
sformatter = logging.Formatter(fmt, sdatefmt) | |
fh = logging.FileHandler(logfile_name) | |
fh.setLevel(logging.INFO) | |
fh.setFormatter(fileformatter) | |
sh = logging.StreamHandler() | |
sh.setLevel(logging.INFO) | |
sh.setFormatter(sformatter) | |
logging.basicConfig(level=logging.INFO, handlers=[fh, sh]) | |
return logging.getLogger() | |
def yaml_config_hook(config_file): | |
""" | |
Custom YAML config loader, which can include other yaml files (I like using config files | |
insteaad of using argparser) | |
""" | |
# load yaml files in the nested 'defaults' section, which include defaults for experiments | |
with open(config_file,'rb') as f: | |
cfg = yaml.safe_load(f) | |
for d in cfg.get("defaults", []): | |
config_dir, cf = d.popitem() | |
cf = os.path.join(os.path.dirname(config_file), config_dir, cf + ".yaml") | |
with open(cf) as f: | |
l = yaml.safe_load(f) | |
cfg.update(l) | |
if "defaults" in cfg.keys(): | |
del cfg["defaults"] | |
return cfg | |
def set_seed(seed=1029): | |
random.seed(seed) | |
os.environ['PYTHONHASHSEED'] = str(seed) | |
np.random.seed(seed) | |
torch.manual_seed(seed) | |
torch.cuda.manual_seed(seed) | |
torch.cuda.manual_seed_all(seed) | |
torch.backends.cudnn.benchmark = False | |
torch.backends.cudnn.deterministic = True | |
torch.autograd.set_detect_anomaly(True) | |
#torch.use_deterministic_algorithms(True) | |
def ddp_setup(local_rank): | |
torch.cuda.set_device(local_rank) | |
dist.init_process_group(backend='nccl',timeout=timedelta(seconds=7200)) | |
def init_distributed_mode(args): | |
if "RANK" in os.environ and "WORLD_SIZE" in os.environ: | |
args.rank = int(os.environ["RANK"]) | |
args.world_size = int(os.environ["WORLD_SIZE"]) | |
else: | |
print("Not using distributed mode") | |
args.distributed = False | |
return | |
args.distributed = True | |
torch.cuda.set_device(args.local_rank) | |
args.dist_backend = "nccl" | |
print( | |
"| distributed init (rank {} local rank {}): {}".format( | |
args.rank, args.local_rank, "env://" | |
), | |
flush=True, | |
) | |
torch.distributed.init_process_group( | |
backend=args.dist_backend, init_method="env://", world_size=args.world_size, rank=args.rank | |
) | |
torch.distributed.barrier() | |
setup_for_distributed(args.rank == 0) | |
def setup_for_distributed(is_master): | |
""" | |
This function disables printing when not in master process | |
""" | |
import builtins as __builtin__ | |
builtin_print = __builtin__.print | |
def print(*args, **kwargs): | |
force = kwargs.pop("force", False) | |
if is_master or force: | |
builtin_print(*args, **kwargs) | |
__builtin__.print = print | |
def ddp_cleanup(): | |
dist.destroy_process_group() | |
def get_run_dir(config): | |
run_dir = f'{config.root_dir}/{config.data.dataset_name}' \ | |
f'_{config.target}_model{config.model.model_type}_gtype{config.data.prot_graph_type}' \ | |
f'_lcut{config.data.ligcut}_pcut{config.data.protcut}_icut{config.data.intercut}_ccut{config.data.chaincut}' \ | |
f'_pgmn{config.data.prot_max_neighbors}_lgmn{config.data.lig_max_neighbors}' \ | |
f'_igmn{config.data.inter_min_neighbors}_igmn{config.data.inter_max_neighbors}' \ | |
f'_test2{config.data.test_2}' | |
return run_dir | |
def get_esm_layer(esm_model_name): | |
esm_layer = {'esm2_t48_15B_UR50D':48, | |
'esm2_t36_3B_UR50D':36, | |
'esm2_t33_650M_UR50D':33, | |
'esm2_t30_150M_UR50D':30, | |
'esm2_t12_35M_UR50D':12, | |
'esm2_t6_8M_UR50D':6,} | |
return esm_layer[esm_model_name] | |
def get_esm_dim(esm_model_name): | |
esm_dim = {'esm2_t48_15B_UR50D':5120, | |
'esm2_t36_3B_UR50D':2560, | |
'esm2_t33_650M_UR50D':1280, | |
'esm2_t30_150M_UR50D':640, | |
'esm2_t12_35M_UR50D':480, | |
'esm2_t6_8M_UR50D':320,} | |
return esm_dim[esm_model_name] |