Spaces:
Sleeping
Sleeping
MilesCranmer
commited on
Commit
•
358f0ab
1
Parent(s):
40f498c
Remove additional changes to internal parameters
Browse files- pysr/sr.py +16 -14
pysr/sr.py
CHANGED
@@ -80,7 +80,8 @@ def pysr(X, y, weights=None, **kwargs): # pragma: no cover
|
|
80 |
return model.equations
|
81 |
|
82 |
|
83 |
-
def
|
|
|
84 |
for op in unary_operators:
|
85 |
if op not in constraints:
|
86 |
constraints[op] = -1
|
@@ -101,10 +102,13 @@ def _handle_constraints(binary_operators, unary_operators, constraints):
|
|
101 |
constraints[op][1],
|
102 |
constraints[op][0],
|
103 |
)
|
|
|
104 |
|
105 |
|
106 |
-
def
|
107 |
global Main
|
|
|
|
|
108 |
for op_list in [binary_operators, unary_operators]:
|
109 |
for i, op in enumerate(op_list):
|
110 |
is_user_defined_operator = "(" in op
|
@@ -123,6 +127,7 @@ def _create_inline_operators(binary_operators, unary_operators):
|
|
123 |
"Only alphanumeric characters, numbers, and underscores are allowed."
|
124 |
)
|
125 |
op_list[i] = function_name
|
|
|
126 |
|
127 |
|
128 |
def _check_assertions(
|
@@ -1214,17 +1219,18 @@ class PySRRegressor(BaseEstimator, RegressorMixin, MultiOutputMixin):
|
|
1214 |
Main.pow = Main.eval("(^)")
|
1215 |
Main.div = Main.eval("(/)")
|
1216 |
|
1217 |
-
|
|
|
1218 |
binary_operators=self.binary_operators, unary_operators=self.unary_operators
|
1219 |
)
|
1220 |
-
|
1221 |
-
binary_operators=
|
1222 |
-
unary_operators=
|
1223 |
constraints=self.constraints,
|
1224 |
)
|
1225 |
|
1226 |
-
una_constraints = [
|
1227 |
-
bin_constraints = [
|
1228 |
|
1229 |
# Parse dict into Julia Dict for nested constraints::
|
1230 |
if self.nested_constraints is not None:
|
@@ -1265,12 +1271,8 @@ class PySRRegressor(BaseEstimator, RegressorMixin, MultiOutputMixin):
|
|
1265 |
# Call to Julia backend.
|
1266 |
# See https://github.com/search?q=%22function+Options%22+repo%3AMilesCranmer%2FSymbolicRegression.jl+path%3A%2Fsrc%2F+filename%3AOptions.jl+language%3AJulia&type=Code
|
1267 |
options = Main.Options(
|
1268 |
-
binary_operators=Main.eval(
|
1269 |
-
|
1270 |
-
),
|
1271 |
-
unary_operators=Main.eval(
|
1272 |
-
str(tuple(self.unary_operators)).replace("'", "")
|
1273 |
-
),
|
1274 |
bin_constraints=bin_constraints,
|
1275 |
una_constraints=una_constraints,
|
1276 |
complexity_of_operators=complexity_of_operators,
|
|
|
80 |
return model.equations
|
81 |
|
82 |
|
83 |
+
def _process_constraints(binary_operators, unary_operators, constraints):
|
84 |
+
constraints = constraints.copy()
|
85 |
for op in unary_operators:
|
86 |
if op not in constraints:
|
87 |
constraints[op] = -1
|
|
|
102 |
constraints[op][1],
|
103 |
constraints[op][0],
|
104 |
)
|
105 |
+
return constraints
|
106 |
|
107 |
|
108 |
+
def _maybe_create_inline_operators(binary_operators, unary_operators):
|
109 |
global Main
|
110 |
+
binary_operators = binary_operators.copy()
|
111 |
+
unary_operators = unary_operators.copy()
|
112 |
for op_list in [binary_operators, unary_operators]:
|
113 |
for i, op in enumerate(op_list):
|
114 |
is_user_defined_operator = "(" in op
|
|
|
127 |
"Only alphanumeric characters, numbers, and underscores are allowed."
|
128 |
)
|
129 |
op_list[i] = function_name
|
130 |
+
return binary_operators, unary_operators
|
131 |
|
132 |
|
133 |
def _check_assertions(
|
|
|
1219 |
Main.pow = Main.eval("(^)")
|
1220 |
Main.div = Main.eval("(/)")
|
1221 |
|
1222 |
+
# TODO(mcranmer): These functions should be part of this class.
|
1223 |
+
binary_operators, unary_operators = _maybe_create_inline_operators(
|
1224 |
binary_operators=self.binary_operators, unary_operators=self.unary_operators
|
1225 |
)
|
1226 |
+
constraints = _process_constraints(
|
1227 |
+
binary_operators=binary_operators,
|
1228 |
+
unary_operators=unary_operators,
|
1229 |
constraints=self.constraints,
|
1230 |
)
|
1231 |
|
1232 |
+
una_constraints = [constraints[op] for op in unary_operators]
|
1233 |
+
bin_constraints = [constraints[op] for op in binary_operators]
|
1234 |
|
1235 |
# Parse dict into Julia Dict for nested constraints::
|
1236 |
if self.nested_constraints is not None:
|
|
|
1271 |
# Call to Julia backend.
|
1272 |
# See https://github.com/search?q=%22function+Options%22+repo%3AMilesCranmer%2FSymbolicRegression.jl+path%3A%2Fsrc%2F+filename%3AOptions.jl+language%3AJulia&type=Code
|
1273 |
options = Main.Options(
|
1274 |
+
binary_operators=Main.eval(str(tuple(binary_operators)).replace("'", "")),
|
1275 |
+
unary_operators=Main.eval(str(tuple(unary_operators)).replace("'", "")),
|
|
|
|
|
|
|
|
|
1276 |
bin_constraints=bin_constraints,
|
1277 |
una_constraints=una_constraints,
|
1278 |
complexity_of_operators=complexity_of_operators,
|