File size: 10,753 Bytes
89c0b51
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
# Copyright 2024 ByteDance and/or its affiliates.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import argparse
import copy
import sys
from typing import Any, Optional, Union

import yaml
from ml_collections.config_dict import ConfigDict

from protenix.config.extend_types import (
    DefaultNoneWithType,
    GlobalConfigValue,
    ListValue,
    RequiredValue,
    ValueMaybeNone,
    get_bool_value,
)


class ArgumentNotSet(object):
    pass


class ConfigManager(object):
    def __init__(self, global_configs: dict, fill_required_with_null: bool = False):
        """
        Initialize the ConfigManager instance.

        Args:
            global_configs (dict): A dictionary containing global configuration settings.
            fill_required_with_null (bool, optional):
                A boolean flag indicating whether required values should be filled with `None` if not provided. Defaults to False.
        """
        self.global_configs = global_configs
        self.fill_required_with_null = fill_required_with_null
        self.config_infos, self.default_configs = self.get_config_infos()

    def get_value_info(
        self, value
    ) -> tuple[Any, Optional[Any], Optional[bool], Optional[bool]]:
        """
        Return the type, default value, whether it allows None, and whether it is required for a given value.

        Args:
            value: The value to determine the information for.

        Returns:
            tuple: A tuple containing the following elements:
                - dtype: The type of the value.
                - default_value: The default value for the value.
                - allow_none: A boolean indicating whether the value can be None.
                - required: A boolean indicating whether the value is required.
        """
        if isinstance(value, DefaultNoneWithType):
            return value.dtype, None, True, False
        elif isinstance(value, ValueMaybeNone):
            return value.dtype, value.value, True, False
        elif isinstance(value, RequiredValue):
            if self.fill_required_with_null:
                return value.dtype, None, True, False
            else:
                return value.dtype, None, False, True
        elif isinstance(value, GlobalConfigValue):
            return self.get_value_info(self.global_configs[value.global_key])
        elif isinstance(value, ListValue):
            return (value.dtype, value.value, False, False)
        elif isinstance(value, list):
            return (type(value[0]), value, False, False)
        else:
            return type(value), value, False, False

    def _get_config_infos(self, config_dict: dict) -> dict:
        """
        Recursively extracts configuration information from a given dictionary.

        Args:
            config_dict (dict): The dictionary containing configuration settings.

        Returns:
            tuple: A tuple containing two dictionaries:
                - all_keys: A dictionary mapping keys to their corresponding configuration information.
                - default_configs: A dictionary mapping keys to their default configuration values.

        Raises:
            AssertionError: If a key contains a period (.), which is not allowed.
        """
        all_keys = {}
        default_configs = {}
        for key, value in config_dict.items():
            assert "." not in key
            if isinstance(value, (dict)):
                children_keys, children_configs = self._get_config_infos(value)
                all_keys.update(
                    {
                        f"{key}.{child_key}": child_value_type
                        for child_key, child_value_type in children_keys.items()
                    }
                )
                default_configs[key] = children_configs
            else:
                value_info = self.get_value_info(value)
                all_keys[key] = value_info
                default_configs[key] = value_info[1]
        return all_keys, default_configs

    def get_config_infos(self):
        return self._get_config_infos(self.global_configs)

    def _merge_configs(
        self,
        new_configs: dict,
        global_configs: dict,
        local_configs: dict,
        prefix="",
    ) -> ConfigDict:
        """Overwrite default configs with new configs recursively.
        Args:
            new_configs: global flattern config dict with all hierarchical config keys joined by '.', i.e.
                {
                    'c_z': 32,
                    'model.evoformer.c_z': 16,
                    ...
                }
            global_configs: global hierarchical merging configs, i.e.
                {
                    'c_z' 32,
                    'c_m': 128,
                    'model': {
                        'evoformer': {
                            ...
                        }
                    }
                }
            local_configs: hierarchical merging config dict in current level, i.e. for 'model' level, this maybe
                {
                    'evoformer': {
                        'c_z': GlobalConfigValue("c_z"),
                    },
                    'embedder': {
                        ...
                    }
                }
            prefix (str, optional): A prefix string to prepend to keys during recursion. Defaults to an empty string.

        Returns:
            ConfigDict: The merged configuration dictionary.

        Raises:
            Exception: If a required config value is not allowed to be None.
        """
        # Merge configs in current level first, since these configs maybe referenced by lower level
        for key, value in local_configs.items():
            if isinstance(value, dict):
                continue
            full_key = f"{prefix}.{key}" if prefix else key
            dtype, default_value, allow_none, required = self.config_infos[full_key]
            if full_key in new_configs and not isinstance(
                new_configs[full_key], ArgumentNotSet
            ):
                if allow_none and new_configs[full_key] in [
                    "None",
                    "none",
                    "null",
                ]:
                    local_configs[key] = None
                elif dtype == bool:
                    local_configs[key] = get_bool_value(new_configs[full_key])
                elif isinstance(value, (ListValue, list)):
                    local_configs[key] = (
                        [dtype(s) for s in new_configs[full_key].strip().split(",")]
                        if new_configs[full_key].strip()
                        else []
                    )
                else:
                    local_configs[key] = dtype(new_configs[full_key])
            elif isinstance(value, GlobalConfigValue):
                local_configs[key] = global_configs[value.global_key]
            else:
                if not allow_none and default_value is None:
                    raise Exception(f"config {full_key} not allowed to be none")
                local_configs[key] = default_value
        for key, value in local_configs.items():
            if not isinstance(value, dict):
                continue
            self._merge_configs(
                new_configs, global_configs, value, f"{prefix}.{key}" if prefix else key
            )

    def merge_configs(self, new_configs: dict) -> ConfigDict:
        configs = copy.deepcopy(self.global_configs)
        self._merge_configs(new_configs, configs, configs)
        return ConfigDict(configs)


def parse_configs(
    configs: dict, arg_str: str = None, fill_required_with_null: bool = False
) -> ConfigDict:
    """
    Parses and merges configuration settings from a dictionary and command-line arguments.

    Args:
        configs (dict): A dictionary containing initial configuration settings.
        arg_str (str, optional): A string representing command-line arguments. Defaults to None.
        fill_required_with_null (bool, optional):
            A boolean flag indicating whether required values should be filled with `None` if not provided. Defaults to False.

    Returns:
        ConfigDict: The merged configuration dictionary.
    """
    manager = ConfigManager(configs, fill_required_with_null=fill_required_with_null)
    parser = argparse.ArgumentParser()
    # Register arguments
    for key, (
        dtype,
        default_value,
        allow_none,
        required,
    ) in manager.config_infos.items():
        # All config use str type, strings will be converted to real dtype later
        parser.add_argument(
            "--" + key, type=str, default=ArgumentNotSet(), required=required
        )
    # Merge user commandline pargs with default ones
    merged_configs = manager.merge_configs(
        vars(parser.parse_args(arg_str.split())) if arg_str else {}
    )
    return merged_configs


def parse_sys_args() -> str:
    """
    Check whether command-line arguments are valid.
    Each argument is expected to be in the format `--key value`.

    Returns:
        str: A string formatted as command-line arguments.

    Raises:
        AssertionError: If any key does not start with `--`.
    """
    args = sys.argv[1:]
    arg_str = ""
    for k, v in zip(args[::2], args[1::2]):
        if not k.startswith("--"):
            print(k)
        arg_str += f"{k} {v} "
    return arg_str


def load_config(path: str) -> dict:
    """
    Loads a configuration from a YAML file.

    Args:
        path (str): The path to the YAML file containing the configuration.

    Returns:
        dict: A dictionary containing the configuration loaded from the YAML file.
    """
    with open(path, "r") as f:
        return yaml.safe_load(f)


def save_config(config: Union[ConfigDict, dict], path: str) -> None:
    """
    Saves a configuration to a YAML file.

    Args:
        config (ConfigDict or dict): The configuration to be saved.
            If it is a ConfigDict, it will be converted to a dictionary.
        path (str): The path to the YAML file where the configuration will be saved.
    """
    with open(path, "w") as f:
        if isinstance(config, ConfigDict):
            config = config.to_dict()
        yaml.safe_dump(config, f)