MilesCranmer commited on
Commit
fb950bb
1 Parent(s): 68b3673

Refactor lazy loading of torch and jax

Browse files
Files changed (4) hide show
  1. pysr/__init__.py +2 -0
  2. pysr/export_jax.py +28 -3
  3. pysr/export_torch.py +176 -152
  4. test/test_jax.py +1 -1
pysr/__init__.py CHANGED
@@ -1,2 +1,4 @@
1
  from .sr import pysr, get_hof, best, best_tex, best_callable, best_row
2
  from .feynman_problems import Problem, FeynmanProblem
 
 
 
1
  from .sr import pysr, get_hof, best, best_tex, best_callable, best_row
2
  from .feynman_problems import Problem, FeynmanProblem
3
+ from .export_jax import sympy2jax
4
+ from .export_torch import sympy2torch
pysr/export_jax.py CHANGED
@@ -2,9 +2,6 @@ import functools as ft
2
  import sympy
3
  import string
4
  import random
5
- import jax
6
- from jax import numpy as jnp
7
- from jax.scipy import special as jsp
8
 
9
  # Special since need to reduce arguments.
10
  MUL = 0
@@ -53,6 +50,7 @@ _jnp_func_lookup = {
53
  sympy.Mod: "jnp.mod",
54
  }
55
 
 
56
  def sympy2jaxtext(expr, parameters, symbols_in):
57
  if issubclass(expr.func, sympy.Float):
58
  parameters.append(float(expr))
@@ -71,6 +69,27 @@ def sympy2jaxtext(expr, parameters, symbols_in):
71
  else:
72
  return f'{_func}({", ".join(args)})'
73
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
74
  def sympy2jax(expression, symbols_in):
75
  """Returns a function f and its parameters;
76
  the function takes an input matrix, and a list of arguments:
@@ -142,6 +161,12 @@ def sympy2jax(expression, symbols_in):
142
  # 3.5427954 , -2.7479894 ], dtype=float32)
143
  ```
144
  """
 
 
 
 
 
 
145
  parameters = []
146
  functional_form_text = sympy2jaxtext(expression, parameters, symbols_in)
147
  hash_string = 'A_' + str(abs(hash(str(expression) + str(symbols_in))))
 
2
  import sympy
3
  import string
4
  import random
 
 
 
5
 
6
  # Special since need to reduce arguments.
7
  MUL = 0
 
50
  sympy.Mod: "jnp.mod",
51
  }
52
 
53
+
54
  def sympy2jaxtext(expr, parameters, symbols_in):
55
  if issubclass(expr.func, sympy.Float):
56
  parameters.append(float(expr))
 
69
  else:
70
  return f'{_func}({", ".join(args)})'
71
 
72
+
73
+ jax_initialized = False
74
+ jax = None
75
+ jnp = None
76
+ jsp = None
77
+
78
+ def _initialize_jax():
79
+ global jax_initialized
80
+ global jax
81
+ global jnp
82
+ global jsp
83
+
84
+ if not jax_initialized:
85
+ import jax as _jax
86
+ from jax import numpy as _jnp
87
+ from jax.scipy import special as _jsp
88
+ jax = _jax
89
+ jnp = _jnp
90
+ jsp = _jsp
91
+
92
+
93
  def sympy2jax(expression, symbols_in):
94
  """Returns a function f and its parameters;
95
  the function takes an input matrix, and a list of arguments:
 
161
  # 3.5427954 , -2.7479894 ], dtype=float32)
162
  ```
163
  """
164
+ _initialize_jax()
165
+ global jax_initialized
166
+ global jax
167
+ global jnp
168
+ global jsp
169
+
170
  parameters = []
171
  functional_form_text = sympy2jaxtext(expression, parameters, symbols_in)
172
  hash_string = 'A_' + str(abs(hash(str(expression) + str(symbols_in))))
pysr/export_torch.py CHANGED
@@ -6,165 +6,189 @@
6
  import collections as co
7
  import functools as ft
8
  import sympy
9
- import torch
10
 
11
  def _reduce(fn):
12
  def fn_(*args):
13
  return ft.reduce(fn, args)
14
  return fn_
15
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
 
17
- _global_func_lookup = {
18
- sympy.Mul: _reduce(torch.mul),
19
- sympy.Add: _reduce(torch.add),
20
- sympy.div: torch.div,
21
- sympy.Abs: torch.abs,
22
- sympy.sign: torch.sign,
23
- # Note: May raise error for ints.
24
- sympy.ceiling: torch.ceil,
25
- sympy.floor: torch.floor,
26
- sympy.log: torch.log,
27
- sympy.exp: torch.exp,
28
- sympy.sqrt: torch.sqrt,
29
- sympy.cos: torch.cos,
30
- sympy.acos: torch.acos,
31
- sympy.sin: torch.sin,
32
- sympy.asin: torch.asin,
33
- sympy.tan: torch.tan,
34
- sympy.atan: torch.atan,
35
- sympy.atan2: torch.atan2,
36
- # Note: May give NaN for complex results.
37
- sympy.cosh: torch.cosh,
38
- sympy.acosh: torch.acosh,
39
- sympy.sinh: torch.sinh,
40
- sympy.asinh: torch.asinh,
41
- sympy.tanh: torch.tanh,
42
- sympy.atanh: torch.atanh,
43
- sympy.Pow: torch.pow,
44
- sympy.re: torch.real,
45
- sympy.im: torch.imag,
46
- sympy.arg: torch.angle,
47
- # Note: May raise error for ints and complexes
48
- sympy.erf: torch.erf,
49
- sympy.loggamma: torch.lgamma,
50
- sympy.Eq: torch.eq,
51
- sympy.Ne: torch.ne,
52
- sympy.StrictGreaterThan: torch.gt,
53
- sympy.StrictLessThan: torch.lt,
54
- sympy.LessThan: torch.le,
55
- sympy.GreaterThan: torch.ge,
56
- sympy.And: torch.logical_and,
57
- sympy.Or: torch.logical_or,
58
- sympy.Not: torch.logical_not,
59
- sympy.Max: torch.max,
60
- sympy.Min: torch.min,
61
- # Matrices
62
- sympy.MatAdd: torch.add,
63
- sympy.HadamardProduct: torch.mul,
64
- sympy.Trace: torch.trace,
65
- # Note: May raise error for integer matrices.
66
- sympy.Determinant: torch.det,
67
- }
68
-
69
- class _Node(torch.nn.Module):
70
- def __init__(self, *, expr, _memodict, _func_lookup, **kwargs):
71
- super().__init__(**kwargs)
72
-
73
- self._sympy_func = expr.func
74
-
75
- if issubclass(expr.func, sympy.Float):
76
- self._value = torch.nn.Parameter(torch.tensor(float(expr)))
77
- self._torch_func = lambda: self._value
78
- self._args = ()
79
- elif issubclass(expr.func, sympy.UnevaluatedExpr):
80
- if len(expr.args) != 1 or not issubclass(expr.args[0].func, sympy.Float):
81
- raise ValueError("UnevaluatedExpr should only be used to wrap floats.")
82
- self.register_buffer('_value', torch.tensor(float(expr.args[0])))
83
- self._torch_func = lambda: self._value
84
- self._args = ()
85
- elif issubclass(expr.func, sympy.Integer):
86
- # Can get here if expr is one of the Integer special cases,
87
- # e.g. NegativeOne
88
- self._value = int(expr)
89
- self._torch_func = lambda: self._value
90
- self._args = ()
91
- elif issubclass(expr.func, sympy.Symbol):
92
- self._name = expr.name
93
- self._torch_func = lambda value: value
94
- self._args = ((lambda memodict: memodict[expr.name]),)
95
- else:
96
- self._torch_func = _func_lookup[expr.func]
97
- args = []
98
- for arg in expr.args:
99
- try:
100
- arg_ = _memodict[arg]
101
- except KeyError:
102
- arg_ = type(self)(expr=arg, _memodict=_memodict, _func_lookup=_func_lookup, **kwargs)
103
- _memodict[arg] = arg_
104
- args.append(arg_)
105
- self._args = torch.nn.ModuleList(args)
106
-
107
- def sympy(self, _memodict):
108
- if issubclass(self._sympy_func, sympy.Float):
109
- return self._sympy_func(self._value.item())
110
- elif issubclass(self._sympy_func, sympy.UnevaluatedExpr):
111
- return self._sympy_func(self._value.item())
112
- elif issubclass(self._sympy_func, sympy.Integer):
113
- return self._sympy_func(self._value)
114
- elif issubclass(self._sympy_func, sympy.Symbol):
115
- return self._sympy_func(self._name)
116
- else:
117
- if issubclass(self._sympy_func, (sympy.Min, sympy.Max)):
118
- evaluate = False
119
- else:
120
- evaluate = True
121
- args = []
122
- for arg in self._args:
123
- try:
124
- arg_ = _memodict[arg]
125
- except KeyError:
126
- arg_ = arg.sympy(_memodict)
127
- _memodict[arg] = arg_
128
- args.append(arg_)
129
- return self._sympy_func(*args, evaluate=evaluate)
130
-
131
- def forward(self, memodict):
132
- args = []
133
- for arg in self._args:
134
- try:
135
- arg_ = memodict[arg]
136
- except KeyError:
137
- arg_ = arg(memodict)
138
- memodict[arg] = arg_
139
- args.append(arg_)
140
- return self._torch_func(*args)
141
-
142
-
143
- class SingleSymPyModule(torch.nn.Module):
144
- def __init__(self, expression, symbols_in,
145
- extra_funcs=None, **kwargs):
146
- super().__init__(**kwargs)
147
-
148
- if extra_funcs is None:
149
- extra_funcs = {}
150
- _func_lookup = co.ChainMap(_global_func_lookup, extra_funcs)
151
-
152
- _memodict = {}
153
- self._node = _Node(expr=expression, _memodict=_memodict, _func_lookup=_func_lookup)
154
- self._expression_string = str(expression)
155
- self.symbols_in = [str(symbol) for symbol in symbols_in]
156
-
157
- def __repr__(self):
158
- return f"{type(self).__name__}(expression={self._expression_string})"
159
-
160
- def sympy(self):
161
- _memodict = {}
162
- return self._node.sympy(_memodict)
163
-
164
- def forward(self, X):
165
- symbols = {symbol: X[:, i]
166
- for i, symbol in enumerate(self.symbols_in)}
167
- return self._node(symbols)
168
 
169
  def sympy2torch(expression, symbols_in):
 
 
 
 
 
 
170
  return SingleSymPyModule(expression, symbols_in)
 
6
  import collections as co
7
  import functools as ft
8
  import sympy
 
9
 
10
  def _reduce(fn):
11
  def fn_(*args):
12
  return ft.reduce(fn, args)
13
  return fn_
14
 
15
+ torch_initialized = False
16
+ torch = None
17
+ _global_func_lookup = None
18
+ _Node = None
19
+ SingleSymPyModule = None
20
+
21
+ def _initialize_torch():
22
+ global torch_initialized
23
+ global torch
24
+ global _global_func_lookup
25
+ global _Node
26
+ global SingleSymPyModule
27
+
28
+ # Way to lazy load torch, only if this is called,
29
+ # but still allow this module to be loaded in __init__
30
+ if not torch_initialized:
31
+ import torch as _torch
32
+ torch = _torch
33
+
34
+ _global_func_lookup = {
35
+ sympy.Mul: _reduce(torch.mul),
36
+ sympy.Add: _reduce(torch.add),
37
+ sympy.div: torch.div,
38
+ sympy.Abs: torch.abs,
39
+ sympy.sign: torch.sign,
40
+ # Note: May raise error for ints.
41
+ sympy.ceiling: torch.ceil,
42
+ sympy.floor: torch.floor,
43
+ sympy.log: torch.log,
44
+ sympy.exp: torch.exp,
45
+ sympy.sqrt: torch.sqrt,
46
+ sympy.cos: torch.cos,
47
+ sympy.acos: torch.acos,
48
+ sympy.sin: torch.sin,
49
+ sympy.asin: torch.asin,
50
+ sympy.tan: torch.tan,
51
+ sympy.atan: torch.atan,
52
+ sympy.atan2: torch.atan2,
53
+ # Note: May give NaN for complex results.
54
+ sympy.cosh: torch.cosh,
55
+ sympy.acosh: torch.acosh,
56
+ sympy.sinh: torch.sinh,
57
+ sympy.asinh: torch.asinh,
58
+ sympy.tanh: torch.tanh,
59
+ sympy.atanh: torch.atanh,
60
+ sympy.Pow: torch.pow,
61
+ sympy.re: torch.real,
62
+ sympy.im: torch.imag,
63
+ sympy.arg: torch.angle,
64
+ # Note: May raise error for ints and complexes
65
+ sympy.erf: torch.erf,
66
+ sympy.loggamma: torch.lgamma,
67
+ sympy.Eq: torch.eq,
68
+ sympy.Ne: torch.ne,
69
+ sympy.StrictGreaterThan: torch.gt,
70
+ sympy.StrictLessThan: torch.lt,
71
+ sympy.LessThan: torch.le,
72
+ sympy.GreaterThan: torch.ge,
73
+ sympy.And: torch.logical_and,
74
+ sympy.Or: torch.logical_or,
75
+ sympy.Not: torch.logical_not,
76
+ sympy.Max: torch.max,
77
+ sympy.Min: torch.min,
78
+ # Matrices
79
+ sympy.MatAdd: torch.add,
80
+ sympy.HadamardProduct: torch.mul,
81
+ sympy.Trace: torch.trace,
82
+ # Note: May raise error for integer matrices.
83
+ sympy.Determinant: torch.det,
84
+ }
85
+
86
+ class _Node(torch.nn.Module):
87
+ def __init__(self, *, expr, _memodict, _func_lookup, **kwargs):
88
+ super().__init__(**kwargs)
89
+
90
+ self._sympy_func = expr.func
91
+
92
+ if issubclass(expr.func, sympy.Float):
93
+ self._value = torch.nn.Parameter(torch.tensor(float(expr)))
94
+ self._torch_func = lambda: self._value
95
+ self._args = ()
96
+ elif issubclass(expr.func, sympy.UnevaluatedExpr):
97
+ if len(expr.args) != 1 or not issubclass(expr.args[0].func, sympy.Float):
98
+ raise ValueError("UnevaluatedExpr should only be used to wrap floats.")
99
+ self.register_buffer('_value', torch.tensor(float(expr.args[0])))
100
+ self._torch_func = lambda: self._value
101
+ self._args = ()
102
+ elif issubclass(expr.func, sympy.Integer):
103
+ # Can get here if expr is one of the Integer special cases,
104
+ # e.g. NegativeOne
105
+ self._value = int(expr)
106
+ self._torch_func = lambda: self._value
107
+ self._args = ()
108
+ elif issubclass(expr.func, sympy.Symbol):
109
+ self._name = expr.name
110
+ self._torch_func = lambda value: value
111
+ self._args = ((lambda memodict: memodict[expr.name]),)
112
+ else:
113
+ self._torch_func = _func_lookup[expr.func]
114
+ args = []
115
+ for arg in expr.args:
116
+ try:
117
+ arg_ = _memodict[arg]
118
+ except KeyError:
119
+ arg_ = type(self)(expr=arg, _memodict=_memodict, _func_lookup=_func_lookup, **kwargs)
120
+ _memodict[arg] = arg_
121
+ args.append(arg_)
122
+ self._args = torch.nn.ModuleList(args)
123
+
124
+ def sympy(self, _memodict):
125
+ if issubclass(self._sympy_func, sympy.Float):
126
+ return self._sympy_func(self._value.item())
127
+ elif issubclass(self._sympy_func, sympy.UnevaluatedExpr):
128
+ return self._sympy_func(self._value.item())
129
+ elif issubclass(self._sympy_func, sympy.Integer):
130
+ return self._sympy_func(self._value)
131
+ elif issubclass(self._sympy_func, sympy.Symbol):
132
+ return self._sympy_func(self._name)
133
+ else:
134
+ if issubclass(self._sympy_func, (sympy.Min, sympy.Max)):
135
+ evaluate = False
136
+ else:
137
+ evaluate = True
138
+ args = []
139
+ for arg in self._args:
140
+ try:
141
+ arg_ = _memodict[arg]
142
+ except KeyError:
143
+ arg_ = arg.sympy(_memodict)
144
+ _memodict[arg] = arg_
145
+ args.append(arg_)
146
+ return self._sympy_func(*args, evaluate=evaluate)
147
+
148
+ def forward(self, memodict):
149
+ args = []
150
+ for arg in self._args:
151
+ try:
152
+ arg_ = memodict[arg]
153
+ except KeyError:
154
+ arg_ = arg(memodict)
155
+ memodict[arg] = arg_
156
+ args.append(arg_)
157
+ return self._torch_func(*args)
158
+
159
+
160
+ class SingleSymPyModule(torch.nn.Module):
161
+ def __init__(self, expression, symbols_in,
162
+ extra_funcs=None, **kwargs):
163
+ super().__init__(**kwargs)
164
+
165
+ if extra_funcs is None:
166
+ extra_funcs = {}
167
+ _func_lookup = co.ChainMap(_global_func_lookup, extra_funcs)
168
+
169
+ _memodict = {}
170
+ self._node = _Node(expr=expression, _memodict=_memodict, _func_lookup=_func_lookup)
171
+ self._expression_string = str(expression)
172
+ self.symbols_in = [str(symbol) for symbol in symbols_in]
173
+
174
+ def __repr__(self):
175
+ return f"{type(self).__name__}(expression={self._expression_string})"
176
+
177
+ def sympy(self):
178
+ _memodict = {}
179
+ return self._node.sympy(_memodict)
180
+
181
+ def forward(self, X):
182
+ symbols = {symbol: X[:, i]
183
+ for i, symbol in enumerate(self.symbols_in)}
184
+ return self._node(symbols)
185
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
186
 
187
  def sympy2torch(expression, symbols_in):
188
+ global torch
189
+ global _Node
190
+ global SingleSymPyModule
191
+
192
+ _initialize_torch()
193
+
194
  return SingleSymPyModule(expression, symbols_in)
test/test_jax.py CHANGED
@@ -1,6 +1,6 @@
1
  import unittest
2
  import numpy as np
3
- from pysr import pysr, sympy2jax
4
  from jax import numpy as jnp
5
  from jax import random
6
  from jax import grad
 
1
  import unittest
2
  import numpy as np
3
+ from pysr import sympy2jax
4
  from jax import numpy as jnp
5
  from jax import random
6
  from jax import grad