Spaces:
Sleeping
Sleeping
File size: 4,284 Bytes
202eff6 6ba63c9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 |
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import yaml
import json
import argparse
import logging
logger = logging.getLogger(__name__)
def load_config_dict_to_opt(opt, config_dict):
"""
Load the key, value pairs from config_dict to opt, overriding existing values in opt
if there is any.
"""
if not isinstance(config_dict, dict):
raise TypeError("Config must be a Python dictionary")
for k, v in config_dict.items():
k_parts = k.split('.')
pointer = opt
for k_part in k_parts[:-1]:
if k_part not in pointer:
pointer[k_part] = {}
pointer = pointer[k_part]
assert isinstance(pointer, dict), "Overriding key needs to be inside a Python dict."
ori_value = pointer.get(k_parts[-1])
pointer[k_parts[-1]] = v
if ori_value:
logger.warning(f"Overrided {k} from {ori_value} to {pointer[k_parts[-1]]}")
def load_opt_from_config_files(conf_files):
"""
Load opt from the config files, settings in later files can override those in previous files.
Args:
conf_files (list): a list of config file paths
Returns:
dict: a dictionary of opt settings
"""
opt = {}
for conf_file in conf_files:
with open(conf_file, encoding='utf-8') as f:
config_dict = yaml.safe_load(f)
load_config_dict_to_opt(opt, config_dict)
return opt
def load_opt_command(args):
parser = argparse.ArgumentParser(description='Pretrain or fine-tune models for NLP tasks.')
parser.add_argument('command', help='Command: train/evaluate/train-and-evaluate')
parser.add_argument('--conf_files', nargs='+', required=True, help='Path(s) to the config file(s).')
parser.add_argument('--user_dir', help='Path to the user defined module for tasks (models, criteria), optimizers, and lr schedulers.')
parser.add_argument('--config_overrides', nargs='*', help='Override parameters on config with a json style string, e.g. {"<PARAM_NAME_1>": <PARAM_VALUE_1>, "<PARAM_GROUP_2>.<PARAM_SUBGROUP_2>.<PARAM_2>": <PARAM_VALUE_2>}. A key with "." updates the object in the corresponding nested dict. Remember to escape " in command line.')
parser.add_argument('--overrides', help='arguments that used to override the config file in cmdline', nargs=argparse.REMAINDER)
cmdline_args = parser.parse_args() if not args else parser.parse_args(args)
opt = load_opt_from_config_files(cmdline_args.conf_files)
if cmdline_args.config_overrides:
config_overrides_string = ' '.join(cmdline_args.config_overrides)
logger.warning(f"Command line config overrides: {config_overrides_string}")
config_dict = json.loads(config_overrides_string)
load_config_dict_to_opt(opt, config_dict)
if cmdline_args.overrides:
assert len(cmdline_args.overrides) % 2 == 0, "overrides arguments is not paired, required: key value"
keys = [cmdline_args.overrides[idx*2] for idx in range(len(cmdline_args.overrides)//2)]
vals = [cmdline_args.overrides[idx*2+1] for idx in range(len(cmdline_args.overrides)//2)]
vals = [val.replace('false', '').replace('False','') if len(val.replace(' ', '')) == 5 else val for val in vals]
types = []
for key in keys:
key = key.split('.')
ele = opt.copy()
while len(key) > 0:
ele = ele[key.pop(0)]
types.append(type(ele))
config_dict = {x:z(y) for x,y,z in zip(keys, vals, types)}
load_config_dict_to_opt(opt, config_dict)
# combine cmdline_args into opt dictionary
for key, val in cmdline_args.__dict__.items():
if val is not None:
opt[key] = val
return opt, cmdline_args |