SciPIP / src /config /reader.py
lihuigu's picture
update new version
c8709b2
r"""_summary_
-*- coding: utf-8 -*-
Module : configs.reader
File Name : reader.py
Description : Load the config file, which supports referencing other configuration files. If a circular reference occurs, an exception will be thrown
Creation Date : 2024-07-13
Author : Frank Kang([email protected])
"""
import pathlib
import json
import os
import warnings
from typing import Union, Any, IO
from omegaconf import OmegaConf, DictConfig, ListConfig
from .utils import get_dir
class ConfigReader:
"""_summary_
Load the config file, which supports referencing other configuration files. If a circular reference occurs, an exception will be thrown
for examples:
```
config = ConfigReader.load(file)
```
"""
def __init__(
self,
file_: Union[str, pathlib.Path, IO[Any]],
included: set | None = None
) -> None:
"""_summary_
Args:
file_ (Union[str, pathlib.Path, IO[Any]]): config
included (set | None, optional): Include config file. Defaults to None.
Raises:
FileNotFoundError: If the configuration file cannot be found
RecursionError: If there is a loop include
"""
fname = ''
self.included = included if included is not None else set()
if isinstance(file_, str):
fname = file_
if not os.path.exists(fname):
template_path = '{}.template'.format(fname)
if os.path.exists(template_path):
with open(fname, 'w', encoding='utf8') as wf:
with open(template_path, 'r', encoding='utf8') as rf:
wf.write(rf.read())
warnings.warn(
'cannot find file {}. Auto generate from {}'.format(
fname, template_path))
else:
raise FileNotFoundError(
'cannot find file {}'.format(fname))
else:
fname = file_.name
suffix = fname.split('.')[-1]
if suffix == 'yaml':
config = OmegaConf.load(fname)
elif suffix == 'json':
if isinstance(file_, (str, IO[Any])):
with open(file_, 'r', encoding='utf8') as f:
config = json.load(f)
else:
config = json.load(file_)
config = DictConfig(config)
if fname not in self.included:
self.included.add(fname)
else:
raise RecursionError()
self.__config = config
self.complied = False
def complie(self, config: DictConfig | None = None):
"""_summary_
Resolve config to make include effective
Args:
config (DictConfig | None, optional): dict config. Defaults to None.
Raises:
RecursionError: If there is a loop include
"""
modify_flag = False
if config is None:
config = self.__config
modify_flag = True
include_item = None
for key in config.keys():
value = config.get(key)
if isinstance(value, DictConfig):
self.complie(value)
if include_item is not None:
if isinstance(include_item, str):
included = self.included.copy()
if include_item in included:
print(include_item, included)
raise RecursionError()
included.add(include_item)
config.merge_with(ConfigReader.load(include_item, included))
else:
for item in include_item:
included = self.included.copy()
if item in included:
print(include_item, included)
raise RecursionError()
config.merge_with(ConfigReader.load(item, included))
included.add(item)
if modify_flag:
self.complied = True
@property
def config(self) -> DictConfig:
"""_summary_
Obtain parsed dict config
Returns:
DictConfig: parsed dict config
"""
if not self.complied:
self.complie()
return self.__config
@staticmethod
def load(
file_: Union[str, pathlib.Path, IO[Any]],
included: set | None = None,
**kwargs
) -> DictConfig:
"""_summary_
Class method loading configuration file
Args:
file_ (Union[str, pathlib.Path, IO[Any]]): config
included (set | None, optional): Include config file. Defaults to None.
Returns:
DictConfig: parsed dict config
"""
config = ConfigReader(file_, included).config
for k, v in kwargs.items():
config.get(k, {}).update(v)
return config