MilesCranmer commited on
Commit
0857108
1 Parent(s): e8c7db3

Fix non-floating dtypes

Browse files
Files changed (1) hide show
  1. pysr/sr.py +2 -4
pysr/sr.py CHANGED
@@ -1622,14 +1622,12 @@ class PySRRegressor(MultiOutputMixin, RegressorMixin, BaseEstimator):
1622
 
1623
  # Convert data to desired precision
1624
  test_X = np.array(X)
1625
- is_real = np.issubdtype(test_X.dtype, np.floating)
1626
  is_complex = np.issubdtype(test_X.dtype, np.complexfloating)
 
1627
  if is_real:
1628
  np_dtype = {16: np.float16, 32: np.float32, 64: np.float64}[self.precision]
1629
- elif is_complex:
1630
- np_dtype = {32: np.complex64, 64: np.complex128}[self.precision]
1631
  else:
1632
- np_dtype = None
1633
 
1634
  # This converts the data into a Julia array:
1635
  Main.X = np.array(X, dtype=np_dtype).T
 
1622
 
1623
  # Convert data to desired precision
1624
  test_X = np.array(X)
 
1625
  is_complex = np.issubdtype(test_X.dtype, np.complexfloating)
1626
+ is_real = not is_complex
1627
  if is_real:
1628
  np_dtype = {16: np.float16, 32: np.float32, 64: np.float64}[self.precision]
 
 
1629
  else:
1630
+ np_dtype = {32: np.complex64, 64: np.complex128}[self.precision]
1631
 
1632
  # This converts the data into a Julia array:
1633
  Main.X = np.array(X, dtype=np_dtype).T