MilesCranmer commited on
Commit
e7ede78
1 Parent(s): a88a169

Allow user to pass extra torch operators

Browse files
Files changed (1) hide show
  1. pysr/export_torch.py +2 -2
pysr/export_torch.py CHANGED
@@ -160,7 +160,7 @@ def _initialize_torch():
160
  return self._node(symbols)
161
 
162
 
163
- def sympy2torch(expression, symbols_in):
164
  """Returns a module for a given sympy expression with trainable parameters;
165
 
166
  This function will assume the input to the module is a matrix X, where
@@ -170,4 +170,4 @@ def sympy2torch(expression, symbols_in):
170
 
171
  _initialize_torch()
172
 
173
- return SingleSymPyModule(expression, symbols_in)
 
160
  return self._node(symbols)
161
 
162
 
163
+ def sympy2torch(expression, symbols_in, extra_torch_mappings=None):
164
  """Returns a module for a given sympy expression with trainable parameters;
165
 
166
  This function will assume the input to the module is a matrix X, where
 
170
 
171
  _initialize_torch()
172
 
173
+ return SingleSymPyModule(expression, symbols_in, extra_funcs=extra_torch_mappings)