File size: 1,738 Bytes
5a01e6f
75c23d4
976f8d8
505af8d
976f8d8
5a01e6f
 
505af8d
 
5a01e6f
 
b2d7f41
 
 
 
5a01e6f
 
 
505af8d
 
 
 
b2d7f41
5a01e6f
 
 
 
 
 
 
 
 
 
 
 
b2d7f41
5a01e6f
505af8d
5a01e6f
505af8d
83d8e67
 
 
 
 
 
 
505af8d
5a01e6f
bd90cfc
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
"""Code for exporting discovered expressions to numpy"""

import warnings
from typing import List, Union

import numpy as np
import pandas as pd
from numpy.typing import NDArray
from sympy import Expr, Symbol, lambdify


def sympy2numpy(eqn, sympy_symbols, *, selection=None):
    return CallableEquation(eqn, sympy_symbols, selection=selection)


class CallableEquation:
    """Simple wrapper for numpy lambda functions built with sympy"""

    _sympy: Expr
    _sympy_symbols: List[Symbol]
    _selection: Union[NDArray[np.bool_], None]

    def __init__(self, eqn, sympy_symbols, selection=None):
        self._sympy = eqn
        self._sympy_symbols = sympy_symbols
        self._selection = selection

    def __repr__(self):
        return f"PySRFunction(X=>{self._sympy})"

    def __call__(self, X):
        expected_shape = (X.shape[0],)
        if isinstance(X, pd.DataFrame):
            # Lambda function takes as argument:
            return self._lambda(
                **{k: X[k].values for k in map(str, self._sympy_symbols)}
            ) * np.ones(expected_shape)

        if self._selection is not None:
            if X.shape[1] != self._selection.sum():
                warnings.warn(
                    "`X` should be of shape (n_samples, len(self._selection)). "
                    "Automatically filtering `X` to selection. "
                    "Note: Filtered `X` column order may not match column order in fit "
                    "this may lead to incorrect predictions and other errors."
                )
                X = X[:, self._selection]

        return self._lambda(*X.T) * np.ones(expected_shape)

    @property
    def _lambda(self):
        return lambdify(self._sympy_symbols, self._sympy)