MilesCranmer commited on
Commit
295c6bd
1 Parent(s): 483a583

Clean up dataset integration

Browse files
Files changed (2) hide show
  1. README.md +5 -5
  2. eureqa.jl +8 -7
README.md CHANGED
@@ -12,17 +12,17 @@ You can change the binary and unary operators in `eureqa.jl` here:
12
  const binops = [plus, mult]
13
  const unaops = [sin, cos, exp];
14
  ```
15
- E.g., you can add another binary function with:
16
  ```
17
- const binops = [plus, mult, (x, y)->x^2*y]
 
18
  ```
19
 
20
  You can change the dataset here:
21
  ```
22
- const nvar = 5;
23
- const X = rand(100, nvar);
24
  # Here is the function we want to learn (x2^2 + cos(x3))
25
- const y = ((cx,)->cx^2).(X[:, 2]) + cos.(X[:, 3])
26
  ```
27
  by either loading in a dataset, or modifying the definition of `y`.
28
 
 
12
  const binops = [plus, mult]
13
  const unaops = [sin, cos, exp];
14
  ```
15
+ E.g., you can add the function for powers with:
16
  ```
17
+ pow(x::Float32, y::Float32)::Float32 = sign(x)*abs(x)^y
18
+ const binops = [plus, mult, pow]
19
  ```
20
 
21
  You can change the dataset here:
22
  ```
23
+ const X = convert(Array{Float32, 2}, randn(100, 5)*2)
 
24
  # Here is the function we want to learn (x2^2 + cos(x3))
25
+ const y = convert(Array{Float32, 1}, ((cx,)->cx^2).(X[:, 2]) + cos.(X[:, 3]))
26
  ```
27
  by either loading in a dataset, or modifying the definition of `y`.
28
 
eureqa.jl CHANGED
@@ -1,24 +1,24 @@
1
  # Define allowed operators
2
- plus(x::Float32, y::Float32) = x+y
3
- mult(x::Float32, y::Float32) = x*y;
 
4
 
5
  ##########################
6
  # # Allowed operators
7
  # (Apparently using const for globals helps speed)
8
- const binops = [plus, mult]
9
- const unaops = [sin, cos, exp]
10
  ##########################
11
 
12
  # How many equations to search when replacing
13
  const ns=10;
14
 
15
- const nvar = 5;
16
  # Here is the function we want to learn (x2^2 + cos(x3) + 5)
17
  #
18
  ##########################
19
  # # Dataset to learn
20
- const X = convert(Array{Float32, 2}, randn(100, nvar)*2)
21
- const y = convert(Array{Float32, 1}, ((cx,)->cx^2).(X[:, 2]) + cos.(X[:, 3]))
22
  ##########################
23
 
24
  ##################
@@ -34,6 +34,7 @@ id = (x,) -> x
34
  const nuna = size(unaops)[1]
35
  const nbin = size(binops)[1]
36
  const nops = nuna + nbin
 
37
 
38
  # Define a serialization format for the symbolic equations:
39
  mutable struct Node
 
1
  # Define allowed operators
2
+ plus(x::Float32, y::Float32)::Float32 = x+y
3
+ mult(x::Float32, y::Float32)::Float32 = x*y;
4
+ pow(x::Float32, y::Float32)::Float32 = sign(x)*abs(x)^y
5
 
6
  ##########################
7
  # # Allowed operators
8
  # (Apparently using const for globals helps speed)
9
+ const binops = [plus, mult, pow]
10
+ const unaops = [sin, cos]
11
  ##########################
12
 
13
  # How many equations to search when replacing
14
  const ns=10;
15
 
 
16
  # Here is the function we want to learn (x2^2 + cos(x3) + 5)
17
  #
18
  ##########################
19
  # # Dataset to learn
20
+ const X = convert(Array{Float32, 2}, randn(100, 5)*2)
21
+ const y = convert(Array{Float32, 1}, ((cx,)->sign(cx)*abs(cx)^3.5).(X[:, 2]) + cos.(X[:, 3]) .+ 5.0)
22
  ##########################
23
 
24
  ##################
 
34
  const nuna = size(unaops)[1]
35
  const nbin = size(binops)[1]
36
  const nops = nuna + nbin
37
+ const nvar = size(X)[2];
38
 
39
  # Define a serialization format for the symbolic equations:
40
  mutable struct Node