File size: 1,745 Bytes
69ad385
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
#!/usr/bin/python3
# -*- coding: utf-8 -*-
import copy
import os
from typing import Any, Dict, Union

import yaml


CONFIG_FILE = "config.yaml"


class PretrainedConfig(object):
    def __init__(self, **kwargs):
        pass

    @classmethod
    def _dict_from_yaml_file(cls, yaml_file: Union[str, os.PathLike]):
        with open(yaml_file, encoding="utf-8") as f:
            config_dict = yaml.safe_load(f)
        return config_dict

    @classmethod
    def get_config_dict(
        cls, pretrained_model_name_or_path: Union[str, os.PathLike]
    ) -> Dict[str, Any]:
        if os.path.isdir(pretrained_model_name_or_path):
            config_file = os.path.join(pretrained_model_name_or_path, CONFIG_FILE)
        else:
            config_file = pretrained_model_name_or_path
        config_dict = cls._dict_from_yaml_file(config_file)
        return config_dict

    @classmethod
    def from_dict(cls, config_dict: Dict[str, Any], **kwargs):
        for k, v in kwargs.items():
            if k in config_dict.keys():
                config_dict[k] = v
        config = cls(**config_dict)
        return config

    @classmethod
    def from_pretrained(
        cls,
        pretrained_model_name_or_path: Union[str, os.PathLike],
        **kwargs,
    ):
        config_dict = cls.get_config_dict(pretrained_model_name_or_path)
        return cls.from_dict(config_dict, **kwargs)

    def to_dict(self):
        output = copy.deepcopy(self.__dict__)
        return output

    def to_yaml_file(self, yaml_file_path: Union[str, os.PathLike]):
        config_dict = self.to_dict()

        with open(yaml_file_path, "w", encoding="utf-8") as writer:
            yaml.safe_dump(config_dict, writer)


if __name__ == '__main__':
    pass