Spaces:
Sleeping
Sleeping
MilesCranmer
commited on
Commit
•
b2fc69c
1
Parent(s):
ab9ae60
Enable complex numbers
Browse files- pysr/sr.py +11 -1
- pysr/version.py +2 -2
pysr/sr.py
CHANGED
@@ -498,6 +498,8 @@ class PySRRegressor(MultiOutputMixin, RegressorMixin, BaseEstimator):
|
|
498 |
What precision to use for the data. By default this is `32`
|
499 |
(float32), but you can select `64` or `16` as well, giving
|
500 |
you 64 or 16 bits of floating point precision, respectively.
|
|
|
|
|
501 |
Default is `32`.
|
502 |
random_state : int, Numpy RandomState instance or None
|
503 |
Pass an int for reproducible results across multiple function calls.
|
@@ -1619,7 +1621,15 @@ class PySRRegressor(MultiOutputMixin, RegressorMixin, BaseEstimator):
|
|
1619 |
)
|
1620 |
|
1621 |
# Convert data to desired precision
|
1622 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1623 |
|
1624 |
# This converts the data into a Julia array:
|
1625 |
Main.X = np.array(X, dtype=np_dtype).T
|
|
|
498 |
What precision to use for the data. By default this is `32`
|
499 |
(float32), but you can select `64` or `16` as well, giving
|
500 |
you 64 or 16 bits of floating point precision, respectively.
|
501 |
+
If you pass complex data, the corresponding complex precision
|
502 |
+
will be used (i.e., `64` for complex128, `32` for complex64).
|
503 |
Default is `32`.
|
504 |
random_state : int, Numpy RandomState instance or None
|
505 |
Pass an int for reproducible results across multiple function calls.
|
|
|
1621 |
)
|
1622 |
|
1623 |
# Convert data to desired precision
|
1624 |
+
test_X = np.array(X)
|
1625 |
+
is_real = np.issubdtype(test_X.dtype, np.floating)
|
1626 |
+
is_complex = np.issubdtype(test_X.dtype, np.complexfloating)
|
1627 |
+
if is_real:
|
1628 |
+
np_dtype = {16: np.float16, 32: np.float32, 64: np.float64}[self.precision]
|
1629 |
+
elif is_complex:
|
1630 |
+
np_dtype = {32: np.complex64, 64: np.complex128}[self.precision]
|
1631 |
+
else:
|
1632 |
+
np_dtype = None
|
1633 |
|
1634 |
# This converts the data into a Julia array:
|
1635 |
Main.X = np.array(X, dtype=np_dtype).T
|
pysr/version.py
CHANGED
@@ -1,2 +1,2 @@
|
|
1 |
-
__version__ = "0.
|
2 |
-
__symbolic_regression_jl_version__ = "0.
|
|
|
1 |
+
__version__ = "0.12.0"
|
2 |
+
__symbolic_regression_jl_version__ = "0.16.0"
|