MilesCranmer commited on
Commit
bbec288
1 Parent(s): 5969aa3

Force user to define extra_sympy_mappings at init

Browse files
Files changed (1) hide show
  1. pysr/sr.py +17 -4
pysr/sr.py CHANGED
@@ -115,7 +115,9 @@ def _process_constraints(binary_operators, unary_operators, constraints):
115
  return constraints
116
 
117
 
118
- def _maybe_create_inline_operators(binary_operators, unary_operators):
 
 
119
  global Main
120
  binary_operators = binary_operators.copy()
121
  unary_operators = unary_operators.copy()
@@ -137,6 +139,14 @@ def _maybe_create_inline_operators(binary_operators, unary_operators):
137
  "Only alphanumeric characters, numbers, "
138
  "and underscores are allowed."
139
  )
 
 
 
 
 
 
 
 
140
  op_list[i] = function_name
141
  return binary_operators, unary_operators
142
 
@@ -1488,7 +1498,9 @@ class PySRRegressor(MultiOutputMixin, RegressorMixin, BaseEstimator):
1488
 
1489
  # TODO(mcranmer): These functions should be part of this class.
1490
  binary_operators, unary_operators = _maybe_create_inline_operators(
1491
- binary_operators=binary_operators, unary_operators=unary_operators
 
 
1492
  )
1493
  constraints = _process_constraints(
1494
  binary_operators=binary_operators,
@@ -1848,8 +1860,9 @@ class PySRRegressor(MultiOutputMixin, RegressorMixin, BaseEstimator):
1848
  raise ValueError(
1849
  "Failed to evaluate the expression. "
1850
  "If you are using a custom operator, make sure to define it in :param`extra_sympy_mappings`, "
1851
- "e.g., `model.set_params(extra_sympy_mappings={'inv': lambda x: 1 / x})`. You can then "
1852
- "run `model.refresh()` to re-load the expressions."
 
1853
  ) from error
1854
 
1855
  def sympy(self, index=None):
 
115
  return constraints
116
 
117
 
118
+ def _maybe_create_inline_operators(
119
+ binary_operators, unary_operators, extra_sympy_mappings
120
+ ):
121
  global Main
122
  binary_operators = binary_operators.copy()
123
  unary_operators = unary_operators.copy()
 
139
  "Only alphanumeric characters, numbers, "
140
  "and underscores are allowed."
141
  )
142
+ if not function_name in extra_sympy_mappings:
143
+ raise ValueError(
144
+ f"Custom function {function_name} is not defined in :param`extra_sympy_mappings`. "
145
+ "You can define it with, "
146
+ "e.g., `model.set_params(extra_sympy_mappings={'inv': lambda x: 1/x})`, where "
147
+ "`lambda x: 1/x` is a valid SymPy function defining the operator. "
148
+ "You can also define these at initialization time."
149
+ )
150
  op_list[i] = function_name
151
  return binary_operators, unary_operators
152
 
 
1498
 
1499
  # TODO(mcranmer): These functions should be part of this class.
1500
  binary_operators, unary_operators = _maybe_create_inline_operators(
1501
+ binary_operators=binary_operators,
1502
+ unary_operators=unary_operators,
1503
+ extra_sympy_mappings=self.extra_sympy_mappings,
1504
  )
1505
  constraints = _process_constraints(
1506
  binary_operators=binary_operators,
 
1860
  raise ValueError(
1861
  "Failed to evaluate the expression. "
1862
  "If you are using a custom operator, make sure to define it in :param`extra_sympy_mappings`, "
1863
+ "e.g., `model.set_params(extra_sympy_mappings={'inv': lambda x: 1/x})`, where "
1864
+ "`lambda x: 1/x` is a valid SymPy function defining the operator. "
1865
+ "You can then run `model.refresh()` to re-load the expressions."
1866
  ) from error
1867
 
1868
  def sympy(self, index=None):