MilesCranmer commited on
Commit
e394c0e
·
1 Parent(s): ef3bdd3

Add missing sympy mappings in test

Browse files
Files changed (2) hide show
  1. README.md +3 -1
  2. test/test.py +1 -0
README.md CHANGED
@@ -93,6 +93,7 @@ One can define custom operators in Julia by passing a string:
93
  ```python
94
  equations = pysr.pysr(X, y, niterations=100,
95
  binary_operators=["mult", "plus", "special(x, y) = x^2 + y"],
 
96
  unary_operators=["cos"])
97
  ```
98
 
@@ -100,7 +101,8 @@ Now, the symbolic regression code can search using this `special` function
100
  that squares its left argument and adds it to its right. Make sure
101
  all passed functions are valid Julia code, and take one (unary)
102
  or two (binary) float32 scalars as input, and output a float32. Operators
103
- are automatically vectorized.
 
104
 
105
  One can also edit `operators.jl`. See below for more options.
106
 
 
93
  ```python
94
  equations = pysr.pysr(X, y, niterations=100,
95
  binary_operators=["mult", "plus", "special(x, y) = x^2 + y"],
96
+ extra_sympy_mappings={'special': lambda x, y: x**2 + y},
97
  unary_operators=["cos"])
98
  ```
99
 
 
101
  that squares its left argument and adds it to its right. Make sure
102
  all passed functions are valid Julia code, and take one (unary)
103
  or two (binary) float32 scalars as input, and output a float32. Operators
104
+ are automatically vectorized. We also define `extra_sympy_mappings`,
105
+ so that the SymPy code can understand the output equation from Julia.
106
 
107
  One can also edit `operators.jl`. See below for more options.
108
 
test/test.py CHANGED
@@ -13,6 +13,7 @@ print("Test 2 - test custom operator")
13
  y = X[:, 0]**2
14
  equations = pysr(X, y,
15
  unary_operators=["square(x) = x^2"], binary_operators=["plus"],
 
16
  niterations=10)
17
  print(equations)
18
  assert equations.iloc[-1]['MSE'] < 1e-10
 
13
  y = X[:, 0]**2
14
  equations = pysr(X, y,
15
  unary_operators=["square(x) = x^2"], binary_operators=["plus"],
16
+ extra_sympy_mappings={'square': lambda x: x**2},
17
  niterations=10)
18
  print(equations)
19
  assert equations.iloc[-1]['MSE'] < 1e-10