File size: 4,420 Bytes
2e36228 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 |
import time
import os
import sys
import json
import random
import numpy as np
import torch
def setup_device(gpu_id):
#set up GPUS
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
if int(gpu_id) == -2 and os.getenv('CUDA_VISIBLE_DEVICES') is not None:
gpu_id = os.getenv('CUDA_VISIBLE_DEVICES')
elif int(gpu_id) >= 0:
os.environ['CUDA_VISIBLE_DEVICES'] = str(gpu_id)
print("set CUDA_VISIBLE_DEVICES=", gpu_id)
else:
os.environ['CUDA_VISIBLE_DEVICES'] = ""
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print("using device %s" % device)
return device
def setup_seed(seed):
if seed < 0:
if os.getenv('SATOSHI_SEED') is not None and seed == -2:
seed = int(os.getenv('SATOSHI_SEED'))
print("env seed used")
else:
import math
seed = int(10**4 * math.modf(time.time())[0])
seed = seed
print("random seed", seed)
return seed
def setup_savedir(prefix="", basedir="./experiments", args=None, append_args=[]):
savedir = prefix
if len(append_args) > 0 and args is not None:
for arg_opt in append_args:
arg_value = getattr(args, arg_opt)
savedir += "_" + arg_opt + "-" + str(arg_value)
else:
savedir += "exp"
savedir = savedir.replace(" ", "").replace("'", "").replace('"', '')
savedir = os.path.join(basedir, savedir)
#if exists, append _num-[num]
i = 1
savedir_ori = savedir
while True:
try:
os.makedirs(savedir)
break
except FileExistsError as e:
savedir = savedir_ori + "_num-%d" % i
i += 1
print("made the log directory", savedir)
return savedir
def save_args(savedir, args, name="args.json"):
#save args as "args.json" in the savedir
path = os.path.join(savedir, name)
with open(path, 'w') as f:
json.dump(vars(args), f, sort_keys=True, indent=4)
print("args saved as %s" % path)
def save_json(dict, path):
with open(path, 'w') as f:
json.dump(dict, f, sort_keys=True, indent=4)
print("log saved at %s" % path)
def resume_model(model, resume, state_dict_key="model"):
'''
model:pytorch model
resume: path to the resume file
state_dict_key: dict key
'''
print("resuming trained weights from %s" % resume)
checkpoint = torch.load(resume, map_location='cpu')
if state_dict_key is not None:
pretrained_dict = checkpoint[state_dict_key]
else:
pretrained_dict = checkpoint
try:
model.load_state_dict(pretrained_dict)
except RuntimeError as e:
print(e)
print("can't load the all weights due to error above, trying to load part of them!")
model_dict = model.state_dict()
# 1. filter out unnecessary keys
pretrained_dict_use = {}
pretrained_dict_ignored = {}
for k, v in pretrained_dict.items():
if k in model_dict:
pretrained_dict_use[k] = v
else:
pretrained_dict_ignored[k] = v
pretrained_dict = pretrained_dict_use
# 2. overwrite entries in the existing state dict
model_dict.update(pretrained_dict)
# 3. load the new state dict
model.load_state_dict(model_dict)
print("resumed only", pretrained_dict.keys())
print("ignored:", pretrained_dict_ignored.keys())
return model
def save_checkpoint(path, model, key="model"):
#save model state dict
checkpoint = {}
checkpoint[key] = model.state_dict()
torch.save(checkpoint, path)
print("checkpoint saved at", path)
def make_deterministic(seed, strict=False):
#https://github.com/pytorch/pytorch/issues/7068#issuecomment-487907668
random.seed(seed)
os.environ['PYTHONHASHSEED'] = str(seed)
np.random.seed(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed) # if you are using multi-GPU.
# torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True
# torch.set_deterministic(True)
if strict:
#https://github.com/pytorch/pytorch/issues/7068#issuecomment-515728600
torch.backends.cudnn.enabled = False
print(
"strict reproducability required! cudnn disabled. make sure to set num_workers=0 too!")
|