MilesCranmer commited on
Commit
8c55475
1 Parent(s): 45b290b

Allow custom selection of X matrix in torch/jax modules

Browse files
Files changed (2) hide show
  1. pysr/export_jax.py +4 -1
  2. pysr/export_torch.py +9 -3
pysr/export_jax.py CHANGED
@@ -90,7 +90,7 @@ def _initialize_jax():
90
  jsp = _jsp
91
 
92
 
93
- def sympy2jax(expression, symbols_in):
94
  """Returns a function f and its parameters;
95
  the function takes an input matrix, and a list of arguments:
96
  f(X, parameters)
@@ -171,6 +171,9 @@ def sympy2jax(expression, symbols_in):
171
  functional_form_text = sympy2jaxtext(expression, parameters, symbols_in)
172
  hash_string = 'A_' + str(abs(hash(str(expression) + str(symbols_in))))
173
  text = f"def {hash_string}(X, parameters):\n"
 
 
 
174
  text += " return "
175
  text += functional_form_text
176
  ldict = {}
 
90
  jsp = _jsp
91
 
92
 
93
+ def sympy2jax(expression, symbols_in, selection=None):
94
  """Returns a function f and its parameters;
95
  the function takes an input matrix, and a list of arguments:
96
  f(X, parameters)
 
171
  functional_form_text = sympy2jaxtext(expression, parameters, symbols_in)
172
  hash_string = 'A_' + str(abs(hash(str(expression) + str(symbols_in))))
173
  text = f"def {hash_string}(X, parameters):\n"
174
+ if selection is not None:
175
+ # Impose the feature selection:
176
+ text += f" X = X[:, {list(selection)}]"
177
  text += " return "
178
  text += functional_form_text
179
  ldict = {}
pysr/export_torch.py CHANGED
@@ -137,7 +137,7 @@ def _initialize_torch():
137
  class SingleSymPyModule(torch.nn.Module):
138
  """SympyTorch code from https://github.com/patrick-kidger/sympytorch"""
139
  def __init__(self, expression, symbols_in,
140
- extra_funcs=None, **kwargs):
141
  super().__init__(**kwargs)
142
 
143
  if extra_funcs is None:
@@ -147,18 +147,22 @@ def _initialize_torch():
147
  _memodict = {}
148
  self._node = _Node(expr=expression, _memodict=_memodict, _func_lookup=_func_lookup)
149
  self._expression_string = str(expression)
 
150
  self.symbols_in = [str(symbol) for symbol in symbols_in]
151
 
152
  def __repr__(self):
153
  return f"{type(self).__name__}(expression={self._expression_string})"
154
 
155
  def forward(self, X):
 
 
156
  symbols = {symbol: X[:, i]
157
  for i, symbol in enumerate(self.symbols_in)}
158
  return self._node(symbols)
159
 
160
 
161
- def sympy2torch(expression, symbols_in, extra_torch_mappings=None):
 
162
  """Returns a module for a given sympy expression with trainable parameters;
163
 
164
  This function will assume the input to the module is a matrix X, where
@@ -168,4 +172,6 @@ def sympy2torch(expression, symbols_in, extra_torch_mappings=None):
168
 
169
  _initialize_torch()
170
 
171
- return SingleSymPyModule(expression, symbols_in, extra_funcs=extra_torch_mappings)
 
 
 
137
  class SingleSymPyModule(torch.nn.Module):
138
  """SympyTorch code from https://github.com/patrick-kidger/sympytorch"""
139
  def __init__(self, expression, symbols_in,
140
+ selection=None, extra_funcs=None, **kwargs):
141
  super().__init__(**kwargs)
142
 
143
  if extra_funcs is None:
 
147
  _memodict = {}
148
  self._node = _Node(expr=expression, _memodict=_memodict, _func_lookup=_func_lookup)
149
  self._expression_string = str(expression)
150
+ self._selection = selection
151
  self.symbols_in = [str(symbol) for symbol in symbols_in]
152
 
153
  def __repr__(self):
154
  return f"{type(self).__name__}(expression={self._expression_string})"
155
 
156
  def forward(self, X):
157
+ if self._selection is not None:
158
+ X = X[:, self._selection]
159
  symbols = {symbol: X[:, i]
160
  for i, symbol in enumerate(self.symbols_in)}
161
  return self._node(symbols)
162
 
163
 
164
+ def sympy2torch(expression, symbols_in,
165
+ selection=None, extra_torch_mappings=None):
166
  """Returns a module for a given sympy expression with trainable parameters;
167
 
168
  This function will assume the input to the module is a matrix X, where
 
172
 
173
  _initialize_torch()
174
 
175
+ return SingleSymPyModule(expression, symbols_in,
176
+ selection=selection,
177
+ extra_funcs=extra_torch_mappings)