Spaces:
Running
Running
Merge pull request #52 from MilesCranmer/deepsource-fix-147c2d73
Browse filesRefactor unnecessary `else` / `elif` when `if` block has a `return` statement
- pysr/export_jax.py +9 -11
- pysr/sr.py +8 -15
pysr/export_jax.py
CHANGED
@@ -55,21 +55,19 @@ def sympy2jaxtext(expr, parameters, symbols_in):
|
|
55 |
if issubclass(expr.func, sympy.Float):
|
56 |
parameters.append(float(expr))
|
57 |
return f"parameters[{len(parameters) - 1}]"
|
58 |
-
|
59 |
return f"{int(expr)}"
|
60 |
-
|
61 |
return (
|
62 |
f"X[:, {[i for i in range(len(symbols_in)) if symbols_in[i] == expr][0]}]"
|
63 |
)
|
64 |
-
|
65 |
-
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
-
|
70 |
-
|
71 |
-
else:
|
72 |
-
return f'{_func}({", ".join(args)})'
|
73 |
|
74 |
|
75 |
jax_initialized = False
|
|
|
55 |
if issubclass(expr.func, sympy.Float):
|
56 |
parameters.append(float(expr))
|
57 |
return f"parameters[{len(parameters) - 1}]"
|
58 |
+
if issubclass(expr.func, sympy.Integer):
|
59 |
return f"{int(expr)}"
|
60 |
+
if issubclass(expr.func, sympy.Symbol):
|
61 |
return (
|
62 |
f"X[:, {[i for i in range(len(symbols_in)) if symbols_in[i] == expr][0]}]"
|
63 |
)
|
64 |
+
_func = _jnp_func_lookup[expr.func]
|
65 |
+
args = [sympy2jaxtext(arg, parameters, symbols_in) for arg in expr.args]
|
66 |
+
if _func == MUL:
|
67 |
+
return " * ".join(["(" + arg + ")" for arg in args])
|
68 |
+
if _func == ADD:
|
69 |
+
return " + ".join(["(" + arg + ")" for arg in args])
|
70 |
+
return f'{_func}({", ".join(args)})'
|
|
|
|
|
71 |
|
72 |
|
73 |
jax_initialized = False
|
pysr/sr.py
CHANGED
@@ -643,10 +643,9 @@ def _make_hyperparams_julia_str(
|
|
643 |
def tuple_fix(ops):
|
644 |
if len(ops) > 1:
|
645 |
return ", ".join(ops)
|
646 |
-
|
647 |
return ""
|
648 |
-
|
649 |
-
return ops[0] + ","
|
650 |
|
651 |
def_hyperparams += f"""\n
|
652 |
plus=(+)
|
@@ -1024,8 +1023,7 @@ def get_hof(
|
|
1024 |
|
1025 |
if multioutput:
|
1026 |
return ret_outputs
|
1027 |
-
|
1028 |
-
return ret_outputs[0]
|
1029 |
|
1030 |
|
1031 |
def best_row(equations=None):
|
@@ -1036,8 +1034,7 @@ def best_row(equations=None):
|
|
1036 |
equations = get_hof()
|
1037 |
if isinstance(equations, list):
|
1038 |
return [eq.iloc[np.argmax(eq["score"])] for eq in equations]
|
1039 |
-
|
1040 |
-
return equations.iloc[np.argmax(equations["score"])]
|
1041 |
|
1042 |
|
1043 |
def best_tex(equations=None):
|
@@ -1050,8 +1047,7 @@ def best_tex(equations=None):
|
|
1050 |
return [
|
1051 |
sympy.latex(best_row(eq)["sympy_format"].simplify()) for eq in equations
|
1052 |
]
|
1053 |
-
|
1054 |
-
return sympy.latex(best_row(equations)["sympy_format"].simplify())
|
1055 |
|
1056 |
|
1057 |
def best(equations=None):
|
@@ -1062,8 +1058,7 @@ def best(equations=None):
|
|
1062 |
equations = get_hof()
|
1063 |
if isinstance(equations, list):
|
1064 |
return [best_row(eq)["sympy_format"].simplify() for eq in equations]
|
1065 |
-
|
1066 |
-
return best_row(equations)["sympy_format"].simplify()
|
1067 |
|
1068 |
|
1069 |
def best_callable(equations=None):
|
@@ -1074,8 +1069,7 @@ def best_callable(equations=None):
|
|
1074 |
equations = get_hof()
|
1075 |
if isinstance(equations, list):
|
1076 |
return [best_row(eq)["lambda_format"] for eq in equations]
|
1077 |
-
|
1078 |
-
return best_row(equations)["lambda_format"]
|
1079 |
|
1080 |
|
1081 |
def _escape_filename(filename):
|
@@ -1113,5 +1107,4 @@ class CallableEquation(object):
|
|
1113 |
def __call__(self, X):
|
1114 |
if self._selection is not None:
|
1115 |
return self._lambda(*X[:, self._selection].T)
|
1116 |
-
|
1117 |
-
return self._lambda(*X.T)
|
|
|
643 |
def tuple_fix(ops):
|
644 |
if len(ops) > 1:
|
645 |
return ", ".join(ops)
|
646 |
+
if len(ops) == 0:
|
647 |
return ""
|
648 |
+
return ops[0] + ","
|
|
|
649 |
|
650 |
def_hyperparams += f"""\n
|
651 |
plus=(+)
|
|
|
1023 |
|
1024 |
if multioutput:
|
1025 |
return ret_outputs
|
1026 |
+
return ret_outputs[0]
|
|
|
1027 |
|
1028 |
|
1029 |
def best_row(equations=None):
|
|
|
1034 |
equations = get_hof()
|
1035 |
if isinstance(equations, list):
|
1036 |
return [eq.iloc[np.argmax(eq["score"])] for eq in equations]
|
1037 |
+
return equations.iloc[np.argmax(equations["score"])]
|
|
|
1038 |
|
1039 |
|
1040 |
def best_tex(equations=None):
|
|
|
1047 |
return [
|
1048 |
sympy.latex(best_row(eq)["sympy_format"].simplify()) for eq in equations
|
1049 |
]
|
1050 |
+
return sympy.latex(best_row(equations)["sympy_format"].simplify())
|
|
|
1051 |
|
1052 |
|
1053 |
def best(equations=None):
|
|
|
1058 |
equations = get_hof()
|
1059 |
if isinstance(equations, list):
|
1060 |
return [best_row(eq)["sympy_format"].simplify() for eq in equations]
|
1061 |
+
return best_row(equations)["sympy_format"].simplify()
|
|
|
1062 |
|
1063 |
|
1064 |
def best_callable(equations=None):
|
|
|
1069 |
equations = get_hof()
|
1070 |
if isinstance(equations, list):
|
1071 |
return [best_row(eq)["lambda_format"] for eq in equations]
|
1072 |
+
return best_row(equations)["lambda_format"]
|
|
|
1073 |
|
1074 |
|
1075 |
def _escape_filename(filename):
|
|
|
1107 |
def __call__(self, X):
|
1108 |
if self._selection is not None:
|
1109 |
return self._lambda(*X[:, self._selection].T)
|
1110 |
+
return self._lambda(*X.T)
|
|