Spaces:
Sleeping
Sleeping
File size: 2,428 Bytes
a06bfc4 fb950bb b0e1209 fb950bb b0e1209 fb950bb b0e1209 fb950bb b0e1209 b80fb14 b0e1209 8c55475 fb950bb b0e1209 8c55475 b0e1209 fb950bb 8c55475 b0e1209 a06bfc4 8c55475 9068541 b0e1209 fb950bb b0e1209 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 |
#####
# From https://github.com/patrick-kidger/sympytorch
# Copied here to allow PySR-specific tweaks
#####
import collections as co
import functools as ft
import sympy
def _reduce(fn):
def fn_(*args):
return ft.reduce(fn, args)
return fn_
torch_initialized = False
torch = None
sympytorch = None
PySRTorchModule = None
def _initialize_torch():
global torch_initialized
global torch
global sympytorch
global PySRTorchModule
# Way to lazy load torch, only if this is called,
# but still allow this module to be loaded in __init__
if not torch_initialized:
try:
import torch
import sympytorch
except ImportError:
raise ImportError("You need to pip install `torch` and `sympytorch` before exporting to pytorch.")
torch_initialized = True
class PySRTorchModule(torch.nn.Module):
"""SympyTorch code from https://github.com/patrick-kidger/sympytorch"""
def __init__(self, *, expression, symbols_in,
selection=None, extra_funcs=None, **kwargs):
super().__init__(**kwargs)
self._module = sympytorch.SymPyModule(
expressions=[expression],
extra_funcs=extra_funcs)
self._selection = selection
self._symbols = symbols_in
def __repr__(self):
return f"{type(self).__name__}(expression={self._expression_string})"
def forward(self, X):
if self._selection is not None:
X = X[:, self._selection]
symbols = {str(symbol): X[:, i]
for i, symbol in enumerate(self._symbols)}
return self._module(**symbols)[..., 0]
def sympy2torch(expression, symbols_in,
selection=None, extra_torch_mappings=None):
"""Returns a module for a given sympy expression with trainable parameters;
This function will assume the input to the module is a matrix X, where
each column corresponds to each symbol you pass in `symbols_in`.
"""
global PySRTorchModule
_initialize_torch()
return PySRTorchModule(expression=expression,
symbols_in=symbols_in,
selection=selection,
extra_funcs=extra_torch_mappings)
|