|
r"""_summary_ |
|
-*- coding: utf-8 -*- |
|
|
|
Module : configs.reader |
|
|
|
File Name : reader.py |
|
|
|
Description : Load the config file, which supports referencing other configuration files. If a circular reference occurs, an exception will be thrown |
|
|
|
Creation Date : 2024-07-13 |
|
|
|
Author : Frank Kang([email protected]) |
|
""" |
|
import pathlib |
|
import json |
|
|
|
import os |
|
import warnings |
|
|
|
from typing import Union, Any, IO |
|
from omegaconf import OmegaConf, DictConfig, ListConfig |
|
|
|
from .utils import get_dir |
|
|
|
class ConfigReader: |
|
"""_summary_ |
|
Load the config file, which supports referencing other configuration files. If a circular reference occurs, an exception will be thrown |
|
|
|
for examples: |
|
``` |
|
config = ConfigReader.load(file) |
|
``` |
|
""" |
|
|
|
def __init__( |
|
self, |
|
file_: Union[str, pathlib.Path, IO[Any]], |
|
included: set | None = None |
|
) -> None: |
|
"""_summary_ |
|
|
|
Args: |
|
file_ (Union[str, pathlib.Path, IO[Any]]): config |
|
included (set | None, optional): Include config file. Defaults to None. |
|
|
|
Raises: |
|
FileNotFoundError: If the configuration file cannot be found |
|
RecursionError: If there is a loop include |
|
""" |
|
fname = '' |
|
self.included = included if included is not None else set() |
|
if isinstance(file_, str): |
|
fname = file_ |
|
if not os.path.exists(fname): |
|
template_path = '{}.template'.format(fname) |
|
if os.path.exists(template_path): |
|
with open(fname, 'w', encoding='utf8') as wf: |
|
with open(template_path, 'r', encoding='utf8') as rf: |
|
wf.write(rf.read()) |
|
warnings.warn( |
|
'cannot find file {}. Auto generate from {}'.format( |
|
fname, template_path)) |
|
else: |
|
raise FileNotFoundError( |
|
'cannot find file {}'.format(fname)) |
|
else: |
|
fname = file_.name |
|
|
|
suffix = fname.split('.')[-1] |
|
if suffix == 'yaml': |
|
config = OmegaConf.load(fname) |
|
elif suffix == 'json': |
|
if isinstance(file_, (str, IO[Any])): |
|
with open(file_, 'r', encoding='utf8') as f: |
|
config = json.load(f) |
|
else: |
|
config = json.load(file_) |
|
config = DictConfig(config) |
|
if fname not in self.included: |
|
self.included.add(fname) |
|
else: |
|
raise RecursionError() |
|
self.__config = config |
|
self.complied = False |
|
|
|
def complie(self, config: DictConfig | None = None): |
|
"""_summary_ |
|
|
|
Resolve config to make include effective |
|
|
|
Args: |
|
config (DictConfig | None, optional): dict config. Defaults to None. |
|
|
|
Raises: |
|
RecursionError: If there is a loop include |
|
""" |
|
modify_flag = False |
|
if config is None: |
|
config = self.__config |
|
modify_flag = True |
|
|
|
include_item = None |
|
|
|
for key in config.keys(): |
|
value = config.get(key) |
|
if isinstance(value, DictConfig): |
|
self.complie(value) |
|
|
|
if include_item is not None: |
|
if isinstance(include_item, str): |
|
included = self.included.copy() |
|
if include_item in included: |
|
print(include_item, included) |
|
raise RecursionError() |
|
included.add(include_item) |
|
config.merge_with(ConfigReader.load(include_item, included)) |
|
|
|
else: |
|
for item in include_item: |
|
included = self.included.copy() |
|
if item in included: |
|
print(include_item, included) |
|
raise RecursionError() |
|
config.merge_with(ConfigReader.load(item, included)) |
|
included.add(item) |
|
|
|
if modify_flag: |
|
self.complied = True |
|
|
|
@property |
|
def config(self) -> DictConfig: |
|
"""_summary_ |
|
|
|
Obtain parsed dict config |
|
|
|
Returns: |
|
DictConfig: parsed dict config |
|
""" |
|
if not self.complied: |
|
self.complie() |
|
return self.__config |
|
|
|
@staticmethod |
|
def load( |
|
file_: Union[str, pathlib.Path, IO[Any]], |
|
included: set | None = None, |
|
**kwargs |
|
) -> DictConfig: |
|
"""_summary_ |
|
|
|
Class method loading configuration file |
|
|
|
Args: |
|
file_ (Union[str, pathlib.Path, IO[Any]]): config |
|
included (set | None, optional): Include config file. Defaults to None. |
|
|
|
Returns: |
|
DictConfig: parsed dict config |
|
""" |
|
config = ConfigReader(file_, included).config |
|
for k, v in kwargs.items(): |
|
config.get(k, {}).update(v) |
|
return config |
|
|