Spaces:
Sleeping
Sleeping
MilesCranmer
commited on
Commit
•
e394c0e
1
Parent(s):
ef3bdd3
Add missing sympy mappings in test
Browse files- README.md +3 -1
- test/test.py +1 -0
README.md
CHANGED
@@ -93,6 +93,7 @@ One can define custom operators in Julia by passing a string:
|
|
93 |
```python
|
94 |
equations = pysr.pysr(X, y, niterations=100,
|
95 |
binary_operators=["mult", "plus", "special(x, y) = x^2 + y"],
|
|
|
96 |
unary_operators=["cos"])
|
97 |
```
|
98 |
|
@@ -100,7 +101,8 @@ Now, the symbolic regression code can search using this `special` function
|
|
100 |
that squares its left argument and adds it to its right. Make sure
|
101 |
all passed functions are valid Julia code, and take one (unary)
|
102 |
or two (binary) float32 scalars as input, and output a float32. Operators
|
103 |
-
are automatically vectorized.
|
|
|
104 |
|
105 |
One can also edit `operators.jl`. See below for more options.
|
106 |
|
|
|
93 |
```python
|
94 |
equations = pysr.pysr(X, y, niterations=100,
|
95 |
binary_operators=["mult", "plus", "special(x, y) = x^2 + y"],
|
96 |
+
extra_sympy_mappings={'special': lambda x, y: x**2 + y},
|
97 |
unary_operators=["cos"])
|
98 |
```
|
99 |
|
|
|
101 |
that squares its left argument and adds it to its right. Make sure
|
102 |
all passed functions are valid Julia code, and take one (unary)
|
103 |
or two (binary) float32 scalars as input, and output a float32. Operators
|
104 |
+
are automatically vectorized. We also define `extra_sympy_mappings`,
|
105 |
+
so that the SymPy code can understand the output equation from Julia.
|
106 |
|
107 |
One can also edit `operators.jl`. See below for more options.
|
108 |
|
test/test.py
CHANGED
@@ -13,6 +13,7 @@ print("Test 2 - test custom operator")
|
|
13 |
y = X[:, 0]**2
|
14 |
equations = pysr(X, y,
|
15 |
unary_operators=["square(x) = x^2"], binary_operators=["plus"],
|
|
|
16 |
niterations=10)
|
17 |
print(equations)
|
18 |
assert equations.iloc[-1]['MSE'] < 1e-10
|
|
|
13 |
y = X[:, 0]**2
|
14 |
equations = pysr(X, y,
|
15 |
unary_operators=["square(x) = x^2"], binary_operators=["plus"],
|
16 |
+
extra_sympy_mappings={'square': lambda x: x**2},
|
17 |
niterations=10)
|
18 |
print(equations)
|
19 |
assert equations.iloc[-1]['MSE'] < 1e-10
|