MilesCranmer commited on
Commit
b0e1209
1 Parent(s): f544d25

Switch to using sympytorch

Browse files
Files changed (2) hide show
  1. pysr/export_torch.py +25 -129
  2. setup.py +14 -3
pysr/export_torch.py CHANGED
@@ -14,151 +14,46 @@ def _reduce(fn):
14
 
15
  torch_initialized = False
16
  torch = None
17
- _global_func_lookup = None
18
- _Node = None
19
- SingleSymPyModule = None
20
 
21
  def _initialize_torch():
22
  global torch_initialized
23
  global torch
24
- global _global_func_lookup
25
- global _Node
26
- global SingleSymPyModule
27
 
28
  # Way to lazy load torch, only if this is called,
29
  # but still allow this module to be loaded in __init__
30
  if not torch_initialized:
31
- import torch as _torch
32
- torch = _torch
 
 
 
 
33
 
34
- _global_func_lookup = {
35
- sympy.Mul: _reduce(torch.mul),
36
- sympy.Add: _reduce(torch.add),
37
- sympy.div: torch.div,
38
- sympy.Abs: torch.abs,
39
- sympy.sign: torch.sign,
40
- # Note: May raise error for ints.
41
- sympy.ceiling: torch.ceil,
42
- sympy.floor: torch.floor,
43
- sympy.log: torch.log,
44
- sympy.exp: torch.exp,
45
- sympy.sqrt: torch.sqrt,
46
- sympy.cos: torch.cos,
47
- sympy.acos: torch.acos,
48
- sympy.sin: torch.sin,
49
- sympy.asin: torch.asin,
50
- sympy.tan: torch.tan,
51
- sympy.atan: torch.atan,
52
- sympy.atan2: torch.atan2,
53
- # Note: May give NaN for complex results.
54
- sympy.cosh: torch.cosh,
55
- sympy.acosh: torch.acosh,
56
- sympy.sinh: torch.sinh,
57
- sympy.asinh: torch.asinh,
58
- sympy.tanh: torch.tanh,
59
- sympy.atanh: torch.atanh,
60
- sympy.Pow: torch.pow,
61
- sympy.re: torch.real,
62
- sympy.im: torch.imag,
63
- sympy.arg: torch.angle,
64
- # Note: May raise error for ints and complexes
65
- sympy.erf: torch.erf,
66
- sympy.loggamma: torch.lgamma,
67
- sympy.Eq: torch.eq,
68
- sympy.Ne: torch.ne,
69
- sympy.StrictGreaterThan: torch.gt,
70
- sympy.StrictLessThan: torch.lt,
71
- sympy.LessThan: torch.le,
72
- sympy.GreaterThan: torch.ge,
73
- sympy.And: torch.logical_and,
74
- sympy.Or: torch.logical_or,
75
- sympy.Not: torch.logical_not,
76
- sympy.Max: torch.max,
77
- sympy.Min: torch.min,
78
- # Matrices
79
- sympy.MatAdd: torch.add,
80
- sympy.HadamardProduct: torch.mul,
81
- sympy.Trace: torch.trace,
82
- # Note: May raise error for integer matrices.
83
- sympy.Determinant: torch.det,
84
- }
85
 
86
- class _Node(torch.nn.Module):
87
  """SympyTorch code from https://github.com/patrick-kidger/sympytorch"""
88
- def __init__(self, *, expr, _memodict, _func_lookup, **kwargs):
89
- super().__init__(**kwargs)
90
-
91
- self._sympy_func = expr.func
92
-
93
- if issubclass(expr.func, sympy.Float):
94
- self._value = torch.nn.Parameter(torch.tensor(float(expr)))
95
- self._torch_func = lambda: self._value
96
- self._args = ()
97
- elif issubclass(expr.func, sympy.UnevaluatedExpr):
98
- if len(expr.args) != 1 or not issubclass(expr.args[0].func, sympy.Float):
99
- raise ValueError("UnevaluatedExpr should only be used to wrap floats.")
100
- self.register_buffer('_value', torch.tensor(float(expr.args[0])))
101
- self._torch_func = lambda: self._value
102
- self._args = ()
103
- elif issubclass(expr.func, sympy.Integer):
104
- # Can get here if expr is one of the Integer special cases,
105
- # e.g. NegativeOne
106
- self._value = int(expr)
107
- self._torch_func = lambda: self._value
108
- self._args = ()
109
- elif issubclass(expr.func, sympy.Symbol):
110
- self._name = expr.name
111
- self._torch_func = lambda value: value
112
- self._args = ((lambda memodict: memodict[expr.name]),)
113
- else:
114
- self._torch_func = _func_lookup[expr.func]
115
- args = []
116
- for arg in expr.args:
117
- try:
118
- arg_ = _memodict[arg]
119
- except KeyError:
120
- arg_ = type(self)(expr=arg, _memodict=_memodict, _func_lookup=_func_lookup, **kwargs)
121
- _memodict[arg] = arg_
122
- args.append(arg_)
123
- self._args = torch.nn.ModuleList(args)
124
-
125
- def forward(self, memodict):
126
- args = []
127
- for arg in self._args:
128
- try:
129
- arg_ = memodict[arg]
130
- except KeyError:
131
- arg_ = arg(memodict)
132
- memodict[arg] = arg_
133
- args.append(arg_)
134
- return self._torch_func(*args)
135
-
136
-
137
- class SingleSymPyModule(torch.nn.Module):
138
- """SympyTorch code from https://github.com/patrick-kidger/sympytorch"""
139
- def __init__(self, expression, symbols_in,
140
  selection=None, extra_funcs=None, **kwargs):
141
  super().__init__(**kwargs)
142
-
143
- if extra_funcs is None:
144
- extra_funcs = {}
145
- _func_lookup = co.ChainMap(_global_func_lookup, extra_funcs)
146
-
147
- _memodict = {}
148
- self._node = _Node(expr=expression, _memodict=_memodict, _func_lookup=_func_lookup)
149
- self._expression_string = str(expression)
150
  self._selection = selection
151
- self.symbols_in = [str(symbol) for symbol in symbols_in]
152
-
153
  def __repr__(self):
154
  return f"{type(self).__name__}(expression={self._expression_string})"
155
 
156
  def forward(self, X):
157
  if self._selection is not None:
158
  X = X[:, self._selection]
159
- symbols = {symbol: X[:, i]
160
- for i, symbol in enumerate(self.symbols_in)}
161
- return self._node(symbols)
162
 
163
 
164
  def sympy2torch(expression, symbols_in,
@@ -168,10 +63,11 @@ def sympy2torch(expression, symbols_in,
168
  This function will assume the input to the module is a matrix X, where
169
  each column corresponds to each symbol you pass in `symbols_in`.
170
  """
171
- global SingleSymPyModule
172
 
173
  _initialize_torch()
174
 
175
- return SingleSymPyModule(expression, symbols_in,
176
- selection=selection,
177
- extra_funcs=extra_torch_mappings)
 
 
14
 
15
  torch_initialized = False
16
  torch = None
17
+ sympytorch = None
18
+ PySRTorchModule = None
 
19
 
20
  def _initialize_torch():
21
  global torch_initialized
22
  global torch
23
+ global sympytorch
24
+ global PySRTorchModule
 
25
 
26
  # Way to lazy load torch, only if this is called,
27
  # but still allow this module to be loaded in __init__
28
  if not torch_initialized:
29
+ try:
30
+ import torch
31
+ import sympytorch
32
+ except ImportError:
33
+ raise ImportError("You need to pip install `torch` and `sympytorch` before exporting to pytorch.")
34
+ torch_initialized = True
35
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36
 
37
+ class PySRTorchModule(torch.nn.Module):
38
  """SympyTorch code from https://github.com/patrick-kidger/sympytorch"""
39
+ def __init__(self, *, expression, symbols_in,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
40
  selection=None, extra_funcs=None, **kwargs):
41
  super().__init__(**kwargs)
42
+ self._module = sympytorch.SymPyModule(
43
+ expressions=[expression],
44
+ extra_funcs=extra_funcs)
 
 
 
 
 
45
  self._selection = selection
46
+ self._symbols = symbols_in
47
+
48
  def __repr__(self):
49
  return f"{type(self).__name__}(expression={self._expression_string})"
50
 
51
  def forward(self, X):
52
  if self._selection is not None:
53
  X = X[:, self._selection]
54
+ symbols = {str(symbol): X[:, i]
55
+ for i, symbol in enumerate(self._symbols)}
56
+ return self._module(**symbols)[..., 0]
57
 
58
 
59
  def sympy2torch(expression, symbols_in,
 
63
  This function will assume the input to the module is a matrix X, where
64
  each column corresponds to each symbol you pass in `symbols_in`.
65
  """
66
+ global PySRTorchModule
67
 
68
  _initialize_torch()
69
 
70
+ return PySRTorchModule(expression=expression,
71
+ symbols_in=symbols_in,
72
+ selection=selection,
73
+ extra_funcs=extra_torch_mappings)
setup.py CHANGED
@@ -1,8 +1,19 @@
 
1
  import setuptools
2
 
3
  with open("README.md", "r") as fh:
4
  long_description = fh.read()
5
 
 
 
 
 
 
 
 
 
 
 
6
  setuptools.setup(
7
  name="pysr", # Replace with your own username
8
  version="0.6.0rc1",
@@ -12,11 +23,11 @@ setuptools.setup(
12
  long_description=long_description,
13
  long_description_content_type="text/markdown",
14
  url="https://github.com/MilesCranmer/pysr",
15
- install_requires=[
16
  "numpy",
17
  "pandas",
18
  "sympy"
19
- ],
20
  packages=setuptools.find_packages(),
21
  package_data={
22
  'pysr': ['../Project.toml', '../datasets/*']
@@ -26,5 +37,5 @@ setuptools.setup(
26
  "Programming Language :: Python :: 3",
27
  "Operating System :: OS Independent",
28
  ],
29
- python_requires='>=3.3',
30
  )
 
1
+ import importlib.util
2
  import setuptools
3
 
4
  with open("README.md", "r") as fh:
5
  long_description = fh.read()
6
 
7
+ extra_installs = []
8
+
9
+ torch_installed = (importlib.util.find_spec('torch') is not None)
10
+ install_sympytorch = torch_installed
11
+
12
+ if install_sympytorch:
13
+ extra_installs.append('sympytorch')
14
+
15
+ print(extra_installs)
16
+
17
  setuptools.setup(
18
  name="pysr", # Replace with your own username
19
  version="0.6.0rc1",
 
23
  long_description=long_description,
24
  long_description_content_type="text/markdown",
25
  url="https://github.com/MilesCranmer/pysr",
26
+ install_requires=([
27
  "numpy",
28
  "pandas",
29
  "sympy"
30
+ ] + extra_installs),
31
  packages=setuptools.find_packages(),
32
  package_data={
33
  'pysr': ['../Project.toml', '../datasets/*']
 
37
  "Programming Language :: Python :: 3",
38
  "Operating System :: OS Independent",
39
  ],
40
+ python_requires='>=3.7',
41
  )