# Fork of https://github.com/patrick-kidger/sympytorch import collections as co import functools as ft import numpy as np # noqa: F401 import sympy # type: ignore def _reduce(fn): def fn_(*args): return ft.reduce(fn, args) return fn_ torch_initialized = False torch = None SingleSymPyModule = None def _initialize_torch(): global torch_initialized global torch global SingleSymPyModule # Way to lazy load torch, only if this is called, # but still allow this module to be loaded in __init__ if not torch_initialized: import torch as _torch torch = _torch _global_func_lookup = { sympy.Mul: _reduce(torch.mul), sympy.Add: _reduce(torch.add), sympy.div: torch.div, sympy.Abs: torch.abs, sympy.sign: torch.sign, # Note: May raise error for ints. sympy.ceiling: torch.ceil, sympy.floor: torch.floor, sympy.log: torch.log, sympy.exp: torch.exp, sympy.sqrt: torch.sqrt, sympy.cos: torch.cos, sympy.acos: torch.acos, sympy.sin: torch.sin, sympy.asin: torch.asin, sympy.tan: torch.tan, sympy.atan: torch.atan, sympy.atan2: torch.atan2, # Note: May give NaN for complex results. sympy.cosh: torch.cosh, sympy.acosh: torch.acosh, sympy.sinh: torch.sinh, sympy.asinh: torch.asinh, sympy.tanh: torch.tanh, sympy.atanh: torch.atanh, sympy.Pow: torch.pow, sympy.re: torch.real, sympy.im: torch.imag, sympy.arg: torch.angle, # Note: May raise error for ints and complexes sympy.erf: torch.erf, sympy.loggamma: torch.lgamma, sympy.Eq: torch.eq, sympy.Ne: torch.ne, sympy.StrictGreaterThan: torch.gt, sympy.StrictLessThan: torch.lt, sympy.LessThan: torch.le, sympy.GreaterThan: torch.ge, sympy.And: torch.logical_and, sympy.Or: torch.logical_or, sympy.Not: torch.logical_not, sympy.Max: torch.max, sympy.Min: torch.min, sympy.Mod: torch.remainder, sympy.Heaviside: torch.heaviside, sympy.core.numbers.Half: (lambda: 0.5), sympy.core.numbers.One: (lambda: 1.0), } class _Node(torch.nn.Module): """Forked from https://github.com/patrick-kidger/sympytorch""" def __init__(self, *, expr, _memodict, _func_lookup, **kwargs): super().__init__(**kwargs) self._sympy_func = expr.func if issubclass(expr.func, sympy.Float): self._value = torch.nn.Parameter(torch.tensor(float(expr))) self._torch_func = lambda: self._value self._args = () elif issubclass(expr.func, sympy.Rational): # This is some fraction fixed in the operator. self._value = float(expr) self._torch_func = lambda: self._value self._args = () elif issubclass(expr.func, sympy.UnevaluatedExpr): if len(expr.args) != 1 or not issubclass( expr.args[0].func, sympy.Float ): raise ValueError( "UnevaluatedExpr should only be used to wrap floats." ) self.register_buffer("_value", torch.tensor(float(expr.args[0]))) self._torch_func = lambda: self._value self._args = () elif issubclass(expr.func, sympy.Integer): # Can get here if expr is one of the Integer special cases, # e.g. NegativeOne self._value = int(expr) self._torch_func = lambda: self._value self._args = () elif issubclass(expr.func, sympy.NumberSymbol): # Can get here from exp(1) or exact pi self._value = float(expr) self._torch_func = lambda: self._value self._args = () elif issubclass(expr.func, sympy.Symbol): self._name = expr.name self._torch_func = lambda value: value self._args = ((lambda memodict: memodict[expr.name]),) else: try: self._torch_func = _func_lookup[expr.func] except KeyError: raise KeyError( f"Function {expr.func} was not found in Torch function mappings." "Please add it to extra_torch_mappings in the format, e.g., " "{sympy.sqrt: torch.sqrt}." ) args = [] for arg in expr.args: try: arg_ = _memodict[arg] except KeyError: arg_ = type(self)( expr=arg, _memodict=_memodict, _func_lookup=_func_lookup, **kwargs, ) _memodict[arg] = arg_ args.append(arg_) self._args = torch.nn.ModuleList(args) def forward(self, memodict): args = [] for arg in self._args: try: arg_ = memodict[arg] except KeyError: arg_ = arg(memodict) memodict[arg] = arg_ args.append(arg_) return self._torch_func(*args) class _SingleSymPyModule(torch.nn.Module): """Forked from https://github.com/patrick-kidger/sympytorch""" def __init__( self, expression, symbols_in, selection=None, extra_funcs=None, **kwargs ): super().__init__(**kwargs) if extra_funcs is None: extra_funcs = {} _func_lookup = co.ChainMap(_global_func_lookup, extra_funcs) _memodict = {} self._node = _Node( expr=expression, _memodict=_memodict, _func_lookup=_func_lookup ) self._expression_string = str(expression) self._selection = selection self.symbols_in = [str(symbol) for symbol in symbols_in] def __repr__(self): return f"{type(self).__name__}(expression={self._expression_string})" def forward(self, X): if self._selection is not None: X = X[:, self._selection] symbols = {symbol: X[:, i] for i, symbol in enumerate(self.symbols_in)} return self._node(symbols) SingleSymPyModule = _SingleSymPyModule def sympy2torch(expression, symbols_in, selection=None, extra_torch_mappings=None): """Returns a module for a given sympy expression with trainable parameters; This function will assume the input to the module is a matrix X, where each column corresponds to each symbol you pass in `symbols_in`. """ global SingleSymPyModule _initialize_torch() return SingleSymPyModule( expression, symbols_in, selection=selection, extra_funcs=extra_torch_mappings )