File size: 4,420 Bytes
2e36228
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
import time
import os
import sys
import json
import random
import numpy as np
import torch


def setup_device(gpu_id):
    #set up GPUS
    os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
    if int(gpu_id) == -2 and os.getenv('CUDA_VISIBLE_DEVICES') is not None:
        gpu_id = os.getenv('CUDA_VISIBLE_DEVICES')
    elif int(gpu_id) >= 0:
        os.environ['CUDA_VISIBLE_DEVICES'] = str(gpu_id)
        print("set CUDA_VISIBLE_DEVICES=", gpu_id)
    else:
        os.environ['CUDA_VISIBLE_DEVICES'] = ""
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    print("using device %s" % device)
    return device


def setup_seed(seed):
    if seed < 0:
        if os.getenv('SATOSHI_SEED') is not None and seed == -2:
            seed = int(os.getenv('SATOSHI_SEED'))
            print("env seed used")
        else:
            import math
            seed = int(10**4 * math.modf(time.time())[0])
            seed = seed
    print("random seed", seed)
    return seed


def setup_savedir(prefix="", basedir="./experiments", args=None, append_args=[]):
    savedir = prefix
    if len(append_args) > 0 and args is not None:
        for arg_opt in append_args:
            arg_value = getattr(args, arg_opt)
            savedir += "_" + arg_opt + "-" + str(arg_value)
    else:
        savedir += "exp"

    savedir = savedir.replace(" ", "").replace("'", "").replace('"', '')
    savedir = os.path.join(basedir, savedir)

    #if exists, append _num-[num]
    i = 1
    savedir_ori = savedir
    while True:
        try:
            os.makedirs(savedir)
            break
        except FileExistsError as e:
            savedir = savedir_ori + "_num-%d" % i
            i += 1

    print("made the log directory", savedir)
    return savedir


def save_args(savedir, args, name="args.json"):
    #save args as "args.json" in the savedir
    path = os.path.join(savedir, name)
    with open(path, 'w') as f:
        json.dump(vars(args), f, sort_keys=True, indent=4)
    print("args saved as %s" % path)


def save_json(dict, path):
    with open(path, 'w') as f:
        json.dump(dict, f, sort_keys=True, indent=4)
        print("log saved at %s" % path)


def resume_model(model, resume, state_dict_key="model"):
    '''
    model:pytorch model
    resume: path to the resume file
    state_dict_key: dict key 
    '''
    print("resuming trained weights from %s" % resume)

    checkpoint = torch.load(resume, map_location='cpu')
    if state_dict_key is not None:
        pretrained_dict = checkpoint[state_dict_key]
    else:
        pretrained_dict = checkpoint

    try:
        model.load_state_dict(pretrained_dict)
    except RuntimeError as e:
        print(e)
        print("can't load the all weights due to error above, trying to load part of them!")
        model_dict = model.state_dict()
        # 1. filter out unnecessary keys
        pretrained_dict_use = {}
        pretrained_dict_ignored = {}
        for k, v in pretrained_dict.items():
            if k in model_dict:
                pretrained_dict_use[k] = v
            else:
                pretrained_dict_ignored[k] = v
        pretrained_dict = pretrained_dict_use
        # 2. overwrite entries in the existing state dict
        model_dict.update(pretrained_dict)
        # 3. load the new state dict
        model.load_state_dict(model_dict)
        print("resumed only", pretrained_dict.keys())
        print("ignored:", pretrained_dict_ignored.keys())

    return model


def save_checkpoint(path, model, key="model"):
    #save model state dict
    checkpoint = {}
    checkpoint[key] = model.state_dict()
    torch.save(checkpoint, path)
    print("checkpoint saved at", path)


def make_deterministic(seed, strict=False):
    #https://github.com/pytorch/pytorch/issues/7068#issuecomment-487907668
    random.seed(seed)
    os.environ['PYTHONHASHSEED'] = str(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)    # if you are using multi-GPU.
    # torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True
    # torch.set_deterministic(True)
    if strict:
        #https://github.com/pytorch/pytorch/issues/7068#issuecomment-515728600
        torch.backends.cudnn.enabled = False
        print(
            "strict reproducability required! cudnn disabled. make sure to set num_workers=0 too!")