MilesCranmer commited on
Commit
d18011f
1 Parent(s): 3bea8e3

Revert to old torch export

Browse files

- Installing a separate but optional library with dependency on torch introduced
too many difficulties. In the end, the simplest solution is to just
maintain a separate codebase here.

.github/workflows/CI.yml CHANGED
@@ -73,7 +73,7 @@ jobs:
73
  run: coverage run --append --source=pysr --omit='*/feynman_problems.py' -m unittest test.test_jax
74
  shell: bash
75
  - name: "Install Torch"
76
- run: pip install torch sympytorch # (optional import)
77
  shell: bash
78
  - name: "Run Torch tests"
79
  run: coverage run --append --source=pysr --omit='*/feynman_problems.py' -m unittest test.test_torch
 
73
  run: coverage run --append --source=pysr --omit='*/feynman_problems.py' -m unittest test.test_jax
74
  shell: bash
75
  - name: "Install Torch"
76
+ run: pip install torch # (optional import)
77
  shell: bash
78
  - name: "Run Torch tests"
79
  run: coverage run --append --source=pysr --omit='*/feynman_problems.py' -m unittest test.test_torch
.github/workflows/CI_Windows.yml CHANGED
@@ -65,7 +65,7 @@ jobs:
65
  run: python -m unittest test.test
66
  shell: bash
67
  - name: "Install Torch"
68
- run: pip install torch sympytorch # (optional import)
69
  shell: bash
70
  - name: "Run Torch tests"
71
  run: python -m unittest test.test_torch
 
65
  run: python -m unittest test.test
66
  shell: bash
67
  - name: "Install Torch"
68
+ run: pip install torch # (optional import)
69
  shell: bash
70
  - name: "Run Torch tests"
71
  run: python -m unittest test.test_torch
.github/workflows/CI_mac.yml CHANGED
@@ -71,7 +71,7 @@ jobs:
71
  run: python -m unittest test.test_jax
72
  shell: bash
73
  - name: "Install Torch"
74
- run: pip install torch sympytorch # (optional import)
75
  shell: bash
76
  - name: "Run Torch tests"
77
  run: python -m unittest test.test_torch
 
71
  run: python -m unittest test.test_jax
72
  shell: bash
73
  - name: "Install Torch"
74
+ run: pip install torch # (optional import)
75
  shell: bash
76
  - name: "Run Torch tests"
77
  run: python -m unittest test.test_torch
pysr/export_torch.py CHANGED
@@ -1,47 +1,164 @@
 
 
 
 
 
1
  import collections as co
 
2
  import sympy
3
 
 
 
 
 
 
4
  torch_initialized = False
5
  torch = None
6
- sympytorch = None
7
- PySRTorchModule = None
 
8
 
9
  def _initialize_torch():
10
  global torch_initialized
11
  global torch
12
- global sympytorch
13
- global PySRTorchModule
 
14
 
15
- # Way to lazy load torch and sympytorch, only if this is called,
16
  # but still allow this module to be loaded in __init__
17
  if not torch_initialized:
18
- try:
19
- import torch
20
- import sympytorch
21
- except ImportError:
22
- raise ImportError("You need to pip install `torch` and `sympytorch` before exporting to pytorch.")
23
- torch_initialized = True
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25
 
26
- class PySRTorchModule(torch.nn.Module):
27
- def __init__(self, *, expression, symbols_in,
 
 
 
 
 
 
 
 
 
 
 
 
 
28
  selection=None, extra_funcs=None, **kwargs):
29
  super().__init__(**kwargs)
30
- self._module = sympytorch.SymPyModule(
31
- expressions=[expression],
32
- extra_funcs=extra_funcs)
33
- self._selection = selection
34
- self._symbols = symbols_in
35
 
 
 
 
 
 
 
 
 
 
 
36
  def __repr__(self):
37
  return f"{type(self).__name__}(expression={self._expression_string})"
38
 
39
  def forward(self, X):
40
  if self._selection is not None:
41
  X = X[:, self._selection]
42
- symbols = {str(symbol): X[:, i]
43
- for i, symbol in enumerate(self._symbols)}
44
- return self._module(**symbols)[..., 0]
45
 
46
 
47
  def sympy2torch(expression, symbols_in,
@@ -51,11 +168,10 @@ def sympy2torch(expression, symbols_in,
51
  This function will assume the input to the module is a matrix X, where
52
  each column corresponds to each symbol you pass in `symbols_in`.
53
  """
54
- global PySRTorchModule
55
 
56
  _initialize_torch()
57
 
58
- return PySRTorchModule(expression=expression,
59
- symbols_in=symbols_in,
60
- selection=selection,
61
- extra_funcs=extra_torch_mappings)
 
1
+ #####
2
+ # From https://github.com/patrick-kidger/sympytorch
3
+ # Copied here to allow PySR-specific tweaks
4
+ #####
5
+
6
  import collections as co
7
+ import functools as ft
8
  import sympy
9
 
10
+ def _reduce(fn):
11
+ def fn_(*args):
12
+ return ft.reduce(fn, args)
13
+ return fn_
14
+
15
  torch_initialized = False
16
  torch = None
17
+ _global_func_lookup = None
18
+ _Node = None
19
+ SingleSymPyModule = None
20
 
21
  def _initialize_torch():
22
  global torch_initialized
23
  global torch
24
+ global _global_func_lookup
25
+ global _Node
26
+ global SingleSymPyModule
27
 
28
+ # Way to lazy load torch, only if this is called,
29
  # but still allow this module to be loaded in __init__
30
  if not torch_initialized:
31
+ import torch as _torch
32
+ torch = _torch
33
+
34
+ _global_func_lookup = {
35
+ sympy.Mul: _reduce(torch.mul),
36
+ sympy.Add: _reduce(torch.add),
37
+ sympy.div: torch.div,
38
+ sympy.Abs: torch.abs,
39
+ sympy.sign: torch.sign,
40
+ # Note: May raise error for ints.
41
+ sympy.ceiling: torch.ceil,
42
+ sympy.floor: torch.floor,
43
+ sympy.log: torch.log,
44
+ sympy.exp: torch.exp,
45
+ sympy.sqrt: torch.sqrt,
46
+ sympy.cos: torch.cos,
47
+ sympy.acos: torch.acos,
48
+ sympy.sin: torch.sin,
49
+ sympy.asin: torch.asin,
50
+ sympy.tan: torch.tan,
51
+ sympy.atan: torch.atan,
52
+ sympy.atan2: torch.atan2,
53
+ # Note: May give NaN for complex results.
54
+ sympy.cosh: torch.cosh,
55
+ sympy.acosh: torch.acosh,
56
+ sympy.sinh: torch.sinh,
57
+ sympy.asinh: torch.asinh,
58
+ sympy.tanh: torch.tanh,
59
+ sympy.atanh: torch.atanh,
60
+ sympy.Pow: torch.pow,
61
+ sympy.re: torch.real,
62
+ sympy.im: torch.imag,
63
+ sympy.arg: torch.angle,
64
+ # Note: May raise error for ints and complexes
65
+ sympy.erf: torch.erf,
66
+ sympy.loggamma: torch.lgamma,
67
+ sympy.Eq: torch.eq,
68
+ sympy.Ne: torch.ne,
69
+ sympy.StrictGreaterThan: torch.gt,
70
+ sympy.StrictLessThan: torch.lt,
71
+ sympy.LessThan: torch.le,
72
+ sympy.GreaterThan: torch.ge,
73
+ sympy.And: torch.logical_and,
74
+ sympy.Or: torch.logical_or,
75
+ sympy.Not: torch.logical_not,
76
+ sympy.Max: torch.max,
77
+ sympy.Min: torch.min,
78
+ # Matrices
79
+ sympy.MatAdd: torch.add,
80
+ sympy.HadamardProduct: torch.mul,
81
+ sympy.Trace: torch.trace,
82
+ # Note: May raise error for integer matrices.
83
+ sympy.Determinant: torch.det,
84
+ }
85
+
86
+ class _Node(torch.nn.Module):
87
+ """SympyTorch code from https://github.com/patrick-kidger/sympytorch"""
88
+ def __init__(self, *, expr, _memodict, _func_lookup, **kwargs):
89
+ super().__init__(**kwargs)
90
+
91
+ self._sympy_func = expr.func
92
 
93
+ if issubclass(expr.func, sympy.Float):
94
+ self._value = torch.nn.Parameter(torch.tensor(float(expr)))
95
+ self._torch_func = lambda: self._value
96
+ self._args = ()
97
+ elif issubclass(expr.func, sympy.UnevaluatedExpr):
98
+ if len(expr.args) != 1 or not issubclass(expr.args[0].func, sympy.Float):
99
+ raise ValueError("UnevaluatedExpr should only be used to wrap floats.")
100
+ self.register_buffer('_value', torch.tensor(float(expr.args[0])))
101
+ self._torch_func = lambda: self._value
102
+ self._args = ()
103
+ elif issubclass(expr.func, sympy.Integer):
104
+ # Can get here if expr is one of the Integer special cases,
105
+ # e.g. NegativeOne
106
+ self._value = int(expr)
107
+ self._torch_func = lambda: self._value
108
+ self._args = ()
109
+ elif issubclass(expr.func, sympy.Symbol):
110
+ self._name = expr.name
111
+ self._torch_func = lambda value: value
112
+ self._args = ((lambda memodict: memodict[expr.name]),)
113
+ else:
114
+ self._torch_func = _func_lookup[expr.func]
115
+ args = []
116
+ for arg in expr.args:
117
+ try:
118
+ arg_ = _memodict[arg]
119
+ except KeyError:
120
+ arg_ = type(self)(expr=arg, _memodict=_memodict, _func_lookup=_func_lookup, **kwargs)
121
+ _memodict[arg] = arg_
122
+ args.append(arg_)
123
+ self._args = torch.nn.ModuleList(args)
124
 
125
+ def forward(self, memodict):
126
+ args = []
127
+ for arg in self._args:
128
+ try:
129
+ arg_ = memodict[arg]
130
+ except KeyError:
131
+ arg_ = arg(memodict)
132
+ memodict[arg] = arg_
133
+ args.append(arg_)
134
+ return self._torch_func(*args)
135
+
136
+
137
+ class SingleSymPyModule(torch.nn.Module):
138
+ """SympyTorch code from https://github.com/patrick-kidger/sympytorch"""
139
+ def __init__(self, expression, symbols_in,
140
  selection=None, extra_funcs=None, **kwargs):
141
  super().__init__(**kwargs)
 
 
 
 
 
142
 
143
+ if extra_funcs is None:
144
+ extra_funcs = {}
145
+ _func_lookup = co.ChainMap(_global_func_lookup, extra_funcs)
146
+
147
+ _memodict = {}
148
+ self._node = _Node(expr=expression, _memodict=_memodict, _func_lookup=_func_lookup)
149
+ self._expression_string = str(expression)
150
+ self._selection = selection
151
+ self.symbols_in = [str(symbol) for symbol in symbols_in]
152
+
153
  def __repr__(self):
154
  return f"{type(self).__name__}(expression={self._expression_string})"
155
 
156
  def forward(self, X):
157
  if self._selection is not None:
158
  X = X[:, self._selection]
159
+ symbols = {symbol: X[:, i]
160
+ for i, symbol in enumerate(self.symbols_in)}
161
+ return self._node(symbols)
162
 
163
 
164
  def sympy2torch(expression, symbols_in,
 
168
  This function will assume the input to the module is a matrix X, where
169
  each column corresponds to each symbol you pass in `symbols_in`.
170
  """
171
+ global SingleSymPyModule
172
 
173
  _initialize_torch()
174
 
175
+ return SingleSymPyModule(expression, symbols_in,
176
+ selection=selection,
177
+ extra_funcs=extra_torch_mappings)