File size: 2,492 Bytes
0b32ad6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
"""
Parse command-line arguments into override dictionary

Authors
  * Leo 2022
"""

import logging

logger = logging.getLogger(__name__)

__all__ = [
    "parse_overrides",
]


def parse_override(string):
    """
    Example usgae:
        -o "optimizer.lr=1.0e-3,,optimizer.name='AdamW',,runner.eval_dataloaders=['dev', 'test']"

    Convert to:
        {
            "optimizer": {"lr": 1.0e-3, "name": "AdamW"},
            "runner": {"eval_dataloaders": ["dev", "test"]}
        }
    """
    options = string.split(",,")
    config = {}
    for option in options:
        option = option.strip()
        key, value_str = option.split("=")
        key, value_str = key.strip(), value_str.strip()
        remaining = key.split(".")

        try:
            value = eval(value_str)
        except:
            value = value_str

        logger.info(f"{key} = {value}")

        target_config = config
        for i, field_name in enumerate(remaining):
            if i == len(remaining) - 1:
                target_config[field_name] = value
            else:
                target_config.setdefault(field_name, {})
                target_config = target_config[field_name]
    return config


def parse_overrides(options: list):
    """
    Example usgae:
        [
            "--optimizer.lr",
            "1.0e-3",
            "--optimizer.name",
            "AdamW",
            "--runner.eval_dataloaders",
            "['dev', 'test']",
        ]

    Convert to:
        {
            "optimizer": {"lr": 1.0e-3, "name": "AdamW"},
            "runner": {"eval_dataloaders": ["dev", "test"]}
        }
    """
    config = {}
    for position in range(0, len(options), 2):
        key: str = options[position]
        assert key.startswith("--")
        key = key.strip("--")
        value_str: str = options[position + 1]
        key, value_str = key.strip(), value_str.strip()
        remaining = key.split(".")

        try:
            value = eval(value_str)
        except Exception as e:
            if "newdict" in value_str or "Container" in value_str:
                raise
            value = value_str

        logger.debug(f"{key} = {value}")

        target_config = config
        for i, field_name in enumerate(remaining):
            if i == len(remaining) - 1:
                target_config[field_name] = value
            else:
                target_config.setdefault(field_name, {})
                target_config = target_config[field_name]
    return config