|
import time |
|
import os |
|
import sys |
|
import json |
|
import random |
|
import numpy as np |
|
import torch |
|
|
|
|
|
def setup_device(gpu_id): |
|
|
|
os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID" |
|
if int(gpu_id) == -2 and os.getenv('CUDA_VISIBLE_DEVICES') is not None: |
|
gpu_id = os.getenv('CUDA_VISIBLE_DEVICES') |
|
elif int(gpu_id) >= 0: |
|
os.environ['CUDA_VISIBLE_DEVICES'] = str(gpu_id) |
|
print("set CUDA_VISIBLE_DEVICES=", gpu_id) |
|
else: |
|
os.environ['CUDA_VISIBLE_DEVICES'] = "" |
|
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") |
|
print("using device %s" % device) |
|
return device |
|
|
|
|
|
def setup_seed(seed): |
|
if seed < 0: |
|
if os.getenv('SATOSHI_SEED') is not None and seed == -2: |
|
seed = int(os.getenv('SATOSHI_SEED')) |
|
print("env seed used") |
|
else: |
|
import math |
|
seed = int(10**4 * math.modf(time.time())[0]) |
|
seed = seed |
|
print("random seed", seed) |
|
return seed |
|
|
|
|
|
def setup_savedir(prefix="", basedir="./experiments", args=None, append_args=[]): |
|
savedir = prefix |
|
if len(append_args) > 0 and args is not None: |
|
for arg_opt in append_args: |
|
arg_value = getattr(args, arg_opt) |
|
savedir += "_" + arg_opt + "-" + str(arg_value) |
|
else: |
|
savedir += "exp" |
|
|
|
savedir = savedir.replace(" ", "").replace("'", "").replace('"', '') |
|
savedir = os.path.join(basedir, savedir) |
|
|
|
|
|
i = 1 |
|
savedir_ori = savedir |
|
while True: |
|
try: |
|
os.makedirs(savedir) |
|
break |
|
except FileExistsError as e: |
|
savedir = savedir_ori + "_num-%d" % i |
|
i += 1 |
|
|
|
print("made the log directory", savedir) |
|
return savedir |
|
|
|
|
|
def save_args(savedir, args, name="args.json"): |
|
|
|
path = os.path.join(savedir, name) |
|
with open(path, 'w') as f: |
|
json.dump(vars(args), f, sort_keys=True, indent=4) |
|
print("args saved as %s" % path) |
|
|
|
|
|
def save_json(dict, path): |
|
with open(path, 'w') as f: |
|
json.dump(dict, f, sort_keys=True, indent=4) |
|
print("log saved at %s" % path) |
|
|
|
|
|
def resume_model(model, resume, state_dict_key="model"): |
|
''' |
|
model:pytorch model |
|
resume: path to the resume file |
|
state_dict_key: dict key |
|
''' |
|
print("resuming trained weights from %s" % resume) |
|
|
|
checkpoint = torch.load(resume, map_location='cpu') |
|
if state_dict_key is not None: |
|
pretrained_dict = checkpoint[state_dict_key] |
|
else: |
|
pretrained_dict = checkpoint |
|
|
|
try: |
|
model.load_state_dict(pretrained_dict) |
|
except RuntimeError as e: |
|
print(e) |
|
print("can't load the all weights due to error above, trying to load part of them!") |
|
model_dict = model.state_dict() |
|
|
|
pretrained_dict_use = {} |
|
pretrained_dict_ignored = {} |
|
for k, v in pretrained_dict.items(): |
|
if k in model_dict: |
|
pretrained_dict_use[k] = v |
|
else: |
|
pretrained_dict_ignored[k] = v |
|
pretrained_dict = pretrained_dict_use |
|
|
|
model_dict.update(pretrained_dict) |
|
|
|
model.load_state_dict(model_dict) |
|
print("resumed only", pretrained_dict.keys()) |
|
print("ignored:", pretrained_dict_ignored.keys()) |
|
|
|
return model |
|
|
|
|
|
def save_checkpoint(path, model, key="model"): |
|
|
|
checkpoint = {} |
|
checkpoint[key] = model.state_dict() |
|
torch.save(checkpoint, path) |
|
print("checkpoint saved at", path) |
|
|
|
|
|
def make_deterministic(seed, strict=False): |
|
|
|
random.seed(seed) |
|
os.environ['PYTHONHASHSEED'] = str(seed) |
|
np.random.seed(seed) |
|
torch.manual_seed(seed) |
|
torch.cuda.manual_seed(seed) |
|
torch.cuda.manual_seed_all(seed) |
|
|
|
torch.backends.cudnn.deterministic = True |
|
|
|
if strict: |
|
|
|
torch.backends.cudnn.enabled = False |
|
print( |
|
"strict reproducability required! cudnn disabled. make sure to set num_workers=0 too!") |
|
|