#!/usr/bin/python3 # -*- coding: utf-8 -*- import copy import os from typing import Any, Dict, Union import yaml CONFIG_FILE = "config.yaml" class PretrainedConfig(object): def __init__(self, **kwargs): pass @classmethod def _dict_from_yaml_file(cls, yaml_file: Union[str, os.PathLike]): with open(yaml_file, encoding="utf-8") as f: config_dict = yaml.safe_load(f) return config_dict @classmethod def get_config_dict( cls, pretrained_model_name_or_path: Union[str, os.PathLike] ) -> Dict[str, Any]: if os.path.isdir(pretrained_model_name_or_path): config_file = os.path.join(pretrained_model_name_or_path, CONFIG_FILE) else: config_file = pretrained_model_name_or_path config_dict = cls._dict_from_yaml_file(config_file) return config_dict @classmethod def from_dict(cls, config_dict: Dict[str, Any], **kwargs): for k, v in kwargs.items(): if k in config_dict.keys(): config_dict[k] = v config = cls(**config_dict) return config @classmethod def from_pretrained( cls, pretrained_model_name_or_path: Union[str, os.PathLike], **kwargs, ): config_dict = cls.get_config_dict(pretrained_model_name_or_path) return cls.from_dict(config_dict, **kwargs) def to_dict(self): output = copy.deepcopy(self.__dict__) return output def to_yaml_file(self, yaml_file_path: Union[str, os.PathLike]): config_dict = self.to_dict() with open(yaml_file_path, "w", encoding="utf-8") as writer: yaml.safe_dump(config_dict, writer) if __name__ == '__main__': pass