|
import importlib |
|
import json |
|
import logging |
|
import os |
|
from datetime import datetime |
|
from functools import reduce, partial |
|
from operator import getitem |
|
from pathlib import Path |
|
|
|
from hw_asr import text_encoder as text_encoder_module |
|
from hw_asr.base.base_text_encoder import BaseTextEncoder |
|
from hw_asr.logger import setup_logging |
|
from hw_asr.text_encoder import CTCCharTextEncoder |
|
from hw_asr.utils import read_json, write_json, ROOT_PATH |
|
|
|
|
|
class ConfigParser: |
|
def __init__(self, config, resume=None, modification=None, run_id=None): |
|
""" |
|
class to parse configuration json file. Handles hyperparameters for training, |
|
initializations of modules, checkpoint saving and logging module. |
|
:param config: Dict containing configurations, hyperparameters for training. |
|
contents of `config.json` file for example. |
|
:param resume: String, path to the checkpoint being loaded. |
|
:param modification: Dict {keychain: value}, specifying position values to be replaced |
|
from config dict. |
|
:param run_id: Unique Identifier for training processes. |
|
Used to save checkpoints and training log. Timestamp is being used as default |
|
""" |
|
|
|
self._config = _update_config(config, modification) |
|
self.resume = resume |
|
self._text_encoder = None |
|
|
|
|
|
save_dir = Path(self.config["trainer"]["save_dir"]) |
|
|
|
exper_name = self.config["name"] |
|
if run_id is None: |
|
run_id = datetime.now().strftime(r"%m%d_%H%M%S") |
|
self._save_dir = str(save_dir / "models" / exper_name / run_id) |
|
self._log_dir = str(save_dir / "log" / exper_name / run_id) |
|
|
|
|
|
exist_ok = run_id == "" |
|
self.save_dir.mkdir(parents=True, exist_ok=exist_ok) |
|
self.log_dir.mkdir(parents=True, exist_ok=exist_ok) |
|
|
|
|
|
write_json(self.config, self.save_dir / "config.json") |
|
|
|
|
|
setup_logging(self.log_dir) |
|
self.log_levels = {0: logging.WARNING, 1: logging.INFO, 2: logging.DEBUG} |
|
|
|
@classmethod |
|
def from_args(cls, args, options=""): |
|
""" |
|
Initialize this class from some cli arguments. Used in train, test. |
|
""" |
|
for opt in options: |
|
args.add_argument(*opt.flags, default=None, type=opt.type) |
|
if not isinstance(args, tuple): |
|
args = args.parse_args() |
|
|
|
if args.device is not None: |
|
os.environ["CUDA_VISIBLE_DEVICES"] = args.device |
|
if args.resume is not None: |
|
resume = Path(args.resume) |
|
cfg_fname = resume.parent / "config.json" |
|
else: |
|
msg_no_cfg = "Configuration file need to be specified. " \ |
|
"Add '-c config.json', for example." |
|
assert args.config is not None, msg_no_cfg |
|
resume = None |
|
cfg_fname = Path(args.config) |
|
|
|
config = read_json(cfg_fname) |
|
if args.config and resume: |
|
|
|
config.update(read_json(args.config)) |
|
|
|
|
|
modification = { |
|
opt.target: getattr(args, _get_opt_name(opt.flags)) for opt in options |
|
} |
|
return cls(config, resume, modification) |
|
|
|
@staticmethod |
|
def init_obj(obj_dict, default_module, *args, **kwargs): |
|
""" |
|
Finds a function handle with the name given as 'type' in config, and returns the |
|
instance initialized with corresponding arguments given. |
|
|
|
`object = config.init_obj(config['param'], module, a, b=1)` |
|
is equivalent to |
|
`object = module.name(a, b=1)` |
|
""" |
|
if "module" in obj_dict: |
|
default_module = importlib.import_module(obj_dict["module"]) |
|
|
|
module_name = obj_dict["type"] |
|
module_args = dict(obj_dict["args"]) |
|
assert all( |
|
[k not in module_args for k in kwargs] |
|
), "Overwriting kwargs given in config file is not allowed" |
|
module_args.update(kwargs) |
|
return getattr(default_module, module_name)(*args, **module_args) |
|
|
|
def init_ftn(self, name, module, *args, **kwargs): |
|
""" |
|
Finds a function handle with the name given as 'type' in config, and returns the |
|
function with given arguments fixed with functools.partial. |
|
|
|
`function = config.init_ftn('name', module, a, b=1)` |
|
is equivalent to |
|
`function = lambda *args, **kwargs: module.name(a, *args, b=1, **kwargs)`. |
|
""" |
|
module_name = self[name]["type"] |
|
module_args = dict(self[name]["args"]) |
|
assert all( |
|
[k not in module_args for k in kwargs] |
|
), "Overwriting kwargs given in config file is not allowed" |
|
module_args.update(kwargs) |
|
return partial(getattr(module, module_name), *args, **module_args) |
|
|
|
def __getitem__(self, name): |
|
"""Access items like ordinary dict.""" |
|
return self.config[name] |
|
|
|
def get_logger(self, name, verbosity=2): |
|
msg_verbosity = "verbosity option {} is invalid. Valid options are {}.".format( |
|
verbosity, self.log_levels.keys() |
|
) |
|
assert verbosity in self.log_levels, msg_verbosity |
|
logger = logging.getLogger(name) |
|
logger.setLevel(self.log_levels[verbosity]) |
|
return logger |
|
|
|
def get_text_encoder(self) -> BaseTextEncoder: |
|
if self._text_encoder is None: |
|
if "text_encoder" not in self._config: |
|
self._text_encoder = CTCCharTextEncoder() |
|
elif self._config["text_encoder"] == "CTCCharTextEncoder": |
|
self._text_encoder = CTCCharTextEncoder(self._config["text_encoder"]["args"]) |
|
else: |
|
self._text_encoder = self.init_obj(self["text_encoder"], |
|
default_module=text_encoder_module) |
|
return self._text_encoder |
|
|
|
|
|
@property |
|
def config(self): |
|
return self._config |
|
|
|
@property |
|
def save_dir(self): |
|
return Path(self._save_dir) |
|
|
|
@property |
|
def log_dir(self): |
|
return Path(self._log_dir) |
|
|
|
@classmethod |
|
def get_default_configs(cls): |
|
config_path = ROOT_PATH / "hw_asr" / "config.json" |
|
with config_path.open() as f: |
|
return cls(json.load(f)) |
|
|
|
@classmethod |
|
def get_test_configs(cls): |
|
config_path = ROOT_PATH / "hw_asr" / "tests" / "config.json" |
|
with config_path.open() as f: |
|
return cls(json.load(f)) |
|
|
|
|
|
|
|
def _update_config(config, modification): |
|
if modification is None: |
|
return config |
|
|
|
for k, v in modification.items(): |
|
if v is not None: |
|
_set_by_path(config, k, v) |
|
return config |
|
|
|
|
|
def _get_opt_name(flags): |
|
for flg in flags: |
|
if flg.startswith("--"): |
|
return flg.replace("--", "") |
|
return flags[0].replace("--", "") |
|
|
|
|
|
def _set_by_path(tree, keys, value): |
|
"""Set a value in a nested object in tree by sequence of keys.""" |
|
keys = keys.split(";") |
|
_get_by_path(tree, keys[:-1])[keys[-1]] = value |
|
|
|
|
|
def _get_by_path(tree, keys): |
|
"""Access a nested object in tree by sequence of keys.""" |
|
return reduce(getitem, keys, tree) |
|
|