File size: 4,511 Bytes
222619b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
98a2c89
 
 
222619b
 
 
 
98a2c89
222619b
 
 
 
 
 
 
98a2c89
 
 
 
222619b
 
 
 
 
 
 
 
 
 
 
 
 
 
98a2c89
 
 
222619b
 
 
 
 
 
 
98a2c89
 
222619b
98a2c89
222619b
 
98a2c89
 
 
 
 
 
 
 
222619b
98a2c89
 
222619b
 
 
 
 
 
 
98a2c89
 
222619b
98a2c89
 
222619b
 
 
 
 
98a2c89
222619b
 
 
 
 
 
98a2c89
222619b
 
 
 
 
 
 
98a2c89
 
 
 
 
222619b
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
import argparse
import os
import yaml

global_print_hparams = True
hparams = {}


class Args:
    def __init__(self, **kwargs):
        for k, v in kwargs.items():
            self.__setattr__(k, v)


def override_config(old_config: dict, new_config: dict):
    for k, v in new_config.items():
        if isinstance(v, dict) and k in old_config:
            override_config(old_config[k], new_config[k])
        else:
            old_config[k] = v


def set_hparams(config='', exp_name='', hparams_str='', print_hparams=True, global_hparams=True):
    if config == '':
        parser = argparse.ArgumentParser(description='neural music')
        parser.add_argument('--config', type=str, default='',
                            help='location of the data corpus')
        parser.add_argument('--exp_name', type=str, default='', help='exp_name')
        parser.add_argument('--hparams', type=str, default='',
                            help='location of the data corpus')
        parser.add_argument('--inference', action='store_true', help='inference')
        parser.add_argument('--validate', action='store_true', help='validate')
        parser.add_argument('--reset', action='store_true', help='reset hparams')
        parser.add_argument('--debug', action='store_true', help='debug')
        args, unknown = parser.parse_known_args()
    else:
        args = Args(config=config, exp_name=exp_name, hparams=hparams_str,
                    infer=False, validate=False, reset=False, debug=False)
    args_work_dir = ''
    if args.exp_name != '':
        args.work_dir = args.exp_name
        args_work_dir = f'checkpoints/{args.work_dir}'

    config_chains = []
    loaded_config = set()

    def load_config(config_fn):  # deep first
        with open(config_fn) as f:
            hparams_ = yaml.safe_load(f)
        loaded_config.add(config_fn)
        if 'base_config' in hparams_:
            ret_hparams = {}
            if not isinstance(hparams_['base_config'], list):
                hparams_['base_config'] = [hparams_['base_config']]
            for c in hparams_['base_config']:
                if c not in loaded_config:
                    if c.startswith('.'):
                        c = f'{os.path.dirname(config_fn)}/{c}'
                        c = os.path.normpath(c)
                    override_config(ret_hparams, load_config(c))
            override_config(ret_hparams, hparams_)
        else:
            ret_hparams = hparams_
        config_chains.append(config_fn)
        return ret_hparams

    global hparams
    assert args.config != '' or args_work_dir != ''
    saved_hparams = {}
    if args_work_dir != 'checkpoints/':
        ckpt_config_path = f'{args_work_dir}/config.yaml'
        if os.path.exists(ckpt_config_path):
            try:
                with open(ckpt_config_path) as f:
                    saved_hparams.update(yaml.safe_load(f))
            except:
                pass
        if args.config == '':
            args.config = ckpt_config_path

    hparams_ = {}
    hparams_.update(load_config(args.config))

    if not args.reset:
        hparams_.update(saved_hparams)
    hparams_['work_dir'] = args_work_dir

    if args.hparams != "":
        for new_hparam in args.hparams.split(","):
            k, v = new_hparam.split("=")
            if v in ['True', 'False'] or type(hparams_[k]) == bool:
                hparams_[k] = eval(v)
            else:
                hparams_[k] = type(hparams_[k])(v)

    if args_work_dir != '' and (not os.path.exists(ckpt_config_path) or args.reset) and not args.infer:
        os.makedirs(hparams_['work_dir'], exist_ok=True)
        with open(ckpt_config_path, 'w') as f:
            yaml.safe_dump(hparams_, f)

    hparams_['inference'] = args.infer
    hparams_['debug'] = args.debug
    hparams_['validate'] = args.validate
    global global_print_hparams
    if global_hparams:
        hparams.clear()
        hparams.update(hparams_)

    if print_hparams and global_print_hparams and global_hparams:
        print('| Hparams chains: ', config_chains)
        print('| Hparams: ')
        for i, (k, v) in enumerate(sorted(hparams_.items())):
            print(f"\033[;33;m{k}\033[0m: {v}, ", end="\n" if i % 5 == 4 else "")
        print("")
        global_print_hparams = False
    # print(hparams_.keys())
    if hparams.get('exp_name') is None:
        hparams['exp_name'] = args.exp_name
    if hparams_.get('exp_name') is None:
        hparams_['exp_name'] = args.exp_name
    return hparams_