MilesCranmer commited on
Commit
a06bfc4
1 Parent(s): 5b978f9

Create torch export function

Browse files
Files changed (3) hide show
  1. pysr/__init__.py +1 -0
  2. pysr/export_torch.py +172 -0
  3. pysr/sr.py +1 -0
pysr/__init__.py CHANGED
@@ -1,3 +1,4 @@
1
  from .sr import pysr, get_hof, best, best_tex, best_callable, best_row
2
  from .feynman_problems import Problem, FeynmanProblem
3
  from .export_jax import sympy2jax
 
 
1
  from .sr import pysr, get_hof, best, best_tex, best_callable, best_row
2
  from .feynman_problems import Problem, FeynmanProblem
3
  from .export_jax import sympy2jax
4
+ from .export_torch import sympy2torch
pysr/export_torch.py ADDED
@@ -0,0 +1,172 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #####
2
+ # From https://github.com/patrick-kidger/sympytorch
3
+ # Copied here to allow PySR-specific tweaks
4
+ #####
5
+
6
+ import collections as co
7
+ import functools as ft
8
+ import sympy
9
+ import torch
10
+
11
+
12
+ def _reduce(fn):
13
+ def fn_(*args):
14
+ return ft.reduce(fn, args)
15
+ return fn_
16
+
17
+
18
+ _global_func_lookup = {
19
+ sympy.Mul: _reduce(torch.mul),
20
+ sympy.Add: _reduce(torch.add),
21
+ sympy.div: torch.div,
22
+ sympy.Abs: torch.abs,
23
+ sympy.sign: torch.sign,
24
+ # Note: May raise error for ints.
25
+ sympy.ceiling: torch.ceil,
26
+ sympy.floor: torch.floor,
27
+ sympy.log: torch.log,
28
+ sympy.exp: torch.exp,
29
+ sympy.sqrt: torch.sqrt,
30
+ sympy.cos: torch.cos,
31
+ sympy.acos: torch.acos,
32
+ sympy.sin: torch.sin,
33
+ sympy.asin: torch.asin,
34
+ sympy.tan: torch.tan,
35
+ sympy.atan: torch.atan,
36
+ sympy.atan2: torch.atan2,
37
+ # Note: May give NaN for complex results.
38
+ sympy.cosh: torch.cosh,
39
+ sympy.acosh: torch.acosh,
40
+ sympy.sinh: torch.sinh,
41
+ sympy.asinh: torch.asinh,
42
+ sympy.tanh: torch.tanh,
43
+ sympy.atanh: torch.atanh,
44
+ sympy.Pow: torch.pow,
45
+ sympy.re: torch.real,
46
+ sympy.im: torch.imag,
47
+ sympy.arg: torch.angle,
48
+ # Note: May raise error for ints and complexes
49
+ sympy.erf: torch.erf,
50
+ sympy.loggamma: torch.lgamma,
51
+ sympy.Eq: torch.eq,
52
+ sympy.Ne: torch.ne,
53
+ sympy.StrictGreaterThan: torch.gt,
54
+ sympy.StrictLessThan: torch.lt,
55
+ sympy.LessThan: torch.le,
56
+ sympy.GreaterThan: torch.ge,
57
+ sympy.And: torch.logical_and,
58
+ sympy.Or: torch.logical_or,
59
+ sympy.Not: torch.logical_not,
60
+ sympy.Max: torch.max,
61
+ sympy.Min: torch.min,
62
+ # Matrices
63
+ sympy.MatAdd: torch.add,
64
+ sympy.HadamardProduct: torch.mul,
65
+ sympy.Trace: torch.trace,
66
+ # Note: May raise error for integer matrices.
67
+ sympy.Determinant: torch.det,
68
+ }
69
+
70
+
71
+ class _Node(torch.nn.Module):
72
+ def __init__(self, *, expr, _memodict, _func_lookup, **kwargs):
73
+ super().__init__(**kwargs)
74
+
75
+ self._sympy_func = expr.func
76
+
77
+ if issubclass(expr.func, sympy.Float):
78
+ self._value = torch.nn.Parameter(torch.tensor(float(expr)))
79
+ self._torch_func = lambda: self._value
80
+ self._args = ()
81
+ elif issubclass(expr.func, sympy.UnevaluatedExpr):
82
+ if len(expr.args) != 1 or not issubclass(expr.args[0].func, sympy.Float):
83
+ raise ValueError("UnevaluatedExpr should only be used to wrap floats.")
84
+ self.register_buffer('_value', torch.tensor(float(expr.args[0])))
85
+ self._torch_func = lambda: self._value
86
+ self._args = ()
87
+ elif issubclass(expr.func, sympy.Integer):
88
+ # Can get here if expr is one of the Integer special cases,
89
+ # e.g. NegativeOne
90
+ self._value = int(expr)
91
+ self._torch_func = lambda: self._value
92
+ self._args = ()
93
+ elif issubclass(expr.func, sympy.Symbol):
94
+ self._name = expr.name
95
+ self._torch_func = lambda value: value
96
+ self._args = ((lambda memodict: memodict[expr.name]),)
97
+ else:
98
+ self._torch_func = _func_lookup[expr.func]
99
+ args = []
100
+ for arg in expr.args:
101
+ try:
102
+ arg_ = _memodict[arg]
103
+ except KeyError:
104
+ arg_ = type(self)(expr=arg, _memodict=_memodict, _func_lookup=_func_lookup, **kwargs)
105
+ _memodict[arg] = arg_
106
+ args.append(arg_)
107
+ self._args = torch.nn.ModuleList(args)
108
+
109
+ def sympy(self, _memodict):
110
+ if issubclass(self._sympy_func, sympy.Float):
111
+ return self._sympy_func(self._value.item())
112
+ elif issubclass(self._sympy_func, sympy.UnevaluatedExpr):
113
+ return self._sympy_func(self._value.item())
114
+ elif issubclass(self._sympy_func, sympy.Integer):
115
+ return self._sympy_func(self._value)
116
+ elif issubclass(self._sympy_func, sympy.Symbol):
117
+ return self._sympy_func(self._name)
118
+ else:
119
+ if issubclass(self._sympy_func, (sympy.Min, sympy.Max)):
120
+ evaluate = False
121
+ else:
122
+ evaluate = True
123
+ args = []
124
+ for arg in self._args:
125
+ try:
126
+ arg_ = _memodict[arg]
127
+ except KeyError:
128
+ arg_ = arg.sympy(_memodict)
129
+ _memodict[arg] = arg_
130
+ args.append(arg_)
131
+ return self._sympy_func(*args, evaluate=evaluate)
132
+
133
+ def forward(self, memodict):
134
+ args = []
135
+ for arg in self._args:
136
+ try:
137
+ arg_ = memodict[arg]
138
+ except KeyError:
139
+ arg_ = arg(memodict)
140
+ memodict[arg] = arg_
141
+ args.append(arg_)
142
+ return self._torch_func(*args)
143
+
144
+
145
+ class SingleSymPyModule(torch.nn.Module):
146
+ def __init__(self, expression, symbols_in,
147
+ extra_funcs=None, **kwargs):
148
+ super().__init__(**kwargs)
149
+
150
+ if extra_funcs is None:
151
+ extra_funcs = {}
152
+ _func_lookup = co.ChainMap(_global_func_lookup, extra_funcs)
153
+
154
+ _memodict = {}
155
+ self._node = _Node(expr=expression, _memodict=_memodict, _func_lookup=_func_lookup)
156
+ self._expression_string = str(expression)
157
+ self.symbols_in = [str(symbol) for symbol in symbols_in]
158
+
159
+ def __repr__(self):
160
+ return f"{type(self).__name__}(expression={self._expression_string})"
161
+
162
+ def sympy(self):
163
+ _memodict = {}
164
+ return self._node.sympy(_memodict)
165
+
166
+ def forward(self, X):
167
+ symbols = {symbol: X[:, i]
168
+ for i, symbol in enumerate(self.symbols_in)}
169
+ return self._node(symbols)
170
+
171
+ def sympy2torch(expression, symbols_in):
172
+ return SingleSymPyModule(expression, symbols_in)
pysr/sr.py CHANGED
@@ -14,6 +14,7 @@ from pathlib import Path
14
  from datetime import datetime
15
  import warnings
16
  from .export_jax import sympy2jax
 
17
 
18
  global_equation_file = 'hall_of_fame.csv'
19
  global_n_features = None
 
14
  from datetime import datetime
15
  import warnings
16
  from .export_jax import sympy2jax
17
+ from .export_torch import sympy2torch
18
 
19
  global_equation_file = 'hall_of_fame.csv'
20
  global_n_features = None