PySR / pysr /export_sympy.py
MilesCranmer's picture
fix: deal with edgecase in sympify
36e5dde unverified
raw
history blame
3.43 kB
"""Define utilities to export to sympy"""
from typing import Callable, Dict, List, Optional
import sympy
from sympy import sympify
from .utils import ArrayLike
sympy_mappings = {
"div": lambda x, y: x / y,
"mult": lambda x, y: x * y,
"sqrt": lambda x: sympy.sqrt(x),
"sqrt_abs": lambda x: sympy.sqrt(abs(x)),
"square": lambda x: x**2,
"cube": lambda x: x**3,
"plus": lambda x, y: x + y,
"sub": lambda x, y: x - y,
"neg": lambda x: -x,
"pow": lambda x, y: x**y,
"pow_abs": lambda x, y: abs(x) ** y,
"cos": sympy.cos,
"sin": sympy.sin,
"tan": sympy.tan,
"cosh": sympy.cosh,
"sinh": sympy.sinh,
"tanh": sympy.tanh,
"exp": sympy.exp,
"acos": sympy.acos,
"asin": sympy.asin,
"atan": sympy.atan,
"acosh": lambda x: sympy.acosh(x),
"acosh_abs": lambda x: sympy.acosh(abs(x) + 1),
"asinh": sympy.asinh,
"atanh": lambda x: sympy.atanh(sympy.Mod(x + 1, 2) - sympy.S(1)),
"atanh_clip": lambda x: sympy.atanh(sympy.Mod(x + 1, 2) - sympy.S(1)),
"abs": abs,
"mod": sympy.Mod,
"erf": sympy.erf,
"erfc": sympy.erfc,
"log": lambda x: sympy.log(x),
"log10": lambda x: sympy.log(x, 10),
"log2": lambda x: sympy.log(x, 2),
"log1p": lambda x: sympy.log(x + 1),
"log_abs": lambda x: sympy.log(abs(x)),
"log10_abs": lambda x: sympy.log(abs(x), 10),
"log2_abs": lambda x: sympy.log(abs(x), 2),
"log1p_abs": lambda x: sympy.log(abs(x) + 1),
"floor": sympy.floor,
"ceil": sympy.ceiling,
"sign": sympy.sign,
"gamma": sympy.gamma,
"round": lambda x: sympy.ceiling(x - 0.5),
"max": lambda x, y: sympy.Piecewise((y, x < y), (x, True)),
"min": lambda x, y: sympy.Piecewise((x, x < y), (y, True)),
"greater": lambda x, y: sympy.Piecewise((1.0, x > y), (0.0, True)),
"cond": lambda x, y: sympy.Piecewise((y, x > 0), (0.0, True)),
"logical_or": lambda x, y: sympy.Piecewise((1.0, (x > 0) | (y > 0)), (0.0, True)),
"logical_and": lambda x, y: sympy.Piecewise((1.0, (x > 0) & (y > 0)), (0.0, True)),
"relu": lambda x: sympy.Piecewise((0.0, x < 0), (x, True)),
}
def create_sympy_symbols_map(
feature_names_in: ArrayLike[str],
) -> Dict[str, sympy.Symbol]:
return {variable: sympy.Symbol(variable) for variable in feature_names_in}
def create_sympy_symbols(
feature_names_in: ArrayLike[str],
) -> List[sympy.Symbol]:
return [sympy.Symbol(variable) for variable in feature_names_in]
def pysr2sympy(
equation: str,
*,
feature_names_in: Optional[ArrayLike[str]] = None,
extra_sympy_mappings: Optional[Dict[str, Callable]] = None,
):
if feature_names_in is None:
feature_names_in = []
local_sympy_mappings = {
**create_sympy_symbols_map(feature_names_in),
**(extra_sympy_mappings if extra_sympy_mappings is not None else {}),
**sympy_mappings,
}
try:
return sympify(equation, locals=local_sympy_mappings, evaluate=False)
except TypeError as e:
if "got an unexpected keyword argument 'evaluate'" in str(e):
return sympify(equation, locals=local_sympy_mappings)
raise TypeError(f"Error processing equation '{equation}'") from e
def assert_valid_sympy_symbol(var_name: str) -> None:
if var_name in sympy_mappings or var_name in sympy.__dict__.keys():
raise ValueError(f"Variable name {var_name} is already a function name.")