Spaces:
Running
Running
MilesCranmer
commited on
Commit
•
2bd7782
1
Parent(s):
dca10d6
refactor: improved type inference in return values
Browse files- pysr/sr.py +24 -14
pysr/sr.py
CHANGED
@@ -2006,11 +2006,13 @@ class PySRRegressor(MultiOutputMixin, RegressorMixin, BaseEstimator):
|
|
2006 |
X = self._validate_data(X, reset=False)
|
2007 |
|
2008 |
try:
|
2009 |
-
if
|
|
|
2010 |
return np.stack(
|
2011 |
[eq["lambda_format"](X) for eq in best_equation], axis=1
|
2012 |
)
|
2013 |
-
|
|
|
2014 |
except Exception as error:
|
2015 |
raise ValueError(
|
2016 |
"Failed to evaluate the expression. "
|
@@ -2040,9 +2042,11 @@ class PySRRegressor(MultiOutputMixin, RegressorMixin, BaseEstimator):
|
|
2040 |
"""
|
2041 |
self.refresh()
|
2042 |
best_equation = self.get_best(index=index)
|
2043 |
-
if
|
|
|
2044 |
return [eq["sympy_format"] for eq in best_equation]
|
2045 |
-
|
|
|
2046 |
|
2047 |
def latex(self, index=None, precision=3):
|
2048 |
"""
|
@@ -2102,9 +2106,11 @@ class PySRRegressor(MultiOutputMixin, RegressorMixin, BaseEstimator):
|
|
2102 |
self.set_params(output_jax_format=True)
|
2103 |
self.refresh()
|
2104 |
best_equation = self.get_best(index=index)
|
2105 |
-
if
|
|
|
2106 |
return [eq["jax_format"] for eq in best_equation]
|
2107 |
-
|
|
|
2108 |
|
2109 |
def pytorch(self, index=None):
|
2110 |
"""
|
@@ -2132,9 +2138,10 @@ class PySRRegressor(MultiOutputMixin, RegressorMixin, BaseEstimator):
|
|
2132 |
self.set_params(output_torch_format=True)
|
2133 |
self.refresh()
|
2134 |
best_equation = self.get_best(index=index)
|
2135 |
-
if
|
|
|
|
|
2136 |
return [eq["torch_format"] for eq in best_equation]
|
2137 |
-
return best_equation["torch_format"]
|
2138 |
|
2139 |
def _read_equation_file(self):
|
2140 |
"""Read the hall of fame file created by `SymbolicRegression.jl`."""
|
@@ -2233,10 +2240,8 @@ class PySRRegressor(MultiOutputMixin, RegressorMixin, BaseEstimator):
|
|
2233 |
lastComplexity = 0
|
2234 |
sympy_format = []
|
2235 |
lambda_format = []
|
2236 |
-
|
2237 |
-
|
2238 |
-
if self.output_torch_format:
|
2239 |
-
torch_format = []
|
2240 |
|
2241 |
for _, eqn_row in output.iterrows():
|
2242 |
eqn = pysr2sympy(
|
@@ -2348,7 +2353,7 @@ class PySRRegressor(MultiOutputMixin, RegressorMixin, BaseEstimator):
|
|
2348 |
"""
|
2349 |
self.refresh()
|
2350 |
|
2351 |
-
if self.
|
2352 |
if indices is not None:
|
2353 |
assert isinstance(indices, list)
|
2354 |
assert isinstance(indices[0], list)
|
@@ -2357,7 +2362,7 @@ class PySRRegressor(MultiOutputMixin, RegressorMixin, BaseEstimator):
|
|
2357 |
table_string = sympy2multilatextable(
|
2358 |
self.equations_, indices=indices, precision=precision, columns=columns
|
2359 |
)
|
2360 |
-
|
2361 |
if indices is not None:
|
2362 |
assert isinstance(indices, list)
|
2363 |
assert isinstance(indices[0], int)
|
@@ -2365,6 +2370,11 @@ class PySRRegressor(MultiOutputMixin, RegressorMixin, BaseEstimator):
|
|
2365 |
table_string = sympy2latextable(
|
2366 |
self.equations_, indices=indices, precision=precision, columns=columns
|
2367 |
)
|
|
|
|
|
|
|
|
|
|
|
2368 |
|
2369 |
preamble_string = [
|
2370 |
r"\usepackage{breqn}",
|
|
|
2006 |
X = self._validate_data(X, reset=False)
|
2007 |
|
2008 |
try:
|
2009 |
+
if isinstance(best_equation, list):
|
2010 |
+
assert self.nout_ > 1
|
2011 |
return np.stack(
|
2012 |
[eq["lambda_format"](X) for eq in best_equation], axis=1
|
2013 |
)
|
2014 |
+
else:
|
2015 |
+
return best_equation["lambda_format"](X)
|
2016 |
except Exception as error:
|
2017 |
raise ValueError(
|
2018 |
"Failed to evaluate the expression. "
|
|
|
2042 |
"""
|
2043 |
self.refresh()
|
2044 |
best_equation = self.get_best(index=index)
|
2045 |
+
if isinstance(best_equation, list):
|
2046 |
+
assert self.nout_ > 1
|
2047 |
return [eq["sympy_format"] for eq in best_equation]
|
2048 |
+
else:
|
2049 |
+
return best_equation["sympy_format"]
|
2050 |
|
2051 |
def latex(self, index=None, precision=3):
|
2052 |
"""
|
|
|
2106 |
self.set_params(output_jax_format=True)
|
2107 |
self.refresh()
|
2108 |
best_equation = self.get_best(index=index)
|
2109 |
+
if isinstance(best_equation, list):
|
2110 |
+
assert self.nout_ > 1
|
2111 |
return [eq["jax_format"] for eq in best_equation]
|
2112 |
+
else:
|
2113 |
+
return best_equation["jax_format"]
|
2114 |
|
2115 |
def pytorch(self, index=None):
|
2116 |
"""
|
|
|
2138 |
self.set_params(output_torch_format=True)
|
2139 |
self.refresh()
|
2140 |
best_equation = self.get_best(index=index)
|
2141 |
+
if isinstance(best_equation, pd.Series):
|
2142 |
+
return best_equation["torch_format"]
|
2143 |
+
else:
|
2144 |
return [eq["torch_format"] for eq in best_equation]
|
|
|
2145 |
|
2146 |
def _read_equation_file(self):
|
2147 |
"""Read the hall of fame file created by `SymbolicRegression.jl`."""
|
|
|
2240 |
lastComplexity = 0
|
2241 |
sympy_format = []
|
2242 |
lambda_format = []
|
2243 |
+
jax_format = []
|
2244 |
+
torch_format = []
|
|
|
|
|
2245 |
|
2246 |
for _, eqn_row in output.iterrows():
|
2247 |
eqn = pysr2sympy(
|
|
|
2353 |
"""
|
2354 |
self.refresh()
|
2355 |
|
2356 |
+
if isinstance(self.equations_, list):
|
2357 |
if indices is not None:
|
2358 |
assert isinstance(indices, list)
|
2359 |
assert isinstance(indices[0], list)
|
|
|
2362 |
table_string = sympy2multilatextable(
|
2363 |
self.equations_, indices=indices, precision=precision, columns=columns
|
2364 |
)
|
2365 |
+
elif isinstance(self.equations_, pd.DataFrame):
|
2366 |
if indices is not None:
|
2367 |
assert isinstance(indices, list)
|
2368 |
assert isinstance(indices[0], int)
|
|
|
2370 |
table_string = sympy2latextable(
|
2371 |
self.equations_, indices=indices, precision=precision, columns=columns
|
2372 |
)
|
2373 |
+
else:
|
2374 |
+
raise ValueError(
|
2375 |
+
"Invalid type for equations_ to pass to `latex_table`. "
|
2376 |
+
"Expected a DataFrame or a list of DataFrames."
|
2377 |
+
)
|
2378 |
|
2379 |
preamble_string = [
|
2380 |
r"\usepackage{breqn}",
|