Spaces:
Running
Running
File size: 2,398 Bytes
4978916 ed84fe9 4978916 ed84fe9 4978916 dfc6d2e 4978916 e9fbda8 4978916 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 |
# Load YAML file param_groupings.yml:
from yaml import safe_load
import sys
sys.path.append("..")
from pysr import PySRRegressor
import pysr
import re
from docstring_parser import parse
found_params = []
def str_param_groups(param_groupings, params, cur_heading=2):
global found_params
# Recursively print the parameter descriptions, defaults,
# with headings from the param groupings dict.
if isinstance(param_groupings, list):
return "\n\n".join(
str_param_groups(param, params, cur_heading) for param in param_groupings
)
elif isinstance(param_groupings, dict):
for heading, param_grouping in param_groupings.items():
return (
f"{'#' * cur_heading} {heading}"
+ "\n\n"
+ str_param_groups(param_grouping, params, cur_heading + 1)
)
elif isinstance(param_groupings, str):
found_params.append(param_groupings)
default_value = re.search(
r"Default is `(.*)`", params[param_groupings].description
)
clean_desc = re.sub(r"Default is .*", "", params[param_groupings].description)
# Prepend every line with 4 spaces:
clean_desc = "\n".join(" " + line for line in clean_desc.splitlines())
return (
f" - **`{param_groupings}`**"
+ "\n\n"
+ clean_desc
+ (
"\n\n " + f"*Default:* `{default_value.group(1)}`"
if default_value
else ""
)
)
else:
raise TypeError(f"Unexpected type {type(param_groupings)}")
if __name__ == "__main__":
# This is the path to the param_groupings.yml file
# relative to the current file.
path = "../pysr/param_groupings.yml"
with open(path, "r") as f:
param_groupings = safe_load(f)
# This is basically a dict of lists and dicts.
# Let's load in the parameter descriptions from the docstring of PySRRegressor:
raw_params = parse(PySRRegressor.__doc__).params
params = {
param.arg_name: param
for param in raw_params
if param.arg_name[-1] != "_" and param.arg_name != "**kwargs"
}
output = str_param_groups(param_groupings, params, cur_heading=3)
assert len(set(found_params) ^ set(params.keys())) == 0
print("## PySRRegressor Parameters")
print(output)
|