Spaces:
Sleeping
Sleeping
MilesCranmer
commited on
Commit
•
8c55475
1
Parent(s):
45b290b
Allow custom selection of X matrix in torch/jax modules
Browse files- pysr/export_jax.py +4 -1
- pysr/export_torch.py +9 -3
pysr/export_jax.py
CHANGED
@@ -90,7 +90,7 @@ def _initialize_jax():
|
|
90 |
jsp = _jsp
|
91 |
|
92 |
|
93 |
-
def sympy2jax(expression, symbols_in):
|
94 |
"""Returns a function f and its parameters;
|
95 |
the function takes an input matrix, and a list of arguments:
|
96 |
f(X, parameters)
|
@@ -171,6 +171,9 @@ def sympy2jax(expression, symbols_in):
|
|
171 |
functional_form_text = sympy2jaxtext(expression, parameters, symbols_in)
|
172 |
hash_string = 'A_' + str(abs(hash(str(expression) + str(symbols_in))))
|
173 |
text = f"def {hash_string}(X, parameters):\n"
|
|
|
|
|
|
|
174 |
text += " return "
|
175 |
text += functional_form_text
|
176 |
ldict = {}
|
|
|
90 |
jsp = _jsp
|
91 |
|
92 |
|
93 |
+
def sympy2jax(expression, symbols_in, selection=None):
|
94 |
"""Returns a function f and its parameters;
|
95 |
the function takes an input matrix, and a list of arguments:
|
96 |
f(X, parameters)
|
|
|
171 |
functional_form_text = sympy2jaxtext(expression, parameters, symbols_in)
|
172 |
hash_string = 'A_' + str(abs(hash(str(expression) + str(symbols_in))))
|
173 |
text = f"def {hash_string}(X, parameters):\n"
|
174 |
+
if selection is not None:
|
175 |
+
# Impose the feature selection:
|
176 |
+
text += f" X = X[:, {list(selection)}]"
|
177 |
text += " return "
|
178 |
text += functional_form_text
|
179 |
ldict = {}
|
pysr/export_torch.py
CHANGED
@@ -137,7 +137,7 @@ def _initialize_torch():
|
|
137 |
class SingleSymPyModule(torch.nn.Module):
|
138 |
"""SympyTorch code from https://github.com/patrick-kidger/sympytorch"""
|
139 |
def __init__(self, expression, symbols_in,
|
140 |
-
|
141 |
super().__init__(**kwargs)
|
142 |
|
143 |
if extra_funcs is None:
|
@@ -147,18 +147,22 @@ def _initialize_torch():
|
|
147 |
_memodict = {}
|
148 |
self._node = _Node(expr=expression, _memodict=_memodict, _func_lookup=_func_lookup)
|
149 |
self._expression_string = str(expression)
|
|
|
150 |
self.symbols_in = [str(symbol) for symbol in symbols_in]
|
151 |
|
152 |
def __repr__(self):
|
153 |
return f"{type(self).__name__}(expression={self._expression_string})"
|
154 |
|
155 |
def forward(self, X):
|
|
|
|
|
156 |
symbols = {symbol: X[:, i]
|
157 |
for i, symbol in enumerate(self.symbols_in)}
|
158 |
return self._node(symbols)
|
159 |
|
160 |
|
161 |
-
def sympy2torch(expression, symbols_in,
|
|
|
162 |
"""Returns a module for a given sympy expression with trainable parameters;
|
163 |
|
164 |
This function will assume the input to the module is a matrix X, where
|
@@ -168,4 +172,6 @@ def sympy2torch(expression, symbols_in, extra_torch_mappings=None):
|
|
168 |
|
169 |
_initialize_torch()
|
170 |
|
171 |
-
return SingleSymPyModule(expression, symbols_in,
|
|
|
|
|
|
137 |
class SingleSymPyModule(torch.nn.Module):
|
138 |
"""SympyTorch code from https://github.com/patrick-kidger/sympytorch"""
|
139 |
def __init__(self, expression, symbols_in,
|
140 |
+
selection=None, extra_funcs=None, **kwargs):
|
141 |
super().__init__(**kwargs)
|
142 |
|
143 |
if extra_funcs is None:
|
|
|
147 |
_memodict = {}
|
148 |
self._node = _Node(expr=expression, _memodict=_memodict, _func_lookup=_func_lookup)
|
149 |
self._expression_string = str(expression)
|
150 |
+
self._selection = selection
|
151 |
self.symbols_in = [str(symbol) for symbol in symbols_in]
|
152 |
|
153 |
def __repr__(self):
|
154 |
return f"{type(self).__name__}(expression={self._expression_string})"
|
155 |
|
156 |
def forward(self, X):
|
157 |
+
if self._selection is not None:
|
158 |
+
X = X[:, self._selection]
|
159 |
symbols = {symbol: X[:, i]
|
160 |
for i, symbol in enumerate(self.symbols_in)}
|
161 |
return self._node(symbols)
|
162 |
|
163 |
|
164 |
+
def sympy2torch(expression, symbols_in,
|
165 |
+
selection=None, extra_torch_mappings=None):
|
166 |
"""Returns a module for a given sympy expression with trainable parameters;
|
167 |
|
168 |
This function will assume the input to the module is a matrix X, where
|
|
|
172 |
|
173 |
_initialize_torch()
|
174 |
|
175 |
+
return SingleSymPyModule(expression, symbols_in,
|
176 |
+
selection=selection,
|
177 |
+
extra_funcs=extra_torch_mappings)
|