Spaces:
Sleeping
Sleeping
File size: 6,438 Bytes
a06bfc4 fb950bb a06bfc4 9068541 fb950bb a06bfc4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 |
#####
# From https://github.com/patrick-kidger/sympytorch
# Copied here to allow PySR-specific tweaks
#####
import collections as co
import functools as ft
import sympy
def _reduce(fn):
def fn_(*args):
return ft.reduce(fn, args)
return fn_
torch_initialized = False
torch = None
_global_func_lookup = None
_Node = None
SingleSymPyModule = None
def _initialize_torch():
global torch_initialized
global torch
global _global_func_lookup
global _Node
global SingleSymPyModule
# Way to lazy load torch, only if this is called,
# but still allow this module to be loaded in __init__
if not torch_initialized:
import torch as _torch
torch = _torch
_global_func_lookup = {
sympy.Mul: _reduce(torch.mul),
sympy.Add: _reduce(torch.add),
sympy.div: torch.div,
sympy.Abs: torch.abs,
sympy.sign: torch.sign,
# Note: May raise error for ints.
sympy.ceiling: torch.ceil,
sympy.floor: torch.floor,
sympy.log: torch.log,
sympy.exp: torch.exp,
sympy.sqrt: torch.sqrt,
sympy.cos: torch.cos,
sympy.acos: torch.acos,
sympy.sin: torch.sin,
sympy.asin: torch.asin,
sympy.tan: torch.tan,
sympy.atan: torch.atan,
sympy.atan2: torch.atan2,
# Note: May give NaN for complex results.
sympy.cosh: torch.cosh,
sympy.acosh: torch.acosh,
sympy.sinh: torch.sinh,
sympy.asinh: torch.asinh,
sympy.tanh: torch.tanh,
sympy.atanh: torch.atanh,
sympy.Pow: torch.pow,
sympy.re: torch.real,
sympy.im: torch.imag,
sympy.arg: torch.angle,
# Note: May raise error for ints and complexes
sympy.erf: torch.erf,
sympy.loggamma: torch.lgamma,
sympy.Eq: torch.eq,
sympy.Ne: torch.ne,
sympy.StrictGreaterThan: torch.gt,
sympy.StrictLessThan: torch.lt,
sympy.LessThan: torch.le,
sympy.GreaterThan: torch.ge,
sympy.And: torch.logical_and,
sympy.Or: torch.logical_or,
sympy.Not: torch.logical_not,
sympy.Max: torch.max,
sympy.Min: torch.min,
# Matrices
sympy.MatAdd: torch.add,
sympy.HadamardProduct: torch.mul,
sympy.Trace: torch.trace,
# Note: May raise error for integer matrices.
sympy.Determinant: torch.det,
}
class _Node(torch.nn.Module):
def __init__(self, *, expr, _memodict, _func_lookup, **kwargs):
super().__init__(**kwargs)
self._sympy_func = expr.func
if issubclass(expr.func, sympy.Float):
self._value = torch.nn.Parameter(torch.tensor(float(expr)))
self._torch_func = lambda: self._value
self._args = ()
elif issubclass(expr.func, sympy.UnevaluatedExpr):
if len(expr.args) != 1 or not issubclass(expr.args[0].func, sympy.Float):
raise ValueError("UnevaluatedExpr should only be used to wrap floats.")
self.register_buffer('_value', torch.tensor(float(expr.args[0])))
self._torch_func = lambda: self._value
self._args = ()
elif issubclass(expr.func, sympy.Integer):
# Can get here if expr is one of the Integer special cases,
# e.g. NegativeOne
self._value = int(expr)
self._torch_func = lambda: self._value
self._args = ()
elif issubclass(expr.func, sympy.Symbol):
self._name = expr.name
self._torch_func = lambda value: value
self._args = ((lambda memodict: memodict[expr.name]),)
else:
self._torch_func = _func_lookup[expr.func]
args = []
for arg in expr.args:
try:
arg_ = _memodict[arg]
except KeyError:
arg_ = type(self)(expr=arg, _memodict=_memodict, _func_lookup=_func_lookup, **kwargs)
_memodict[arg] = arg_
args.append(arg_)
self._args = torch.nn.ModuleList(args)
def forward(self, memodict):
args = []
for arg in self._args:
try:
arg_ = memodict[arg]
except KeyError:
arg_ = arg(memodict)
memodict[arg] = arg_
args.append(arg_)
return self._torch_func(*args)
class SingleSymPyModule(torch.nn.Module):
def __init__(self, expression, symbols_in,
extra_funcs=None, **kwargs):
super().__init__(**kwargs)
if extra_funcs is None:
extra_funcs = {}
_func_lookup = co.ChainMap(_global_func_lookup, extra_funcs)
_memodict = {}
self._node = _Node(expr=expression, _memodict=_memodict, _func_lookup=_func_lookup)
self._expression_string = str(expression)
self.symbols_in = [str(symbol) for symbol in symbols_in]
def __repr__(self):
return f"{type(self).__name__}(expression={self._expression_string})"
def sympy(self):
_memodict = {}
return self._node.sympy(_memodict)
def forward(self, X):
symbols = {symbol: X[:, i]
for i, symbol in enumerate(self.symbols_in)}
return self._node(symbols)
def sympy2torch(expression, symbols_in):
"""Returns a module for a given sympy expression with trainable parameters;
This function will assume the input to the module is a matrix X, where
each column corresponds to each symbol you pass in `symbols_in`.
"""
global SingleSymPyModule
_initialize_torch()
return SingleSymPyModule(expression, symbols_in)
|