File size: 6,438 Bytes
a06bfc4
 
 
 
 
 
 
 
 
 
 
 
 
 
fb950bb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a06bfc4
 
 
9068541
 
 
 
 
fb950bb
 
 
 
a06bfc4
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
#####
# From https://github.com/patrick-kidger/sympytorch
# Copied here to allow PySR-specific tweaks
#####

import collections as co
import functools as ft
import sympy

def _reduce(fn):
    def fn_(*args):
        return ft.reduce(fn, args)
    return fn_

torch_initialized = False
torch = None
_global_func_lookup = None
_Node = None
SingleSymPyModule = None

def _initialize_torch():
    global torch_initialized
    global torch
    global _global_func_lookup
    global _Node
    global SingleSymPyModule

    # Way to lazy load torch, only if this is called,
    # but still allow this module to be loaded in __init__
    if not torch_initialized:
        import torch as _torch
        torch = _torch

        _global_func_lookup = {
            sympy.Mul: _reduce(torch.mul),
            sympy.Add: _reduce(torch.add),
            sympy.div: torch.div,
            sympy.Abs: torch.abs,
            sympy.sign: torch.sign,
            # Note: May raise error for ints.
            sympy.ceiling: torch.ceil,
            sympy.floor: torch.floor,
            sympy.log: torch.log,
            sympy.exp: torch.exp,
            sympy.sqrt: torch.sqrt,
            sympy.cos: torch.cos,
            sympy.acos: torch.acos,
            sympy.sin: torch.sin,
            sympy.asin: torch.asin,
            sympy.tan: torch.tan,
            sympy.atan: torch.atan,
            sympy.atan2: torch.atan2,
            # Note: May give NaN for complex results.
            sympy.cosh: torch.cosh,
            sympy.acosh: torch.acosh,
            sympy.sinh: torch.sinh,
            sympy.asinh: torch.asinh,
            sympy.tanh: torch.tanh,
            sympy.atanh: torch.atanh,
            sympy.Pow: torch.pow,
            sympy.re: torch.real,
            sympy.im: torch.imag,
            sympy.arg: torch.angle,
            # Note: May raise error for ints and complexes
            sympy.erf: torch.erf,
            sympy.loggamma: torch.lgamma,
            sympy.Eq: torch.eq,
            sympy.Ne: torch.ne,
            sympy.StrictGreaterThan: torch.gt,
            sympy.StrictLessThan: torch.lt,
            sympy.LessThan: torch.le,
            sympy.GreaterThan: torch.ge,
            sympy.And: torch.logical_and,
            sympy.Or: torch.logical_or,
            sympy.Not: torch.logical_not,
            sympy.Max: torch.max,
            sympy.Min: torch.min,
            # Matrices
            sympy.MatAdd: torch.add,
            sympy.HadamardProduct: torch.mul,
            sympy.Trace: torch.trace,
            # Note: May raise error for integer matrices.
            sympy.Determinant: torch.det,
        }

        class _Node(torch.nn.Module):
            def __init__(self, *, expr, _memodict, _func_lookup, **kwargs):
                super().__init__(**kwargs)

                self._sympy_func = expr.func

                if issubclass(expr.func, sympy.Float):
                    self._value = torch.nn.Parameter(torch.tensor(float(expr)))
                    self._torch_func = lambda: self._value
                    self._args = ()
                elif issubclass(expr.func, sympy.UnevaluatedExpr):
                    if len(expr.args) != 1 or not issubclass(expr.args[0].func, sympy.Float):
                        raise ValueError("UnevaluatedExpr should only be used to wrap floats.")
                    self.register_buffer('_value', torch.tensor(float(expr.args[0])))
                    self._torch_func = lambda: self._value
                    self._args = ()
                elif issubclass(expr.func, sympy.Integer):
                    # Can get here if expr is one of the Integer special cases,
                    # e.g. NegativeOne
                    self._value = int(expr)
                    self._torch_func = lambda: self._value
                    self._args = ()
                elif issubclass(expr.func, sympy.Symbol):
                    self._name = expr.name
                    self._torch_func = lambda value: value
                    self._args = ((lambda memodict: memodict[expr.name]),)
                else:
                    self._torch_func = _func_lookup[expr.func]
                    args = []
                    for arg in expr.args:
                        try:
                            arg_ = _memodict[arg]
                        except KeyError:
                            arg_ = type(self)(expr=arg, _memodict=_memodict, _func_lookup=_func_lookup, **kwargs)
                            _memodict[arg] = arg_
                        args.append(arg_)
                    self._args = torch.nn.ModuleList(args)

            def forward(self, memodict):
                args = []
                for arg in self._args:
                    try:
                        arg_ = memodict[arg]
                    except KeyError:
                        arg_ = arg(memodict)
                        memodict[arg] = arg_
                    args.append(arg_)
                return self._torch_func(*args)


        class SingleSymPyModule(torch.nn.Module):
            def __init__(self, expression, symbols_in,
                    extra_funcs=None, **kwargs):
                super().__init__(**kwargs)
                
                if extra_funcs is None:
                    extra_funcs = {}
                _func_lookup = co.ChainMap(_global_func_lookup, extra_funcs)

                _memodict = {}
                self._node = _Node(expr=expression, _memodict=_memodict, _func_lookup=_func_lookup)
                self._expression_string = str(expression)
                self.symbols_in = [str(symbol) for symbol in symbols_in]

            def __repr__(self):
                return f"{type(self).__name__}(expression={self._expression_string})"

            def sympy(self):
                _memodict = {}
                return self._node.sympy(_memodict)

            def forward(self, X):
                symbols = {symbol: X[:, i]
                           for i, symbol in enumerate(self.symbols_in)}
                return self._node(symbols)


def sympy2torch(expression, symbols_in):
    """Returns a module for a given sympy expression with trainable parameters;

    This function will assume the input to the module is a matrix X, where
        each column corresponds to each symbol you pass in `symbols_in`.
    """
    global SingleSymPyModule

    _initialize_torch()

    return SingleSymPyModule(expression, symbols_in)