File size: 2,217 Bytes
4978916
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
# Load YAML file param_groupings.yml:
from pathlib import Path
from yaml import safe_load
import sys

sys.path.append("..")
from pysr import PySRRegressor
import pysr
import re
from docstring_parser import parse

found_params = []


def str_param_groups(param_groupings, params, cur_heading=2):
    global found_params
    # Recursively print the parameter descriptions, defaults,
    # with headings from the param groupings dict.
    if isinstance(param_groupings, list):
        return "\n\n".join(
            str_param_groups(param, params, cur_heading) for param in param_groupings
        )
    elif isinstance(param_groupings, dict):
        for heading, param_grouping in param_groupings.items():
            return (
                f"{'#' * cur_heading} {heading}"
                + "\n\n"
                + str_param_groups(param_grouping, params, cur_heading + 1)
            )
    elif isinstance(param_groupings, str):
        found_params.append(param_groupings)

        clean_desc = re.sub(r"Default is .*", "", params[param_groupings].description)
        default_value = re.search(
            r"Default is `(.*)`", params[param_groupings].description
        )
        return (
            f"**`{param_groupings}`**"
            + "\n\n"
            + clean_desc
            + ("\n" + f"*Default:* `{default_value.group(1)}`" if default_value else "")
        )
    else:
        raise TypeError(f"Unexpected type {type(param_groupings)}")


if __name__ == "__main__":
    # This is the path to the param_groupings.yml file
    # relative to the current file.
    path = "param_groupings.yml"
    with open(path, "r") as f:
        param_groupings = safe_load(f)

    # This is basically a dict of lists and dicts.

    # Let's load in the parameter descriptions from the docstring of PySRRegressor:
    raw_params = parse(PySRRegressor.__doc__).params
    params = {
        param.arg_name: param
        for param in raw_params
        if param.arg_name[-1] != "_" and param.arg_name != "**kwargs"
    }

    output = str_param_groups(param_groupings, params, cur_heading=3)
    assert len(set(found_params) ^ set(params.keys())) == 0
    print("## PySRRegressor Parameters")
    print(output)