MilesCranmer commited on
Commit
3d7c303
1 Parent(s): fb950bb

Add test for pytorch export

Browse files
Files changed (2) hide show
  1. .github/workflows/CI.yml +9 -3
  2. test/test_torch.py +12 -0
.github/workflows/CI.yml CHANGED
@@ -61,17 +61,23 @@ jobs:
61
  python setup.py install
62
  - name: "Install Coverage tool"
63
  run: pip install coverage coveralls
 
 
 
64
  - name: "Install JAX"
65
  if: matrix.os != 'windows-latest'
66
  run: pip install jax jaxlib # (optional import)
67
  shell: bash
68
- - name: "Run tests"
69
- run: coverage run --source=pysr --omit='*/feynman_problems.py' -m unittest test.test
70
- shell: bash
71
  - name: "Run JAX tests"
72
  if: matrix.os != 'windows-latest'
73
  run: coverage run --append --source=pysr --omit='*/feynman_problems.py' -m unittest test.test_jax
74
  shell: bash
 
 
 
 
 
 
75
  - name: Coveralls
76
  env:
77
  GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
 
61
  python setup.py install
62
  - name: "Install Coverage tool"
63
  run: pip install coverage coveralls
64
+ - name: "Run tests"
65
+ run: coverage run --source=pysr --omit='*/feynman_problems.py' -m unittest test.test
66
+ shell: bash
67
  - name: "Install JAX"
68
  if: matrix.os != 'windows-latest'
69
  run: pip install jax jaxlib # (optional import)
70
  shell: bash
 
 
 
71
  - name: "Run JAX tests"
72
  if: matrix.os != 'windows-latest'
73
  run: coverage run --append --source=pysr --omit='*/feynman_problems.py' -m unittest test.test_jax
74
  shell: bash
75
+ - name: "Install Torch"
76
+ run: pip install torch # (optional import)
77
+ shell: bash
78
+ - name: "Run Torch tests"
79
+ run: coverage run --append --source=pysr --omit='*/feynman_problems.py' -m unittest test.test_torch
80
+ shell: bash
81
  - name: Coveralls
82
  env:
83
  GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
test/test_torch.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ from pysr import sympy2torch
3
+ import torch
4
+ import sympy
5
+
6
+ print("Test Torch 1 - test export")
7
+ x, y, z = sympy.symbols('x y z')
8
+ cosx = 1.0 * sympy.cos(x) + y
9
+ X = torch.randn((1000, 2))
10
+ true = 1.0 * torch.cos(X[:, 0]) + X[:, 1]
11
+ torch_module = sympy2torch(cosx, [x, y, z])
12
+ assert jnp.all(jnp.isclose(torch_module(X), true)).item()