File size: 2,423 Bytes
4978916
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ed84fe9
 
 
4978916
ed84fe9
4978916
 
dfc6d2e
 
 
 
 
4978916
 
 
 
 
 
 
 
e9fbda8
4978916
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
# Load YAML file param_groupings.yml:
from pathlib import Path
from yaml import safe_load
import sys

sys.path.append("..")
from pysr import PySRRegressor
import pysr
import re
from docstring_parser import parse

found_params = []


def str_param_groups(param_groupings, params, cur_heading=2):
    global found_params
    # Recursively print the parameter descriptions, defaults,
    # with headings from the param groupings dict.
    if isinstance(param_groupings, list):
        return "\n\n".join(
            str_param_groups(param, params, cur_heading) for param in param_groupings
        )
    elif isinstance(param_groupings, dict):
        for heading, param_grouping in param_groupings.items():
            return (
                f"{'#' * cur_heading} {heading}"
                + "\n\n"
                + str_param_groups(param_grouping, params, cur_heading + 1)
            )
    elif isinstance(param_groupings, str):
        found_params.append(param_groupings)

        default_value = re.search(
            r"Default is `(.*)`", params[param_groupings].description
        )
        clean_desc = re.sub(r"Default is .*", "", params[param_groupings].description)
        # Prepend every line with 4 spaces:
        clean_desc = "\n".join("    " + line for line in clean_desc.splitlines())
        return (
            f"  - **`{param_groupings}`**"
            + "\n\n"
            + clean_desc
            + (
                "\n\n    " + f"*Default:* `{default_value.group(1)}`"
                if default_value
                else ""
            )
        )
    else:
        raise TypeError(f"Unexpected type {type(param_groupings)}")


if __name__ == "__main__":
    # This is the path to the param_groupings.yml file
    # relative to the current file.
    path = "../pysr/param_groupings.yml"
    with open(path, "r") as f:
        param_groupings = safe_load(f)

    # This is basically a dict of lists and dicts.

    # Let's load in the parameter descriptions from the docstring of PySRRegressor:
    raw_params = parse(PySRRegressor.__doc__).params
    params = {
        param.arg_name: param
        for param in raw_params
        if param.arg_name[-1] != "_" and param.arg_name != "**kwargs"
    }

    output = str_param_groups(param_groupings, params, cur_heading=3)
    assert len(set(found_params) ^ set(params.keys())) == 0
    print("## PySRRegressor Parameters")
    print(output)