File size: 1,346 Bytes
7c8fd9c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 |
import argparse
from omegaconf import OmegaConf, DictConfig
import os
def load_config(print_config = True):
parser = argparse.ArgumentParser()
parser.add_argument('--config', type=str,
default='configs/tea-pour.yaml',
help="Config file path")
args = parser.parse_args()
config = OmegaConf.load(args.config)
# Recursively merge base configs
cur_config_path = args.config
cur_config = config
while "base_config" in cur_config and cur_config.base_config != cur_config_path:
base_config = OmegaConf.load(cur_config.base_config)
config = OmegaConf.merge(base_config, config)
cur_config_path = cur_config.base_config
cur_config = base_config
prompt = config.generation.prompt
if isinstance(prompt, str):
prompt = {"edit": prompt}
config.generation.prompt = prompt
OmegaConf.resolve(config)
if print_config:
print("[INFO] loaded config:")
print(OmegaConf.to_yaml(config))
return config
def save_config(config: DictConfig, path, gene = False, inv = False):
os.makedirs(path, exist_ok = True)
config = OmegaConf.create(config)
if gene:
config.pop("inversion")
if inv:
config.pop("generation")
OmegaConf.save(config, os.path.join(path, "config.yaml")) |